Source code for pyexasol.http_transport

from __future__ import annotations

import hashlib
import os
import re
import socket
import socketserver
import struct
import sys
import threading
import zlib
from collections.abc import Iterable
from dataclasses import dataclass
from ssl import SSLContext
from typing import (
    TYPE_CHECKING,
    Optional,
    Union,
)

from packaging.version import Version

if TYPE_CHECKING:
    from pyexasol import ExaConnection


@dataclass
class SqlQuery:
    connection: ExaConnection
    compression: bool
    # set these values in param dictionary to ExaConnection
    column_delimiter: Optional[str] = None
    column_separator: Optional[str] = None
    columns: Optional[Iterable[str]] = None
    comment: Optional[str] = None
    csv_cols: Optional[Iterable[str]] = None
    encoding: Optional[str] = None
    format: Optional[str] = None
    null: Optional[str] = None
    row_separator: Optional[str] = None

    def _build_csv_cols(self) -> str:
        if self.csv_cols is not None:
            safe_csv_cols_regexp = re.compile(
                r"^(\d+|\d+\.\.\d+)(\sFORMAT='[^'\n]+')?$", re.IGNORECASE
            )
            for c in self.csv_cols:
                if not safe_csv_cols_regexp.match(c):
                    raise ValueError(f"Value [{c}] is not a safe csv_cols part")

            csv_cols = ",".join(self.csv_cols)
            if csv_cols != "":
                return f"({csv_cols})"

        return ""

    @staticmethod
    def _split_exa_address_into_components(exa_address: str) -> tuple[str, str | None]:
        """
        Split ip_address:port and public key from exa address, where the expected
        patterns are:
            ip_address:port
            ip_address:port/public_key
        The value for public key is expected to be a SHA-256 hash of the public key,
        which is then base64-encoded.
        """
        pattern = r"^([\d\.]+:\d+)(?:\/([a-zA-Z0-9_\-+\/]+=))?$"
        match = re.match(pattern, exa_address)
        if match is None:
            raise ValueError(
                f"Could not split exa_address {exa_address} into known components"
            )
        ip_address, public_key = match.groups()
        if not public_key:
            return ip_address, None
        return ip_address, public_key

    def _get_file_list(self, exa_address_list: list[str]) -> list[str]:
        file_ext = self._file_ext
        prefix = self._url_prefix

        csv_cols = self._build_csv_cols()
        files = []
        for i, exa_address in enumerate(exa_address_list):
            ip_address_port, public_key = self._split_exa_address_into_components(
                exa_address
            )
            statement = f"AT '{prefix}{ip_address_port}'"
            if self._requires_tls_public_key():
                if not public_key:
                    raise ValueError(
                        "Public key is required to be in the 'exa_address' for encrypted connections with Exasol DB >= 8.32.0"
                    )
                statement += f" PUBLIC KEY 'sha256//{public_key}'"
            statement += f" FILE '{str(i).rjust(3, '0')}.{file_ext}'{csv_cols}"
            files.append(statement)
        return files

    @staticmethod
    def _get_query_str(query_lines: list[Optional[str]]) -> str:
        filtered_query_lines = [q for q in query_lines if q is not None]
        return "\n".join(filtered_query_lines)

    def _requires_tls_public_key(self) -> bool:
        version = self.connection.exasol_db_version
        return (
            version is not None
            and version >= Version("8.32.0")
            and self.connection.options["encryption"]
        )

    @property
    def _column_spec(self) -> str:
        """
        Return either empty string or comma-separated list of columns in parentheses,
        e.g. '("A", "B")'
        """
        if self.columns is not None:
            formatted = [
                self.connection.format.default_format_ident(c) for c in self.columns
            ]
            comma_sep = ",".join(formatted)
            if comma_sep != "":
                return f"({comma_sep})"
        return ""

    @property
    def _column_delimiter(self) -> Optional[str]:
        if self.column_delimiter is None:
            return None
        return (
            f"COLUMN DELIMITER = {self.connection.format.quote(self.column_delimiter)}"
        )

    @property
    def _column_separator(self) -> Optional[str]:
        if self.column_separator is None:
            return None
        return (
            f"COLUMN SEPARATOR = {self.connection.format.quote(self.column_separator)}"
        )

    @property
    def _comment(self) -> Optional[str]:
        if self.comment is None:
            return None

        if "*/" in self.comment:
            raise ValueError(
                f'Invalid comment "{self.comment}". Comment must not contain "*/".'
            )
        return f"/*{self.comment}*/"

    @property
    def _encoding(self) -> Optional[str]:
        if self.encoding is None:
            return None
        return f"ENCODING = {self.connection.format.quote(self.encoding)}"

    @property
    def _file_ext(self) -> str:
        if self.format is None:
            if self.compression:
                return "gz"
            return "csv"
        if self.format not in ("gz", "bz2", "zip"):
            raise ValueError(f"Unsupported compression format: {self.format}")
        return self.format

    @property
    def _null(self) -> Optional[str]:
        if self.null is None:
            return None
        return f"NULL = {self.connection.format.quote(self.null)}"

    @property
    def _url_prefix(self) -> str:
        if self.connection.options["encryption"]:
            return "https://"
        return "http://"

    @property
    def _row_separator(self) -> Optional[str]:
        if self.row_separator is None:
            return None
        return f"ROW SEPARATOR = {self.connection.format.quote(self.row_separator)}"


@dataclass
class ImportQuery(SqlQuery):
    # set these values in param dictionary to ExaConnection
    skip: Optional[Union[str, int]] = None
    trim: Optional[str] = None

    def build_query(self, table: str, exa_address_list: list[str]) -> str:
        query_lines = [
            self._comment,
            self._get_import(table=table),
            *self._get_file_list(exa_address_list=exa_address_list),
            self._encoding,
            self._null,
            self._skip,
            self._trim,
            self._row_separator,
            self._column_separator,
            self._column_delimiter,
        ]
        return self._get_query_str(query_lines)

    @staticmethod
    def load_from_dict(
        connection: ExaConnection, compression: bool, params: dict
    ) -> ImportQuery:
        """
        Load the params dictionary into the ImportQuery class

        Keys in `params` that are not present in as attributes of the `ImportQuery`
        class will raise an Exception.
        """
        return ImportQuery(connection=connection, compression=compression, **params)

    def _get_import(self, table: str) -> str:
        return f"IMPORT INTO {table}{self._column_spec} FROM CSV"

    @property
    def _skip(self) -> Optional[str]:
        if self.skip is None:
            return None
        return f"SKIP = {self.connection.format.safe_decimal(self.skip)}"

    @property
    def _trim(self) -> Optional[str]:
        if self.trim is None:
            return None

        trim = str(self.trim).upper()
        if trim not in ("TRIM", "LTRIM", "RTRIM"):
            raise ValueError(f"Invalid value for import parameter TRIM: {trim}")
        return trim


@dataclass
class ExportQuery(SqlQuery):
    # set these values in param dictionary to ExaConnection
    delimit: Optional[str] = None
    with_column_names: Optional[str] = None

    def build_query(self, table: str, exa_address_list: list[str]) -> str:
        query_lines = [
            self._comment,
            self._get_export(table=table),
            *self._get_file_list(exa_address_list=exa_address_list),
            self._delimit,
            self._encoding,
            self._null,
            self._row_separator,
            self._column_separator,
            self._column_delimiter,
            self._with_column_names,
        ]
        return self._get_query_str(query_lines)

    @staticmethod
    def load_from_dict(
        connection: ExaConnection, compression: bool, params: dict
    ) -> ExportQuery:
        """
        Load the params dictionary into the ExportQuery class

        Keys in `params` that are not present in as attributes of the `ExportQuery`
        class will raise an Exception.
        """
        return ExportQuery(connection=connection, compression=compression, **params)

    def _get_export(self, table: str) -> str:
        return f"EXPORT {table}{self._column_spec} INTO CSV"

    @property
    def _delimit(self) -> Optional[str]:
        if self.delimit is None:
            return None

        delimit = str(self.delimit).upper()
        if delimit not in ("AUTO", "ALWAYS", "NEVER"):
            raise ValueError(f"Invalid value for export parameter DELIMIT: {delimit}")
        return f"DELIMIT={delimit}"

    @property
    def _with_column_names(self) -> Optional[str]:
        if self.with_column_names is None:
            return None
        return "WITH COLUMN NAMES"


class ExaSQLThread(threading.Thread):
    """
    Thread class which re-throws any Exception to parent thread
    """

    def __init__(self, connection: ExaConnection, compression: bool):
        self.connection = connection
        self.compression = compression

        self.params: dict = {}
        self.http_thread = None
        self.exa_address_list: list[str] = []
        self.exc = None

        super().__init__()

    def set_http_thread(self, http_thread):
        self.http_thread = http_thread
        self.exa_address_list = [http_thread.exa_address]

    def set_exa_address_list(self, exa_address_list):
        self.exa_address_list = exa_address_list

    def run(self):
        try:
            self.run_sql()
        except BaseException as e:
            self.exc = e

            # In case of SQL error stop HTTP server, close pipes and interrupt I/O in callback function
            if self.http_thread:
                self.http_thread.terminate()

    def run_sql(self):
        pass

    def join_with_exc(self, *args):
        super().join(*args)

        if self.exc:
            raise self.exc


class ExaSQLExportThread(ExaSQLThread):
    """
    Build and run EXPORT query into separate thread
    Main thread is busy outputting data in callbacks
    """

    def __init__(
        self,
        connection: ExaConnection,
        compression: bool,
        query_or_table,
        export_params: dict,
    ):
        super().__init__(connection, compression)

        self.query_or_table = query_or_table
        self.params = export_params

    def run_sql(self):
        if (
            isinstance(self.query_or_table, tuple)
            or str(self.query_or_table).strip().find(" ") == -1
        ):
            export_table = self.connection.format.default_format_ident(
                self.query_or_table
            )
        else:
            # New lines are mandatory to handle queries with single-line comments '--'
            export_query = self.query_or_table.lstrip(" \n").rstrip(" \n;")
            export_table = f"(\n{export_query}\n)"

            if self.params.get("columns"):
                raise ValueError(
                    "Export option 'columns' is not compatible with SQL query export source"
                )

        export_query = ExportQuery.load_from_dict(
            connection=self.connection, compression=self.compression, params=self.params
        ).build_query(table=export_table, exa_address_list=self.exa_address_list)
        self.connection.execute(export_query)


class ExaSQLImportThread(ExaSQLThread):
    """
    Build and run EXPORT query into separate thread
    Main thread is busy parsing results in callbacks
    """

    def __init__(
        self,
        connection: ExaConnection,
        compression: bool,
        table: str,
        import_params: dict,
    ):
        super().__init__(connection, compression)

        self.table = table
        self.params = import_params

    def run_sql(self):
        table = self.connection.format.default_format_ident(self.table)

        import_query = ImportQuery.load_from_dict(
            connection=self.connection, compression=self.compression, params=self.params
        ).build_query(table=table, exa_address_list=self.exa_address_list)
        self.connection.execute(import_query)


class ExaHttpThread(threading.Thread):
    """
    HTTP communication and compression / decompression is offloaded to a separate thread.
    PyExasol uses a thread instead of a subprocess or multiprocessing to avoid
    compatibility issues on Windows operating systems. For further details, see
    - https://github.com/exasol/pyexasol/issues/73
    - https://pythonforthelab.com/blog/differences-between-multiprocessing-windows-and-linux/
    """

    def __init__(self, ipaddr: str, port: int, compression: bool, encryption: bool):
        self.server = ExaTCPServer(
            (ipaddr, port),
            ExaHttpRequestHandler,
            compression=compression,
            encryption=encryption,
        )

        self.read_pipe = self.server.read_pipe
        self.write_pipe = self.server.write_pipe

        self.exc = None

        super().__init__()

    @property
    def exa_address(self) -> str:
        address = f"{self.server.exa_address_ipaddr}:{self.server.exa_address_port}"
        if public_key := self.server.exa_address_public_key:
            address = f"{address}/{public_key}"
        return address

    def run(self):
        try:
            # Handle exactly one HTTP request
            # Exit loop if thread was explicitly terminated prior to receiving HTTP request
            while self.server.total_clients == 0 and not self.server.is_terminated:
                self.server.handle_request()
        except BaseException as e:
            self.exc = e
        finally:
            self.server.server_close()

    def join(self, timeout=None):
        self.server.can_finish_get.set()
        super().join(timeout)

    def join_with_exc(self):
        self.join()

        if self.exc:
            raise self.exc

    def terminate(self):
        self.server.is_terminated = True
        self.server.can_finish_get.set()

        # Must close pipes here to prevent infinite lock in callback function
        # Termination pipe order is important for Windows
        self.write_pipe.close()
        self.read_pipe.close()


[docs] class ExaHTTPTransportWrapper: """ Wrapper for :ref:`http_transport_parallel`. You may create this wrapper using :func:`pyexasol.http_transport`. Note: Starts an HTTP server, obtains the address (the ``"ipaddr:port"`` string), and sends it to the parent process. Block into ``export_*()`` or ``import_*()`` call, wait for incoming connection, process data and exit. """ def __init__( self, ipaddr: str, port: int, compression: bool = False, encryption: bool = True, ): self.http_thread = ExaHttpThread(ipaddr, port, compression, encryption) self.http_thread.start() @property def exa_address(self) -> str: """ Internal Exasol address as ``ipaddr:port`` string. Note: This string should be passed from child processes to parent process and used as an argument for ``export_parallel()`` and ``import_parallel()`` functions. """ return self.http_thread.exa_address
[docs] def get_proxy(self): """ Caution: **DEPRECATED**, please use ``.exa_address`` property """ return self.http_thread.exa_address
[docs] def export_to_callback(self, callback, dst, callback_params=None): """ Exports chunk of data using callback function. Args: callback: Callback function. dst: Export destination for callback function. callback_params: Dict with additional parameters for callback function. Returns: Result of the callback function. Note: You may use exactly the same callbacks utilized by standard non-parallel ``export_to_callback()`` function. """ if not callable(callback): raise ValueError("Callback argument is not callable") if callback_params is None: callback_params = {} try: result = callback(self.http_thread.read_pipe, dst, **callback_params) self.http_thread.read_pipe.close() self.http_thread.join_with_exc() return result except (Exception, KeyboardInterrupt) as e: self.http_thread.terminate() self.http_thread.join() raise e
[docs] def import_from_callback(self, callback, src, callback_params=None): """ Import chunk of data using callback function. Args: callback: Callback function. src: Import source for the callback function. callback_params: Dict with additional parameters for the callback function. Returns: Result of callback function Note: You may use exactly the same callbacks utilized by standard non-parallel ``import_from_callback()`` function. """ if not callable(callback): raise ValueError("Callback argument is not callable") if callback_params is None: callback_params = {} try: result = callback(self.http_thread.write_pipe, src, **callback_params) self.http_thread.write_pipe.close() self.http_thread.join_with_exc() return result except (Exception, KeyboardInterrupt) as e: self.http_thread.terminate() self.http_thread.join() raise e
def __repr__(self): return f"<{self.__class__.__name__} exa_address={self.exa_address}>"
class ExaTCPServer(socketserver.TCPServer): exa_address_ipaddr: str exa_address_port: int exa_address_public_key: Optional[str] = None total_clients: int = 0 is_terminated: bool = False timeout: int | None = 1 def __init__(self, *args, **kwargs): self.compression: bool = kwargs.pop("compression", False) self.encryption: bool = kwargs.pop("encryption", True) r_fd, w_fd = os.pipe() self.read_pipe = open(r_fd, "rb", 0) self.write_pipe = open(w_fd, "wb", 0) # GET method calls (IMPORT) require extra protection # # Callback function may close pipe abruptly and raise an exception # It may cause partial valid data to be sent to Exasol server # # This event waits for callback function to finish before sending final chunk of data # # If callback function failed with exception, HTTP thread will be terminated # Final chunk will not be sent, causing IMPORT query to fail and to discard partial data self.can_finish_get = threading.Event() super().__init__(*args, **kwargs) def server_bind(self): self.set_sock_opts() """ Special Exasol packet to establish tunneling and return an internal Exasol address, which can be used in a query """ self.socket.connect(self.server_address) self.socket.sendall(struct.pack("iii", 0x02212102, 1, 1)) _, port, ipaddr = struct.unpack("ii16s", self.socket.recv(24)) self.exa_address_ipaddr = ipaddr.replace(b"\x00", b"").decode() self.exa_address_port = port if self.encryption: context, public_key_sha = self.generate_adhoc_ssl_context() self.socket = context.wrap_socket( self.socket, server_side=True, do_handshake_on_connect=False ) self.exa_address_public_key = public_key_sha def server_activate(self): pass def get_request(self): return self.socket, self.server_address def shutdown_request(self, request): pass def close_request(self, request): pass def set_sock_opts(self): # only large packets are expected for HTTP transport self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) # default keep-alive once a minute, 5 probes keepidle = 60 keepintvl = 60 keepcnt = 5 if sys.platform.startswith("linux"): self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, keepidle) self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, keepintvl) self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, keepcnt) elif sys.platform.startswith("darwin"): # TCP_KEEPALIVE = 0x10 # https://bugs.python.org/issue34932 self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) self.socket.setsockopt(socket.IPPROTO_TCP, 0x10, keepidle) self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, keepintvl) self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, keepcnt) elif sys.platform.startswith("win"): self.socket.ioctl( socket.SIO_KEEPALIVE_VALS, (1, keepidle * 1000, keepintvl * 1000) ) @staticmethod def generate_adhoc_ssl_context() -> tuple[SSLContext, str]: """ Create temporary self-signed certificate for encrypted HTTP transport Exasol does not check validity of certificates """ from base64 import b64encode from datetime import ( datetime, timedelta, timezone, ) from pathlib import Path from ssl import ( CERT_NONE, PROTOCOL_TLS_SERVER, ) from tempfile import TemporaryDirectory from cryptography import x509 from cryptography.hazmat.primitives import ( hashes, serialization, ) from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509.oid import NameOID key_pair = rsa.generate_private_key( public_exponent=65537, key_size=2048, ) # For a self-signed certificate, subject and issuer are identical. subject = issuer = x509.Name( [ x509.NameAttribute(NameOID.COUNTRY_NAME, "DE"), x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "Franconia"), x509.NameAttribute(NameOID.LOCALITY_NAME, "Nuremberg"), x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Exasol AG"), x509.NameAttribute(NameOID.COMMON_NAME, "exasol.com"), ] ) today = datetime.now(timezone.utc) cert = ( x509.CertificateBuilder() .subject_name(subject) .issuer_name(issuer) .public_key(key_pair.public_key()) .serial_number(x509.random_serial_number()) .not_valid_before(today) .not_valid_after(today + timedelta(days=365)) .sign(key_pair, hashes.SHA256()) ) der_data = key_pair.public_key().public_bytes( encoding=serialization.Encoding.DER, format=serialization.PublicFormat.SubjectPublicKeyInfo, ) sha256_hash = hashlib.sha256(der_data).digest() base64_encoded = b64encode(sha256_hash) public_key_sha256 = base64_encoded.decode("utf-8") # TemporaryDirectory is used instead of NamedTemporaryFile for compatibility with Windows with TemporaryDirectory(prefix="pyexasol_ssl_") as tempdir: directory = Path(tempdir) cert_file = open(directory / "cert", "wb") cert_file.write(cert.public_bytes(serialization.Encoding.PEM)) cert_file.close() key_file = open(directory / "key", "wb") key_file.write( key_pair.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, encryption_algorithm=serialization.NoEncryption(), ) ) key_file.close() context = SSLContext(PROTOCOL_TLS_SERVER) context.verify_mode = CERT_NONE context.load_cert_chain(certfile=cert_file.name, keyfile=key_file.name) return context, public_key_sha256 class ExaHttpRequestHandler(socketserver.StreamRequestHandler): success_headers = ( b"HTTP/1.1 200 OK\r\n" b"Connection: close\r\n" b"Transfer-Encoding: chunked\r\n" b"\r\n" ) error_headers = ( b"HTTP/1.1 500 Internal Server Error\r\n" b"Connection: close\r\n" b"\r\n" ) server: ExaTCPServer def handle(self): self.server.total_clients += 1 # Extract method from the first header method = str(self.rfile.readline(), "iso-8859-1").split()[0] # Skip all other headers while self.rfile.readline() != b"\r\n": pass if method == "PUT": if self.server.compression: self.method_put_compressed() else: self.method_put_raw() elif method == "GET": if self.server.compression: self.method_get_compressed() else: self.method_get_raw() else: raise RuntimeError(f"Unsupported HTTP method [{method}]") def method_put_raw(self): try: while not self.server.is_terminated: data = self.read_chunk() if data is None: break self.server.write_pipe.write(data) except Exception as e: self.write_error_headers() raise e else: self.write_success_headers() self.write_final_chunk() finally: self.server.write_pipe.close() def method_put_compressed(self): try: d = zlib.decompressobj(wbits=16 + zlib.MAX_WBITS) while not self.server.is_terminated: data = self.read_chunk() if data is None: self.server.write_pipe.write(d.flush()) break self.server.write_pipe.write(d.decompress(data)) except Exception as e: self.write_error_headers() raise e else: self.write_success_headers() self.write_final_chunk() finally: self.server.write_pipe.close() def method_get_raw(self): try: self.write_success_headers() while not self.server.is_terminated: # Exasol server produces chunks of this size (without chunk header part) data = self.server.read_pipe.read(65524) if data is None or len(data) == 0: break self.write_chunk(data) except Exception as e: raise e finally: self.server.read_pipe.close() if self.server.can_finish_get.wait() and not self.server.is_terminated: self.write_final_chunk() def method_get_compressed(self): try: self.write_success_headers() c = zlib.compressobj(wbits=16 + zlib.MAX_WBITS) while not self.server.is_terminated: # Linux common pipe buffer, 64Kb data = self.server.read_pipe.read(65536) if data is None or len(data) == 0: self.write_chunk(c.flush(zlib.Z_FINISH)) break self.write_chunk(c.compress(data)) except Exception as e: raise e finally: self.server.read_pipe.close() if self.server.can_finish_get.wait() and not self.server.is_terminated: self.write_final_chunk() def read_chunk(self): hex_length = self.rfile.readline().rstrip() if len(hex_length) == 0: chunk_len = 0 else: chunk_len = int(hex_length, 16) if chunk_len == 0: return None data = self.rfile.read(chunk_len) if self.rfile.read(2) != b"\r\n": raise RuntimeError("Invalid chunk delimiter in HTTP stream") return data def write_chunk(self, data): chunk_len = len(data) if chunk_len == 0: return self.wfile.write(b"%X\r\n%b\r\n" % (chunk_len, data)) def write_final_chunk(self): self.wfile.write(b"0\r\n\r\n") def write_success_headers(self): self.wfile.write(self.success_headers) def write_error_headers(self): self.wfile.write(self.error_headers)