예제 #1
0
    def construct(self, s_t_hat, encoder_outputs, encoder_feature,
                  enc_padding_mask, coverage):
        b, t_k, n = encoder_outputs.shape

        dec_fea = self.decode_proj(s_t_hat)  # (B, 2 * hidden_dim)
        dec_fea_expand = P.ExpandDims()(dec_fea, 1)
        dec_fea_expand = P.BroadcastTo()(dec_fea_expand, (b, t_k, n))

        att_features = encoder_feature + dec_fea_expand
        if self.is_coverage:
            coverage_input = coverage.view(-1, 1)  # (B * t_k, 1)
            coverage_feature = self.W_c(
                coverage_input)  # (B * t_k, 2 * hidden_dim)
            att_features = att_features + coverage_feature

        e = P.Tanh()(att_features)  # (B * t_k, 2 * hidden_dim)
        scores = self.v(e)  # (B * t_k, 1)
        scores = scores.view(-1, t_k)  # (B, t_k)

        attn_dist_ = P.Softmax(1)(scores) * enc_padding_mask  # (B, t_k)
        normalization_factor = P.ReduceSum(True)(attn_dist_, 1)
        attn_dist = attn_dist_ / normalization_factor

        attn_dist = P.ExpandDims()(attn_dist, 1)  # (B, 1, t_k)
        c_t = P.BatchMatMul(attn_dist, encoder_outputs)  # (B, 1, n)
        c_t = c_t.view(-1, self.hidden_dim * 2)  # (B, 2 * hidden_dim)

        attn_dist = attn_dist.view(-1, t_k)

        if self.is_coverage:
            coverage = coverage.view(-1, t_k)
            coverage = coverage + attn_dist

        return c_t, attn_dist, coverage
예제 #2
0
    def variable_recurrent(self, x, h, seq_length):
        time_step = range(x.shape[0])
        h_t = h
        if self.is_lstm:
            hidden_size = h[0].shape[-1]
            zero_output = P.ZerosLike()(h_t[0])
        else:
            hidden_size = h.shape[-1]
            zero_output = P.ZerosLike()(h_t)

        seq_length = P.BroadcastTo((hidden_size, -1))(seq_length)
        seq_length = P.Transpose()(seq_length, (1, 0))

        outputs = []
        state_t = h_t
        for t in time_step:
            h_t = self.cell(x[t], state_t)
            seq_cond = seq_length > t
            if self.is_lstm:
                state_t_0 = P.Select()(seq_cond, h_t[0], state_t[0])
                state_t_1 = P.Select()(seq_cond, h_t[1], state_t[1])
                output = P.Select()(seq_cond, h_t[0], zero_output)
                state_t = (state_t_0, state_t_1)
            else:
                state_t = P.Select()(seq_cond, h_t, state_t)
                output = P.Select()(seq_cond, h_t, zero_output)
            outputs.append(output)
        outputs = P.Stack()(outputs)
        return outputs, state_t
예제 #3
0
    def variable_recurrent(self, x, h, seq_length, w_ih, w_hh, b_ih, b_hh):
        '''recurrent steps with sequence length'''
        time_step = x.shape[0]
        h_t = h
        if self.is_lstm:
            hidden_size = h[0].shape[-1]
            zero_output = P.ZerosLike()(h_t[0])
        else:
            hidden_size = h.shape[-1]
            zero_output = P.ZerosLike()(h_t)
        seq_length = P.Cast()(seq_length, mindspore.float32)
        seq_length = P.BroadcastTo((hidden_size, -1))(seq_length)
        seq_length = P.Cast()(seq_length, mindspore.int32)
        seq_length = P.Transpose()(seq_length, (1, 0))

        outputs = []
        state_t = h_t
        t = 0
        while t < time_step:
            x_t = x[t:t + 1:1]
            x_t = P.Squeeze(0)(x_t)
            h_t = self.cell(x_t, state_t, w_ih, w_hh, b_ih, b_hh)
            seq_cond = seq_length > t
            if self.is_lstm:
                state_t_0 = P.Select()(seq_cond, h_t[0], state_t[0])
                state_t_1 = P.Select()(seq_cond, h_t[1], state_t[1])
                output = P.Select()(seq_cond, h_t[0], zero_output)
                state_t = (state_t_0, state_t_1)
            else:
                state_t = P.Select()(seq_cond, h_t, state_t)
                output = P.Select()(seq_cond, h_t, zero_output)
            outputs.append(output)
            t += 1
        outputs = P.Stack()(outputs)
        return outputs, state_t
예제 #4
0
    def construct(self, inputs, targets):
        """
        Args:
        - inputs: feature matrix with shape (batch_size, feat_dim)
        - targets: ground truth labels with shape (num_classes)
        """
        n = inputs.shape[0]

        # Compute pairwise distance, replace by the official when merged
        pow = P.Pow()
        sum = P.ReduceSum(keep_dims=True)
        expand = P.BroadcastTo((n, n))
        transpose = P.Transpose()
        mul = P.Mul()
        add = P.Add()
        sqrt = P.Sqrt()
        equal = P.Equal()
        cat = P.Concat()
        ones_like = P.OnesLike()

        dist = pow(inputs, 2)
        dist = sum(dist, axis=1)
        dist = expand(dist)
        dist = dist + transpose(dist, (1, 0))

        temp1 = P.matmul(inputs, transpose(inputs, (1, 0)))
        temp1 = mul(-2, temp1)
        dist = add(dist, temp1)
        dist = P.composite.clip_by_value(
            dist, clip_value_min=1e-12, clip_value_max=100000000
        )  # for numerical stability, clip_value_max=? why must set?
        dist = sqrt(dist)

        # For each anchor, find the hardest positive and negative
        targets = expand(targets)
        mask = equal(targets, transpose(targets, (1, 0)))
        dist_ap = []
        dist_an = []

        # only for debugging
        #####################
        # print("dist is")
        # print(dist.shape)
        # print(dist)
        # print("mask is")
        # print(mask.shape)
        # print(mask)
        # print(mask[0])
        #####################

        for i in range(n):
            minval = -1.0
            maxval = -1.0
            for j in range(n):
                if mask[i][j] and dist[i][j] > maxval:
                    maxval = dist[i][j]
                if not mask[i][j] and (dist[i][j] < minval or minval == -1):
                    minval = dist[i][j]

            if (not isinstance(minval, Tensor)
                    or not isinstance(maxval, Tensor) or minval == -1.0
                    or maxval == -1.0):
                if self.error_msg is not None:
                    print("Error Msg", file=self.error_msg)
                    print("mask {} is".format(i), file=self.error_msg)
                    print(mask[i], file=self.error_msg)
                    print("dist is:", file=self.error_msg)
                    print(dist[i], file=self.error_msg)
                    print(maxval, file=self.error_msg)
                    print(minval, file=self.error_msg)
                    print(type(maxval), file=self.error_msg)
                    print(type(minval), file=self.error_msg)
                    self.error_msg.flush()

            # assert minval != -1.0 and isinstance(minval, Tensor)
            # assert maxval != -1.0 and isinstance(maxval, Tensor)
            dist_ap.append(maxval.asnumpy())
            dist_an.append(minval.asnumpy())

        dist_ap = Tensor(dist_ap, ms.float32)
        dist_an = Tensor(dist_an, ms.float32)
        # only for debugging
        #####################
        # print(dist_ap)
        # print(dist_ap.shape)
        # print(dist_an)
        #####################

        # Compute ranking hinge loss
        y = ones_like(dist_an)
        loss = self.ranking_loss(dist_an, dist_ap, y)

        # # compute accuracy
        # correct = torch.ge(dist_an, dist_ap).sum().item()
        return loss


# class GradOriTripletLoss(nn.Cell)
#     def __init__(self, net):
#         super(GradOriTripletLoss, self).__init__()
#         self.net = net
#         self.grad_op = P.GradOperation(get_all=True)
#
#     def construct(self, inputs, targets):
#         gradient_function = self.grad_op(self.net)
#         return gradient_function(inputs, targets)