Beispiel #1
0
def evaluator_epoch_comp_callback(engine):
    # save masks for each batch
    batch_output = engine.state.output
    input_filenames = batch_output['input_filename']
    masks = batch_output['mask']

    for i, input_filename in enumerate(input_filenames):
        mask = cv2.resize(masks[i],
                          dsize=(utils.cropped_width, utils.cropped_height),
                          interpolation=cv2.INTER_AREA)

        # if pad:
        #     h_start, w_start = utils.h_start, utils.w_start
        #     h, w = mask.shape
        #     # recover to original shape
        #     full_mask = np.zeros((original_height, original_width))
        #     full_mask[h_start:h_start + h, w_start:w_start + w] = t_mask
        #     mask = full_mask
        #print("Input Filename-->", input_filename)
        #instrument_folder_name = input_filename.parent.parent.name
        instrument_folder_name = os.path.basename(
            os.path.dirname(os.path.dirname(input_filename)))
        #print("instrument_folder_name-->", instrument_folder_name)

        # mask_folder/instrument_dataset_x/problem_type_masks/framexxx.png
        mask_folder = mask_save_dir / instrument_folder_name / utils.mask_folder[
            args.problem_type]
        mask_folder.mkdir(exist_ok=True, parents=True)
        mask_filename = mask_folder / os.path.basename(input_filename)
        #print("mask_filename-->", mask_filename)
        cv2.imwrite(str(mask_filename), mask)

        if 'TAPNet' in args.model:
            attmap = batch_output['attmap'][i]

            attmap_folder = mask_save_dir / instrument_folder_name / '_'.join(
                args.problem_type, 'attmaps')
            attmap_folder.mkdir(exist_ok=True, parents=True)
            attmap_filename = attmap_folder / input_filename.name

            cv2.imwrite(str(attmap_filename), attmap)

    evaluator.run(eval_loader)

    # validator engine
    validator = engine.Engine(valid_step)

    # monitor loss
    valid_ra_loss = imetrics.RunningAverage(
        output_transform=lambda x: x['loss'], alpha=0.98)
    valid_ra_loss.attach(validator, 'valid_ra_loss')

    # monitor validation loss over epoch
    valid_loss = imetrics.Loss(loss_func,
                               output_transform=lambda x:
                               (x['output'], x['target']))
    valid_loss.attach(validator, 'valid_loss')

    # monitor <data> mean metrics
    valid_data_miou = imetrics.RunningAverage(
        output_transform=lambda x: x['iou'].data_mean()['mean'], alpha=0.98)
    valid_data_miou.attach(validator, 'mIoU')
    valid_data_mdice = imetrics.RunningAverage(
        output_transform=lambda x: x['dice'].data_mean()['mean'], alpha=0.98)
    valid_data_mdice.attach(validator, 'mDice')

    # show metrics on progress bar (after every iteration)
    valid_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True)
    valid_metric_names = ['valid_ra_loss', 'mIoU', 'mDice']
    valid_pbar.attach(validator, metric_names=valid_metric_names)

    # ## monitor ignite IoU (the same as iou we are using) ###
    # cm = imetrics.ConfusionMatrix(num_classes,
    #     output_transform=lambda x: (x['output'], x['target']))
    # imetrics.IoU(cm,
    #     ignore_index=0
    #     ).attach(validator, 'iou')

    # # monitor ignite mean iou (over all classes even not exist in gt)
    # mean_iou = imetrics.mIoU(cm,
    #     ignore_index=0
    #     ).attach(validator, 'mean_iou')

    @validator.on(engine.Events.STARTED)
    def validator_start_callback(engine):
        pass

    @validator.on(engine.Events.EPOCH_STARTED)
    def validator_epoch_start_callback(engine):
        engine.state.epoch_metrics = {
            # directly use definition to calculate
            'iou':
            MetricRecord(),
            'dice':
            MetricRecord(),
            'confusion_matrix':
            np.zeros((num_classes, num_classes), dtype=np.uint32),
        }

    # evaluate after iter finish
    @validator.on(engine.Events.ITERATION_COMPLETED)
    def validator_iter_comp_callback(engine):
        pass

    # evaluate after epoch finish
    @validator.on(engine.Events.EPOCH_COMPLETED)
    def validator_epoch_comp_callback(engine):

        # log ignite metrics
        # logging_logger.info(engine.state.metrics)
        # ious = engine.state.metrics['iou']
        # msg = 'IoU: '
        # for ins_id, iou in enumerate(ious):
        #     msg += '{:d}: {:.3f}, '.format(ins_id + 1, iou)
        # logging_logger.info(msg)
        # logging_logger.info('nonzero mean IoU for all data: {:.3f}'.format(ious[ious > 0].mean()))

        # log monitored epoch metrics
        epoch_metrics = engine.state.epoch_metrics

        ######### NOTICE: Two metrics are available but different ##########
        ### 1. mean metrics for all data calculated by confusion matrix ####
        '''
        compared with using confusion_matrix[1:, 1:] in original code,
        we use the full confusion matrix and only present non-background result
        '''
        confusion_matrix = epoch_metrics['confusion_matrix']  # [1:, 1:]
        ious = calculate_iou(confusion_matrix)
        dices = calculate_dice(confusion_matrix)

        mean_ious = np.mean(list(ious.values()))
        mean_dices = np.mean(list(dices.values()))
        std_ious = np.std(list(ious.values()))
        std_dices = np.std(list(dices.values()))

        logging_logger.info('mean IoU: %.3f, std: %.3f, for each class: %s' %
                            (mean_ious, std_ious, ious))
        logging_logger.info('mean Dice: %.3f, std: %.3f, for each class: %s' %
                            (mean_dices, std_dices, dices))

        ### 2. mean metrics for all data calculated by definition ###
        iou_data_mean = epoch_metrics['iou'].data_mean()
        dice_data_mean = epoch_metrics['dice'].data_mean()

        logging_logger.info('data (%d) mean IoU: %.3f, std: %.3f' %
                            (len(iou_data_mean['items']),
                             iou_data_mean['mean'], iou_data_mean['std']))
        logging_logger.info('data (%d) mean Dice: %.3f, std: %.3f' %
                            (len(dice_data_mean['items']),
                             dice_data_mean['mean'], dice_data_mean['std']))

        # record metrics in trainer every epoch
        # trainer.state.metrics_records[trainer.state.epoch] = \
        #     {'miou': mean_ious, 'std_miou': std_ious,
        #     'mdice': mean_dices, 'std_mdice': std_dices}

        trainer.state.metrics_records[trainer.state.epoch] = \
            {'miou': iou_data_mean['mean'], 'std_miou': iou_data_mean['std'],
            'mdice': dice_data_mean['mean'], 'std_mdice': dice_data_mean['std']}

    # log interal variables(attention maps, outputs, etc.) on validation
    def tb_log_valid_iter_vars(engine, logger, event_name):
        log_tag = 'valid_iter'
        output = engine.state.output
        batch_size = output['output'].shape[0]
        res_grid = tvutils.make_grid(
            torch.cat([
                output['output_argmax'].unsqueeze(1),
                output['target'].unsqueeze(1),
            ]),
            padding=2,
            normalize=False,  # show origin image
            nrow=batch_size).cpu()

        logger.writer.add_image(tag='%s (outputs, targets)' % (log_tag),
                                img_tensor=res_grid)

        if 'TAPNet' in args.model:
            # log attention maps and other internal values
            inter_vals_grid = tvutils.make_grid(torch.cat([
                output['attmap'],
            ]),
                                                padding=2,
                                                normalize=True,
                                                nrow=batch_size).cpu()
            logger.writer.add_image(tag='%s internal vals' % (log_tag),
                                    img_tensor=inter_vals_grid)

    def tb_log_valid_epoch_vars(engine, logger, event_name):
        log_tag = 'valid_iter'
        # log monitored epoch metrics
        epoch_metrics = engine.state.epoch_metrics
        confusion_matrix = epoch_metrics['confusion_matrix']  # [1:, 1:]
        ious = calculate_iou(confusion_matrix)
        dices = calculate_dice(confusion_matrix)

        mean_ious = np.mean(list(ious.values()))
        mean_dices = np.mean(list(dices.values()))
        logger.writer.add_scalar('mIoU', mean_ious, engine.state.epoch)
        logger.writer.add_scalar('mIoU', mean_dices, engine.state.epoch)

    if args.tb_log:
        # log internal values
        tb_logger.attach(validator,
                         log_handler=tb_log_valid_iter_vars,
                         event_name=engine.Events.ITERATION_COMPLETED)
        tb_logger.attach(validator,
                         log_handler=tb_log_valid_epoch_vars,
                         event_name=engine.Events.EPOCH_COMPLETED)
        # tb_logger.attach(validator, log_handler=OutputHandler('valid_iter', valid_metric_names),
        #     event_name=engine.Events.ITERATION_COMPLETED)
        tb_logger.attach(validator,
                         log_handler=OutputHandler('valid_epoch',
                                                   ['valid_loss']),
                         event_name=engine.Events.EPOCH_COMPLETED)

    # score function for model saving
    ckpt_score_function = lambda engine: \
        np.mean(list(calculate_iou(engine.state.epoch_metrics['confusion_matrix']).values()))
    # ckpt_score_function = lambda engine: engine.state.epoch_metrics['iou'].data_mean()['mean']

    ckpt_filename_prefix = 'fold_%d' % fold

    # model saving handler
    model_ckpt_handler = handlers.ModelCheckpoint(
        dirname=args.model_save_dir,
        filename_prefix=ckpt_filename_prefix,
        score_function=ckpt_score_function,
        create_dir=True,
        require_empty=False,
        save_as_state_dict=True,
        atomic=True)

    validator.add_event_handler(event_name=engine.Events.EPOCH_COMPLETED,
                                handler=model_ckpt_handler,
                                to_save={
                                    'model': model,
                                })

    # early stop
    # trainer=trainer, but should be handled by validator
    early_stopping = handlers.EarlyStopping(patience=args.es_patience,
                                            score_function=ckpt_score_function,
                                            trainer=trainer)

    validator.add_event_handler(event_name=engine.Events.EPOCH_COMPLETED,
                                handler=early_stopping)

    # evaluate after epoch finish
    @trainer.on(engine.Events.EPOCH_COMPLETED)
    def trainer_epoch_comp_callback(engine):
        validator.run(valid_loader)

    trainer.run(train_loader, max_epochs=args.max_epochs)

    if args.tb_log:
        # close tb_logger
        tb_logger.close()

    return trainer.state.metrics_records
Beispiel #2
0
def run_utk(model,
            optimizer,
            epochs,
            log_interval,
            dataloaders,
            dirname='resnet_models',
            filename_prefix='resnet',
            n_saved=2,
            log_dir='../../fer2013/logs',
            launch_tensorboard=False,
            patience=10,
            resume_model=None,
            resume_optimizer=None,
            backup_step=1,
            backup_path=None,
            n_epochs_freeze=5,
            n_cycle=None,
            lr_after_freeze=1e-3,
            lr_cycle_start=1e-4,
            lr_cycle_end=1e-1,
            loss_weights=[1 / 10, 1 / 0.16, 1 / 0.44],
            lr_plot=True):
    """
    Utility function that encapsulates pytorch models training routine.

    :param model: pytorch model to be trained
    :param optimizer: pytorch optimizer that updates the `model`'s parameters
    :param epochs: maximum number of epoch to train for
    :param log_interval: print training loss each `log_interval` iterations during training
    :param dataloaders: dictionary with `train` and `valid` as keys, the corresponding values being resp. train
            and validation pytorch `DataLoader` objects
    :param dirname: path to the directory where to save model checkpoints during training
    :param filename_prefix: string, name under which to save the model checkpoint file
    :param n_saved: int, save n_saved best model during training
    :param log_dir: optional path to a directory where to write tensorboard logs
    :param launch_tensorboard: boolean, whether to write metrics and histograms using tensorboard
    :param patience: int, number of epochs to wait for before stopping training if no improvement is recorded
    :param resume_model: optional path to checkpoint of trained model to load weights from and continue training
    :param resume_optimizer: optional path to a previous optimizer checkpoint to load state_dict from
    :param backup_step: optional, copy the model checkpoints from `dirname` each `backup_step` epochs,
                        This is useful for me in situation where I train on google colab and want backup my checkpoints
                        to my google drive
    :param backup_path: optional path to backup (copy) model checkpoints to, each `backup_step` epochs.
    :param n_epochs_freeze: after `n_epochs_freeze` unfreeze the model's frozen layers,
                            useful when doing transfer learning
    :param n_cycle: optional int, in terms of number of epochs, to be used for cycle size when doing learning rate
                    scheduling
    :param lr_after_freeze: float, the new learning rate to set after unfreezing the model's layer for finetuning
    :param lr_cycle_start: starting value for learning rate when doing learning rate scheduling
    :param lr_cycle_end: end value for learning rate when doing learning rate scheduling
    :param loss_weights: list of float to be used for weighting model's outputs
    :return:
    """

    count_parameters(model)

    # create the tensorboard log directory if relevant
    if launch_tensorboard:
        os.makedirs(log_dir, exist_ok=True)

    # In case a path of previous model and optimizer checkpoints are provided load weights and state from them
    if resume_model:
        model.load_state_dict(torch.load(resume_model))

    if resume_optimizer:
        optimizer.load_state_dict(torch.load(resume_optimizer))
        for state in optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.cuda()

    # Get the training and validation data loaders
    train_loader, val_loader = dataloaders['train'], dataloaders['valid']

    # create tensorboard writers
    if launch_tensorboard:
        writer, val_writer = create_summary_writer(model, train_loader,
                                                   log_dir)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # create trainer and evaluator engines that handle model training and evaluation resp.
    trainer = create_supervised_trainer_multitask(model,
                                                  optimizer,
                                                  loss_fn=my_multi_task_loss,
                                                  loss_weights=loss_weights,
                                                  device=device)
    evaluator = create_supervised_evaluator_multitask(
        model,
        metrics={
            'mt_accuracy': MultiTaskAccuracy(),
            'mt_loss': MutliTaskLoss()
        },
        device=device,
        loss_weights=loss_weights)

    # function to schedule learning rate if needed
    @trainer.on(Events.EPOCH_STARTED)
    def schedule_learning_rate(engine):
        if engine.state.epoch > n_epochs_freeze and n_cycle not in [None, 0] \
                and not getattr(trainer, 'scheduler_set', False):
            scheduler = LinearCyclicalScheduler(optimizer, 'lr',
                                                lr_cycle_start, lr_cycle_end,
                                                len(train_loader) * n_cycle)
            trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
            setattr(trainer, 'scheduler_set', True)

    # functions to write metrics during training
    desc = "ITERATION - loss: {:.3f}"
    pbar = tqdm.tqdm(initial=0,
                     leave=False,
                     total=len(train_loader),
                     desc=desc.format(0))

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter_ = (engine.state.iteration - 1) % len(train_loader) + 1

        if iter_ % log_interval == 0:
            pbar.desc = desc.format(engine.state.output)
            pbar.update(log_interval)

        if launch_tensorboard:
            writer.add_scalar('training/loss', engine.state.output,
                              engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        pbar.refresh()
        # print metrics on training set
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        age_l1_loss, gender_acc, race_acc = metrics['mt_accuracy']
        avg_nll = metrics['mt_loss']
        tqdm.tqdm.write(
            "Training Results - Epoch: {} Age L1-loss: {:.3f} ** Gender accuracy: {:.3f} "
            "** Race accuracy: {:.3f} ** Avg loss: {:.3f}".format(
                engine.state.epoch, age_l1_loss, gender_acc, race_acc,
                avg_nll))
        if launch_tensorboard:
            writer.add_scalar('avg_loss', avg_nll, engine.state.epoch)
            writer.add_scalar('age_l1_loss', age_l1_loss, engine.state.epoch)
            writer.add_scalar('gender_accuracy', gender_acc,
                              engine.state.epoch)
            writer.add_scalar('race_accuracy', race_acc, engine.state.epoch)

        # print metrics on validation set
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        age_l1_loss, gender_acc, race_acc = metrics['mt_accuracy']
        avg_nll = metrics['mt_loss']
        tqdm.tqdm.write(
            "Validation Results - Epoch: {} Age L1-loss: {:.3f} ** Gender accuracy: {:.3f} **"
            " Race accuracy: {:.3f} ** Avg loss: {:.3f}".format(
                engine.state.epoch, age_l1_loss, gender_acc, race_acc,
                avg_nll))
        global val_loss
        val_loss.append(avg_nll)
        if launch_tensorboard:
            val_writer.add_scalar('avg_loss', avg_nll, engine.state.epoch)
            val_writer.add_scalar('age_l1_loss', age_l1_loss,
                                  engine.state.epoch)
            val_writer.add_scalar('gender_accuracy', gender_acc,
                                  engine.state.epoch)
            val_writer.add_scalar('race_accuracy', race_acc,
                                  engine.state.epoch)

        pbar.n = pbar.last_print_n = 0

    # Utility function for unfreezing frozen layer for finetuning
    @trainer.on(Events.EPOCH_STARTED)
    def unfreeze(engine):
        if engine.state.epoch == n_epochs_freeze:
            print('****Unfreezing frozen layers ... ***')
            for param in model.parameters():
                if not param.requires_grad:
                    param.requires_grad = True
                    optimizer.add_param_group({
                        'params': param,
                        "lr": lr_after_freeze
                    })
            count_parameters(model)

    # Function that returns the negative validation loss, useful for saving the best checkpoint at each epoch
    def get_val_loss(_):
        global val_loss
        return -val_loss[-1]

    # callback to save the best model during training
    checkpointer = handlers.ModelCheckpoint(
        dirname=dirname,
        filename_prefix=filename_prefix,
        score_function=get_val_loss,
        # score_function=log_validation_results,
        score_name='val_loss',
        n_saved=n_saved,
        create_dir=True,
        require_empty=False,
        save_as_state_dict=True)

    # callback to stop training if no improvement is observed
    patience *= 2  # because the evaluator is called twice (on training set and validation set)
    earlystop = handlers.EarlyStopping(patience, get_val_loss, trainer)
    #
    evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {
        'optimizer': optimizer,
        'model': model
    })
    evaluator.add_event_handler(Events.COMPLETED, earlystop)

    # optimizer and model that are in the backup_path, created from a previous run
    if backup_path is not None:
        original_files = glob.glob(os.path.join(backup_path, '*.pth*'))

    # utility function to periodically copy best model to `backup_path` folder
    @trainer.on(Events.EPOCH_COMPLETED)
    def backup_checkpoints(engine):
        if backup_path is not None:
            if engine.state.epoch % backup_step == 0:
                # get old model and optimizer files paths so that we can remove them after copying the newer ones
                old_files = glob.glob(os.path.join(backup_path, '*.pth'))

                # get new model and optimizer checkpoints
                new_files = glob.glob(os.path.join(dirname, '*.pth*'))
                if len(
                        new_files
                ) > 0:  # copy new checkpoints from local checkpoint folder to the backup_path folder
                    for f_ in new_files:
                        shutil.copy2(f_, backup_path)

                    if len(
                            old_files
                    ) > 0:  # remove older checkpoints as the new ones have been copied
                        for f_ in old_files:
                            if f_ not in original_files:
                                os.remove(f_)

    @trainer.on(Events.COMPLETED)
    def final_backup(_):
        if backup_path is not None:
            new_files = glob.glob(os.path.join(dirname, '*.pth*'))
            if len(new_files) > 0:
                for f_ in new_files:
                    shutil.copy2(f_, backup_path)

    # plot learning rate
    list_lr = [p['lr'] for i, p in enumerate(optimizer.param_groups) if i == 0]
    list_steps = [0]

    @trainer.on(Events.ITERATION_COMPLETED)
    def track_learning_rate(engine):
        if lr_plot is True:
            list_steps.append(engine.state.iteration)
            list_lr.extend([
                p['lr'] for i, p in enumerate(optimizer.param_groups) if i == 0
            ])

    @trainer.on(Events.EPOCH_COMPLETED)
    def add_histograms(engine):
        if launch_tensorboard:
            for name, param in model.named_parameters():
                writer.add_histogram(name,
                                     param.clone().cpu().data.numpy(),
                                     engine.state.epoch)

    trainer.run(train_loader, max_epochs=epochs)
    pbar.close()
    if launch_tensorboard:
        writer.close()
        val_writer.close()

    if lr_plot:
        plot_lr(list_lr, list_steps)
Beispiel #3
0
def train_fold(fold, args):
    # loggers
    logging_logger = args.logging_logger
    if args.tb_log:
        tb_logger = args.tb_logger

    num_classes = utils.problem_class[args.problem_type]

    # init model
    model = eval(args.model)(in_channels=3, num_classes=num_classes, bn=False)
    model = nn.DataParallel(model, device_ids=args.device_ids).cuda()

    # transform for train/valid data
    train_transform, valid_transform = get_transform(args.model)

    # loss function
    loss_func = LossMulti(num_classes, args.jaccard_weight)
    if args.semi:
        loss_func_semi = LossMultiSemi(num_classes, args.jaccard_weight, args.semi_loss_alpha, args.semi_method)

    # train/valid filenames
    train_filenames, valid_filenames = utils.trainval_split(args.train_dir, fold)

    # DataLoader and Dataset args
    train_shuffle = True
    train_ds_kwargs = {
        'filenames': train_filenames,
        'problem_type': args.problem_type,
        'transform': train_transform,
        'model': args.model,
        'mode': 'train',
        'semi': args.semi,
    }

    valid_num_workers = args.num_workers
    valid_batch_size = args.batch_size
    if 'TAPNet' in args.model:
        # for TAPNet, cancel default shuffle, use self-defined shuffle in torch.Dataset instead
        train_shuffle = False
        train_ds_kwargs['batch_size'] = args.batch_size
        train_ds_kwargs['mf'] = args.mf
    if args.semi == True:
        train_ds_kwargs['semi_method'] = args.semi_method
        train_ds_kwargs['semi_percentage'] = args.semi_percentage

    # additional valid dataset kws
    valid_ds_kwargs = {
        'filenames': valid_filenames,
        'problem_type': args.problem_type,
        'transform': valid_transform,
        'model': args.model,
        'mode': 'valid',
    }

    if 'TAPNet' in args.model:
        # in validation, num_workers should be set to 0 for sequences
        valid_num_workers = 0
        # in validation, batch_size should be set to 1 for sequences
        valid_batch_size = 1
        valid_ds_kwargs['mf'] = args.mf

    # train dataloader
    train_loader = DataLoader(
        dataset=RobotSegDataset(**train_ds_kwargs),
        shuffle=train_shuffle, # set to False to disable pytorch dataset shuffle
        num_workers=args.num_workers,
        batch_size=args.batch_size,
        pin_memory=True
    )
    # valid dataloader
    valid_loader = DataLoader(
        dataset=RobotSegDataset(**valid_ds_kwargs),
        shuffle=False, # in validation, no need to shuffle
        num_workers=valid_num_workers,
        batch_size=valid_batch_size, # in valid time. have to use one image by one
        pin_memory=True
    )

    # optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, 
    #     weight_decay=args.weight_decay, nesterov=True)    

    # ignite trainer process function
    def train_step(engine, batch):
        # set model to train
        model.train()
        # clear gradients
        optimizer.zero_grad()
        
        # additional params to feed into model
        add_params = {}
        inputs = batch['input'].cuda(non_blocking=True)
        with torch.no_grad():
            targets = batch['target'].cuda(non_blocking=True)
            # for TAPNet, add attention maps
            if 'TAPNet' in args.model:
                add_params['attmap'] = batch['attmap'].cuda(non_blocking=True)

        outputs = model(inputs, **add_params)

        loss_kwargs = {}

        if args.semi:
            loss_kwargs['labeled'] = batch['labeled']
            if args.semi_method == 'rev_flow':
                loss_kwargs['optflow'] = batch['optflow']
            loss = loss_func_semi(outputs, targets, **loss_kwargs)
        else:
            loss = loss_func(outputs, targets, **loss_kwargs)
        loss.backward()
        optimizer.step()

        return_dict = {
            'output': outputs,
            'target': targets,
            'loss_kwargs': loss_kwargs,
            'loss': loss.item(),
        }

        # for TAPNet, update attention maps after each iteration
        if 'TAPNet' in args.model:
            # output_classes and target_classes: <b, h, w>
            output_softmax_np = torch.softmax(outputs, dim=1).detach().cpu().numpy()
            # update attention maps
            train_loader.dataset.update_attmaps(output_softmax_np, batch['abs_idx'].numpy())
            return_dict['attmap'] = add_params['attmap']

        return return_dict
    
    # init trainer
    trainer = engine.Engine(train_step)

    # lr scheduler and handler
    # cyc_scheduler = optim.lr_scheduler.CyclicLR(optimizer, args.lr / 100, args.lr)
    # lr_scheduler = c_handlers.param_scheduler.LRScheduler(cyc_scheduler)
    # trainer.add_event_handler(engine.Events.ITERATION_COMPLETED, lr_scheduler)

    step_scheduler = optim.lr_scheduler.StepLR(optimizer,
        step_size=args.lr_decay_epochs, gamma=args.lr_decay)
    lr_scheduler = c_handlers.param_scheduler.LRScheduler(step_scheduler)
    trainer.add_event_handler(engine.Events.EPOCH_STARTED, lr_scheduler)


    @trainer.on(engine.Events.STARTED)
    def trainer_start_callback(engine):
        logging_logger.info('training fold {}, {} train / {} valid files'. \
            format(fold, len(train_filenames), len(valid_filenames)))

        # resume training
        if args.resume:
            # ckpt for current fold fold_<fold>_model_<epoch>.pth
            ckpt_dir = Path(args.ckpt_dir)
            ckpt_filename = ckpt_dir.glob('fold_%d_model_[0-9]*.pth' % fold)[0]
            res = re.match(r'fold_%d_model_(\d+).pth' % fold, ckpt_filename)
            # restore epoch
            engine.state.epoch = int(res.groups()[0])
            # load model state dict
            model.load_state_dict(torch.load(str(ckpt_filename)))
            logging_logger.info('restore model [{}] from epoch {}.'.format(args.model, engine.state.epoch))
        else:
            logging_logger.info('train model [{}] from scratch'.format(args.model))

        # record metrics history every epoch
        engine.state.metrics_records = {}


    @trainer.on(engine.Events.EPOCH_STARTED)
    def trainer_epoch_start_callback(engine):
        # log learning rate on pbar
        train_pbar.log_message('model: %s, problem type: %s, fold: %d, lr: %.5f, batch size: %d' % \
            (args.model, args.problem_type, fold, lr_scheduler.get_param(), args.batch_size))
        
        # for TAPNet, change dataset schedule to random after the first epoch
        if 'TAPNet' in args.model and engine.state.epoch > 1:
            train_loader.dataset.set_dataset_schedule("shuffle")


    @trainer.on(engine.Events.ITERATION_COMPLETED)
    def trainer_iter_comp_callback(engine):
        # logging_logger.info(engine.state.metrics)
        pass

    # monitor loss
    # running average loss
    train_ra_loss = imetrics.RunningAverage(output_transform=
        lambda x: x['loss'], alpha=0.98)
    train_ra_loss.attach(trainer, 'train_ra_loss')

    # monitor train loss over epoch
    if args.semi:
        train_loss = imetrics.Loss(loss_func_semi, output_transform=lambda x: (x['output'], x['target'], x['loss_kwargs']))
    else:
        train_loss = imetrics.Loss(loss_func, output_transform=lambda x: (x['output'], x['target']))
    train_loss.attach(trainer, 'train_loss')

    # progress bar
    train_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True)
    train_metric_names = ['train_ra_loss']
    train_pbar.attach(trainer, metric_names=train_metric_names)

    # tensorboardX: log train info
    if args.tb_log:
        tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer, 'lr'), 
            event_name=engine.Events.EPOCH_STARTED)

        tb_logger.attach(trainer, log_handler=OutputHandler('train_iter', train_metric_names),
            event_name=engine.Events.ITERATION_COMPLETED)

        tb_logger.attach(trainer, log_handler=OutputHandler('train_epoch', ['train_loss']),
            event_name=engine.Events.EPOCH_COMPLETED)

        tb_logger.attach(trainer,
             log_handler=WeightsScalarHandler(model, reduction=torch.norm),
             event_name=engine.Events.ITERATION_COMPLETED)

        # tb_logger.attach(trainer, log_handler=tb_log_train_vars, 
        #     event_name=engine.Events.ITERATION_COMPLETED)


    # ignite validator process function
    def valid_step(engine, batch):
        with torch.no_grad():
            model.eval()
            inputs = batch['input'].cuda(non_blocking=True)
            targets = batch['target'].cuda(non_blocking=True)

            # additional arguments
            add_params = {}
            # for TAPNet, add attention maps
            if 'TAPNet' in args.model:
                add_params['attmap'] = batch['attmap'].cuda(non_blocking=True)

            # output logits
            outputs = model(inputs, **add_params)
            # loss
            loss = loss_func(outputs, targets)

            output_softmaxs = torch.softmax(outputs, dim=1)
            output_argmaxs = output_softmaxs.argmax(dim=1)
            # output_classes and target_classes: <b, h, w>
            output_classes = output_argmaxs.cpu().numpy()
            target_classes = targets.cpu().numpy()

            # record current batch metrics
            iou_mRecords = MetricRecord()
            dice_mRecords = MetricRecord()

            cm_b = np.zeros((num_classes, num_classes), dtype=np.uint32)

            for output_class, target_class in zip(output_classes, target_classes):
                # calculate metrics for each frame
                # calculate using confusion matrix or dirctly using definition
                cm = calculate_confusion_matrix_from_arrays(output_class, target_class, num_classes)
                iou_mRecords.update_record(calculate_iou(cm))
                dice_mRecords.update_record(calculate_dice(cm))
                cm_b += cm

                ######## calculate directly using definition ##########
                # iou_mRecords.update_record(iou_multi_np(target_class, output_class))
                # dice_mRecords.update_record(dice_multi_np(target_class, output_class))

            # accumulate batch metrics to engine state
            engine.state.epoch_metrics['confusion_matrix'] += cm_b
            engine.state.epoch_metrics['iou'].merge(iou_mRecords)
            engine.state.epoch_metrics['dice'].merge(dice_mRecords)


            return_dict = {
                'loss': loss.item(),
                'output': outputs,
                'output_argmax': output_argmaxs,
                'target': targets,
                # for monitoring
                'iou': iou_mRecords,
                'dice': dice_mRecords,
            }

            if 'TAPNet' in args.model:
                # for TAPNet, update attention maps after each iteration
                valid_loader.dataset.update_attmaps(output_softmaxs.cpu().numpy(), batch['abs_idx'].numpy())
                # for TAPNet, return extra internal values
                return_dict['attmap'] = add_params['attmap']
                # TODO: for TAPNet, return internal self-learned attention maps

            return return_dict


    # validator engine
    validator = engine.Engine(valid_step)

    # monitor loss
    valid_ra_loss = imetrics.RunningAverage(output_transform=
        lambda x: x['loss'], alpha=0.98)
    valid_ra_loss.attach(validator, 'valid_ra_loss')

    # monitor validation loss over epoch
    valid_loss = imetrics.Loss(loss_func, output_transform=lambda x: (x['output'], x['target']))
    valid_loss.attach(validator, 'valid_loss')
    
    # monitor <data> mean metrics
    valid_data_miou = imetrics.RunningAverage(output_transform=
        lambda x: x['iou'].data_mean()['mean'], alpha=0.98)
    valid_data_miou.attach(validator, 'mIoU')
    valid_data_mdice = imetrics.RunningAverage(output_transform=
        lambda x: x['dice'].data_mean()['mean'], alpha=0.98)
    valid_data_mdice.attach(validator, 'mDice')

    # show metrics on progress bar (after every iteration)
    valid_pbar = c_handlers.ProgressBar(persist=True, dynamic_ncols=True)
    valid_metric_names = ['valid_ra_loss', 'mIoU', 'mDice']
    valid_pbar.attach(validator, metric_names=valid_metric_names)


    # ## monitor ignite IoU (the same as iou we are using) ###
    # cm = imetrics.ConfusionMatrix(num_classes, 
    #     output_transform=lambda x: (x['output'], x['target']))
    # imetrics.IoU(cm, 
    #     ignore_index=0
    #     ).attach(validator, 'iou')

    # # monitor ignite mean iou (over all classes even not exist in gt)
    # mean_iou = imetrics.mIoU(cm, 
    #     ignore_index=0
    #     ).attach(validator, 'mean_iou')


    @validator.on(engine.Events.STARTED)
    def validator_start_callback(engine):
        pass

    @validator.on(engine.Events.EPOCH_STARTED)
    def validator_epoch_start_callback(engine):
        engine.state.epoch_metrics = {
            # directly use definition to calculate
            'iou': MetricRecord(),
            'dice': MetricRecord(),
            'confusion_matrix': np.zeros((num_classes, num_classes), dtype=np.uint32),
        }


    # evaluate after iter finish
    @validator.on(engine.Events.ITERATION_COMPLETED)
    def validator_iter_comp_callback(engine):
        pass

    # evaluate after epoch finish
    @validator.on(engine.Events.EPOCH_COMPLETED)
    def validator_epoch_comp_callback(engine):

        # log ignite metrics
        # logging_logger.info(engine.state.metrics)
        # ious = engine.state.metrics['iou']
        # msg = 'IoU: '
        # for ins_id, iou in enumerate(ious):
        #     msg += '{:d}: {:.3f}, '.format(ins_id + 1, iou)
        # logging_logger.info(msg)
        # logging_logger.info('nonzero mean IoU for all data: {:.3f}'.format(ious[ious > 0].mean()))

        # log monitored epoch metrics
        epoch_metrics = engine.state.epoch_metrics

        ######### NOTICE: Two metrics are available but different ##########
        ### 1. mean metrics for all data calculated by confusion matrix ####

        '''
        compared with using confusion_matrix[1:, 1:] in original code,
        we use the full confusion matrix and only present non-background result
        '''
        confusion_matrix = epoch_metrics['confusion_matrix']# [1:, 1:]
        ious = calculate_iou(confusion_matrix)
        dices = calculate_dice(confusion_matrix)

        mean_ious = np.mean(list(ious.values()))
        mean_dices = np.mean(list(dices.values()))
        std_ious = np.std(list(ious.values()))
        std_dices = np.std(list(dices.values()))

        logging_logger.info('mean IoU: %.3f, std: %.3f, for each class: %s' % 
            (mean_ious, std_ious, ious))
        logging_logger.info('mean Dice: %.3f, std: %.3f, for each class: %s' % 
            (mean_dices, std_dices, dices))


        ### 2. mean metrics for all data calculated by definition ###
        iou_data_mean = epoch_metrics['iou'].data_mean()
        dice_data_mean = epoch_metrics['dice'].data_mean()

        logging_logger.info('data (%d) mean IoU: %.3f, std: %.3f' %
            (len(iou_data_mean['items']), iou_data_mean['mean'], iou_data_mean['std']))
        logging_logger.info('data (%d) mean Dice: %.3f, std: %.3f' %
            (len(dice_data_mean['items']), dice_data_mean['mean'], dice_data_mean['std']))

        # record metrics in trainer every epoch
        # trainer.state.metrics_records[trainer.state.epoch] = \
        #     {'miou': mean_ious, 'std_miou': std_ious,
        #     'mdice': mean_dices, 'std_mdice': std_dices}
        
        trainer.state.metrics_records[trainer.state.epoch] = \
            {'miou': iou_data_mean['mean'], 'std_miou': iou_data_mean['std'],
            'mdice': dice_data_mean['mean'], 'std_mdice': dice_data_mean['std']}


    # log interal variables(attention maps, outputs, etc.) on validation
    def tb_log_valid_iter_vars(engine, logger, event_name):
        log_tag = 'valid_iter'
        output = engine.state.output
        batch_size = output['output'].shape[0]
        res_grid = tvutils.make_grid(torch.cat([
            output['output_argmax'].unsqueeze(1),
            output['target'].unsqueeze(1),
        ]), padding=2, 
        normalize=False, # show origin image
        nrow=batch_size).cpu()

        logger.writer.add_image(tag='%s (outputs, targets)' % (log_tag), img_tensor=res_grid)

        if 'TAPNet' in args.model:
            # log attention maps and other internal values
            inter_vals_grid = tvutils.make_grid(torch.cat([
                output['attmap'],
            ]), padding=2, normalize=True, nrow=batch_size).cpu()
            logger.writer.add_image(tag='%s internal vals' % (log_tag), img_tensor=inter_vals_grid)

    def tb_log_valid_epoch_vars(engine, logger, event_name):
        log_tag = 'valid_iter'
        # log monitored epoch metrics
        epoch_metrics = engine.state.epoch_metrics
        confusion_matrix = epoch_metrics['confusion_matrix']# [1:, 1:]
        ious = calculate_iou(confusion_matrix)
        dices = calculate_dice(confusion_matrix)

        mean_ious = np.mean(list(ious.values()))
        mean_dices = np.mean(list(dices.values()))
        logger.writer.add_scalar('mIoU', mean_ious, engine.state.epoch)
        logger.writer.add_scalar('mIoU', mean_dices, engine.state.epoch)



    if args.tb_log:
        # log internal values
        tb_logger.attach(validator, log_handler=tb_log_valid_iter_vars, 
            event_name=engine.Events.ITERATION_COMPLETED)
        tb_logger.attach(validator, log_handler=tb_log_valid_epoch_vars,
            event_name=engine.Events.EPOCH_COMPLETED)
        # tb_logger.attach(validator, log_handler=OutputHandler('valid_iter', valid_metric_names),
        #     event_name=engine.Events.ITERATION_COMPLETED)
        tb_logger.attach(validator, log_handler=OutputHandler('valid_epoch', ['valid_loss']),
            event_name=engine.Events.EPOCH_COMPLETED)


    # score function for model saving
    ckpt_score_function = lambda engine: \
        np.mean(list(calculate_iou(engine.state.epoch_metrics['confusion_matrix']).values()))
    # ckpt_score_function = lambda engine: engine.state.epoch_metrics['iou'].data_mean()['mean']
    
    ckpt_filename_prefix = 'fold_%d' % fold

    # model saving handler
    model_ckpt_handler = handlers.ModelCheckpoint(
        dirname=args.model_save_dir,
        filename_prefix=ckpt_filename_prefix, 
        score_function=ckpt_score_function,
        create_dir=True,
        require_empty=False,
        save_as_state_dict=True,
        atomic=True)


    validator.add_event_handler(event_name=engine.Events.EPOCH_COMPLETED, 
        handler=model_ckpt_handler,
        to_save={
            'model': model,
        })

    # early stop
    # trainer=trainer, but should be handled by validator
    early_stopping = handlers.EarlyStopping(patience=args.es_patience, 
        score_function=ckpt_score_function,
        trainer=trainer
        )

    validator.add_event_handler(event_name=engine.Events.EPOCH_COMPLETED,
        handler=early_stopping)


    # evaluate after epoch finish
    @trainer.on(engine.Events.EPOCH_COMPLETED)
    def trainer_epoch_comp_callback(engine):
        validator.run(valid_loader)

    trainer.run(train_loader, max_epochs=args.max_epochs)

    if args.tb_log:
        # close tb_logger
        tb_logger.close()

    return trainer.state.metrics_records
 def _init_early_stopping_handler(self) -> None:
     if self.train_cfg.early_stop:
         early_stop_handler = handlers.EarlyStopping(self.train_cfg.patience,
                                                     score_function=self.eval_func,
                                                     trainer=self.trainer)
         self.evaluator.add_event_handler(Events.COMPLETED, early_stop_handler)
Beispiel #5
0
        tqdm.write ("Training   :: Epoch {} Loss {:.2f}".format (TRAINER.state.epoch, np.log10(metrics['loss'])))
        L_TRAIN.append (metrics['loss'])
        PBAR.n = PBAR.last_print_n = 0

    @TRAINER.on (ie.Events.EPOCH_COMPLETED)
    def log_validation_results (TRAINER):
        EVALUATOR.run (v_DataLoader)
        metrics = EVALUATOR.state.metrics
        tqdm.write ("Validation :: Epoch {} Loss {:.2f} Acc {:.2f}".format (TRAINER.state.epoch, np.log10(metrics['loss']), 100*metrics['acc']))
        L_EVAL.append (metrics['loss'])
        L_ACC.append (metrics['acc'])
        L_PRE.append (metrics['recall'])
        L_REC.append (metrics['precision'])
        L_CFM.append (metrics['cfm'])
        PBAR.n = PBAR.last_print_n = 0

    def loss_score (engine):
        return -engine.state.metrics['loss']
    early_stopper = ih.EarlyStopping (patience=30,score_function=loss_score,trainer=TRAINER)
    EVALUATOR.add_event_handler (ie.Events.COMPLETED, early_stopper)
    #########################
    try:
        TRAINER.run (t_DataLoader, max_epochs=100)
        PBAR.close ()
    except KeyboardInterrupt:
        print ("Received keyboard interrupt")
    ######
    with open (os.path.join (MDIR, "losses1k.pkl"), 'wb') as lf:
        import pickle as pkl
        pkl.dump ([L_TRAIN, L_EVAL, L_ACC, L_PRE, L_REC, L_CFM], lf)
def run(path_to_model_script,
        epochs,
        log_interval,
        dataloaders,
        dirname='resnet_models',
        filename_prefix='resnet',
        n_saved=2,
        log_dir='../../fer2013/logs',
        launch_tensorboard=False,
        patience=10,
        resume_model=None,
        resume_optimizer=None,
        backup_step=1,
        backup_path=''):

    if launch_tensorboard:
        os.makedirs(log_dir, exist_ok=True)
        # os.system('pkill tensorboard')
        # os.system('tensorboard --logdir {} --host 0.0.0.0 --port 6006 &'.format(log_dir))
        # os.system("npm install -g localtunnel")
        # os.system('lt --port 6006 >> /content/url.txt 2>&1 &')
        # os.system('cat /content/url.txt')

    # Get the model, optimizer and dataloaders from script
    model_script = dict()
    with open(path_to_model_script) as f:
        exec(f.read(), model_script)

    model = model_script['my_model']
    optimizer = model_script['optimizer']

    if resume_model:
        model.load_state_dict(torch.load(resume_model))
    if resume_optimizer:
        optimizer.load_state_dict(torch.load(resume_model))

    train_loader, val_loader = dataloaders['Training'], dataloaders[
        'PublicTest']

    if launch_tensorboard:
        writer, val_writer = create_summary_writer(model, train_loader,
                                                   log_dir)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        F.cross_entropy,
                                        device=device)
    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                'accuracy': Accuracy(),
                                                'nll': Loss(F.cross_entropy)
                                            },
                                            device=device)

    def get_val_loss(engine):
        return -engine.state.metrics['nll']

    checkpointer = handlers.ModelCheckpoint(dirname=dirname,
                                            filename_prefix=filename_prefix,
                                            score_function=get_val_loss,
                                            score_name='val_loss',
                                            n_saved=n_saved,
                                            create_dir=True,
                                            require_empty=False,
                                            save_as_state_dict=True)
    earlystop = handlers.EarlyStopping(patience, get_val_loss, trainer)

    evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {
        'optimizer': optimizer,
        'model': model
    })
    evaluator.add_event_handler(Events.EPOCH_COMPLETED, earlystop)

    desc = "ITERATION - loss: {:.3f}"
    pbar = tqdm(initial=0,
                leave=False,
                total=len(train_loader),
                desc=desc.format(0))

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter_ = (engine.state.iteration - 1) % len(train_loader) + 1

        if iter_ % log_interval == 0:
            pbar.desc = desc.format(engine.state.output)
            pbar.update(log_interval)

        if launch_tensorboard:
            writer.add_scalar('training/loss', engine.state.output,
                              engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        pbar.refresh()
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']
        tqdm.write(
            "Training Results - Epoch: {}  Avg accuracy: {:.3f} Avg loss: {:.3f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll))

        if launch_tensorboard:
            writer.add_scalar('avg_loss', avg_nll, engine.state.epoch)
            writer.add_scalar('avg_accuracy', avg_accuracy, engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_nll = metrics['nll']
        tqdm.write(
            "Validation Results - Epoch: {}  Avg accuracy: {:.3f} Avg loss: {:.3f}"
            .format(engine.state.epoch, avg_accuracy, avg_nll))

        pbar.n = pbar.last_print_n = 0

        if launch_tensorboard:
            val_writer.add_scalar('avg_loss', avg_nll, engine.state.epoch)
            val_writer.add_scalar('avg_accuracy', avg_accuracy,
                                  engine.state.epoch)

    # optimizer and model that are in the gdrive, created from a previous run
    original_files = glob.glob(os.path.join(backup_path, '*.pth*'))

    @trainer.on(Events.EPOCH_COMPLETED)
    def backup_checkpoints(engine):
        if engine.state.epoch % backup_step == 0:

            # get old model and optimizer files paths so that we can remove them after copying the newer ones
            old_files = glob.glob(os.path.join(backup_path, '*.pth'))

            # get new model and optimizer checkpoints
            new_files = glob.glob(os.path.join(dirname, '*.pth*'))
            if len(
                    new_files
            ) > 0:  # copy new checkpoints from local checkpoint folder to the backup_path folder
                for f_ in new_files:
                    shutil.copy2(f_, backup_path)

                if len(
                        old_files
                ) > 0:  # remove older checkpoints as the new ones have been copied
                    for f_ in old_files:
                        if f_ not in original_files:
                            os.remove(f_)

    if launch_tensorboard:

        @trainer.on(Events.EPOCH_COMPLETED)
        def add_histograms(engine):
            for name, param in model.named_parameters():
                writer.add_histogram(name,
                                     param.clone().cpu().data.numpy(),
                                     engine.state.epoch)

    trainer.run(train_loader, max_epochs=epochs)
    pbar.close()
    if launch_tensorboard:
        writer.close()
        val_writer.close()
Beispiel #7
0
def run(path_to_model_script,
        epochs,
        log_interval,
        dataloaders,
        dirname='resnet_models',
        filename_prefix='resnet',
        n_saved=2,
        log_dir='../../fer2013/logs',
        launch_tensorboard=False,
        patience=10):

    # if launch_tensorboard:
    #     os.system('pkill tensorboard')
    #     os.system('tensorboard --logdir {} --host 0.0.0.0 --port 6006 &'.format(log_dir))
    #     os.system("npm install -g localtunnel")
    #     os.system('lt --port 6006 >> /content/url.txt 2>&1 &')
    #     os.system('cat /content/url.txt')

    # Get the model, optimizer and dataloaders from script
    model_script = dict()
    with open(path_to_model_script) as f:
        exec(f.read(), model_script)

    model = model_script['my_model']
    optimizer = model_script['optimizer']

    train_loader, val_loader = dataloaders['train'], dataloaders['valid']

    if launch_tensorboard:
        writer, val_writer = create_summary_writer(model, train_loader,
                                                   log_dir)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    trainer = create_supervised_trainer_multitask(model,
                                                  optimizer,
                                                  loss_fn=my_multi_task_loss,
                                                  device=device)
    evaluator = create_supervised_evaluator_multitask(model,
                                                      metrics={
                                                          'mt_accuracy':
                                                          MultiTaskAccuracy(),
                                                          'mt_loss':
                                                          MutliTaskLoss()
                                                      },
                                                      device=device)

    def get_val_loss(engine):
        return -engine.state.metrics['mt_loss']

    checkpointer = handlers.ModelCheckpoint(dirname=dirname,
                                            filename_prefix=filename_prefix,
                                            score_function=get_val_loss,
                                            score_name='val_loss',
                                            n_saved=n_saved,
                                            create_dir=True,
                                            require_empty=False,
                                            save_as_state_dict=True)
    earlystop = handlers.EarlyStopping(patience, get_val_loss, trainer)
    #
    evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {
        'optimizer': optimizer,
        'model': model
    })
    evaluator.add_event_handler(Events.EPOCH_COMPLETED, earlystop)

    desc = "ITERATION - loss: {:.3f}"
    pbar = tqdm(initial=0,
                leave=False,
                total=len(train_loader),
                desc=desc.format(0))

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter_ = (engine.state.iteration - 1) % len(train_loader) + 1

        if iter_ % log_interval == 0:
            pbar.desc = desc.format(engine.state.output)
            pbar.update(log_interval)

        if launch_tensorboard:
            writer.add_scalar('training/loss', engine.state.output,
                              engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        pbar.refresh()
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        age_l1_loss, gender_acc, race_acc = metrics['mt_accuracy']
        avg_nll = metrics['mt_loss']
        tqdm.write(
            "Training Results - Epoch: {} Age L1-loss: {:.3f} ** Gender accuracy: {:.3f} "
            "** Race accuracy: {:.3f} ** Avg loss: {:.3f}".format(
                engine.state.epoch, age_l1_loss, gender_acc, race_acc,
                avg_nll))

        # if launch_tensorboard:
        #     writer.add_scalar('avg_loss', avg_nll, engine.state.epoch)
        #     writer.add_scalar('avg_accuracy', avg_accuracy, engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        age_l1_loss, gender_acc, race_acc = metrics['mt_accuracy']
        avg_nll = metrics['mt_loss']
        tqdm.write(
            "Validation Results - Epoch: {} Age L1-loss: {:.3f} ** Gender accuracy: {:.3f} **"
            " Race accuracy: {:.3f} ** Avg loss: {:.3f}".format(
                engine.state.epoch, age_l1_loss, gender_acc, race_acc,
                avg_nll))

        pbar.n = pbar.last_print_n = 0

        # if launch_tensorboard:
        #     val_writer.add_scalar('avg_loss', avg_nll, engine.state.epoch)
        #     val_writer.add_scalar('avg_accuracy', avg_accuracy, engine.state.epoch)

    # if launch_tensorboard:
    #     @trainer.on(Events.EPOCH_COMPLETED)
    #     def add_histograms(engine):
    #         for name, param in model.named_parameters():
    #             writer.add_histogram(name, param.clone().cpu().data.numpy(), engine.state.epoch)

    trainer.run(train_loader, max_epochs=epochs)
    pbar.close()