def mnp_repeat(x): a = mnp.repeat(x, 2) b = mnp.repeat(x, 3, axis=0) c = mnp.repeat(x, (4, 1, 5), axis=1) d = mnp.repeat(x, (3, 2, 1, 0, 4), axis=-1) e = mnp.repeat(x, 0) return a, b, c, d, e
def construct(self, teacher, student, neg): expand_dims = ops.ExpandDims() # unsqueeze算子 teacher_vgg, student_vgg, neg_vgg = self.vgg(teacher), self.vgg( student), self.vgg(neg) loss = 0 for i in range(len(teacher_vgg)): neg_i = expand_dims(neg_vgg[i], 0) # [8, n_feats, w, h] # neg_i = neg_i.repeat(student_vgg[i].shape[0], axis=0) #TODO:1.3版本才会支持Tensor.repeat neg_i = np.repeat(neg_i, student_vgg[i].shape[0], axis=0) # [16, 8, n_feats, w, h] neg_i = neg_i.transpose((1, 0, 2, 3, 4)) # [8, 16, n_feats, w, h] d_ts = self.l1(stop_gradient(teacher_vgg[i]), student_vgg[i]) # d_sn = (stop_gradient(neg_i) - student_vgg[i]).abs().sum(axis=0).mean() #TODO:1.3版本才支持Tensor.sum d_sn = (stop_gradient(neg_i) - student_vgg[i]).abs() # [8, 16, n_feats, w, h] # print(d_sn.shape) reduceSum = ops.ReduceSum() d_sn = reduceSum(d_sn, 0).mean() # print(d_sn) contrastive = d_ts / (d_sn + 1e-7) loss += self.weights[i] * contrastive return self.get_loss(loss)