You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
290 lines
8.8 KiB
290 lines
8.8 KiB
# type: ignore
|
|
import collections
|
|
import sys
|
|
import typing
|
|
|
|
try:
|
|
from typing_extensions import _AnnotatedAlias
|
|
except Exception:
|
|
try:
|
|
from typing import _AnnotatedAlias
|
|
except Exception:
|
|
_AnnotatedAlias = None
|
|
|
|
try:
|
|
from typing_extensions import get_type_hints as _get_type_hints
|
|
except Exception:
|
|
from typing import get_type_hints as _get_type_hints
|
|
|
|
try:
|
|
from typing_extensions import NotRequired, Required
|
|
except Exception:
|
|
try:
|
|
from typing import NotRequired, Required
|
|
except Exception:
|
|
Required = NotRequired = None
|
|
|
|
|
|
if Required is None and _AnnotatedAlias is None:
|
|
# No extras available, so no `include_extras`
|
|
get_type_hints = _get_type_hints
|
|
else:
|
|
|
|
def get_type_hints(obj):
|
|
return _get_type_hints(obj, include_extras=True)
|
|
|
|
|
|
# The `is_class` argument was new in 3.11, but was backported to 3.9 and 3.10.
|
|
# It's _likely_ to be available for 3.9/3.10, but may not be. Easiest way to
|
|
# check is to try it and see. This check can be removed when we drop support
|
|
# for Python 3.10.
|
|
try:
|
|
typing.ForwardRef("Foo", is_class=True)
|
|
except TypeError:
|
|
|
|
def _forward_ref(value):
|
|
return typing.ForwardRef(value, is_argument=False)
|
|
|
|
else:
|
|
|
|
def _forward_ref(value):
|
|
return typing.ForwardRef(value, is_argument=False, is_class=True)
|
|
|
|
|
|
def _apply_params(obj, mapping):
|
|
if params := getattr(obj, "__parameters__", None):
|
|
args = tuple(mapping.get(p, p) for p in params)
|
|
return obj[args]
|
|
elif isinstance(obj, typing.TypeVar):
|
|
return mapping.get(obj, obj)
|
|
return obj
|
|
|
|
|
|
def _get_class_mro_and_typevar_mappings(obj):
|
|
mapping = {}
|
|
|
|
if isinstance(obj, type):
|
|
cls = obj
|
|
else:
|
|
cls = obj.__origin__
|
|
|
|
def inner(c, scope):
|
|
if isinstance(c, type):
|
|
cls = c
|
|
new_scope = {}
|
|
else:
|
|
cls = getattr(c, "__origin__", None)
|
|
if cls in (None, object, typing.Generic) or cls in mapping:
|
|
return
|
|
params = cls.__parameters__
|
|
args = tuple(_apply_params(a, scope) for a in c.__args__)
|
|
assert len(params) == len(args)
|
|
mapping[cls] = new_scope = dict(zip(params, args))
|
|
|
|
if issubclass(cls, typing.Generic):
|
|
bases = getattr(cls, "__orig_bases__", cls.__bases__)
|
|
for b in bases:
|
|
inner(b, new_scope)
|
|
|
|
inner(obj, {})
|
|
return cls.__mro__, mapping
|
|
|
|
|
|
def get_class_annotations(obj):
|
|
"""Get the annotations for a class.
|
|
|
|
This is similar to ``typing.get_type_hints``, except:
|
|
|
|
- We maintain it
|
|
- It leaves extras like ``Annotated``/``ClassVar`` alone
|
|
- It resolves any parametrized generics in the class mro. The returned
|
|
mapping may still include ``TypeVar`` values, but those should be treated
|
|
as their unparametrized variants (i.e. equal to ``Any`` for the common case).
|
|
|
|
Note that this function doesn't check that Generic types are being used
|
|
properly - invalid uses of `Generic` may slip through without complaint.
|
|
|
|
The assumption here is that the user is making use of a static analysis
|
|
tool like ``mypy``/``pyright`` already, which would catch misuse of these
|
|
APIs.
|
|
"""
|
|
hints = {}
|
|
mro, typevar_mappings = _get_class_mro_and_typevar_mappings(obj)
|
|
|
|
for cls in mro:
|
|
if cls in (typing.Generic, object):
|
|
continue
|
|
|
|
mapping = typevar_mappings.get(cls)
|
|
cls_locals = dict(vars(cls))
|
|
cls_globals = getattr(sys.modules.get(cls.__module__, None), "__dict__", {})
|
|
|
|
ann = cls.__dict__.get("__annotations__", {})
|
|
for name, value in ann.items():
|
|
if name in hints:
|
|
continue
|
|
if value is None:
|
|
value = type(None)
|
|
elif isinstance(value, str):
|
|
value = _forward_ref(value)
|
|
value = typing._eval_type(value, cls_locals, cls_globals)
|
|
if mapping is not None:
|
|
value = _apply_params(value, mapping)
|
|
hints[name] = value
|
|
return hints
|
|
|
|
|
|
# A mapping from a type annotation (or annotation __origin__) to the concrete
|
|
# python type that msgspec will use when decoding. THIS IS PRIVATE FOR A
|
|
# REASON. DON'T MUCK WITH THIS.
|
|
_CONCRETE_TYPES = {
|
|
list: list,
|
|
tuple: tuple,
|
|
set: set,
|
|
frozenset: frozenset,
|
|
dict: dict,
|
|
typing.List: list,
|
|
typing.Tuple: tuple,
|
|
typing.Set: set,
|
|
typing.FrozenSet: frozenset,
|
|
typing.Dict: dict,
|
|
typing.Collection: list,
|
|
typing.MutableSequence: list,
|
|
typing.Sequence: list,
|
|
typing.MutableMapping: dict,
|
|
typing.Mapping: dict,
|
|
typing.MutableSet: set,
|
|
typing.AbstractSet: set,
|
|
collections.abc.Collection: list,
|
|
collections.abc.MutableSequence: list,
|
|
collections.abc.Sequence: list,
|
|
collections.abc.MutableSet: set,
|
|
collections.abc.Set: set,
|
|
collections.abc.MutableMapping: dict,
|
|
collections.abc.Mapping: dict,
|
|
}
|
|
|
|
|
|
def get_typeddict_info(obj):
|
|
if isinstance(obj, type):
|
|
cls = obj
|
|
else:
|
|
cls = obj.__origin__
|
|
|
|
raw_hints = get_class_annotations(obj)
|
|
|
|
if hasattr(cls, "__required_keys__"):
|
|
required = set(cls.__required_keys__)
|
|
elif cls.__total__:
|
|
required = set(raw_hints)
|
|
else:
|
|
required = set()
|
|
|
|
# Both `typing.TypedDict` and `typing_extensions.TypedDict` have a bug
|
|
# where `Required`/`NotRequired` aren't properly detected at runtime when
|
|
# `__future__.annotations` is enabled, meaning the `__required_keys__`
|
|
# isn't correct. This code block works around this issue by amending the
|
|
# set of required keys as needed, while also stripping off any
|
|
# `Required`/`NotRequired` wrappers.
|
|
hints = {}
|
|
for k, v in raw_hints.items():
|
|
origin = getattr(v, "__origin__", False)
|
|
if origin is Required:
|
|
required.add(k)
|
|
hints[k] = v.__args__[0]
|
|
elif origin is NotRequired:
|
|
required.discard(k)
|
|
hints[k] = v.__args__[0]
|
|
else:
|
|
hints[k] = v
|
|
return hints, required
|
|
|
|
|
|
def get_dataclass_info(obj):
|
|
if isinstance(obj, type):
|
|
cls = obj
|
|
else:
|
|
cls = obj.__origin__
|
|
hints = get_class_annotations(obj)
|
|
required = []
|
|
optional = []
|
|
defaults = []
|
|
|
|
if hasattr(cls, "__dataclass_fields__"):
|
|
from dataclasses import _FIELD, _FIELD_INITVAR, MISSING
|
|
|
|
for field in cls.__dataclass_fields__.values():
|
|
if field._field_type is not _FIELD:
|
|
if field._field_type is _FIELD_INITVAR:
|
|
raise TypeError(
|
|
"dataclasses with `InitVar` fields are not supported"
|
|
)
|
|
continue
|
|
name = field.name
|
|
typ = hints[name]
|
|
if field.default is not MISSING:
|
|
defaults.append(field.default)
|
|
optional.append((name, typ, False))
|
|
elif field.default_factory is not MISSING:
|
|
defaults.append(field.default_factory)
|
|
optional.append((name, typ, True))
|
|
else:
|
|
required.append((name, typ, False))
|
|
|
|
required.extend(optional)
|
|
|
|
pre_init = None
|
|
post_init = getattr(cls, "__post_init__", None)
|
|
else:
|
|
from attrs import NOTHING, Factory
|
|
|
|
fields_with_validators = []
|
|
|
|
for field in cls.__attrs_attrs__:
|
|
name = field.name
|
|
typ = hints[name]
|
|
default = field.default
|
|
if default is not NOTHING:
|
|
if isinstance(default, Factory):
|
|
if default.takes_self:
|
|
raise NotImplementedError(
|
|
"Support for default factories with `takes_self=True` "
|
|
"is not implemented. File a GitHub issue if you need "
|
|
"this feature!"
|
|
)
|
|
defaults.append(default.factory)
|
|
optional.append((name, typ, True))
|
|
else:
|
|
defaults.append(default)
|
|
optional.append((name, typ, False))
|
|
else:
|
|
required.append((name, typ, False))
|
|
|
|
if field.validator is not None:
|
|
fields_with_validators.append(field)
|
|
|
|
required.extend(optional)
|
|
|
|
pre_init = getattr(cls, "__attrs_pre_init__", None)
|
|
post_init = getattr(cls, "__attrs_post_init__", None)
|
|
|
|
if fields_with_validators:
|
|
post_init = _wrap_attrs_validators(fields_with_validators, post_init)
|
|
|
|
return cls, tuple(required), tuple(defaults), pre_init, post_init
|
|
|
|
|
|
def _wrap_attrs_validators(fields, post_init):
|
|
def inner(obj):
|
|
for field in fields:
|
|
field.validator(obj, field, getattr(obj, field.name))
|
|
if post_init is not None:
|
|
post_init(obj)
|
|
|
|
return inner
|
|
|
|
|
|
def rebuild(cls, kwargs):
|
|
"""Used to unpickle Structs with keyword-only fields"""
|
|
return cls(**kwargs)
|