You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
440 lines
16 KiB
440 lines
16 KiB
from __future__ import annotations
|
|
|
|
import re
|
|
import textwrap
|
|
from collections.abc import Iterable
|
|
from typing import Any, Optional, Callable
|
|
|
|
from . import inspect as mi, to_builtins
|
|
|
|
__all__ = ("schema", "schema_components")
|
|
|
|
|
|
def schema(
|
|
type: Any, *, schema_hook: Optional[Callable[[type], dict[str, Any]]] = None
|
|
) -> dict[str, Any]:
|
|
"""Generate a JSON Schema for a given type.
|
|
|
|
Any schemas for (potentially) shared components are extracted and stored in
|
|
a top-level ``"$defs"`` field.
|
|
|
|
If you want to generate schemas for multiple types, or to have more control
|
|
over the generated schema you may want to use ``schema_components`` instead.
|
|
|
|
Parameters
|
|
----------
|
|
type : type
|
|
The type to generate the schema for.
|
|
schema_hook : callable, optional
|
|
An optional callback to use for generating JSON schemas of custom
|
|
types. Will be called with the custom type, and should return a dict
|
|
representation of the JSON schema for that type.
|
|
|
|
Returns
|
|
-------
|
|
schema : dict
|
|
The generated JSON Schema.
|
|
|
|
See Also
|
|
--------
|
|
schema_components
|
|
"""
|
|
(out,), components = schema_components((type,), schema_hook=schema_hook)
|
|
if components:
|
|
out["$defs"] = components
|
|
return out
|
|
|
|
|
|
def schema_components(
|
|
types: Iterable[Any],
|
|
*,
|
|
schema_hook: Optional[Callable[[type], dict[str, Any]]] = None,
|
|
ref_template: str = "#/$defs/{name}",
|
|
) -> tuple[tuple[dict[str, Any], ...], dict[str, Any]]:
|
|
"""Generate JSON Schemas for one or more types.
|
|
|
|
Any schemas for (potentially) shared components are extracted and returned
|
|
in a separate ``components`` dict.
|
|
|
|
Parameters
|
|
----------
|
|
types : Iterable[type]
|
|
An iterable of one or more types to generate schemas for.
|
|
schema_hook : callable, optional
|
|
An optional callback to use for generating JSON schemas of custom
|
|
types. Will be called with the custom type, and should return a dict
|
|
representation of the JSON schema for that type.
|
|
ref_template : str, optional
|
|
A template to use when generating ``"$ref"`` fields. This template is
|
|
formatted with the type name as ``template.format(name=name)``. This
|
|
can be useful if you intend to store the ``components`` mapping
|
|
somewhere other than a top-level ``"$defs"`` field. For example, you
|
|
might use ``ref_template="#/components/{name}"`` if generating an
|
|
OpenAPI schema.
|
|
|
|
Returns
|
|
-------
|
|
schemas : tuple[dict]
|
|
A tuple of JSON Schemas, one for each type in ``types``.
|
|
components : dict
|
|
A mapping of name to schema for any shared components used by
|
|
``schemas``.
|
|
|
|
See Also
|
|
--------
|
|
schema
|
|
"""
|
|
type_infos = mi.multi_type_info(types)
|
|
|
|
component_types = _collect_component_types(type_infos)
|
|
|
|
name_map = _build_name_map(component_types)
|
|
|
|
gen = _SchemaGenerator(name_map, schema_hook, ref_template)
|
|
|
|
schemas = tuple(gen.to_schema(t) for t in type_infos)
|
|
|
|
components = {
|
|
name_map[cls]: gen.to_schema(t, False) for cls, t in component_types.items()
|
|
}
|
|
return schemas, components
|
|
|
|
|
|
def _collect_component_types(type_infos: Iterable[mi.Type]) -> dict[Any, mi.Type]:
|
|
"""Find all types in the type tree that are "nameable" and worthy of being
|
|
extracted out into a shared top-level components mapping.
|
|
|
|
Currently this looks for Struct, Dataclass, NamedTuple, TypedDict, and Enum
|
|
types.
|
|
"""
|
|
components = {}
|
|
|
|
def collect(t):
|
|
if isinstance(
|
|
t, (mi.StructType, mi.TypedDictType, mi.DataclassType, mi.NamedTupleType)
|
|
):
|
|
if t.cls not in components:
|
|
components[t.cls] = t
|
|
for f in t.fields:
|
|
collect(f.type)
|
|
elif isinstance(t, mi.EnumType):
|
|
components[t.cls] = t
|
|
elif isinstance(t, mi.Metadata):
|
|
collect(t.type)
|
|
elif isinstance(t, mi.CollectionType):
|
|
collect(t.item_type)
|
|
elif isinstance(t, mi.TupleType):
|
|
for st in t.item_types:
|
|
collect(st)
|
|
elif isinstance(t, mi.DictType):
|
|
collect(t.key_type)
|
|
collect(t.value_type)
|
|
elif isinstance(t, mi.UnionType):
|
|
for st in t.types:
|
|
collect(st)
|
|
|
|
for t in type_infos:
|
|
collect(t)
|
|
|
|
return components
|
|
|
|
|
|
def _type_repr(obj):
|
|
return obj.__name__ if isinstance(obj, type) else repr(obj)
|
|
|
|
|
|
def _get_class_name(cls: Any) -> str:
|
|
if hasattr(cls, "__origin__"):
|
|
name = cls.__origin__.__name__
|
|
args = ", ".join(_type_repr(a) for a in cls.__args__)
|
|
return f"{name}[{args}]"
|
|
return cls.__name__
|
|
|
|
|
|
def _get_doc(t: mi.Type) -> str:
|
|
assert hasattr(t, "cls")
|
|
cls = getattr(t.cls, "__origin__", t.cls)
|
|
doc = getattr(cls, "__doc__", "")
|
|
if not doc:
|
|
return ""
|
|
doc = textwrap.dedent(doc).strip("\r\n")
|
|
if isinstance(t, mi.EnumType):
|
|
if doc == "An enumeration.":
|
|
return ""
|
|
elif isinstance(t, (mi.NamedTupleType, mi.DataclassType)):
|
|
if doc.startswith(f"{cls.__name__}(") and doc.endswith(")"):
|
|
return ""
|
|
return doc
|
|
|
|
|
|
def _build_name_map(component_types: dict[Any, mi.Type]) -> dict[Any, str]:
|
|
"""A mapping from nameable subcomponents to a generated name.
|
|
|
|
The generated name is usually a normalized version of the class name. In
|
|
the case of conflicts, the name will be expanded to also include the full
|
|
import path.
|
|
"""
|
|
|
|
def normalize(name):
|
|
return re.sub(r"[^a-zA-Z0-9.\-_]", "_", name)
|
|
|
|
def fullname(cls):
|
|
return normalize(f"{cls.__module__}.{cls.__qualname__}")
|
|
|
|
conflicts = set()
|
|
names: dict[str, Any] = {}
|
|
|
|
for cls in component_types:
|
|
name = normalize(_get_class_name(cls))
|
|
if name in names:
|
|
old = names.pop(name)
|
|
conflicts.add(name)
|
|
names[fullname(old)] = old
|
|
if name in conflicts:
|
|
names[fullname(cls)] = cls
|
|
else:
|
|
names[name] = cls
|
|
return {v: k for k, v in names.items()}
|
|
|
|
|
|
class _SchemaGenerator:
|
|
def __init__(
|
|
self,
|
|
name_map: dict[Any, str],
|
|
schema_hook: Optional[Callable[[type], dict[str, Any]]] = None,
|
|
ref_template: str = "#/$defs/{name}",
|
|
):
|
|
self.name_map = name_map
|
|
self.schema_hook = schema_hook
|
|
self.ref_template = ref_template
|
|
|
|
def to_schema(self, t: mi.Type, check_ref: bool = True) -> dict[str, Any]:
|
|
"""Converts a Type to a json-schema."""
|
|
schema: dict[str, Any] = {}
|
|
|
|
while isinstance(t, mi.Metadata):
|
|
schema = mi._merge_json(schema, t.extra_json_schema)
|
|
t = t.type
|
|
|
|
if check_ref and hasattr(t, "cls"):
|
|
if name := self.name_map.get(t.cls):
|
|
schema["$ref"] = self.ref_template.format(name=name)
|
|
return schema
|
|
|
|
if isinstance(t, (mi.AnyType, mi.RawType)):
|
|
pass
|
|
elif isinstance(t, mi.NoneType):
|
|
schema["type"] = "null"
|
|
elif isinstance(t, mi.BoolType):
|
|
schema["type"] = "boolean"
|
|
elif isinstance(t, (mi.IntType, mi.FloatType)):
|
|
schema["type"] = "integer" if isinstance(t, mi.IntType) else "number"
|
|
if t.ge is not None:
|
|
schema["minimum"] = t.ge
|
|
if t.gt is not None:
|
|
schema["exclusiveMinimum"] = t.gt
|
|
if t.le is not None:
|
|
schema["maximum"] = t.le
|
|
if t.lt is not None:
|
|
schema["exclusiveMaximum"] = t.lt
|
|
if t.multiple_of is not None:
|
|
schema["multipleOf"] = t.multiple_of
|
|
elif isinstance(t, mi.StrType):
|
|
schema["type"] = "string"
|
|
if t.max_length is not None:
|
|
schema["maxLength"] = t.max_length
|
|
if t.min_length is not None:
|
|
schema["minLength"] = t.min_length
|
|
if t.pattern is not None:
|
|
schema["pattern"] = t.pattern
|
|
elif isinstance(t, (mi.BytesType, mi.ByteArrayType, mi.MemoryViewType)):
|
|
schema["type"] = "string"
|
|
schema["contentEncoding"] = "base64"
|
|
if t.max_length is not None:
|
|
schema["maxLength"] = 4 * ((t.max_length + 2) // 3)
|
|
if t.min_length is not None:
|
|
schema["minLength"] = 4 * ((t.min_length + 2) // 3)
|
|
elif isinstance(t, mi.DateTimeType):
|
|
schema["type"] = "string"
|
|
if t.tz is True:
|
|
schema["format"] = "date-time"
|
|
elif isinstance(t, mi.TimeType):
|
|
schema["type"] = "string"
|
|
if t.tz is True:
|
|
schema["format"] = "time"
|
|
elif t.tz is False:
|
|
schema["format"] = "partial-time"
|
|
elif isinstance(t, mi.DateType):
|
|
schema["type"] = "string"
|
|
schema["format"] = "date"
|
|
elif isinstance(t, mi.TimeDeltaType):
|
|
schema["type"] = "string"
|
|
schema["format"] = "duration"
|
|
elif isinstance(t, mi.UUIDType):
|
|
schema["type"] = "string"
|
|
schema["format"] = "uuid"
|
|
elif isinstance(t, mi.DecimalType):
|
|
schema["type"] = "string"
|
|
schema["format"] = "decimal"
|
|
elif isinstance(t, mi.CollectionType):
|
|
schema["type"] = "array"
|
|
if not isinstance(t.item_type, mi.AnyType):
|
|
schema["items"] = self.to_schema(t.item_type)
|
|
if t.max_length is not None:
|
|
schema["maxItems"] = t.max_length
|
|
if t.min_length is not None:
|
|
schema["minItems"] = t.min_length
|
|
elif isinstance(t, mi.TupleType):
|
|
schema["type"] = "array"
|
|
schema["minItems"] = schema["maxItems"] = len(t.item_types)
|
|
if t.item_types:
|
|
schema["prefixItems"] = [self.to_schema(i) for i in t.item_types]
|
|
schema["items"] = False
|
|
elif isinstance(t, mi.DictType):
|
|
schema["type"] = "object"
|
|
# If there are restrictions on the keys, specify them as propertyNames
|
|
if isinstance(key_type := t.key_type, mi.StrType):
|
|
property_names: dict[str, Any] = {}
|
|
if key_type.min_length is not None:
|
|
property_names["minLength"] = key_type.min_length
|
|
if key_type.max_length is not None:
|
|
property_names["maxLength"] = key_type.max_length
|
|
if key_type.pattern is not None:
|
|
property_names["pattern"] = key_type.pattern
|
|
if property_names:
|
|
schema["propertyNames"] = property_names
|
|
if not isinstance(t.value_type, mi.AnyType):
|
|
schema["additionalProperties"] = self.to_schema(t.value_type)
|
|
if t.max_length is not None:
|
|
schema["maxProperties"] = t.max_length
|
|
if t.min_length is not None:
|
|
schema["minProperties"] = t.min_length
|
|
elif isinstance(t, mi.UnionType):
|
|
structs = {}
|
|
other = []
|
|
tag_field = None
|
|
for subtype in t.types:
|
|
real_type = subtype
|
|
while isinstance(real_type, mi.Metadata):
|
|
real_type = real_type.type
|
|
if isinstance(real_type, mi.StructType) and not real_type.array_like:
|
|
tag_field = real_type.tag_field
|
|
structs[real_type.tag] = real_type
|
|
else:
|
|
other.append(subtype)
|
|
|
|
options = [self.to_schema(a) for a in other]
|
|
|
|
if len(structs) >= 2:
|
|
mapping = {
|
|
k: self.ref_template.format(name=self.name_map[v.cls])
|
|
for k, v in structs.items()
|
|
}
|
|
struct_schema = {
|
|
"anyOf": [self.to_schema(v) for v in structs.values()],
|
|
"discriminator": {"propertyName": tag_field, "mapping": mapping},
|
|
}
|
|
if options:
|
|
options.append(struct_schema)
|
|
schema["anyOf"] = options
|
|
else:
|
|
schema.update(struct_schema)
|
|
elif len(structs) == 1:
|
|
_, subtype = structs.popitem()
|
|
options.append(self.to_schema(subtype))
|
|
schema["anyOf"] = options
|
|
else:
|
|
schema["anyOf"] = options
|
|
elif isinstance(t, mi.LiteralType):
|
|
schema["enum"] = sorted(t.values)
|
|
elif isinstance(t, mi.EnumType):
|
|
schema.setdefault("title", t.cls.__name__)
|
|
if doc := _get_doc(t):
|
|
schema.setdefault("description", doc)
|
|
schema["enum"] = sorted(e.value for e in t.cls)
|
|
elif isinstance(t, mi.StructType):
|
|
schema.setdefault("title", _get_class_name(t.cls))
|
|
if doc := _get_doc(t):
|
|
schema.setdefault("description", doc)
|
|
required = []
|
|
names = []
|
|
fields = []
|
|
|
|
if t.tag_field is not None:
|
|
required.append(t.tag_field)
|
|
names.append(t.tag_field)
|
|
fields.append({"enum": [t.tag]})
|
|
|
|
for field in t.fields:
|
|
field_schema = self.to_schema(field.type)
|
|
if field.required:
|
|
required.append(field.encode_name)
|
|
elif field.default is not mi.NODEFAULT:
|
|
field_schema["default"] = to_builtins(field.default, str_keys=True)
|
|
elif field.default_factory in (list, dict, set, bytearray):
|
|
field_schema["default"] = field.default_factory()
|
|
names.append(field.encode_name)
|
|
fields.append(field_schema)
|
|
|
|
if t.array_like:
|
|
n_trailing_defaults = 0
|
|
for n_trailing_defaults, f in enumerate(reversed(t.fields)):
|
|
if f.required:
|
|
break
|
|
schema["type"] = "array"
|
|
schema["prefixItems"] = fields
|
|
schema["minItems"] = len(fields) - n_trailing_defaults
|
|
if t.forbid_unknown_fields:
|
|
schema["maxItems"] = len(fields)
|
|
else:
|
|
schema["type"] = "object"
|
|
schema["properties"] = dict(zip(names, fields))
|
|
schema["required"] = required
|
|
if t.forbid_unknown_fields:
|
|
schema["additionalProperties"] = False
|
|
elif isinstance(t, (mi.TypedDictType, mi.DataclassType, mi.NamedTupleType)):
|
|
schema.setdefault("title", _get_class_name(t.cls))
|
|
if doc := _get_doc(t):
|
|
schema.setdefault("description", doc)
|
|
names = []
|
|
fields = []
|
|
required = []
|
|
for field in t.fields:
|
|
field_schema = self.to_schema(field.type)
|
|
if field.required:
|
|
required.append(field.encode_name)
|
|
elif field.default is not mi.NODEFAULT:
|
|
field_schema["default"] = to_builtins(field.default, str_keys=True)
|
|
names.append(field.encode_name)
|
|
fields.append(field_schema)
|
|
if isinstance(t, mi.NamedTupleType):
|
|
schema["type"] = "array"
|
|
schema["prefixItems"] = fields
|
|
schema["minItems"] = len(required)
|
|
schema["maxItems"] = len(fields)
|
|
else:
|
|
schema["type"] = "object"
|
|
schema["properties"] = dict(zip(names, fields))
|
|
schema["required"] = required
|
|
elif isinstance(t, mi.ExtType):
|
|
raise TypeError("json-schema doesn't support msgpack Ext types")
|
|
elif isinstance(t, mi.CustomType):
|
|
if self.schema_hook:
|
|
try:
|
|
schema = mi._merge_json(self.schema_hook(t.cls), schema)
|
|
except NotImplementedError:
|
|
pass
|
|
if not schema:
|
|
raise TypeError(
|
|
"Generating JSON schema for custom types requires either:\n"
|
|
"- specifying a `schema_hook`\n"
|
|
"- annotating the type with `Meta(extra_json_schema=...)`\n"
|
|
"\n"
|
|
f"type {t.cls!r} is not supported"
|
|
)
|
|
else:
|
|
# This should be unreachable
|
|
raise TypeError(f"json-schema doesn't support type {t!r}")
|
|
|
|
return schema
|