Ejemplo n.º 1
0
def permutify(params1, params2):
    """Permute the parameters of params2 to match params1 as closely as possible.
  Returns the permuted version of params2. Only works on sequences of Dense
  layers for now."""
    p1f = flatten_params(params1)
    p2f = flatten_params(params2)

    p2f_new = {**p2f}
    num_layers = max(
        int(kmatch("**/Dense_*/**", k).group(2)) for k in p1f.keys())
    # range is [0, num_layers), so we're safe here since we don't want to be
    # reordering the output of the last layer.
    for layer in range(num_layers):
        # Maximize since we're dealing with similarities, not distances.
        ri, ci = linear_sum_assignment(cosine_similarity(
            p1f[f"params/Dense_{layer}/kernel"].T,
            p2f_new[f"params/Dense_{layer}/kernel"].T),
                                       maximize=True)
        assert (ri == jnp.arange(len(ri))).all()

        p2f_new = {
            **p2f_new, f"params/Dense_{layer}/kernel":
            p2f_new[f"params/Dense_{layer}/kernel"][:, ci],
            f"params/Dense_{layer}/bias":
            p2f_new[f"params/Dense_{layer}/bias"][ci],
            f"params/Dense_{layer+1}/kernel":
            p2f_new[f"params/Dense_{layer+1}/kernel"][ci, :]
        }

    new_params2 = unflatten_params(p2f_new)

    return new_params2
Ejemplo n.º 2
0
    def second_order_approximation(self, X_train, Y_train, X_val, Y_val):
        _, loss_train = self.loss(X_train, Y_train)
        w = flatten_params(self.parameters())
        momentum = self.weights_optimizer.param_groups[0]["momentum"]
        weight_decay = self.weights_optimizer.param_groups[0]["weight_decay"]
        eta = self.weights_optimizer.param_groups[0]["lr"]
        try:
            velocity = flatten_params(
                self.weights_optimizer.state[v]['momentum_buffer']
                for v in self.parameters())
        except KeyError:
            velocity = torch.zeros_like(w)
        # gradient of weight parameters, plus L2 regularization if neccessary.
        w_grad = flatten_params(
            torch.autograd.grad(loss_train,
                                self.parameters())) + weight_decay * w
        velocity = momentum * velocity + w_grad
        w_prime = w - eta * velocity

        unrolled_model = deepcopy(self)
        params, offset = unrolled_model.state_dict(), 0
        for k, v in self.named_parameters():
            v_length = torch.prod(torch.Tensor(list(v.size()))).int()
            params[k] = w_prime[offset:offset + v_length].view(v.size())
            offset += v_length

        _, loss_val = unrolled_model.loss(X_val, Y_val)
        loss_val.backward()

        alpha_grads = [
            alpha.grad for alpha in unrolled_model.arch_parameters
        ]  # derivative of loss_val w.r.t architecture parameters
        w_prime_grads = [
            w_prime.grad for w_prime in unrolled_model.parameters()
        ]  # derivative of loss_val w.r.t w' parameters
        hessian_vector_grads = self._hessian_vector_product(
            w_prime_grads, X_train, Y_train)

        with torch.no_grad():
            for alpha, alpha_grad, hessian_vector_grad in zip(
                    self.arch_parameters, alpha_grads, hessian_vector_grads):
                alpha.grad = alpha_grad - hessian_vector_grad
Ejemplo n.º 3
0
    def _hessian_vector_product(self, w_prime_grads, X_train, Y_train, r=1e-2):
        R = r / flatten_params(w_prime_grads).norm()

        with torch.no_grad():
            for p, v in zip(self.parameters(), w_prime_grads):
                p.add_(R, v)
        _, loss = self.loss(X_train, Y_train)
        grads_p = torch.autograd.grad(loss, self.arch_parameters)

        with torch.no_grad():
            for p, v in zip(self.parameters(), w_prime_grads):
                p.sub_(2 * R, v)
        _, loss = self.loss(X_train, Y_train)
        grads_n = torch.autograd.grad(loss, self.arch_parameters)

        with torch.no_grad():
            for p, v in zip(self.parameters(), w_prime_grads):
                p.add_(R, v)

        return [(x - y).div_(2 * R) for x, y in zip(grads_p, grads_n)]
Ejemplo n.º 4
0
def define_student(depth, width):
    definitions = {
        18: [2, 2, 2, 2],
        34: [3, 4, 6, 5],
    }
    assert depth in list(definitions.keys())
    widths = np.floor(np.asarray([64, 128, 256, 512]) * width).astype(np.int)
    blocks = definitions[depth]

    def gen_block_params(ni, no):
        return {
            'conv0': conv_params(ni, no, 3),
            'conv1': conv_params(no, no, 3),
            'bn0': bnparams(no),
            'bn1': bnparams(no),
            'convdim': conv_params(ni, no, 1) if ni != no else None,
        }

    def gen_group_params(ni, no, count):
        return {
            'block%d' % i: gen_block_params(ni if i == 0 else no, no)
            for i in range(count)
        }

    def gen_group_stats(no, count):
        return {
            'block%d' % i: {
                'bn0': bnstats(no),
                'bn1': bnstats(no)
            }
            for i in range(count)
        }

    params = {
        'conv0': conv_params(3, 64, 7),
        'bn0': bnparams(64),
        'group0': gen_group_params(64, widths[0], blocks[0]),
        'group1': gen_group_params(widths[0], widths[1], blocks[1]),
        'group2': gen_group_params(widths[1], widths[2], blocks[2]),
        'group3': gen_group_params(widths[2], widths[3], blocks[3]),
        'fc': linear_params(widths[3], 1000),
    }

    stats = {
        'bn0': bnstats(64),
        'group0': gen_group_stats(widths[0], blocks[0]),
        'group1': gen_group_stats(widths[1], blocks[1]),
        'group2': gen_group_stats(widths[2], blocks[2]),
        'group3': gen_group_stats(widths[3], blocks[3]),
    }

    # flatten parameters and additional buffers
    flat_params = flatten_params(params)
    flat_stats = flatten_stats(stats)

    def block(x, params, stats, base, mode, stride):
        y = F.conv2d(x, params[base + '.conv0'], stride=stride, padding=1)
        o1 = F.relu(batch_norm(y, params, stats, base + '.bn0', mode),
                    inplace=True)
        z = F.conv2d(o1, params[base + '.conv1'], stride=1, padding=1)
        o2 = batch_norm(z, params, stats, base + '.bn1', mode)
        if base + '.convdim' in params:
            return F.relu(
                o2 + F.conv2d(x, params[base + '.convdim'], stride=stride),
                inplace=True)
        else:
            return F.relu(o2 + x, inplace=True)

    def group(o, params, stats, base, mode, stride, n):
        for i in range(n):
            o = block(o, params, stats, '%s.block%d' % (base, i), mode,
                      stride if i == 0 else 1)
        return o

    def f(input, params, stats, mode, pr=''):
        o = F.conv2d(input, params[pr + 'conv0'], stride=2, padding=3)
        o = F.relu(batch_norm(o, params, stats, pr + 'bn0', mode),
                   inplace=True)
        o = F.max_pool2d(o, 3, 2, 1)
        g0 = group(o, params, stats, pr + 'group0', mode, 1, blocks[0])
        g1 = group(g0, params, stats, pr + 'group1', mode, 2, blocks[1])
        g2 = group(g1, params, stats, pr + 'group2', mode, 2, blocks[2])
        g3 = group(g2, params, stats, pr + 'group3', mode, 2, blocks[3])
        o = F.avg_pool2d(g3, 7)
        o = o.view(o.size(0), -1)
        o = F.linear(o, params[pr + 'fc.weight'], params[pr + 'fc.bias'])
        return o, [g0, g1, g2, g3]

    return f, flat_params, flat_stats
Ejemplo n.º 5
0
def resnet(depth, width, num_classes, stu_depth=0):
    assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
    n = (depth - 4) // 6
    if stu_depth != 0:
        assert (stu_depth - 4) % 6 == 0, 'student depth should be 6n+4'
        n_s = (stu_depth - 4) // 6
    else:
        n_s = 0

    widths = torch.Tensor([16, 32, 64]).mul(width).int()

    def gen_block_params(ni, no):
        return {
            'conv0': conv_params(ni, no, 3),
            'conv1': conv_params(no, no, 3),
            'bn0': bnparams(ni),
            'bn1': bnparams(no),
            'bns0': bnparams(ni),
            'bns1': bnparams(no),
            'convdim': conv_params(ni, no, 1) if ni != no else None,
        }

    def gen_group_params(ni, no, count):
        return {
            'block%d' % i: gen_block_params(ni if i == 0 else no, no)
            for i in range(count)
        }

    def gen_group_stats(ni, no, count):
        return {
            'block%d' % i: {
                'bn0': bnstats(ni if i == 0 else no),
                'bn1': bnstats(no),
                'bns0': bnstats(ni if i == 0 else no),
                'bns1': bnstats(no)
            }
            for i in range(count)
        }

    if stu_depth != 0 and not opt.param_share:
        params = {
            'conv0': conv_params(3, 16, 3),
            'group0': gen_group_params(16, widths[0], n),
            'group1': gen_group_params(widths[0], widths[1], n),
            'group2': gen_group_params(widths[1], widths[2], n),
            'groups0': gen_group_params(16, widths[0], n_s),
            'groups1': gen_group_params(widths[0], widths[1], n_s),
            'groups2': gen_group_params(widths[1], widths[2], n_s),
            'bn': bnparams(widths[2]),
            'bns': bnparams(widths[2]),
            'fc': linear_params(widths[2], num_classes),
            'fcs': linear_params(widths[2], num_classes),
        }

        stats = {
            'group0': gen_group_stats(16, widths[0], n),
            'group1': gen_group_stats(widths[0], widths[1], n),
            'group2': gen_group_stats(widths[1], widths[2], n),
            'groups0': gen_group_stats(16, widths[0], n_s),
            'groups1': gen_group_stats(widths[0], widths[1], n_s),
            'groups2': gen_group_stats(widths[1], widths[2], n_s),
            'bn': bnstats(widths[2]),
            'bns': bnstats(widths[2]),
        }
    else:
        params = {
            'conv0': conv_params(3, 16, 3),
            'group0': gen_group_params(16, widths[0], n),
            'group1': gen_group_params(widths[0], widths[1], n),
            'group2': gen_group_params(widths[1], widths[2], n),
            'bn': bnparams(widths[2]),
            'bns': bnparams(widths[2]),
            'fc': linear_params(widths[2], num_classes),
            'fcs': linear_params(widths[2], num_classes),
        }

        stats = {
            'group0': gen_group_stats(16, widths[0], n),
            'group1': gen_group_stats(widths[0], widths[1], n),
            'group2': gen_group_stats(widths[1], widths[2], n),
            'bn': bnstats(widths[2]),
            'bns': bnstats(widths[2]),
        }

    flat_params = flatten_params(params)
    flat_stats = flatten_stats(stats)

    def block(x, params, stats, base, mode, stride, flag, drop_switch=True):
        if flag == 's':
            o1 = F.relu(batch_norm(x, params, stats, base + '.bns0', mode))
            y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1)
            o2 = F.relu(batch_norm(y, params, stats, base + '.bns1', mode))
            z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1)
            if base + '.convdim' in params:
                return z + F.conv2d(
                    o1, params[base + '.convdim'], stride=stride)
            else:
                return z + x
        o1 = F.relu(batch_norm(x, params, stats, base + '.bn0', mode))
        y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1)
        o2 = F.relu(batch_norm(y, params, stats, base + '.bn1', mode))
        if opt.dropout > 0 and drop_switch:
            o2 = F.dropout(o2, p=opt.dropout, training=mode)
        z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1)
        if base + '.convdim' in params:
            return z + F.conv2d(o1, params[base + '.convdim'], stride=stride)
        else:
            return z + x

    def group(o, params, stats, base, mode, stride):
        for i in range(n):
            o = block(o, params, stats, '%s.block%d' % (base, i), mode,
                      stride if i == 0 else 1, 't', False)
        return o

    def group_student(o, params, stats, base, mode, stride, n_layer):
        for i in range(n_layer):
            o = block(o, params, stats, '%s.block%d' % (base, i), mode,
                      stride if i == 0 else 1, 's', False)
        return o

    def f(input, params, stats, mode, prefix=''):
        x = F.conv2d(input, params[prefix + 'conv0'], padding=1)
        g0 = group(x, params, stats, prefix + 'group0', mode, 1)
        g1 = group(g0, params, stats, prefix + 'group1', mode, 2)
        g2 = group(g1, params, stats, prefix + 'group2', mode, 2)
        o = F.relu(batch_norm(g2, params, stats, prefix + 'bn', mode))
        o = F.avg_pool2d(o, 8, 1, 0)
        o = o.view(o.size(0), -1)
        o = F.linear(o, params[prefix + 'fc.weight'],
                     params[prefix + 'fc.bias'])
        #x_s = F.conv2d(input, params[prefix+'conv0_s'], padding=1)
        if stu_depth != 0:
            if opt.param_share:
                gs0 = group_student(x, params, stats, prefix + 'group0', mode,
                                    1, n_s)
                gs1 = group_student(gs0, params, stats, prefix + 'group1',
                                    mode, 2, n_s)
                gs2 = group_student(gs1, params, stats, prefix + 'group2',
                                    mode, 2, n_s)
            else:
                gs0 = group_student(x, params, stats, prefix + 'groups0', mode,
                                    1, n_s)
                gs1 = group_student(gs0, params, stats, prefix + 'groups1',
                                    mode, 2, n_s)
                gs2 = group_student(gs1, params, stats, prefix + 'groups2',
                                    mode, 2, n_s)

            os = F.relu(batch_norm(gs2, params, stats, prefix + 'bns', mode))
            os = F.avg_pool2d(os, 8, 1, 0)
            os = os.view(os.size(0), -1)
            os = F.linear(os, params[prefix + 'fcs.weight'],
                          params[prefix + 'fcs.bias'])
            return os, o, [g0, g1, g2, gs0, gs1, gs2]
        else:
            return o, [g0, g1, g2]

    return f, flat_params, flat_stats
Ejemplo n.º 6
0
def resnet(depth,
           width,
           num_classes,
           is_full_wrn=True,
           is_fully_convolutional=False):
    #assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
    #n = (depth - 4) // 6
    #wrn = WideResNet(depth, width, ninputs=3,useCuda=True, num_groups=3, num_classes=num_classes)
    n = depth
    widths = torch.Tensor([16, 32, 64]).mul(width).int()

    def gen_block_params(ni, no):
        return {
            'conv0': conv_params(ni, no, 3),
            'conv1': conv_params(no, no, 3),
            'bn0': bnparams(ni),
            'bn1': bnparams(no),
            'convdim': conv_params(ni, no, 1) if ni != no else None,
        }

    def gen_group_params(ni, no, count):
        return {
            'block%d' % i: gen_block_params(ni if i == 0 else no, no)
            for i in range(count)
        }

    def gen_group_stats(ni, no, count):
        return {
            'block%d' % i: {
                'bn0': bnstats(ni if i == 0 else no),
                'bn1': bnstats(no)
            }
            for i in range(count)
        }

    params = {
        'conv0': conv_params(3, 16, 3),
        'group0': gen_group_params(16, widths[0], n),
        'group1': gen_group_params(widths[0], widths[1], n),
        'group2': gen_group_params(widths[1], widths[2], n),
        'bn': bnparams(widths[2]),
        'fc': linear_params(widths[2], num_classes),
    }

    stats = {
        'group0': gen_group_stats(16, widths[0], n),
        'group1': gen_group_stats(widths[0], widths[1], n),
        'group2': gen_group_stats(widths[1], widths[2], n),
        'bn': bnstats(widths[2]),
    }
    if not is_full_wrn:
        ''' omniglot '''
        params['bn'] = bnparams(widths[1])
        #params['fc'] = linear_params(widths[1]*16*16, num_classes)
        params['fc'] = linear_params(widths[1], num_classes)
        stats['bn'] = bnstats(widths[1])
        '''
        # banknote
        params['bn'] = bnparams(widths[2])
        #params['fc'] = linear_params(widths[2]*16*16, num_classes)
        params['fc'] = linear_params(widths[2], num_classes)
        stats['bn'] = bnstats(widths[2])
        '''

    flat_params = flatten_params(params)
    flat_stats = flatten_stats(stats)

    def activation(x, params, stats, base, mode):
        return F.relu(F.batch_norm(x,
                                   weight=params[base + '.weight'],
                                   bias=params[base + '.bias'],
                                   running_mean=stats[base + '.running_mean'],
                                   running_var=stats[base + '.running_var'],
                                   training=mode,
                                   momentum=0.1,
                                   eps=1e-5),
                      inplace=True)

    def block(x, params, stats, base, mode, stride):
        o1 = activation(x, params, stats, base + '.bn0', mode)
        y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1)
        o2 = activation(y, params, stats, base + '.bn1', mode)
        z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1)
        if base + '.convdim' in params:
            return z + F.conv2d(o1, params[base + '.convdim'], stride=stride)
        else:
            return z + x

    def group(o, params, stats, base, mode, stride):
        for i in range(n):
            o = block(o, params, stats, '%s.block%d' % (base, i), mode,
                      stride if i == 0 else 1)
        return o

    def full_wrn(input, params, stats, mode):
        assert input.get_device() == params['conv0'].get_device()
        x = F.conv2d(input, params['conv0'], padding=1)
        g0 = group(x, params, stats, 'group0', mode, 1)
        g1 = group(g0, params, stats, 'group1', mode, 2)
        g2 = group(g1, params, stats, 'group2', mode, 2)
        o = activation(g2, params, stats, 'bn', mode)
        o = F.avg_pool2d(o, o.shape[2], 1, 0)
        o = o.view(o.size(0), -1)
        o = F.linear(o, params['fc.weight'], params['fc.bias'])
        return o

    def not_full_wrn(input, params, stats, mode):
        assert input.get_device() == params['conv0'].get_device()
        x = F.conv2d(input, params['conv0'], padding=1)
        g0 = group(x, params, stats, 'group0', mode, 1)
        g1 = group(g0, params, stats, 'group1', mode, 2)
        # omniglot
        o = activation(g1, params, stats, 'bn', mode)
        o = F.avg_pool2d(o, o.shape[2], 1, 0)
        # banknote
        '''
        g2 = group(g1, params, stats, 'group2', mode, 2)
        o = activation(g2, params, stats, 'bn', mode)
        o = F.avg_pool2d(o, 16, 1, 0)
        '''
        o = o.view(o.size(0), -1)
        o = F.linear(o, params['fc.weight'], params['fc.bias'])
        return o

    def fcn_full_wrn(input, params, stats, mode):
        assert input.get_device() == params['conv0'].get_device()
        x = F.conv2d(input, params['conv0'], padding=1)
        g0 = group(x, params, stats, 'group0', mode, 1)
        g1 = group(g0, params, stats, 'group1', mode, 2)
        g2 = group(g1, params, stats, 'group2', mode, 2)
        o = activation(g2, params, stats, 'bn', mode)
        return o

    def fcn_not_full_wrn(input, params, stats, mode):
        assert input.get_device() == params['conv0'].get_device()
        x = F.conv2d(input, params['conv0'], padding=1)
        g0 = group(x, params, stats, 'group0', mode, 1)
        g1 = group(g0, params, stats, 'group1', mode, 2)
        o = activation(g1, params, stats, 'bn', mode)
        return o

    if is_fully_convolutional:
        if is_full_wrn:
            return fcn_full_wrn, flat_params, flat_stats
        else:
            return fcn_not_full_wrn, flat_params, flat_stats
    else:
        if is_full_wrn:
            return full_wrn, flat_params, flat_stats
        else:
            return not_full_wrn, flat_params, flat_stats
Ejemplo n.º 7
0
def resnet(depth, width, num_classes):
    assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
    n = (depth - 4) // 6
    widths = [int(x * width) for x in [16, 32, 64]]

    def gen_block_params(ni, no):
        return {
            'conv0': conv_params(ni, no, 3),
            'conv1': conv_params(no, no, 3),
            'bn0': bnparams(ni),
            'bn1': bnparams(no),
            'convdim': conv_params(ni, no, 1) if ni != no else None,
        }

    def gen_group_params(ni, no, count):
        return {'block%d' % i: gen_block_params(ni if i == 0 else no, no)
                for i in range(count)}

    def gen_group_stats(ni, no, count):
        return {'block%d' % i: {'bn0': bnstats(ni if i == 0 else no), 'bn1': bnstats(no)}
                for i in range(count)}

    flat_params = flatten_params({
        'conv0': conv_params(3,16,3),
        'group0': gen_group_params(16, widths[0], n),
        'group1': gen_group_params(widths[0], widths[1], n),
        'group2': gen_group_params(widths[1], widths[2], n),
        'bn': bnparams(widths[2]),
        'fc': linear_params(widths[2], num_classes),
    })

    flat_stats = flatten_stats({
        'group0': gen_group_stats(16, widths[0], n),
        'group1': gen_group_stats(widths[0], widths[1], n),
        'group2': gen_group_stats(widths[1], widths[2], n),
        'bn': bnstats(widths[2]),
    })

    def block(x, params, stats, base, mode, stride):
        o1 = F.relu(batch_norm(x, params, stats, base + '.bn0', mode), inplace=True)
        y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1)
        o2 = F.relu(batch_norm(y, params, stats, base + '.bn1', mode), inplace=True)
        z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1)
        if base + '.convdim' in params:
            return z + F.conv2d(o1, params[base + '.convdim'], stride=stride)
        else:
            return z + x

    def group(o, params, stats, base, mode, stride):
        for i in range(n):
            o = block(o, params, stats, '%s.block%d' % (base,i), mode, stride if i == 0 else 1)
        return o

    def f(input, params, stats, mode):
        x = F.conv2d(input, params['conv0'], padding=1)
        g0 = group(x, params, stats, 'group0', mode, 1)
        g1 = group(g0, params, stats, 'group1', mode, 2)
        g2 = group(g1, params, stats, 'group2', mode, 2)
        o = F.relu(batch_norm(g2, params, stats, 'bn', mode))
        o = F.avg_pool2d(o, 8, 1, 0)
        o = o.view(o.size(0), -1)
        o = F.linear(o, params['fc.weight'], params['fc.bias'])
        return o

    return f, flat_params, flat_stats
Ejemplo n.º 8
0
def find_feature_importance(net):
    """Get a vector indicating the importance of features in the network"""
    with torch.no_grad():
        w_t = utils.flatten_params(net.get_params(), net.params)
        return abs(w_t - w_t.mean()) / sum(abs(w_t))
Ejemplo n.º 9
0
def resnet(depth, width, num_classes):
    assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
    n = (depth - 4) // 6
    widths = torch.Tensor([16, 32, 64]).mul(width).int().numpy().tolist()

    def gen_block_params(ni, no, scalar):
        if scalar:
            return {
                'bn0': bnparams(ni),
                'bn1': bnparams(no),
            }
        return {
            'conv0': conv_params(ni, no, 3),
            'conv1': conv_params(no, no, 3),
            'convdim': conv_params(ni, no, 1) if ni != no else None,
        }

    def gen_group_params(ni, no, count, bias=False):
        return {
            'block%d' % i: gen_block_params(ni if i == 0 else no, no, bias)
            for i in range(count)
        }

    def gen_group_stats(ni, no, count):
        return {
            'block%d' % i: {
                'bn0': bnstats(ni if i == 0 else no),
                'bn1': bnstats(no)
            }
            for i in range(count)
        }

    flat_vectors = flatten_params({
        'conv0':
        conv_params(3, 16, 3),
        'group0':
        gen_group_params(16, widths[0], n),
        'group1':
        gen_group_params(widths[0], widths[1], n),
        'group2':
        gen_group_params(widths[1], widths[2], n),
        'conv1':
        conv_params(widths[2], num_classes, 1),
    })

    flat_scalars = flatten_params({
        'group0':
        gen_group_params(16, widths[0], n, True),
        'group1':
        gen_group_params(widths[0], widths[1], n, True),
        'group2':
        gen_group_params(widths[1], widths[2], n, True),
        'bn':
        bnparams(widths[2]),
    })

    flat_stats = flatten_stats({
        'group0':
        gen_group_stats(16, widths[0], n),
        'group1':
        gen_group_stats(widths[0], widths[1], n),
        'group2':
        gen_group_stats(widths[1], widths[2], n),
        'bn':
        bnstats(widths[2]),
    })

    def block(x, params, stats, base, mode, stride):
        o1 = F.relu(batch_norm(x, params, stats, base + '.bn0', mode, 1.),
                    inplace=True)
        y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1)
        o2 = F.relu(batch_norm(y, params, stats, base + '.bn1', mode, 1.),
                    inplace=True)
        z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1)
        if base + '.convdim' in params:
            return z + F.conv2d(o1, params[base + '.convdim'], stride=stride)
        else:
            return z + x

    def group(o, params, stats, base, mode, stride):
        for i in range(n):
            o = block(o, params, stats, '%s.block%d' % (base, i), mode,
                      stride if i == 0 else 1)
        return o

    def f(input, params, stats, mode):
        x = F.conv2d(input, params['conv0'], padding=1)
        g0 = group(x, params, stats, 'group0', mode, 1)
        g1 = group(g0, params, stats, 'group1', mode, 2)
        g2 = group(g1, params, stats, 'group2', mode, 2)
        o = F.relu(batch_norm(g2, params, stats, 'bn', mode, 1.))
        o = F.conv2d(o, params['conv1'])
        o = F.avg_pool2d(o, 8, 1, 0)
        o = o.view(o.size(0), -1)
        return o

    return f, flat_vectors, flat_scalars, flat_stats
Ejemplo n.º 10
0
def resnet(depth, width, num_classes,activation):
    assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
    n = (depth - 4) // 6
    widths = torch.Tensor([16, 32, 64]).mul(width).int()
    actfun=None
    if activation=='swish':
        actfun=swish
    elif activation=='new':
        actfun=new
    elif activation=='elu':
        actfun=F.elu
    elif activation=='tanh':
        actfun=F.tanh
    elif activation=='lrelu':
        actfun=F.leaky_relu
    elif activation=='relu':
        actfun=F.relu

    def gen_block_params(ni, no):
        return {
            'conv0': conv_params(ni, no, 3),
            'conv1': conv_params(no, no, 3),
            'bn0': bnparams(ni),
            'bn1': bnparams(no),
            'convdim': conv_params(ni, no, 1) if ni != no else None,
        }

    def gen_group_params(ni, no, count):
        return {'block%d' % i: gen_block_params(ni if i == 0 else no, no)
                for i in range(count)}

    def gen_group_stats(ni, no, count):
        return {'block%d' % i: {'bn0': bnstats(ni if i == 0 else no), 'bn1': bnstats(no)}
                for i in range(count)}

    flat_params = flatten_params({
        'conv0': conv_params(3,16,3),
        'group0': gen_group_params(16, widths[0], n),
        'group1': gen_group_params(widths[0], widths[1], n),
        'group2': gen_group_params(widths[1], widths[2], n),
        'bn': bnparams(widths[2]),
        'fc': linear_params(widths[2], num_classes),
    })

    flat_stats = flatten_stats({
        'group0': gen_group_stats(16, widths[0], n),
        'group1': gen_group_stats(widths[0], widths[1], n),
        'group2': gen_group_stats(widths[1], widths[2], n),
        'bn': bnstats(widths[2]),
    })

    def block(x, params, stats, base, mode, stride):
        o1 = actfun(batch_norm(x, params, stats, base + '.bn0', mode),0.2)
        y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1)
        o2 = actfun(batch_norm(y, params, stats, base + '.bn1', mode),0.2)
        z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1)
        if base + '.convdim' in params:
            return z + F.conv2d(o1, params[base + '.convdim'], stride=stride)
        else:
            return z + x

    def group(o, params, stats, base, mode, stride):
        for i in range(n):
            o = block(o, params, stats, '%s.block%d' % (base,i), mode, stride if i == 0 else 1)
        return o

    def f(input, params, stats, mode):
        x = F.conv2d(input, params['conv0'], padding=1)
        g0 = group(x, params, stats, 'group0', mode, 1)
        g1 = group(g0, params, stats, 'group1', mode, 2)
        g2 = group(g1, params, stats, 'group2', mode, 2)
        o = actfun(batch_norm(g2, params, stats, 'bn', mode),0.2)
        o = F.avg_pool2d(o, 8, 1, 0)
        o = o.view(o.size(0), -1)
        o = F.linear(o, params['fc.weight'], params['fc.bias'])
        return o

    return f, flat_params, flat_stats
Ejemplo n.º 11
0
input('\nProgram paused. Press enter to continue.\n');

## ================ Part 2: Loading Parameters ================
# In this part of the exercise, we load some pre-initialized 
# neural network parameters.

print('> Loading Saved Neural Network Parameters ...')

# Load the weights into variables Theta1 and Theta2
weights = sio.loadmat('../data/ex4weights.mat');
Theta1 = weights['Theta1']
Theta2 = weights['Theta2']

# Unroll parameters 
nn_params = flatten_params(Theta1, Theta2)

## ================ Part 3: Compute Cost (Feedforward) ================
#  To the neural network, you should first start by implementing the
#  feedforward part of the neural network that returns the cost only. You
#  should complete the code in nnCostFunction.m to return cost. After
#  implementing the feedforward to compute the cost, you can verify that
#  your implementation is correct by verifying that you get the same cost
#  as us for the fixed debugging parameters.
#
#  We suggest implementing the feedforward cost *without* regularization
#  first so that it will be easier for you to debug. Later, in part 4, you
#  will get to implement the regularized cost.
#
print('> Feedforward Using Neural Network ...')
Ejemplo n.º 12
0
def resnet(depth, width, num_classes):
    assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
    n = (depth - 4) // 6
    widths = torch.Tensor([16, 32, 64]).mul(width).int()

    def gen_block_params(ni, no):
        return {
            'conv0': conv_params(ni, no, 3),
            'conv1': conv_params(no, no, 3),
            'bn0': bnparams(ni),
            'bn1': bnparams(no),
            'convdim': conv_params(ni, no, 1) if ni != no else None,
        }

    def gen_group_params(ni, no, count):
        return {'block%d' % i: gen_block_params(ni if i == 0 else no, no)
                for i in range(count)}

    def gen_group_stats(ni, no, count):
        return {'block%d' % i: {'bn0': bnstats(ni if i == 0 else no), 'bn1': bnstats(no)}
                for i in range(count)}

    flat_params = flatten_params({
        'conv0': conv_params(3,16,3),
        'group0': gen_group_params(16, widths[0], n),
        'group1': gen_group_params(widths[0], widths[1], n),
        'group2': gen_group_params(widths[1], widths[2], n),
        'bn': bnparams(widths[2]),
        'fc': linear_params(widths[2], num_classes),
    })

    flat_stats = flatten_stats({
        'group0': gen_group_stats(16, widths[0], n),
        'group1': gen_group_stats(widths[0], widths[1], n),
        'group2': gen_group_stats(widths[1], widths[2], n),
        'bn': bnstats(widths[2]),
    })

    def block(x, params, stats, base, mode, stride):
        o1 = F.relu(batch_norm(x, params, stats, base + '.bn0', mode), inplace=True)
        y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1)
        o2 = F.relu(batch_norm(y, params, stats, base + '.bn1', mode), inplace=True)
        z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1)
        if base + '.convdim' in params:
            return z + F.conv2d(o1, params[base + '.convdim'], stride=stride)
        else:
            return z + x

    def group(o, params, stats, base, mode, stride):
        for i in range(n):
            o = block(o, params, stats, '%s.block%d' % (base,i), mode, stride if i == 0 else 1)
        return o

    def f(input, params, stats, mode):
        x = F.conv2d(input, params['conv0'], padding=1)
        g0 = group(x, params, stats, 'group0', mode, 1)
        g1 = group(g0, params, stats, 'group1', mode, 2)
        g2 = group(g1, params, stats, 'group2', mode, 2)
        o = F.relu(batch_norm(g2, params, stats, 'bn', mode))
        o = F.avg_pool2d(o, 8, 1, 0)
        o = o.view(o.size(0), -1)
        o = F.linear(o, params['fc.weight'], params['fc.bias'])
        return o

    return f, flat_params, flat_stats
Ejemplo n.º 13
0
input('\nProgram paused. Press enter to continue.\n')

## ================ Part 2: Loading Parameters ================
# In this part of the exercise, we load some pre-initialized
# neural network parameters.

print('> Loading Saved Neural Network Parameters ...')

# Load the weights into variables Theta1 and Theta2
weights = sio.loadmat('../data/ex4weights.mat')
Theta1 = weights['Theta1']
Theta2 = weights['Theta2']

# Unroll parameters
nn_params = flatten_params(Theta1, Theta2)

## ================ Part 3: Compute Cost (Feedforward) ================
#  To the neural network, you should first start by implementing the
#  feedforward part of the neural network that returns the cost only. You
#  should complete the code in nnCostFunction.m to return cost. After
#  implementing the feedforward to compute the cost, you can verify that
#  your implementation is correct by verifying that you get the same cost
#  as us for the fixed debugging parameters.
#
#  We suggest implementing the feedforward cost *without* regularization
#  first so that it will be easier for you to debug. Later, in part 4, you
#  will get to implement the regularized cost.
#
print('> Feedforward Using Neural Network ...')
Ejemplo n.º 14
0
    def __init__(self,
                 depth,
                 width,
                 ninputs=3,
                 num_groups=3,
                 num_classes=None,
                 dropout=0.):

        super(WideResNet, self).__init__()
        self.depth = depth
        self.width = width
        self.num_groups = num_groups
        self.num_classes = num_classes
        self.dropout = dropout
        self.mode = True  # Training

        #widths = torch.Tensor([16, 32, 64]).mul(width).int()
        widths = np.array([16, 32, 64]).astype(np.int) * width

        def gen_block_params(ni, no):
            return {
                'conv0': conv_params(ni, no, 3),
                'conv1': conv_params(no, no, 3),
                'bn0': bnparams(ni),
                'bn1': bnparams(no),
                'convdim': conv_params(ni, no, 1) if ni != no else None,
            }

        def gen_group_params(ni, no, count):
            return {
                'block%d' % i: gen_block_params(ni if i == 0 else no, no)
                for i in range(count)
            }

        def gen_group_stats(ni, no, count):
            return {
                'block%d' % i: {
                    'bn0': bnstats(ni if i == 0 else no),
                    'bn1': bnstats(no)
                }
                for i in range(count)
            }

        params = {'conv0': conv_params(ni=ninputs, no=widths[0], k=3)}
        stats = {}

        for i in range(num_groups + 1):
            if i == 0:
                params.update({
                    'group' + str(i):
                    gen_group_params(widths[i], widths[i], depth)
                })
                stats.update({
                    'group' + str(i):
                    gen_group_stats(widths[i], widths[i], depth)
                })
            else:
                params.update({
                    'group' + str(i):
                    gen_group_params(widths[i - 1], widths[i], depth)
                })
                stats.update({
                    'group' + str(i):
                    gen_group_stats(widths[i - 1], widths[i], depth)
                })

        if num_classes is not None:
            params.update({'fc': linear_params(widths[i], num_classes)})
        params.update({'bn': bnparams(widths[i])})
        stats.update({'bn': bnstats(widths[i])})

        params = flatten_params(params)
        stats = flatten_stats(stats)

        self.params = nn.ParameterDict({})
        self.stats = nn.ParameterDict({})
        for key in params.keys():
            self.params.update(
                {key: nn.Parameter(params[key], requires_grad=True)})
        for key in stats.keys():
            self.stats.update(
                {key: nn.Parameter(stats[key], requires_grad=False)})
Ejemplo n.º 15
0
# Train all weights.
net = make_net([2048] * 6)
train(net,
      init_params=net.init(random.PRNGKey(0), jnp.zeros((1, 28 * 28))),
      trainable_predicate=lambda k: True,
      log_prefix="normal")

print("Training only gains model...")
# Train only gains, keeping all other weights fixed.
net = make_net([2048] * 6)
only_gains_final_params = train(
    net,
    init_params=net.init(random.PRNGKey(0), jnp.zeros((1, 28 * 28))),
    trainable_predicate=lambda k: kmatch("**/gain", k),
    log_prefix="only_gain")
only_gains_final_params_flat = flatten_params(only_gains_final_params)
print("  full model params:")
print(tree_map(jnp.shape, only_gains_final_params_flat))
gain_params = {
    k: v
    for k, v in only_gains_final_params_flat.items() if kmatch("**/gain", k)
}
gain_params_flat, unravel = ravel_pytree(gain_params)
cutoff = jnp.percentile(jnp.abs(gain_params_flat), config.remove_percentile)
# A mask that identifies only those gains that have the largest absolute
# value. Only keep the top `(100 - config.remove_percentile)%` gains. Note
# that this isn't the traditional LTH approach. It's more accurately described
# as a structural lottery ticket.
gain_mask = binarize(unravel(jnp.abs(gain_params_flat) > cutoff))
print(tree_map(jnp.sum, gain_mask))
Ejemplo n.º 16
0
def resnet(depth, width, num_classes):
    assert (depth - 4) % 6 == 0, 'depth should be 6n+4'
    n = (depth - 4) // 6
    widths = torch.Tensor([16, 32, 64]).mul(width).int()

    def gen_block_params(ni, no):
        return {
            'conv0': conv_params(ni, no, 3),
            'conv1': conv_params(no, no, 3),
            'bn0': bnparams(ni),
            'bn1': bnparams(no),
            'convdim': conv_params(ni, no, 1) if ni != no else None,
        }

    def gen_group_params(ni, no, count):
        return {
            'block%d' % i: gen_block_params(ni if i == 0 else no, no)
            for i in range(count)
        }

    def gen_group_stats(ni, no, count):
        return {
            'block%d' % i: {
                'bn0': bnstats(ni if i == 0 else no),
                'bn1': bnstats(no)
            }
            for i in range(count)
        }

    params = {
        'conv0': conv_params(3, 16, 3),
        'group0': gen_group_params(16, widths[0], n),
        'group1': gen_group_params(widths[0], widths[1], n),
        'group2': gen_group_params(widths[1], widths[2], n),
        'bn': bnparams(widths[2]),
        'fc': linear_params(widths[2], num_classes),
    }

    stats = {
        'group0': gen_group_stats(16, widths[0], n),
        'group1': gen_group_stats(widths[0], widths[1], n),
        'group2': gen_group_stats(widths[1], widths[2], n),
        'bn': bnstats(widths[2]),
    }

    flat_params = flatten_params(params)
    flat_stats = flatten_stats(stats)

    def activation(x, params, stats, base, mode):
        return F.relu(F.batch_norm(x,
                                   weight=params[base + '.weight'],
                                   bias=params[base + '.bias'],
                                   running_mean=stats[base + '.running_mean'],
                                   running_var=stats[base + '.running_var'],
                                   training=mode,
                                   momentum=0.1,
                                   eps=1e-5),
                      inplace=True)

    def block(x, params, stats, base, mode, stride):
        o1 = activation(x, params, stats, base + '.bn0', mode)
        y = F.conv2d(o1, params[base + '.conv0'], stride=stride, padding=1)
        o2 = activation(y, params, stats, base + '.bn1', mode)
        z = F.conv2d(o2, params[base + '.conv1'], stride=1, padding=1)
        if base + '.convdim' in params:
            return z + F.conv2d(o1, params[base + '.convdim'], stride=stride)
        else:
            return z + x

    def group(o, params, stats, base, mode, stride):
        for i in range(n):
            o = block(o, params, stats, '%s.block%d' % (base, i), mode,
                      stride if i == 0 else 1)
        return o

    def f(input, params, stats, mode):
        assert input.get_device() == params['conv0'].get_device()
        x = F.conv2d(input, params['conv0'], padding=1)
        g0 = group(x, params, stats, 'group0', mode, 1)
        g1 = group(g0, params, stats, 'group1', mode, 2)
        g2 = group(g1, params, stats, 'group2', mode, 2)
        o = activation(g2, params, stats, 'bn', mode)
        o = F.avg_pool2d(o, 8, 1, 0)
        o = o.view(o.size(0), -1)
        o = F.linear(o, params['fc.weight'], params['fc.bias'])
        return o

    return f, flat_params, flat_stats
Ejemplo n.º 17
0
def train(net, init_params, trainable_predicate, log_prefix):
    def loss(trainable_params, untrainable_params, batch):
        inputs, targets = batch
        preds = net.apply(merge_params(trainable_params, untrainable_params),
                          inputs)
        return -jnp.mean(jnp.sum(preds * targets, axis=1))

    def accuracy(trainable_params, untrainable_params, batch):
        inputs, targets = batch
        target_class = jnp.argmax(targets, axis=1)
        params = merge_params(trainable_params, untrainable_params)
        predicted_class = jnp.argmax(net.apply(params, inputs), axis=1)
        return jnp.mean(predicted_class == target_class)

    tx = optax.adam(config.learning_rate)

    @jit
    def update(opt_state, trainable_params, untrainable_params, batch):
        batch_loss, g = value_and_grad(loss)(trainable_params,
                                             untrainable_params, batch)
        # Standard gradient update on the smooth part.
        updates, opt_state = tx.update(g, opt_state)
        trainable_params = optax.apply_updates(trainable_params, updates)
        # TODO: Proximal update on the L1 non-smooth part.
        return opt_state, trainable_params, untrainable_params, batch_loss

    trainable_params, untrainable_params = partition_dict(
        trainable_predicate, flatten_params(init_params))
    print("Trainable params:")
    print(tree_map(jnp.shape, trainable_params))
    assert len(trainable_params) > 0

    opt_state = tx.init(trainable_params)
    itercount = itertools.count()
    batches = data_stream()
    start_time = time.time()
    for epoch in tqdm(range(config.num_epochs)):
        for _ in range(num_batches):
            step = next(itercount)
            opt_state, trainable_params, untrainable_params, batch_loss = update(
                opt_state, trainable_params, untrainable_params, next(batches))
            wandb.log({
                f"{log_prefix}/batch_loss": batch_loss,
                "step": step,
                "wallclock": time.time() - start_time
            })

        # Calculate the proportion of gains that are dead.
        # gains, _ = ravel_pytree(
        #     tree_map(lambda x: x.gain if isinstance(x, ProximalGainLayerWeights) else jnp.array([]),
        #              params,
        #              is_leaf=lambda x: isinstance(x, ProximalGainLayerWeights)))
        # dead_units_proportion = jnp.sum(jnp.abs(gains) < 1e-12) / jnp.size(gains)
        # print(dead_units_proportion)

        wandb.log({
            f"{log_prefix}/train_loss":
            loss(trainable_params, untrainable_params,
                 (train_images, train_labels)),
            f"{log_prefix}/test_loss":
            loss(trainable_params, untrainable_params,
                 (test_images, test_labels)),
            f"{log_prefix}/train_accuracy":
            accuracy(trainable_params, untrainable_params,
                     (train_images, train_labels)),
            f"{log_prefix}/test_accuracy":
            accuracy(trainable_params, untrainable_params,
                     (test_images, test_labels)),
            # f"{log_prefix}/dead_units_proportion": dead_units_proportion,
            "step":
            step,
            "epoch":
            epoch,
            "wallclock":
            time.time() - start_time
        })

    return merge_params(trainable_params, untrainable_params)
def vgg(depth, width, num_classes):
    assert depth in [11, 13, 16, 19]
    depth_str = str(int(depth))
    _cfg = cfg[depth_str]

    def gen_feature_params():
        in_channels = 3
        dic = {}
        for i in range(len(_cfg)):
            if not _cfg[i] == 'M':
                dic['conv{0}'.format(i)] = conv_params(in_channels, _cfg[i], 3)
                dic['bn{0}'.format(i)] = bnparams(_cfg[i])
                in_channels = _cfg[i]

        return dic

    def gen_feature_stats():
        dic = {}
        for i in range(len(_cfg)):
            if not _cfg[i] == 'M':
                dic['bn{0}'.format(i)] = bnstats(_cfg[i])
        return dic

    def gen_classifier_params():
        return {
            'fc1': linear_params(512, 4096),
            'fc2': linear_params(4096, 4096),
            'fc3': linear_params(4096, num_classes),
        }

    def feature(input, params, stats, mode):
        out = input
        for i in range(len(_cfg)):
            if _cfg[i] == 'M':
                out = F.max_pool2d(out, 2, 2, 0)
            else:
                out = F.conv2d(out, params['conv{0}'.format(i)], padding=1)
                out = activation(out, params, stats, 'bn{0}'.format(i), mode)

        return out

    def activation(x, params, stats, base, mode):
        return F.relu(F.batch_norm(x,
                                   weight=params[base + '.weight'],
                                   bias=params[base + '.bias'],
                                   running_mean=stats[base + '.running_mean'],
                                   running_var=stats[base + '.running_var'],
                                   training=mode,
                                   momentum=0.1,
                                   eps=1e-5),
                      inplace=True)

    def classifier(input, params, num_classes, mode):
        out = F.relu(F.linear(input, params['fc1.weight'], params['fc1.bias']),
                     inplace=False)
        #         out = F.dropout(out, p=0.3, training=mode)
        out = F.relu(F.linear(out, params['fc2.weight'], params['fc2.bias']),
                     inplace=False)
        #         out = F.dropout(out, p=0.3, training=mode)
        out = F.linear(out, params['fc3.weight'], params['fc3.bias'])

        return out

    params = {**gen_feature_params(), **gen_classifier_params()}
    stats = gen_feature_stats()

    flat_params = flatten_params(params)
    flat_stats = flatten_stats(stats)

    def f(input, params, stats, mode):
        out = feature(input, params, stats, mode)
        out = out.view(-1, np.prod(out.size()[1:])).contiguous()
        out = classifier(out, params, num_classes, mode)

        return out

    return f, flat_params, flat_stats