Source code for exasol.nb_connector.secret_store

import contextlib
import logging
from collections.abc import Iterable
from inspect import cleandoc
from pathlib import Path
from typing import (
    Optional,
    Union,
)

from sqlcipher3 import dbapi2 as sqlcipher  # type: ignore

from exasol.nb_connector.ai_lab_config import AILabConfig as CKey

_logger = logging.getLogger(__name__)
TABLE_NAME = "secrets"


class InvalidPassword(Exception):
    """Signal potentially incorrect master password."""


[docs] class Secrets: def __init__(self, db_file: Path, master_password: str) -> None: self.db_file = db_file self._master_password = master_password self._con = None def close(self) -> None: if self._con is not None: self._con.close() self._con = None # disable error messages about unresolved types in type hints as # sqlcipher is a c library and does not provide this information. def connection( self, ) -> sqlcipher.Connection: # pylint: disable=E1101 if self._con is None: db_file_found = self.db_file.exists() if not db_file_found: _logger.info("Creating file %s", self.db_file) # fmt: off self._con = sqlcipher.connect(self.db_file) # pylint: disable=E1101 # fmt: on self._use_master_password() self._initialize(db_file_found) return self._con def _initialize(self, db_file_found: bool) -> None: if db_file_found: self._verify_access() return _logger.info('Creating table "%s".', TABLE_NAME) with self._cursor() as cur: cur.execute(f"CREATE TABLE {TABLE_NAME} (key TEXT PRIMARY KEY, value TEXT)") def _use_master_password(self) -> None: """ If database is unencrypted then this method encrypts it. If database is already encrypted then this method enables to access the data. """ if self._master_password is not None: sanitized = self._master_password.replace("'", "\\'") with self._cursor() as cur: cur.execute(f"PRAGMA key = '{sanitized}'") def _verify_access(self): try: with self._cursor() as cur: cur.execute("SELECT * FROM sqlite_master") # fmt: off except (sqlcipher.DatabaseError) as ex: # pylint: disable=E1101 # fmt: on print(f"exception {ex}") if str(ex) == "file is not a database": raise InvalidPassword( cleandoc( f""" Cannot access database file {self.db_file}. This also happens if master password is incorrect. """ ) ) from ex raise ex @contextlib.contextmanager def _cursor( self, ) -> sqlcipher.Cursor: # pylint: disable=E1101 cur = self.connection().cursor() try: yield cur self.connection().commit() except: self.connection().rollback() raise finally: cur.close() def save(self, key: Union[str, CKey], value: str) -> "Secrets": """key represents a system, service, or application""" key = key.name if isinstance(key, CKey) else key def entry_exists(cur) -> bool: res = cur.execute(f"SELECT * FROM {TABLE_NAME} WHERE key=?", [key]) return res and res.fetchone() def update(cur) -> None: cur.execute(f"UPDATE {TABLE_NAME} SET value=? WHERE key=?", [value, key]) def insert(cur) -> None: cur.execute( f"INSERT INTO {TABLE_NAME} (key,value) VALUES (?, ?)", [key, value] ) with self._cursor() as cur: if entry_exists(cur): update(cur) else: insert(cur) return self def get( self, key: Union[str, CKey], default_value: Optional[str] = None ) -> Optional[str]: key = key.name if isinstance(key, CKey) else key with self._cursor() as cur: res = cur.execute(f"SELECT value FROM {TABLE_NAME} WHERE key=?", [key]) row = res.fetchone() if res else None return row[0] if row else default_value def __getattr__(self, key) -> str: val = self.get(key) if val is None: raise AttributeError(f'Unknown key "{key}"') return val def __getitem__(self, item) -> str: val = self.get(item) if val is None: raise AttributeError(f'Unknown key "{item}"') return val def keys(self) -> Iterable[str]: """Iterator over keys akin to dict.keys()""" with self._cursor() as cur: res = cur.execute(f"SELECT key FROM {TABLE_NAME}") for row in res: yield row[0] def values(self) -> Iterable[str]: """Iterator over values akin to dict.values()""" with self._cursor() as cur: res = cur.execute(f"SELECT value FROM {TABLE_NAME}") for row in res: yield row[0] def items(self) -> Iterable[tuple[str, str]]: """Iterator over keys and values akin to dict.items()""" with self._cursor() as cur: res = cur.execute(f"SELECT key, value FROM {TABLE_NAME}") for row in res: yield row[0], row[1] def remove(self, key: Union[str, CKey]) -> None: """ Deletes entry with the specified key if it exists. Doesn't raise any exception if the key doesn't exist. """ key = key.name if isinstance(key, CKey) else key with self._cursor() as cur: cur.execute(f"DELETE FROM {TABLE_NAME} WHERE key=?", [key])