def run_training(model, train, valid, optimizer, loss, lr_find=False): print_file(f'Experiment: {rcp.experiment}\nDescription:{rcp.description}', f'{rcp.base_path}description.txt') print_file(model, f'{rcp.models_path}model.txt') print_file(get_transforms(), f'{rcp.models_path}transform_{rcp.stage}.txt') # Data train.transform = get_transforms() valid.transform = get_transforms() train.save_csv(f'{rcp.base_path}train_df_{rcp.stage}.csv') valid.save_csv(f'{rcp.base_path}valid_df_{rcp.stage}.csv') train_loader = DataLoader(train, batch_size=rcp.bs, num_workers=8, shuffle=rcp.shuffle_batch) valid_loader = DataLoader(valid, batch_size=rcp.bs, num_workers=8, shuffle=rcp.shuffle_batch) if lr_find: lr_finder(model, optimizer, loss, train_loader, valid_loader) one_batch = next(iter(train_loader)) dot = make_dot(model(one_batch[0].to(cfg.device)), params=dict(model.named_parameters())) dot.render(f'{rcp.models_path}graph', './', format='png', cleanup=True) summary(model, one_batch[0].shape[-3:], batch_size=rcp.bs, device=cfg.device, to_file=f'{rcp.models_path}summary_{rcp.stage}.txt') # Engines trainer = create_supervised_trainer(model, optimizer, loss, device=cfg.device) t_evaluator = create_supervised_evaluator(model, metrics={ 'accuracy': Accuracy(), 'nll': Loss(loss), 'precision': Precision(average=True), 'recall': Recall(average=True), 'topK': TopKCategoricalAccuracy() }, device=cfg.device) v_evaluator = create_supervised_evaluator( model, metrics={ 'accuracy': Accuracy(), 'nll': Loss(loss), 'precision_avg': Precision(average=True), 'recall_avg': Recall(average=True), 'topK': TopKCategoricalAccuracy(), 'conf_mat': ConfusionMatrix(num_classes=len(valid.classes), average=None), }, device=cfg.device) # Tensorboard tb_logger = TensorboardLogger(log_dir=f'{rcp.tb_log_path}{rcp.stage}') tb_writer = tb_logger.writer tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer, "lr"), event_name=Events.EPOCH_STARTED) tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=GradsScalarHandler(model), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED) @trainer.on(Events.EPOCH_COMPLETED) def tb_and_log_training_stats(engine): t_evaluator.run(train_loader) v_evaluator.run(valid_loader) tb_and_log_train_valid_stats(engine, t_evaluator, v_evaluator, tb_writer) @trainer.on( Events.ITERATION_COMPLETED(every=int(1 + len(train_loader) / 100))) def print_dash(engine): print('-', sep='', end='', flush=True) if cfg.show_batch_images: @trainer.on(Events.STARTED) def show_batch_images(engine): imgs, lbls = next(iter(train_loader)) denormalize = DeNormalize(**rcp.transforms.normalize) for i in range(len(imgs)): imgs[i] = denormalize(imgs[i]) imgs = imgs.to(cfg.device) grid = thv.utils.make_grid(imgs) tb_writer.add_image('images', grid, 0) tb_writer.add_graph(model, imgs) tb_writer.flush() if cfg.show_top_losses: @trainer.on(Events.COMPLETED) def show_top_losses(engine, k=6): nll_loss = nn.NLLLoss(reduction='none') df = predict_dataset(model, valid, nll_loss, transform=None, bs=rcp.bs, device=cfg.device) df.sort_values('loss', ascending=False, inplace=True) df.reset_index(drop=True, inplace=True) for i, row in df.iterrows(): img = cv2.imread(str(row['fname'])) img = th.as_tensor(img.transpose(2, 0, 1)) # #CHW tag = f'TopLoss_{engine.state.epoch}/{row.loss:.4f}/{row.target}/{row.pred}/{row.pred2}' tb_writer.add_image(tag, img, 0) if i >= k - 1: break tb_writer.flush() if cfg.tb_projector: images, labels = train.select_n_random(250) # get the class labels for each image class_labels = [train.classes[lab] for lab in labels] # log embeddings features = images.view(-1, images.shape[-1] * images.shape[-2]) tb_writer.add_embedding(features, metadata=class_labels, label_img=images) if cfg.log_pr_curve: @trainer.on(Events.COMPLETED) def log_pr_curve(engine): """ 1. gets the probability predictions in a test_size x num_classes Tensor 2. gets the preds in a test_size Tensor takes ~10 seconds to run """ class_probs = [] class_preds = [] with th.no_grad(): for data in valid_loader: imgs, lbls = data imgs, lbls = imgs.to(cfg.device), lbls.to(cfg.device) output = model(imgs) class_probs_batch = [ th.softmax(el, dim=0) for el in output ] _, class_preds_batch = th.max(output, 1) class_probs.append(class_probs_batch) class_preds.append(class_preds_batch) test_probs = th.cat([th.stack(batch) for batch in class_probs]) test_preds = th.cat(class_preds) for i in range(len(valid.classes)): """ Takes in a "class_index" from 0 to 9 and plots the corresponding precision-recall curve""" tensorboard_preds = test_preds == i tensorboard_probs = test_probs[:, i] tb_writer.add_pr_curve(f'{rcp.stage}/{valid.classes[i]}', tensorboard_preds, tensorboard_probs, global_step=engine.state.epoch, num_thresholds=127) tb_writer.flush() print() if cfg.lr_scheduler: # lr_scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5, factor=.5, min_lr=1e-7, verbose=True) # v_evaluator.add_event_handler(Events.EPOCH_COMPLETED, lambda engine: lr_scheduler.step(v_evaluator.state.metrics['nll'])) lr_scheduler = DelayedCosineAnnealingLR(optimizer, 10, 5) trainer.add_event_handler( Events.EPOCH_COMPLETED, lambda engine: lr_scheduler.step(trainer.state.epoch)) if cfg.early_stopping: def score_function(engine): score = -1 * round(engine.state.metrics['nll'], 5) # score = engine.state.metrics['accuracy'] return score es_handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer) v_evaluator.add_event_handler(Events.COMPLETED, es_handler) if cfg.save_last_checkpoint: @trainer.on(Events.EPOCH_COMPLETED(every=1)) def save_last_checkpoint(engine): checkpoint = {} objects = {'model': model, 'optimizer': optimizer} if cfg.lr_scheduler: objects['lr_scheduler'] = lr_scheduler for k, obj in objects.items(): checkpoint[k] = obj.state_dict() th.save(checkpoint, f'{rcp.models_path}last_{rcp.stage}_checkpoint.pth') if cfg.save_best_checkpoint: def score_function(engine): score = -1 * round(engine.state.metrics['nll'], 5) # score = engine.state.metrics['accuracy'] return score objects = {'model': model, 'optimizer': optimizer} if cfg.lr_scheduler: objects['lr_scheduler'] = lr_scheduler save_best = Checkpoint( objects, DiskSaver(f'{rcp.models_path}', require_empty=False, create_dir=True), n_saved=4, filename_prefix=f'best_{rcp.stage}', score_function=score_function, score_name='val_loss', global_step_transform=global_step_from_engine(trainer)) v_evaluator.add_event_handler(Events.EPOCH_COMPLETED(every=1), save_best) load_checkpoint = False if load_checkpoint: resume_epoch = 6 cp = f'{rcp.models_path}last_{rcp.stage}_checkpoint.pth' obj = th.load(f'{cp}') Checkpoint.load_objects(objects, obj) @trainer.on(Events.STARTED) def resume_training(engine): engine.state.iteration = (resume_epoch - 1) * len( engine.state.dataloader) engine.state.epoch = resume_epoch - 1 if cfg.save_confusion_matrix: @trainer.on(Events.STARTED) def init_best_loss(engine): engine.state.metrics['best_loss'] = 1e99 @trainer.on(Events.EPOCH_COMPLETED) def confusion_matric(engine): if engine.state.metrics['best_loss'] > v_evaluator.state.metrics[ 'nll']: engine.state.metrics['best_loss'] = v_evaluator.state.metrics[ 'nll'] cm = v_evaluator.state.metrics['conf_mat'] cm_df = pd.DataFrame(cm.numpy(), index=valid.classes, columns=valid.classes) pretty_plot_confusion_matrix( cm_df, f'{rcp.results_path}cm_{rcp.stage}_{trainer.state.epoch}.png', False) if cfg.log_stats: class Hook: def __init__(self, module): self.name = module[0] self.hook = module[1].register_forward_hook(self.hook_fn) self.stats_mean = 0 self.stats_std = 0 def hook_fn(self, module, input, output): self.stats_mean = output.mean() self.stats_std = output.std() def close(self): self.hook.remove() hookF = [Hook(layer) for layer in list(model.cnn.named_children())] @trainer.on(Events.ITERATION_COMPLETED) def log_stats(engine): std = {} mean = {} for hook in hookF: tb_writer.add_scalar(f'std/{hook.name}', hook.stats_std, engine.state.iteration) tb_writer.add_scalar(f'mean/{hook.name}', hook.stats_mean, engine.state.iteration) cfg.save_yaml() rcp.save_yaml() print(f'# batches: train: {len(train_loader)}, valid: {len(valid_loader)}') trainer.run(data=train_loader, max_epochs=rcp.max_epochs) tb_writer.close() tb_logger.close() return model
net = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ) loss = monai.losses.DiceLoss(do_sigmoid=True) lr = 1e-3 opt = torch.optim.Adam(net.parameters(), lr) device = torch.device('cuda:0') # ignite trainer expects batch=(img, seg) and returns output=loss at every iteration, # user can add output_transform to return other values, like: y_pred, y, etc. trainer = create_supervised_trainer(net, opt, loss, device, False) # adding checkpoint handler to save models (network params and optimizer stats) during training checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'net': net, 'opt': opt }) # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't set metrics for trainer here, so just print loss, user can also customize print functions
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval, log_dir): train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size) model = Net() writer = create_summary_writer(model, train_loader, log_dir) device = 'cpu' if torch.cuda.is_available(): device = 'cuda' optimizer = SGD(model.parameters(), lr=lr, momentum=momentum) trainer = create_supervised_trainer(model, optimizer, F.nll_loss, device=device) evaluator = create_supervised_evaluator(model, metrics={ 'accuracy': Accuracy(), 'nll': Loss(F.nll_loss) }, device=device) @trainer.on(Events.ITERATION_COMPLETED(every=log_interval)) def log_training_loss(engine): print("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}" "".format(engine.state.epoch, engine.state.iteration, len(train_loader), engine.state.output)) writer.add_scalar("training/loss", engine.state.output, engine.state.iteration) @trainer.on(Events.EPOCH_COMPLETED) def log_training_results(engine): evaluator.run(train_loader) metrics = evaluator.state.metrics avg_accuracy = metrics['accuracy'] avg_nll = metrics['nll'] print( "Training Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}" .format(engine.state.epoch, avg_accuracy, avg_nll)) writer.add_scalar("training/avg_loss", avg_nll, engine.state.epoch) writer.add_scalar("training/avg_accuracy", avg_accuracy, engine.state.epoch) @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): evaluator.run(val_loader) metrics = evaluator.state.metrics avg_accuracy = metrics['accuracy'] avg_nll = metrics['nll'] print( "Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}" .format(engine.state.epoch, avg_accuracy, avg_nll)) writer.add_scalar("valdation/avg_loss", avg_nll, engine.state.epoch) writer.add_scalar("valdation/avg_accuracy", avg_accuracy, engine.state.epoch) # kick everything off trainer.run(train_loader, max_epochs=epochs) writer.close()
cross_entropy_loss = nn.CrossEntropyLoss() adam_optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999)) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') home = os.environ['HOME'] writer = SummaryWriter(home + '/log.json') lowest_loss = np.Inf train_loader, val_loader, test_loader = load_data(args.batch_size) model.to(device) if args.summary: summary(model, (1, 28, 28)) print('Batch size: ', args.batch_size) print('Epochs:', args.num_epochs) trainer = create_supervised_trainer(model, adam_optimizer, cross_entropy_loss, device=device) evaluator = create_supervised_evaluator(model, metrics={"accuracy": Accuracy(), "cross": Loss(cross_entropy_loss), "prec": Precision(), "recall": Recall()}, device=device) desc = "ITERATION - loss: {:.2f}" pbar = tqdm( initial=0, leave=False, total=len(train_loader), desc=desc.format(0) ) @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): iter = (engine.state.iteration - 1) % len(train_loader) + 1
def do_train(model, train_loader, val_loader, optimizer, scheduler, checkpointer, loss_fn, device, checkpoint_period, log_period, epochs): logger = logging.getLogger("template_model.train") logger.info("Start training") trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device) evaluator = create_supervised_evaluator(model, metrics={ 'accuracy': Accuracy(), 'ce_loss': Loss(loss_fn) }, device=device) desc = "ITERATION -loss: {:.3f}" pbar = tqdm(initial=0, leave=False, total=len(train_loader), desc=desc.format(0)) @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): iter = (engine.state.iteration - 1) % len(train_loader) + 1 if iter % log_period == 0: pbar.desc = desc.format(engine.state.output) pbar.update(log_period) @trainer.on(Events.EPOCH_COMPLETED) def log_training_results(engine): pbar.refresh() evaluator.run(train_loader) metrics = evaluator.state.metrics avg_accuracy = metrics['accuracy'] avg_loss = metrics['ce_loss'] # tqdm.write("Training Results - Epoch: {} Avg accuracy: {:.3f} Avg Loss: {:.3f}" # .format(engine.state.epoch, avg_accuracy, avg_loss) # ) logger.info( "Training Results - Epoch: {} Avg accuracy: {:.3f} Avg Loss: {:.3f}" .format(engine.state.epoch, avg_accuracy, avg_loss)) if val_loader is not None: @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): evaluator.run(val_loader) metrics = evaluator.state.metrics avg_accuracy = metrics['accuracy'] avg_loss = metrics['ce_loss'] # tqdm.write("Validation Results - Epoch: {} Avg accuracy: {:.3f} Avg Loss: {:.3f}" # .format(engine.state.epoch, avg_accuracy, avg_loss) # ) logger.info( "Validation Results - Epoch: {} Avg accuracy: {:.3f} Avg Loss: {:.3f}" .format(engine.state.epoch, avg_accuracy, avg_loss)) pbar.n = pbar.last_print_n = 0 trainer.run(train_loader, max_epochs=epochs) pbar.close()
def run( train_batch_size, val_batch_size, epochs, lr, momentum, log_interval, log_dir, checkpoint_every, resume_from, crash_iteration=-1, deterministic=False, ): # Setup seed to have same model's initialization: manual_seed(75) train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size) model = Net() writer = SummaryWriter(log_dir=log_dir) device = "cpu" if torch.cuda.is_available(): device = "cuda" model.to(device) # Move model before creating optimizer criterion = nn.NLLLoss() optimizer = SGD(model.parameters(), lr=lr, momentum=momentum) lr_scheduler = StepLR(optimizer, step_size=1, gamma=0.5) # Setup trainer and evaluator if deterministic: tqdm.write("Setup deterministic trainer") trainer = create_supervised_trainer(model, optimizer, criterion, device=device, deterministic=deterministic) evaluator = create_supervised_evaluator(model, metrics={ "accuracy": Accuracy(), "nll": Loss(criterion) }, device=device) # Apply learning rate scheduling @trainer.on(Events.EPOCH_COMPLETED) def lr_step(engine): lr_scheduler.step() desc = "Epoch {} - loss: {:.4f} - lr: {:.4f}" pbar = tqdm(initial=0, leave=False, total=len(train_loader), desc=desc.format(0, 0, lr)) @trainer.on(Events.ITERATION_COMPLETED(every=log_interval)) def log_training_loss(engine): lr = optimizer.param_groups[0]["lr"] pbar.desc = desc.format(engine.state.epoch, engine.state.output, lr) pbar.update(log_interval) writer.add_scalar("training/loss", engine.state.output, engine.state.iteration) writer.add_scalar("lr", lr, engine.state.iteration) if crash_iteration > 0: @trainer.on(Events.ITERATION_COMPLETED(once=crash_iteration)) def _(engine): raise Exception(f"STOP at {engine.state.iteration}") if resume_from is not None: @trainer.on(Events.STARTED) def _(engine): pbar.n = engine.state.iteration % engine.state.epoch_length @trainer.on(Events.EPOCH_COMPLETED) def log_training_results(engine): pbar.refresh() evaluator.run(train_loader) metrics = evaluator.state.metrics avg_accuracy = metrics["accuracy"] avg_nll = metrics["nll"] tqdm.write( f"Training Results - Epoch: {engine.state.epoch} Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}" ) writer.add_scalar("training/avg_loss", avg_nll, engine.state.epoch) writer.add_scalar("training/avg_accuracy", avg_accuracy, engine.state.epoch) # Compute and log validation metrics @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): evaluator.run(val_loader) metrics = evaluator.state.metrics avg_accuracy = metrics["accuracy"] avg_nll = metrics["nll"] tqdm.write( f"Validation Results - Epoch: {engine.state.epoch} Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}" ) pbar.n = pbar.last_print_n = 0 writer.add_scalar("valdation/avg_loss", avg_nll, engine.state.epoch) writer.add_scalar("valdation/avg_accuracy", avg_accuracy, engine.state.epoch) # Setup object to checkpoint objects_to_checkpoint = { "trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler } training_checkpoint = Checkpoint( to_save=objects_to_checkpoint, save_handler=DiskSaver(log_dir, require_empty=False), n_saved=None, global_step_transform=lambda *_: trainer.state.epoch, ) trainer.add_event_handler(Events.EPOCH_COMPLETED(every=checkpoint_every), training_checkpoint) # Setup logger to print and dump into file: model weights, model grads and data stats # - first 3 iterations # - 4 iterations after checkpointing # This helps to compare resumed training with checkpointed training def log_event_filter(e, event): if event in [1, 2, 3]: return True elif 0 <= (event % (checkpoint_every * e.state.epoch_length)) < 5: return True return False fp = Path(log_dir) / ("run.log" if resume_from is None else "resume_run.log") fp = fp.as_posix() for h in [log_data_stats, log_model_weights, log_model_grads]: trainer.add_event_handler( Events.ITERATION_COMPLETED(event_filter=log_event_filter), h, model=model, fp=fp) if resume_from is not None: tqdm.write(f"Resume from the checkpoint: {resume_from}") checkpoint = torch.load(resume_from) Checkpoint.load_objects(to_load=objects_to_checkpoint, checkpoint=checkpoint) try: # Synchronize random states manual_seed(15) trainer.run(train_loader, max_epochs=epochs) except Exception as e: import traceback print(traceback.format_exc()) pbar.close() writer.close()