Warning
This document is for an old release of Galaxy. You can alternatively view this page in the latest release if it exists or view the top of the latest release's documentation.
Source code for galaxy.webapps.openapi.utils
"""
Allows merging overlapping openapi definitions.
Taken from https://github.com/tiangolo/fastapi/pull/4727
"""
import http.client
import inspect
import itertools
import warnings
from enum import Enum
from typing import (
Any,
Callable,
cast,
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
Union,
)
from fastapi import routing
from fastapi.datastructures import DefaultPlaceholder
from fastapi.dependencies.models import Dependant
from fastapi.dependencies.utils import (
get_flat_dependant,
get_flat_params,
)
from fastapi.encoders import jsonable_encoder
from fastapi.openapi.constants import (
METHODS_WITH_BODY,
REF_PREFIX,
)
from fastapi.openapi.models import OpenAPI
from fastapi.params import (
Body,
Param,
)
from fastapi.responses import Response
from fastapi.utils import (
deep_dict_update,
generate_operation_id_for_path,
get_model_definitions,
is_body_allowed_for_status_code,
)
from pydantic import BaseModel
from pydantic.fields import (
ModelField,
Undefined,
)
from pydantic.schema import (
field_schema,
get_flat_models_from_fields,
get_model_name_map,
)
from pydantic.utils import lenient_issubclass
from starlette.responses import JSONResponse
from starlette.routing import BaseRoute
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
validation_error_definition = {
"title": "ValidationError",
"type": "object",
"properties": {
"loc": {"title": "Location", "type": "array", "items": {"type": "string"}},
"msg": {"title": "Message", "type": "string"},
"type": {"title": "Error Type", "type": "string"},
},
"required": ["loc", "msg", "type"],
}
validation_error_response_definition = {
"title": "HTTPValidationError",
"type": "object",
"properties": {
"detail": {
"title": "Detail",
"type": "array",
"items": {"$ref": REF_PREFIX + "ValidationError"},
}
},
}
status_code_ranges: Dict[str, str] = {
"1XX": "Information",
"2XX": "Success",
"3XX": "Redirection",
"4XX": "Client Error",
"5XX": "Server Error",
"DEFAULT": "Default Response",
}
[docs]def get_openapi_security_definitions(
flat_dependant: Dependant,
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
security_definitions = {}
operation_security = []
for security_requirement in flat_dependant.security_requirements:
security_definition = jsonable_encoder(
security_requirement.security_scheme.model,
by_alias=True,
exclude_none=True,
)
security_name = security_requirement.security_scheme.scheme_name
security_definitions[security_name] = security_definition
operation_security.append({security_name: security_requirement.scopes})
return security_definitions, operation_security
[docs]def get_openapi_operation_parameters(
*,
all_route_params: Sequence[ModelField],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
) -> List[Dict[str, Any]]:
parameters = []
for param in all_route_params:
field_info = param.field_info
field_info = cast(Param, field_info)
if not field_info.include_in_schema:
continue
parameter = {
"name": param.alias,
"in": field_info.in_.value,
"required": param.required,
"schema": field_schema(param, model_name_map=model_name_map, ref_prefix=REF_PREFIX)[0],
}
if field_info.description:
parameter["description"] = field_info.description
if field_info.examples:
parameter["examples"] = jsonable_encoder(field_info.examples)
elif field_info.example != Undefined:
parameter["example"] = jsonable_encoder(field_info.example)
if field_info.deprecated:
parameter["deprecated"] = field_info.deprecated
parameters.append(parameter)
return parameters
[docs]def get_openapi_operation_request_body(
*,
body_field: Optional[ModelField],
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
) -> Optional[Dict[str, Any]]:
if not body_field:
return None
assert isinstance(body_field, ModelField)
body_schema, _, _ = field_schema(body_field, model_name_map=model_name_map, ref_prefix=REF_PREFIX)
field_info = cast(Body, body_field.field_info)
request_media_type = field_info.media_type
required = body_field.required
request_body_oai: Dict[str, Any] = {}
if required:
request_body_oai["required"] = required
request_media_content: Dict[str, Any] = {"schema": body_schema}
if field_info.examples:
request_media_content["examples"] = jsonable_encoder(field_info.examples)
elif field_info.example != Undefined:
request_media_content["example"] = jsonable_encoder(field_info.example)
request_body_oai["content"] = {request_media_type: request_media_content}
return request_body_oai
[docs]def generate_operation_id(*, route: routing.APIRoute, method: str) -> str: # pragma: nocover
warnings.warn(
"fastapi.openapi.utils.generate_operation_id() was deprecated, "
"it is not used internally, and will be removed soon",
DeprecationWarning,
stacklevel=2,
)
if route.operation_id:
return route.operation_id
path: str = route.path_format
return generate_operation_id_for_path(name=route.name, path=path, method=method)
[docs]def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str:
if route.summary:
return route.summary
return route.name.replace("_", " ").title()
[docs]def get_openapi_operation_metadata(*, route: routing.APIRoute, method: str, operation_ids: Set[str]) -> Dict[str, Any]:
operation: Dict[str, Any] = {}
if route.tags:
operation["tags"] = route.tags
operation["summary"] = generate_operation_summary(route=route, method=method)
if route.description:
operation["description"] = route.description
operation_id = route.operation_id or route.unique_id
if operation_id in operation_ids:
message = f"Duplicate Operation ID {operation_id} for function " + f"{route.endpoint.__name__}"
file_name = getattr(route.endpoint, "__globals__", {}).get("__file__")
if file_name:
message += f" at {file_name}"
warnings.warn(message)
operation_ids.add(operation_id)
operation["operationId"] = operation_id
if route.deprecated:
operation["deprecated"] = route.deprecated
return operation
[docs]def get_openapi_path(
*, route: routing.APIRoute, model_name_map: Dict[type, str], operation_ids: Set[str]
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
path = {}
security_schemes: Dict[str, Any] = {}
definitions: Dict[str, Any] = {}
assert route.methods is not None, "Methods must be a list"
if isinstance(route.response_class, DefaultPlaceholder):
current_response_class: Type[Response] = route.response_class.value
else:
current_response_class = route.response_class
assert current_response_class, "A response class is needed to generate OpenAPI"
route_response_media_type: Optional[str] = current_response_class.media_type
if route.include_in_schema:
for method in route.methods:
operation = get_openapi_operation_metadata(route=route, method=method, operation_ids=operation_ids)
parameters: List[Dict[str, Any]] = []
flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
security_definitions, operation_security = get_openapi_security_definitions(flat_dependant=flat_dependant)
if operation_security:
operation.setdefault("security", []).extend(operation_security)
if security_definitions:
security_schemes.update(security_definitions)
all_route_params = get_flat_params(route.dependant)
operation_parameters = get_openapi_operation_parameters(
all_route_params=all_route_params, model_name_map=model_name_map
)
parameters.extend(operation_parameters)
if parameters:
operation["parameters"] = list({param["name"]: param for param in parameters}.values())
if method in METHODS_WITH_BODY:
request_body_oai = get_openapi_operation_request_body(
body_field=route.body_field, model_name_map=model_name_map
)
if request_body_oai:
operation["requestBody"] = request_body_oai
if route.callbacks:
callbacks = {}
for callback in route.callbacks:
if isinstance(callback, routing.APIRoute):
(cb_path, cb_security_schemes, cb_definitions,) = get_openapi_path(
route=callback,
model_name_map=model_name_map,
operation_ids=operation_ids,
)
callbacks[callback.name] = {callback.path: cb_path}
operation["callbacks"] = callbacks
if route.status_code is not None:
status_code = str(route.status_code)
else:
# It would probably make more sense for all response classes to have an
# explicit default status_code, and to extract it from them, instead of
# doing this inspection tricks, that would probably be in the future
# TODO: probably make status_code a default class attribute for all
# responses in Starlette
response_signature = inspect.signature(current_response_class.__init__)
status_code_param = response_signature.parameters.get("status_code")
if status_code_param is not None:
if isinstance(status_code_param.default, int):
status_code = str(status_code_param.default)
operation.setdefault("responses", {}).setdefault(status_code, {})[
"description"
] = route.response_description
if route_response_media_type and is_body_allowed_for_status_code(route.status_code):
response_schema = {"type": "string"}
if lenient_issubclass(current_response_class, JSONResponse):
if route.response_field:
response_schema, _, _ = field_schema(
route.response_field,
model_name_map=model_name_map,
ref_prefix=REF_PREFIX,
)
else:
response_schema = {}
operation.setdefault("responses", {}).setdefault(status_code, {}).setdefault("content", {}).setdefault(
route_response_media_type, {}
)["schema"] = response_schema
if route.responses:
operation_responses = operation.setdefault("responses", {})
for (
additional_status_code,
additional_response,
) in route.responses.items():
process_response = additional_response.copy()
process_response.pop("model", None)
status_code_key = str(additional_status_code).upper()
if status_code_key == "DEFAULT":
status_code_key = "default"
openapi_response = operation_responses.setdefault(status_code_key, {})
assert isinstance(process_response, dict), "An additional response must be a dict"
field = route.response_fields.get(additional_status_code)
additional_field_schema: Optional[Dict[str, Any]] = None
if field:
additional_field_schema, _, _ = field_schema(
field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
)
media_type = route_response_media_type or "application/json"
additional_schema = (
process_response.setdefault("content", {})
.setdefault(media_type, {})
.setdefault("schema", {})
)
deep_dict_update(additional_schema, additional_field_schema)
status_text: Optional[str] = status_code_ranges.get(
str(additional_status_code).upper()
) or http.client.responses.get(int(additional_status_code))
description = (
process_response.get("description")
or openapi_response.get("description")
or status_text
or "Additional Response"
)
deep_dict_update(openapi_response, process_response)
openapi_response["description"] = description
http422 = str(HTTP_422_UNPROCESSABLE_ENTITY)
if (all_route_params or route.body_field) and not any(
[status in operation["responses"] for status in [http422, "4XX", "default"]]
):
operation["responses"][http422] = {
"description": "Validation Error",
"content": {"application/json": {"schema": {"$ref": REF_PREFIX + "HTTPValidationError"}}},
}
if "ValidationError" not in definitions:
definitions.update(
{
"ValidationError": validation_error_definition,
"HTTPValidationError": validation_error_response_definition,
}
)
if route.openapi_extra:
deep_dict_update(operation, route.openapi_extra)
path[method.lower()] = operation
return path, security_schemes, definitions
[docs]def get_flat_models_from_routes(
routes: Sequence[BaseRoute],
) -> Set[Union[Type[BaseModel], Type[Enum]]]:
body_fields_from_routes: List[ModelField] = []
responses_from_routes: List[ModelField] = []
request_fields_from_routes: List[ModelField] = []
callback_flat_models: Set[Union[Type[BaseModel], Type[Enum]]] = set()
for route in routes:
if getattr(route, "include_in_schema", None) and isinstance(route, routing.APIRoute):
if route.body_field:
assert isinstance(route.body_field, ModelField), "A request body must be a Pydantic Field"
body_fields_from_routes.append(route.body_field)
if route.response_field:
responses_from_routes.append(route.response_field)
if route.response_fields:
responses_from_routes.extend(route.response_fields.values())
if route.callbacks:
callback_flat_models |= get_flat_models_from_routes(route.callbacks)
params = get_flat_params(route.dependant)
request_fields_from_routes.extend(params)
flat_models = callback_flat_models | get_flat_models_from_fields(
body_fields_from_routes + responses_from_routes + request_fields_from_routes,
known_models=set(),
)
return flat_models
[docs]def get_openapi(
*,
title: str,
version: str,
openapi_version: str = "3.0.2",
description: Optional[str] = None,
routes: Sequence[BaseRoute],
tags: Optional[List[Dict[str, Any]]] = None,
servers: Optional[List[Dict[str, Union[str, Any]]]] = None,
terms_of_service: Optional[str] = None,
contact: Optional[Dict[str, Union[str, Any]]] = None,
license_info: Optional[Dict[str, Union[str, Any]]] = None,
) -> Dict[str, Any]:
info: Dict[str, Any] = {"title": title, "version": version}
if description:
info["description"] = description
if terms_of_service:
info["termsOfService"] = terms_of_service
if contact:
info["contact"] = contact
if license_info:
info["license"] = license_info
output: Dict[str, Any] = {"openapi": openapi_version, "info": info}
if servers:
output["servers"] = servers
components: Dict[str, Dict[str, Any]] = {}
paths: Dict[str, Dict[str, Any]] = {}
operation_ids: Set[str] = set()
flat_models = get_flat_models_from_routes(routes)
model_name_map = get_model_name_map(flat_models)
definitions = get_model_definitions(flat_models=flat_models, model_name_map=model_name_map)
for route in routes:
if isinstance(route, routing.APIRoute):
result = get_openapi_path(route=route, model_name_map=model_name_map, operation_ids=operation_ids)
if result:
path, security_schemes, path_definitions = result
if path:
existing_path = paths.get(route.path_format)
paths[route.path_format] = merge_paths(existing_path, path) if existing_path else path
if security_schemes:
components.setdefault("securitySchemes", {}).update(security_schemes)
if path_definitions:
definitions.update(path_definitions)
if definitions:
components["schemas"] = {k: definitions[k] for k in sorted(definitions)}
if components:
output["components"] = components
output["paths"] = paths
if tags:
output["tags"] = tags
return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore
[docs]def merge_paths(existing_path: Dict[str, Any], new_path: Dict[str, Any]) -> Dict[str, Any]:
"""Merge two openapi path descriptions for the same route, e.g. for response content with different accept-encoding."""
path_by_operation_id: Dict[str, Dict[str, Any]] = {}
for method, operation in itertools.chain(existing_path.items(), new_path.items()):
path_by_operation_id.setdefault(operation["operationId"], {})[method] = operation
return deep_dict_operation_merge(path_by_operation_id) or {}
[docs]def merge_tags(tags_by_id: Dict[str, Optional[List[str]]]) -> Optional[List[str]]:
tags: Optional[List[str]] = None
for update_tags in tags_by_id.values():
if update_tags:
tags = tags or []
tags.extend(t for t in update_tags if t not in tags)
return tags
[docs]def merge_summary(summary_by_id: Dict[str, Optional[str]]) -> Optional[str]:
summary: Optional[str] = None
for update_summary in summary_by_id.values():
if update_summary and summary != update_summary:
summary = f"{summary} / {update_summary}" if summary else update_summary
return summary
[docs]def merge_description(desc_by_id: Dict[str, Optional[str]]) -> Optional[str]:
desc: Optional[str] = None
for update_desc in desc_by_id.values():
if update_desc and desc != update_desc:
desc = f"{desc}\n\n OR \n\n {update_desc}" if desc else update_desc
return desc
[docs]def merge_operation_id(operation_id_by_id: Dict[str, Optional[str]]) -> Optional[str]:
operation_id: Optional[str] = None
for update_operation_id in operation_id_by_id.values():
if update_operation_id and operation_id != update_operation_id:
if operation_id:
message = f"Merging operation with id {operation_id} with operation with id {update_operation_id}."
warnings.warn(message)
operation_id = f"{operation_id}+{update_operation_id}"
else:
operation_id = update_operation_id
return operation_id
[docs]def merge_deprecated(deprecated_by_id: Dict[str, Optional[bool]]) -> Optional[bool]:
if any(deprecated_by_id.values()):
return True
return None
[docs]def merge_parameters(params_by_id: Dict[str, Optional[List[Dict[str, Any]]]]) -> Optional[List[Dict[str, Any]]]:
with_params = [(param_id, params) for (param_id, params) in params_by_id.items() if params]
if not with_params:
return None
if len(with_params) == 1:
return with_params[0][1]
param_by_id_by_name: Dict[str, Any] = {}
for (param_id, params) in with_params:
for param in params:
param_by_id_by_name.setdefault(param["name"], {})[param_id] = param
return [deep_dict_operation_merge(param_by_id) for param_by_id in param_by_id_by_name.values()]
[docs]def deep_dict_operation_merge(dict_by_id: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
if not dict_by_id:
return dict_by_id
if len(dict_by_id) == 1:
return next(iter(dict_by_id.values()))
keys = {k for d in dict_by_id.values() for k in d}
result_dict: Optional[Dict[Any, Any]] = None
for key in keys:
merge_method = MERGE_METHODS_BY_KEY.get(key)
if not merge_method and _all_dict_or_none_at_key(dict_by_id, key):
merge_method = deep_dict_operation_merge
result_dict = result_dict or {}
dict_at_key_by_id = {d_id: d[key] for d_id, d in dict_by_id.items() if key in d}
if merge_method:
result_dict[key] = merge_method(dict_at_key_by_id)
else:
# use last entry at key for each
result_dict[key] = next(reversed(list(dict_at_key_by_id.values())), None)
return {k: v for k, v in result_dict.items() if v is not None} if result_dict else {}
def _all_dict_or_none_at_key(dict_by_id: Dict[str, Dict[str, Any]], key: str) -> bool:
return all(isinstance(d[key], dict) for d in dict_by_id.values() if key in d if d[key] is not None)
MERGE_METHODS_BY_KEY: Dict[str, Callable[[Dict[str, Any]], Any]] = {
"tags": merge_tags,
"summary": merge_summary,
"description": merge_description,
"operationId": merge_operation_id,
"parameters": merge_parameters,
"deprecated": merge_deprecated,
"content": deep_dict_operation_merge,
"requestBody": deep_dict_operation_merge,
"callbacks": deep_dict_operation_merge,
}