def broadcast_factor(mask: batchers.Mask, tensor_expr: dy.Expression) -> numbers.Integral: """ returns product(tensor_expr dims) / product(mask dims) """ tensor_expr_size = tensor_expr.dim()[1] for d in tensor_expr.dim()[0]: tensor_expr_size *= d return tensor_expr_size / mask.np_arr.size
def transform(self, input_expr: dy.Expression, mask: Optional[batchers.Mask]=None): """ Apply batch norm. Args: input_expr: input mask: compute statistics only over unmasked parts of the input expression """ dim_in = input_expr.dim() param_bn_gamma = dy.parameter(self.gamma) param_bn_beta = dy.parameter(self.beta) if self.train: num_unmasked = 0 if mask is not None: input_expr = set_masked_to_mean(mask, input_expr, self.time_first) num_unmasked = (mask.np_arr.size - np.count_nonzero(mask.np_arr)) * broadcast_factor(mask, input_expr) bn_mean = dy.moment_dim(input_expr, self.get_stat_dimensions(), 1, True, num_unmasked) neg_bn_mean_reshaped = -dy.reshape(-bn_mean, self.get_normalizer_dimensionality()) self.population_running_mean += (-BN_MOMENTUM) * self.population_running_mean + BN_MOMENTUM * bn_mean.npvalue() bn_std = dy.std_dim(input_expr, self.get_stat_dimensions(), True, num_unmasked) self.population_running_std += (-BN_MOMENTUM) * self.population_running_std + BN_MOMENTUM * bn_std.npvalue() else: neg_bn_mean_reshaped = -dy.reshape(dy.inputVector(self.population_running_mean), self.get_normalizer_dimensionality()) bn_std = dy.inputVector(self.population_running_std) bn_numerator = input_expr + neg_bn_mean_reshaped bn_xhat = dy.cdiv(bn_numerator, dy.reshape(bn_std, self.get_normalizer_dimensionality()) + BN_EPS) bn_y = dy.cmult(param_bn_gamma, bn_xhat) + param_bn_beta # y = gamma * xhat + beta dim_out = bn_y.dim() self.save_processed_arg("population_running_mean", self.population_running_mean) self.save_processed_arg("population_running_std", self.population_running_std) assert dim_out == dim_in return bn_y
def calc_attention(self, state: dy.Expression) -> dy.Expression: scores = self.I * state if self.scale: scores /= math.sqrt(state.dim()[0][0]) if self.curr_sent.mask is not None: scores = self.curr_sent.mask.add_to_tensor_expr(scores, multiplicator = -100.0) normalized = dy.softmax(scores) self.attention_vecs.append(normalized) return normalized
def set_masked_to_mean(mask: batchers.Mask, tensor_expr: dy.Expression, time_first: bool = False) -> dy.Expression: """ Set masked parts of the tensor expr to the mean of the unmasked parts. """ if np.count_nonzero(mask.np_arr) == 0: return tensor_expr else: dim_before = tensor_expr.dim() reshape_size = mask_reshape_size(mask, tensor_expr.dim(), time_first) inv_mask_expr = dy.inputTensor( 1.0 - np.reshape(mask.np_arr.transpose(), reshape_size), batched=True) unmasked = dy.cmult(tensor_expr, inv_mask_expr) unmasked_mean = unmasked while sum( unmasked_mean.dim()[0] ) > 1: # loop because mean_dim only supports reducing up to 2 dimensions at a time unmasked_mean = dy.mean_dim( unmasked_mean, list(range(min(2, len(unmasked_mean.dim()[0])))), unmasked_mean.dim()[1] > 1, n=1) # this is mean without normalization == sum unmasked_mean = dy.cdiv( unmasked_mean, dy.inputTensor(np.asarray([ (mask.np_arr.size - np.count_nonzero(mask.np_arr)) * broadcast_factor(mask, tensor_expr) ]), batched=False)) mask_expr = dy.cmult( dy.inputTensor(np.reshape(mask.np_arr.transpose(), reshape_size), batched=True), unmasked_mean) ret = unmasked + mask_expr assert ret.dim() == dim_before return ret
def __call__(self, x: dy.Expression, att_mask: np.ndarray, batch_mask: np.ndarray, p: numbers.Real): """ x: expression of dimensions (input_dim, time) x batch att_mask: numpy array of dimensions (time, time); pre-transposed batch_mask: numpy array of dimensions (batch, time) p: dropout prob """ sent_len = x.dim()[0][1] batch_size = x[0].dim()[1] if self.downsample_factor > 1: if sent_len % self.downsample_factor != 0: raise ValueError( "For 'reshape' downsampling, sequence lengths must be multiples of the downsampling factor. " "Configure batcher accordingly.") if batch_mask is not None: batch_mask = batch_mask[:, ::self.downsample_factor] sent_len_out = sent_len // self.downsample_factor sent_len = sent_len_out out_mask = x.mask if self.downsample_factor > 1 and out_mask is not None: out_mask = out_mask.lin_subsampled( reduce_factor=self.downsample_factor) x = ExpressionSequence(expr_tensor=dy.reshape( x.as_tensor(), (x.dim()[0][0] * self.downsample_factor, x.dim()[0][1] / self.downsample_factor), batch_size=batch_size), mask=out_mask) residual = SAAMTimeDistributed()(x) else: residual = SAAMTimeDistributed()(x) sent_len_out = sent_len if self.model_dim != self.input_dim * self.downsample_factor: residual = self.res_shortcut.transform(residual) # Concatenate all the words together for doing vectorized affine transform if self.kq_pos_encoding_type is None: kvq_lin = self.linear_kvq.transform(SAAMTimeDistributed()(x)) key_up = self.shape_projection( dy.pick_range(kvq_lin, 0, self.head_count * self.dim_per_head), batch_size) value_up = self.shape_projection( dy.pick_range(kvq_lin, self.head_count * self.dim_per_head, 2 * self.head_count * self.dim_per_head), batch_size) query_up = self.shape_projection( dy.pick_range(kvq_lin, 2 * self.head_count * self.dim_per_head, 3 * self.head_count * self.dim_per_head), batch_size) else: assert self.kq_pos_encoding_type == "embedding" encoding = self.kq_positional_embedder.embed_sent( sent_len).as_tensor() kq_lin = self.linear_kq.transform(SAAMTimeDistributed()( ExpressionSequence( expr_tensor=dy.concatenate([x.as_tensor(), encoding])))) key_up = self.shape_projection( dy.pick_range(kq_lin, 0, self.head_count * self.dim_per_head), batch_size) query_up = self.shape_projection( dy.pick_range(kq_lin, self.head_count * self.dim_per_head, 2 * self.head_count * self.dim_per_head), batch_size) v_lin = self.linear_v.transform(SAAMTimeDistributed()(x)) value_up = self.shape_projection(v_lin, batch_size) if self.cross_pos_encoding_type: assert self.cross_pos_encoding_type == "embedding" emb1 = dy.pick_range(dy.parameter(self.cross_pos_emb_p1), 0, sent_len) emb2 = dy.pick_range(dy.parameter(self.cross_pos_emb_p2), 0, sent_len) key_up = dy.reshape(key_up, (sent_len, self.dim_per_head, self.head_count), batch_size=batch_size) key_up = dy.concatenate_cols( [dy.cmult(key_up, emb1), dy.cmult(key_up, emb2)]) key_up = dy.reshape(key_up, (sent_len, self.dim_per_head * 2), batch_size=self.head_count * batch_size) query_up = dy.reshape( query_up, (sent_len, self.dim_per_head, self.head_count), batch_size=batch_size) query_up = dy.concatenate_cols( [dy.cmult(query_up, emb2), dy.cmult(query_up, -emb1)]) query_up = dy.reshape(query_up, (sent_len, self.dim_per_head * 2), batch_size=self.head_count * batch_size) scaled = query_up * dy.transpose( key_up / math.sqrt(self.dim_per_head) ) # scale before the matrix multiplication to save memory # Apply Mask here if not self.ignore_masks: if att_mask is not None: att_mask_inp = att_mask * -100.0 if self.downsample_factor > 1: att_mask_inp = att_mask_inp[::self.downsample_factor, :: self.downsample_factor] scaled += dy.inputTensor(att_mask_inp) if batch_mask is not None: # reshape (batch, time) -> (time, head_count*batch), then *-100 inp = np.resize(np.broadcast_to(batch_mask.T[:, np.newaxis, :], (sent_len, self.head_count, batch_size)), (1, sent_len, self.head_count * batch_size)) \ * -100 mask_expr = dy.inputTensor(inp, batched=True) scaled += mask_expr if self.diag_gauss_mask: diag_growing = np.zeros((sent_len, sent_len, self.head_count)) for i in range(sent_len): for j in range(sent_len): diag_growing[i, j, :] = -(i - j)**2 / 2.0 e_diag_gauss_mask = dy.inputTensor(diag_growing) e_sigma = dy.parameter(self.diag_gauss_mask_sigma) if self.square_mask_std: e_sigma = dy.square(e_sigma) e_sigma_sq_inv = dy.cdiv( dy.ones(e_sigma.dim()[0], batch_size=batch_size), dy.square(e_sigma)) e_diag_gauss_mask_final = dy.cmult(e_diag_gauss_mask, e_sigma_sq_inv) scaled += dy.reshape(e_diag_gauss_mask_final, (sent_len, sent_len), batch_size=batch_size * self.head_count) # Computing Softmax here. attn = dy.softmax(scaled, d=1) if LOG_ATTENTION: yaml_logger.info({ "key": "selfatt_mat_ax0", "value": np.average(attn.value(), axis=0).dumps(), "desc": self.desc }) yaml_logger.info({ "key": "selfatt_mat_ax1", "value": np.average(attn.value(), axis=1).dumps(), "desc": self.desc }) yaml_logger.info({ "key": "selfatt_mat_ax0_ent", "value": entropy(attn.value()).dumps(), "desc": self.desc }) yaml_logger.info({ "key": "selfatt_mat_ax1_ent", "value": entropy(attn.value().transpose()).dumps(), "desc": self.desc }) self.select_att_head = 0 if self.select_att_head is not None: attn = dy.reshape(attn, (sent_len, sent_len, self.head_count), batch_size=batch_size) sel_mask = np.zeros((1, 1, self.head_count)) sel_mask[0, 0, self.select_att_head] = 1.0 attn = dy.cmult(attn, dy.inputTensor(sel_mask)) attn = dy.reshape(attn, (sent_len, sent_len), batch_size=self.head_count * batch_size) # Applying dropout to attention if p > 0.0: drop_attn = dy.dropout(attn, p) else: drop_attn = attn # Computing weighted attention score attn_prod = drop_attn * value_up # Reshaping the attn_prod to input query dimensions out = dy.reshape(attn_prod, (sent_len_out, self.dim_per_head * self.head_count), batch_size=batch_size) out = dy.transpose(out) out = dy.reshape(out, (self.model_dim, ), batch_size=batch_size * sent_len_out) # out = dy.reshape_transpose_reshape(attn_prod, (sent_len_out, self.dim_per_head * self.head_count), (self.model_dim,), pre_batch_size=batch_size, post_batch_size=batch_size*sent_len_out) if self.plot_attention: from sklearn.metrics.pairwise import cosine_similarity assert batch_size == 1 mats = [] for i in range(attn.dim()[1]): mats.append(dy.pick_batch_elem(attn, i).npvalue()) self.plot_att_mat( mats[-1], "{}.sent_{}.head_{}.png".format( self.plot_attention, self.plot_attention_counter, i), 300) avg_mat = np.average(mats, axis=0) self.plot_att_mat( avg_mat, "{}.sent_{}.head_avg.png".format(self.plot_attention, self.plot_attention_counter), 300) cosim_before = cosine_similarity(x.as_tensor().npvalue().T) self.plot_att_mat( cosim_before, "{}.sent_{}.cosim_before.png".format( self.plot_attention, self.plot_attention_counter), 600) cosim_after = cosine_similarity(out.npvalue().T) self.plot_att_mat( cosim_after, "{}.sent_{}.cosim_after.png".format( self.plot_attention, self.plot_attention_counter), 600) self.plot_attention_counter += 1 # Adding dropout and layer normalization if p > 0.0: res = dy.dropout(out, p) + residual else: res = out + residual ret = self.layer_norm.transform(res) return ret