Source code for

""" Compute tool recommendations """

import json
import logging
import os

import h5py
import numpy as np
import requests
import yaml

from import populate_state
from import workflow_building_modes
from galaxy.util import DEFAULT_SOCKET_TIMEOUT
from galaxy.workflow.modules import module_factory

log = logging.getLogger(__name__)

[docs]class ToolRecommendations: max_seq_len = 25 ff_dim = 128 embed_dim = 128 num_heads = 4 dropout = 0.1
[docs] def __init__(self): self.tool_recommendation_model_path = None self.admin_tool_recommendations_path = None self.deprecated_tools = dict() self.admin_recommendations = dict() self.model_data_dictionary = dict() self.reverse_dictionary = dict() self.all_tools = dict() self.tool_weights_sorted = dict() self.loaded_model = None self.compatible_tools = None self.standard_connections = None self.model_ok = None
[docs] def create_transformer_model(self, vocab_size): try: from tensorflow.keras.layers import ( Dense, Dropout, Embedding, GlobalAveragePooling1D, Input, Layer, LayerNormalization, MultiHeadAttention, ) from tensorflow.keras.models import ( Model, Sequential, ) except Exception as e: log.exception(e) return None class TransformerBlock(Layer): def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1): super().__init__() self.att = MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim, dropout=rate) self.ffn = Sequential([Dense(ff_dim, activation="relu"), Dense(embed_dim)]) self.layernorm1 = LayerNormalization(epsilon=1e-6) self.layernorm2 = LayerNormalization(epsilon=1e-6) self.dropout1 = Dropout(rate) self.dropout2 = Dropout(rate) def call(self, inputs, training): attn_output, attention_scores = self.att( inputs, inputs, inputs, return_attention_scores=True, training=training ) attn_output = self.dropout1(attn_output, training=training) out1 = self.layernorm1(inputs + attn_output) ffn_output = self.ffn(out1) ffn_output = self.dropout2(ffn_output, training=training) return self.layernorm2(out1 + ffn_output), attention_scores class TokenAndPositionEmbedding(Layer): def __init__(self, maxlen, vocab_size, embed_dim): super().__init__() self.token_emb = Embedding(input_dim=vocab_size, output_dim=embed_dim, mask_zero=True) self.pos_emb = Embedding(input_dim=maxlen, output_dim=embed_dim, mask_zero=True) def call(self, x): try: import tensorflow as tf maxlen = tf.shape(x)[-1] positions = tf.range(start=0, limit=maxlen, delta=1) positions = self.pos_emb(positions) x = self.token_emb(x) return x + positions except Exception as e: log.exception(e) return None inputs = Input(shape=(ToolRecommendations.max_seq_len,)) embedding_layer = TokenAndPositionEmbedding( ToolRecommendations.max_seq_len, vocab_size, ToolRecommendations.embed_dim ) x = embedding_layer(inputs) transformer_block = TransformerBlock( ToolRecommendations.embed_dim, ToolRecommendations.num_heads, ToolRecommendations.ff_dim ) x, weights = transformer_block(x) x = GlobalAveragePooling1D()(x) x = Dropout(ToolRecommendations.dropout)(x) x = Dense(ToolRecommendations.ff_dim, activation="relu")(x) x = Dropout(ToolRecommendations.dropout)(x) outputs = Dense(vocab_size, activation="sigmoid")(x) return Model(inputs=inputs, outputs=[outputs, weights])
[docs] def get_predictions(self, trans, tool_sequence, remote_model_url): """ Compute tool predictions """ recommended_tools = dict() self.__collect_admin_preferences( if self.model_ok is None: self.__set_model(trans, remote_model_url) recommended_tools = self.__compute_tool_prediction(trans, tool_sequence) return tool_sequence, recommended_tools
def __set_model(self, trans, remote_model_url): """ Create model and associated dictionaries for recommendations """ self.tool_recommendation_model_path = self.__download_model(remote_model_url) model_file = h5py.File(self.tool_recommendation_model_path, "r") self.reverse_dictionary = json.loads(model_file["reverse_dict"][()].decode("utf-8")) self.loaded_model = self.create_transformer_model(len(self.reverse_dictionary) + 1) self.loaded_model.load_weights(self.tool_recommendation_model_path) self.model_data_dictionary = {v: k for k, v in self.reverse_dictionary.items()} # set the list of compatible tools self.compatible_tools = json.loads(model_file["compatible_tools"][()].decode("utf-8")) tool_weights = json.loads(model_file["class_weights"][()].decode("utf-8")) self.standard_connections = json.loads(model_file["standard_connections"][()].decode("utf-8")) # sort the tools' usage dictionary tool_pos_sorted = [int(key) for key in tool_weights.keys()] for k in tool_pos_sorted: self.tool_weights_sorted[k] = tool_weights[str(k)] # collect ids and names of all the installed tools for tool_id, tool in t_id_renamed = tool_id if t_id_renamed.find("/") > -1: t_id_renamed = t_id_renamed.split("/")[-2] self.all_tools[t_id_renamed] = (tool_id, self.model_ok = True def __collect_admin_preferences(self, admin_path): """ Collect preferences for recommendations of tools set by admins as dictionaries of deprecated tools and additional recommendations """ if not self.admin_tool_recommendations_path and admin_path is not None: self.admin_tool_recommendations_path = os.path.join(os.getcwd(), admin_path) if os.path.exists(self.admin_tool_recommendations_path): with open(self.admin_tool_recommendations_path) as admin_recommendations: admin_recommendation_preferences = yaml.safe_load(admin_recommendations) if admin_recommendation_preferences: for tool_id in admin_recommendation_preferences: tool_info = admin_recommendation_preferences[tool_id] if "is_deprecated" in tool_info[0]: self.deprecated_tools[tool_id] = tool_info[0]["text_message"] else: if tool_id not in self.admin_recommendations: self.admin_recommendations[tool_id] = tool_info def __download_model(self, model_url, download_local="database/"): """ Download the model from remote server """ local_dir = os.path.join(os.getcwd(), download_local, "tool_recommendation_model.hdf5") # read model from remote model_binary = requests.get(model_url, timeout=DEFAULT_SOCKET_TIMEOUT) # save model to a local directory with open(local_dir, "wb") as model_file: model_file.write(model_binary.content) return local_dir def __get_tool_extensions(self, trans, tool_id): """ Get the input and output extensions of a tool """ payload = {"type": "tool", "tool_id": tool_id, "_": "true"} inputs = payload.get("inputs", {}) trans.workflow_building_mode = workflow_building_modes.ENABLED module = module_factory.from_dict(trans, payload) if "tool_state" not in payload: module_state = {} populate_state(trans, module.get_inputs(), inputs, module_state, check=False) module.recover_state(module_state) inputs = module.get_all_inputs(connectable_only=True) outputs = module.get_all_outputs() input_extensions = list() output_extensions = list() for i_ext in inputs: input_extensions.extend(i_ext["extensions"]) for o_ext in outputs: output_extensions.extend(o_ext["extensions"]) return input_extensions, output_extensions def __filter_tool_predictions(self, trans, prediction_data, tool_ids, tool_scores, last_tool_name): """ Filter tool predictions based on datatype compatibility and tool connections. Add admin preferences to recommendations. """ last_compatible_tools = list() if last_tool_name in self.model_data_dictionary: last_tool_name_id = self.model_data_dictionary[last_tool_name] if last_tool_name_id in self.compatible_tools: last_compatible_tools = [ self.reverse_dictionary[t_id] for t_id in self.compatible_tools[last_tool_name_id] ] prediction_data["is_deprecated"] = False # get the list of datatype extensions of the last tool of the tool sequence _, last_output_extensions = self.__get_tool_extensions(trans, self.all_tools[last_tool_name][0]) prediction_data["o_extensions"] = list(set(last_output_extensions)) t_ids_scores = zip(tool_ids, tool_scores) # form the payload of the predicted tools to be shown for child, score in t_ids_scores: c_dict = dict() for t_id in self.all_tools: # select the name and tool id if it is installed in Galaxy if ( t_id == child and score >= 0.0 and t_id in last_compatible_tools and child not in self.deprecated_tools ): full_tool_id = self.all_tools[t_id][0] pred_input_extensions, _ = self.__get_tool_extensions(trans, full_tool_id) c_dict["name"] = self.all_tools[t_id][1] c_dict["tool_id"] = full_tool_id c_dict["i_extensions"] = list(set(pred_input_extensions)) prediction_data["children"].append(c_dict) break # incorporate preferences set by admins if self.admin_tool_recommendations_path: # filter out deprecated tools t_ids_scores = [ (tid, score) for tid, score in zip(tool_ids, tool_scores) if tid not in self.deprecated_tools ] # set the property if the last tool of the sequence is deprecated if last_tool_name in self.deprecated_tools: prediction_data["is_deprecated"] = True prediction_data["message"] = self.deprecated_tools[last_tool_name] # add the recommendations given by admins for tool_id in self.admin_recommendations: if last_tool_name == tool_id: admin_recommendations = self.admin_recommendations[tool_id] if is True: prediction_data["children"] = admin_recommendations else: prediction_data["children"].extend(admin_recommendations) break # get the root name for displaying after tool run for t_id in self.all_tools: if t_id == last_tool_name: prediction_data["name"] = self.all_tools[t_id][1] break return prediction_data def __get_predicted_tools(self, base_tools, predictions, topk): """ Get predicted tools. If predicted tools are less in number, combine them with published tools """ t_intersect = list(set(predictions).intersection(set(base_tools))) t_diff = list(set(predictions).difference(set(base_tools))) t_intersect, u_intersect = self.__sort_by_usage( t_intersect, self.tool_weights_sorted, self.model_data_dictionary ) t_diff, u_diff = self.__sort_by_usage(t_diff, self.tool_weights_sorted, self.model_data_dictionary) t_intersect.extend(t_diff) u_intersect.extend(u_diff) return t_intersect, u_intersect def __sort_by_usage(self, t_list, class_weights, d_dict): """ Sort predictions by usage/class weights """ tool_dict = dict() for tool in t_list: t_id = d_dict[tool] tool_dict[tool] = class_weights[int(t_id)] tool_dict = dict(sorted(tool_dict.items(), key=lambda kv: kv[1], reverse=True)) return list(tool_dict.keys()), list(tool_dict.values()) def __separate_predictions(self, base_tools, predictions, last_tool_name, weight_values, topk): """ Get predictions from published and normal workflows """ last_base_tools = list() prediction_pos = np.argsort(predictions, axis=-1) topk_prediction_pos = prediction_pos[-topk:] # get tool ids pred_tool_names = [self.reverse_dictionary[str(tool_pos)] for tool_pos in topk_prediction_pos] if last_tool_name in base_tools: last_base_tools = base_tools[last_tool_name] if type(last_base_tools).__name__ == "str": # get published or compatible tools for the last tool in a sequence of tools last_base_tools = last_base_tools.split(",") # get predicted tools sorted_c_t, sorted_c_v = self.__get_predicted_tools(last_base_tools, pred_tool_names, topk) return sorted_c_t, sorted_c_v def __compute_tool_prediction(self, trans, tool_sequence): """ Compute the predicted tools for a tool sequences Return a payload with the tool sequences and recommended tools Return an empty payload with just the tool sequence if anything goes wrong within the try block """ topk = prediction_data = dict() tool_sequence = tool_sequence.split(",")[::-1] prediction_data["name"] = ",".join(tool_sequence) prediction_data["children"] = list() last_tool_name = tool_sequence[-1] # do prediction only if the last is present in the collections of tools if last_tool_name in self.model_data_dictionary: sample = np.zeros(self.max_seq_len) # get tool names without slashes and create a sequence vector for idx, tool_name in enumerate(tool_sequence): if tool_name.find("/") > -1: tool_name = tool_name.split("/")[-2] try: sample[idx] = int(self.model_data_dictionary[tool_name]) except Exception: log.exception(f"Failed to find tool {tool_name} in model") return prediction_data sample = np.reshape(sample, (1, self.max_seq_len)) # boost the predicted scores using tools' usage weight_values = list(self.tool_weights_sorted.values()) # predict next tools for a test path try: import tensorflow as tf sample = tf.convert_to_tensor(sample, dtype=tf.int64) prediction, _ = self.loaded_model(sample, training=False) except Exception as e: log.exception(e) return prediction_data # get dimensions nw_dimension = prediction.shape[1] prediction = np.reshape(prediction, (nw_dimension,)) # get recommended tools from published workflows pub_t, pub_v = self.__separate_predictions( self.standard_connections, prediction, last_tool_name, weight_values, topk ) # remove duplicates if any pub_t = list(dict.fromkeys(pub_t)) pub_v = list(dict.fromkeys(pub_v)) prediction_data = self.__filter_tool_predictions(trans, prediction_data, pub_t, pub_v, last_tool_name) return prediction_data