def visualise_attention(smiles_str, attn_model: Transformer): ''' Script for visualising the attention in the model for a given input :param smiles_str: SMILES string input e.g. 'Cc1ccccc1' :param attn_model: a loaded Transformer model :param params: params of that transformer model :param tokens: tokens of that transformer model (charset) :return: ''' tokens = attn_model.i_tokens smiles_seq = SmilesToArray(smiles_str, tokens=tokens, max_len=attn_model.len_limit) attentions = attn_model.output_attns.predict_on_batch( [smiles_seq, smiles_seq]) # split attentions into individual attentions n_layers = int(len(attentions) / 3) reshape_attn = lambda ATT: np.swapaxes(np.expand_dims(np.array(ATT), 1), 3, 4) enc_atts = reshape_attn(attentions[0:n_layers]) dec_atts = reshape_attn(attentions[n_layers:2 * n_layers]) encdec_atts = reshape_attn(attentions[2 * n_layers:]) print( "Shapes of arrays:\n\tenc_atts:\t{}\n\tdec_atts:\t{}\n\tencdec_atts:\t{}" .format(np.shape(enc_atts), np.shape(dec_atts), np.shape(encdec_atts))) call_html() out_str = list(smiles_str) + [tokens.token(tokens.endid())] in_str = [tokens.token(tokens.startid())] + out_str attention.show(in_str, out_str, enc_atts, dec_atts, encdec_atts)
# Convert inputs and outputs to subwords inp_text = to_tokens(encoders["inputs"].encode(inputs)) out_text = to_tokens(encoders["inputs"].encode(outputs)) # Run eval to collect attention weights example = encode_eval(inputs, outputs) with tfe.restore_variables_on_create( tf.train.latest_checkpoint(checkpoint_dir)): translate_model.set_mode(Modes.EVAL) translate_model(example) # Get normalized attention weights for each layer enc_atts, dec_atts, encdec_atts = get_att_mats() call_html() attention.show(inp_text, out_text, enc_atts, dec_atts, encdec_atts) """# Train a custom model on MNIST""" # Create your own model class MySimpleModel(t2t_model.T2TModel): def body(self, features): inputs = features["inputs"] filters = self.hparams.hidden_size h1 = tf.layers.conv2d(inputs, filters, kernel_size=(5, 5), strides=(2, 2)) h2 = tf.layers.conv2d(tf.nn.relu(h1), filters,