Source code for galaxy.webapps.galaxy.api.plugins

"""
Plugins resource control over the API.
"""

import json
import logging
from typing import (
    Any,
    cast,
    Literal,
    Optional,
    Union,
)

from fastapi import (
    Body,
    Path,
    Query,
    Request,
)
from fastapi.responses import (
    JSONResponse,
    StreamingResponse,
)
from openai import (
    APIError,
    AsyncOpenAI,
)
from openai._streaming import AsyncStream
from openai.types.chat import (
    ChatCompletion,
    ChatCompletionChunk,
    ChatCompletionMessageParam,
    ChatCompletionToolParam,
)
from pydantic import BaseModel

from galaxy.config import GalaxyAppConfiguration
from galaxy.exceptions import (
    MessageException,
    ObjectNotFound,
)
from galaxy.managers import (
    hdas,
    histories,
)
from galaxy.model import (
    HistoryDatasetAssociation,
    User,
)
from galaxy.schema.fields import DecodedDatabaseIdField
from galaxy.schema.visualization import VisualizationPluginResponse
from galaxy.structured_app import StructuredApp
from galaxy.webapps.galaxy.api import (
    depends,
    DependsOnApp,
    DependsOnTrans,
    DependsOnUser,
    Router,
)
from galaxy.webapps.galaxy.fast_app import limiter
from galaxy.work.context import SessionRequestContext

log = logging.getLogger(__name__)

router = Router(tags=["plugins"])

GALAXY_PROMPT = """
You are a Galaxy agent.
You assist users with scientific data analysis and research workflows.
Respond only to scientific, computational, or data analysis related questions.
"""

# Set constants
MAX_MESSAGES = 1024
MAX_TOOLS = 128
MAX_TOOL_BYTES = 16384
TEMPERATURE = 0.3
TIMEOUT = 120.0
TOKENS_DEFAULT = 1024
TOKENS_MAX = 8192
TOP_P = 0.9


[docs] class ChatMessage(BaseModel): role: Literal["assistant", "system", "tool", "user"] content: Optional[str] = None tool_calls: Optional[list[dict[str, Any]]] = None model_config = dict(extra="allow")
[docs] class ChatToolFunction(BaseModel): name: str model_config = dict(extra="allow")
[docs] class ChatTool(BaseModel): type: Literal["function"] function: ChatToolFunction model_config = dict(extra="allow")
[docs] class ChatCompletionRequest(BaseModel): messages: list[ChatMessage] tools: Optional[list[ChatTool]] = None stream: Optional[bool] = False max_tokens: Optional[int] = None model_config = dict(extra="allow")
[docs] class PluginDatasetEntry(BaseModel): id: str hid: int name: str
[docs] class PluginDatasetsResponse(BaseModel): hdas: list[PluginDatasetEntry]
[docs] @router.cbv class FastAPIPlugins: """RESTful controller for interactions with visualization plugins.""" app: StructuredApp = DependsOnApp config: GalaxyAppConfiguration = depends(GalaxyAppConfiguration) hda_manager: hdas.HDAManager = depends(hdas.HDAManager) history_manager: histories.HistoryManager = depends(histories.HistoryManager)
[docs] @router.post("/api/plugins/{plugin_name}/chat/completions", unstable=True) @limiter.limit("30/minute") async def plugins_chat_adapter( self, request: Request, payload: ChatCompletionRequest = Body(...), user: User = DependsOnUser, plugin_name: str = Path( ..., title="Plugin Name", description="Visualization plugin name used to resolve the AI prompt.", examples=["jupyterlite"], ), ): registry = self.app.visualizations_registry if registry: try: plugin = registry.get_plugin(plugin_name) except ObjectNotFound: return self._create_error(f"Plugin does not exist: {plugin_name}.") plugin_specs = plugin and plugin.config.get("specs") plugin_ai_prompt = plugin_specs and plugin_specs.get("ai_prompt") if plugin_ai_prompt: return await self._open_ai_adapter(payload, plugin_ai_prompt, plugin_name) else: return self._create_error("Selected plugin has no AI prompt.") else: return self._create_error("Visualization registry is not available.")
def _get_plugin_config(self, plugin_name: str, key: str) -> Optional[str]: """Get config for a plugin with fallback through inference_services. Precedence: 1. Plugin-specific: inference_services.<plugin_name>.<key> 2. Default inference: inference_services.default.<key> 3. Global config: ai_model / ai_api_key / ai_api_base_url """ inference_config = getattr(self.config, "inference_services", None) if isinstance(inference_config, dict): plugin_specific = inference_config.get(plugin_name) if isinstance(plugin_specific, dict) and key in plugin_specific: return plugin_specific[key] default_config = inference_config.get("default") if isinstance(default_config, dict) and key in default_config: return default_config[key] if key == "model": return self.config.ai_model elif key == "api_key": return self.config.ai_api_key elif key == "api_base_url": return self.config.ai_api_base_url return None async def _open_ai_adapter( self, payload: ChatCompletionRequest, prompt: str, plugin_name: str, ): """Galaxy managed chat completion adapter with prompt injection""" # Collect configuration via inference_services fallback chain ai_api_key = self._get_plugin_config(plugin_name, "api_key") ai_api_base_url = self._get_plugin_config(plugin_name, "api_base_url") ai_model = self._get_plugin_config(plugin_name, "model") if ai_api_key is None: return self._create_error("AI service not configured: API key is required.") if ai_model is None: return self._create_error("AI service not configured: Model is required.") # Limit max tokens max_tokens = min(payload.max_tokens or TOKENS_DEFAULT, TOKENS_MAX) # Validate messages messages: list[ChatCompletionMessageParam] = cast( list[ChatCompletionMessageParam], [ dict(role="system", content=GALAXY_PROMPT), dict(role="system", content=prompt), ], ) original_messages = payload.messages for msg in original_messages: role = msg.role content = msg.content tool_calls = msg.tool_calls if role == "assistant": msg_dict: dict[str, Any] = dict(role="assistant") if content is not None: msg_dict["content"] = content if isinstance(tool_calls, list): msg_dict["tool_calls"] = tool_calls if len(msg_dict) > 1: messages.append(cast(ChatCompletionMessageParam, msg_dict)) elif role in ("user", "tool") and isinstance(content, str): messages.append(cast(ChatCompletionMessageParam, dict(role=role, content=content))) else: continue if len(messages) >= MAX_MESSAGES: return self._create_error("You have exceeded the number of maximum messages.") # Detect streaming flag stream = payload.stream is True # Limit number and size of tools tools: list[ChatCompletionToolParam] = [] original_tools = payload.tools or [] if len(original_tools) <= MAX_TOOLS: for tool in original_tools: tool_dict = tool.model_dump() func = tool_dict.get("function", {}) if func.get("parameters") is None: func["parameters"] = {"type": "object", "properties": {}} size = len(json.dumps(tool_dict, separators=(",", ":")).encode("utf-8")) if size > MAX_TOOL_BYTES: return self._create_error("Tool schema too large.") tools.append(cast(ChatCompletionToolParam, tool_dict)) else: return self._create_error("Number of tools exceeded or invalid tools list.") # Build openai client with timeout try: client = AsyncOpenAI( api_key=ai_api_key, timeout=TIMEOUT, base_url=ai_api_base_url or None, ) except Exception as e: log.debug("Failed to initialize OpenAI client.", exc_info=e) return self._create_error("Failed to initialize OpenAI client.", 500) # Connect to ai provider log.info(f"Proxying to {ai_model}, tokens: {max_tokens}.") try: response = await client.chat.completions.create( max_tokens=max_tokens, messages=messages, model=ai_model, stream=stream, temperature=TEMPERATURE, tools=tools, top_p=TOP_P, ) except APIError as e: log.debug("Failed to complete OpenAI request.", exc_info=e) status_code = getattr(e, "status_code", 500) if hasattr(e, "body") and isinstance(e.body, dict): return JSONResponse(content=dict(error=e.body), status_code=status_code) return self._create_error("Failed to complete OpenAI request.", status_code) # Parse response if stream: stream_response: AsyncStream[ChatCompletionChunk] = cast(AsyncStream[ChatCompletionChunk], response) async def generate(): try: async for chunk in stream_response: yield f"data: {json.dumps(chunk.model_dump())}\n\n" yield "data: [DONE]\n\n" finally: await client.close() return StreamingResponse( generate(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", }, ) else: completion_response: ChatCompletion = cast(ChatCompletion, response) return JSONResponse(content=completion_response.model_dump()) def _create_error(self, message: str, status_code=400): """Error handling helper.""" log.debug(message) return JSONResponse(content=dict(error=dict(message=message)), status_code=status_code)
[docs] @router.get("/api/plugins") def index( self, trans: SessionRequestContext = DependsOnTrans, dataset_id: Optional[DecodedDatabaseIdField] = Query( default=None, title="Dataset ID", description="Filter to visualizations compatible with this dataset.", ), embeddable: Optional[bool] = Query( default=None, title="Embeddable", description="Filter to embeddable visualizations only.", ), ) -> list[dict[str, Any]]: """List available visualization plugins.""" registry = self._get_registry() target_object = None if dataset_id is not None: target_object = self.hda_manager.get_accessible(dataset_id, trans.user) return registry.get_visualizations(trans, target_object=target_object, embeddable=embeddable or False)
[docs] @router.get("/api/plugins/{id}") def show( self, trans: SessionRequestContext = DependsOnTrans, id: str = Path( ..., title="Plugin ID", description="The visualization plugin identifier.", ), history_id: Optional[DecodedDatabaseIdField] = Query( default=None, title="History ID", description="Filter datasets compatible with this plugin from the specified history.", ), ) -> Union[PluginDatasetsResponse, VisualizationPluginResponse]: """Get details of a specific visualization plugin.""" registry = self._get_registry() if history_id is not None: history = self.history_manager.get_owned(history_id, trans.user, current_history=trans.history) hdas: list[PluginDatasetEntry] = [] for item in history.contents_iter(types=["dataset"], deleted=False, visible=True): hda = cast(HistoryDatasetAssociation, item) if hda.hid is not None and registry.get_visualization(trans, id, hda): hdas.append( PluginDatasetEntry( id=trans.security.encode_id(hda.id), hid=hda.hid, name=hda.name, ) ) hdas.sort(key=lambda h: h.hid, reverse=True) return PluginDatasetsResponse(hdas=hdas) else: return VisualizationPluginResponse(**registry.get_plugin(id).to_dict())
def _get_registry(self): """Get the visualizations registry or raise an error if not configured.""" if not self.app.visualizations_registry: raise MessageException("The visualization registry has not been configured.") return self.app.visualizations_registry