Example #1
0
    def __init__(self, hparams):
        super(Model, self).__init__()
        self.optimizer = None
        self.scheduler = None
        self._metric: Optional[Metric] = None
        self.metrics: Dict[str, Metric] = dict()
        self.trn_datasets: List[Dataset] = None
        self.val_datasets: List[Dataset] = None
        self.tst_datasets: List[Dataset] = None
        self.padding: Dict[str, int] = {}
        self.base_dir: str = ""

        self._batch_per_epoch: int = -1
        self._comparsion: Optional[str] = None
        self._selection_criterion: Optional[str] = None

        if isinstance(hparams, dict):
            hparams = Namespace(**hparams)
        self.hparams: Namespace = hparams
        pl.seed_everything(hparams.seed)

        self.tokenizer = AutoTokenizer.from_pretrained(hparams.pretrain)
        self.model = self.build_model()
        self.freeze_layers()

        self.weight = nn.Parameter(torch.zeros(self.num_layers))
        self.mapping = None
        if hparams.mapping:
            assert os.path.isfile(hparams.mapping)
            self.mapping = torch.load(hparams.mapping)
            util.freeze(self.mapping)

        self.projector = self.build_projector()
        self.dropout = module.InputVariationalDropout(hparams.input_dropout)
Example #2
0
    def __init__(self, hparams):
        super(Aligner, self).__init__(hparams)

        self._comparsion = "min"
        self._selection_criterion = "val_loss"

        self.aligner_projector = self.build_aligner_projector()
        self.orig_model = deepcopy(self.model)
        util.freeze(self.orig_model)

        self.mappings = nn.ModuleList([])
        for _ in range(self.num_layers):
            m = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
            nn.init.eye_(m.weight)
            self.mappings.append(m)

        self.padding = {
            "src_sent": self.tokenizer.pad_token_id,
            "tgt_sent": self.tokenizer.pad_token_id,
            "src_align": PAD_ALIGN,
            "tgt_align": PAD_ALIGN,
            "src_lang": 0,
            "tgt_lang": 0,
            "lang": 0,
        }

        self.setup_metrics()
Example #3
0
 def freeze_layer(self, layer):
     if isinstance(self.model, transformers.BertModel) or isinstance(
             self.model, transformers.RobertaModel):
         util.freeze(self.model.encoder.layer[layer - 1])
     elif isinstance(self.model, transformers.XLMModel):
         util.freeze(self.model.attentions[layer - 1])
         util.freeze(self.model.layer_norm1[layer - 1])
         util.freeze(self.model.ffns[layer - 1])
         util.freeze(self.model.layer_norm2[layer - 1])
     else:
         raise ValueError("Unsupported model")
Example #4
0
 def freeze_embeddings(self):
     if isinstance(self.model, transformers.BertModel) or isinstance(
             self.model, transformers.RobertaModel):
         util.freeze(self.model.embeddings)
     elif isinstance(self.model, transformers.XLMModel):
         util.freeze(self.model.position_embeddings)
         if self.model.n_langs > 1 and self.model.use_lang_emb:
             util.freeze(self.model.lang_embeddings)
         util.freeze(self.model.embeddings)
     else:
         raise ValueError("Unsupported model")
Example #5
0
def compare_events(expected_events, output, mode):
    # The order of expected_events is only meaningful inside a partition, so
    # let's convert it into a map indexed by partition key.
    expected_events_map = {}
    for event in expected_events:
        expected_type, expected_key, expected_old_image, expected_new_image = event
        # For simplicity, we actually use the entire key, not just the partiton
        # key. We only lose a bit of testing power we didn't plan to test anyway
        # (that events for different items in the same partition are ordered).
        key = freeze(expected_key)
        if not key in expected_events_map:
            expected_events_map[key] = []
        expected_events_map[key].append(event)
    # Iterate over the events in output. An event for a certain key needs to
    # be the *first* remaining event for this key in expected_events_map (and
    # then we remove this matched even from expected_events_map)
    for event in output:
        # In DynamoDB, eventSource is 'aws:dynamodb'. We decided to set it to
        # a *different* value - 'scylladb:alternator'. Issue #6931.
        assert 'eventSource' in event
        # Alternator is missing "awsRegion", which makes little sense for it
        # (although maybe we should have provided the DC name). Issue #6931.
        #assert 'awsRegion' in event
        # Alternator is also missing the "eventVersion" entry. Issue #6931.
        #assert 'eventVersion' in event
        # Check that eventID appears, but can't check much on what it is.
        assert 'eventID' in event
        op = event['eventName']
        record = event['dynamodb']
        # record['Keys'] is "serialized" JSON, ({'S', 'thestring'}), so we
        # want to deserialize it to match our expected_events content.
        deserializer = TypeDeserializer()
        key = {x:deserializer.deserialize(y) for (x,y) in record['Keys'].items()}
        expected_type, expected_key, expected_old_image, expected_new_image = expected_events_map[freeze(key)].pop(0)
        if expected_type != '*': # hack to allow a caller to not test op, to bypass issue #6918.
            assert op == expected_type
        assert record['StreamViewType'] == mode
        # Check that all the expected members appear in the record, even if
        # we don't have anything to compare them to (TODO: we should probably
        # at least check they have proper format).
        assert 'ApproximateCreationDateTime' in record
        assert 'SequenceNumber' in record
        # Alternator doesn't set the SizeBytes member. Issue #6931.
        #assert 'SizeBytes' in record
        if mode == 'KEYS_ONLY':
            assert not 'NewImage' in record
            assert not 'OldImage' in record
        elif mode == 'NEW_IMAGE':
            assert not 'OldImage' in record
            if expected_new_image == None:
                assert not 'NewImage' in record
            else:
                new_image = {x:deserializer.deserialize(y) for (x,y) in record['NewImage'].items()}
                assert expected_new_image == new_image
        elif mode == 'OLD_IMAGE':
            assert not 'NewImage' in record
            if expected_old_image == None:
                assert not 'OldImage' in record
                pass
            else:
                old_image = {x:deserializer.deserialize(y) for (x,y) in record['OldImage'].items()}
                assert expected_old_image == old_image
        elif mode == 'NEW_AND_OLD_IMAGES':
            if expected_new_image == None:
                assert not 'NewImage' in record
            else:
                new_image = {x:deserializer.deserialize(y) for (x,y) in record['NewImage'].items()}
                assert expected_new_image == new_image
            if expected_old_image == None:
                assert not 'OldImage' in record
            else:
                old_image = {x:deserializer.deserialize(y) for (x,y) in record['OldImage'].items()}
                assert expected_old_image == old_image
        else:
            pytest.fail('cannot happen')
    # After the above loop, expected_events_map should remain empty arrays.
    # If it isn't, one of the expected events did not yet happen. Return False.
    for entry in expected_events_map.values():
        if len(entry) > 0:
            return False
    return True
Example #6
0
    def compute_layer_stats(layer):
        refreeze = False
        if hasattr(layer, 'frozen') and layer.frozen:
            u.unfreeze(layer)
            refreeze = True

        s = AttrDefault(str, {})
        n = args.stats_batch_size
        param = u.get_param(layer)
        _d = len(param.flatten())  # dimensionality of parameters
        layer_idx = model.layers.index(layer)
        # TODO: get layer type, include it in name
        assert layer_idx >= 0
        assert stats_data.shape[0] == n

        def backprop_loss():
            model.zero_grad()
            output = model(
                stats_data)  # use last saved data batch for backprop
            loss = compute_loss(output, stats_targets)
            loss.backward()
            return loss, output

        def backprop_output():
            model.zero_grad()
            output = model(stats_data)
            output.backward(gradient=torch.ones_like(output))
            return output

        # per-example gradients, n, d
        _loss, _output = backprop_loss()
        At = layer.data_input
        Bt = layer.grad_output * n
        G = u.khatri_rao_t(At, Bt)
        g = G.sum(dim=0, keepdim=True) / n
        u.check_close(g, u.vec(param.grad).t())

        s.diversity = torch.norm(G, "fro")**2 / g.flatten().norm()**2
        s.grad_fro = g.flatten().norm()
        s.param_fro = param.data.flatten().norm()
        pos_activations = torch.sum(layer.data_output > 0)
        neg_activations = torch.sum(layer.data_output <= 0)
        s.a_sparsity = neg_activations.float() / (
            pos_activations + neg_activations)  # 1 sparsity means all 0's
        activation_size = len(layer.data_output.flatten())
        s.a_magnitude = torch.sum(layer.data_output) / activation_size

        _output = backprop_output()
        B2t = layer.grad_output
        J = u.khatri_rao_t(At, B2t)  # batch output Jacobian
        H = J.t() @ J / n

        s.hessian_l2 = u.l2_norm(H)
        s.jacobian_l2 = u.l2_norm(J)
        J1 = J.sum(dim=0) / n  # single output Jacobian
        s.J1_l2 = J1.norm()

        # newton decrement
        def loss_direction(direction, eps):
            """loss improvement if we take step eps in direction dir"""
            return u.to_python_scalar(eps * (direction @ g.t()) - 0.5 *
                                      eps**2 * direction @ H @ direction.t())

        s.regret_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2)

        # TODO: gradient diversity is stuck at 1
        # TODO: newton/gradient angle
        # TODO: newton step magnitude
        s.grad_curvature = u.to_python_scalar(
            g @ H @ g.t())  # curvature in direction of g
        s.step_openai = u.to_python_scalar(
            s.grad_fro**2 / s.grad_curvature) if s.grad_curvature else 999

        s.regret_gradient = loss_direction(g, s.step_openai)

        if refreeze:
            u.freeze(layer)
        return s
Example #7
0
def main():
    global logger, stats_data, stats_targets

    parser = argparse.ArgumentParser()
    # Data
    parser.add_argument('--dataset',
                        type=str,
                        choices=[DATASET_MNIST],
                        default=DATASET_MNIST,
                        help='name of dataset')
    parser.add_argument('--root',
                        type=str,
                        default='./data',
                        help='root of dataset')
    parser.add_argument('--epochs',
                        type=int,
                        default=100,
                        help='number of epochs to train')
    parser.add_argument('--batch_size',
                        type=int,
                        default=128,
                        help='input batch size for training')
    parser.add_argument(
        '--stats_batch_size',
        type=int,
        default=1,
        help='size of batch to use for second order statistics')
    parser.add_argument('--val_batch_size',
                        type=int,
                        default=128,
                        help='input batch size for valing')
    parser.add_argument('--normalizing_data',
                        action='store_true',
                        help='[data pre processing] normalizing data')
    parser.add_argument('--random_crop',
                        action='store_true',
                        help='[data augmentation] random crop')
    parser.add_argument('--random_horizontal_flip',
                        action='store_true',
                        help='[data augmentation] random horizontal flip')
    # Training Settings
    parser.add_argument('--arch_file',
                        type=str,
                        default=None,
                        help='name of file which defines the architecture')
    parser.add_argument('--arch_name',
                        type=str,
                        default='LeNet5',
                        help='name of the architecture')
    parser.add_argument('--arch_args',
                        type=json.loads,
                        default=None,
                        help='[JSON] arguments for the architecture')
    parser.add_argument('--optim_name',
                        type=str,
                        default=SecondOrderOptimizer.__name__,
                        help='name of the optimizer')
    parser.add_argument('--optim_args',
                        type=json.loads,
                        default=None,
                        help='[JSON] arguments for the optimizer')
    parser.add_argument('--curv_args',
                        type=json.loads,
                        default=None,
                        help='[JSON] arguments for the curvature')
    # Options
    parser.add_argument(
        '--download',
        action='store_true',
        default=True,
        help=
        'if True, downloads the dataset (CIFAR-10 or 100) from the internet')
    parser.add_argument('--create_graph',
                        action='store_true',
                        default=False,
                        help='create graph of the derivative')
    parser.add_argument('--no_cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, help='random seed')
    parser.add_argument('--num_workers',
                        type=int,
                        default=0,
                        help='number of sub processes for data loading')
    parser.add_argument(
        '--log_interval',
        type=int,
        default=50,
        help='how many batches to wait before logging training status')
    parser.add_argument('--log_file_name',
                        type=str,
                        default='log',
                        help='log file name')
    parser.add_argument(
        '--checkpoint_interval',
        type=int,
        default=50,
        help='how many epochs to wait before logging training status')
    parser.add_argument('--resume',
                        type=str,
                        default=None,
                        help='checkpoint path for resume training')
    parser.add_argument('--logdir',
                        type=str,
                        default='/temp/graph_test/run',
                        help='dir to save output files')
    parser.add_argument('--out',
                        type=str,
                        default=None,
                        help='dir to save output files')
    parser.add_argument('--config', default=None, help='config file path')
    parser.add_argument('--fisher_mc_approx',
                        action='store_true',
                        default=False,
                        help='if True, Fisher is estimated by MC sampling')
    parser.add_argument('--fisher_num_mc',
                        type=int,
                        default=1,
                        help='number of MC samples for estimating Fisher')
    parser.add_argument('--log_wandb',
                        type=int,
                        default=1,
                        help='log to wandb')

    args = parser.parse_args()

    # get run name from logdir
    root_logdir = args.logdir
    count = 0
    while os.path.exists(f"{root_logdir}{count:02d}"):
        count += 1
    args.logdir = f"{root_logdir}{count:02d}"

    run_name = os.path.basename(args.logdir)

    assert args.out is None, "Use args.logdir instead of args.out"
    args.out = args.logdir

    # Copy this file & config to args.out
    if not os.path.isdir(args.out):
        os.makedirs(args.out)
    shutil.copy(os.path.realpath(__file__), args.out)

    # Load config file
    if args.config:
        with open(args.config) as f:
            config = json.load(f)
        dict_args = vars(args)
        dict_args.update(config)

    if args.config is not None:
        shutil.copy(args.config, args.out)
    if args.arch_file is not None:
        shutil.copy(args.arch_file, args.out)

    # Setup logger
    logger = Logger(args.out, args.log_file_name)
    logger.start()

    g.event_writer = SummaryWriter(args.logdir)

    try:
        # os.environ['WANDB_SILENT'] = 'true'
        if args.log_wandb:
            wandb.init(project='test-graphs_test', name=run_name)
            wandb.tensorboard.patch(tensorboardX=False)
        wandb.config['config'] = args.config
        wandb.config['batch'] = args.batch_size
        wandb.config['optim'] = args.optim_name
    except Exception as e:
        if args.log_wandb:
            print(f"wandb crash with {e}")
        pass

    # Set device
    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')

    # Set random seed
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    # Setup data augmentation & data pre processing
    train_transforms, val_transforms = [], []
    if args.random_crop:
        train_transforms.append(transforms.RandomCrop(32, padding=4))

    if args.random_horizontal_flip:
        train_transforms.append(transforms.RandomHorizontalFlip())

    train_transforms.append(transforms.ToTensor())
    val_transforms.append(transforms.ToTensor())

    if args.normalizing_data:
        normalize = transforms.Normalize((0.4914, 0.4822, 0.4465),
                                         (0.2023, 0.1994, 0.2010))
        train_transforms.append(normalize)
        val_transforms.append(normalize)

    train_transform = transforms.Compose(train_transforms)
    # val_transform = transforms.Compose(val_transforms)

    num_classes = 10
    dataset_class = SimpleMNIST

    class Net(nn.Module):
        def __init__(self, d, nonlin=True):
            super().__init__()
            self.layers = []
            self.all_layers = []
            self.d = d
            for i in range(len(d) - 1):
                linear = nn.Linear(d[i], d[i + 1], bias=False)
                self.layers.append(linear)
                self.all_layers.append(linear)
                if nonlin:
                    self.all_layers.append(nn.ReLU())
            self.predict = torch.nn.Sequential(*self.all_layers)

        def forward(self, x: torch.Tensor):
            x = x.reshape((-1, self.d[0]))
            return self.predict(x)

    def compute_layer_stats(layer):
        refreeze = False
        if hasattr(layer, 'frozen') and layer.frozen:
            u.unfreeze(layer)
            refreeze = True

        s = AttrDefault(str, {})
        n = args.stats_batch_size
        param = u.get_param(layer)
        _d = len(param.flatten())  # dimensionality of parameters
        layer_idx = model.layers.index(layer)
        # TODO: get layer type, include it in name
        assert layer_idx >= 0
        assert stats_data.shape[0] == n

        def backprop_loss():
            model.zero_grad()
            output = model(
                stats_data)  # use last saved data batch for backprop
            loss = compute_loss(output, stats_targets)
            loss.backward()
            return loss, output

        def backprop_output():
            model.zero_grad()
            output = model(stats_data)
            output.backward(gradient=torch.ones_like(output))
            return output

        # per-example gradients, n, d
        _loss, _output = backprop_loss()
        At = layer.data_input
        Bt = layer.grad_output * n
        G = u.khatri_rao_t(At, Bt)
        g = G.sum(dim=0, keepdim=True) / n
        u.check_close(g, u.vec(param.grad).t())

        s.diversity = torch.norm(G, "fro")**2 / g.flatten().norm()**2
        s.grad_fro = g.flatten().norm()
        s.param_fro = param.data.flatten().norm()
        pos_activations = torch.sum(layer.data_output > 0)
        neg_activations = torch.sum(layer.data_output <= 0)
        s.a_sparsity = neg_activations.float() / (
            pos_activations + neg_activations)  # 1 sparsity means all 0's
        activation_size = len(layer.data_output.flatten())
        s.a_magnitude = torch.sum(layer.data_output) / activation_size

        _output = backprop_output()
        B2t = layer.grad_output
        J = u.khatri_rao_t(At, B2t)  # batch output Jacobian
        H = J.t() @ J / n

        s.hessian_l2 = u.l2_norm(H)
        s.jacobian_l2 = u.l2_norm(J)
        J1 = J.sum(dim=0) / n  # single output Jacobian
        s.J1_l2 = J1.norm()

        # newton decrement
        def loss_direction(direction, eps):
            """loss improvement if we take step eps in direction dir"""
            return u.to_python_scalar(eps * (direction @ g.t()) - 0.5 *
                                      eps**2 * direction @ H @ direction.t())

        s.regret_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2)

        # TODO: gradient diversity is stuck at 1
        # TODO: newton/gradient angle
        # TODO: newton step magnitude
        s.grad_curvature = u.to_python_scalar(
            g @ H @ g.t())  # curvature in direction of g
        s.step_openai = u.to_python_scalar(
            s.grad_fro**2 / s.grad_curvature) if s.grad_curvature else 999

        s.regret_gradient = loss_direction(g, s.step_openai)

        if refreeze:
            u.freeze(layer)
        return s

    train_dataset = dataset_class(root=gl.dataset,
                                  train=True,
                                  download=args.download,
                                  transform=train_transform)
    # val_dataset = t(root=args.root, train=False, download=args.download, transform=val_transform)

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               num_workers=args.num_workers)
    # val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.val_batch_size, shuffle=False, num_workers=args.num_workers)
    stats_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.stats_batch_size,
        shuffle=False,
        num_workers=args.num_workers)
    stats_data, stats_targets = next(iter(stats_loader))

    arch_kwargs = {} if args.arch_args is None else args.arch_args
    arch_kwargs['num_classes'] = num_classes

    model = Net([NUM_CHANNELS * IMAGE_SIZE**2, 8, 1], nonlin=False)
    setattr(model, 'num_classes', num_classes)
    model = model.to(device)
    param = u.get_param(model.layers[0])
    param.data.copy_(torch.ones_like(param) / len(param.data.flatten()))
    param = u.get_param(model.layers[1])
    param.data.copy_(torch.ones_like(param) / len(param.data.flatten()))

    u.freeze(model.layers[0])

    # use learning rate 0 to avoid changing parameter vector
    optim_kwargs = dict(lr=0.001,
                        momentum=0.9,
                        weight_decay=0,
                        l2_reg=0,
                        bias_correction=False,
                        acc_steps=1,
                        curv_type="Cov",
                        curv_shapes={"Linear": "Kron"},
                        momentum_type="preconditioned",
                        update_inv=False,
                        precondition_grad=False)
    curv_args = dict(damping=0, ema_decay=1)
    SecondOrderOptimizer(
        model, **optim_kwargs,
        curv_kwargs=curv_args)  # call optimizer to add backward hoooks
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

    start_epoch = 1

    # Run training
    for epoch in range(start_epoch, args.epochs + 1):
        num_examples_processed = epoch * len(
            train_loader) * train_loader.batch_size
        g.token_count = num_examples_processed
        for i in range(len(model.layers)):
            if i == 0:
                continue  # skip initial expensive layer

            layer = model.layers[i]
            layer_stats = compute_layer_stats(layer)
            layer_name = f"{i:02d}-{layer.__class__.__name__.lower()}"
            log_scalars(u.nest_stats(f'stats/{layer_name}', layer_stats))

        # train
        accuracy, loss, confidence = train(model, device, train_loader,
                                           optimizer, epoch, args, logger)

        # save log
        iteration = epoch * len(train_loader)
        metrics = {
            'epoch': epoch,
            'iteration': iteration,
            'accuracy': accuracy,
            'loss': loss,
            'val_accuracy': 0,
            'val_loss': 0,
            'lr': optimizer.param_groups[0]['lr'],
            'momentum': optimizer.param_groups[0].get('momentum', 0)
        }
        log_scalars(metrics)

        # save checkpoint
        if epoch % args.checkpoint_interval == 0 or epoch == args.epochs:
            path = os.path.join(args.out, 'epoch{}.ckpt'.format(epoch))
            data = {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'epoch': epoch
            }
            torch.save(data, path)
def test_lineasearch():
    """Implement linesearch with sanity checks."""
    global logger, stats_data, stats_targets, args
    run_name = 'default'  # name of run in

    torch.set_default_dtype(torch.float32)

    # Copy this file & config to args.out
    out = '/tmp'
    if not os.path.isdir(out):
        os.makedirs(out)
    shutil.copy(os.path.realpath(__file__), out)

    # Setup logger
    log_file_name = 'log'
    logger = Logger(out, log_file_name)
    logger.start()

    # Set device
    use_cuda = False
    device = torch.device('cuda' if use_cuda else 'cpu')

    # Set random seed
    u.seed_random(1)

    # Setup data augmentation & data pre processing
    train_transforms, val_transforms = [], []

    train_transforms.append(transforms.ToTensor())
    val_transforms.append(transforms.ToTensor())

    train_transform = transforms.Compose(train_transforms)
    # val_transform = transforms.Compose(val_transforms)

    num_classes = 10
    dataset_class = SimpleMNIST

    class Net(nn.Module):
        def __init__(self, d, nonlin=True):
            super().__init__()
            self.layers = []
            self.all_layers = []
            self.d = d
            for i in range(len(d) - 1):
                linear = nn.Linear(d[i], d[i + 1], bias=False)
                self.layers.append(linear)
                self.all_layers.append(linear)
                if nonlin:
                    self.all_layers.append(nn.ReLU())
            self.predict = torch.nn.Sequential(*self.all_layers)

        def forward(self, x: torch.Tensor):
            x = x.reshape((-1, self.d[0]))
            return self.predict(x)

    stats_batch_size = 1

    def compute_layer_stats(layer):
        stats = AttrDefault(str, {})
        n = stats_batch_size
        param = u.get_param(layer)
        d = len(param.flatten())
        layer_idx = model.layers.index(layer)
        assert layer_idx >= 0
        assert stats_data.shape[0] == n

        def backprop_loss():
            model.zero_grad()
            output = model(
                stats_data)  # use last saved data batch for backprop
            loss = compute_loss(output, stats_targets)
            loss.backward()
            return loss, output

        def backprop_output():
            model.zero_grad()
            output = model(stats_data)
            output.backward(gradient=torch.ones_like(output))
            return output

        # per-example gradients, n, d
        loss, output = backprop_loss()
        At = layer.data_input
        Bt = layer.grad_output * n
        G = u.khatri_rao_t(At, Bt)
        g = G.sum(dim=0, keepdim=True) / n
        u.check_close(g, u.vec(param.grad).t())

        stats.diversity = torch.norm(G, "fro")**2 / g.flatten().norm()**2

        stats.gradient_norm = g.flatten().norm()
        stats.parameter_norm = param.data.flatten().norm()
        pos_activations = torch.sum(layer.data_output > 0)
        neg_activations = torch.sum(layer.data_output <= 0)
        stats.sparsity = pos_activations.float() / (pos_activations +
                                                    neg_activations)

        output = backprop_output()
        At2 = layer.data_input
        u.check_close(At, At2)
        B2t = layer.grad_output
        J = u.khatri_rao_t(At, B2t)
        H = J.t() @ J / n

        model.zero_grad()
        output = model(stats_data)  # use last saved data batch for backprop
        loss = compute_loss(output, stats_targets)
        hess = u.hessian(loss, param)

        hess = hess.transpose(2, 3).transpose(0, 1).reshape(d, d)
        u.check_close(hess, H)
        u.check_close(hess, H)

        stats.hessian_norm = u.l2_norm(H)
        stats.jacobian_norm = u.l2_norm(J)
        Joutput = J.sum(dim=0) / n
        stats.jacobian_sensitivity = Joutput.norm()

        # newton decrement
        stats.loss_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2)
        u.check_close(stats.loss_newton, loss)

        # do line-search to find optimal step
        def line_search(directionv, start, end, steps=10):
            """Takes steps between start and end, returns steps+1 loss entries"""
            param0 = param.data.clone()
            param0v = u.vec(param0).t()
            losses = []
            for i in range(steps + 1):
                output = model(
                    stats_data)  # use last saved data batch for backprop
                loss = compute_loss(output, stats_targets)
                losses.append(loss)
                offset = start + i * ((end - start) / steps)
                param1v = param0v + offset * directionv

                param1 = u.unvec(param1v.t(), param.data.shape[0])
                param.data.copy_(param1)

            output = model(
                stats_data)  # use last saved data batch for backprop
            loss = compute_loss(output, stats_targets)
            losses.append(loss)

            param.data.copy_(param0)
            return losses

        # try to take a newton step
        gradv = g
        line_losses = line_search(-gradv @ u.pinv(H), 0, 2, steps=10)
        u.check_equal(line_losses[0], loss)
        u.check_equal(line_losses[6], 0)
        assert line_losses[5] > line_losses[6]
        assert line_losses[7] > line_losses[6]
        return stats

    train_dataset = dataset_class(root='/tmp/data',
                                  train=True,
                                  download=True,
                                  transform=train_transform)

    batch_size = 32
    num_workers = 0
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=False,
                                               num_workers=num_workers)
    stats_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=stats_batch_size,
                                               shuffle=False,
                                               num_workers=num_workers)
    stats_data, stats_targets = next(iter(stats_loader))

    model = Net([NUM_CHANNELS * IMAGE_SIZE**2, 8, 8, 1], nonlin=False)
    setattr(model, 'num_classes', num_classes)
    model = model.to(device)
    u.freeze(model.layers[0])
    u.freeze(model.layers[2])

    # use learning rate 0 to avoid changing parameter vector
    optim_kwargs = dict(lr=0.001,
                        momentum=0.9,
                        weight_decay=0,
                        l2_reg=0,
                        bias_correction=False,
                        acc_steps=1,
                        curv_type="Cov",
                        curv_shapes={"Linear": "Kron"},
                        momentum_type="preconditioned",
                        update_inv=False,
                        precondition_grad=False)
    curv_args = dict(damping=0, ema_decay=1)
    SecondOrderOptimizer(
        model, **optim_kwargs,
        curv_kwargs=curv_args)  # call optimizer to add backward hoooks
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

    start_epoch = 1

    # Run training
    epochs = 100
    for epoch in range(start_epoch, epochs + 1):
        num_examples_processed = epoch * len(
            train_loader) * train_loader.batch_size
        layer_stats = compute_layer_stats(model.layers[1])