def __init__(self, input_channels, n_filters, filter_size, dropout, bias=True, dilation=1, stride=(1, 1), bn_momentum=0.1, ini='random'): super(BnReluConv, self).__init__() """ It builds a block with: Batch Norm, Dropout, ReLU, Convolution. Input: - input_channels: int. Number of input feature maps. - n_filters: int. Number of output feature maps. - filter_size: int. Convolution filter size. - dropout: float. Percentage of dropout. - bias: bool. Bias in convolution. - dilation: int. Dilation rate for dilated convolution. If 1, traditional convolution is used. - stride: int or tuple. Stride used in the convolution. - bn_momentum: float. Batch-norm momentum. - ini: string. Initialization for the dilated convolution weights. It can be 'random' or 'identity'. """ self.bn = nn.BatchNorm2d(input_channels, eps=0.001, momentum=bn_momentum) if dropout > 0: self.drop = nn.Dropout(dropout) if dilation == 1: self.conv = nn.Conv2d(input_channels, n_filters, kernel_size=filter_size, padding=(filter_size - 1) // 2, bias=bias, stride=stride) # Initialize modules for m in self.modules(): if isinstance(m, nn.Conv2d): kaiming_uniform(m.weight) m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() # In the case where we want to use dilated convolutions # in the transformation blocks between ResNets else: self.conv = nn.Conv2d(input_channels, n_filters, kernel_size=filter_size, dilation=dilation, padding=((filter_size + (filter_size - 1) * ( dilation - 1)) - 1) // 2, bias=bias) # Initialize modules for m in self.modules(): if isinstance(m, nn.Conv2d): if ini == 'identity': dirac(m.weight) else: kaiming_uniform(m.weight) m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() self.dropout = dropout
def init_params(self, out_channels): self.alpha = nn.Parameter(torch.Tensor(out_channels).fill_(1)) self.beta = nn.Parameter(torch.Tensor(out_channels).fill_(0.1)) self.register_buffer('delta', dirac(self.weight.data.clone())) assert self.delta.size() == self.weight.size() self.v = (-1,) + (1,) * (self.weight.dim() - 1)
def dirac_delta(ni, no, k): n = min(ni, no) return dirac(torch.Tensor(n, n, k, k)).repeat(max(no // ni, 1), max(ni // no, 1), 1, 1)
def dirac_delta(ni, no, k): n = min(ni, no) size = (n, n) + k repeats = (max(no // ni, 1), max(ni // no, 1)) + (1, ) * len(k) return dirac(torch.Tensor(*size)).repeat(*repeats)
def define_diracnet(depth, width, dataset): def gen_group_params(ni, no, count): return { 'block%d' % i: { 'conv': conv_params(ni if i == 0 else no, no, k=3, gain=1), 'alpha': cast(torch.Tensor([1])), 'beta': cast(torch.Tensor([0.1])), 'bn': bnparams(no) } for i in range(count) } def gen_group_stats(no, count): return {'block%d' % i: {'bn': bnstats(no)} for i in range(count)} if dataset.startswith('CIFAR'): n = (depth - 4) // 6 widths = torch.Tensor([16, 32, 64]).mul(width).int() def f(inputs, params, stats, mode): o = F.conv2d(inputs, params['conv'], padding=1) o = F.relu(batch_norm(o, params, stats, 'bn', mode)) o = group(o, params, stats, 'group0', mode, n * 2) o = F.max_pool2d(o, 2) o = group(o, params, stats, 'group1', mode, n * 2) o = F.max_pool2d(o, 2) o = group(o, params, stats, 'group2', mode, n * 2) o = F.avg_pool2d(F.relu(o), 8) o = F.linear(o.view(o.size(0), -1), params['fc.weight'], params['fc.bias']) return o params = { 'conv': cast(kaiming_normal(torch.Tensor(widths[0], 3, 3, 3))), 'bn': bnparams(widths[0]), 'group0': gen_group_params(widths[0], widths[0], n * 2), 'group1': gen_group_params(widths[0], widths[1], n * 2), 'group2': gen_group_params(widths[1], widths[2], n * 2), 'fc': linear_params(widths[2], 10 if dataset == 'CIFAR10' else 100), } stats = { 'group%d' % i: gen_group_stats(no, n * 2) for i, no in enumerate(widths) } stats['bn'] = bnstats(widths[0]) elif dataset == 'ImageNet': definitions = {18: [2, 2, 2, 2], 34: [3, 4, 6, 3]} widths = torch.Tensor([64, 128, 256, 512]).mul(width).int() blocks = definitions[depth] def f(inputs, params, stats, mode): o = F.conv2d(inputs, params['conv'], padding=3, stride=2) o = batch_norm(o, params, stats, 'bn', mode) o = F.max_pool2d(o, 3, 2, 1) o = group(o, params, stats, 'group0', mode, blocks[0] * 2) o = F.max_pool2d(o, 2) o = group(o, params, stats, 'group1', mode, blocks[1] * 2) o = F.max_pool2d(o, 2) o = group(o, params, stats, 'group2', mode, blocks[2] * 2) o = F.max_pool2d(o, 2) o = group(o, params, stats, 'group3', mode, blocks[3] * 2) o = F.avg_pool2d(F.relu(o), o.size(-1)) o = F.linear(o.view(o.size(0), -1), params['fc.weight'], params['fc.bias']) return o params = { 'conv': cast(kaiming_normal(torch.Tensor(widths[0], 3, 7, 7))), 'group0': gen_group_params(widths[0], widths[0], 2 * blocks[0]), 'group1': gen_group_params(widths[0], widths[1], 2 * blocks[1]), 'group2': gen_group_params(widths[1], widths[2], 2 * blocks[2]), 'group3': gen_group_params(widths[2], widths[3], 2 * blocks[3]), 'bn': bnparams(widths[0]), 'fc': linear_params(widths[-1], 1000), } stats = { 'group%d' % i: gen_group_stats(no, 2 * b) for i, (no, b) in enumerate(zip(widths, blocks)) } stats['bn'] = bnstats(widths[0]) else: raise ValueError('dataset not understood') flat_params = flatten_params(params) flat_stats = flatten_stats(stats) for k, v in list(flat_params.items()): if k.find('.conv') > -1: flat_stats[size2name(v.size())] = cast(dirac(v.data.clone())) return f, flat_params, flat_stats