def __call__(self, x: dy.Expression) -> dy.Expression: """ Move the time-dimension of an input expression into the batch dimension via a reshape. Args: x: expression of dimensions ((hidden, timesteps), batch_size) Returns: expression of dimensions ((hidden,), timesteps*batch_size) """ batch_size = x[0].dim()[1] model_dim = x[0].dim()[0][0] seq_len = len(x) total_words = seq_len * batch_size input_tensor = x.as_tensor() return dy.reshape(input_tensor, (model_dim, ), batch_size=total_words)
def __call__(self, x: dy.Expression, att_mask: np.ndarray, batch_mask: np.ndarray, p: numbers.Real): """ x: expression of dimensions (input_dim, time) x batch att_mask: numpy array of dimensions (time, time); pre-transposed batch_mask: numpy array of dimensions (batch, time) p: dropout prob """ sent_len = x.dim()[0][1] batch_size = x[0].dim()[1] if self.downsample_factor > 1: if sent_len % self.downsample_factor != 0: raise ValueError( "For 'reshape' downsampling, sequence lengths must be multiples of the downsampling factor. " "Configure batcher accordingly.") if batch_mask is not None: batch_mask = batch_mask[:, ::self.downsample_factor] sent_len_out = sent_len // self.downsample_factor sent_len = sent_len_out out_mask = x.mask if self.downsample_factor > 1 and out_mask is not None: out_mask = out_mask.lin_subsampled( reduce_factor=self.downsample_factor) x = ExpressionSequence(expr_tensor=dy.reshape( x.as_tensor(), (x.dim()[0][0] * self.downsample_factor, x.dim()[0][1] / self.downsample_factor), batch_size=batch_size), mask=out_mask) residual = SAAMTimeDistributed()(x) else: residual = SAAMTimeDistributed()(x) sent_len_out = sent_len if self.model_dim != self.input_dim * self.downsample_factor: residual = self.res_shortcut.transform(residual) # Concatenate all the words together for doing vectorized affine transform if self.kq_pos_encoding_type is None: kvq_lin = self.linear_kvq.transform(SAAMTimeDistributed()(x)) key_up = self.shape_projection( dy.pick_range(kvq_lin, 0, self.head_count * self.dim_per_head), batch_size) value_up = self.shape_projection( dy.pick_range(kvq_lin, self.head_count * self.dim_per_head, 2 * self.head_count * self.dim_per_head), batch_size) query_up = self.shape_projection( dy.pick_range(kvq_lin, 2 * self.head_count * self.dim_per_head, 3 * self.head_count * self.dim_per_head), batch_size) else: assert self.kq_pos_encoding_type == "embedding" encoding = self.kq_positional_embedder.embed_sent( sent_len).as_tensor() kq_lin = self.linear_kq.transform(SAAMTimeDistributed()( ExpressionSequence( expr_tensor=dy.concatenate([x.as_tensor(), encoding])))) key_up = self.shape_projection( dy.pick_range(kq_lin, 0, self.head_count * self.dim_per_head), batch_size) query_up = self.shape_projection( dy.pick_range(kq_lin, self.head_count * self.dim_per_head, 2 * self.head_count * self.dim_per_head), batch_size) v_lin = self.linear_v.transform(SAAMTimeDistributed()(x)) value_up = self.shape_projection(v_lin, batch_size) if self.cross_pos_encoding_type: assert self.cross_pos_encoding_type == "embedding" emb1 = dy.pick_range(dy.parameter(self.cross_pos_emb_p1), 0, sent_len) emb2 = dy.pick_range(dy.parameter(self.cross_pos_emb_p2), 0, sent_len) key_up = dy.reshape(key_up, (sent_len, self.dim_per_head, self.head_count), batch_size=batch_size) key_up = dy.concatenate_cols( [dy.cmult(key_up, emb1), dy.cmult(key_up, emb2)]) key_up = dy.reshape(key_up, (sent_len, self.dim_per_head * 2), batch_size=self.head_count * batch_size) query_up = dy.reshape( query_up, (sent_len, self.dim_per_head, self.head_count), batch_size=batch_size) query_up = dy.concatenate_cols( [dy.cmult(query_up, emb2), dy.cmult(query_up, -emb1)]) query_up = dy.reshape(query_up, (sent_len, self.dim_per_head * 2), batch_size=self.head_count * batch_size) scaled = query_up * dy.transpose( key_up / math.sqrt(self.dim_per_head) ) # scale before the matrix multiplication to save memory # Apply Mask here if not self.ignore_masks: if att_mask is not None: att_mask_inp = att_mask * -100.0 if self.downsample_factor > 1: att_mask_inp = att_mask_inp[::self.downsample_factor, :: self.downsample_factor] scaled += dy.inputTensor(att_mask_inp) if batch_mask is not None: # reshape (batch, time) -> (time, head_count*batch), then *-100 inp = np.resize(np.broadcast_to(batch_mask.T[:, np.newaxis, :], (sent_len, self.head_count, batch_size)), (1, sent_len, self.head_count * batch_size)) \ * -100 mask_expr = dy.inputTensor(inp, batched=True) scaled += mask_expr if self.diag_gauss_mask: diag_growing = np.zeros((sent_len, sent_len, self.head_count)) for i in range(sent_len): for j in range(sent_len): diag_growing[i, j, :] = -(i - j)**2 / 2.0 e_diag_gauss_mask = dy.inputTensor(diag_growing) e_sigma = dy.parameter(self.diag_gauss_mask_sigma) if self.square_mask_std: e_sigma = dy.square(e_sigma) e_sigma_sq_inv = dy.cdiv( dy.ones(e_sigma.dim()[0], batch_size=batch_size), dy.square(e_sigma)) e_diag_gauss_mask_final = dy.cmult(e_diag_gauss_mask, e_sigma_sq_inv) scaled += dy.reshape(e_diag_gauss_mask_final, (sent_len, sent_len), batch_size=batch_size * self.head_count) # Computing Softmax here. attn = dy.softmax(scaled, d=1) if LOG_ATTENTION: yaml_logger.info({ "key": "selfatt_mat_ax0", "value": np.average(attn.value(), axis=0).dumps(), "desc": self.desc }) yaml_logger.info({ "key": "selfatt_mat_ax1", "value": np.average(attn.value(), axis=1).dumps(), "desc": self.desc }) yaml_logger.info({ "key": "selfatt_mat_ax0_ent", "value": entropy(attn.value()).dumps(), "desc": self.desc }) yaml_logger.info({ "key": "selfatt_mat_ax1_ent", "value": entropy(attn.value().transpose()).dumps(), "desc": self.desc }) self.select_att_head = 0 if self.select_att_head is not None: attn = dy.reshape(attn, (sent_len, sent_len, self.head_count), batch_size=batch_size) sel_mask = np.zeros((1, 1, self.head_count)) sel_mask[0, 0, self.select_att_head] = 1.0 attn = dy.cmult(attn, dy.inputTensor(sel_mask)) attn = dy.reshape(attn, (sent_len, sent_len), batch_size=self.head_count * batch_size) # Applying dropout to attention if p > 0.0: drop_attn = dy.dropout(attn, p) else: drop_attn = attn # Computing weighted attention score attn_prod = drop_attn * value_up # Reshaping the attn_prod to input query dimensions out = dy.reshape(attn_prod, (sent_len_out, self.dim_per_head * self.head_count), batch_size=batch_size) out = dy.transpose(out) out = dy.reshape(out, (self.model_dim, ), batch_size=batch_size * sent_len_out) # out = dy.reshape_transpose_reshape(attn_prod, (sent_len_out, self.dim_per_head * self.head_count), (self.model_dim,), pre_batch_size=batch_size, post_batch_size=batch_size*sent_len_out) if self.plot_attention: from sklearn.metrics.pairwise import cosine_similarity assert batch_size == 1 mats = [] for i in range(attn.dim()[1]): mats.append(dy.pick_batch_elem(attn, i).npvalue()) self.plot_att_mat( mats[-1], "{}.sent_{}.head_{}.png".format( self.plot_attention, self.plot_attention_counter, i), 300) avg_mat = np.average(mats, axis=0) self.plot_att_mat( avg_mat, "{}.sent_{}.head_avg.png".format(self.plot_attention, self.plot_attention_counter), 300) cosim_before = cosine_similarity(x.as_tensor().npvalue().T) self.plot_att_mat( cosim_before, "{}.sent_{}.cosim_before.png".format( self.plot_attention, self.plot_attention_counter), 600) cosim_after = cosine_similarity(out.npvalue().T) self.plot_att_mat( cosim_after, "{}.sent_{}.cosim_after.png".format( self.plot_attention, self.plot_attention_counter), 600) self.plot_attention_counter += 1 # Adding dropout and layer normalization if p > 0.0: res = dy.dropout(out, p) + residual else: res = out + residual ret = self.layer_norm.transform(res) return ret