def __init__(self, hparams): super(Model, self).__init__() self.optimizer = None self.scheduler = None self._metric: Optional[Metric] = None self.metrics: Dict[str, Metric] = dict() self.trn_datasets: List[Dataset] = None self.val_datasets: List[Dataset] = None self.tst_datasets: List[Dataset] = None self.padding: Dict[str, int] = {} self.base_dir: str = "" self._batch_per_epoch: int = -1 self._comparsion: Optional[str] = None self._selection_criterion: Optional[str] = None if isinstance(hparams, dict): hparams = Namespace(**hparams) self.hparams: Namespace = hparams pl.seed_everything(hparams.seed) self.tokenizer = AutoTokenizer.from_pretrained(hparams.pretrain) self.model = self.build_model() self.freeze_layers() self.weight = nn.Parameter(torch.zeros(self.num_layers)) self.mapping = None if hparams.mapping: assert os.path.isfile(hparams.mapping) self.mapping = torch.load(hparams.mapping) util.freeze(self.mapping) self.projector = self.build_projector() self.dropout = module.InputVariationalDropout(hparams.input_dropout)
def __init__(self, hparams): super(Aligner, self).__init__(hparams) self._comparsion = "min" self._selection_criterion = "val_loss" self.aligner_projector = self.build_aligner_projector() self.orig_model = deepcopy(self.model) util.freeze(self.orig_model) self.mappings = nn.ModuleList([]) for _ in range(self.num_layers): m = nn.Linear(self.hidden_size, self.hidden_size, bias=False) nn.init.eye_(m.weight) self.mappings.append(m) self.padding = { "src_sent": self.tokenizer.pad_token_id, "tgt_sent": self.tokenizer.pad_token_id, "src_align": PAD_ALIGN, "tgt_align": PAD_ALIGN, "src_lang": 0, "tgt_lang": 0, "lang": 0, } self.setup_metrics()
def freeze_layer(self, layer): if isinstance(self.model, transformers.BertModel) or isinstance( self.model, transformers.RobertaModel): util.freeze(self.model.encoder.layer[layer - 1]) elif isinstance(self.model, transformers.XLMModel): util.freeze(self.model.attentions[layer - 1]) util.freeze(self.model.layer_norm1[layer - 1]) util.freeze(self.model.ffns[layer - 1]) util.freeze(self.model.layer_norm2[layer - 1]) else: raise ValueError("Unsupported model")
def freeze_embeddings(self): if isinstance(self.model, transformers.BertModel) or isinstance( self.model, transformers.RobertaModel): util.freeze(self.model.embeddings) elif isinstance(self.model, transformers.XLMModel): util.freeze(self.model.position_embeddings) if self.model.n_langs > 1 and self.model.use_lang_emb: util.freeze(self.model.lang_embeddings) util.freeze(self.model.embeddings) else: raise ValueError("Unsupported model")
def compare_events(expected_events, output, mode): # The order of expected_events is only meaningful inside a partition, so # let's convert it into a map indexed by partition key. expected_events_map = {} for event in expected_events: expected_type, expected_key, expected_old_image, expected_new_image = event # For simplicity, we actually use the entire key, not just the partiton # key. We only lose a bit of testing power we didn't plan to test anyway # (that events for different items in the same partition are ordered). key = freeze(expected_key) if not key in expected_events_map: expected_events_map[key] = [] expected_events_map[key].append(event) # Iterate over the events in output. An event for a certain key needs to # be the *first* remaining event for this key in expected_events_map (and # then we remove this matched even from expected_events_map) for event in output: # In DynamoDB, eventSource is 'aws:dynamodb'. We decided to set it to # a *different* value - 'scylladb:alternator'. Issue #6931. assert 'eventSource' in event # Alternator is missing "awsRegion", which makes little sense for it # (although maybe we should have provided the DC name). Issue #6931. #assert 'awsRegion' in event # Alternator is also missing the "eventVersion" entry. Issue #6931. #assert 'eventVersion' in event # Check that eventID appears, but can't check much on what it is. assert 'eventID' in event op = event['eventName'] record = event['dynamodb'] # record['Keys'] is "serialized" JSON, ({'S', 'thestring'}), so we # want to deserialize it to match our expected_events content. deserializer = TypeDeserializer() key = {x:deserializer.deserialize(y) for (x,y) in record['Keys'].items()} expected_type, expected_key, expected_old_image, expected_new_image = expected_events_map[freeze(key)].pop(0) if expected_type != '*': # hack to allow a caller to not test op, to bypass issue #6918. assert op == expected_type assert record['StreamViewType'] == mode # Check that all the expected members appear in the record, even if # we don't have anything to compare them to (TODO: we should probably # at least check they have proper format). assert 'ApproximateCreationDateTime' in record assert 'SequenceNumber' in record # Alternator doesn't set the SizeBytes member. Issue #6931. #assert 'SizeBytes' in record if mode == 'KEYS_ONLY': assert not 'NewImage' in record assert not 'OldImage' in record elif mode == 'NEW_IMAGE': assert not 'OldImage' in record if expected_new_image == None: assert not 'NewImage' in record else: new_image = {x:deserializer.deserialize(y) for (x,y) in record['NewImage'].items()} assert expected_new_image == new_image elif mode == 'OLD_IMAGE': assert not 'NewImage' in record if expected_old_image == None: assert not 'OldImage' in record pass else: old_image = {x:deserializer.deserialize(y) for (x,y) in record['OldImage'].items()} assert expected_old_image == old_image elif mode == 'NEW_AND_OLD_IMAGES': if expected_new_image == None: assert not 'NewImage' in record else: new_image = {x:deserializer.deserialize(y) for (x,y) in record['NewImage'].items()} assert expected_new_image == new_image if expected_old_image == None: assert not 'OldImage' in record else: old_image = {x:deserializer.deserialize(y) for (x,y) in record['OldImage'].items()} assert expected_old_image == old_image else: pytest.fail('cannot happen') # After the above loop, expected_events_map should remain empty arrays. # If it isn't, one of the expected events did not yet happen. Return False. for entry in expected_events_map.values(): if len(entry) > 0: return False return True
def compute_layer_stats(layer): refreeze = False if hasattr(layer, 'frozen') and layer.frozen: u.unfreeze(layer) refreeze = True s = AttrDefault(str, {}) n = args.stats_batch_size param = u.get_param(layer) _d = len(param.flatten()) # dimensionality of parameters layer_idx = model.layers.index(layer) # TODO: get layer type, include it in name assert layer_idx >= 0 assert stats_data.shape[0] == n def backprop_loss(): model.zero_grad() output = model( stats_data) # use last saved data batch for backprop loss = compute_loss(output, stats_targets) loss.backward() return loss, output def backprop_output(): model.zero_grad() output = model(stats_data) output.backward(gradient=torch.ones_like(output)) return output # per-example gradients, n, d _loss, _output = backprop_loss() At = layer.data_input Bt = layer.grad_output * n G = u.khatri_rao_t(At, Bt) g = G.sum(dim=0, keepdim=True) / n u.check_close(g, u.vec(param.grad).t()) s.diversity = torch.norm(G, "fro")**2 / g.flatten().norm()**2 s.grad_fro = g.flatten().norm() s.param_fro = param.data.flatten().norm() pos_activations = torch.sum(layer.data_output > 0) neg_activations = torch.sum(layer.data_output <= 0) s.a_sparsity = neg_activations.float() / ( pos_activations + neg_activations) # 1 sparsity means all 0's activation_size = len(layer.data_output.flatten()) s.a_magnitude = torch.sum(layer.data_output) / activation_size _output = backprop_output() B2t = layer.grad_output J = u.khatri_rao_t(At, B2t) # batch output Jacobian H = J.t() @ J / n s.hessian_l2 = u.l2_norm(H) s.jacobian_l2 = u.l2_norm(J) J1 = J.sum(dim=0) / n # single output Jacobian s.J1_l2 = J1.norm() # newton decrement def loss_direction(direction, eps): """loss improvement if we take step eps in direction dir""" return u.to_python_scalar(eps * (direction @ g.t()) - 0.5 * eps**2 * direction @ H @ direction.t()) s.regret_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2) # TODO: gradient diversity is stuck at 1 # TODO: newton/gradient angle # TODO: newton step magnitude s.grad_curvature = u.to_python_scalar( g @ H @ g.t()) # curvature in direction of g s.step_openai = u.to_python_scalar( s.grad_fro**2 / s.grad_curvature) if s.grad_curvature else 999 s.regret_gradient = loss_direction(g, s.step_openai) if refreeze: u.freeze(layer) return s
def main(): global logger, stats_data, stats_targets parser = argparse.ArgumentParser() # Data parser.add_argument('--dataset', type=str, choices=[DATASET_MNIST], default=DATASET_MNIST, help='name of dataset') parser.add_argument('--root', type=str, default='./data', help='root of dataset') parser.add_argument('--epochs', type=int, default=100, help='number of epochs to train') parser.add_argument('--batch_size', type=int, default=128, help='input batch size for training') parser.add_argument( '--stats_batch_size', type=int, default=1, help='size of batch to use for second order statistics') parser.add_argument('--val_batch_size', type=int, default=128, help='input batch size for valing') parser.add_argument('--normalizing_data', action='store_true', help='[data pre processing] normalizing data') parser.add_argument('--random_crop', action='store_true', help='[data augmentation] random crop') parser.add_argument('--random_horizontal_flip', action='store_true', help='[data augmentation] random horizontal flip') # Training Settings parser.add_argument('--arch_file', type=str, default=None, help='name of file which defines the architecture') parser.add_argument('--arch_name', type=str, default='LeNet5', help='name of the architecture') parser.add_argument('--arch_args', type=json.loads, default=None, help='[JSON] arguments for the architecture') parser.add_argument('--optim_name', type=str, default=SecondOrderOptimizer.__name__, help='name of the optimizer') parser.add_argument('--optim_args', type=json.loads, default=None, help='[JSON] arguments for the optimizer') parser.add_argument('--curv_args', type=json.loads, default=None, help='[JSON] arguments for the curvature') # Options parser.add_argument( '--download', action='store_true', default=True, help= 'if True, downloads the dataset (CIFAR-10 or 100) from the internet') parser.add_argument('--create_graph', action='store_true', default=False, help='create graph of the derivative') parser.add_argument('--no_cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, help='random seed') parser.add_argument('--num_workers', type=int, default=0, help='number of sub processes for data loading') parser.add_argument( '--log_interval', type=int, default=50, help='how many batches to wait before logging training status') parser.add_argument('--log_file_name', type=str, default='log', help='log file name') parser.add_argument( '--checkpoint_interval', type=int, default=50, help='how many epochs to wait before logging training status') parser.add_argument('--resume', type=str, default=None, help='checkpoint path for resume training') parser.add_argument('--logdir', type=str, default='/temp/graph_test/run', help='dir to save output files') parser.add_argument('--out', type=str, default=None, help='dir to save output files') parser.add_argument('--config', default=None, help='config file path') parser.add_argument('--fisher_mc_approx', action='store_true', default=False, help='if True, Fisher is estimated by MC sampling') parser.add_argument('--fisher_num_mc', type=int, default=1, help='number of MC samples for estimating Fisher') parser.add_argument('--log_wandb', type=int, default=1, help='log to wandb') args = parser.parse_args() # get run name from logdir root_logdir = args.logdir count = 0 while os.path.exists(f"{root_logdir}{count:02d}"): count += 1 args.logdir = f"{root_logdir}{count:02d}" run_name = os.path.basename(args.logdir) assert args.out is None, "Use args.logdir instead of args.out" args.out = args.logdir # Copy this file & config to args.out if not os.path.isdir(args.out): os.makedirs(args.out) shutil.copy(os.path.realpath(__file__), args.out) # Load config file if args.config: with open(args.config) as f: config = json.load(f) dict_args = vars(args) dict_args.update(config) if args.config is not None: shutil.copy(args.config, args.out) if args.arch_file is not None: shutil.copy(args.arch_file, args.out) # Setup logger logger = Logger(args.out, args.log_file_name) logger.start() g.event_writer = SummaryWriter(args.logdir) try: # os.environ['WANDB_SILENT'] = 'true' if args.log_wandb: wandb.init(project='test-graphs_test', name=run_name) wandb.tensorboard.patch(tensorboardX=False) wandb.config['config'] = args.config wandb.config['batch'] = args.batch_size wandb.config['optim'] = args.optim_name except Exception as e: if args.log_wandb: print(f"wandb crash with {e}") pass # Set device use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device('cuda' if use_cuda else 'cpu') # Set random seed torch.manual_seed(args.seed) random.seed(args.seed) np.random.seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) # Setup data augmentation & data pre processing train_transforms, val_transforms = [], [] if args.random_crop: train_transforms.append(transforms.RandomCrop(32, padding=4)) if args.random_horizontal_flip: train_transforms.append(transforms.RandomHorizontalFlip()) train_transforms.append(transforms.ToTensor()) val_transforms.append(transforms.ToTensor()) if args.normalizing_data: normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) train_transforms.append(normalize) val_transforms.append(normalize) train_transform = transforms.Compose(train_transforms) # val_transform = transforms.Compose(val_transforms) num_classes = 10 dataset_class = SimpleMNIST class Net(nn.Module): def __init__(self, d, nonlin=True): super().__init__() self.layers = [] self.all_layers = [] self.d = d for i in range(len(d) - 1): linear = nn.Linear(d[i], d[i + 1], bias=False) self.layers.append(linear) self.all_layers.append(linear) if nonlin: self.all_layers.append(nn.ReLU()) self.predict = torch.nn.Sequential(*self.all_layers) def forward(self, x: torch.Tensor): x = x.reshape((-1, self.d[0])) return self.predict(x) def compute_layer_stats(layer): refreeze = False if hasattr(layer, 'frozen') and layer.frozen: u.unfreeze(layer) refreeze = True s = AttrDefault(str, {}) n = args.stats_batch_size param = u.get_param(layer) _d = len(param.flatten()) # dimensionality of parameters layer_idx = model.layers.index(layer) # TODO: get layer type, include it in name assert layer_idx >= 0 assert stats_data.shape[0] == n def backprop_loss(): model.zero_grad() output = model( stats_data) # use last saved data batch for backprop loss = compute_loss(output, stats_targets) loss.backward() return loss, output def backprop_output(): model.zero_grad() output = model(stats_data) output.backward(gradient=torch.ones_like(output)) return output # per-example gradients, n, d _loss, _output = backprop_loss() At = layer.data_input Bt = layer.grad_output * n G = u.khatri_rao_t(At, Bt) g = G.sum(dim=0, keepdim=True) / n u.check_close(g, u.vec(param.grad).t()) s.diversity = torch.norm(G, "fro")**2 / g.flatten().norm()**2 s.grad_fro = g.flatten().norm() s.param_fro = param.data.flatten().norm() pos_activations = torch.sum(layer.data_output > 0) neg_activations = torch.sum(layer.data_output <= 0) s.a_sparsity = neg_activations.float() / ( pos_activations + neg_activations) # 1 sparsity means all 0's activation_size = len(layer.data_output.flatten()) s.a_magnitude = torch.sum(layer.data_output) / activation_size _output = backprop_output() B2t = layer.grad_output J = u.khatri_rao_t(At, B2t) # batch output Jacobian H = J.t() @ J / n s.hessian_l2 = u.l2_norm(H) s.jacobian_l2 = u.l2_norm(J) J1 = J.sum(dim=0) / n # single output Jacobian s.J1_l2 = J1.norm() # newton decrement def loss_direction(direction, eps): """loss improvement if we take step eps in direction dir""" return u.to_python_scalar(eps * (direction @ g.t()) - 0.5 * eps**2 * direction @ H @ direction.t()) s.regret_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2) # TODO: gradient diversity is stuck at 1 # TODO: newton/gradient angle # TODO: newton step magnitude s.grad_curvature = u.to_python_scalar( g @ H @ g.t()) # curvature in direction of g s.step_openai = u.to_python_scalar( s.grad_fro**2 / s.grad_curvature) if s.grad_curvature else 999 s.regret_gradient = loss_direction(g, s.step_openai) if refreeze: u.freeze(layer) return s train_dataset = dataset_class(root=gl.dataset, train=True, download=args.download, transform=train_transform) # val_dataset = t(root=args.root, train=False, download=args.download, transform=val_transform) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) # val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=args.val_batch_size, shuffle=False, num_workers=args.num_workers) stats_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.stats_batch_size, shuffle=False, num_workers=args.num_workers) stats_data, stats_targets = next(iter(stats_loader)) arch_kwargs = {} if args.arch_args is None else args.arch_args arch_kwargs['num_classes'] = num_classes model = Net([NUM_CHANNELS * IMAGE_SIZE**2, 8, 1], nonlin=False) setattr(model, 'num_classes', num_classes) model = model.to(device) param = u.get_param(model.layers[0]) param.data.copy_(torch.ones_like(param) / len(param.data.flatten())) param = u.get_param(model.layers[1]) param.data.copy_(torch.ones_like(param) / len(param.data.flatten())) u.freeze(model.layers[0]) # use learning rate 0 to avoid changing parameter vector optim_kwargs = dict(lr=0.001, momentum=0.9, weight_decay=0, l2_reg=0, bias_correction=False, acc_steps=1, curv_type="Cov", curv_shapes={"Linear": "Kron"}, momentum_type="preconditioned", update_inv=False, precondition_grad=False) curv_args = dict(damping=0, ema_decay=1) SecondOrderOptimizer( model, **optim_kwargs, curv_kwargs=curv_args) # call optimizer to add backward hoooks optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) start_epoch = 1 # Run training for epoch in range(start_epoch, args.epochs + 1): num_examples_processed = epoch * len( train_loader) * train_loader.batch_size g.token_count = num_examples_processed for i in range(len(model.layers)): if i == 0: continue # skip initial expensive layer layer = model.layers[i] layer_stats = compute_layer_stats(layer) layer_name = f"{i:02d}-{layer.__class__.__name__.lower()}" log_scalars(u.nest_stats(f'stats/{layer_name}', layer_stats)) # train accuracy, loss, confidence = train(model, device, train_loader, optimizer, epoch, args, logger) # save log iteration = epoch * len(train_loader) metrics = { 'epoch': epoch, 'iteration': iteration, 'accuracy': accuracy, 'loss': loss, 'val_accuracy': 0, 'val_loss': 0, 'lr': optimizer.param_groups[0]['lr'], 'momentum': optimizer.param_groups[0].get('momentum', 0) } log_scalars(metrics) # save checkpoint if epoch % args.checkpoint_interval == 0 or epoch == args.epochs: path = os.path.join(args.out, 'epoch{}.ckpt'.format(epoch)) data = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch } torch.save(data, path)
def test_lineasearch(): """Implement linesearch with sanity checks.""" global logger, stats_data, stats_targets, args run_name = 'default' # name of run in torch.set_default_dtype(torch.float32) # Copy this file & config to args.out out = '/tmp' if not os.path.isdir(out): os.makedirs(out) shutil.copy(os.path.realpath(__file__), out) # Setup logger log_file_name = 'log' logger = Logger(out, log_file_name) logger.start() # Set device use_cuda = False device = torch.device('cuda' if use_cuda else 'cpu') # Set random seed u.seed_random(1) # Setup data augmentation & data pre processing train_transforms, val_transforms = [], [] train_transforms.append(transforms.ToTensor()) val_transforms.append(transforms.ToTensor()) train_transform = transforms.Compose(train_transforms) # val_transform = transforms.Compose(val_transforms) num_classes = 10 dataset_class = SimpleMNIST class Net(nn.Module): def __init__(self, d, nonlin=True): super().__init__() self.layers = [] self.all_layers = [] self.d = d for i in range(len(d) - 1): linear = nn.Linear(d[i], d[i + 1], bias=False) self.layers.append(linear) self.all_layers.append(linear) if nonlin: self.all_layers.append(nn.ReLU()) self.predict = torch.nn.Sequential(*self.all_layers) def forward(self, x: torch.Tensor): x = x.reshape((-1, self.d[0])) return self.predict(x) stats_batch_size = 1 def compute_layer_stats(layer): stats = AttrDefault(str, {}) n = stats_batch_size param = u.get_param(layer) d = len(param.flatten()) layer_idx = model.layers.index(layer) assert layer_idx >= 0 assert stats_data.shape[0] == n def backprop_loss(): model.zero_grad() output = model( stats_data) # use last saved data batch for backprop loss = compute_loss(output, stats_targets) loss.backward() return loss, output def backprop_output(): model.zero_grad() output = model(stats_data) output.backward(gradient=torch.ones_like(output)) return output # per-example gradients, n, d loss, output = backprop_loss() At = layer.data_input Bt = layer.grad_output * n G = u.khatri_rao_t(At, Bt) g = G.sum(dim=0, keepdim=True) / n u.check_close(g, u.vec(param.grad).t()) stats.diversity = torch.norm(G, "fro")**2 / g.flatten().norm()**2 stats.gradient_norm = g.flatten().norm() stats.parameter_norm = param.data.flatten().norm() pos_activations = torch.sum(layer.data_output > 0) neg_activations = torch.sum(layer.data_output <= 0) stats.sparsity = pos_activations.float() / (pos_activations + neg_activations) output = backprop_output() At2 = layer.data_input u.check_close(At, At2) B2t = layer.grad_output J = u.khatri_rao_t(At, B2t) H = J.t() @ J / n model.zero_grad() output = model(stats_data) # use last saved data batch for backprop loss = compute_loss(output, stats_targets) hess = u.hessian(loss, param) hess = hess.transpose(2, 3).transpose(0, 1).reshape(d, d) u.check_close(hess, H) u.check_close(hess, H) stats.hessian_norm = u.l2_norm(H) stats.jacobian_norm = u.l2_norm(J) Joutput = J.sum(dim=0) / n stats.jacobian_sensitivity = Joutput.norm() # newton decrement stats.loss_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2) u.check_close(stats.loss_newton, loss) # do line-search to find optimal step def line_search(directionv, start, end, steps=10): """Takes steps between start and end, returns steps+1 loss entries""" param0 = param.data.clone() param0v = u.vec(param0).t() losses = [] for i in range(steps + 1): output = model( stats_data) # use last saved data batch for backprop loss = compute_loss(output, stats_targets) losses.append(loss) offset = start + i * ((end - start) / steps) param1v = param0v + offset * directionv param1 = u.unvec(param1v.t(), param.data.shape[0]) param.data.copy_(param1) output = model( stats_data) # use last saved data batch for backprop loss = compute_loss(output, stats_targets) losses.append(loss) param.data.copy_(param0) return losses # try to take a newton step gradv = g line_losses = line_search(-gradv @ u.pinv(H), 0, 2, steps=10) u.check_equal(line_losses[0], loss) u.check_equal(line_losses[6], 0) assert line_losses[5] > line_losses[6] assert line_losses[7] > line_losses[6] return stats train_dataset = dataset_class(root='/tmp/data', train=True, download=True, transform=train_transform) batch_size = 32 num_workers = 0 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) stats_loader = torch.utils.data.DataLoader(train_dataset, batch_size=stats_batch_size, shuffle=False, num_workers=num_workers) stats_data, stats_targets = next(iter(stats_loader)) model = Net([NUM_CHANNELS * IMAGE_SIZE**2, 8, 8, 1], nonlin=False) setattr(model, 'num_classes', num_classes) model = model.to(device) u.freeze(model.layers[0]) u.freeze(model.layers[2]) # use learning rate 0 to avoid changing parameter vector optim_kwargs = dict(lr=0.001, momentum=0.9, weight_decay=0, l2_reg=0, bias_correction=False, acc_steps=1, curv_type="Cov", curv_shapes={"Linear": "Kron"}, momentum_type="preconditioned", update_inv=False, precondition_grad=False) curv_args = dict(damping=0, ema_decay=1) SecondOrderOptimizer( model, **optim_kwargs, curv_kwargs=curv_args) # call optimizer to add backward hoooks optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) start_epoch = 1 # Run training epochs = 100 for epoch in range(start_epoch, epochs + 1): num_examples_processed = epoch * len( train_loader) * train_loader.batch_size layer_stats = compute_layer_stats(model.layers[1])