Beispiel #1
0
    def __init__(self, input_channels, n_filters, filter_size, dropout,
                 bias=True, dilation=1, stride=(1, 1), bn_momentum=0.1,
                 ini='random'):
        super(BnReluConv, self).__init__()
        """
        It builds a block with: Batch Norm, Dropout, ReLU, Convolution.

        Input:
            - input_channels: int. Number of input feature maps.
            - n_filters: int. Number of output feature maps.
            - filter_size: int. Convolution filter size.
            - dropout: float. Percentage of dropout.
            - bias: bool. Bias in convolution.
            - dilation: int. Dilation rate for dilated convolution.
                        If 1, traditional convolution is used.
            - stride: int or tuple. Stride used in the convolution.
            - bn_momentum: float. Batch-norm momentum.
            - ini: string. Initialization for the dilated convolution
                   weights. It can be 'random' or 'identity'.
        """
        self.bn = nn.BatchNorm2d(input_channels, eps=0.001,
                                 momentum=bn_momentum)
        if dropout > 0:
            self.drop = nn.Dropout(dropout)
        if dilation == 1:
            self.conv = nn.Conv2d(input_channels, n_filters,
                                  kernel_size=filter_size,
                                  padding=(filter_size - 1) // 2, bias=bias,
                                  stride=stride)
            # Initialize modules
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    kaiming_uniform(m.weight)
                    m.bias.data.zero_()
                elif isinstance(m, nn.BatchNorm2d):
                    m.weight.data.fill_(1)
                    m.bias.data.zero_()

        # In the case where we want to use dilated convolutions
        # in the transformation blocks between ResNets
        else:

            self.conv = nn.Conv2d(input_channels, n_filters,
                                  kernel_size=filter_size, dilation=dilation,
                                  padding=((filter_size + (filter_size - 1) * (
                                              dilation - 1)) - 1) // 2,
                                  bias=bias)
            # Initialize modules
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    if ini == 'identity':
                        dirac(m.weight)
                    else:
                        kaiming_uniform(m.weight)
                    m.bias.data.zero_()
                elif isinstance(m, nn.BatchNorm2d):
                    m.weight.data.fill_(1)
                    m.bias.data.zero_()

        self.dropout = dropout
Beispiel #2
0
 def init_params(self, out_channels):
     self.alpha = nn.Parameter(torch.Tensor(out_channels).fill_(1))
     self.beta = nn.Parameter(torch.Tensor(out_channels).fill_(0.1))
     self.register_buffer('delta', dirac(self.weight.data.clone()))
     assert self.delta.size() == self.weight.size()
     self.v = (-1,) + (1,) * (self.weight.dim() - 1)
Beispiel #3
0
def dirac_delta(ni, no, k):
    n = min(ni, no)
    return dirac(torch.Tensor(n, n, k, k)).repeat(max(no // ni, 1),
                                                  max(ni // no, 1), 1, 1)
Beispiel #4
0
def dirac_delta(ni, no, k):
    n = min(ni, no)
    size = (n, n) + k
    repeats = (max(no // ni, 1), max(ni // no, 1)) + (1, ) * len(k)
    return dirac(torch.Tensor(*size)).repeat(*repeats)
Beispiel #5
0
def define_diracnet(depth, width, dataset):
    def gen_group_params(ni, no, count):
        return {
            'block%d' % i: {
                'conv': conv_params(ni if i == 0 else no, no, k=3, gain=1),
                'alpha': cast(torch.Tensor([1])),
                'beta': cast(torch.Tensor([0.1])),
                'bn': bnparams(no)
            }
            for i in range(count)
        }

    def gen_group_stats(no, count):
        return {'block%d' % i: {'bn': bnstats(no)} for i in range(count)}

    if dataset.startswith('CIFAR'):
        n = (depth - 4) // 6
        widths = torch.Tensor([16, 32, 64]).mul(width).int()

        def f(inputs, params, stats, mode):
            o = F.conv2d(inputs, params['conv'], padding=1)
            o = F.relu(batch_norm(o, params, stats, 'bn', mode))
            o = group(o, params, stats, 'group0', mode, n * 2)
            o = F.max_pool2d(o, 2)
            o = group(o, params, stats, 'group1', mode, n * 2)
            o = F.max_pool2d(o, 2)
            o = group(o, params, stats, 'group2', mode, n * 2)
            o = F.avg_pool2d(F.relu(o), 8)
            o = F.linear(o.view(o.size(0), -1), params['fc.weight'],
                         params['fc.bias'])
            return o

        params = {
            'conv': cast(kaiming_normal(torch.Tensor(widths[0], 3, 3, 3))),
            'bn': bnparams(widths[0]),
            'group0': gen_group_params(widths[0], widths[0], n * 2),
            'group1': gen_group_params(widths[0], widths[1], n * 2),
            'group2': gen_group_params(widths[1], widths[2], n * 2),
            'fc': linear_params(widths[2],
                                10 if dataset == 'CIFAR10' else 100),
        }

        stats = {
            'group%d' % i: gen_group_stats(no, n * 2)
            for i, no in enumerate(widths)
        }
        stats['bn'] = bnstats(widths[0])

    elif dataset == 'ImageNet':
        definitions = {18: [2, 2, 2, 2], 34: [3, 4, 6, 3]}
        widths = torch.Tensor([64, 128, 256, 512]).mul(width).int()
        blocks = definitions[depth]

        def f(inputs, params, stats, mode):
            o = F.conv2d(inputs, params['conv'], padding=3, stride=2)
            o = batch_norm(o, params, stats, 'bn', mode)
            o = F.max_pool2d(o, 3, 2, 1)
            o = group(o, params, stats, 'group0', mode, blocks[0] * 2)
            o = F.max_pool2d(o, 2)
            o = group(o, params, stats, 'group1', mode, blocks[1] * 2)
            o = F.max_pool2d(o, 2)
            o = group(o, params, stats, 'group2', mode, blocks[2] * 2)
            o = F.max_pool2d(o, 2)
            o = group(o, params, stats, 'group3', mode, blocks[3] * 2)
            o = F.avg_pool2d(F.relu(o), o.size(-1))
            o = F.linear(o.view(o.size(0), -1), params['fc.weight'],
                         params['fc.bias'])
            return o

        params = {
            'conv': cast(kaiming_normal(torch.Tensor(widths[0], 3, 7, 7))),
            'group0': gen_group_params(widths[0], widths[0], 2 * blocks[0]),
            'group1': gen_group_params(widths[0], widths[1], 2 * blocks[1]),
            'group2': gen_group_params(widths[1], widths[2], 2 * blocks[2]),
            'group3': gen_group_params(widths[2], widths[3], 2 * blocks[3]),
            'bn': bnparams(widths[0]),
            'fc': linear_params(widths[-1], 1000),
        }

        stats = {
            'group%d' % i: gen_group_stats(no, 2 * b)
            for i, (no, b) in enumerate(zip(widths, blocks))
        }
        stats['bn'] = bnstats(widths[0])
    else:
        raise ValueError('dataset not understood')

    flat_params = flatten_params(params)
    flat_stats = flatten_stats(stats)

    for k, v in list(flat_params.items()):
        if k.find('.conv') > -1:
            flat_stats[size2name(v.size())] = cast(dirac(v.data.clone()))

    return f, flat_params, flat_stats