示例#1
0
    def forward(self, x):
        net = self.conv(x)
        bs = x.shape[0]
        if self.use_bn:
            net = self.bn0(net)
        net = self.relu(net)
        if self.radix > 1:
            splits = F.split(net, int(self.radix), axis=1)
            gap = sum(splits)
        else:
            gap = net
        gap = F.adaptive_avg_pool2d(gap, 1)
        gap = self.fc1(gap)
        if self.use_bn:
            gap = self.bn1(gap)
        gap = self.relu(gap)
        atten = self.fc2(gap)
        atten = self.rsoftmax(atten).reshape(bs, -1, 1, 1)

        if self.radix > 1:
            attens = F.split(atten, int(self.radix), axis=1)
            out = sum([att * split for att, split in zip(attens, splits)])
        else:
            out = atten * net
        return out
示例#2
0
    def forward(self, x):
        #do the conv
        net = self.conv(x)
        if self.use_bn:
            net = self.bn0(net)

        if self.droupblock_prob > 0.0:
            net = self.droupblock(net)

        net = self.relu(net)
        #split from the channels
        batch = net.shape[0]

        if self.radix > 1:
            splited = F.split(net, self.radix, axis=1)
            gap = sum(splited)
        #calculate the attention
        gap = F.adaptive_avg_pool2d(gap, 1)
        gap = self.fc1(gap)

        if self.use_bn:
            gap = self.bn1(gap)

        atten = self.fc2(gap)
        atten = self.rsoftmax(atten).reshape(batch, -1, 1, 1)

        if self.radix > 1:
            attens = F.split(atten, self.radix, axis=1)

            out = sum([att * split for (att, split) in zip(attens, splited)])
        else:
            out = atten * x

        return out
示例#3
0
def test_split():
    data = np.random.random((2, 3, 4, 5)).astype(np.float32)
    mge_out1 = F.split(tensor(data), 2, axis=3)
    mge_out2 = F.split(tensor(data), [3, 5], axis=3)

    np_out = np.split(data, [3, 5], axis=3)

    np.testing.assert_equal(mge_out1[0].numpy(), mge_out2[0].numpy())
    np.testing.assert_equal(mge_out1[0].numpy(), np_out[0])
示例#4
0
    def _ternary_transform_mge(image):
        n, c, h, w = image.shape
        if c == 3:
            R, G, B = F.split(image, 3, 1)
            intensities = (0.2989 * R + 0.5870 * G + 0.1140 * B
                           )  # * 255  # convert to gray
        elif c == 1:
            intensities = image
        else:
            raise ValueError('image channel should be 3 or 1: %s' % c)
        # intensities = tf.image.rgb_to_grayscale(image) * 255
        out_channels = patch_size * patch_size
        w = np.eye(out_channels).reshape(
            (patch_size, patch_size, 1, out_channels))  # h,w,1,out_c
        w_ = np.transpose(w, (3, 2, 0, 1))  # 1,out_c,h,w
        # weight = torch.from_numpy(w_).float()
        weight = mge.tensor(w_.astype(np.float32))  # need check cuda?

        # if image.is_cuda:
        #     weight = weight.cuda()
        # patches_torch = torch.conv2d(input=out_channels, weight=weight, bias=None, stride=[1, 1], padding=[max_distance, max_distance])
        patches_mge = F.nn.conv2d(inp=intensities,
                                  weight=weight,
                                  bias=None,
                                  stride=[1, 1],
                                  padding=[max_distance, max_distance])
        transf_mge = patches_mge - intensities
        transf_norm_mge = transf_mge / F.sqrt(0.81 + transf_mge**2)
        return transf_norm_mge
示例#5
0
def test_split():
    data = np.random.random((2, 3, 4, 5)).astype(np.float32)
    inp = tensor(data)

    mge_out0 = F.split(inp, 2, axis=3)
    mge_out1 = F.split(inp, [3], axis=3)

    np_out = np.split(data, [3, 5], axis=3)

    assert len(mge_out0) == 2
    assert len(mge_out1) == 2

    np.testing.assert_equal(mge_out0[0].numpy(), np_out[0])
    np.testing.assert_equal(mge_out1[0].numpy(), np_out[0])

    np.testing.assert_equal(mge_out0[1].numpy(), np_out[1])
    np.testing.assert_equal(mge_out1[1].numpy(), np_out[1])

    try:
        F.split(inp, 4)
        assert False
    except ValueError as e:
        pass

    try:
        F.split(inp, [3, 3, 5], axis=3)
        assert False
    except ValueError as e:
        assert str(e) == "Invalid nsplits_or_secions: [3, 3, 5]"
示例#6
0
    def backward(self, label, gm):
        label_chunks = F.split(label, 4)
        losses = []

        for i, x in enumerate(self.inp_chunks):
            with gm:  #ad.GradManager().attach(self.parameters()) as gm:
                gm.attach(x)  # query gradient of the input
                y = self.features(x)

                if dist.get_rank() == 3:
                    y = F.avg_pool2d(y, 7)
                    y = F.flatten(y, 1)
                    y = self.classifier(y)
                    loss = F.nn.cross_entropy(y, label_chunks[i])
                    losses.append(loss)
                    gm.backward(loss)
                else:
                    grad = grad_fr_next_gpu()
                    gm.backward(y, dy=grad)

                if dist.get_rank() != 0:
                    _ = grad_to_prev_gpu(x.grad)

        return sum(losses) / self.num_chunks if losses else None
示例#7
0
    def forward(self, x):
        self.num_chunks = 4
        self.inp_chunks = []
        self.oup_chunks = []
        if dist.get_rank() == 0:
            self.inp_chunks = F.split(x, 4)

        for i in range(self.num_chunks):
            if dist.get_rank() == 0:
                x = self.inp_chunks[i]
            else:
                x = recv_fr_prev_gpu()
                self.inp_chunks.append(x)

            x = self.features(x)
            if dist.get_rank() != 3:
                _ = send_to_next_gpu(x)
            else:
                x = F.avg_pool2d(x, 7)
                x = F.flatten(x, 1)
                x = self.classifier(x)
            self.oup_chunks.append(x)

        return F.concat(self.oup_chunks)
示例#8
0
def test_split_basic(is_varnode):
    if is_varnode:
        network = Network()
        saved_symbolic_shape = set_symbolic_shape(False)
    else:
        network = None

    data = np.random.random((2, 3, 4, 5)).astype(np.float32)
    inp = make_tensor(data, network)

    mge_out0 = F.split(inp, 2, axis=3)
    mge_out1 = F.split(inp, [3], axis=3)

    np_out = np.split(data, [3, 5], axis=3)

    assert len(mge_out0) == 2
    assert len(mge_out1) == 2

    np.testing.assert_equal(mge_out0[0].numpy(), np_out[0])
    np.testing.assert_equal(mge_out1[0].numpy(), np_out[0])

    np.testing.assert_equal(mge_out0[1].numpy(), np_out[1])
    np.testing.assert_equal(mge_out1[1].numpy(), np_out[1])

    try:
        F.split(inp, 4)
        assert False
    except ValueError as e:
        pass

    try:
        F.split(inp, [3, 2, 5], axis=3)
        assert False
    except ValueError as e:
        assert str(e) == "Invalid nsplits_or_secions: [3, 2, 5]"

    if is_varnode:
        set_symbolic_shape(saved_symbolic_shape)
示例#9
0
 def func(inp, nsplits_or_sections, axis):
     return F.split(inp, nsplits_or_sections, axis)
示例#10
0
 ("relu", MF.relu, TF.relu, [(100, 100)], [(64, 512, 16, 16)], True, 1000),
 ("relu6", MF.relu6, TF.relu6, [(100, 100)], [(64, 512, 16, 16)], True,
  1000),
 (
     "repeat",
     lambda x: MF.repeat(x, 5),
     lambda x: torch.repeat_interleave(x, 5),
     [(100, 100)],
     [(64, 512, 16, 16)],
     True,
     1000,
 ),
 ("silu", MF.silu, TF.silu, [(100, 100)], [(64, 512, 16, 16)], True, 1000),
 (
     "split",
     lambda x: MF.split(x, 5),
     lambda x: torch.split(x, 5),
     [(100, 100)],
     [(64, 512, 16, 16)],
     True,
     1000,
 ),
 ("sigmoid", MF.sigmoid, TF.sigmoid, [(100, 100)], [(64, 512, 16, 16)],
  True, 1000),
 (
     "softmax",
     lambda x: MF.softmax(x, axis=1),
     lambda x: TF.softmax(x, dim=1),
     [(100, 100)],
     [(64, 512, 16, 16)],
     True,
示例#11
0
    def forward(self, features, label=None, mask=None):
        """
        if label and mask both None, the loss will degenerate to
        SimSLR unsupervised loss.
        Reference:
            "A Simple Framework for Contrastive Learning of Visual Representations"<https://arxiv.org/pdf/2002.05709.pdf>
            "Supervised Contrastive Learning"<https://arxiv.org/abs/2004.11362>
        Args:
            features(tensor): The embedding feature. shape=[bs, n_views, ...]
            label(tensor): The label of images, shape=[bs]
            mask(tensor): contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        return:
            loss
        """
        if len(features.shape) < 3:
            raise ValueError("Features need have 3 dimensions at least")
        bs, num_view = features.shape[:2]
        #if dimension > 3, change the shape of the features to [bs, num_view, ...]
        if len(features.shape) > 3:
            features = features.reshape(bs, num_view, -1)

        #label and mask cannot provided at the same time
        if (label is not None) and (mask is not None):
            raise ValueError("label and mask cannot provided at the same time")
        elif (label is None) and (mask is None):
            mask = F.eye(bs, dtype="float32")
        elif label is not None:
            label = label.reshape(-1, 1)
            if label.shape[0] != bs:
                raise RuntimeError(
                    "Num of labels does not match num of features")
            mask = F.equal(label, label.T)
        else:
            mask = mask.astype("float32")

        contrast_count = features.shape[1]
        features = F.split(features, features.shape[1], axis=1)
        contrast_feature = F.squeeze(F.concat(features, axis=0), axis=1)
        if self.contrast_mode == "one":
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == "all":
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError("Unknown mode:{}".format(self.contrast_mode))
        #compute logits
        anchor_dot_contrast = F.div(
            F.matmul(anchor_feature, contrast_feature.T), self.temperate)

        #for numerical stability
        logits_max = F.max(anchor_dot_contrast, axis=-1, keepdims=True)
        logits = anchor_dot_contrast - logits_max

        #tile mask
        an1, con = mask.shape[:2]
        nums = anchor_count * contrast_count
        # mask-out self-contrast cases
        mask = F.stack([mask] * nums).reshape(an1 * anchor_count,
                                              con * contrast_count)
        logits_mask = F.scatter(
            F.ones_like(mask), 1,
            F.arange(0, int(bs * anchor_count), dtype="int32").reshape(-1, 1),
            F.zeros(int(bs * anchor_count), dtype="int32").reshape(-1, 1))
        mask = mask * logits_mask
        #compute log_prob
        exp_logits = F.exp(logits) * logits_mask
        log_prob = logits - F.log(F.sum(exp_logits, axis=1,
                                        keepdims=True))  #equation 2

        #mean
        mean_log_prob_pos = F.sum(mask * log_prob, axis=1) / F.sum(mask,
                                                                   axis=1)

        #loss
        loss = -(self.temperate / self.base_temperate) * mean_log_prob_pos
        loss = F.mean(loss.reshape(anchor_count, bs))
        return loss