Esempio n. 1
0
def define_student(depth, width):
    definitions = {18: [2,2,2,2],
                   34: [3,4,6,5]}
    assert depth in list(definitions.keys())
    widths = [int(w * width) for w in (64, 128, 256, 512)]
    blocks = definitions[depth]

    def gen_block_params(ni, no):
        return {'conv0': utils.conv_params(ni, no, 3),
                'conv1': utils.conv_params(no, no, 3),
                'bn0': utils.bnparams(no),
                'bn1': utils.bnparams(no),
                'convdim': utils.conv_params(ni, no, 1) if ni != no else None,
                }

    def gen_group_params(ni, no, count):
        return {'block%d'%i: gen_block_params(ni if i==0 else no, no)
                for i in range(count)}

    flat_params = OrderedDict(utils.flatten({
        'conv0': utils.conv_params(3, 64, 7),
        'bn0': utils.bnparams(64),
        'group0': gen_group_params(64, widths[0], blocks[0]),
        'group1': gen_group_params(widths[0], widths[1], blocks[1]),
        'group2': gen_group_params(widths[1], widths[2], blocks[2]),
        'group3': gen_group_params(widths[2], widths[3], blocks[3]),
        'fc': utils.linear_params(widths[3], 1000),
    }))

    utils.set_requires_grad_except_bn_(flat_params)

    def block(x, params, base, mode, stride):
        y = F.conv2d(x, params[base+'.conv0'], stride=stride, padding=1)
        o1 = F.relu(utils.batch_norm(y, params, base+'.bn0', mode), inplace=True)
        z = F.conv2d(o1, params[base+'.conv1'], stride=1, padding=1)
        o2 = utils.batch_norm(z, params, base+'.bn1', mode)
        if base + '.convdim' in params:
            return F.relu(o2 + F.conv2d(x, params[base+'.convdim'], stride=stride), inplace=True)
        else:
            return F.relu(o2 + x, inplace=True)

    def group(o, params, base, mode, stride, n):
        for i in range(n):
            o = block(o, params, '%s.block%d' % (base, i), mode, stride if i == 0 else 1)
        return o

    def f(input, params, mode, pr=''):
        o = F.conv2d(input, params[pr+'conv0'], stride=2, padding=3)
        o = F.relu(utils.batch_norm(o, params, pr+'bn0', mode), inplace=True)
        o = F.max_pool2d(o, 3, 2, 1)
        g0 = group(o, params, pr+'group0', mode, 1, blocks[0])
        g1 = group(g0, params, pr+'group1', mode, 2, blocks[1])
        g2 = group(g1, params, pr+'group2', mode, 2, blocks[2])
        g3 = group(g2, params, pr+'group3', mode, 2, blocks[3])
        o = F.avg_pool2d(g3, 7)
        o = o.view(o.size(0), -1)
        o = F.linear(o, params[pr+'fc.weight'], params[pr+'fc.bias'])
        return o, [g0, g1, g2, g3]

    return f, flat_params
Esempio n. 2
0
def resnet(depth, width, num_classes, dropout):
    assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
    n = (depth - 4) // 6
    widths = [int(v * width) for v in (16, 32, 64)]

    def gen_block_params(ni, no):
        return {
            'conv0': utils.conv_params(ni, no, 3),
            'conv1': utils.conv_params(no, no, 3),
            'bn0': utils.bnparams(ni),
            'bn1': utils.bnparams(no),
            'convdim': utils.conv_params(ni, no, 1) if ni != no else None,
        }

    def gen_group_params(ni, no, count):
        return {'block%d' % i: gen_block_params(ni if i == 0 else no, no)
                for i in range(count)}

    flat_params = utils.cast(utils.flatten({
        'conv0': utils.conv_params(3, 16, 3),
        'group0': gen_group_params(16, widths[0], n),
        'group1': gen_group_params(widths[0], widths[1], n),
        'group2': gen_group_params(widths[1], widths[2], n),
        'bn': utils.bnparams(widths[2]),
        'fc': utils.linear_params(widths[2], num_classes),
    }))

    utils.set_requires_grad_except_bn_(flat_params)

    def block(x, params, base, mode, stride):
        o1 = F.relu(utils.batch_norm(x, params, base + '.bn0', mode), inplace=True)
        y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1)
        o2 = F.relu(utils.batch_norm(y, params, base + '.bn1', mode), inplace=True)
        if dropout > 0:
            o2 = F.dropout(o2, p=dropout, training=mode, inplace=False)
        z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1)
        if base + '.convdim' in params:
            return z + F.conv2d(o1, params[base + '.convdim'], stride=stride)
        else:
            return z + x

    def group(o, params, base, mode, stride):
        for i in range(n):
            o = block(o, params, '%s.block%d' % (base,i), mode, stride if i == 0 else 1)
        return o

    def f(input, params, mode):
        x = F.conv2d(input, params['conv0'], stride=2, padding=1)
        g0 = group(x, params, 'group0', mode, 1)
        g1 = group(g0, params, 'group1', mode, 2)
        g2 = group(g1, params, 'group2', mode, 2)
        o = F.relu(utils.batch_norm(g2, params, 'bn', mode))
        o = F.avg_pool2d(o, 12, 1, 0)
        o = o.view(o.size(0), -1)
        o = F.linear(o, params['fc.weight'], params['fc.bias'])
        return o

    return f, flat_params
def resnet(depth, width, num_classes):
    assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
    n = (depth - 4) // 6
    widths = [int(v * width) for v in (16, 32, 64)]

    def gen_block_params(ni, no):
        return {
            'conv0': utils.conv_params(ni, no, 3),
            'conv1': utils.conv_params(no, no, 3),
            'bn0': utils.bnparams(ni),
            'bn1': utils.bnparams(no),
            'convdim': utils.conv_params(ni, no, 1) if ni != no else None,
        }

    def gen_group_params(ni, no, count):
        return {'block%d' % i: gen_block_params(ni if i == 0 else no, no)
                for i in range(count)}

    flat_params = utils.cast(utils.flatten({
        'conv0': utils.conv_params(3, 16, 3),
        'group0': gen_group_params(16, widths[0], n),
        'group1': gen_group_params(widths[0], widths[1], n),
        'group2': gen_group_params(widths[1], widths[2], n),
        'bn': utils.bnparams(widths[2]),
        'fc': utils.linear_params(widths[2], num_classes),
    }))

    utils.set_requires_grad_except_bn_(flat_params)

    def block(x, params, base, mode, stride):
        o1 = F.relu(utils.batch_norm(x, params, base + '.bn0', mode), inplace=True)
        y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1)
        o2 = F.relu(utils.batch_norm(y, params, base + '.bn1', mode), inplace=True)
        z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1)
        if base + '.convdim' in params:
            return z + F.conv2d(o1, params[base + '.convdim'], stride=stride)
        else:
            return z + x

    def group(o, params, base, mode, stride):
        for i in range(n):
            o = block(o, params, '%s.block%d' % (base,i), mode, stride if i == 0 else 1)
        return o

    def f(input, params, mode):
        x = F.conv2d(input, params['conv0'], padding=1)
        g0 = group(x, params, 'group0', mode, 1)
        g1 = group(g0, params, 'group1', mode, 2)
        g2 = group(g1, params, 'group2', mode, 2)
        o = F.relu(utils.batch_norm(g2, params, 'bn', mode))
        o = F.avg_pool2d(o, 8, 1, 0)
        o = o.view(o.size(0), -1)
        o = F.linear(o, params['fc.weight'], params['fc.bias'])
        return o

    return f, flat_params
Esempio n. 4
0
def define_student(depth, width):
    definitions = {
        18: [2, 2, 2, 2],
        34: [3, 4, 6, 5],
    }
    assert depth in list(definitions.keys())
    widths = np.floor(np.asarray([64, 128, 256, 512]) * width).astype(np.int)
    blocks = definitions[depth]

    def gen_block_params(ni, no):
        return {
            'conv0': conv_params(ni, no, 3),
            'conv1': conv_params(no, no, 3),
            'bn0': bnparams(no),
            'bn1': bnparams(no),
            'convdim': conv_params(ni, no, 1) if ni != no else None,
        }

    def gen_group_params(ni, no, count):
        return {
            'block%d' % i: gen_block_params(ni if i == 0 else no, no)
            for i in range(count)
        }

    def gen_group_stats(no, count):
        return {
            'block%d' % i: {
                'bn0': bnstats(no),
                'bn1': bnstats(no)
            }
            for i in range(count)
        }

    params = {
        'conv0': conv_params(3, 64, 7),
        'bn0': bnparams(64),
        'group0': gen_group_params(64, widths[0], blocks[0]),
        'group1': gen_group_params(widths[0], widths[1], blocks[1]),
        'group2': gen_group_params(widths[1], widths[2], blocks[2]),
        'group3': gen_group_params(widths[2], widths[3], blocks[3]),
        'fc': linear_params(widths[3], 1000),
    }

    stats = {
        'bn0': bnstats(64),
        'group0': gen_group_stats(widths[0], blocks[0]),
        'group1': gen_group_stats(widths[1], blocks[1]),
        'group2': gen_group_stats(widths[2], blocks[2]),
        'group3': gen_group_stats(widths[3], blocks[3]),
    }

    # flatten parameters and additional buffers
    flat_params = flatten_params(params)
    flat_stats = flatten_stats(stats)

    def block(x, params, stats, base, mode, stride):
        y = F.conv2d(x, params[base + '.conv0'], stride=stride, padding=1)
        o1 = F.relu(batch_norm(y, params, stats, base + '.bn0', mode),
                    inplace=True)
        z = F.conv2d(o1, params[base + '.conv1'], stride=1, padding=1)
        o2 = batch_norm(z, params, stats, base + '.bn1', mode)
        if base + '.convdim' in params:
            return F.relu(
                o2 + F.conv2d(x, params[base + '.convdim'], stride=stride),
                inplace=True)
        else:
            return F.relu(o2 + x, inplace=True)

    def group(o, params, stats, base, mode, stride, n):
        for i in range(n):
            o = block(o, params, stats, '%s.block%d' % (base, i), mode,
                      stride if i == 0 else 1)
        return o

    def f(input, params, stats, mode, pr=''):
        o = F.conv2d(input, params[pr + 'conv0'], stride=2, padding=3)
        o = F.relu(batch_norm(o, params, stats, pr + 'bn0', mode),
                   inplace=True)
        o = F.max_pool2d(o, 3, 2, 1)
        g0 = group(o, params, stats, pr + 'group0', mode, 1, blocks[0])
        g1 = group(g0, params, stats, pr + 'group1', mode, 2, blocks[1])
        g2 = group(g1, params, stats, pr + 'group2', mode, 2, blocks[2])
        g3 = group(g2, params, stats, pr + 'group3', mode, 2, blocks[3])
        o = F.avg_pool2d(g3, 7)
        o = o.view(o.size(0), -1)
        o = F.linear(o, params[pr + 'fc.weight'], params[pr + 'fc.bias'])
        return o, [g0, g1, g2, g3]

    return f, flat_params, flat_stats
Esempio n. 5
0
def resnet(depth, width, num_classes):
    assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
    n = int((depth - 4) / 6)
    widths = np.floor(np.asarray([16., 32., 64.]) * width).astype(np.int)

    def gen_block_params(ni, no):
        return {
            'conv0': conv_params(ni, no, 3),
            'conv1': conv_params(no, no, 3),
            'bn0': bnparams(ni),
            'bn1': bnparams(no),
            'convdim': conv_params(ni, no, 1) if ni != no else None,
        }

    def gen_group_params(ni, no, count):
        return {
            'block%d' % i: gen_block_params(ni if i == 0 else no, no)
            for i in range(count)
        }

    def gen_group_stats(ni, no, count):
        return {
            'block%d' % i: {
                'bn0': bnstats(ni if i == 0 else no),
                'bn1': bnstats(no)
            }
            for i in range(count)
        }

    params = nested_dict({
        'conv0': conv_params(3, 16, 3),
        'group0': gen_group_params(16, widths[0], n),
        'group1': gen_group_params(widths[0], widths[1], n),
        'group2': gen_group_params(widths[1], widths[2], n),
        'bn': bnparams(widths[2]),
        'fc': linear_params(widths[2], num_classes),
    })

    stats = nested_dict({
        'group0': gen_group_stats(16, widths[0], n),
        'group1': gen_group_stats(widths[0], widths[1], n),
        'group2': gen_group_stats(widths[1], widths[2], n),
        'bn': bnstats(widths[2]),
    })

    flat_params = OrderedDict()
    flat_stats = OrderedDict()
    for keys, v in params.iteritems_flat():
        if v is not None:
            flat_params['.'.join(keys)] = Variable(v, requires_grad=True)
    for keys, v in stats.iteritems_flat():
        flat_stats['.'.join(keys)] = v

    def activation(x, params, stats, base, mode):
        return F.relu(
            F.batch_norm(x,
                         weight=params[base + '.weight'],
                         bias=params[base + '.bias'],
                         running_mean=stats[base + '.running_mean'],
                         running_var=stats[base + '.running_var'],
                         training=mode,
                         momentum=0.1,
                         eps=1e-5))

    def block(x, params, stats, base, mode, stride):
        o1 = activation(x, params, stats, base + '.bn0', mode)
        y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1)
        o2 = activation(y, params, stats, base + '.bn1', mode)
        z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1)
        if base + '.convdim' in params:
            return z + F.conv2d(o1, params[base + '.convdim'], stride=stride)
        else:
            return z + x

    def group(o, params, stats, base, mode, stride):
        for i in range(n):
            o = block(o, params, stats, '%s.block%d' % (base, i), mode,
                      stride if i == 0 else 1)
        return o

    def f(input, params, stats, mode, prefix=''):
        x = F.conv2d(input, params[prefix + 'conv0'], padding=1)
        g0 = group(x, params, stats, prefix + 'group0', mode, 1)
        g1 = group(g0, params, stats, prefix + 'group1', mode, 2)
        g2 = group(g1, params, stats, prefix + 'group2', mode, 2)
        o = activation(g2, params, stats, prefix + 'bn', mode)
        o = F.avg_pool2d(o, 8, 1, 0)
        o = o.view(o.size(0), -1)
        o = F.linear(o, params[prefix + 'fc.weight'],
                     params[prefix + 'fc.bias'])
        return o, [g0, g1, g2]

    return f, flat_params, flat_stats
def resnet(depth, width, num_classes, stu_depth=0):
    assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
    n = (depth - 4) // 6
    if stu_depth != 0:
        assert (stu_depth - 4) % 6 == 0, 'student depth should be 6n+4'
        n_s = (stu_depth - 4) // 6
    else:
        n_s = 0

    widths = torch.Tensor([16, 32, 64]).mul(width).int()

    def gen_block_params(ni, no):
        return {
            'conv0': conv_params(ni, no, 3),
            'conv1': conv_params(no, no, 3),
            'bn0': bnparams(ni),
            'bn1': bnparams(no),
            'bns0': bnparams(ni),
            'bns1': bnparams(no),
            'convdim': conv_params(ni, no, 1) if ni != no else None,
        }

    def gen_group_params(ni, no, count):
        return {
            'block%d' % i: gen_block_params(ni if i == 0 else no, no)
            for i in range(count)
        }

    def gen_group_stats(ni, no, count):
        return {
            'block%d' % i: {
                'bn0': bnstats(ni if i == 0 else no),
                'bn1': bnstats(no),
                'bns0': bnstats(ni if i == 0 else no),
                'bns1': bnstats(no)
            }
            for i in range(count)
        }

    if stu_depth != 0 and not opt.param_share:
        params = {
            'conv0': conv_params(3, 16, 3),
            'group0': gen_group_params(16, widths[0], n),
            'group1': gen_group_params(widths[0], widths[1], n),
            'group2': gen_group_params(widths[1], widths[2], n),
            'groups0': gen_group_params(16, widths[0], n_s),
            'groups1': gen_group_params(widths[0], widths[1], n_s),
            'groups2': gen_group_params(widths[1], widths[2], n_s),
            'bn': bnparams(widths[2]),
            'bns': bnparams(widths[2]),
            'fc': linear_params(widths[2], num_classes),
            'fcs': linear_params(widths[2], num_classes),
        }

        stats = {
            'group0': gen_group_stats(16, widths[0], n),
            'group1': gen_group_stats(widths[0], widths[1], n),
            'group2': gen_group_stats(widths[1], widths[2], n),
            'groups0': gen_group_stats(16, widths[0], n_s),
            'groups1': gen_group_stats(widths[0], widths[1], n_s),
            'groups2': gen_group_stats(widths[1], widths[2], n_s),
            'bn': bnstats(widths[2]),
            'bns': bnstats(widths[2]),
        }
    else:
        params = {
            'conv0': conv_params(3, 16, 3),
            'group0': gen_group_params(16, widths[0], n),
            'group1': gen_group_params(widths[0], widths[1], n),
            'group2': gen_group_params(widths[1], widths[2], n),
            'bn': bnparams(widths[2]),
            'bns': bnparams(widths[2]),
            'fc': linear_params(widths[2], num_classes),
            'fcs': linear_params(widths[2], num_classes),
        }

        stats = {
            'group0': gen_group_stats(16, widths[0], n),
            'group1': gen_group_stats(widths[0], widths[1], n),
            'group2': gen_group_stats(widths[1], widths[2], n),
            'bn': bnstats(widths[2]),
            'bns': bnstats(widths[2]),
        }

    flat_params = flatten_params(params)
    flat_stats = flatten_stats(stats)

    def block(x, params, stats, base, mode, stride, flag, drop_switch=True):
        if flag == 's':
            o1 = F.relu(batch_norm(x, params, stats, base + '.bns0', mode))
            y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1)
            o2 = F.relu(batch_norm(y, params, stats, base + '.bns1', mode))
            z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1)
            if base + '.convdim' in params:
                return z + F.conv2d(
                    o1, params[base + '.convdim'], stride=stride)
            else:
                return z + x
        o1 = F.relu(batch_norm(x, params, stats, base + '.bn0', mode))
        y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1)
        o2 = F.relu(batch_norm(y, params, stats, base + '.bn1', mode))
        if opt.dropout > 0 and drop_switch:
            o2 = F.dropout(o2, p=opt.dropout, training=mode)
        z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1)
        if base + '.convdim' in params:
            return z + F.conv2d(o1, params[base + '.convdim'], stride=stride)
        else:
            return z + x

    def group(o, params, stats, base, mode, stride):
        for i in range(n):
            o = block(o, params, stats, '%s.block%d' % (base, i), mode,
                      stride if i == 0 else 1, 't', False)
        return o

    def group_student(o, params, stats, base, mode, stride, n_layer):
        for i in range(n_layer):
            o = block(o, params, stats, '%s.block%d' % (base, i), mode,
                      stride if i == 0 else 1, 's', False)
        return o

    def f(input, params, stats, mode, prefix=''):
        x = F.conv2d(input, params[prefix + 'conv0'], padding=1)
        g0 = group(x, params, stats, prefix + 'group0', mode, 1)
        g1 = group(g0, params, stats, prefix + 'group1', mode, 2)
        g2 = group(g1, params, stats, prefix + 'group2', mode, 2)
        o = F.relu(batch_norm(g2, params, stats, prefix + 'bn', mode))
        o = F.avg_pool2d(o, 8, 1, 0)
        o = o.view(o.size(0), -1)
        o = F.linear(o, params[prefix + 'fc.weight'],
                     params[prefix + 'fc.bias'])
        #x_s = F.conv2d(input, params[prefix+'conv0_s'], padding=1)
        if stu_depth != 0:
            if opt.param_share:
                gs0 = group_student(x, params, stats, prefix + 'group0', mode,
                                    1, n_s)
                gs1 = group_student(gs0, params, stats, prefix + 'group1',
                                    mode, 2, n_s)
                gs2 = group_student(gs1, params, stats, prefix + 'group2',
                                    mode, 2, n_s)
            else:
                gs0 = group_student(x, params, stats, prefix + 'groups0', mode,
                                    1, n_s)
                gs1 = group_student(gs0, params, stats, prefix + 'groups1',
                                    mode, 2, n_s)
                gs2 = group_student(gs1, params, stats, prefix + 'groups2',
                                    mode, 2, n_s)

            os = F.relu(batch_norm(gs2, params, stats, prefix + 'bns', mode))
            os = F.avg_pool2d(os, 8, 1, 0)
            os = os.view(os.size(0), -1)
            os = F.linear(os, params[prefix + 'fcs.weight'],
                          params[prefix + 'fcs.bias'])
            return os, o, [g0, g1, g2, gs0, gs1, gs2]
        else:
            return o, [g0, g1, g2]

    return f, flat_params, flat_stats
Esempio n. 7
0
def resnet(depth, width, num_classes):
    assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
    n = (depth - 4) // 6
    widths = [int(x * width) for x in [16, 32, 64]]

    def gen_block_params(ni, no):
        return {
            'conv0': conv_params(ni, no, 3),
            'conv1': conv_params(no, no, 3),
            'bn0': bnparams(ni),
            'bn1': bnparams(no),
            'convdim': conv_params(ni, no, 1) if ni != no else None,
        }

    def gen_group_params(ni, no, count):
        return {'block%d' % i: gen_block_params(ni if i == 0 else no, no)
                for i in range(count)}

    def gen_group_stats(ni, no, count):
        return {'block%d' % i: {'bn0': bnstats(ni if i == 0 else no), 'bn1': bnstats(no)}
                for i in range(count)}

    flat_params = flatten_params({
        'conv0': conv_params(3,16,3),
        'group0': gen_group_params(16, widths[0], n),
        'group1': gen_group_params(widths[0], widths[1], n),
        'group2': gen_group_params(widths[1], widths[2], n),
        'bn': bnparams(widths[2]),
        'fc': linear_params(widths[2], num_classes),
    })

    flat_stats = flatten_stats({
        'group0': gen_group_stats(16, widths[0], n),
        'group1': gen_group_stats(widths[0], widths[1], n),
        'group2': gen_group_stats(widths[1], widths[2], n),
        'bn': bnstats(widths[2]),
    })

    def block(x, params, stats, base, mode, stride):
        o1 = F.relu(batch_norm(x, params, stats, base + '.bn0', mode), inplace=True)
        y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1)
        o2 = F.relu(batch_norm(y, params, stats, base + '.bn1', mode), inplace=True)
        z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1)
        if base + '.convdim' in params:
            return z + F.conv2d(o1, params[base + '.convdim'], stride=stride)
        else:
            return z + x

    def group(o, params, stats, base, mode, stride):
        for i in range(n):
            o = block(o, params, stats, '%s.block%d' % (base,i), mode, stride if i == 0 else 1)
        return o

    def f(input, params, stats, mode):
        x = F.conv2d(input, params['conv0'], padding=1)
        g0 = group(x, params, stats, 'group0', mode, 1)
        g1 = group(g0, params, stats, 'group1', mode, 2)
        g2 = group(g1, params, stats, 'group2', mode, 2)
        o = F.relu(batch_norm(g2, params, stats, 'bn', mode))
        o = F.avg_pool2d(o, 8, 1, 0)
        o = o.view(o.size(0), -1)
        o = F.linear(o, params['fc.weight'], params['fc.bias'])
        return o

    return f, flat_params, flat_stats
Esempio n. 8
0
def resnet(depth, width, num_classes):
    assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
    n = (depth - 4) // 6
    widths = torch.Tensor([16, 32, 64]).mul(width).int()

    def gen_block_params(ni, no):
        return {
            'conv0': conv_params(ni, no, 3),
            'conv1': conv_params(no, no, 3),
            'bn0': bnparams(ni),
            'bn1': bnparams(no),
            'convdim': conv_params(ni, no, 1) if ni != no else None,
        }

    def gen_group_params(ni, no, count):
        return {
            'block%d' % i: gen_block_params(ni if i == 0 else no, no)
            for i in range(count)
        }

    def gen_group_stats(ni, no, count):
        return {
            'block%d' % i: {
                'bn0': bnstats(ni if i == 0 else no),
                'bn1': bnstats(no)
            }
            for i in range(count)
        }

    params = {
        'conv0': conv_params(3, 16, 3),
        'group0': gen_group_params(16, widths[0], n),
        'group1': gen_group_params(widths[0], widths[1], n),
        'group2': gen_group_params(widths[1], widths[2], n),
        'bn': bnparams(widths[2]),
        'fc': linear_params(widths[2], num_classes),
    }

    stats = {
        'group0': gen_group_stats(16, widths[0], n),
        'group1': gen_group_stats(widths[0], widths[1], n),
        'group2': gen_group_stats(widths[1], widths[2], n),
        'bn': bnstats(widths[2]),
    }

    flat_params = flatten_params(params)
    flat_stats = flatten_stats(stats)

    def activation(x, params, stats, base, mode):
        return F.relu(F.batch_norm(x,
                                   weight=params[base + '.weight'],
                                   bias=params[base + '.bias'],
                                   running_mean=stats[base + '.running_mean'],
                                   running_var=stats[base + '.running_var'],
                                   training=mode,
                                   momentum=0.1,
                                   eps=1e-5),
                      inplace=True)

    def block(x, params, stats, base, mode, stride):
        o1 = activation(x, params, stats, base + '.bn0', mode)
        y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1)
        o2 = activation(y, params, stats, base + '.bn1', mode)
        z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1)
        if base + '.convdim' in params:
            return z + F.conv2d(o1, params[base + '.convdim'], stride=stride)
        else:
            return z + x

    def group(o, params, stats, base, mode, stride):
        for i in range(n):
            o = block(o, params, stats, '%s.block%d' % (base, i), mode,
                      stride if i == 0 else 1)
        return o

    def f(input, params, stats, mode):
        assert input.get_device() == params['conv0'].get_device()
        x = F.conv2d(input, params['conv0'], padding=1)
        g0 = group(x, params, stats, 'group0', mode, 1)
        g1 = group(g0, params, stats, 'group1', mode, 2)
        g2 = group(g1, params, stats, 'group2', mode, 2)
        o = activation(g2, params, stats, 'bn', mode)
        o = F.avg_pool2d(o, 8, 1, 0)
        o = o.view(o.size(0), -1)
        o = F.linear(o, params['fc.weight'], params['fc.bias'])
        return o

    return f, flat_params, flat_stats
Esempio n. 9
0
def resnet(depth, width, num_classes):
    assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
    n = (depth - 4) // 6
    widths = [int(v * width) for v in (16, 32, 64)]

    def gen_block_params(ni, no):
        return {
            'conv0': utils.conv_params(ni, no, 3),
            'conv1': utils.conv_params(no, no, 3),
            'bn0': utils.bnparams(ni),
            'bn1': utils.bnparams(no),
            'convdim': utils.conv_params(ni, no, 1) if ni != no else None,
        }

    def gen_group_params(ni, no, count):
        return {
            'block%d' % i: gen_block_params(ni if i == 0 else no, no)
            for i in range(count)
        }

    flat_params = utils.cast(
        utils.flatten({
            'conv0': utils.conv_params(3, 16, 3),
            'group0': gen_group_params(16, widths[0], n),
            'group1': gen_group_params(widths[0], widths[1], n),
            'group2': gen_group_params(widths[1], widths[2], n),
            'bn': utils.bnparams(widths[2]),
            'fc': utils.linear_params(widths[2], num_classes),
        }))

    utils.set_requires_grad_except_bn_(flat_params)

    def block(x, params, base, mode, stride, out_dict):
        o1 = F.relu(utils.batch_norm(x, params, base + '.bn0', mode),
                    inplace=True)
        y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1)
        o2 = F.relu(utils.batch_norm(y, params, base + '.bn1', mode),
                    inplace=True)
        z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1)
        if base + '.convdim' in params:
            o = z + F.conv2d(o1, params[base + '.convdim'], stride=stride)
        else:
            o = z + x
        out_dict[base + '.relu0'] = o1
        out_dict[base + '.relu1'] = o2
        return o, out_dict

    def group(o, params, base, mode, stride, out_dict):
        for i in range(n):
            o, out_dict = block(o, params, '{}.block{}'.format(base, i), mode,
                                stride if i == 0 else 1, out_dict)
        return o, out_dict

    def f(input, params, mode, base=''):
        out_dict = {}
        x = F.conv2d(input, params['{}conv0'.format(base)], padding=1)
        g0, out_dict = group(x, params, '{}group0'.format(base), mode, 1,
                             out_dict)
        g1, out_dict = group(g0, params, '{}group1'.format(base), mode, 2,
                             out_dict)
        g2, out_dict = group(g1, params, '{}group2'.format(base), mode, 2,
                             out_dict)
        o = F.relu(utils.batch_norm(g2, params, '{}bn'.format(base), mode))
        out_dict[base + "last_relu"] = o
        o = F.avg_pool2d(o, 8, 1, 0)
        o = o.view(o.size(0), -1)
        o = F.linear(o, params['{}fc.weight'.format(base)],
                     params['{}fc.bias'.format(base)])
        return o, (g0, g1, g2), out_dict

    return f, flat_params
Esempio n. 10
0
def resnet(depth, width, num_classes, dropout_prob, activation_dropout):
    assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
    n = (depth - 4) // 6
    widths = [int(v * width) for v in (16, 32, 64)]

    def gen_block_params(ni, no):
        return {
            'conv0': utils.conv_params(ni, no, 3),
            'conv1': utils.conv_params(no, no, 3),
            'bn0': utils.bnparams(ni),
            'bn1': utils.bnparams(no),
            'convdim': utils.conv_params(ni, no, 1) if ni != no else None,
        }

    def gen_group_params(ni, no, count):
        return {
            'block%d' % i: gen_block_params(ni if i == 0 else no, no)
            for i in range(count)
        }

    flat_params = utils.cast(
        utils.flatten({
            'conv0': utils.conv_params(3, 16, 3),
            'group0': gen_group_params(16, widths[0], n),
            'group1': gen_group_params(widths[0], widths[1], n),
            'group2': gen_group_params(widths[1], widths[2], n),
            'bn': utils.bnparams(widths[2]),
            'fc': utils.linear_params(widths[2], num_classes),
        }))

    utils.set_requires_grad_except_bn_(flat_params)

    def activation_dropout(x, p_drop, training):
        if training:
            # the input P. is the base DROPOUT PROBABILITY
            P = 1. - p_drop
            # x.size() = [bs, f , h, w]
            #sum over w and h to get total activation of a filter across space -> [bs, f]
            # normalize feature activations to 1 for each example in the batch
            bs, N, w, h = x.size()
            p_act = F.normalize(x.sum(-1).sum(-1), p=1, dim=-1)
            p_retain = 1. - ((1. - P) * (N - 1.) * p_act) / ((
                (1. - P) * N - 1.) * p_act + P)
            mask = torch.bernoulli(p_retain)
            scale = mask.mean(-1)
            mask = mask / torch.stack([scale for i in range(N)], -1)
            mask = torch.stack([mask for i in range(w)], -1)
            mask = torch.stack([mask for i in range(h)], -1)

            return mask * x
        else:
            return x

    def block(x, params, base, mode, stride):
        o1 = F.relu(utils.batch_norm(x, params, base + '.bn0', mode),
                    inplace=True)
        y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1)
        o2 = F.relu(utils.batch_norm(y, params, base + '.bn1', mode),
                    inplace=True)
        if activation_dropout:
            o2 = activation_dropout(o2, p_drop=dropout_prob, training=mode)
        elif dropout_prob:
            o2 = F.dropout2d(o2, p=dropout_prob, training=mode)
        z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1)
        if base + '.convdim' in params:
            return z + F.conv2d(o1, params[base + '.convdim'], stride=stride)
        else:
            return z + x

    def group(o, params, base, mode, stride):
        for i in range(n):
            o = block(o, params, '%s.block%d' % (base, i), mode,
                      stride if i == 0 else 1)
        return o

    def f(input, params, mode):
        x = F.conv2d(input, params['conv0'], padding=1)
        g0 = group(x, params, 'group0', mode, 1)
        g1 = group(g0, params, 'group1', mode, 2)
        g2 = group(g1, params, 'group2', mode, 2)
        o = F.relu(utils.batch_norm(g2, params, 'bn', mode))
        o = F.avg_pool2d(o, 8, 1, 0)
        o = o.view(o.size(0), -1)
        o = F.linear(o, params['fc.weight'], params['fc.bias'])
        return o

    return f, flat_params
Esempio n. 11
0
def resnet(depth, width, num_classes):
    assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
    #n为残差块个数
    n = (depth - 4) // 6
    #widths为网络每一层的filter个数
    widths = [int(v * width) for v in (16, 32, 64)]

    #初始化,生成残差块的参数
    def gen_block_params(ni, no):
        return {
            'conv0': utils.conv_params(ni, no, 3),
            'conv1': utils.conv_params(no, no, 3),
            'bn0': utils.bnparams(ni),
            'bn1': utils.bnparams(no),
            'convdim': utils.conv_params(ni, no, 1) if ni != no else None,
        }
#初始化,生成残差组的参数,里面有count个残差块

    def gen_group_params(ni, no, count):
        return {
            'block%d' % i: gen_block_params(ni if i == 0 else no, no)
            for i in range(count)
        }
#初始化,整个网络框架的参数,卷积-组0-组1-组2-
#flatten:将所有的param平铺出来

    flat_params = utils.cast(
        utils.flatten({
            'conv0': utils.conv_params(3, 16, 3),
            'group0': gen_group_params(16, widths[0], n),
            #每一层的宽度不一样
            'group1': gen_group_params(widths[0], widths[1], n),
            'group2': gen_group_params(widths[1], widths[2], n),
            'bn': utils.bnparams(widths[2]),
            'fc': utils.linear_params(widths[2], num_classes),
        }))

    utils.set_requires_grad_except_bn_(flat_params)

    #大一个残差块
    #RELU-卷积-relu-卷积
    def block(x, params, base, mode, stride):
        o1 = F.relu(utils.batch_norm(x, params, base + '.bn0', mode),
                    inplace=True)
        y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1)
        o2 = F.relu(utils.batch_norm(y, params, base + '.bn1', mode),
                    inplace=True)
        z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1)
        #加入的是bottlneck?1*1的卷积层
        if base + '.convdim' in params:
            return z + F.conv2d(o1, params[base + '.convdim'], stride=stride)
        else:
            return z + x
#搭一个残差组

    def group(o, params, base, mode, stride):
        for i in range(n):
            #o = block(o, params, f'{base}.block{i}',mode,stride if i == 0 else 1)
            o = block(o, params, '%s.block%d' % (base, i), mode,
                      stride if i == 0 else 1)
        return o


#搭整个网络
#卷积-第一组-第二组-第三组-relu-池化-扁平化(全连接)

    def f(input, params, mode, base=''):
        x = F.conv2d(input, params[base + 'conv0'], padding=1)
        g0 = group(x, params, base + 'group0', mode, 1)
        g1 = group(g0, params, base + 'group1', mode, 2)
        g2 = group(g1, params, base + 'group2', mode, 2)
        o = F.relu(utils.batch_norm(g2, params, base + 'bn', mode))
        o = F.avg_pool2d(o, 8, 1, 0)
        #将多行的tensor,变为一行
        o = o.view(o.size(0), -1)
        #o = F.linear(o, params[base+'Connection timed outfc.weight'], params[base+'fc.bias'])
        o = F.linear(o, params[base + 'fc.weight'], params[base + 'fc.bias'])
        return o, (g0, g1, g2)
        #返回的是最后的输出,以及每个组输出的tuple

    return f, flat_params
Esempio n. 12
0
    def __init__(self,
                 depth,
                 width,
                 ninputs=3,
                 num_groups=3,
                 num_classes=None,
                 dropout=0.):

        super(WideResNet, self).__init__()
        self.depth = depth
        self.width = width
        self.num_groups = num_groups
        self.num_classes = num_classes
        self.dropout = dropout
        self.mode = True  # Training

        #widths = torch.Tensor([16, 32, 64]).mul(width).int()
        widths = np.array([16, 32, 64]).astype(np.int) * width

        def gen_block_params(ni, no):
            return {
                'conv0': conv_params(ni, no, 3),
                'conv1': conv_params(no, no, 3),
                'bn0': bnparams(ni),
                'bn1': bnparams(no),
                'convdim': conv_params(ni, no, 1) if ni != no else None,
            }

        def gen_group_params(ni, no, count):
            return {
                'block%d' % i: gen_block_params(ni if i == 0 else no, no)
                for i in range(count)
            }

        def gen_group_stats(ni, no, count):
            return {
                'block%d' % i: {
                    'bn0': bnstats(ni if i == 0 else no),
                    'bn1': bnstats(no)
                }
                for i in range(count)
            }

        params = {'conv0': conv_params(ni=ninputs, no=widths[0], k=3)}
        stats = {}

        for i in range(num_groups + 1):
            if i == 0:
                params.update({
                    'group' + str(i):
                    gen_group_params(widths[i], widths[i], depth)
                })
                stats.update({
                    'group' + str(i):
                    gen_group_stats(widths[i], widths[i], depth)
                })
            else:
                params.update({
                    'group' + str(i):
                    gen_group_params(widths[i - 1], widths[i], depth)
                })
                stats.update({
                    'group' + str(i):
                    gen_group_stats(widths[i - 1], widths[i], depth)
                })

        if num_classes is not None:
            params.update({'fc': linear_params(widths[i], num_classes)})
        params.update({'bn': bnparams(widths[i])})
        stats.update({'bn': bnstats(widths[i])})

        params = flatten_params(params)
        stats = flatten_stats(stats)

        self.params = nn.ParameterDict({})
        self.stats = nn.ParameterDict({})
        for key in params.keys():
            self.params.update(
                {key: nn.Parameter(params[key], requires_grad=True)})
        for key in stats.keys():
            self.stats.update(
                {key: nn.Parameter(stats[key], requires_grad=False)})
Esempio n. 13
0
def resnet(depth,
           width,
           num_classes,
           is_full_wrn=True,
           is_fully_convolutional=False):
    #assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
    #n = (depth - 4) // 6
    #wrn = WideResNet(depth, width, ninputs=3,useCuda=True, num_groups=3, num_classes=num_classes)
    n = depth
    widths = torch.Tensor([16, 32, 64]).mul(width).int()

    def gen_block_params(ni, no):
        return {
            'conv0': conv_params(ni, no, 3),
            'conv1': conv_params(no, no, 3),
            'bn0': bnparams(ni),
            'bn1': bnparams(no),
            'convdim': conv_params(ni, no, 1) if ni != no else None,
        }

    def gen_group_params(ni, no, count):
        return {
            'block%d' % i: gen_block_params(ni if i == 0 else no, no)
            for i in range(count)
        }

    def gen_group_stats(ni, no, count):
        return {
            'block%d' % i: {
                'bn0': bnstats(ni if i == 0 else no),
                'bn1': bnstats(no)
            }
            for i in range(count)
        }

    params = {
        'conv0': conv_params(3, 16, 3),
        'group0': gen_group_params(16, widths[0], n),
        'group1': gen_group_params(widths[0], widths[1], n),
        'group2': gen_group_params(widths[1], widths[2], n),
        'bn': bnparams(widths[2]),
        'fc': linear_params(widths[2], num_classes),
    }

    stats = {
        'group0': gen_group_stats(16, widths[0], n),
        'group1': gen_group_stats(widths[0], widths[1], n),
        'group2': gen_group_stats(widths[1], widths[2], n),
        'bn': bnstats(widths[2]),
    }
    if not is_full_wrn:
        ''' omniglot '''
        params['bn'] = bnparams(widths[1])
        #params['fc'] = linear_params(widths[1]*16*16, num_classes)
        params['fc'] = linear_params(widths[1], num_classes)
        stats['bn'] = bnstats(widths[1])
        '''
        # banknote
        params['bn'] = bnparams(widths[2])
        #params['fc'] = linear_params(widths[2]*16*16, num_classes)
        params['fc'] = linear_params(widths[2], num_classes)
        stats['bn'] = bnstats(widths[2])
        '''

    flat_params = flatten_params(params)
    flat_stats = flatten_stats(stats)

    def activation(x, params, stats, base, mode):
        return F.relu(F.batch_norm(x,
                                   weight=params[base + '.weight'],
                                   bias=params[base + '.bias'],
                                   running_mean=stats[base + '.running_mean'],
                                   running_var=stats[base + '.running_var'],
                                   training=mode,
                                   momentum=0.1,
                                   eps=1e-5),
                      inplace=True)

    def block(x, params, stats, base, mode, stride):
        o1 = activation(x, params, stats, base + '.bn0', mode)
        y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1)
        o2 = activation(y, params, stats, base + '.bn1', mode)
        z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1)
        if base + '.convdim' in params:
            return z + F.conv2d(o1, params[base + '.convdim'], stride=stride)
        else:
            return z + x

    def group(o, params, stats, base, mode, stride):
        for i in range(n):
            o = block(o, params, stats, '%s.block%d' % (base, i), mode,
                      stride if i == 0 else 1)
        return o

    def full_wrn(input, params, stats, mode):
        assert input.get_device() == params['conv0'].get_device()
        x = F.conv2d(input, params['conv0'], padding=1)
        g0 = group(x, params, stats, 'group0', mode, 1)
        g1 = group(g0, params, stats, 'group1', mode, 2)
        g2 = group(g1, params, stats, 'group2', mode, 2)
        o = activation(g2, params, stats, 'bn', mode)
        o = F.avg_pool2d(o, o.shape[2], 1, 0)
        o = o.view(o.size(0), -1)
        o = F.linear(o, params['fc.weight'], params['fc.bias'])
        return o

    def not_full_wrn(input, params, stats, mode):
        assert input.get_device() == params['conv0'].get_device()
        x = F.conv2d(input, params['conv0'], padding=1)
        g0 = group(x, params, stats, 'group0', mode, 1)
        g1 = group(g0, params, stats, 'group1', mode, 2)
        # omniglot
        o = activation(g1, params, stats, 'bn', mode)
        o = F.avg_pool2d(o, o.shape[2], 1, 0)
        # banknote
        '''
        g2 = group(g1, params, stats, 'group2', mode, 2)
        o = activation(g2, params, stats, 'bn', mode)
        o = F.avg_pool2d(o, 16, 1, 0)
        '''
        o = o.view(o.size(0), -1)
        o = F.linear(o, params['fc.weight'], params['fc.bias'])
        return o

    def fcn_full_wrn(input, params, stats, mode):
        assert input.get_device() == params['conv0'].get_device()
        x = F.conv2d(input, params['conv0'], padding=1)
        g0 = group(x, params, stats, 'group0', mode, 1)
        g1 = group(g0, params, stats, 'group1', mode, 2)
        g2 = group(g1, params, stats, 'group2', mode, 2)
        o = activation(g2, params, stats, 'bn', mode)
        return o

    def fcn_not_full_wrn(input, params, stats, mode):
        assert input.get_device() == params['conv0'].get_device()
        x = F.conv2d(input, params['conv0'], padding=1)
        g0 = group(x, params, stats, 'group0', mode, 1)
        g1 = group(g0, params, stats, 'group1', mode, 2)
        o = activation(g1, params, stats, 'bn', mode)
        return o

    if is_fully_convolutional:
        if is_full_wrn:
            return fcn_full_wrn, flat_params, flat_stats
        else:
            return fcn_not_full_wrn, flat_params, flat_stats
    else:
        if is_full_wrn:
            return full_wrn, flat_params, flat_stats
        else:
            return not_full_wrn, flat_params, flat_stats
Esempio n. 14
0
def define_student(depth, width):
    definitions = {
        18: [2, 2, 2, 2],
        34: [3, 4, 6, 5],
    }
    assert depth in definitions.keys()
    widths = np.floor(np.asarray([64, 128, 256, 512]) * width).astype(np.int)
    blocks = definitions[depth]

    def batch_norm(x, params, stats, base, mode):
        return F.batch_norm(x,
                            weight=params[base + '.weight'],
                            bias=params[base + '.bias'],
                            running_mean=stats[base + '.running_mean'],
                            running_var=stats[base + '.running_var'],
                            training=mode,
                            momentum=0.1,
                            eps=1e-5)

    def gen_block_params(ni, no):
        return {
            'conv0': conv_params(ni, no, 3),
            'conv1': conv_params(no, no, 3),
            'bn0': bnparams(no),
            'bn1': bnparams(no),
            'convdim': conv_params(ni, no, 1) if ni != no else None,
        }

    def gen_group_params(ni, no, count):
        return {
            'block%d' % i: gen_block_params(ni if i == 0 else no, no)
            for i in range(count)
        }

    def gen_group_stats(no, count):
        return {
            'block%d' % i: {
                'bn0': bnstats(no),
                'bn1': bnstats(no)
            }
            for i in range(count)
        }

    params = nested_dict({
        'conv0':
        conv_params(3, 64, 7),
        'bn0':
        bnparams(64),
        'group0':
        gen_group_params(64, widths[0], blocks[0]),
        'group1':
        gen_group_params(widths[0], widths[1], blocks[1]),
        'group2':
        gen_group_params(widths[1], widths[2], blocks[2]),
        'group3':
        gen_group_params(widths[2], widths[3], blocks[3]),
        'fc':
        linear_params(widths[3], 1000),
    })

    stats = nested_dict({
        'bn0': bnstats(64),
        'group0': gen_group_stats(widths[0], blocks[0]),
        'group1': gen_group_stats(widths[1], blocks[1]),
        'group2': gen_group_stats(widths[2], blocks[2]),
        'group3': gen_group_stats(widths[3], blocks[3]),
    })

    # flatten parameters and additional buffers
    flat_params = OrderedDict()
    flat_stats = OrderedDict()
    for keys, v in params.iteritems_flat():
        if v is not None:
            flat_params['.'.join(keys)] = Variable(v, requires_grad=True)
    for keys, v in stats.iteritems_flat():
        flat_stats['.'.join(keys)] = v

    def block(x, params, stats, base, mode, stride):
        y = F.conv2d(x, params[base + '.conv0'], stride=stride, padding=1)
        o1 = F.relu(batch_norm(y, params, stats, base + '.bn0', mode),
                    inplace=True)
        z = F.conv2d(o1, params[base + '.conv1'], stride=1, padding=1)
        o2 = batch_norm(z, params, stats, base + '.bn1', mode)
        if base + '.convdim' in params:
            return F.relu(
                o2 + F.conv2d(x, params[base + '.convdim'], stride=stride),
                inplace=True)
        else:
            return F.relu(o2 + x, inplace=True)

    def group(o, params, stats, base, mode, stride, n):
        for i in range(n):
            o = block(o, params, stats, '%s.block%d' % (base, i), mode,
                      stride if i == 0 else 1)
        return o

    def f(input, params, stats, mode, pr=''):
        o = F.conv2d(input, params[pr + 'conv0'], stride=2, padding=3)
        o = F.relu(batch_norm(o, params, stats, pr + 'bn0', mode),
                   inplace=True)
        o = F.max_pool2d(o, 3, 2, 1)
        g0 = group(o, params, stats, pr + 'group0', mode, 1, blocks[0])
        g1 = group(g0, params, stats, pr + 'group1', mode, 2, blocks[1])
        g2 = group(g1, params, stats, pr + 'group2', mode, 2, blocks[2])
        g3 = group(g2, params, stats, pr + 'group3', mode, 2, blocks[3])
        o = F.avg_pool2d(g3, 7)
        o = o.view(o.size(0), -1)
        o = F.linear(o, params[pr + 'fc.weight'], params[pr + 'fc.bias'])
        return o, [g0, g1, g2, g3]

    return f, flat_params, flat_stats
 def gen_classifier_params():
     return {
         'fc1': linear_params(512, 4096),
         'fc2': linear_params(4096, 4096),
         'fc3': linear_params(4096, num_classes),
     }
def resnet(depth, width, num_classes, dropout, level=None):
    assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
    assert level is None or level in [2, 3], 'level should be 2, 3 or None'
    n = (depth - 4) // 6
    widths = [int(v * width) for v in (16, 32, 64)]

    def gen_harmonic_params(ni,
                            no,
                            k,
                            normalize=False,
                            level=None,
                            linear=False):
        nf = k**2 if level is None else level * (level + 1) // 2
        paramdict = {
            'conv':
            utils.dct_params(ni, no, nf) if linear else utils.conv_params(
                ni * nf, no, 1)
        }
        if normalize and not linear:
            paramdict.update({'bn': utils.bnparams(ni * nf, affine=False)})
        return paramdict

    def gen_block_params(ni, no):
        return {
            'harmonic0':
            gen_harmonic_params(ni,
                                no,
                                k=3,
                                normalize=False,
                                level=level,
                                linear=True),
            'harmonic1':
            gen_harmonic_params(no,
                                no,
                                k=3,
                                normalize=False,
                                level=level,
                                linear=True),
            'bn0':
            utils.bnparams(ni),
            'bn1':
            utils.bnparams(no),
            'convdim':
            utils.conv_params(ni, no, 1) if ni != no else None,
        }

    def gen_group_params(ni, no, count):
        return {
            'block%d' % i: gen_block_params(ni if i == 0 else no, no)
            for i in range(count)
        }

    flat_params = utils.cast(
        utils.flatten({
            'dct0':
            utils.dct_filters(n=3, groups=3),
            'dct':
            utils.dct_filters(n=3,
                              groups=int(width) * 64,
                              expand_dim=0,
                              level=level),
            'harmonic0':
            gen_harmonic_params(3, 16, k=3, normalize=True, level=None),
            'group0':
            gen_group_params(16, widths[0], n),
            'group1':
            gen_group_params(widths[0], widths[1], n),
            'group2':
            gen_group_params(widths[1], widths[2], n),
            'bn':
            utils.bnparams(widths[2]),
            'fc':
            utils.linear_params(widths[2], num_classes),
        }))

    utils.set_requires_grad_except_bn_(flat_params)

    def harmonic_block(x, params, base, mode, stride=1, padding=1):
        y = F.conv2d(x,
                     params['dct0'],
                     stride=stride,
                     padding=padding,
                     groups=x.size(1))
        if base + '.bn.running_mean' in params:
            y = utils.batch_norm(y, params, base + '.bn', mode, affine=False)
        z = F.conv2d(y, params[base + '.conv'], padding=0)
        return z

    def lin_harmonic_block(x, params, base, mode, stride=1, padding=1):
        filt = torch.sum(params[base + '.conv'] *
                         params['dct'][:x.size(1), ...],
                         dim=2)
        y = F.conv2d(x, filt, stride=stride, padding=padding)
        return y

    def block(x, params, base, mode, stride):
        o1 = F.relu(utils.batch_norm(x, params, base + '.bn0', mode),
                    inplace=True)
        y = lin_harmonic_block(o1,
                               params,
                               base + '.harmonic0',
                               mode,
                               stride=stride,
                               padding=1)
        o2 = F.relu(utils.batch_norm(y, params, base + '.bn1', mode),
                    inplace=True)
        if dropout > 0:
            o2 = F.dropout(o2, p=dropout, training=mode, inplace=False)
        z = lin_harmonic_block(o2,
                               params,
                               base + '.harmonic1',
                               mode,
                               stride=1,
                               padding=1)
        if base + '.convdim' in params:
            return z + F.conv2d(o1, params[base + '.convdim'], stride=stride)
        else:
            return z + x

    def group(o, params, base, mode, stride):
        for i in range(n):
            o = block(o, params, '%s.block%d' % (base, i), mode,
                      stride if i == 0 else 1)
        return o

    def f(input, params, mode):
        x = harmonic_block(input,
                           params,
                           'harmonic0',
                           mode,
                           stride=2,
                           padding=1)
        g0 = group(x, params, 'group0', mode, 1)
        g1 = group(g0, params, 'group1', mode, 2)
        g2 = group(g1, params, 'group2', mode, 2)
        o = F.relu(utils.batch_norm(g2, params, 'bn', mode))
        o = F.avg_pool2d(o, 12, 1, 0)
        o = o.view(o.size(0), -1)
        o = F.linear(o, params['fc.weight'], params['fc.bias'])
        return o

    return f, flat_params
def resnet(depth, width, num_classes,activation):
    assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
    n = (depth - 4) // 6
    widths = torch.Tensor([16, 32, 64]).mul(width).int()
    actfun=None
    if activation=='swish':
        actfun=swish
    elif activation=='new':
        actfun=new
    elif activation=='elu':
        actfun=F.elu
    elif activation=='tanh':
        actfun=F.tanh
    elif activation=='lrelu':
        actfun=F.leaky_relu
    elif activation=='relu':
        actfun=F.relu

    def gen_block_params(ni, no):
        return {
            'conv0': conv_params(ni, no, 3),
            'conv1': conv_params(no, no, 3),
            'bn0': bnparams(ni),
            'bn1': bnparams(no),
            'convdim': conv_params(ni, no, 1) if ni != no else None,
        }

    def gen_group_params(ni, no, count):
        return {'block%d' % i: gen_block_params(ni if i == 0 else no, no)
                for i in range(count)}

    def gen_group_stats(ni, no, count):
        return {'block%d' % i: {'bn0': bnstats(ni if i == 0 else no), 'bn1': bnstats(no)}
                for i in range(count)}

    flat_params = flatten_params({
        'conv0': conv_params(3,16,3),
        'group0': gen_group_params(16, widths[0], n),
        'group1': gen_group_params(widths[0], widths[1], n),
        'group2': gen_group_params(widths[1], widths[2], n),
        'bn': bnparams(widths[2]),
        'fc': linear_params(widths[2], num_classes),
    })

    flat_stats = flatten_stats({
        'group0': gen_group_stats(16, widths[0], n),
        'group1': gen_group_stats(widths[0], widths[1], n),
        'group2': gen_group_stats(widths[1], widths[2], n),
        'bn': bnstats(widths[2]),
    })

    def block(x, params, stats, base, mode, stride):
        o1 = actfun(batch_norm(x, params, stats, base + '.bn0', mode),0.2)
        y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1)
        o2 = actfun(batch_norm(y, params, stats, base + '.bn1', mode),0.2)
        z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1)
        if base + '.convdim' in params:
            return z + F.conv2d(o1, params[base + '.convdim'], stride=stride)
        else:
            return z + x

    def group(o, params, stats, base, mode, stride):
        for i in range(n):
            o = block(o, params, stats, '%s.block%d' % (base,i), mode, stride if i == 0 else 1)
        return o

    def f(input, params, stats, mode):
        x = F.conv2d(input, params['conv0'], padding=1)
        g0 = group(x, params, stats, 'group0', mode, 1)
        g1 = group(g0, params, stats, 'group1', mode, 2)
        g2 = group(g1, params, stats, 'group2', mode, 2)
        o = actfun(batch_norm(g2, params, stats, 'bn', mode),0.2)
        o = F.avg_pool2d(o, 8, 1, 0)
        o = o.view(o.size(0), -1)
        o = F.linear(o, params['fc.weight'], params['fc.bias'])
        return o

    return f, flat_params, flat_stats
Esempio n. 18
0
def define_student(depth, width):
    # wide-resnet-14-2, 21530792
    definitions = {14: [1, 1, 1, 1]}
    assert depth in list(definitions.keys())
    widths = [int(w * width) for w in (64, 128, 256, 512)]
    blocks = definitions[depth]
    print("student model is resnet-{}-{}".format(depth, width))

    def gen_block_params(ni, nm, no):
        return {
            'conv0': utils.conv_params(ni, nm, 1),
            'conv1': utils.conv_params(nm, nm, 3),
            'conv2': utils.conv_params(nm, no, 1),
            'conv_dim': utils.conv_params(ni, no, 1) if ni != no else None,
        }

    def gen_group_params(ni, nm, no, count):
        return {
            'block%d' % i: gen_block_params(ni if i == 0 else no, nm, no)
            for i in range(count)
        }

    flat_params = OrderedDict(
        utils.flatten({
            'conv0':
            utils.conv_params(3, 64, 7),
            'group0':
            gen_group_params(64, widths[0], widths[0] * 2, blocks[0]),
            'group1':
            gen_group_params(widths[0] * 2, widths[1], widths[1] * 2,
                             blocks[1]),
            'group2':
            gen_group_params(widths[1] * 2, widths[2], widths[2] * 2,
                             blocks[2]),
            'group3':
            gen_group_params(widths[2] * 2, widths[3], widths[3] * 2,
                             blocks[3]),
            'fc':
            utils.linear_params(widths[3] * 2, 1000),
        }))

    utils.set_requires_grad_except_bn_(flat_params)

    def conv2d(input, params, base, stride=1, pad=0):
        # return F.conv2d(input, params[base + '.weight'], params[base + '.bias'], stride, pad)
        return F.conv2d(input, params[base], stride=stride, padding=pad)

    def group(input, params, base, stride, n):
        o = input
        for i in range(0, n):
            b_base = ('%s.block%d.conv') % (base, i)
            x = o
            o = conv2d(x, params, b_base + '0')
            o = F.relu(o)
            o = conv2d(o,
                       params,
                       b_base + '1',
                       stride=i == 0 and stride or 1,
                       pad=1)
            o = F.relu(o)
            o = conv2d(o, params, b_base + '2')
            if i == 0:
                o += conv2d(x, params, b_base + '_dim', stride=stride)
            else:
                o += x
            o = F.relu(o)
        return o

    def f(input, params, mode, pr=''):
        # o = F.conv2d(input, params['conv0.weight'], params['conv0.bias'], 2, 3)
        o = conv2d(input, params, pr + 'conv0', 2, 3)
        o = F.relu(o)
        o = F.max_pool2d(o, 3, 2, 1)
        o_g0 = group(o, params, pr + 'group0', 1, blocks[0])
        o_g1 = group(o_g0, params, pr + 'group1', 2, blocks[1])
        o_g2 = group(o_g1, params, pr + 'group2', 2, blocks[2])
        o_g3 = group(o_g2, params, pr + 'group3', 2, blocks[3])
        o = F.avg_pool2d(o_g3, 7, 1, 0)
        o = o.view(o.size(0), -1)
        o = F.linear(o, params[pr + 'fc.weight'], params[pr + 'fc.bias'])
        return o, (o_g0, o_g1, o_g2, o_g3)

    return f, flat_params