Source code for

from __future__ import annotations

import logging

from contextlib import contextmanager
from pathlib import Path
from typing import TYPE_CHECKING

import pandas as pd
import yaml

from geoalchemy2.types import Geometry
from sqlalchemy import create_engine, func
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.session import Session
from sshtunnel import SSHTunnelForwarder

    from edisgo import EDisGo

logger = logging.getLogger(__name__)

[docs]def config_settings(path: Path | str) -> dict[str, dict[str, str | int | Path]]: """ Return a nested dictionary containing the configuration settings. It's a nested dictionary because the top level has command names as keys and dictionaries as values where the second level dictionary has command line switches applicable to the command as keys and the supplied values as values. So you would obtain the ``--database-name`` configuration setting used by the current invocation of ``egon-data`` via .. code-block:: python settings()["egon-data"]["--database-name"] Parameters ---------- path : pathlib.Path or str Path to configuration YAML file of egon-data database. Returns ------- dict Nested dictionary containing the egon-data and optional ssh tunnel configuration settings. """ if isinstance(path, str): path = Path(path) if not path.is_file(): raise ValueError(f"Configuration file {path} not found.") with open(path) as f: return yaml.safe_load(f)
[docs]def credentials(path: Path | str) -> dict[str, str | int | Path]: """ Return local database connection parameters. Parameters ---------- path : pathlib.Path or str Path to configuration YAML file of egon-data database. Returns ------- dict Complete DB connection information. """ translated = { "--database-name": "POSTGRES_DB", "--database-password": "POSTGRES_PASSWORD", "--database-host": "HOST", "--database-port": "PORT", "--database-user": "POSTGRES_USER", } configuration = config_settings(path=path) egon_config = configuration["egon-data"] update = { translated[flag]: egon_config[flag] for flag in egon_config if flag in translated } if "PORT" in update.keys(): update["PORT"] = int(update["PORT"]) egon_config.update(update) if "ssh-tunnel" in configuration.keys(): translated = { "ssh-host": "SSH_HOST", "ssh-user": "SSH_USER", "ssh-pkey": "SSH_PKEY", "pgres-host": "PGRES_HOST", } update = { translated[flag]: configuration["ssh-tunnel"][flag] for flag in configuration["ssh-tunnel"] if flag in translated } egon_config.update(update) if "SSH_PKEY" in egon_config.keys(): egon_config["SSH_PKEY"] = Path(egon_config["SSH_PKEY"]).expanduser() if not egon_config["SSH_PKEY"].is_file(): raise ValueError(f"{egon_config['SSH_PKEY']} is not a file.") return egon_config
[docs]def ssh_tunnel(cred: dict) -> str: """ Initialize an SSH tunnel to a remote host according to the input arguments. See for more information. Parameters ---------- cred : dict Complete DB connection information. Returns ------- str Name of local port. """ server = SSHTunnelForwarder( ssh_address_or_host=(cred["SSH_HOST"], 22), ssh_username=cred["SSH_USER"], ssh_pkey=cred["SSH_PKEY"], remote_bind_address=(cred["PGRES_HOST"], cred["PORT"]), ) server.start() return str(server.local_bind_port)
[docs]def engine(path: Path | str, ssh: bool = False) -> Engine: """ Engine for local or remote database. Parameters ---------- path : str Path to configuration YAML file of egon-data database. ssh : bool If True try to establish ssh tunnel from given information within the configuration YAML. If False try to connect to local database. Returns ------- sqlalchemy.engine.base.Engine Database engine """ cred = credentials(path=path) if not ssh: return create_engine( f"postgresql+psycopg2://{cred['POSTGRES_USER']}:" f"{cred['POSTGRES_PASSWORD']}@{cred['HOST']}:" f"{cred['PORT']}/{cred['POSTGRES_DB']}", echo=False, ) local_port = ssh_tunnel(cred) return create_engine( f"postgresql+psycopg2://{cred['POSTGRES_USER']}:" f"{cred['POSTGRES_PASSWORD']}@{cred['PGRES_HOST']}:" f"{local_port}/{cred['POSTGRES_DB']}", echo=False, )
[docs]@contextmanager def session_scope_egon_data(engine: Engine): """Provide a transactional scope around a series of operations.""" Session = sessionmaker(bind=engine) session = Session() try: yield session session.commit() except: # noqa: E722 session.rollback() raise finally: session.close()
[docs]def sql_grid_geom(edisgo_obj: EDisGo) -> Geometry: return func.ST_GeomFromText( edisgo_obj.topology.grid_district["geom"].wkt, edisgo_obj.topology.grid_district["srid"], )
[docs]def get_srid_of_db_table(session: Session, geom_col: InstrumentedAttribute) -> int: query = session.query(func.ST_SRID(geom_col)).limit(1) return pd.read_sql(sql=query.statement, con=query.session.bind).iat[0, 0]
[docs]def sql_within(geom_a: Geometry, geom_b: Geometry, srid: int): """ Checks if geometry a is completely within geometry b. Parameters ---------- geom_a : Geometry Geometry within `geom_b`. geom_b : Geometry Geometry containing `geom_a`. srid : int SRID geometries are transformed to in order to use the same SRID for both geometries. """ return func.ST_Within( func.ST_Transform( geom_a, srid, ), func.ST_Transform( geom_b, srid, ), )
[docs]def sql_intersects(geom_col: InstrumentedAttribute, geom_shape: Geometry, srid: int): return func.ST_Intersects( func.ST_Transform( geom_col, srid, ), func.ST_Transform( geom_shape, srid, ), )