def test_weights_scalar_handler_wrong_setup(): with pytest.raises( TypeError, match="Argument model should be of type torch.nn.Module"): WeightsScalarHandler(None) model = MagicMock(spec=torch.nn.Module) with pytest.raises(TypeError, match="Argument reduction should be callable"): WeightsScalarHandler(model, reduction=123) with pytest.raises( TypeError, match="Output of the reduction function should be a scalar"): WeightsScalarHandler(model, reduction=lambda x: x) wrapper = WeightsScalarHandler(model) mock_logger = MagicMock() mock_engine = MagicMock() with pytest.raises( RuntimeError, match= "Handler 'WeightsScalarHandler' works only with TensorboardLogger" ): wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
def test_weights_scalar_handler_frozen_layers(dummy_model_factory): model = dummy_model_factory(with_grads=True, with_frozen_layer=True) wrapper = WeightsScalarHandler(model) mock_logger = MagicMock(spec=TensorboardLogger) mock_logger.writer = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 5 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) mock_logger.writer.add_scalar.assert_has_calls([ call("weights_norm/fc2/weight", 12.0, 5), call("weights_norm/fc2/bias", math.sqrt(12.0), 5) ], any_order=True) with pytest.raises(AssertionError): mock_logger.writer.add_scalar.assert_has_calls( [ call("weights_norm/fc1/weight", 12.0, 5), call("weights_norm/fc1/bias", math.sqrt(12.0), 5) ], any_order=True, ) assert mock_logger.writer.add_scalar.call_count == 2
def test_weights_scalar_handler_wrong_setup(): model = MagicMock(spec=torch.nn.Module) wrapper = WeightsScalarHandler(model) mock_logger = MagicMock() mock_engine = MagicMock() with pytest.raises(RuntimeError, match="Handler 'WeightsScalarHandler' works only with TensorboardLogger"): wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
def custom_setup(self): if self.tensorboard_logs: tb_logger = TensorboardLogger(log_dir=self.tensorboard_logs) tb_logger.attach(self.trainer, log_handler=OutputHandler( tag="training", output_transform=lambda loss: {'loss': loss}), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(self.evaluator, log_handler=OutputHandler( tag="validation", metric_names=["LossMetric"], another_engine=self.trainer), event_name=Events.EPOCH_COMPLETED) if self.optional_tensorboard_features: tb_logger.attach(self.trainer, log_handler=OptimizerParamsHandler( self.optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(self.trainer, log_handler=WeightsScalarHandler(self.model), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(self.trainer, log_handler=WeightsHistHandler(self.model), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(self.trainer, log_handler=GradsScalarHandler(self.model), event_name=Events.ITERATION_COMPLETED) # This is important to close the tensorboard file logger @self.trainer.on(Events.COMPLETED) def end_tensorboard(trainer): logger.info("Training completed") tb_logger.close() if self.embeddings_name: @self.trainer.on(Events.COMPLETED) def log_embeddings(trainer): if hasattr(self.model, self.embeddings_name) and hasattr( self.dataset_splits, "vectorizer") and TENSORBOARD: logger.info( f"Logging embeddings ({self.embeddings_name}) to Tensorboard!" ) embeddings = getattr(self.model, self.embeddings_name).weight.data metadata = [ str(self.dataset_splits.vectorizer.data_vocab. _id2token[token_index]).encode('utf-8') for token_index in range(embeddings.shape[0]) ] self.writer.add_embedding( mat=embeddings, metadata=metadata, global_step=self.trainer.state.epoch)
def test_weights_scalar_handler_whitelist(dummy_model_factory): model = dummy_model_factory() wrapper = WeightsScalarHandler(model, whitelist=["fc2.weight"]) mock_logger = MagicMock(spec=TensorboardLogger) mock_logger.writer = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 5 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) mock_logger.writer.add_scalar.assert_called_once_with("weights_norm/fc2/weight", 12.0, 5) mock_logger.writer.reset_mock() wrapper = WeightsScalarHandler(model, tag="model", whitelist=["fc1"]) wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) mock_logger.writer.add_scalar.assert_has_calls( [ call("model/weights_norm/fc1/weight", 0.0, 5), call("model/weights_norm/fc1/bias", 0.0, 5), ], any_order=True, ) assert mock_logger.writer.add_scalar.call_count == 2 mock_logger.writer.reset_mock() def weight_selector(n, _): return "bias" in n wrapper = WeightsScalarHandler(model, tag="model", whitelist=weight_selector) wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) mock_logger.writer.add_scalar.assert_has_calls( [ call("model/weights_norm/fc1/bias", 0.0, 5), call("model/weights_norm/fc2/bias", pytest.approx(math.sqrt(12.0)), 5), ], any_order=True, ) assert mock_logger.writer.add_scalar.call_count == 2
def _test(tag=None): wrapper = WeightsScalarHandler(model, tag=tag) mock_logger = MagicMock(spec=TensorboardLogger) mock_logger.writer = MagicMock() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 5 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) tag_prefix = f"{tag}/" if tag else "" assert mock_logger.writer.add_scalar.call_count == 4 mock_logger.writer.add_scalar.assert_has_calls( [ call(tag_prefix + "weights_norm/fc1/weight", 0.0, 5), call(tag_prefix + "weights_norm/fc1/bias", 0.0, 5), call(tag_prefix + "weights_norm/fc2/weight", 12.0, 5), call(tag_prefix + "weights_norm/fc2/bias", math.sqrt(12.0), 5), ], any_order=True, )
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_dir): train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size) model = Net() device = "cpu" if torch.cuda.is_available(): device = "cuda" model.to(device) # Move model before creating optimizer optimizer = SGD(model.parameters(), lr=lr, momentum=momentum) criterion = nn.CrossEntropyLoss() trainer = create_supervised_trainer(model, optimizer, criterion, device=device) trainer.logger = setup_logger("Trainer") if sys.version_info > (3, ): from ignite.contrib.metrics.gpu_info import GpuInfo try: GpuInfo().attach(trainer) except RuntimeError: print( "INFO: By default, in this example it is possible to log GPU information (used memory, utilization). " "As there is no pynvml python package installed, GPU information won't be logged. Otherwise, please " "install it : `pip install pynvml`") metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)} train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) train_evaluator.logger = setup_logger("Train Evaluator") validation_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) validation_evaluator.logger = setup_logger("Val Evaluator") @trainer.on(Events.EPOCH_COMPLETED) def compute_metrics(engine): train_evaluator.run(train_loader) validation_evaluator.run(val_loader) tb_logger = TensorboardLogger(log_dir=log_dir) tb_logger.attach_output_handler( trainer, event_name=Events.ITERATION_COMPLETED(every=100), tag="training", output_transform=lambda loss: {"batchloss": loss}, metric_names="all", ) for tag, evaluator in [("training", train_evaluator), ("validation", validation_evaluator)]: tb_logger.attach_output_handler( evaluator, event_name=Events.EPOCH_COMPLETED, tag=tag, metric_names=["loss", "accuracy"], global_step_transform=global_step_from_engine(trainer), ) tb_logger.attach_opt_params_handler( trainer, event_name=Events.ITERATION_COMPLETED(every=100), optimizer=optimizer) tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100)) tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=100)) tb_logger.attach(trainer, log_handler=GradsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100)) tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=100)) def score_function(engine): return engine.state.metrics["accuracy"] model_checkpoint = ModelCheckpoint( log_dir, n_saved=2, filename_prefix="best", score_function=score_function, score_name="validation_accuracy", global_step_transform=global_step_from_engine(trainer), ) validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {"model": model}) # kick everything off trainer.run(train_loader, max_epochs=epochs) tb_logger.close()
def train_fold(fold, args): # loggers logging_logger = args.logging_logger if args.tb_log: tb_logger = args.tb_logger num_classes = utils.problem_class[args.problem_type] # init model model = eval(args.model)(in_channels=3, num_classes=num_classes, bn=False) model = nn.DataParallel(model, device_ids=args.device_ids).cuda() # transform for train/valid data train_transform, valid_transform = get_transform(args.model) # loss function loss_func = LossMulti(num_classes, args.jaccard_weight) if args.semi: loss_func_semi = LossMultiSemi(num_classes, args.jaccard_weight, args.semi_loss_alpha, args.semi_method) # train/valid filenames train_filenames, valid_filenames = utils.trainval_split(args.train_dir, fold) # DataLoader and Dataset args train_shuffle = True train_ds_kwargs = { 'filenames': train_filenames, 'problem_type': args.problem_type, 'transform': train_transform, 'model': args.model, 'mode': 'train', 'semi': args.semi, } valid_num_workers = args.num_workers valid_batch_size = args.batch_size if 'TAPNet' in args.model: # for TAPNet, cancel default shuffle, use self-defined shuffle in torch.Dataset instead train_shuffle = False train_ds_kwargs['batch_size'] = args.batch_size train_ds_kwargs['mf'] = args.mf if args.semi == True: train_ds_kwargs['semi_method'] = args.semi_method train_ds_kwargs['semi_percentage'] = args.semi_percentage # additional valid dataset kws valid_ds_kwargs = { 'filenames': valid_filenames, 'problem_type': args.problem_type, 'transform': valid_transform, 'model': args.model, 'mode': 'valid', } if 'TAPNet' in args.model: # in validation, num_workers should be set to 0 for sequences valid_num_workers = 0 # in validation, batch_size should be set to 1 for sequences valid_batch_size = 1 valid_ds_kwargs['mf'] = args.mf # train dataloader train_loader = DataLoader( dataset=RobotSegDataset(**train_ds_kwargs), shuffle=train_shuffle, # set to False to disable pytorch dataset shuffle num_workers=args.num_workers, batch_size=args.batch_size, pin_memory=True ) # valid dataloader valid_loader = DataLoader( dataset=RobotSegDataset(**valid_ds_kwargs), shuffle=False, # in validation, no need to shuffle num_workers=valid_num_workers, batch_size=valid_batch_size, # in valid time. have to use one image by one pin_memory=True ) # optimizer optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, # weight_decay=args.weight_decay, nesterov=True) # ignite trainer process function def train_step(engine, batch): # set model to train model.train() # clear gradients optimizer.zero_grad() # additional params to feed into model add_params = {} inputs = batch['input'].cuda(non_blocking=True) with torch.no_grad(): targets = batch['target'].cuda(non_blocking=True) # for TAPNet, add attention maps if 'TAPNet' in args.model: add_params['attmap'] = batch['attmap'].cuda(non_blocking=True) outputs = model(inputs, **add_params) loss_kwargs = {} if args.semi: loss_kwargs['labeled'] = batch['labeled'] if args.semi_method == 'rev_flow': loss_kwargs['optflow'] = batch['optflow'] loss = loss_func_semi(outputs, targets, **loss_kwargs) else: loss = loss_func(outputs, targets, **loss_kwargs) loss.backward() optimizer.step() return_dict = { 'output': outputs, 'target': targets, 'loss_kwargs': loss_kwargs, 'loss': loss.item(), } # for TAPNet, update attention maps after each iteration if 'TAPNet' in args.model: # output_classes and target_classes: <b, h, w> output_softmax_np = torch.softmax(outputs, dim=1).detach().cpu().numpy() # update attention maps train_loader.dataset.update_attmaps(output_softmax_np, batch['abs_idx'].numpy()) return_dict['attmap'] = add_params['attmap'] return return_dict # init trainer trainer = engine.Engine(train_step) # lr scheduler and handler # cyc_scheduler = optim.lr_scheduler.CyclicLR(optimizer, args.lr / 100, args.lr) # lr_scheduler = c_handlers.param_scheduler.LRScheduler(cyc_scheduler) # trainer.add_event_handler(engine.Events.ITERATION_COMPLETED, lr_scheduler) step_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_decay_epochs, gamma=args.lr_decay) lr_scheduler = c_handlers.param_scheduler.LRScheduler(step_scheduler) trainer.add_event_handler(engine.Events.EPOCH_STARTED, lr_scheduler) @trainer.on(engine.Events.STARTED) def trainer_start_callback(engine): logging_logger.info('training fold {}, {} train / {} valid files'. \ format(fold, len(train_filenames), len(valid_filenames))) # resume training if args.resume: # ckpt for current fold fold_<fold>_model_<epoch>.pth ckpt_dir = Path(args.ckpt_dir) ckpt_filename = ckpt_dir.glob('fold_%d_model_[0-9]*.pth' % fold)[0] res = re.match(r'fold_%d_model_(\d+).pth' % fold, ckpt_filename) # restore epoch engine.state.epoch = int(res.groups()[0]) # load model state dict model.load_state_dict(torch.load(str(ckpt_filename))) logging_logger.info('restore model [{}] from epoch {}.'.format(args.model, engine.state.epoch)) else: logging_logger.info('train model [{}] from scratch'.format(args.model)) # record metrics history every epoch engine.state.metrics_records = {} @trainer.on(engine.Events.EPOCH_STARTED) def trainer_epoch_start_callback(engine): # log learning rate on pbar train_pbar.log_message('model: %s, problem type: %s, fold: %d, lr: %.5f, batch size: %d' % \ (args.model, args.problem_type, fold, lr_scheduler.get_param(), args.batch_size)) # for TAPNet, change dataset schedule to random after the first epoch if 'TAPNet' in args.model and engine.state.epoch > 1: train_loader.dataset.set_dataset_schedule("shuffle") @trainer.on(engine.Events.ITERATION_COMPLETED) def trainer_iter_comp_callback(engine): # logging_logger.info(engine.state.metrics) pass # monitor loss # running average loss train_ra_loss = imetrics.RunningAverage(output_transform= lambda x: x['loss'], alpha=0.98) train_ra_loss.attach(trainer, 'train_ra_loss') # monitor train loss over epoch if args.semi: train_loss = imetrics.Loss(loss_func_semi, output_transform=lambda x: (x['output'], x['target'], x['loss_kwargs'])) else: train_loss = imetrics.Loss(loss_func, output_transform=lambda x: (x['output'], x['target'])) train_loss.attach(trainer, 'train_loss') # progress bar train_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True) train_metric_names = ['train_ra_loss'] train_pbar.attach(trainer, metric_names=train_metric_names) # tensorboardX: log train info if args.tb_log: tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer, 'lr'), event_name=engine.Events.EPOCH_STARTED) tb_logger.attach(trainer, log_handler=OutputHandler('train_iter', train_metric_names), event_name=engine.Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OutputHandler('train_epoch', ['train_loss']), event_name=engine.Events.EPOCH_COMPLETED) tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model, reduction=torch.norm), event_name=engine.Events.ITERATION_COMPLETED) # tb_logger.attach(trainer, log_handler=tb_log_train_vars, # event_name=engine.Events.ITERATION_COMPLETED) # ignite validator process function def valid_step(engine, batch): with torch.no_grad(): model.eval() inputs = batch['input'].cuda(non_blocking=True) targets = batch['target'].cuda(non_blocking=True) # additional arguments add_params = {} # for TAPNet, add attention maps if 'TAPNet' in args.model: add_params['attmap'] = batch['attmap'].cuda(non_blocking=True) # output logits outputs = model(inputs, **add_params) # loss loss = loss_func(outputs, targets) output_softmaxs = torch.softmax(outputs, dim=1) output_argmaxs = output_softmaxs.argmax(dim=1) # output_classes and target_classes: <b, h, w> output_classes = output_argmaxs.cpu().numpy() target_classes = targets.cpu().numpy() # record current batch metrics iou_mRecords = MetricRecord() dice_mRecords = MetricRecord() cm_b = np.zeros((num_classes, num_classes), dtype=np.uint32) for output_class, target_class in zip(output_classes, target_classes): # calculate metrics for each frame # calculate using confusion matrix or dirctly using definition cm = calculate_confusion_matrix_from_arrays(output_class, target_class, num_classes) iou_mRecords.update_record(calculate_iou(cm)) dice_mRecords.update_record(calculate_dice(cm)) cm_b += cm ######## calculate directly using definition ########## # iou_mRecords.update_record(iou_multi_np(target_class, output_class)) # dice_mRecords.update_record(dice_multi_np(target_class, output_class)) # accumulate batch metrics to engine state engine.state.epoch_metrics['confusion_matrix'] += cm_b engine.state.epoch_metrics['iou'].merge(iou_mRecords) engine.state.epoch_metrics['dice'].merge(dice_mRecords) return_dict = { 'loss': loss.item(), 'output': outputs, 'output_argmax': output_argmaxs, 'target': targets, # for monitoring 'iou': iou_mRecords, 'dice': dice_mRecords, } if 'TAPNet' in args.model: # for TAPNet, update attention maps after each iteration valid_loader.dataset.update_attmaps(output_softmaxs.cpu().numpy(), batch['abs_idx'].numpy()) # for TAPNet, return extra internal values return_dict['attmap'] = add_params['attmap'] # TODO: for TAPNet, return internal self-learned attention maps return return_dict # validator engine validator = engine.Engine(valid_step) # monitor loss valid_ra_loss = imetrics.RunningAverage(output_transform= lambda x: x['loss'], alpha=0.98) valid_ra_loss.attach(validator, 'valid_ra_loss') # monitor validation loss over epoch valid_loss = imetrics.Loss(loss_func, output_transform=lambda x: (x['output'], x['target'])) valid_loss.attach(validator, 'valid_loss') # monitor <data> mean metrics valid_data_miou = imetrics.RunningAverage(output_transform= lambda x: x['iou'].data_mean()['mean'], alpha=0.98) valid_data_miou.attach(validator, 'mIoU') valid_data_mdice = imetrics.RunningAverage(output_transform= lambda x: x['dice'].data_mean()['mean'], alpha=0.98) valid_data_mdice.attach(validator, 'mDice') # show metrics on progress bar (after every iteration) valid_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True) valid_metric_names = ['valid_ra_loss', 'mIoU', 'mDice'] valid_pbar.attach(validator, metric_names=valid_metric_names) # ## monitor ignite IoU (the same as iou we are using) ### # cm = imetrics.ConfusionMatrix(num_classes, # output_transform=lambda x: (x['output'], x['target'])) # imetrics.IoU(cm, # ignore_index=0 # ).attach(validator, 'iou') # # monitor ignite mean iou (over all classes even not exist in gt) # mean_iou = imetrics.mIoU(cm, # ignore_index=0 # ).attach(validator, 'mean_iou') @validator.on(engine.Events.STARTED) def validator_start_callback(engine): pass @validator.on(engine.Events.EPOCH_STARTED) def validator_epoch_start_callback(engine): engine.state.epoch_metrics = { # directly use definition to calculate 'iou': MetricRecord(), 'dice': MetricRecord(), 'confusion_matrix': np.zeros((num_classes, num_classes), dtype=np.uint32), } # evaluate after iter finish @validator.on(engine.Events.ITERATION_COMPLETED) def validator_iter_comp_callback(engine): pass # evaluate after epoch finish @validator.on(engine.Events.EPOCH_COMPLETED) def validator_epoch_comp_callback(engine): # log ignite metrics # logging_logger.info(engine.state.metrics) # ious = engine.state.metrics['iou'] # msg = 'IoU: ' # for ins_id, iou in enumerate(ious): # msg += '{:d}: {:.3f}, '.format(ins_id + 1, iou) # logging_logger.info(msg) # logging_logger.info('nonzero mean IoU for all data: {:.3f}'.format(ious[ious > 0].mean())) # log monitored epoch metrics epoch_metrics = engine.state.epoch_metrics ######### NOTICE: Two metrics are available but different ########## ### 1. mean metrics for all data calculated by confusion matrix #### ''' compared with using confusion_matrix[1:, 1:] in original code, we use the full confusion matrix and only present non-background result ''' confusion_matrix = epoch_metrics['confusion_matrix']# [1:, 1:] ious = calculate_iou(confusion_matrix) dices = calculate_dice(confusion_matrix) mean_ious = np.mean(list(ious.values())) mean_dices = np.mean(list(dices.values())) std_ious = np.std(list(ious.values())) std_dices = np.std(list(dices.values())) logging_logger.info('mean IoU: %.3f, std: %.3f, for each class: %s' % (mean_ious, std_ious, ious)) logging_logger.info('mean Dice: %.3f, std: %.3f, for each class: %s' % (mean_dices, std_dices, dices)) ### 2. mean metrics for all data calculated by definition ### iou_data_mean = epoch_metrics['iou'].data_mean() dice_data_mean = epoch_metrics['dice'].data_mean() logging_logger.info('data (%d) mean IoU: %.3f, std: %.3f' % (len(iou_data_mean['items']), iou_data_mean['mean'], iou_data_mean['std'])) logging_logger.info('data (%d) mean Dice: %.3f, std: %.3f' % (len(dice_data_mean['items']), dice_data_mean['mean'], dice_data_mean['std'])) # record metrics in trainer every epoch # trainer.state.metrics_records[trainer.state.epoch] = \ # {'miou': mean_ious, 'std_miou': std_ious, # 'mdice': mean_dices, 'std_mdice': std_dices} trainer.state.metrics_records[trainer.state.epoch] = \ {'miou': iou_data_mean['mean'], 'std_miou': iou_data_mean['std'], 'mdice': dice_data_mean['mean'], 'std_mdice': dice_data_mean['std']} # log interal variables(attention maps, outputs, etc.) on validation def tb_log_valid_iter_vars(engine, logger, event_name): log_tag = 'valid_iter' output = engine.state.output batch_size = output['output'].shape[0] res_grid = tvutils.make_grid(torch.cat([ output['output_argmax'].unsqueeze(1), output['target'].unsqueeze(1), ]), padding=2, normalize=False, # show origin image nrow=batch_size).cpu() logger.writer.add_image(tag='%s (outputs, targets)' % (log_tag), img_tensor=res_grid) if 'TAPNet' in args.model: # log attention maps and other internal values inter_vals_grid = tvutils.make_grid(torch.cat([ output['attmap'], ]), padding=2, normalize=True, nrow=batch_size).cpu() logger.writer.add_image(tag='%s internal vals' % (log_tag), img_tensor=inter_vals_grid) def tb_log_valid_epoch_vars(engine, logger, event_name): log_tag = 'valid_iter' # log monitored epoch metrics epoch_metrics = engine.state.epoch_metrics confusion_matrix = epoch_metrics['confusion_matrix']# [1:, 1:] ious = calculate_iou(confusion_matrix) dices = calculate_dice(confusion_matrix) mean_ious = np.mean(list(ious.values())) mean_dices = np.mean(list(dices.values())) logger.writer.add_scalar('mIoU', mean_ious, engine.state.epoch) logger.writer.add_scalar('mIoU', mean_dices, engine.state.epoch) if args.tb_log: # log internal values tb_logger.attach(validator, log_handler=tb_log_valid_iter_vars, event_name=engine.Events.ITERATION_COMPLETED) tb_logger.attach(validator, log_handler=tb_log_valid_epoch_vars, event_name=engine.Events.EPOCH_COMPLETED) # tb_logger.attach(validator, log_handler=OutputHandler('valid_iter', valid_metric_names), # event_name=engine.Events.ITERATION_COMPLETED) tb_logger.attach(validator, log_handler=OutputHandler('valid_epoch', ['valid_loss']), event_name=engine.Events.EPOCH_COMPLETED) # score function for model saving ckpt_score_function = lambda engine: \ np.mean(list(calculate_iou(engine.state.epoch_metrics['confusion_matrix']).values())) # ckpt_score_function = lambda engine: engine.state.epoch_metrics['iou'].data_mean()['mean'] ckpt_filename_prefix = 'fold_%d' % fold # model saving handler model_ckpt_handler = handlers.ModelCheckpoint( dirname=args.model_save_dir, filename_prefix=ckpt_filename_prefix, score_function=ckpt_score_function, create_dir=True, require_empty=False, save_as_state_dict=True, atomic=True) validator.add_event_handler(event_name=engine.Events.EPOCH_COMPLETED, handler=model_ckpt_handler, to_save={ 'model': model, }) # early stop # trainer=trainer, but should be handled by validator early_stopping = handlers.EarlyStopping(patience=args.es_patience, score_function=ckpt_score_function, trainer=trainer ) validator.add_event_handler(event_name=engine.Events.EPOCH_COMPLETED, handler=early_stopping) # evaluate after epoch finish @trainer.on(engine.Events.EPOCH_COMPLETED) def trainer_epoch_comp_callback(engine): validator.run(valid_loader) trainer.run(train_loader, max_epochs=args.max_epochs) if args.tb_log: # close tb_logger tb_logger.close() return trainer.state.metrics_records
def setup(self, training_metrics): def metric_name(n) -> str: if n.endswith('Accuracy'): n = 'acc' else: n = n[:-6] if n.endswith('Metric') else n return n def print_metrics(metrics) -> str: rv = '' metric_keys = sorted(k for k in metrics) for k in metric_keys: if k == 'Accuracy': rv += f'{metric_name(k)}: {metrics[k]:.3}' else: rv += f'{metric_name(k)}: {metrics[k]:.6}' return rv if self.seed: set_seed_everywhere(self.seed, self.cuda) pbar = ProgressBar() names = [] for k, v in training_metrics.items(): name = f'r{k}' names.append(name) RunningAverage(v).attach(self.trainer, name) RunningAverage(None, output_transform=lambda x: x[-1] * self. loss_accumulation_steps).attach(self.trainer, 'rloss') names.append('rloss') pbar.attach(self.trainer, names) pbar = ProgressBar() pbar.attach(self.evaluator) # A few events handler. To add / modify the events handler, you need to extend the __init__ method of RunnerABC # Ignite provides the necessary abstractions and a furnished repository of useful tools @self.trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(trainer): self.evaluator.run(self.dataset_splits.val_data_loader()) metrics = self.evaluator.state.metrics logger.info( f"Validation Results - Epoch: {trainer.state.epoch} {print_metrics(metrics)}" ) if self.scheduler: self.scheduler.step( metrics[self.loss_metric.__class__.__name__]) @self.trainer.on(Events.COMPLETED) def log_test_results(trainer): self.evaluator.run(self.dataset_splits.test_data_loader()) metrics = self.evaluator.state.metrics logger.info( f"Test Results - Epoch: {trainer.state.epoch} {print_metrics(metrics)}" ) if self.tensorboard_logs: tb_logger = TensorboardLogger(log_dir=self.tensorboard_logs) tb_logger.attach(self.trainer, log_handler=OutputHandler( tag="training", output_transform=lambda loss: {'loss': loss}), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(self.evaluator, log_handler=OutputHandler( tag="validation", metric_names=["LossMetric"], another_engine=self.trainer), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(self.trainer, log_handler=OptimizerParamsHandler( self.optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(self.trainer, log_handler=WeightsScalarHandler(self.model), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(self.trainer, log_handler=WeightsHistHandler(self.model), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(self.trainer, log_handler=GradsScalarHandler(self.model), event_name=Events.ITERATION_COMPLETED) # This is important to close the tensorboard file logger @self.trainer.on(Events.COMPLETED) def end_tensorboard(trainer): logger.info("Training completed") tb_logger.close() if self.embeddings_name: @self.trainer.on(Events.COMPLETED) def log_embeddings(trainer): if hasattr(self.model, self.embeddings_name) and hasattr( self.dataset_splits, "vectorizer"): logger.info( f"Logging embeddings ({self.embeddings_name}) to Tensorboard!" ) embeddings = getattr(self.model, self.embeddings_name).weight.data metadata = [ str(self.dataset_splits.vectorizer.data_vocab. _id2token[token_index]).encode('utf-8') for token_index in range(embeddings.shape[0]) ] self.writer.add_embedding( mat=embeddings, metadata=metadata, global_step=self.trainer.state.epoch)
def train(epochs=500, batch_size=32, bptt_len=70, lr=0.00025, log_steps=200, clip_grad=0.25, log_dir="experiments"): ################################################################### # Dataset ################################################################### wt = wikitext103(batch_size=batch_size, bptt_len=bptt_len) # wt = wikitext2(batch_size=batch_size, bptt_len=bptt_len) ################################################################### # Configs ################################################################### embedding_config = DropEmbedding.Hyperparams(len(wt.text_field.vocab) + 3, ninp=512) encoder_config = TransformerEncoder.Hyperparams( att_num_units=[512, 512, 512, 512, 512, 512], max_ext=384) ################################################################### # Models ################################################################### base_embedding = DropEmbedding(embedding_config) embedding = TransformerEmbedding(embedding=base_embedding, max_length=bptt_len, embedding_size=embedding_config.ninp, use_positional_embedding=False) encoder = TransformerEncoder(encoder_config) model = TransformerLanguageModel(embedding, encoder) model.init_weight() ################################################################### # Loss ################################################################### criterion = lm_criterion(in_features=encoder_config.att_num_units[-1], vocab_size=len(wt.text_field.vocab)) ################################################################### # Parameters + Train ops ################################################################### parameters = (list(model.parameters()) + list(criterion.parameters())) tot_params = 0 for p in parameters: tot_params += reduce(lambda x, y: x * y, p.size()) print("Total Parameters: ", tot_params) opt = optim.Adam(parameters, lr=lr) model.to(DEVICE) criterion.to(DEVICE) ################################################################### # Train + Evaluation ################################################################### def train_step(engine, batch): model.train() opt.zero_grad() text = batch.text.to(DEVICE).t().contiguous() target = batch.target.to(DEVICE).t().contiguous() out, out_past = model(text, engine.state.train_past) engine.state.train_past = out_past raw_loss = criterion(out.view(-1, out.size(2)), target.view(-1)) loss = raw_loss[1] loss.backward() nn.utils.clip_grad_norm_(parameters, clip_grad) opt.step() return {"train_loss": loss.item(), "train_ppl": loss.exp().item()} def eval_step(engine, batch): model.eval() if not hasattr(engine.state, "eval_past"): engine.state.eval_past = None with torch.no_grad(): text = batch.text.to(DEVICE).t().contiguous() target = batch.target.to(DEVICE).t().contiguous() out, out_past = model(text, engine.state.eval_past) engine.state.eval_past = out_past raw_loss = criterion(out.view(-1, out.size(2)), target.view(-1)) loss = raw_loss[1] return {"val_loss": loss.item()} train_engine = Engine(train_step) eval_engine = Engine(eval_step) def reset_state(engine): engine.state.train_past = None def run_eval(_): print("start running eval") eval_engine.run(wt.valid_iter) metrics = eval_engine.state.metrics print("Validation loss: ", metrics["val_loss"], ", ppl: ", np.exp(metrics["val_loss"])) train_engine.add_event_handler(Events.EPOCH_STARTED, reset_state) train_engine.add_event_handler(Events.EPOCH_COMPLETED, run_eval) ################################################################### # LR Scheduler ################################################################### cosine_scheduler = CosineAnnealingScheduler(opt.param_groups[0], "lr", 0.0, 2.5e-4, cycle_size=len(wt.train_iter)) warmup_scheduler = create_lr_scheduler_with_warmup(cosine_scheduler, 0.0, 2.5e-4, 200) train_engine.add_event_handler(Events.ITERATION_STARTED, warmup_scheduler) ################################################################### # Metrics ################################################################### RunningAverage(output_transform=lambda x: x["train_ppl"]).attach( train_engine, "train_ppl") RunningAverage(output_transform=lambda x: x["train_loss"]).attach( train_engine, "train_loss") RunningAverage(output_transform=lambda x: x["val_loss"]).attach( eval_engine, "val_loss") progress_bar = ProgressBar(persist=True) progress_bar.attach(train_engine, ["train_ppl", "train_loss"]) progress_bar_val = ProgressBar(persist=True) progress_bar_val.attach(eval_engine, ["val_loss"]) ################################################################### # Tensorboard ################################################################### tb_logger = TensorboardLogger(log_dir=log_dir) def stepn_logger(num_steps, handler): def logger_runner(engine, log_handler, event_name): if engine.state.iteration % num_steps == 0: handler(engine, log_handler, event_name) return logger_runner tb_logger.attach(train_engine, log_handler=stepn_logger( log_steps, OutputHandler(tag="training", output_transform=lambda loss: loss)), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(eval_engine, log_handler=OutputHandler( tag="validation", output_transform=lambda loss: loss, another_engine=train_engine), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(train_engine, log_handler=stepn_logger(log_steps, OptimizerParamsHandler(opt)), event_name=Events.ITERATION_STARTED) tb_logger.attach(train_engine, log_handler=stepn_logger(log_steps, WeightsScalarHandler(model)), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(train_engine, log_handler=stepn_logger(log_steps, GradsScalarHandler(model)), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(train_engine, log_handler=stepn_logger(500, WeightsHistHandler(model)), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(train_engine, log_handler=stepn_logger(500, GradsHistHandler(model)), event_name=Events.ITERATION_COMPLETED) try: train_engine.run(wt.train_iter, max_epochs=epochs) except Exception: pass finally: tb_logger.close()