"""Example Flask application with RESTful API using flask_restx_marshmallow."""
# ruff: noqa: PLR6301
from collections.abc import Iterable
from dataclasses import dataclass, field
from http import HTTPStatus
from typing import TYPE_CHECKING, Self, Unpack
from uuid import UUID, uuid4
from flask.app import Flask
from flask.testing import FlaskClient
from flask_sqlalchemy.extension import SQLAlchemy as BaseSQLAlchemy
from marshmallow.fields import Integer, Nested
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm.base import Mapped
from sqlalchemy_utils.types.json import JSONType
from flask_restx_marshmallow import Resource
from flask_restx_marshmallow.api import Api
from flask_restx_marshmallow.namespace import Namespace
from flask_restx_marshmallow.schema import Schema
from flask_restx_marshmallow.type import (
Result,
SchemaInitKwargs,
Success,
TupleResponse,
Warn,
)
from flask_restx_marshmallow.util import class_schema, make_default_schema
try:
from loguru import logger
except ImportError:
from logging import getLogger
logger = getLogger(__name__)
if TYPE_CHECKING:
from typing import Self
from flask_sqlalchemy.extension import _FSAModel
from flask_sqlalchemy.query import Query
from sqlalchemy.orm.query import Query as _Query
class Model(_FSAModel):
"""Add original type for type annotation, which is able to use typevar."""
query: _Query[Self] | Query
class SQLAlchemy(BaseSQLAlchemy):
"""Add original type for type annotation, which is able to use typevar."""
Model: type[Model]
@dataclass(kw_only=True)
class CreateTaskParameters:
"""Parameters for creating a task."""
name: str = field(
metadata={"metadata": {"description": "task name"}},
)
settings: dict | None = field(
default=None,
metadata={"metadata": {"description": "task settings"}},
)
@dataclass(kw_only=True)
class UpdateTaskParameters:
"""Parameters for updating a task."""
task_id: UUID = field(
metadata={"metadata": {"description": "task id", "location": "path"}},
)
name: str | None = field(
default=None,
metadata={"metadata": {"description": "task name"}},
)
settings: dict | None = field(
default=None,
metadata={"metadata": {"description": "task settings"}},
)
@dataclass(kw_only=True)
class DeleteTaskParameters:
"""Parameters for deleting a task."""
task_id: UUID = field(
metadata={"metadata": {"description": "task id", "location": "path"}},
)
@dataclass(kw_only=True)
class QueryTasksParameters:
"""Parameters for querying a task."""
name: str | None = field(
default=None,
metadata={"metadata": {"description": "task name"}},
)
@dataclass(kw_only=True)
class QueryTaskParameters:
"""Parameters for querying a task."""
task_id: "UUID | None" = field(
default=None,
metadata={"metadata": {"description": "task id", "location": "path"}},
)
name: str | None = field(
default=None,
metadata={"metadata": {"description": "task name"}},
)
@dataclass(kw_only=True)
class TaskProfile:
"""Task profile schema."""
task_id: UUID = field(
metadata={"metadata": {"description": "task unique identifier"}},
)
name: str = field(metadata={"metadata": {"description": "task name"}})
settings: dict | None = field(
default=None,
metadata={"metadata": {"description": "task settings"}},
)
class TaskProfileSchema(class_schema(TaskProfile, base=Schema)):
"""Task profile schema for marshmallow serialization."""
class StandardSchema(make_default_schema(name="StandardResponse")):
"""Standard response schema."""
def __init__(
self,
message: str | None = None,
header_fields: Iterable[str] | None = None,
**kwargs: Unpack[SchemaInitKwargs],
) -> None:
"""Set default message.
Args:
message (str | None): response message. Defaults to None.
header_fields (Iterable[str] | None): header fields. Defaults to None.
kwargs (Unpack[SchemaInitKwargs]): other kwargs.
"""
super().__init__(header_fields=header_fields, **kwargs)
if message is not None:
self.fields["message"].dump_default = message
self.fields["message"].metadata = dict(
**self.fields["message"].metadata,
)
self.fields["message"].metadata["example"] = message
class TaskSchema(StandardSchema):
"""Query response schema for task."""
result = Nested(
{
"items": Nested(
TaskProfileSchema,
metadata={"description": "task list"},
many=True,
),
"total": Integer(metadata={"description": "total number of tasks"}),
},
metadata={"description": "query result"},
)
db = (
SQLAlchemy()
if TYPE_CHECKING
else BaseSQLAlchemy(session_options={"expire_on_commit": False})
)
class Task(db.Model):
"""Task model for managing tasks in the database."""
if TYPE_CHECKING:
def __init__(
self,
*,
task_id: UUID,
name: str,
settings: dict | None = None,
) -> None:
"""Add init annotations for type checking.
Args:
task_id (UUID): The unique identifier for the task.
name (str): The name of the task.
settings (dict, optional): The settings for the task.
Defaults to None.
"""
__tablename__ = "task"
task_id: Mapped[UUID] = mapped_column(primary_key=True, comment="task id")
name: Mapped[str] = mapped_column(unique=True, comment="task name")
settings: Mapped[dict | None] = mapped_column(JSONType, comment="task settings")
@classmethod
def create(
cls,
create_task_info: "CreateTaskParameters",
) -> TupleResponse["Self"] | TupleResponse:
"""Create a new task instance.
Args:
create_task_info (CreateTaskParameters): The parameters for creating.
Returns:
TupleResponse: A response containing the created task or an error.
message.
"""
try:
task_ = cls(
task_id=uuid4(),
name=create_task_info.name,
settings=create_task_info.settings,
)
db.session.add(task_)
db.session.commit()
return TupleResponse(
Success(
message="create task successfully",
result=Result(
items=[task_],
total=1,
),
),
)
except Exception:
db.session.rollback()
db.session.commit()
logger.exception("task already exists: %s", create_task_info.name)
return TupleResponse(
Warn(message="task already exists"),
status_code=HTTPStatus.BAD_REQUEST,
)
@classmethod
def delete(
cls,
task_id: UUID,
) -> TupleResponse:
"""Delete a task instance by its ID.
Args:
task_id (UUID): The unique identifier for the task.
Returns:
TupleResponse: A response indicating the success or failure of the
deletion.
"""
if (query := cls.query.filter_by(task_id=task_id)).one_or_none():
query.delete()
return TupleResponse(Success(message="delete task successfully"))
logger.warning("task not found: %s", task_id)
return TupleResponse(
Warn(message="task not found"),
status_code=HTTPStatus.BAD_REQUEST,
)
@classmethod
def update(cls, task_update_info: "UpdateTaskParameters") -> TupleResponse:
"""Update a task instance by its ID.
Args:
task_update_info (UpdateTaskParameters): The parameters for updating
the task.
Returns:
TupleResponse: A response indicating the success or failure of the
update.
"""
if task := cls.query.filter_by(
task_id=task_update_info.task_id,
).one_or_none():
if task_update_info.name is not None:
task.name = task_update_info.name
if task_update_info.settings is not None:
task.settings = task_update_info.settings
db.session.commit()
return TupleResponse(Success(message="update task successfully"))
logger.warning("task not found: %s", task_update_info.task_id)
return TupleResponse(
Warn(message="task not found"),
status_code=HTTPStatus.BAD_REQUEST,
)
@classmethod
def query_task(cls, params: "QueryTaskParameters") -> TupleResponse["Self"]:
"""Query a task instance by its ID or name.
Args:
params (QueryTaskParameters): The parameters for querying the task.
Returns:
TupleResponse: A response containing the queried task or an error
message.
"""
if task := cls.query.filter_by(task_id=params.task_id).one_or_none():
return TupleResponse(
Success(
message="query task successfully",
result=Result(items=[task], total=1),
),
)
logger.warning("task not found: %s", params)
return TupleResponse(
Warn(message="task not found"),
status_code=HTTPStatus.BAD_REQUEST,
)
@classmethod
def query_tasks(cls, params: "QueryTasksParameters") -> TupleResponse["Self"]:
"""Query a task instance by its ID or name.
Args:
params (QueryTasksParameters): The parameters for querying the task.
Returns:
TupleResponse: A response containing the queried task or an error
message.
"""
query = cls.query
if params.name is not None:
query = query.filter(cls.name.ilike(params.name))
if query.count() > 0:
return TupleResponse(
Success(
message="query task successfully",
result=Result(
items=query.all(),
total=query.count(),
),
),
)
logger.warning("task not found: %s", params)
return TupleResponse(
Warn(message="task not found"),
status_code=HTTPStatus.BAD_REQUEST,
)
ns = Namespace(
"example",
description="example namespace for task management",
path="/task",
)
@ns.route("/", endpoint="tasks")
class TasksManage(Resource):
"""tasks manage."""
@ns.parameters(params=QueryTasksParameters, location="query")
@ns.responses(schema=TaskSchema(message="query task success"))
@ns.responses(message="task not found", code=HTTPStatus.BAD_REQUEST)
def get(self, params: "QueryTasksParameters") -> "TupleResponse[Task]":
"""Query all tasks.
Args:
params (QueryTasksParameters): The parameters for querying tasks.
Returns:
TupleResponse[Task]: A response containing the queried tasks.
"""
return Task.query_tasks(params)
@ns.parameters(params=CreateTaskParameters, location="body")
@ns.responses(schema=TaskSchema(message="create task successfully"))
@ns.responses(message="task already exists", code=HTTPStatus.BAD_REQUEST)
def post(
self,
params: "CreateTaskParameters",
) -> "TupleResponse[Task] | TupleResponse":
"""Create a new task.
Args:
params (CreateTaskParameters): The parameters for creating a task.
Returns:
TupleResponse[Task]: A response containing the created task.
"""
return Task.create(params)
@ns.route("/<task_id>", endpoint="task")
class TaskManage(Resource):
"""task manage."""
@ns.parameters(params=QueryTaskParameters, location="query")
@ns.responses(schema=TaskSchema(message="query task success"))
@ns.responses(message="task not found", code=HTTPStatus.BAD_REQUEST)
def get(self, params: "QueryTaskParameters", **_) -> "TupleResponse[Task]": # noqa: ANN003
"""Query a task by its ID or name.
Args:
params (QueryTaskParameters): The parameters for querying the task.
Returns:
TupleResponse[Task]: A response containing the queried task.
"""
return Task.query_task(params)
@ns.parameters(params=DeleteTaskParameters, location="path")
@ns.responses(message="delete task successfully")
@ns.responses(message="task not found", code=HTTPStatus.BAD_REQUEST)
def delete(self, params: "DeleteTaskParameters", **_) -> "TupleResponse": # noqa: ANN003
"""Delete a task by its ID.
Args:
params (DeleteTaskParameters): The parameters for deleting the task.
Returns:
TupleResponse: A response indicating the success or failure of the
deletion.
"""
return Task.delete(params.task_id)
@ns.parameters(params=UpdateTaskParameters, location="body")
@ns.responses(message="update task successfully")
@ns.responses(message="task not found", code=HTTPStatus.BAD_REQUEST)
def patch(self, params: "UpdateTaskParameters", **_) -> "TupleResponse": # noqa: ANN003
"""Update a task by its ID.
Args:
params (UpdateTaskParameters): The parameters for updating the task.
Returns:
TupleResponse: A response indicating the success or failure of the
update.
"""
return Task.update(params)
api = Api(
version="0.2.0",
title="example API",
description="api interface for example app",
)
api.add_namespace(ns)
app = Flask(__name__, subdomain_matching=True)
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
app.test_client_class = FlaskClient
api.init_app(app)
db.init_app(app)
if __name__ == "__main__":
with app.app_context():
db.create_all()
app.run(debug=True)