Source code for galaxy.celery

import os
from functools import (
    lru_cache,
    wraps,
)
from multiprocessing import get_context
from threading import local
from typing import (
    Any,
    Callable,
    Dict,
)

import pebble
from celery import (
    Celery,
    shared_task,
)
from celery.signals import (
    worker_init,
    worker_shutting_down,
)
from kombu import serialization

from galaxy.config import Configuration
from galaxy.main_config import find_config
from galaxy.util import ExecutionTimer
from galaxy.util.custom_logging import get_logger
from galaxy.util.properties import load_app_properties
from ._serialization import (
    schema_dumps,
    schema_loads,
)

log = get_logger(__name__)

MAIN_TASK_MODULE = "galaxy.celery.tasks"
DEFAULT_TASK_QUEUE = "galaxy.internal"
TASKS_MODULES = [MAIN_TASK_MODULE]
PYDANTIC_AWARE_SERIALIZER_NAME = "pydantic-aware-json"

APP_LOCAL = local()

serialization.register(
    PYDANTIC_AWARE_SERIALIZER_NAME, encoder=schema_dumps, decoder=schema_loads, content_type="application/json"
)


[docs]def set_thread_app(app): APP_LOCAL.app = app
[docs]def get_galaxy_app(): try: return APP_LOCAL.app except AttributeError: import galaxy.app if galaxy.app.app: return galaxy.app.app return build_app()
[docs]@lru_cache(maxsize=1) def build_app(): kwargs = get_app_properties() if kwargs: kwargs["check_migrate_databases"] = False kwargs["use_display_applications"] = False kwargs["use_converters"] = False import galaxy.app galaxy_app = galaxy.app.GalaxyManagerApplication(configure_logging=False, **kwargs) return galaxy_app
[docs]@lru_cache(maxsize=1) def get_app_properties(): config_file = os.environ.get("GALAXY_CONFIG_FILE") galaxy_root_dir = os.environ.get("GALAXY_ROOT_DIR") if not config_file and galaxy_root_dir: config_file = find_config(config_file, galaxy_root_dir) if config_file: properties = load_app_properties( config_file=os.path.abspath(config_file), config_section="galaxy", ) if galaxy_root_dir: properties["root_dir"] = galaxy_root_dir return properties
[docs]@lru_cache(maxsize=1) def get_config(): kwargs = get_app_properties() if kwargs: kwargs["override_tempdir"] = False return Configuration(**kwargs)
[docs]def get_broker(): config = get_config() if config: return config.celery_broker or config.amqp_internal_connection
[docs]def get_backend(): config = get_config() if config: return config.celery_backend
[docs]def get_history_audit_table_prune_interval(): config = get_config() if config: return config.history_audit_table_prune_interval else: return 3600
[docs]def get_cleanup_short_term_storage_interval(): config = get_config() if config: return config.short_term_storage_cleanup_interval else: return 3600
broker = get_broker() backend = get_backend() celery_app_kwd: Dict[str, Any] = { "broker": broker, "include": TASKS_MODULES, "task_default_queue": DEFAULT_TASK_QUEUE, "task_create_missing_queues": True, } if backend: celery_app_kwd["backend"] = backend celery_app = Celery("galaxy", **celery_app_kwd) celery_app.set_default() # setup cron like tasks... beat_schedule: Dict[str, Dict[str, Any]] = {}
[docs]def init_fork_pool(): # Do slow imports when workers boot. from galaxy.datatypes import registry # noqa: F401 from galaxy.metadata import set_metadata # noqa: F401
[docs]@worker_init.connect def setup_worker_pool(sender=None, conf=None, instance=None, **kwargs): context = get_context("forkserver") celery_app.fork_pool = pebble.ProcessPool( max_workers=sender.concurrency, max_tasks=100, initializer=init_fork_pool, context=context )
[docs]@worker_shutting_down.connect def tear_down_pool(sig, how, exitcode, **kwargs): log.debug("shutting down forkserver pool") celery_app.fork_pool.stop() celery_app.fork_pool.join(timeout=5)
prune_interval = get_history_audit_table_prune_interval() if prune_interval > 0: beat_schedule["prune-history-audit-table"] = { "task": f"{MAIN_TASK_MODULE}.prune_history_audit_table", "schedule": prune_interval, } cleanup_interval = get_cleanup_short_term_storage_interval() if cleanup_interval > 0: beat_schedule["cleanup-short-term-storage"] = { "task": f"{MAIN_TASK_MODULE}.cleanup_short_term_storage", "schedule": cleanup_interval, } if beat_schedule: celery_app.conf.beat_schedule = beat_schedule celery_app.conf.timezone = "UTC"
[docs]def galaxy_task(*args, action=None, **celery_task_kwd): if "serializer" not in celery_task_kwd: celery_task_kwd["serializer"] = PYDANTIC_AWARE_SERIALIZER_NAME def decorate(func: Callable): @shared_task(**celery_task_kwd) @wraps(func) def wrapper(*args, **kwds): app = get_galaxy_app() assert app desc = func.__name__ if action is not None: desc += f" to {action}" try: timer = app.execution_timer_factory.get_timer("internals.tasks.{func.__name__}", desc) except AttributeError: timer = ExecutionTimer() try: rval = app.magic_partial(func)(*args, **kwds) message = f"Successfully executed Celery task {desc} {timer}" log.info(message) return rval except Exception: log.warning(f"Celery task execution failed for {desc} {timer}") raise return wrapper if len(args) == 1 and callable(args[0]): return decorate(args[0]) else: return decorate
if __name__ == "__main__": celery_app.start()