def init_function(h_model): h_criterion = torch.nn.CrossEntropyLoss() h_evaluator = SupervisedEvaluator(model=h_model, criterion=h_criterion, device=device) h_train_evaluator = SupervisedEvaluator(model=h_model, criterion=h_criterion, device=device) h_optimizer = torch.optim.Adam(params=h_model.parameters(), lr=1e-3) h_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(h_optimizer, 'max', verbose=True, patience=5, factor=0.5) h_trainer = SupervisedTrainer(model=h_model, optimizer=h_optimizer, criterion=h_criterion, device=device) # Tqdm logger h_pbar = ProgressBar(persist=False, bar_format=config.IGNITE_BAR_FORMAT) h_pbar.attach(h_trainer.engine, metric_names='all') h_tqdm_logger = TqdmLogger(pbar=h_pbar) # noinspection PyTypeChecker h_tqdm_logger.attach_output_handler( h_evaluator.engine, event_name=Events.COMPLETED, tag="validation", global_step_transform=global_step_from_engine(h_trainer.engine), ) # noinspection PyTypeChecker h_tqdm_logger.attach_output_handler( h_train_evaluator.engine, event_name=Events.COMPLETED, tag="train", global_step_transform=global_step_from_engine(h_trainer.engine), ) # Learning rate scheduling # The PyTorch Ignite LRScheduler class does not work with ReduceLROnPlateau h_evaluator.engine.add_event_handler(Events.COMPLETED, lambda engine: h_lr_scheduler.step(engine.state.metrics['accuracy'])) # Model checkpoints h_handler = ModelCheckpoint(config.MODELS_DIR, run.replace('/', '-'), n_saved=1, create_dir=True, require_empty=False, score_name='acc', score_function=lambda engine: engine.state.metrics['accuracy'], global_step_transform=global_step_from_engine(trainer.engine)) h_evaluator.engine.add_event_handler(Events.EPOCH_COMPLETED, h_handler, {'m': model}) # Early stopping h_es_handler = EarlyStopping(patience=15, min_delta=0.0001, score_function=lambda engine: engine.state.metrics['accuracy'], trainer=h_trainer.engine, cumulative_delta=True) h_es_handler.logger.setLevel(logging.DEBUG) h_evaluator.engine.add_event_handler(Events.COMPLETED, h_es_handler) return h_trainer, h_train_evaluator, h_evaluator
def test_global_step_from_engine(): iteration = 12 epoch = 23 trainer = Engine(lambda e, b: None) trainer.state.iteration = iteration trainer.state.epoch = epoch gst = global_step_from_engine(trainer) assert gst(MagicMock(), Events.EPOCH_COMPLETED) == epoch gst = global_step_from_engine(trainer, custom_event_name=Events.ITERATION_COMPLETED) assert gst(MagicMock(), Events.EPOCH_COMPLETED) == iteration
def update_hparams(engine, finished=False): hparam_dict['total_iterations'] = global_step_from_engine(engine)(engine, Events.ITERATION_COMPLETED) hparam_dict['total_epochs'] = global_step_from_engine(engine)(engine, Events.EPOCH_COMPLETED) hparam_dict['timeout'] = not finished if hparam_dict['train_set_size'] is None: hparam_dict['train_set_size'] = hparam_dict['training_set_size'] try: shutil.copyfile(os.path.join(output_dir, 'hparams.pickle'), os.path.join(output_dir, 'hparams.pickle.backup')) with open(os.path.join(output_dir, 'hparams.pickle'), 'wb') as f: pickle.dump(hparam_dict, f) except AttributeError as e: print('Could not pickle one of the total vars.', e) os.replace(os.path.join(output_dir, 'hparams.pickle.backup'), os.path.join(output_dir, 'hparams.pickle'))
def dev_fn(engine, batch): model.eval() optimizer.zero_grad() with torch.no_grad(): batch = tuple(t.to(self.device) for t in batch) labels = batch[3] inputs = { "input_ids": batch[0], "token_type_ids": batch[1], "attention_mask": batch[2], "label_ids": labels } loss, sequence_tags = model(**inputs) score = f1_score(labels.detach().cpu().numpy(), y_pred=sequence_tags.detach().cpu().numpy(), average="macro") if self.n_gpu > 1: loss = loss.mean() ## tensorboard global_step = global_step_from_engine(engine)( engine, engine.last_event_name) # tb_writer.add_scalar('learning_rate', scheduler.get_lr()[0], global_step) tb_writer.add_scalar('dev_loss', loss.item(), global_step) tb_writer.add_scalar('dev_score', score, global_step) return loss.item(), score
def _build_objects(acc_list): model = DummyModel().to(device) optim = torch.optim.SGD(model.parameters(), lr=0.1) lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.5) def update_fn(engine, batch): x = torch.rand((4, 1)).to(device) optim.zero_grad() y = model(x) loss = y.pow(2.0).sum() loss.backward() if idist.has_xla_support: import torch_xla.core.xla_model as xm xm.optimizer_step(optim, barrier=True) else: optim.step() lr_scheduler.step() trainer = Engine(update_fn) evaluator = Engine(lambda e, b: None) acc_iter = iter(acc_list) @evaluator.on(Events.EPOCH_COMPLETED) def setup_result(): evaluator.state.metrics["accuracy"] = next(acc_iter) @trainer.on(Events.EPOCH_COMPLETED) def run_eval(): evaluator.run([0, 1, 2]) def score_function(engine): return engine.state.metrics["accuracy"] save_handler = DiskSaver(dirname, create_dir=True, require_empty=False) early_stop = EarlyStopping(score_function=score_function, patience=2, trainer=trainer) evaluator.add_event_handler(Events.COMPLETED, early_stop) checkpointer = Checkpoint( { "trainer": trainer, "model": model, "optim": optim, "lr_scheduler": lr_scheduler, "early_stop": early_stop, }, save_handler, include_self=True, global_step_transform=global_step_from_engine(trainer), ) evaluator.add_event_handler(Events.COMPLETED, checkpointer) return trainer, evaluator, model, optim, lr_scheduler, early_stop, checkpointer
def setup_checkpoint_saver(self, to_save): if self.hparams.checkpoint_params is not None: from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine handler = Checkpoint(to_save, DiskSaver(self.hparams.checkpoint_params["save_dir"], require_empty=False), n_saved=self.hparams.checkpoint_params["n_saved"], filename_prefix=self.hparams.checkpoint_params["prefix_name"], score_function=self.score_function, score_name="score", global_step_transform=global_step_from_engine(self.trainer)) self.evaluator.add_event_handler(Events.COMPLETED, handler)
def train_fn(engine, batch): model.train() optimizer.zero_grad() batch = tuple(t.to(self.device) for t in batch) labels = batch[3] inputs = { "input_ids": batch[0], "attention_mask": batch[1], "token_type_ids": batch[2], "labels": labels, "is_nested": self.is_nested } loss, sequence_tags = model(**inputs) if not self.is_nested: score = ( sequence_tags == labels).float().detach().cpu().numpy() condition_1 = (labels != self.label_list.index("O") ).detach().cpu().numpy() condition_2 = (labels != self.label_list.index("<PAD>") ).detach().cpu().numpy() patten = np.logical_and(condition_1, condition_2) score = score[patten].mean() else: ''' y_pred = sequence_tags.detach().cpu().numpy() labels_np = labels.detach().cpu().numpy() score = ((y_pred > self.multi_label_threshold) == (labels_np > 0)).mean() ''' score = ((sequence_tags > self.multi_label_threshold) == ( labels > 0)).float().detach().cpu().numpy() condition_1 = (labels != self.label_list.index("O") ).detach().cpu().numpy() condition_2 = (labels != self.label_list.index("<PAD>") ).detach().cpu().numpy() patten = np.logical_and(condition_1, condition_2) score = score[patten].mean() if self.n_gpu > 1: loss = loss.mean() ## tensorboard global_step = global_step_from_engine(engine)( engine, engine.last_event_name) tb_writer.add_scalar('learning_rate', scheduler.get_lr()[0], global_step) tb_writer.add_scalar('train_loss', loss.item(), global_step) tb_writer.add_scalar('train_score', score.item(), global_step) loss.backward() optimizer.step() scheduler.step() model.zero_grad() return loss.item(), score.item()
def save_checkpoint(trainer, evaluator, to_save, score_function, save_dir, n_saved, prefix_name): if save_dir is not None and score_function is not None: handler = Checkpoint( to_save, DiskSaver(save_dir, require_empty=False), n_saved=n_saved, filename_prefix=prefix_name, score_function=score_function, score_name="score", global_step_transform=global_step_from_engine(trainer)) evaluator.add_event_handler(Events.COMPLETED, handler)
def store(self, engine: ignite.engine.Engine): """Evaluation engine store state with computed metrics, that will be send to main logger""" metrics = {} if not hasattr(engine.state, 'metrics') or len( engine.state.metrics) == 0: return kwargs = dict(current_step=global_step_from_engine(self.train_engine)( self.train_engine, self.train_engine.last_event_name)) if self.train_engine else {} for key, val in engine.state.metrics.items(): metric_name = key metrics[metric_name] = val self.liveplot.update(metrics, **kwargs) self.liveplot.send()
# Validset evaluation valid_state = valid_evaluator.run(validset_testset[0]) valid_metrics = {f'valid_{n}': float(v) for n, v in valid_state.metrics.items()} for n, v in valid_metrics.items(): mlflow.log_metric(n, v, step=engine.state.epoch) if not is_nni_run_standalone(): # TODO: make sure `valid_state.metrics` is ordered so that reported default metric to NNI is always the same nni.report_intermediate_result({'default': valid_state.metrics.values()[0], **train_metrics, **valid_metrics}) if backend_conf.rank == 0: event = Events.ITERATION_COMPLETED(every=hp['log_progress_every_iters'] if hp['log_progress_every_iters'] else None) ProgressBar(persist=False, desc='Train evaluation').attach(train_evaluator, event_name=event) ProgressBar(persist=False, desc='Test evaluation').attach(valid_evaluator) log_handler = OutputHandler(tag='train', metric_names=list(metrics.keys()), global_step_transform=global_step_from_engine(trainer)) tb_logger.attach(train_evaluator, log_handler=log_handler, event_name=Events.COMPLETED) log_handler = OutputHandler(tag='test', metric_names=list(metrics.keys()), global_step_transform=global_step_from_engine(trainer)) tb_logger.attach(valid_evaluator, log_handler=log_handler, event_name=Events.COMPLETED) # Store the best model by validation accuracy: common.save_best_model_by_val_score(str(output_path), valid_evaluator, model=model, metric_name='accuracy', n_saved=3, trainer=trainer, tag='val') if hp['log_grads_every_iters'] is not None and hp['log_grads_every_iters'] > 0: tb_logger.attach(trainer, log_handler=GradsHistHandler(model, tag=model.__class__.__name__), event_name=Events.ITERATION_COMPLETED(every=hp['log_grads_every_iters'])) if hp['crash_iteration'] is not None and hp['crash_iteration'] >= 0: @trainer.on(Events.ITERATION_STARTED(once=hp['crash_iteration'])) def _(engine): raise Exception('STOP at iteration: {}'.format(engine.state.iteration))
def run(model, dataset, val_dataset, device=None, optimizer=None, criterion=None, epochs=10, batch_size=1, log_interval=10, model_name='unknown', log_dir=None, save=True, model_name_prefix='', path=None): start_time = time.time() if path is None: path = join(MODELS_PATH, model_name) # writer = None # if log_dir is not None and log_dir != '': # current_time = datetime.now().strftime('%b%d_%H-%M-%S') # writer = SummaryWriter(log_dir=log_dir+current_time) if device is None: device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f'Model device: {device}') model.to(device) train_loader = DataLoader(dataset=dataset, batch_size=batch_size, collate_fn=batcher(device)) val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, collate_fn=batcher(device)) if optimizer is None: # create the optimizer optimizer = torch.optim.Adagrad(model.parameters(), lr=0.01, weight_decay=1e-4) if criterion is None: criterion = nn.CrossEntropyLoss( weight=torch.tensor([1.0, 7.0], device=device)) def update_model(engine, batch): g = batch.graph optimizer.zero_grad() outputs = model(g) loss = criterion(outputs, batch.labels) # if torch.isnan(loss).any(): # print(f'Loss is NAN at step: {engine.state.iteration}') # return 0 loss.backward() optimizer.step() return loss.item() trainer = Engine(update_model) training_history = {'accuracy': [], 'loss': []} whole_training_history = {'loss': []} validation_history = {'accuracy': [], 'loss': []} val_metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)} train_evaluator = create_my_supervised_evaluator(model, metrics=val_metrics) val_evaluator = create_my_supervised_evaluator(model, metrics=val_metrics) handler = EarlyStopping(patience=50, score_function=score_function, trainer=trainer) val_evaluator.add_event_handler(Events.COMPLETED, handler) to_save = {'model': model} if save: handler = Checkpoint( to_save, DiskSaver(path, create_dir=True, require_empty=False), n_saved=1, score_function=score_function, score_name="loss", filename_prefix=model_name_prefix, global_step_transform=global_step_from_engine(trainer)) val_evaluator.add_event_handler(Events.COMPLETED, handler) @trainer.on(Events.ITERATION_COMPLETED(every=log_interval)) def log_training_loss(engine): whole_training_history['loss'].append(engine.state.output) # if writer is not None: # writer.add_scalar("training/loss", engine.state.output, engine.state.iteration) @trainer.on(Events.EPOCH_COMPLETED) def log_training_results(engine): train_evaluator.run(train_loader) metrics = train_evaluator.state.metrics avg_accuracy = metrics["accuracy"] avg_loss = metrics["loss"] print( f'Epoch {engine.state.epoch:05d} | ' f'Train: loss {avg_loss:7.4f}, acc {avg_accuracy:7.4f} | ', end='') training_history['accuracy'].append(avg_accuracy) training_history['loss'].append(avg_loss) # if writer is not None: # writer.add_scalar("training/avg_loss", avg_loss, engine.state.epoch) # writer.add_scalar("training/avg_accuracy", avg_accuracy, engine.state.epoch) @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): val_evaluator.run(val_loader) metrics = val_evaluator.state.metrics avg_accuracy = metrics["accuracy"] avg_loss = metrics["loss"] print(f'Val: loss {avg_loss:7.4f}, acc {avg_accuracy:7.4f} |') # print(f'Val results | ' # f'Epoch {engine.state.epoch:05d} | ' # f'Avg loss {avg_loss:.4f} | ' # f'Avg accuracy {avg_accuracy:.4f} |') validation_history['accuracy'].append(avg_accuracy) validation_history['loss'].append(avg_loss) # if writer is not None: # writer.add_scalar("valdation/avg_loss", avg_loss, engine.state.epoch) # writer.add_scalar("valdation/avg_accuracy", avg_accuracy, engine.state.epoch) # kick everything off trainer.run(train_loader, max_epochs=epochs) # if writer is not None: # writer.close() print(f'Model trained in {(time.time() - start_time):.1f}s') return training_history, validation_history, whole_training_history
def adv_prune_train_loop(model, params, ds, dset, min_y, base_data, model_id, prune_type, device, batch_size, tpa, max_epochs=5): #assert prune_type in ['global_unstructured', 'structured'] total_prune_amount = tpa remove_amount = tpa ds_train, ds_valid = ds train_set, valid_set = dset min_y_train, min_y_val = min_y train_set, valid_set = dset total_prune_amount = tpa original_model = copy.deepcopy(model) original_model.eval() model_id = f'{model_id}_{prune_type}_pruning_{tpa}_l1' valid_freq = 200 * 500 // batch_size // 3 conv_layers = [model.conv1] for sequential in [model.layer1, model.layer2, model.layer3, model.layer4]: for bottleneck in sequential: conv_layers.extend([bottleneck.conv1, bottleneck.conv2, bottleneck.conv3]) conv_layers = conv_layers[:22] def prune_model(model): print(f'pruned model by {total_prune_amount}') if prune_type == 'global_unstructured': parameters_to_prune = [(layer, 'weight') for layer in conv_layers] prune.global_unstructured( parameters_to_prune, pruning_method=prune.L1Unstructured, amount=total_prune_amount, ) else: for layer in conv_layers: prune_model(model) def valid_eval(model, dataset, dataloader, device, label): right = 0 total = 0 model.eval() with torch.no_grad(): for i, data in tqdm(enumerate(dataloader), total=len(dataset) / dataloader.batch_size): data, y = data data = data.to(device) y = y.to(device) - label ans = model.forward(data) right += torch.sum(torch.eq(torch.argmax(ans, dim=1), y)) total += y.shape[0] return right/total valid_acc = valid_eval(model, valid_set, ds_valid, device, min_y_val) print('initial accuracy:', valid_acc.item()) with create_summary_writer(model, ds_train, base_data, model_id, device=device) as writer: lr = params['lr'] mom = params['momentum'] wd = params['l2_wd'] optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=mom, weight_decay=wd) sched = ReduceLROnPlateau(optimizer, factor=0.5, patience=5) funcs = {'accuracy': Accuracy(), 'loss': Loss(F.cross_entropy)} loss = funcs['loss']._loss_fn acc_metric = Accuracy(device=device) loss_metric = Loss(F.cross_entropy, device=device) acc_val_metric = Accuracy(device=device) loss_val_metric = Loss(F.cross_entropy, device=device) # attack = GradientSignAttack(original_model, loss_fn=loss, eps=0.2) def train_step(engine, batch): model.train() x, y = batch x = x.to(device) y = y.to(device) - min_y_train # with ctx_noparamgrad_and_eval(model): # x_adv = attack.perturb(x, y) # optimizer.zero_grad() # x = torch.cat((x, x_adv)) # y = torch.cat((y, y)) ans = model.forward(x) l = loss(ans, y) optimizer.zero_grad() l.backward() optimizer.step() with torch.no_grad(): for layer in conv_layers: layer.weight *= layer.weight_mask return l.item() trainer = Engine(train_step) def train_eval_step(engine, batch): model.eval() x, y = batch x = x.to(device) y = y.to(device) - min_y_train # x_adv = attack.perturb(x, y) # x = torch.cat((x, x_adv)) # y = torch.cat((y, y)) with torch.no_grad(): ans = model.forward(x) return ans, y train_evaluator = Engine(train_eval_step) acc_metric.attach(train_evaluator, "accuracy") loss_metric.attach(train_evaluator, 'loss') def validation_step(engine, batch): model.eval() x, y = batch x = x.to(device) y = y.to(device) - min_y_val # x_adv = attack.perturb(x, y) # x = torch.cat((x, x_adv)) # y = torch.cat((y, y)) with torch.no_grad(): ans = model.forward(x) return ans, y valid_evaluator = Engine(validation_step) acc_val_metric.attach(valid_evaluator, "accuracy") loss_val_metric.attach(valid_evaluator, 'loss') @trainer.on(Events.ITERATION_COMPLETED(every=valid_freq)) def log_validation_results(engine): valid_evaluator.run(ds_valid) metrics = valid_evaluator.state.metrics valid_avg_accuracy = metrics['accuracy'] avg_nll = metrics['loss'] print("Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}" .format(engine.state.epoch, valid_avg_accuracy, avg_nll)) writer.add_scalar("validation/avg_loss", avg_nll, engine.state.epoch) writer.add_scalar("validation/avg_accuracy", valid_avg_accuracy, engine.state.epoch) writer.add_scalar("validation/avg_error", 1. - valid_avg_accuracy, engine.state.epoch) @trainer.on(Events.EPOCH_COMPLETED) def lr_scheduler(engine): metrics = valid_evaluator.state.metrics avg_nll = metrics['accuracy'] sched.step(avg_nll) @trainer.on(Events.ITERATION_COMPLETED(every=100)) def log_training_loss(engine): batch = engine.state.batch ds = DataLoader(TensorDataset(*batch), batch_size=batch_size) train_evaluator.run(ds) metrics = train_evaluator.state.metrics accuracy = metrics['accuracy'] nll = metrics['loss'] iter = (engine.state.iteration - 1) % len(ds_train) + 1 if (iter % 50) == 0: print("Epoch[{}] Iter[{}/{}] Accuracy: {:.2f} Loss: {:.2f}" .format(engine.state.epoch, iter, len(ds_train), accuracy, nll)) writer.add_scalar("batchtraining/detloss", nll, engine.state.epoch) writer.add_scalar("batchtraining/accuracy", accuracy, engine.state.iteration) writer.add_scalar("batchtraining/error", 1. - accuracy, engine.state.iteration) writer.add_scalar("batchtraining/loss", engine.state.output, engine.state.iteration) @trainer.on(Events.EPOCH_COMPLETED) def log_lr(engine): writer.add_scalar("lr", optimizer.param_groups[0]['lr'], engine.state.epoch) @trainer.on(Events.ITERATION_COMPLETED(every=valid_freq)) def validation_value(engine): metrics = valid_evaluator.state.metrics valid_avg_accuracy = metrics['accuracy'] return valid_avg_accuracy to_save = {'model': model} handler = Checkpoint(to_save, DiskSaver(os.path.join(base_data, model_id), create_dir=True), score_function=validation_value, score_name="val_acc", global_step_transform=global_step_from_engine(trainer), n_saved=None) # kick everything off trainer.add_event_handler(Events.ITERATION_COMPLETED(every=valid_freq), handler) trainer.run(ds_train, max_epochs=max_epochs)
def __init__(self, trainer_args={}): self.log_prefix = trainer_args.get("model_name") self.save_path = os.path.join(trainer_args.get("fold_dir")) create_dirs(self.save_path) self.device = trainer_args.get("device") self.epochs = trainer_args.get("epochs") self.early_stopping = trainer_args.get("early_stopping", None) self.exclude_anomalies = trainer_args.get("exclude_anomalies") self.include_metrics = trainer_args.get("include_metrics") self.model_class = trainer_args.get("model_class") self.model_args = trainer_args.get("model_args", {}) self.optimizer_class = trainer_args.get("optimizer_class") self.optimizer_args = trainer_args.get("optimizer_args", {}) self.training_data_stats = trainer_args.get("training_data_stats") self.loss = trainer_args.get("loss_func") self.node_classes = trainer_args.get("node_classes") self.graph_classes = trainer_args.get("graph_classes") self.score_function = score_function self.resume_from_checkpoint = trainer_args.get( "resume_from_checkpoint", {}) # create model, optimizer self.model = self.model_class({ **self.model_args, "pred_collector_function": self._pred_collector_function }).to(self.device).double() self.optimizer = self.optimizer_class(self.model.parameters(), **self.optimizer_args) self.trainer = create_supervised_trainer( self.model, self.optimizer, loss_fn=self.loss, device=self.device, non_blocking=True, output_transform=lambda x, y, y_pred, loss: (y_pred, y) # so that loss-metric can work with transformed output ) ##################### log some values ################################ if not len( trainer_args.get("resume_from_checkpoint", {}) ): # only print all this on initial training start, not on resume self.custom_print("Device:", self.device) self.custom_print("Max. epochs:", self.epochs) self.custom_print("Early stopping:", self.early_stopping) self.custom_print("Excluded Anomalies:", self.exclude_anomalies) self.custom_print("Included Metrics:", self.include_metrics) self.custom_print("Loss class:", self.loss) self.custom_print("Node anomaly classes:", self.node_classes) self.custom_print("Graph anomaly classes:", self.graph_classes) self.custom_print("Model class:", self.model_class) self.custom_print("Model args:", json.dumps(self.model_args)) self.custom_print("Model - All Parameters:", self.model.all_params) self.custom_print("Model - Trainable Parameters:", self.model.all_trainable_params) self.custom_print("Optimizer args:", json.dumps(self.optimizer_args)) self.custom_print("Training Data Statistics:", self.training_data_stats) self.custom_print("Train indices", trainer_args.get("train_indices")) self.custom_print("Val indices", trainer_args.get("val_indices")) self.custom_print("Test indices", trainer_args.get("test_indices")) ###################################################################### # configure behavior for early stopping self.stopper = None if self.early_stopping: self.stopper = EarlyStopping(patience=self.early_stopping, score_function=self.score_function, trainer=self.trainer) # configure behavior for checkpointing to_save: dict = { "model_state_dict": self.model, "optimizer_state_dict": self.optimizer, "trainer_state_dict": self.trainer } if self.stopper: to_save["stopper_state_dict"] = self.stopper save_handler = DiskSaver(self.save_path, create_dir=True, require_empty=False, atomic=True) # save the best checkpoints self.best_checkpoint_handler = Checkpoint( to_save, save_handler, filename_prefix=f"{self.log_prefix}_best", score_name="val_loss", score_function=self.score_function, include_self=True, global_step_transform=global_step_from_engine(self.trainer), n_saved=5) # save the latest checkpoint (important for resuming training) self.latest_checkpoint_handler = Checkpoint( to_save, save_handler, filename_prefix=f"{self.log_prefix}_latest", include_self=True, global_step_transform=global_step_from_engine(self.trainer), n_saved=1) # resume from checkpoint if len(self.resume_from_checkpoint): self.model, self.optimizer, self.trainer, self.stopper, self.best_checkpoint_handler, self.latest_checkpoint_handler = self._load_checkpoint( self.model, self.optimizer, self.trainer, self.stopper, self.best_checkpoint_handler, self.latest_checkpoint_handler, checkpoint_path_dict=self.resume_from_checkpoint) self.persist_collection = False self.persist_collection_dict: OrderedDict = OrderedDict()
def attach_handlers(run, model, optimizer, learning_rule, trainer, evaluator, train_loader, val_loader, params): # Metrics UnitConvergence(model[0], learning_rule.norm).attach(trainer.engine, 'unit_conv') # Tqdm logger pbar = ProgressBar(persist=True, bar_format=config.IGNITE_BAR_FORMAT) pbar.attach(trainer.engine, metric_names='all') tqdm_logger = TqdmLogger(pbar=pbar) # noinspection PyTypeChecker tqdm_logger.attach_output_handler( evaluator.engine, event_name=Events.COMPLETED, tag="validation", global_step_transform=global_step_from_engine(trainer.engine), ) # Evaluator evaluator.attach(trainer.engine, Events.EPOCH_COMPLETED(every=100), train_loader, val_loader) # Learning rate scheduling lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda epoch: 1 - epoch / params['epochs']) lr_scheduler = LRScheduler(lr_scheduler) trainer.engine.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler) # Early stopping mc_handler = ModelCheckpoint(config.MODELS_DIR, run.replace('/', '-'), n_saved=1, create_dir=True, require_empty=False, global_step_transform=global_step_from_engine(trainer.engine)) trainer.engine.add_event_handler(Events.EPOCH_COMPLETED, mc_handler, {'m': model}) # Create a TensorBoard logger tb_logger = TensorboardLogger(log_dir=os.path.join(config.TENSORBOARD_DIR, run)) images, labels = next(iter(train_loader)) tb_logger.writer.add_graph(copy.deepcopy(model).cpu(), images) tb_logger.writer.add_hparams(params, {}) # noinspection PyTypeChecker tb_logger.attach_output_handler( evaluator.engine, event_name=Events.COMPLETED, tag="validation", metric_names="all", global_step_transform=global_step_from_engine(trainer.engine), ) # noinspection PyTypeChecker tb_logger.attach_output_handler( trainer.engine, event_name=Events.EPOCH_COMPLETED, tag="train", metric_names=["unit_conv"] ) input_shape = tuple(next(iter(train_loader))[0].shape[1:]) tb_logger.attach(trainer.engine, log_handler=WeightsImageHandler(model, input_shape), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(trainer.engine, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.EPOCH_STARTED) # tb_logger.attach(trainer.engine, # log_handler=WeightsScalarHandler(model, layer_names=['linear1', 'linear2']), # event_name=Events.EPOCH_COMPLETED) # tb_logger.attach(trainer.engine, # log_handler=WeightsHistHandler(model, layer_names=['linear1', 'linear2']), # event_name=Events.EPOCH_COMPLETED) # tb_logger.attach(trainer.engine, # log_handler=ActivationsHistHandler(model, layer_names=['batch_norm', 'repu']), # event_name=Events.ITERATION_COMPLETED) # tb_logger.attach(trainer.engine, # log_handler=NumActivationsScalarHandler(model, layer_names=['repu']), # event_name=Events.ITERATION_COMPLETED) # tb_logger.attach(trainer.engine, # log_handler=ActivationsScalarHandler(model, reduction=torch.mean, # layer_names=['batch_norm', 'repu']), # event_name=Events.ITERATION_COMPLETED) # tb_logger.attach(trainer.engine, # log_handler=ActivationsScalarHandler(model, reduction=torch.std, # layer_names=['batch_norm', 'repu']), # event_name=Events.ITERATION_COMPLETED) return tb_logger