def test_neg_conv_relu(data_shape):
    # conv + relu can't be fusion case
    # eg.1
    # conv -----------> relu
    #  |
    #  |
    #  ---------------> [custom op]
    class NegConvReLU(nn.HybridBlock):
        def __init__(self, **kwargs):
            super(NegConvReLU, self).__init__(**kwargs)
            self.conv1 = nn.Conv2D(channels=64,
                                   kernel_size=(3, 3),
                                   strides=(1, 1),
                                   use_bias=False)
            self.act = nn.Activation('relu')
            self.pool = nn.AvgPool2D(pool_size=(4, 4))
            self.tailneg = TailNegBlock()

        def hybrid_forward(self, F, x):
            conv = self.conv1(x)
            bn = self.act(conv)
            pool = self.pool(conv)
            return self.tailneg(bn, pool)

    attrs = []
    excluded_attrs = []
    net = NegConvReLU()
    check_neg_fusion(net, attrs, excluded_attrs, data_shape)
def test_neg_conv_add(data_shape):
  # conv + add can't be fusion case
  # eg.1
  #  ---------------> [custom op]
  #  |
  #  |
  # conv -----------> add
  #                   |
  #                   |
  # added ------------>
  class NegConvAdd(nn.HybridBlock):
    def __init__(self, **kwargs):
      super(NegConvAdd, self).__init__(**kwargs)
      self.conv1 = nn.Conv2D(channels=64, kernel_size=(3, 3), strides=(1,1), use_bias=False)
      self.act = nn.Activation('relu')
      self.pool = nn.AvgPool2D(pool_size=(4,4))
      self.tailneg = TailNegBlock()
      self.add_value = mx.gluon.Parameter('add_value', init=mx.init.Xavier(magnitude=2.24),
                                          dtype='float32', allow_deferred_init=True)

    def forward(self, x):
      conv = self.conv1(x)
      print(conv.shape)
      sum1 = conv + self.add_value.data(x.device)
      pool = self.pool(conv)
      return self.tailneg(sum1, pool)
    
    def infer_shape(self, x):
      self.add_value.shape = (data_shape[0], 64, data_shape[2]-2, data_shape[3]-2)

  attrs = []
  excluded_attrs = ['with_sum']
  net = NegConvAdd()
  check_neg_fusion(net, attrs, excluded_attrs, data_shape)
def test_neg_conv_bn(data_shape):
  # conv + bn can't be fusion case
  # eg.1
  # conv --------- > bn
  #  |
  #  |
  #  -------------> [custom op]
  class NegConvBN(nn.HybridBlock):
    def __init__(self, **kwargs):
      super(NegConvBN, self).__init__(**kwargs)
      self.conv1 = nn.Conv2D(channels=64, kernel_size=(3, 3), strides=(1,1), use_bias=False)
      self.bn1 = nn.BatchNorm()
      self.pool = nn.AvgPool2D(pool_size=(4,4))
      self.tailneg = TailNegBlock()

    def forward(self, x):
      conv = self.conv1(x)
      bn = self.bn1(conv)
      pool = self.pool(conv)

      return self.tailneg(bn, pool)

  attrs = []
  excluded_attrs = []
  net = NegConvBN()
  check_neg_fusion(net, attrs, excluded_attrs, data_shape)
def test_neg_fc_add(data_shape, add_op, flatten, fc_out_add, scaled_fc_out):
    '''
  Test if FullyConnected operator which output is not used for only one 'add_op' input is not fused.
  See NegFCAdd for used graph example
  '''
    flatten = (flatten == 'flat')
    num_hidden = 10
    net = NegFCAdd(num_hidden, add_op, fc_out_add, scaled_fc_out, flatten)
    if flatten:
        data_shapes = [data_shape, (data_shape[0], num_hidden)]
    else:
        data_shapes = [data_shape, (*data_shape[0:-1], num_hidden)]
    attrs = []
    excluded_attrs = ['with_sum']
    check_neg_fusion(net, attrs, excluded_attrs, data_shapes, name='fc')
def test_neg_conv_bn_relu(data_shape):
    # conv + bn + relu can't be fusion case
    # eg.1
    #   --------------> [custom op]
    #   |
    # conv -----------> bn -----------> relu
    #
    # eg.2
    #                   --------------> [custom op]
    #                   |
    # conv -----------> bn -----------> relu
    class NegConvBNRelu(nn.HybridBlock):
        def __init__(self, batchnorm_pool=False, **kwargs):
            super(NegConvBNRelu, self).__init__(**kwargs)
            self.conv1 = nn.Conv2D(channels=64,
                                   kernel_size=(3, 3),
                                   strides=(1, 1),
                                   use_bias=False)
            self.bn = nn.BatchNorm()
            self.act = nn.Activation('relu')
            self.pool = nn.AvgPool2D(pool_size=(4, 4))
            self.tailneg = TailNegBlock()
            self.batchnorm_pool = batchnorm_pool

        def hybrid_forward(self, F, x):
            conv = self.conv1(x)
            bn = self.bn(conv)
            relu = self.act(bn)
            pool = self.pool(bn) if self.batchnorm_pool else self.pool(conv)
            return self.tailneg(relu, pool)

    # eg.1 ([custom op] = pool11)
    net1 = NegConvBNRelu()
    attrs1 = []
    excluded_attrs1 = []
    check_neg_fusion(net1, attrs1, excluded_attrs1, data_shape)

    # eg.2 ([custom op] = pool)
    net2 = NegConvBNRelu(batchnorm_pool=True)
    attrs2 = ['with_bn']
    excluded_attrs2 = ['with_act']
    check_neg_fusion(net2, attrs2, excluded_attrs2, data_shape)
def test_neg_fc_relu(data_shape, use_bias, flatten):
    # fc + relu can't be fusion case
    # eg.1
    # fc -----------> relu
    #  |
    #  |
    #  ---------------> [custom op]
    class NegFCReLU(nn.HybridBlock):
        def __init__(self, use_bias, flatten, **kwargs):
            super(NegFCReLU, self).__init__(**kwargs)
            self.fc = nn.Dense(units=64, use_bias=use_bias, flatten=flatten)
            self.act1 = nn.Activation('relu')
            self.act2 = nn.Activation('sigmoid')
            self.tail_neg = TailNegBlock()

        def forward(self, x):
            fc_out = self.fc(x)
            return self.tail_neg(self.act1(fc_out), self.act2(fc_out))

    attrs, excluded_attrs = [], []
    net = NegFCReLU(use_bias, flatten)
    check_neg_fusion(net, attrs, excluded_attrs, data_shape, name='fc')
def test_neg_conv_bn_add_relu(data_shape):
    # conv + bn + add + relu can't be fusion case
    # eg.1
    #   --------------> [custom op]
    #   |
    # conv -----------> bn -----------> add -----------> relu
    #
    # eg.2
    #                    -------------> [custom op]
    #                    |
    # conv -----------> bn -----------> add -----------> relu
    #
    # eg.3
    #                                    --------------> [custom op]
    #                                    |
    # conv -----------> bn -----------> add -----------> relu

    class NegConvBNAddRelu(nn.HybridBlock):
        def __init__(self, connect_mode="conv_customop", **kwargs):
            super(NegConvBNAddRelu, self).__init__(**kwargs)
            self.conv1 = nn.Conv2D(channels=64,
                                   kernel_size=(3, 3),
                                   strides=(1, 1),
                                   use_bias=False)
            self.bn = nn.BatchNorm()
            self.act = nn.Activation('relu')
            self.pool = nn.AvgPool2D(pool_size=(4, 4))
            self.tailneg = TailNegBlock()
            self.connect_mode = connect_mode
            self.add_value = mx.gluon.Parameter(
                'add_value',
                init=mx.init.Xavier(magnitude=2.24),
                dtype='float32',
                allow_deferred_init=True)

        def hybrid_forward(self, F, x, add_value):
            conv = self.conv1(x)
            bn = self.bn(conv)
            sum1 = bn + add_value
            relu = self.act(sum1)
            if self.connect_mode == "conv_customop":
                pool = self.pool(conv)
            elif self.connect_mode == "bn_customop":
                pool = self.pool(bn)
            else:
                pool = self.pool(sum1)
            return self.tailneg(relu, pool)

    # eg.1
    net1 = NegConvBNAddRelu(connect_mode="conv_customop")
    attrs1 = []
    excluded_attrs1 = ['with_sum', 'with_postsum_act', 'with_bn']
    check_neg_fusion(net1, attrs1, excluded_attrs1, data_shape)

    # eg.2
    net2 = NegConvBNAddRelu(connect_mode="bn_customop")
    attrs2 = ['with_bn']
    excluded_attrs2 = ['with_sum', 'with_postsum_act']
    check_neg_fusion(net2, attrs2, excluded_attrs2, data_shape)

    # eg.3
    net3 = NegConvBNAddRelu(connect_mode="add_customop")
    attrs3 = ['with_bn', 'with_sum']
    excluded_attrs3 = ['with_postsum_act']
    check_neg_fusion(net3, attrs3, excluded_attrs3, data_shape)