Esempio n. 1
0
 def call(self, x, *args, **kwargs):
     """Call GAPConv1x1."""
     size = ops.get_shape(x)[2:]
     out = x
     for model in self.children():
         out = ops.mean(out)
         out = model(out)
         out = ops.interpolate(out, size)
     return out
Esempio n. 2
0
def channel_shuffle(x, groups):
    """Shuffle the channel of features.

    :param x: feature maps
    :type x: tensor
    :param groups: group number of channels
    :type groups: int
    :return: shuffled feature map
    :rtype: tensor
    """
    batchsize, num_channels, height, width = ops.get_shape(x)
    channels_per_group = num_channels // groups
    x = ops.View([batchsize, groups, channels_per_group, height, width])(x)
    x = ops.Transpose(1, 2)(x)
    x = ops.View([batchsize, num_channels, height, width])(x)
    return x
Esempio n. 3
0
 def call(self, x):
     """Forward x."""
     out = x[self.collect_inds[0]]
     for i in range(1, len(self.collect_inds)):
         collect = x[self.collect_inds[i]]
         if ops.get_shape(out)[2] > ops.get_shape(collect)[2]:
             # upsample collect
             collect = ops.interpolate(collect, size=ops.get_shape(
                 out)[2:], mode='bilinear', align_corners=True)
         elif ops.get_shape(collect)[2] > ops.get_shape(out)[2]:
             out = ops.interpolate(out, size=ops.get_shape(collect)[2:], mode='bilinear', align_corners=True)
         if self.agg_concat:
             out = ops.concat([out, collect])
         else:
             out += collect
     out = ops.Relu()(out)
     return out
Esempio n. 4
0
    def call(self, x1, x2):
        """Do an inference on AggregateCell.

        :param x1: first input
        :param x2: second input
        :return: output
        """
        if self.pre_transform:
            x1 = self.branch_1(x1)
            x2 = self.branch_2(x2)
        if tuple(ops.get_shape(x1)[2:]) > tuple(ops.get_shape(x2)[2:]):
            x2 = ops.interpolate(x2, size=ops.get_shape(
                x1)[2:], mode='bilinear', align_corners=True)
        elif tuple(ops.get_shape(x1)[2:]) < tuple(ops.get_shape(x2)[2:]):
            x1 = ops.interpolate(x1, size=ops.get_shape(
                x2)[2:], mode='bilinear', align_corners=True)
        if self.concat:
            return self.conv1x1(ops.concat([x1, x2]))
        else:
            return x1 + x2