def construct(self, s_t_hat, encoder_outputs, encoder_feature, enc_padding_mask, coverage): b, t_k, n = encoder_outputs.shape dec_fea = self.decode_proj(s_t_hat) # (B, 2 * hidden_dim) dec_fea_expand = P.ExpandDims()(dec_fea, 1) dec_fea_expand = P.BroadcastTo()(dec_fea_expand, (b, t_k, n)) att_features = encoder_feature + dec_fea_expand if self.is_coverage: coverage_input = coverage.view(-1, 1) # (B * t_k, 1) coverage_feature = self.W_c( coverage_input) # (B * t_k, 2 * hidden_dim) att_features = att_features + coverage_feature e = P.Tanh()(att_features) # (B * t_k, 2 * hidden_dim) scores = self.v(e) # (B * t_k, 1) scores = scores.view(-1, t_k) # (B, t_k) attn_dist_ = P.Softmax(1)(scores) * enc_padding_mask # (B, t_k) normalization_factor = P.ReduceSum(True)(attn_dist_, 1) attn_dist = attn_dist_ / normalization_factor attn_dist = P.ExpandDims()(attn_dist, 1) # (B, 1, t_k) c_t = P.BatchMatMul(attn_dist, encoder_outputs) # (B, 1, n) c_t = c_t.view(-1, self.hidden_dim * 2) # (B, 2 * hidden_dim) attn_dist = attn_dist.view(-1, t_k) if self.is_coverage: coverage = coverage.view(-1, t_k) coverage = coverage + attn_dist return c_t, attn_dist, coverage
def variable_recurrent(self, x, h, seq_length): time_step = range(x.shape[0]) h_t = h if self.is_lstm: hidden_size = h[0].shape[-1] zero_output = P.ZerosLike()(h_t[0]) else: hidden_size = h.shape[-1] zero_output = P.ZerosLike()(h_t) seq_length = P.BroadcastTo((hidden_size, -1))(seq_length) seq_length = P.Transpose()(seq_length, (1, 0)) outputs = [] state_t = h_t for t in time_step: h_t = self.cell(x[t], state_t) seq_cond = seq_length > t if self.is_lstm: state_t_0 = P.Select()(seq_cond, h_t[0], state_t[0]) state_t_1 = P.Select()(seq_cond, h_t[1], state_t[1]) output = P.Select()(seq_cond, h_t[0], zero_output) state_t = (state_t_0, state_t_1) else: state_t = P.Select()(seq_cond, h_t, state_t) output = P.Select()(seq_cond, h_t, zero_output) outputs.append(output) outputs = P.Stack()(outputs) return outputs, state_t
def variable_recurrent(self, x, h, seq_length, w_ih, w_hh, b_ih, b_hh): '''recurrent steps with sequence length''' time_step = x.shape[0] h_t = h if self.is_lstm: hidden_size = h[0].shape[-1] zero_output = P.ZerosLike()(h_t[0]) else: hidden_size = h.shape[-1] zero_output = P.ZerosLike()(h_t) seq_length = P.Cast()(seq_length, mindspore.float32) seq_length = P.BroadcastTo((hidden_size, -1))(seq_length) seq_length = P.Cast()(seq_length, mindspore.int32) seq_length = P.Transpose()(seq_length, (1, 0)) outputs = [] state_t = h_t t = 0 while t < time_step: x_t = x[t:t + 1:1] x_t = P.Squeeze(0)(x_t) h_t = self.cell(x_t, state_t, w_ih, w_hh, b_ih, b_hh) seq_cond = seq_length > t if self.is_lstm: state_t_0 = P.Select()(seq_cond, h_t[0], state_t[0]) state_t_1 = P.Select()(seq_cond, h_t[1], state_t[1]) output = P.Select()(seq_cond, h_t[0], zero_output) state_t = (state_t_0, state_t_1) else: state_t = P.Select()(seq_cond, h_t, state_t) output = P.Select()(seq_cond, h_t, zero_output) outputs.append(output) t += 1 outputs = P.Stack()(outputs) return outputs, state_t
def construct(self, inputs, targets): """ Args: - inputs: feature matrix with shape (batch_size, feat_dim) - targets: ground truth labels with shape (num_classes) """ n = inputs.shape[0] # Compute pairwise distance, replace by the official when merged pow = P.Pow() sum = P.ReduceSum(keep_dims=True) expand = P.BroadcastTo((n, n)) transpose = P.Transpose() mul = P.Mul() add = P.Add() sqrt = P.Sqrt() equal = P.Equal() cat = P.Concat() ones_like = P.OnesLike() dist = pow(inputs, 2) dist = sum(dist, axis=1) dist = expand(dist) dist = dist + transpose(dist, (1, 0)) temp1 = P.matmul(inputs, transpose(inputs, (1, 0))) temp1 = mul(-2, temp1) dist = add(dist, temp1) dist = P.composite.clip_by_value( dist, clip_value_min=1e-12, clip_value_max=100000000 ) # for numerical stability, clip_value_max=? why must set? dist = sqrt(dist) # For each anchor, find the hardest positive and negative targets = expand(targets) mask = equal(targets, transpose(targets, (1, 0))) dist_ap = [] dist_an = [] # only for debugging ##################### # print("dist is") # print(dist.shape) # print(dist) # print("mask is") # print(mask.shape) # print(mask) # print(mask[0]) ##################### for i in range(n): minval = -1.0 maxval = -1.0 for j in range(n): if mask[i][j] and dist[i][j] > maxval: maxval = dist[i][j] if not mask[i][j] and (dist[i][j] < minval or minval == -1): minval = dist[i][j] if (not isinstance(minval, Tensor) or not isinstance(maxval, Tensor) or minval == -1.0 or maxval == -1.0): if self.error_msg is not None: print("Error Msg", file=self.error_msg) print("mask {} is".format(i), file=self.error_msg) print(mask[i], file=self.error_msg) print("dist is:", file=self.error_msg) print(dist[i], file=self.error_msg) print(maxval, file=self.error_msg) print(minval, file=self.error_msg) print(type(maxval), file=self.error_msg) print(type(minval), file=self.error_msg) self.error_msg.flush() # assert minval != -1.0 and isinstance(minval, Tensor) # assert maxval != -1.0 and isinstance(maxval, Tensor) dist_ap.append(maxval.asnumpy()) dist_an.append(minval.asnumpy()) dist_ap = Tensor(dist_ap, ms.float32) dist_an = Tensor(dist_an, ms.float32) # only for debugging ##################### # print(dist_ap) # print(dist_ap.shape) # print(dist_an) ##################### # Compute ranking hinge loss y = ones_like(dist_an) loss = self.ranking_loss(dist_an, dist_ap, y) # # compute accuracy # correct = torch.ge(dist_an, dist_ap).sum().item() return loss # class GradOriTripletLoss(nn.Cell) # def __init__(self, net): # super(GradOriTripletLoss, self).__init__() # self.net = net # self.grad_op = P.GradOperation(get_all=True) # # def construct(self, inputs, targets): # gradient_function = self.grad_op(self.net) # return gradient_function(inputs, targets)