Source code for galaxy.celery

import os
import uuid
from functools import (
    lru_cache,
    wraps,
)
from multiprocessing import get_context
from threading import local
from typing import (
    Any,
    Callable,
    Dict,
)

import pebble
from celery import (
    Celery,
    shared_task,
    Task,
)
from celery.signals import (
    worker_init,
    worker_shutting_down,
)
from kombu import serialization

from galaxy.celery.base_task import GalaxyTaskBeforeStart
from galaxy.config import Configuration
from galaxy.main_config import find_config
from galaxy.util import ExecutionTimer
from galaxy.util.custom_logging import get_logger
from galaxy.util.properties import load_app_properties
from ._serialization import (
    schema_dumps,
    schema_loads,
)

log = get_logger(__name__)

MAIN_TASK_MODULE = "galaxy.celery.tasks"
DEFAULT_TASK_QUEUE = "galaxy.internal"
TASKS_MODULES = [MAIN_TASK_MODULE]
PYDANTIC_AWARE_SERIALIZER_NAME = "pydantic-aware-json"

APP_LOCAL = local()

serialization.register(
    PYDANTIC_AWARE_SERIALIZER_NAME, encoder=schema_dumps, decoder=schema_loads, content_type="application/json"
)


[docs]class GalaxyCelery(Celery): fork_pool: pebble.ProcessPool
[docs] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] def gen_task_name(self, name, module): module = self.trim_module_name(module) return super().gen_task_name(name, module)
[docs] def trim_module_name(self, module): """ Drop "celery.tasks" infix for less verbose task names: - galaxy.celery.tasks.do_foo >> galaxy.do_foo - galaxy.celery.tasks.subtasks.do_fuz >> galaxy.subtasks.do_fuz """ if module.startswith("galaxy.celery.tasks"): module = f"galaxy{module[19:]}" return module
[docs]class GalaxyTask(Task): """ Custom celery task used to limit number of tasks executions per user per second. """
[docs] def before_start(self, task_id, args, kwargs): """ Set appropriate before start object from DI container. """ app = get_galaxy_app() assert app app[GalaxyTaskBeforeStart](self, task_id, args, kwargs)
[docs]def set_thread_app(app): APP_LOCAL.app = app
[docs]def get_galaxy_app(): try: return APP_LOCAL.app except AttributeError: import galaxy.app if galaxy.app.app: return galaxy.app.app return build_app()
[docs]@lru_cache(maxsize=1) def build_app(): if kwargs := get_app_properties(): kwargs["check_migrate_databases"] = False kwargs["use_display_applications"] = False kwargs["use_converters"] = False import galaxy.app galaxy_app = galaxy.app.GalaxyManagerApplication(configure_logging=False, **kwargs) return galaxy_app
[docs]@lru_cache(maxsize=1) def get_app_properties(): config_file = os.environ.get("GALAXY_CONFIG_FILE") galaxy_root_dir = os.environ.get("GALAXY_ROOT_DIR") if not config_file and galaxy_root_dir: config_file = find_config(config_file, galaxy_root_dir) if config_file: properties = load_app_properties( config_file=os.path.abspath(config_file), config_section="galaxy", ) if galaxy_root_dir: properties["root_dir"] = galaxy_root_dir return properties
[docs]@lru_cache(maxsize=1) def get_config(): kwargs = get_app_properties() or {} kwargs["override_tempdir"] = False return Configuration(**kwargs)
[docs]def init_fork_pool(): # Do slow imports when workers boot. from galaxy.datatypes import registry # noqa: F401 from galaxy.metadata import set_metadata # noqa: F401
[docs]@worker_init.connect def setup_worker_pool(sender=None, conf=None, instance=None, **kwargs): context = get_context("forkserver") celery_app.fork_pool = pebble.ProcessPool( max_workers=sender.concurrency, max_tasks=100, initializer=init_fork_pool, context=context )
[docs]@worker_shutting_down.connect def tear_down_pool(sig, how, exitcode, **kwargs): log.debug("shutting down forkserver pool") celery_app.fork_pool.stop() celery_app.fork_pool.join(timeout=5)
[docs]def galaxy_task(*args, action=None, **celery_task_kwd): if "serializer" not in celery_task_kwd: celery_task_kwd["serializer"] = PYDANTIC_AWARE_SERIALIZER_NAME def decorate(func: Callable): @shared_task(base=GalaxyTask, **celery_task_kwd) @wraps(func) def wrapper(*args, **kwds): app = get_galaxy_app() assert app # Ensure sqlalchemy session registry scope is specific to this instance of the celery task scoped_id = str(uuid.uuid4()) app.model.set_request_id(scoped_id) desc = func.__name__ if action is not None: desc += f" to {action}" try: timer = app.execution_timer_factory.get_timer("internals.tasks.{func.__name__}", desc) except AttributeError: timer = ExecutionTimer() try: rval = app.magic_partial(func)(*args, **kwds) message = f"Successfully executed Celery task {desc} {timer}" log.info(message) return rval except Exception: log.warning(f"Celery task execution failed for {desc} {timer}") raise finally: # Close and remove any open session this task has created app.model.unset_request_id(scoped_id) return wrapper if len(args) == 1 and callable(args[0]): return decorate(args[0]) else: return decorate
[docs]def init_celery_app(): celery_app_kwd: Dict[str, Any] = { "include": TASKS_MODULES, "task_default_queue": DEFAULT_TASK_QUEUE, "task_create_missing_queues": True, "timezone": "UTC", } celery_app = GalaxyCelery("galaxy", **celery_app_kwd) celery_app.set_default() config = get_config() config_celery_app(config, celery_app) setup_periodic_tasks(config, celery_app) return celery_app
[docs]def config_celery_app(config, celery_app): # Apply settings from galaxy's config if config.celery_conf: celery_app.conf.update(config.celery_conf) # Handle special cases if not celery_app.conf.broker_url: celery_app.conf.broker_url = config.amqp_internal_connection
[docs]def setup_periodic_tasks(config, celery_app): def schedule_task(task, interval): if interval > 0: task_key = task.replace("_", "-") module_name = celery_app.trim_module_name(MAIN_TASK_MODULE) task_name = f"{module_name}.{task}" beat_schedule[task_key] = { "task": task_name, "schedule": interval, } beat_schedule: Dict[str, Dict[str, Any]] = {} schedule_task("prune_history_audit_table", config.history_audit_table_prune_interval) schedule_task("cleanup_short_term_storage", config.short_term_storage_cleanup_interval) if config.enable_notification_system: schedule_task("cleanup_expired_notifications", config.expired_notifications_cleanup_interval) schedule_task("dispatch_pending_notifications", config.dispatch_notifications_interval) if config.object_store_cache_monitor_driver in ["auto", "celery"]: schedule_task("clean_object_store_caches", config.object_store_cache_monitor_interval) if beat_schedule: celery_app.conf.beat_schedule = beat_schedule
celery_app = init_celery_app()