Warning

This document is for an in-development version of Galaxy. You can alternatively view this page in the latest release if it exists or view the top of the latest release's documentation.

Source code for galaxy.workflow.refactor.execute

import logging
from typing import (
    Any,
    Dict,
)

from galaxy.exceptions import RequestParameterInvalidException
from galaxy.tools.parameters import visit_input_values
from galaxy.tools.parameters.basic import contains_workflow_parameter
from galaxy.tools.parameters.workflow_utils import (
    ConnectedValue,
    runtime_to_json,
)
from .schema import (
    AddInputAction,
    AddStepAction,
    ConnectAction,
    DisconnectAction,
    ExtractInputAction,
    ExtractUntypedParameter,
    FileDefaultsAction,
    FillStepDefaultsAction,
    InputReferenceByOrderIndex,
    OutputReferenceByOrderIndex,
    RefactorActionExecution,
    RefactorActionExecutionMessage,
    RefactorActionExecutionMessageTypeEnum,
    RefactorActions,
    RemoveUnlabeledWorkflowOutputs,
    step_reference_union,
    StepReferenceByLabel,
    UpdateAnnotationAction,
    UpdateCreatorAction,
    UpdateLicenseAction,
    UpdateNameAction,
    UpdateOutputLabelAction,
    UpdateReportAction,
    UpdateStepLabelAction,
    UpdateStepPositionAction,
    UpgradeAllStepsAction,
    UpgradeSubworkflowAction,
    UpgradeToolAction,
)
from ..modules import (
    InputParameterModule,
    NO_REPLACEMENT,
)

log = logging.getLogger(__name__)


[docs]class WorkflowRefactorExecutor:
[docs] def __init__(self, raw_workflow_description, workflow, module_injector): # we mostly use the ga representation, but there may be cases where the # models/modules of existing workflow are more usable. self.raw_workflow_description = raw_workflow_description self.workflow = workflow self.module_injector = module_injector self.module_injector.inject_all(workflow, ignore_tool_missing_exception=True)
[docs] def refactor(self, refactor_request: RefactorActions): action_executions = [] for action in refactor_request.actions: # TODO: we need to regenerate a detached workflow from as_dict after # after each iteration here. Otherwise one set of changes might render # the workflow state out of sync. It is fine if you're just executing one # action at a time or just performing actions that use raw_workflow_description. action_type = action.action_type refactor_method_name = f"_apply_{action_type}" refactor_method = getattr(self, refactor_method_name, None) if refactor_method is None: raise RequestParameterInvalidException(f"Unknown workflow editing action encountered [{action_type}]") execution = RefactorActionExecution( action=action, messages=[], ) refactor_method(action, execution) action_executions.append(execution) return action_executions
def _apply_update_step_label(self, action: UpdateStepLabelAction, execution: RefactorActionExecution): step = self._find_step_for_action(action) step["label"] = action.label def _apply_update_step_position(self, action: UpdateStepPositionAction, execution: RefactorActionExecution): step = self._find_step_for_action(action) position_update = action.position_shift.to_dict() step["position"]["left"] = step["position"].get("left", 0) + position_update["left"] step["position"]["top"] = step["position"].get("top", 0) + position_update["top"] def _apply_update_output_label(self, action: UpdateOutputLabelAction, execution: RefactorActionExecution): output_reference = action.output output_step_dict = self._find_step(output_reference) output_name = output_reference.output_name for workflow_output_def in output_step_dict.get("workflow_outputs", []): if workflow_output_def["output_name"] == output_name: workflow_output_def["label"] = action.output_label def _apply_update_name(self, action: UpdateNameAction, execution: RefactorActionExecution): self._as_dict["name"] = action.name def _apply_update_annotation(self, action: UpdateAnnotationAction, execution: RefactorActionExecution): self._as_dict["annotation"] = action.annotation def _apply_update_license(self, action: UpdateLicenseAction, execution: RefactorActionExecution): self._as_dict["license"] = action.license def _apply_update_creator(self, action: UpdateCreatorAction, execution: RefactorActionExecution): self._as_dict["creator"] = action.creator def _apply_update_report(self, action: UpdateReportAction, execution: RefactorActionExecution): self._as_dict["report"] = {"markdown": action.report.markdown} def _apply_add_step(self, action: AddStepAction, execution: RefactorActionExecution): steps = self._as_dict["steps"] order_index = len(steps) step_dict = { "order_index": order_index, "id": f"new_{order_index}", "type": action.type, } if action.tool_state: step_dict["tool_state"] = action.tool_state if action.label: step_dict["label"] = action.label if action.position: step_dict["position"] = action.position.to_dict() steps[order_index] = step_dict def _apply_add_input(self, action: AddInputAction, execution: RefactorActionExecution): input_type = action.type module_type = None tool_state: Dict[str, Any] = {} if input_type in ["data", "dataset"]: module_type = "data_input" elif input_type in ["data_collection", "dataset_collection"]: module_type = "data_collection_input" tool_state["collection_type"] = action.collection_type else: if input_type not in InputParameterModule.POSSIBLE_PARAMETER_TYPES: raise RequestParameterInvalidException(f"Invalid input type {input_type} encountered") module_type = "parameter_input" tool_state["parameter_type"] = input_type for action_key in ["restrictions", "suggestions", "optional", "default"]: value = getattr(action, action_key, None) if value is not None: tool_state[action_key] = value if action.restrict_on_connections is not None: tool_state["restrictOnConnections"] = action.restrict_on_connections add_step_kwds = {} if action.label: add_step_kwds["label"] = action.label add_step_action = AddStepAction( action_type="add_step", type=module_type, tool_state=tool_state, position=action.position, **add_step_kwds ) self._apply_add_step(add_step_action, execution) def _apply_disconnect(self, action: DisconnectAction, execution: RefactorActionExecution): input_step_dict, input_name, output_step_dict, output_name = self._connection(action) output_order_index = output_step_dict["id"] # wish this was order_index... # default name is name used for input's output terminal - following # format2 convention of allowing this be absent for clean references # to workflow inputs. all_input_connections = input_step_dict.get("input_connections") self.normalize_input_connections_to_list(all_input_connections, input_name) input_connections = all_input_connections[input_name] # multiple outputs attached to this inputs, just detach # that specific one. delete_index = None for connection_index, output in enumerate(input_connections): if output["id"] == output_order_index and output["output_name"] == output_name: delete_index = connection_index break if delete_index is None: raise RequestParameterInvalidException("Failed to locate connection to disconnect") del input_connections[delete_index] def _apply_connect(self, action: ConnectAction, execution: RefactorActionExecution): input_step_dict, input_name, output_step_dict, output_name = self._connection(action) output_order_index = output_step_dict["id"] # wish this was order_index... all_input_connections = input_step_dict.get("input_connections") self.normalize_input_connections_to_list(all_input_connections, input_name, add_if_missing=True) input_connections = all_input_connections[input_name] input_connections.append( { "id": output_order_index, "output_name": output_name, } ) def _apply_fill_defaults(self, action: FileDefaultsAction, execution: RefactorActionExecution): for _, step in self._iterate_over_step_pairs(execution): module = step.module if module.type != "tool": continue self._as_dict["steps"][step.order_index]["tool_state"] = step.module.get_tool_state() def _apply_fill_step_defaults(self, action: FillStepDefaultsAction, execution: RefactorActionExecution): step = self._find_step_with_module_for_action(action, execution) self._as_dict["steps"][step.order_index]["tool_state"] = step.module.get_tool_state() def _apply_extract_input(self, action: ExtractInputAction, execution: RefactorActionExecution): input_step_dict, input_name = self._input_from_action(action) step = self._step_with_module(input_step_dict["id"], execution) module = step.module inputs = module.get_all_inputs() input_def = None found_input_names = [] for input in inputs: found_input_name = input["name"] found_input_names.append(found_input_name) if found_input_name == input_name: input_def = input break if input_def is None: raise RequestParameterInvalidException( f"Failed to find input with name {input_name} on step {input_step_dict['id']} - input names found {found_input_names}" ) if input_def.get("multiple", False): raise RequestParameterInvalidException("Cannot extract input for multi-input inputs") module_input_type = input_def.get("input_type") # convert dataset, dataset_collection => data, data_collection for refactor API input_type = { "dataset": "data", "dataset_collection": "data_collection", }.get(module_input_type, module_input_type) input_action = AddInputAction( action_type="add_input", optional=input_def.get("optional"), type=input_type, label=action.label, position=action.position, ) new_input_order_index = self._add_input_get_order_index(input_action, execution) connect_action = ConnectAction( action_type="connect", input=action.input, output=OutputReferenceByOrderIndex(order_index=new_input_order_index), ) self._apply_connect(connect_action, execution) def _apply_extract_untyped_parameter(self, action: ExtractUntypedParameter, execution: RefactorActionExecution): untyped_parameter_name = action.name new_label = action.label or untyped_parameter_name target_value = f"${{{untyped_parameter_name}}}" target_tool_inputs = [] rename_pjas = [] for step_def, step in self._iterate_over_step_pairs(execution): module = step.module if module.type != "tool": continue # TODO: require a clean tool state for all tools to do this. tool = module.tool tool_inputs = module.state replace_tool_state = False def callback(input, prefixed_name, context, value=None, **kwargs): nonlocal replace_tool_state # data parameters cannot have untyped parameter values if input.type in ["data", "data_collection"]: return NO_REPLACEMENT if not contains_workflow_parameter(value): return NO_REPLACEMENT if value == target_value: target_tool_inputs.append((step.order_index, input, prefixed_name)) # noqa: B023 replace_tool_state = True return runtime_to_json(ConnectedValue()) else: return NO_REPLACEMENT visit_input_values(tool.inputs, tool_inputs.inputs, callback, no_replacement_value=NO_REPLACEMENT) if replace_tool_state: step_def["tool_state"] = step.module.get_tool_state() for post_job_action in self._iterate_over_rename_pjas(): newname = post_job_action.get("action_arguments", {}).get("newname") if target_value in newname: rename_pjas.append(post_job_action) if len(target_tool_inputs) == 0 and len(rename_pjas) == 0: raise RequestParameterInvalidException( f"Failed to find {target_value} in the tool state or any workflow steps." ) as_parameter_type = { "text": "text", "integer": "integer", "float": "float", "select": "text", "genomebuild": "text", } target_parameter_types = set() for _, tool_input, _ in target_tool_inputs: tool_input_type = tool_input.type if tool_input_type not in as_parameter_type: raise RequestParameterInvalidException( "Extracting inputs for parameters on tool inputs of type {tool_input_type} is unsupported" ) target_parameter_type = as_parameter_type[tool_input_type] target_parameter_types.add(target_parameter_type) if len(target_parameter_types) > 1: raise RequestParameterInvalidException( "Extracting inputs for parameters on conflicting tool input types (e.g. numeric and non-numeric) input types is unsupported" ) if len(target_parameter_types) == 1: (target_parameter_type,) = target_parameter_types else: # only used in PJA, hence only used a string target_parameter_type = "text" for rename_pja in rename_pjas: # if name != label, got to rewrite this rename with new label. if untyped_parameter_name != new_label: action_arguments = rename_pja.get("action_arguments") old_newname = action_arguments["newname"] new_newname = old_newname.replace(target_value, f"${{{new_label}}}") action_arguments["newname"] = new_newname optional = False input_action = AddInputAction( action_type="add_input", optional=optional, type=target_parameter_type, label=new_label, position=action.position, ) new_input_order_index = self._add_input_get_order_index(input_action, execution) for order_index, _tool_input, prefixed_name in target_tool_inputs: connect_input = InputReferenceByOrderIndex(order_index=order_index, input_name=prefixed_name) connect_action = ConnectAction( action_type="connect", input=connect_input, output=OutputReferenceByOrderIndex(order_index=new_input_order_index), ) self._apply_connect(connect_action, execution) def _apply_remove_unlabeled_workflow_outputs( self, action: RemoveUnlabeledWorkflowOutputs, execution: RefactorActionExecution ): for step in self._as_dict["steps"].values(): new_outputs = [] for workflow_output in step.get("workflow_outputs", []): if workflow_output.get("label") is None: continue new_outputs.append(workflow_output) step["workflow_outputs"] = new_outputs def _apply_upgrade_subworkflow(self, action: UpgradeSubworkflowAction, execution: RefactorActionExecution): step_def = self._find_step(action.step) assert step_def["content_id"] is not None trans = self.module_injector.trans content_id = action.content_id if content_id is None: old_workflow = trans.app.workflow_manager.get_owned_workflow(trans, step_def["content_id"]) stored_workflow = old_workflow.stored_workflow content_id = trans.security.encode_id(stored_workflow.latest_workflow.id) step_def["content_id"] = content_id step = self.workflow.steps[step_def["id"]] new_workflow = trans.app.workflow_manager.get_owned_workflow(trans, content_id) step.subworkflow = new_workflow self._inject_for_updated_step(step, execution) self._patch_step(execution, step, step_def) def _apply_upgrade_tool(self, action: UpgradeToolAction, execution: RefactorActionExecution): step_def = self._find_step(action.step) tool_id = step_def["content_id"] trans = self.module_injector.trans tool_version = action.tool_version if tool_version is None: latest_tool = trans.app.toolbox.get_tool(tool_id, get_all_versions=True)[-1] tool_version = latest_tool.version tool_id = latest_tool.id step = self.workflow.steps[step_def["id"]] step.tool_id = tool_id step.tool_version = tool_version self._inject_for_updated_step(step, execution) step_def["tool_version"] = tool_version step_def["tool_state"] = step.module.get_tool_state() if step_def.get("tool_id"): step_def["tool_id"] = step.module.get_content_id() if step_def.get("content_id"): step_def["content_id"] = step.module.get_content_id() self._patch_step(execution, step, step_def) def _apply_upgrade_all_steps(self, action: UpgradeAllStepsAction, execution: RefactorActionExecution): for step_order_index, step in self._as_dict["steps"].items(): if step.get("type") == "subworkflow": step_action_s = UpgradeSubworkflowAction( action_type="upgrade_subworkflow", step={"order_index": step_order_index} ) self._apply_upgrade_subworkflow(step_action_s, execution) elif step.get("type") == "tool": step_action_t = UpgradeToolAction(action_type="upgrade_tool", step={"order_index": step_order_index}) self._apply_upgrade_tool(step_action_t, execution) def _find_step(self, step_reference: step_reference_union): order_index = None if isinstance(step_reference, StepReferenceByLabel): label = step_reference.label if not label: raise RequestParameterInvalidException("Empty label provided.") for step_order_index, step in self._as_dict["steps"].items(): if step["label"] == label: order_index = step_order_index break else: order_index = step_reference.order_index if order_index is None: raise RequestParameterInvalidException(f"Failed to resolve step_reference {step_reference}") if len(self._as_dict["steps"]) <= order_index: raise RequestParameterInvalidException(f"Failed to resolve step_reference {step_reference}") return self._as_dict["steps"][order_index] def _find_step_for_action(self, action): step_reference = action.step return self._find_step(step_reference) def _find_step_with_module_for_action(self, action, execution): step_reference = action.step step_def = self._find_step(step_reference) step = self.workflow.steps[step_def["id"]] return self._inject(step, execution) def _step_with_module(self, order_index, execution): step = self.workflow.steps[order_index] return self._inject(step, execution) def _iterate_over_step_pairs(self, execution): # walk over both the dict-ified steps and the model steps (ensuring) # module is attached. for order_index, step_def in self._as_dict["steps"].items(): if order_index >= len(self.workflow.steps): # newly added step during refactoring, don't iterate over it... continue else: step = self._step_with_module(order_index, execution) yield step_def, step def _inject_for_updated_step(self, step, execution): step.clear_module_extras() return self._inject(step, execution) def _inject(self, step, execution): # compute runtime state, capture upgrade messages that result if not hasattr(step, "module"): self.module_injector.inject(step) self.module_injector.compute_runtime_state(step) if getattr(step, "upgrade_messages", None): for key, value in step.upgrade_messages.items(): message = RefactorActionExecutionMessage( message=value, message_type=RefactorActionExecutionMessageTypeEnum.tool_state_adjustment, input_name=key, step_label=step.label, order_index=step.order_index, ) execution.messages.append(message) if getattr(step.module, "version_changes", None): for version_change in step.module.version_changes: message = RefactorActionExecutionMessage( message=version_change, message_type=RefactorActionExecutionMessageTypeEnum.tool_version_change, step_label=step.label, order_index=step.order_index, ) execution.messages.append(message) return step def _iterate_over_rename_pjas(self): for _, step_def in self._as_dict["steps"].items(): if step_def["type"] != "tool": continue post_job_actions = step_def.get("post_job_actions", []) for post_job_action in post_job_actions.values(): if post_job_action["action_type"] == "RenameDatasetAction": yield post_job_action def _add_input_get_order_index(self, input_action: AddInputAction, execution: RefactorActionExecution): self._apply_add_input(input_action, execution) return len(self._as_dict["steps"]) - 1 def _input_from_action(self, action): input_reference = action.input input_step_dict = self._find_step(input_reference) input_name = input_reference.input_name return input_step_dict, input_name def _connection(self, action): input_step_dict, input_name = self._input_from_action(action) output_reference = action.output output_step_dict = self._find_step(output_reference) output_name = output_reference.output_name return input_step_dict, input_name, output_step_dict, output_name def _patch_step(self, execution, step, step_def): """ patch a workflow step after upgrading the tool / subworkflow """ # TODO: find workflow outputs that need to be dropped and report them upgrade_inputs = step.module.get_all_inputs() upgrade_outputs = step.module.get_all_outputs() upgrade_output_names = {u["name"] for u in upgrade_outputs} upgrade_order_index = step_def["id"] upgrade_label = step_def.get("label") all_input_connections = step_def.get("input_connections") inputs_to_delete = [] for input_name, input_connections in all_input_connections.items(): # try and find an input connection for each input matching_input = None for upgrade_input in upgrade_inputs: if upgrade_input["name"] == input_name: matching_input = upgrade_input break elif step.when_expression and f"inputs.{input_name}" in step.when_expression: # TODO: eventually track step inputs more formally matching_input = upgrade_input # In the future check parameter type, format, mapping status... if matching_input is None: inputs_to_delete.append(input_name) for input_connection in _listify_connections(input_connections): message_text = ( f"Tool or subworkflow input '{input_name}' no longer available, dropping connection to it." ) from_order_index = input_connection["id"] from_step_label = self._as_dict["steps"].get(from_order_index, {}).get("label") message = RefactorActionExecutionMessage( message=message_text, message_type=RefactorActionExecutionMessageTypeEnum.connection_drop_forced, input_name=input_name, step_label=upgrade_label, order_index=upgrade_order_index, output_name=input_connection["output_name"], from_order_index=from_order_index, from_step_label=from_step_label, ) execution.messages.append(message) for input_name in inputs_to_delete: del all_input_connections[input_name] for _, step in self._as_dict["steps"].items(): all_input_connections = step.get("input_connections") for input_name, input_connections in all_input_connections.items(): rebuilt_valid_connections = [] for input_connection in _listify_connections(input_connections): include = False if input_connection["id"] != upgrade_order_index: include = True else: output_name = input_connection["output_name"] if output_name in upgrade_output_names: include = True else: # dropped outputs message_text = f"Tool or subworkflow output '{output_name}' no longer available, dropping connection from it." message = RefactorActionExecutionMessage( message=message_text, message_type=RefactorActionExecutionMessageTypeEnum.connection_drop_forced, input_name=input_name, step_label=step.get("label"), order_index=step["id"], output_name=output_name, from_step_label=upgrade_label, from_order_index=upgrade_order_index, ) execution.messages.append(message) if include: rebuilt_valid_connections.append(input_connection) all_input_connections[input_name] = rebuilt_valid_connections workflow_outputs_to_delete = [] for workflow_output in step_def.get("workflow_outputs", []): output_label = workflow_output.get("label") output_name = workflow_output["output_name"] if not output_label: continue if output_name not in upgrade_output_names: workflow_outputs_to_delete.append(workflow_output) message_text = f"Subworkflow output '{output_name}' no longer available, dropping corresponding workflow output label {output_label}." message = RefactorActionExecutionMessage( message=message_text, message_type=RefactorActionExecutionMessageTypeEnum.workflow_output_drop_forced, step_label=upgrade_label, order_index=upgrade_order_index, output_name=workflow_output.get("output_name"), output_label=output_label, ) execution.messages.append(message)
[docs] @staticmethod def normalize_input_connections_to_list(all_input_connections, input_name, add_if_missing=False): if add_if_missing and input_name not in all_input_connections: all_input_connections[input_name] = [] input_connections = all_input_connections[input_name] all_input_connections[input_name] = _listify_connections(input_connections)
@property def _as_dict(self): return self.raw_workflow_description.as_dict
def _listify_connections(input_connections): if not isinstance(input_connections, list): return [input_connections] return input_connections