def res_unit(x, scope_name, rng, dn=False, test=False): C = x.shape[1] with nn.parameter_scope(scope_name): # Conv -> BN -> Relu with nn.parameter_scope("conv1"): h = PF.binary_weight_convolution(x, C / 2, kernel=(1, 1), pad=(0, 0), with_bias=False) h = PF.batch_normalization(h, batch_stat=not test) h = F.relu(h) # Conv -> BN -> Relu with nn.parameter_scope("conv2"): h = PF.binary_weight_convolution(h, C / 2, kernel=(3, 3), pad=(1, 1), with_bias=False) h = PF.batch_normalization(h, batch_stat=not test) h = F.relu(h) # Conv -> BN with nn.parameter_scope("conv3"): h = PF.binary_weight_convolution(h, C, kernel=(1, 1), pad=(0, 0), with_bias=False) h = PF.batch_normalization(h, batch_stat=not test) # Residual -> Relu h = F.relu(h + x) # Maxpooling if dn: h = F.max_pooling(h, kernel=(2, 2), stride=(2, 2)) return h
def res_unit(x, scope): C = x.shape[1] with nn.parameter_scope(scope): with nn.parameter_scope('conv1'): h = F.elu(bn(PF.binary_weight_convolution( x, C / 2, (1, 1), with_bias=False))) with nn.parameter_scope('conv2'): h = F.elu( bn(PF.binary_weight_convolution(h, C / 2, (3, 3), pad=(1, 1), with_bias=False))) with nn.parameter_scope('conv3'): h = bn(PF.binary_weight_convolution( h, C, (1, 1), with_bias=False)) return F.elu(x + h)
def res_unit(x, scope): C = x.shape[1] with nn.parameter_scope(scope): with nn.parameter_scope('conv1'): h = F.elu(bn(PF.binary_weight_convolution( x, C / 2, (1, 1), with_bias=False))) with nn.parameter_scope('conv2'): h = F.elu( bn(PF.binary_weight_convolution(h, C / 2, (3, 3), pad=(1, 1), with_bias=False))) with nn.parameter_scope('conv3'): h = bn(PF.binary_weight_convolution( h, C, (1, 1), with_bias=False)) return F.elu(x + h)
def mnist_binary_weight_lenet_prediction(image, test=False): """ Construct LeNet for MNIST (Binary Weight Network version). """ with nn.parameter_scope("conv1"): c1 = PF.binary_weight_convolution(image, 16, (5, 5)) c1 = F.elu(F.average_pooling(c1, (2, 2))) with nn.parameter_scope("conv2"): c2 = PF.binary_weight_convolution(c1, 16, (5, 5)) c2 = F.elu(F.average_pooling(c2, (2, 2))) with nn.parameter_scope("fc3"): c3 = F.elu(PF.binary_weight_affine(c2, 50)) with nn.parameter_scope("fc4"): c4 = PF.binary_weight_affine(c3, 10) return c4
def mnist_binary_weight_lenet_prediction(image, test=False): """ Construct LeNet for MNIST (Binary Weight Network version). """ with nn.parameter_scope("conv1"): c1 = PF.binary_weight_convolution(image, 16, (5, 5)) c1 = F.elu(F.average_pooling(c1, (2, 2))) with nn.parameter_scope("conv2"): c2 = PF.binary_weight_convolution(c1, 16, (5, 5)) c2 = F.elu(F.average_pooling(c2, (2, 2))) with nn.parameter_scope("fc3"): c3 = F.elu(PF.binary_weight_affine(c2, 50)) with nn.parameter_scope("fc4"): c4 = PF.binary_weight_affine(c3, 10) return c4
def mnist_binary_weight_resnet_prediction(image, test=False): """ Construct ResNet for MNIST (Binary Weight Network version). """ def bn(x): return PF.batch_normalization(x, batch_stat=not test) def res_unit(x, scope): C = x.shape[1] with nn.parameter_scope(scope): with nn.parameter_scope('conv1'): h = F.elu( bn( PF.binary_weight_convolution(x, C / 2, (1, 1), with_bias=False))) with nn.parameter_scope('conv2'): h = F.elu( bn( PF.binary_weight_convolution(h, C / 2, (3, 3), pad=(1, 1), with_bias=False))) with nn.parameter_scope('conv3'): h = bn( PF.binary_weight_convolution(h, C, (1, 1), with_bias=False)) return F.elu(x + h) # Conv1 --> 64 x 32 x 32 with nn.parameter_scope("conv1"): c1 = F.elu( bn( PF.binary_weight_convolution(image, 64, (3, 3), pad=(3, 3), with_bias=False))) # Conv2 --> 64 x 16 x 16 c2 = F.max_pooling(res_unit(c1, "conv2"), (2, 2)) # Conv3 --> 64 x 8 x 8 c3 = F.max_pooling(res_unit(c2, "conv3"), (2, 2)) # Conv4 --> 64 x 8 x 8 c4 = res_unit(c3, "conv4") # Conv5 --> 64 x 4 x 4 c5 = F.max_pooling(res_unit(c4, "conv5"), (2, 2)) # Conv5 --> 64 x 4 x 4 c6 = res_unit(c5, "conv6") pl = F.average_pooling(c6, (4, 4)) with nn.parameter_scope("classifier"): y = PF.binary_weight_affine(pl, 10) return y
def mnist_binary_weight_resnet_prediction(image, test=False): """ Construct ResNet for MNIST (Binary Weight Network version). """ def bn(x): return PF.batch_normalization(x, batch_stat=not test) def res_unit(x, scope): C = x.shape[1] with nn.parameter_scope(scope): with nn.parameter_scope('conv1'): h = F.elu(bn(PF.binary_weight_convolution( x, C / 2, (1, 1), with_bias=False))) with nn.parameter_scope('conv2'): h = F.elu( bn(PF.binary_weight_convolution(h, C / 2, (3, 3), pad=(1, 1), with_bias=False))) with nn.parameter_scope('conv3'): h = bn(PF.binary_weight_convolution( h, C, (1, 1), with_bias=False)) return F.elu(x + h) # Conv1 --> 64 x 32 x 32 with nn.parameter_scope("conv1"): c1 = F.elu( bn(PF.binary_weight_convolution(image, 64, (3, 3), pad=(3, 3), with_bias=False))) # Conv2 --> 64 x 16 x 16 c2 = F.max_pooling(res_unit(c1, "conv2"), (2, 2)) # Conv3 --> 64 x 8 x 8 c3 = F.max_pooling(res_unit(c2, "conv3"), (2, 2)) # Conv4 --> 64 x 8 x 8 c4 = res_unit(c3, "conv4") # Conv5 --> 64 x 4 x 4 c5 = F.max_pooling(res_unit(c4, "conv5"), (2, 2)) # Conv5 --> 64 x 4 x 4 c6 = res_unit(c5, "conv6") pl = F.average_pooling(c6, (4, 4)) with nn.parameter_scope("classifier"): y = PF.binary_weight_affine(pl, 10) return y
def cifar10_binary_weight_resnet23_prediction(image, maps=64, test=False): """ Construct BianryWeight using resnet23. Binary Weight binaries weights, but use the approximate coefficients to alleviate the binary quantization. References: Rastegari Mohammad, Ordonez Vicente, Redmon Joseph, and Farhadi Ali, "XNOR-Net: ImageNet Classification Using Binary Convolutional Neural Networks", arXiv:1603.05279 """ # Residual Unit def res_unit(x, scope_name, rng, dn=False, test=False): C = x.shape[1] with nn.parameter_scope(scope_name): # Conv -> BN -> Relu with nn.parameter_scope("conv1"): h = PF.binary_weight_convolution(x, C / 2, kernel=(1, 1), pad=(0, 0), with_bias=False) h = PF.batch_normalization(h, batch_stat=not test) h = F.relu(h) # Conv -> BN -> Relu with nn.parameter_scope("conv2"): h = PF.binary_weight_convolution(h, C / 2, kernel=(3, 3), pad=(1, 1), with_bias=False) h = PF.batch_normalization(h, batch_stat=not test) h = F.relu(h) # Conv -> BN with nn.parameter_scope("conv3"): h = PF.binary_weight_convolution(h, C, kernel=(1, 1), pad=(0, 0), with_bias=False) h = PF.batch_normalization(h, batch_stat=not test) # Residual -> Relu h = F.relu(h + x) # Maxpooling if dn: h = F.max_pooling(h, kernel=(2, 2), stride=(2, 2)) return h ncls = 10 # Conv -> BN -> Relu with nn.parameter_scope("conv1"): # Preprocess image /= 255.0 if not test: image = F.image_augmentation(image, contrast=1.0, angle=0.25, flip_lr=True) image.need_grad = False h = PF.binary_weight_convolution(image, maps, kernel=(3, 3), pad=(1, 1), with_bias=False) h = PF.batch_normalization(h, batch_stat=not test) h = F.relu(h) h = res_unit(h, "conv2", False) # -> 32x32 h = res_unit(h, "conv3", True) # -> 16x16 h = res_unit(h, "conv4", False) # -> 16x16 h = res_unit(h, "conv5", True) # -> 8x8 h = res_unit(h, "conv6", False) # -> 8x8 h = res_unit(h, "conv7", True) # -> 4x4 h = res_unit(h, "conv8", False) # -> 4x4 h = F.average_pooling(h, kernel=(4, 4)) # -> 1x1 pred = PF.binary_weight_affine(h, ncls) return pred
def cifar10_binary_weight_resnet23_prediction(image, maps=64, test=False): """ Construct BianryWeight using resnet23. """ # Residual Unit def res_unit(x, scope_name, dn=False): C = x.shape[1] with nn.parameter_scope(scope_name): # Conv -> BN -> Relu with nn.parameter_scope("conv1"): h = PF.binary_weight_convolution(x, C // 2, kernel=(1, 1), pad=(0, 0), with_bias=False) h = PF.batch_normalization(h, batch_stat=not test) h = F.relu(h) # Conv -> BN -> Relu with nn.parameter_scope("conv2"): h = PF.binary_weight_convolution(h, C // 2, kernel=(3, 3), pad=(1, 1), with_bias=False) h = PF.batch_normalization(h, batch_stat=not test) h = F.relu(h) # Conv -> BN with nn.parameter_scope("conv3"): h = PF.binary_weight_convolution(h, C, kernel=(1, 1), pad=(0, 0), with_bias=False) h = PF.batch_normalization(h, batch_stat=not test) # Residual -> Relu h = F.relu(h + x) # Maxpooling if dn: h = F.max_pooling(h, kernel=(2, 2), stride=(2, 2)) return h ncls = 10 # Conv -> BN -> Relu with nn.parameter_scope("conv1"): # Preprocess image /= 255.0 if not test: image = F.image_augmentation(image, contrast=1.0, angle=0.25, flip_lr=True) image.need_grad = False h = PF.binary_weight_convolution(image, maps, kernel=(3, 3), pad=(1, 1), with_bias=False) h = PF.batch_normalization(h, batch_stat=not test) h = F.relu(h) h = res_unit(h, "conv2", False) # -> 32x32 h = res_unit(h, "conv3", True) # -> 16x16 h = res_unit(h, "conv4", False) # -> 16x16 h = res_unit(h, "conv5", True) # -> 8x8 h = res_unit(h, "conv6", False) # -> 8x8 h = res_unit(h, "conv7", True) # -> 4x4 h = res_unit(h, "conv8", False) # -> 4x4 h = F.average_pooling(h, kernel=(4, 4)) # -> 1x1 pred = PF.binary_weight_affine(h, ncls) return pred