def test_pad(self): # *************************************************************** # Test ReplicationPad2d Layer # *************************************************************** arr = np.random.randn(16,3,224,224) check_equal(arr, jnn.ReplicationPad2d(10), tnn.ReplicationPad2d(10)) check_equal(arr, jnn.ReplicationPad2d((1,23,4,5)), tnn.ReplicationPad2d((1,23,4,5))) check_equal(arr, jnn.ReplicationPad2d((1,0,1,5)), tnn.ReplicationPad2d((1,0,1,5))) check_equal(arr, jnn.ReplicationPad2d((100)), tnn.ReplicationPad2d((100))) # *************************************************************** # Test ConstantPad2d Layer # *************************************************************** arr = np.random.randn(16,3,224,224) check_equal(arr, jnn.ConstantPad2d(10,-2), tnn.ConstantPad2d(10,-2)) check_equal(arr, jnn.ConstantPad2d((2,3,34,1),10.2), tnn.ConstantPad2d((2,3,34,1),10.2)) # *************************************************************** # Test ZeroPad2d Layer # *************************************************************** arr = np.random.randn(16,3,224,224) check_equal(arr, jnn.ZeroPad2d(1), tnn.ZeroPad2d(1)) check_equal(arr, jnn.ZeroPad2d((2,3,34,1)), tnn.ZeroPad2d((2,3,34,1))) # *************************************************************** # Test ReflectionPad2d Layer # *************************************************************** arr = np.random.randn(16,3,224,224) check_equal(arr, jnn.ReflectionPad2d(20), tnn.ReflectionPad2d(20)) check_equal(arr, jnn.ReflectionPad2d((2,3,34,1)), tnn.ReflectionPad2d((2,3,34,1))) check_equal(arr, jnn.ReflectionPad2d((10,123,34,1)), tnn.ReflectionPad2d((10,123,34,1))) check_equal(arr, jnn.ReflectionPad2d((100)), tnn.ReflectionPad2d((100)))
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout): conv_block = [] p = 0 if (padding_type == 'reflect'): conv_block += [nn.ReflectionPad2d(1)] elif (padding_type == 'replicate'): conv_block += [nn.ReplicationPad2d(1)] elif (padding_type == 'zero'): p = 1 else: raise NotImplementedError( ('padding [%s] is not implemented' % padding_type)) conv_block += [ nn.Conv(dim, dim, 3, padding=p), norm_layer(dim), activation ] if use_dropout: conv_block += [nn.Dropout(0.5)] p = 0 if (padding_type == 'reflect'): conv_block += [nn.ReflectionPad2d(1)] elif (padding_type == 'replicate'): conv_block += [nn.ReplicationPad2d(1)] elif (padding_type == 'zero'): p = 1 else: raise NotImplementedError( ('padding [%s] is not implemented' % padding_type)) conv_block += [nn.Conv(dim, dim, 3, padding=p), norm_layer(dim)] return nn.Sequential(*conv_block)
def test_pad(self): # *************************************************************** # Test ReplicationPad2d Layer # *************************************************************** arr = np.random.randn(16, 3, 224, 224) check_equal(arr, jnn.ReplicationPad2d(10), tnn.ReplicationPad2d(10)) check_equal(arr, jnn.ReplicationPad2d((1, 23, 4, 5)), tnn.ReplicationPad2d((1, 23, 4, 5))) check_equal(arr, jnn.ReplicationPad2d((1, 0, 1, 5)), tnn.ReplicationPad2d((1, 0, 1, 5))) check_equal(arr, jnn.ReplicationPad2d((100)), tnn.ReplicationPad2d((100))) # *************************************************************** # Test ConstantPad2d Layer # *************************************************************** arr = np.random.randn(16, 3, 224, 224) check_equal(arr, jnn.ConstantPad2d(10, -2), tnn.ConstantPad2d(10, -2)) check_equal(arr, jnn.ConstantPad2d((2, 3, 34, 1), 10.2), tnn.ConstantPad2d((2, 3, 34, 1), 10.2)) arr = np.random.randn(16, 3, 224, 10, 10) check_equal(arr, jnn.ConstantPad2d(10, -2), tnn.ConstantPad2d(10, -2)) check_equal(arr, jnn.ConstantPad2d((2, 3, 34, 1), 10.2), tnn.ConstantPad2d((2, 3, 34, 1), 10.2)) # *************************************************************** # Test ZeroPad2d Layer # *************************************************************** arr = np.random.randn(16, 3, 224, 224) check_equal(arr, jnn.ZeroPad2d(1), tnn.ZeroPad2d(1)) check_equal(arr, jnn.ZeroPad2d((2, 3, 34, 1)), tnn.ZeroPad2d((2, 3, 34, 1))) # *************************************************************** # Test ReflectionPad2d Layer # *************************************************************** arr = np.random.randn(16, 3, 224, 224) check_equal(arr, jnn.ReflectionPad2d(20), tnn.ReflectionPad2d(20)) check_equal(arr, jnn.ReflectionPad2d((2, 3, 34, 1)), tnn.ReflectionPad2d((2, 3, 34, 1))) check_equal(arr, jnn.ReflectionPad2d((10, 123, 34, 1)), tnn.ReflectionPad2d((10, 123, 34, 1))) check_equal(arr, jnn.ReflectionPad2d((100)), tnn.ReflectionPad2d( (100))) # *************************************************************** # Test function pad # *************************************************************** arr = np.random.randn(16, 3, 224, 224) padding = (10, 11, 2, 3) for mode in ['constant', 'replicate', 'reflect', 'circular']: j_data = jt.array(arr) t_data = torch.tensor(arr) t_output = tnn.functional.pad(t_data, padding, mode=mode).detach().numpy() j_output = jnn.pad(j_data, padding, mode).numpy() assert np.allclose(t_output, j_output)
def __init__(self, input_dim, output_dim, kernel_size, stride, padding=0, norm='none', activation='relu', pad_type='zero'): super(ConvBlock, self).__init__() self.use_bias = True # initialize padding if pad_type == 'reflect': self.pad = nn.ReflectionPad2d(padding) elif pad_type == 'replicate': self.pad = nn.ReplicationPad2d(padding) elif pad_type == 'zero': self.pad = nn.ZeroPad2d(padding) else: assert 0, "Unsupported padding type: {}".format(pad_type) # initialize normalization norm_dim = output_dim if norm == 'bn': self.norm = nn.BatchNorm(norm_dim) elif norm == 'in': self.norm = nn.InstanceNorm2d(norm_dim) elif norm == 'adain': self.norm = AdaptiveInstanceNorm2d(norm_dim) elif norm == 'none': self.norm = None else: assert 0, "Unsupported normalization: {}".format(norm) # initialize activation if activation == 'relu': self.activation = nn.ReLU() elif activation == 'tanh': self.activation = nn.Tanh() elif activation == 'none': self.activation = None else: assert 0, "Unsupported activation: {}".format(activation) self.conv = nn.Conv(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)