""" Separate connection attempts from a connection string. """ # Copyright (C) 2024 The Psycopg Team from __future__ import annotations import socket import logging from random import shuffle from . import errors as e from ._conninfo_utils import ConnDict, get_param, is_ip_address, get_param_def from ._conninfo_utils import split_attempts logger = logging.getLogger("psycopg") def conninfo_attempts(params: ConnDict) -> list[ConnDict]: """Split a set of connection params on the single attempts to perform. A connection param can perform more than one attempt more than one ``host`` is provided. Also perform async resolution of the hostname into hostaddr. Because a host can resolve to more than one address, this can lead to yield more attempts too. Raise `OperationalError` if no host could be resolved. Because the libpq async function doesn't honour the timeout, we need to reimplement the repeated attempts. """ last_exc = None attempts = [] for attempt in split_attempts(params): try: attempts.extend(_resolve_hostnames(attempt)) except OSError as ex: logger.debug("failed to resolve host %r: %s", attempt.get("host"), str(ex)) last_exc = ex if not attempts: assert last_exc # We couldn't resolve anything raise e.OperationalError(str(last_exc)) if get_param(params, "load_balance_hosts") == "random": shuffle(attempts) return attempts def _resolve_hostnames(params: ConnDict) -> list[ConnDict]: """ Perform DNS lookup of the hosts and return a list of connection attempts. If a ``host`` param is present but not ``hostname``, resolve the host addresses asynchronously. :param params: The input parameters, for instance as returned by `~psycopg.conninfo.conninfo_to_dict()`. The function expects at most a single entry for host, hostaddr because it is designed to further process the input of split_attempts(). :return: A list of attempts to make (to include the case of a hostname resolving to more than one IP). """ host = get_param(params, "host") if not host or host.startswith("/") or host[1:2] == ":": # Local path, or no host to resolve return [params] hostaddr = get_param(params, "hostaddr") if hostaddr: # Already resolved return [params] if is_ip_address(host): # If the host is already an ip address don't try to resolve it return [{**params, "hostaddr": host}] port = get_param(params, "port") if not port: port_def = get_param_def("port") port = port_def and port_def.compiled or "5432" ans = socket.getaddrinfo( host, port, proto=socket.IPPROTO_TCP, type=socket.SOCK_STREAM ) return [{**params, "hostaddr": item[4][0]} for item in ans]