You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
128 lines
3.0 KiB
128 lines
3.0 KiB
"""
|
|
Internal utilities to manipulate connection strings
|
|
"""
|
|
|
|
# Copyright (C) 2024 The Psycopg Team
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from typing import Any
|
|
from functools import lru_cache
|
|
from ipaddress import ip_address
|
|
from dataclasses import dataclass
|
|
from typing_extensions import TypeAlias
|
|
|
|
from . import pq
|
|
from . import errors as e
|
|
|
|
ConnDict: TypeAlias = "dict[str, Any]"
|
|
|
|
|
|
def split_attempts(params: ConnDict) -> list[ConnDict]:
|
|
"""
|
|
Split connection parameters with a sequence of hosts into separate attempts.
|
|
"""
|
|
|
|
def split_val(key: str) -> list[str]:
|
|
val = get_param(params, key)
|
|
return val.split(",") if val else []
|
|
|
|
hosts = split_val("host")
|
|
hostaddrs = split_val("hostaddr")
|
|
ports = split_val("port")
|
|
|
|
if hosts and hostaddrs and len(hosts) != len(hostaddrs):
|
|
raise e.OperationalError(
|
|
f"could not match {len(hosts)} host names"
|
|
f" with {len(hostaddrs)} hostaddr values"
|
|
)
|
|
|
|
nhosts = max(len(hosts), len(hostaddrs))
|
|
|
|
if 1 < len(ports) != nhosts:
|
|
raise e.OperationalError(
|
|
f"could not match {len(ports)} port numbers to {len(hosts)} hosts"
|
|
)
|
|
|
|
# A single attempt to make. Don't mangle the conninfo string.
|
|
if nhosts <= 1:
|
|
return [params]
|
|
|
|
if len(ports) == 1:
|
|
ports *= nhosts
|
|
|
|
# Now all lists are either empty or have the same length
|
|
rv = []
|
|
for i in range(nhosts):
|
|
attempt = params.copy()
|
|
if hosts:
|
|
attempt["host"] = hosts[i]
|
|
if hostaddrs:
|
|
attempt["hostaddr"] = hostaddrs[i]
|
|
if ports:
|
|
attempt["port"] = ports[i]
|
|
rv.append(attempt)
|
|
|
|
return rv
|
|
|
|
|
|
def get_param(params: ConnDict, name: str) -> str | None:
|
|
"""
|
|
Return a value from a connection string.
|
|
|
|
The value may be also specified in a PG* env var.
|
|
"""
|
|
if name in params:
|
|
return str(params[name])
|
|
|
|
# TODO: check if in service
|
|
|
|
paramdef = get_param_def(name)
|
|
if not paramdef:
|
|
return None
|
|
|
|
env = os.environ.get(paramdef.envvar)
|
|
if env is not None:
|
|
return env
|
|
|
|
return None
|
|
|
|
|
|
@dataclass
|
|
class ParamDef:
|
|
"""
|
|
Information about defaults and env vars for connection params
|
|
"""
|
|
|
|
keyword: str
|
|
envvar: str
|
|
compiled: str | None
|
|
|
|
|
|
def get_param_def(keyword: str, _cache: dict[str, ParamDef] = {}) -> ParamDef | None:
|
|
"""
|
|
Return the ParamDef of a connection string parameter.
|
|
"""
|
|
if not _cache:
|
|
defs = pq.Conninfo.get_defaults()
|
|
for d in defs:
|
|
cd = ParamDef(
|
|
keyword=d.keyword.decode(),
|
|
envvar=d.envvar.decode() if d.envvar else "",
|
|
compiled=d.compiled.decode() if d.compiled is not None else None,
|
|
)
|
|
_cache[cd.keyword] = cd
|
|
|
|
return _cache.get(keyword)
|
|
|
|
|
|
@lru_cache()
|
|
def is_ip_address(s: str) -> bool:
|
|
"""Return True if the string represent a valid ip address."""
|
|
try:
|
|
ip_address(s)
|
|
except ValueError:
|
|
return False
|
|
return True
|