You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

440 lines
16 KiB

from __future__ import annotations
import re
import textwrap
from collections.abc import Iterable
from typing import Any, Optional, Callable
from . import inspect as mi, to_builtins
__all__ = ("schema", "schema_components")
def schema(
type: Any, *, schema_hook: Optional[Callable[[type], dict[str, Any]]] = None
) -> dict[str, Any]:
"""Generate a JSON Schema for a given type.
Any schemas for (potentially) shared components are extracted and stored in
a top-level ``"$defs"`` field.
If you want to generate schemas for multiple types, or to have more control
over the generated schema you may want to use ``schema_components`` instead.
Parameters
----------
type : type
The type to generate the schema for.
schema_hook : callable, optional
An optional callback to use for generating JSON schemas of custom
types. Will be called with the custom type, and should return a dict
representation of the JSON schema for that type.
Returns
-------
schema : dict
The generated JSON Schema.
See Also
--------
schema_components
"""
(out,), components = schema_components((type,), schema_hook=schema_hook)
if components:
out["$defs"] = components
return out
def schema_components(
types: Iterable[Any],
*,
schema_hook: Optional[Callable[[type], dict[str, Any]]] = None,
ref_template: str = "#/$defs/{name}",
) -> tuple[tuple[dict[str, Any], ...], dict[str, Any]]:
"""Generate JSON Schemas for one or more types.
Any schemas for (potentially) shared components are extracted and returned
in a separate ``components`` dict.
Parameters
----------
types : Iterable[type]
An iterable of one or more types to generate schemas for.
schema_hook : callable, optional
An optional callback to use for generating JSON schemas of custom
types. Will be called with the custom type, and should return a dict
representation of the JSON schema for that type.
ref_template : str, optional
A template to use when generating ``"$ref"`` fields. This template is
formatted with the type name as ``template.format(name=name)``. This
can be useful if you intend to store the ``components`` mapping
somewhere other than a top-level ``"$defs"`` field. For example, you
might use ``ref_template="#/components/{name}"`` if generating an
OpenAPI schema.
Returns
-------
schemas : tuple[dict]
A tuple of JSON Schemas, one for each type in ``types``.
components : dict
A mapping of name to schema for any shared components used by
``schemas``.
See Also
--------
schema
"""
type_infos = mi.multi_type_info(types)
component_types = _collect_component_types(type_infos)
name_map = _build_name_map(component_types)
gen = _SchemaGenerator(name_map, schema_hook, ref_template)
schemas = tuple(gen.to_schema(t) for t in type_infos)
components = {
name_map[cls]: gen.to_schema(t, False) for cls, t in component_types.items()
}
return schemas, components
def _collect_component_types(type_infos: Iterable[mi.Type]) -> dict[Any, mi.Type]:
"""Find all types in the type tree that are "nameable" and worthy of being
extracted out into a shared top-level components mapping.
Currently this looks for Struct, Dataclass, NamedTuple, TypedDict, and Enum
types.
"""
components = {}
def collect(t):
if isinstance(
t, (mi.StructType, mi.TypedDictType, mi.DataclassType, mi.NamedTupleType)
):
if t.cls not in components:
components[t.cls] = t
for f in t.fields:
collect(f.type)
elif isinstance(t, mi.EnumType):
components[t.cls] = t
elif isinstance(t, mi.Metadata):
collect(t.type)
elif isinstance(t, mi.CollectionType):
collect(t.item_type)
elif isinstance(t, mi.TupleType):
for st in t.item_types:
collect(st)
elif isinstance(t, mi.DictType):
collect(t.key_type)
collect(t.value_type)
elif isinstance(t, mi.UnionType):
for st in t.types:
collect(st)
for t in type_infos:
collect(t)
return components
def _type_repr(obj):
return obj.__name__ if isinstance(obj, type) else repr(obj)
def _get_class_name(cls: Any) -> str:
if hasattr(cls, "__origin__"):
name = cls.__origin__.__name__
args = ", ".join(_type_repr(a) for a in cls.__args__)
return f"{name}[{args}]"
return cls.__name__
def _get_doc(t: mi.Type) -> str:
assert hasattr(t, "cls")
cls = getattr(t.cls, "__origin__", t.cls)
doc = getattr(cls, "__doc__", "")
if not doc:
return ""
doc = textwrap.dedent(doc).strip("\r\n")
if isinstance(t, mi.EnumType):
if doc == "An enumeration.":
return ""
elif isinstance(t, (mi.NamedTupleType, mi.DataclassType)):
if doc.startswith(f"{cls.__name__}(") and doc.endswith(")"):
return ""
return doc
def _build_name_map(component_types: dict[Any, mi.Type]) -> dict[Any, str]:
"""A mapping from nameable subcomponents to a generated name.
The generated name is usually a normalized version of the class name. In
the case of conflicts, the name will be expanded to also include the full
import path.
"""
def normalize(name):
return re.sub(r"[^a-zA-Z0-9.\-_]", "_", name)
def fullname(cls):
return normalize(f"{cls.__module__}.{cls.__qualname__}")
conflicts = set()
names: dict[str, Any] = {}
for cls in component_types:
name = normalize(_get_class_name(cls))
if name in names:
old = names.pop(name)
conflicts.add(name)
names[fullname(old)] = old
if name in conflicts:
names[fullname(cls)] = cls
else:
names[name] = cls
return {v: k for k, v in names.items()}
class _SchemaGenerator:
def __init__(
self,
name_map: dict[Any, str],
schema_hook: Optional[Callable[[type], dict[str, Any]]] = None,
ref_template: str = "#/$defs/{name}",
):
self.name_map = name_map
self.schema_hook = schema_hook
self.ref_template = ref_template
def to_schema(self, t: mi.Type, check_ref: bool = True) -> dict[str, Any]:
"""Converts a Type to a json-schema."""
schema: dict[str, Any] = {}
while isinstance(t, mi.Metadata):
schema = mi._merge_json(schema, t.extra_json_schema)
t = t.type
if check_ref and hasattr(t, "cls"):
if name := self.name_map.get(t.cls):
schema["$ref"] = self.ref_template.format(name=name)
return schema
if isinstance(t, (mi.AnyType, mi.RawType)):
pass
elif isinstance(t, mi.NoneType):
schema["type"] = "null"
elif isinstance(t, mi.BoolType):
schema["type"] = "boolean"
elif isinstance(t, (mi.IntType, mi.FloatType)):
schema["type"] = "integer" if isinstance(t, mi.IntType) else "number"
if t.ge is not None:
schema["minimum"] = t.ge
if t.gt is not None:
schema["exclusiveMinimum"] = t.gt
if t.le is not None:
schema["maximum"] = t.le
if t.lt is not None:
schema["exclusiveMaximum"] = t.lt
if t.multiple_of is not None:
schema["multipleOf"] = t.multiple_of
elif isinstance(t, mi.StrType):
schema["type"] = "string"
if t.max_length is not None:
schema["maxLength"] = t.max_length
if t.min_length is not None:
schema["minLength"] = t.min_length
if t.pattern is not None:
schema["pattern"] = t.pattern
elif isinstance(t, (mi.BytesType, mi.ByteArrayType, mi.MemoryViewType)):
schema["type"] = "string"
schema["contentEncoding"] = "base64"
if t.max_length is not None:
schema["maxLength"] = 4 * ((t.max_length + 2) // 3)
if t.min_length is not None:
schema["minLength"] = 4 * ((t.min_length + 2) // 3)
elif isinstance(t, mi.DateTimeType):
schema["type"] = "string"
if t.tz is True:
schema["format"] = "date-time"
elif isinstance(t, mi.TimeType):
schema["type"] = "string"
if t.tz is True:
schema["format"] = "time"
elif t.tz is False:
schema["format"] = "partial-time"
elif isinstance(t, mi.DateType):
schema["type"] = "string"
schema["format"] = "date"
elif isinstance(t, mi.TimeDeltaType):
schema["type"] = "string"
schema["format"] = "duration"
elif isinstance(t, mi.UUIDType):
schema["type"] = "string"
schema["format"] = "uuid"
elif isinstance(t, mi.DecimalType):
schema["type"] = "string"
schema["format"] = "decimal"
elif isinstance(t, mi.CollectionType):
schema["type"] = "array"
if not isinstance(t.item_type, mi.AnyType):
schema["items"] = self.to_schema(t.item_type)
if t.max_length is not None:
schema["maxItems"] = t.max_length
if t.min_length is not None:
schema["minItems"] = t.min_length
elif isinstance(t, mi.TupleType):
schema["type"] = "array"
schema["minItems"] = schema["maxItems"] = len(t.item_types)
if t.item_types:
schema["prefixItems"] = [self.to_schema(i) for i in t.item_types]
schema["items"] = False
elif isinstance(t, mi.DictType):
schema["type"] = "object"
# If there are restrictions on the keys, specify them as propertyNames
if isinstance(key_type := t.key_type, mi.StrType):
property_names: dict[str, Any] = {}
if key_type.min_length is not None:
property_names["minLength"] = key_type.min_length
if key_type.max_length is not None:
property_names["maxLength"] = key_type.max_length
if key_type.pattern is not None:
property_names["pattern"] = key_type.pattern
if property_names:
schema["propertyNames"] = property_names
if not isinstance(t.value_type, mi.AnyType):
schema["additionalProperties"] = self.to_schema(t.value_type)
if t.max_length is not None:
schema["maxProperties"] = t.max_length
if t.min_length is not None:
schema["minProperties"] = t.min_length
elif isinstance(t, mi.UnionType):
structs = {}
other = []
tag_field = None
for subtype in t.types:
real_type = subtype
while isinstance(real_type, mi.Metadata):
real_type = real_type.type
if isinstance(real_type, mi.StructType) and not real_type.array_like:
tag_field = real_type.tag_field
structs[real_type.tag] = real_type
else:
other.append(subtype)
options = [self.to_schema(a) for a in other]
if len(structs) >= 2:
mapping = {
k: self.ref_template.format(name=self.name_map[v.cls])
for k, v in structs.items()
}
struct_schema = {
"anyOf": [self.to_schema(v) for v in structs.values()],
"discriminator": {"propertyName": tag_field, "mapping": mapping},
}
if options:
options.append(struct_schema)
schema["anyOf"] = options
else:
schema.update(struct_schema)
elif len(structs) == 1:
_, subtype = structs.popitem()
options.append(self.to_schema(subtype))
schema["anyOf"] = options
else:
schema["anyOf"] = options
elif isinstance(t, mi.LiteralType):
schema["enum"] = sorted(t.values)
elif isinstance(t, mi.EnumType):
schema.setdefault("title", t.cls.__name__)
if doc := _get_doc(t):
schema.setdefault("description", doc)
schema["enum"] = sorted(e.value for e in t.cls)
elif isinstance(t, mi.StructType):
schema.setdefault("title", _get_class_name(t.cls))
if doc := _get_doc(t):
schema.setdefault("description", doc)
required = []
names = []
fields = []
if t.tag_field is not None:
required.append(t.tag_field)
names.append(t.tag_field)
fields.append({"enum": [t.tag]})
for field in t.fields:
field_schema = self.to_schema(field.type)
if field.required:
required.append(field.encode_name)
elif field.default is not mi.NODEFAULT:
field_schema["default"] = to_builtins(field.default, str_keys=True)
elif field.default_factory in (list, dict, set, bytearray):
field_schema["default"] = field.default_factory()
names.append(field.encode_name)
fields.append(field_schema)
if t.array_like:
n_trailing_defaults = 0
for n_trailing_defaults, f in enumerate(reversed(t.fields)):
if f.required:
break
schema["type"] = "array"
schema["prefixItems"] = fields
schema["minItems"] = len(fields) - n_trailing_defaults
if t.forbid_unknown_fields:
schema["maxItems"] = len(fields)
else:
schema["type"] = "object"
schema["properties"] = dict(zip(names, fields))
schema["required"] = required
if t.forbid_unknown_fields:
schema["additionalProperties"] = False
elif isinstance(t, (mi.TypedDictType, mi.DataclassType, mi.NamedTupleType)):
schema.setdefault("title", _get_class_name(t.cls))
if doc := _get_doc(t):
schema.setdefault("description", doc)
names = []
fields = []
required = []
for field in t.fields:
field_schema = self.to_schema(field.type)
if field.required:
required.append(field.encode_name)
elif field.default is not mi.NODEFAULT:
field_schema["default"] = to_builtins(field.default, str_keys=True)
names.append(field.encode_name)
fields.append(field_schema)
if isinstance(t, mi.NamedTupleType):
schema["type"] = "array"
schema["prefixItems"] = fields
schema["minItems"] = len(required)
schema["maxItems"] = len(fields)
else:
schema["type"] = "object"
schema["properties"] = dict(zip(names, fields))
schema["required"] = required
elif isinstance(t, mi.ExtType):
raise TypeError("json-schema doesn't support msgpack Ext types")
elif isinstance(t, mi.CustomType):
if self.schema_hook:
try:
schema = mi._merge_json(self.schema_hook(t.cls), schema)
except NotImplementedError:
pass
if not schema:
raise TypeError(
"Generating JSON schema for custom types requires either:\n"
"- specifying a `schema_hook`\n"
"- annotating the type with `Meta(extra_json_schema=...)`\n"
"\n"
f"type {t.cls!r} is not supported"
)
else:
# This should be unreachable
raise TypeError(f"json-schema doesn't support type {t!r}")
return schema