예제 #1
0
 def __init__(self, layer, time_axis, reshape_with_axis=None):
     if not isinstance(layer, (Cell, Primitive)):
         raise TypeError(
             "Please initialize TimeDistributed with mindspore.nn.Cell or "
             "mindspore.ops.Primitive instance. You passed: {input}".format(
                 input=layer))
     super(TimeDistributed, self).__init__()
     self.layer = layer
     self.time_axis = time_axis
     self.reshape_with_axis = reshape_with_axis
     self.transpose = Transpose()
     self.reshape = Reshape()
예제 #2
0
    def construct(self, x, feat, t=None):  # ? what is t?
        bt, c, h, w = x.shape
        b = bt // t

        # get part features
        part_feat_pool = nn.AvgPool2d(kernel_size=(6, 9), stride=(6, 1))
        part_feat = part_feat_pool(x)
        part_feat = part_feat.view(b, t, c, self.part)
        transpose = Transpose()
        part_feat = transpose(part_feat, (0, 2, 1, 3))  # B, C, T, Part

        part_feat1 = self.fc1(part_feat).view(b, self.inter_channels,
                                              -1)  # B, C//r, T*part
        part_feat1 = transpose(part_feat1, (0, 2, 1))  # B, T*part, C//r

        part_feat2 = self.fc2(part_feat).view(b, self.inter_channels,
                                              -1)  # B, C//r, T*part

        part_feat3 = self.fc3(part_feat).view(b, self.inter_channels,
                                              -1)  # B, C//r, T*part
        part_feat3 = transpose(part_feat3, (0, 2, 1))  # B, T*part, C//r

        # get cross-part attention
        cpa_att = mat_mul(part_feat1, part_feat2)  # B, T*part, T*part
        cpa_att = self.softmax(cpa_att)

        # collect contextual information
        refined_part_feat = P.matmul(cpa_att, part_feat3)  # B, T*Part, C//r
        refined_part_feat = transpose(refined_part_feat,
                                      (0, 2, 1))  # B, C//r, T*part
        refined_part_feat = refined_part_feat.view(
            (b, self.inter_channels, self.part))  # B, C//r, T, part

        # gate = self.softmax(self.gate)
        # weight_part_feat = nn.MatMul(refined_part_feat, gate)
        weight_part_feat = P.matmul(refined_part_feat, self.gate)
        weight_part_feat = weight_part_feat.view(
            (weight_part_feat.shape[0], weight_part_feat.shape[1], 1, 1))
        print("weight_part_feat shape is", weight_part_feat.shape)

        weight_part_feat = weight_part_feat + feat
        feat = self.bottleneck(weight_part_feat)

        return feat