Ejemplo n.º 1
0
def main():
    # Init state params
    params = init_parms()
    device = params.get('device')

    # Loading the model, optimizer & criterion
    model = ASRModel(input_features=config.num_mel_banks,
                     num_classes=config.vocab_size).to(device)
    model = torch.nn.DataParallel(model)
    logger.info(f'Model initialized with {get_model_size(model):.3f}M parameters')
    optimizer = Ranger(model.parameters(), lr=config.lr, eps=1e-5)
    load_checkpoint(model, optimizer, params)
    start_epoch = params['start_epoch']
    sup_criterion = CustomCTCLoss()
    unsup_criterion = UDALoss()

    # Init tensorboard logger, currently gives an error on 37 server.
    tb_logger = TensorboardLogger(log_dir=log_path)

    # Validation progress bars defined here.
    pbar = ProgressBar(persist=True, desc="Training")
    pbar_valid = ProgressBar(persist=True, desc="Validation Clean")
    pbar_valid_other = ProgressBar(persist=True, desc="Validation Other")
    pbar_valid_airtel = ProgressBar(persist=True, desc="Validation Airtel")
    pbar_valid_airtel_payments = ProgressBar(persist=True, desc="Validation Airtel Payments")
    pbar_valid_airtel_hinghlish = ProgressBar(persist=True, desc="Validation Airtel Highlish")

    # load timer and best meter to keep track of state params
    timer = Timer(average=True)
    best_meter = params.get('best_stats', BestMeter())

    # load all the train data
    logger.info('Begining to load Datasets')
    trainCleanPath = os.path.join(lmdb_root_path, 'train-labelled')
    trainOtherPath = os.path.join(lmdb_root_path, 'train-unlabelled')
    trainCommonVoicePath = os.path.join(
        lmdb_commonvoice_root_path, 'train-labelled-en')
    trainAirtelPath = os.path.join(lmdb_airtel_root_path, 'train-labelled-en')
    trainAirtelPaymentsPath = os.path.join(lmdb_airtel_payments_root_path, 'train-labelled-en')
    trainAirtelHinglishPath = os.path.join(lmdb_airtel_hinglish_root_path, 'train-labelled-en')

    # test data
    testCleanPath = os.path.join(lmdb_root_path, 'test-clean')
    testOtherPath = os.path.join(lmdb_root_path, 'test-other')
    testAirtelPath = os.path.join(lmdb_airtel_root_path, 'test-labelled-en')
    testAirtelPaymentsPath = os.path.join(lmdb_airtel_payments_root_path, 'test-labelled-en')
    testAirtelHinglishPath = os.path.join(lmdb_airtel_hinglish_root_path, 'test-labelled-en')

    # ideally the unsupervised data here
    devOtherPath = os.path.join(lmdb_root_path, 'dev-other')

    # form data loaders
    train_clean = lmdbMultiDataset(
        roots=[trainCleanPath, trainOtherPath, trainCommonVoicePath, trainAirtelPath, trainAirtelPaymentsPath, trainAirtelHinglishPath], transform=image_train_transform)
    train_other = lmdbMultiDataset(roots=[devOtherPath], transform=image_train_transform)

    test_clean = lmdbMultiDataset(roots=[testCleanPath], transform=image_val_transform)
    test_other = lmdbMultiDataset(roots=[testOtherPath], transform=image_val_transform)
    test_airtel = lmdbMultiDataset(roots=[testAirtelPath], transform=image_val_transform)
    test_payments_airtel = lmdbMultiDataset(roots=[testAirtelPaymentsPath], transform=image_val_transform)
    test_hinglish_airtel = lmdbMultiDataset(roots=[testAirtelHinglishPath], transform=image_val_transform)

    logger.info(
        f'Loaded Train & Test Datasets, train_labbeled={len(train_clean)}, train_unlabbeled={len(train_other)}, test_clean={len(test_clean)}, test_other={len(test_other)}, test_airtel={len(test_airtel)}, test_payments_airtel={len(test_payments_airtel)}, test_hinglish_airtel={len(test_hinglish_airtel)} examples')

    def train_update_function(engine, _):
        optimizer.zero_grad()

        # Supervised gt, pred
        imgs_sup, labels_sup, label_lengths, input_lengths = next(
            engine.state.train_loader_labbeled)

        imgs_sup = imgs_sup.to(device)
        labels_sup = labels_sup
        # with torch.autograd.detect_anomaly():
        probs_sup = model(imgs_sup)
        # Unsupervised gt, pred
        # imgs_unsup, augmented_imgs_unsup = next(engine.state.train_loader_unlabbeled)
        # with torch.no_grad():
        #     probs_unsup = model(imgs_unsup.to(device))
        # probs_aug_unsup = model(augmented_imgs_unsup.to(device))
        sup_loss = sup_criterion(probs_sup, labels_sup, label_lengths, input_lengths)
        # unsup_loss = unsup_criterion(probs_unsup, probs_aug_unsup)

        # Blend supervised and unsupervised losses till unsupervision_warmup_epoch
        # alpha = get_alpha(engine.state.epoch)
        # final_loss = ((1 - alpha) * sup_loss) + (alpha * unsup_loss)

        # final_loss = sup_loss
        sup_loss.backward()
        optimizer.step()

        return sup_loss.item()

    @torch.no_grad()
    def validate_update_function(engine, batch):
        img, labels, label_lengths, image_lengths = batch
        y_pred = model(img.to(device))
        if np.random.rand() > 0.99:
            pred_sentences = get_most_probable(y_pred)
            labels_list = labels.tolist()
            idx = 0
            for i, length in enumerate(label_lengths.cpu().tolist()):
                pred_sentence = pred_sentences[i]
                gt_sentence = sequence_to_string(labels_list[idx:idx+length])
                idx += length
                print(f"Pred sentence: {pred_sentence}, GT: {gt_sentence}")
        return (y_pred, labels, label_lengths)

    train_loader_labbeled_loader = torch.utils.data.DataLoader(
        train_clean, batch_size=train_batch_size, shuffle=True, num_workers=config.workers, pin_memory=True, collate_fn=allign_collate)
    train_loader_unlabbeled_loader = torch.utils.data.DataLoader(
        train_other, batch_size=train_batch_size * 4, shuffle=True, num_workers=config.workers, pin_memory=True, collate_fn=allign_collate)
    test_loader_clean = torch.utils.data.DataLoader(
        test_clean, batch_size=torch.cuda.device_count(), shuffle=False, num_workers=config.workers, pin_memory=True, collate_fn=allign_collate)
    test_loader_other = torch.utils.data.DataLoader(
        test_other, batch_size=torch.cuda.device_count(), shuffle=False, num_workers=config.workers, pin_memory=True, collate_fn=allign_collate)
    test_loader_airtel = torch.utils.data.DataLoader(
        test_airtel, batch_size=torch.cuda.device_count(), shuffle=False, num_workers=config.workers, pin_memory=True, collate_fn=allign_collate)
    test_loader_airtel_payments = torch.utils.data.DataLoader(
        test_payments_airtel, batch_size=torch.cuda.device_count(), shuffle=False, num_workers=config.workers, pin_memory=True, collate_fn=allign_collate)
    test_loader_airtel_hinglish = torch.utils.data.DataLoader(
        test_hinglish_airtel, batch_size=torch.cuda.device_count(), shuffle=False, num_workers=config.workers, pin_memory=True, collate_fn=allign_collate)
    trainer = Engine(train_update_function)
    evaluator_clean = Engine(validate_update_function)
    evaluator_other = Engine(validate_update_function)
    evaluator_airtel = Engine(validate_update_function)
    evaluator_airtel_payments = Engine(validate_update_function)
    evaluator_airtel_hinglish = Engine(validate_update_function)
    metrics = {'wer': WordErrorRate(), 'cer': CharacterErrorRate()}
    iteration_log_step = int(0.33 * len(train_loader_labbeled_loader))
    for name, metric in metrics.items():
        metric.attach(evaluator_clean, name)
        metric.attach(evaluator_other, name)
        metric.attach(evaluator_airtel, name)
        metric.attach(evaluator_airtel_payments, name)
        metric.attach(evaluator_airtel_hinglish, name)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=config.lr_gamma, patience=int(
        config.epochs * 0.05), verbose=True, threshold_mode="abs", cooldown=int(config.epochs * 0.025), min_lr=1e-5)

    tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", output_transform=lambda loss: {'loss': loss}),
                     event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(trainer,
                     log_handler=OptimizerParamsHandler(optimizer),
                     event_name=Events.ITERATION_STARTED)
    tb_logger.attach(trainer,
                     log_handler=WeightsHistHandler(model),
                     event_name=Events.EPOCH_COMPLETED)
    tb_logger.attach(trainer,
                     log_handler=WeightsScalarHandler(model),
                     event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(trainer,
                     log_handler=GradsScalarHandler(model),
                     event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(trainer,
                     log_handler=GradsHistHandler(model),
                     event_name=Events.EPOCH_COMPLETED)
    tb_logger.attach(evaluator_clean,
                     log_handler=OutputHandler(tag="validation_clean", metric_names=[
                                               "wer", "cer"], another_engine=trainer),
                     event_name=Events.EPOCH_COMPLETED)
    tb_logger.attach(evaluator_other,
                     log_handler=OutputHandler(tag="validation_other", metric_names=[
                                               "wer", "cer"], another_engine=trainer),
                     event_name=Events.EPOCH_COMPLETED)
    tb_logger.attach(evaluator_airtel,
                     log_handler=OutputHandler(tag="validation_airtel", metric_names=[
                                               "wer", "cer"], another_engine=trainer),
                     event_name=Events.EPOCH_COMPLETED)
    tb_logger.attach(evaluator_airtel_payments,
                     log_handler=OutputHandler(tag="validation_airtel_payments", metric_names=[
                                               "wer", "cer"], another_engine=trainer),
                     event_name=Events.EPOCH_COMPLETED)
    tb_logger.attach(evaluator_airtel_hinglish,
                     log_handler=OutputHandler(tag="validation_airtel_highlish", metric_names=[
                                               "wer", "cer"], another_engine=trainer),
                     event_name=Events.EPOCH_COMPLETED)
    pbar.attach(trainer, output_transform=lambda x: {'loss': x})
    pbar_valid.attach(evaluator_clean, [
                      'wer', 'cer'], event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED)
    pbar_valid_other.attach(evaluator_other, [
                            'wer', 'cer'], event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED)
    pbar_valid_airtel.attach(evaluator_airtel, [
                            'wer', 'cer'], event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED)
    pbar_valid_airtel_payments.attach(evaluator_airtel_payments, [
                            'wer', 'cer'], event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED)
    pbar_valid_airtel_hinghlish.attach(evaluator_airtel_hinglish, [
                            'wer', 'cer'], event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED)
    timer.attach(trainer)

    @trainer.on(Events.STARTED)
    def set_init_epoch(engine):
        engine.state.epoch = params['start_epoch']
        logger.info(f'Initial epoch for trainer set to {engine.state.epoch}')

    @trainer.on(Events.EPOCH_STARTED)
    def set_model_train(engine):
        if hasattr(engine.state, 'train_loader_labbeled'):
            del engine.state.train_loader_labbeled
        engine.state.train_loader_labbeled = iter(train_loader_labbeled_loader)
        # engine.state.train_loader_unlabbeled = iter(train_loader_unlabbeled_loader)

    @trainer.on(Events.ITERATION_COMPLETED)
    def iteration_completed(engine):
        if (engine.state.iteration % iteration_log_step == 0) and (engine.state.iteration > 0):
            engine.state.epoch += 1
            train_clean.set_epochs(engine.state.epoch)
            train_other.set_epochs(engine.state.epoch)
            model.eval()
            logger.info('Model set to eval mode')
            evaluator_clean.run(test_loader_clean)
            evaluator_other.run(test_loader_other)
            evaluator_airtel.run(test_loader_airtel)
            evaluator_airtel_payments.run(test_loader_airtel_payments)
            evaluator_airtel_hinglish.run(test_loader_airtel_hinglish)
            model.train()
            logger.info('Model set back to train mode')

    @trainer.on(Events.EPOCH_COMPLETED)
    def after_complete(engine):
        logger.info('Epoch {} done. Time per batch: {:.3f}[s]'.format(
            engine.state.epoch, timer.value()))
        timer.reset()

    @evaluator_other.on(Events.EPOCH_COMPLETED)
    def save_checkpoints(engine):
        metrics = engine.state.metrics
        wer = metrics['wer']
        cer = metrics['cer']
        epoch = trainer.state.epoch
        scheduler.step(wer)
        save_checkpoint(model, optimizer, best_meter, wer, cer, epoch)
        best_meter.update(wer, cer, epoch)

    trainer.run(train_loader_labbeled_loader, max_epochs=epochs)
    tb_logger.close()
Ejemplo n.º 2
0
model = C2F3(n_classes=10)
dataset = MNISTDataset('train_small.csv')
train_loader = DataLoader(dataset, batch_size=1024)
#eval_loader = DataLoader(dataset, batch_size=1024, shuffle=False)

criterion = nn.MSELoss(reduction='sum')    #nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(model.parameters(), lr=1e-4)    #torch.optim.SGD(model.parameters(), lr=1e-4)

#%% Initialize trainer and handlers which print info after iterations and epochs
trainer = create_supervised_trainer(model, optimizer, criterion)
evaluator = create_supervised_evaluator(model, 
                                        metrics={
                                                #'accuracy': CategoricalAccuracy(),
                                                'mse': Loss(criterion)
                                        })
timer = Timer(average=True)
timer.attach(trainer, start=Events.EPOCH_STARTED, pause=Events.EPOCH_COMPLETED,
             resume=Events.EPOCH_STARTED, step=Events.EPOCH_COMPLETED)
    
@trainer.on(Events.ITERATION_COMPLETED)
def log_on_iteration_completed(trainer):
    print('Epoch [{}] Loss: {:.2f}'.format(
            trainer.state.epoch, trainer.state.output))

@trainer.on(Events.EPOCH_COMPLETED)
def log_on_epoch_completed(trainer):
    evaluator.run(train_loader)
    metrics = evaluator.state.metrics
    #print("Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
    #      .format(trainer.state.epoch, metrics['accuracy'], metrics['mse']))
    print('Epoch [{}] Time: {:.2f} sec'.format(trainer.state.epoch, timer.value()))
def do_train(cfg, model, train_loader, val_loader, optimizer, scheduler,
             loss_fn, num_query, start_epoch, device_id,
             train_camstyle_loader):

    log_period = cfg.SOLVER.LOG_PERIOD
    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    eval_period = cfg.SOLVER.EVAL_PERIOD
    output_dir = cfg.OUTPUT_DIR
    epochs = cfg.SOLVER.MAX_EPOCHS
    device = cfg.MODEL.DEVICE

    logger = logging.getLogger("reid_baseline.train")
    logger.info("Start training")
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        loss_fn,
                                        device=device,
                                        device_id=device_id)
    evaluator = create_supervised_evaluator(
        model,
        metrics={
            'r1_mAP':
            R1_mAP(num_query,
                   True,
                   False,
                   max_rank=50,
                   feat_norm=cfg.TEST.FEAT_NORM)
        },
        device=device,
        device_id=device_id)
    if device_id == 0:
        checkpointer = ModelCheckpoint(output_dir,
                                       cfg.MODEL.NAME,
                                       checkpoint_period,
                                       n_saved=10,
                                       require_empty=False)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict()
        })

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    # average metric to attach on trainer
    RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
    RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')
    RunningAverage(output_transform=lambda x: x[2]).attach(
        trainer, 'data_ratio')

    @trainer.on(Events.STARTED)
    def start_training(engine):
        engine.state.epoch = start_epoch

    @trainer.on(Events.EPOCH_STARTED)
    def adjust_learning_rate(engine):
        scheduler.step()

    def cycle(iterable):
        while True:
            for i in iterable:
                yield i

    train_loader_iter = cycle(train_loader)
    train_camstyle_loader_iter = cycle(train_camstyle_loader)

    @trainer.on(Events.ITERATION_STARTED)
    def generate_batch(engine):
        current_iter = engine.state.iteration
        batch = next(train_loader_iter)
        camstyle_batch = next(train_camstyle_loader_iter)
        engine.state.batch = [batch, camstyle_batch]

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.state.iteration - 1) % len(train_loader) + 1
        if iter % log_period == 0:
            logger.info(
                "Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, ratio of data/cam_data: {:.3f}, Base Lr: {:.2e}"
                .format(engine.state.epoch, iter, len(train_loader),
                        engine.state.metrics['avg_loss'],
                        engine.state.metrics['avg_acc'],
                        engine.state.metrics['data_ratio'],
                        scheduler.get_lr()[0]))

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        logger.info(
            'Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
            .format(engine.state.epoch,
                    timer.value() * timer.step_count,
                    train_loader.batch_size / timer.value()))
        logger.info('-' * 10)
        timer.reset()

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        if engine.state.epoch % eval_period == 0:
            evaluator.run(val_loader)
            cmc, mAP = evaluator.state.metrics['r1_mAP']
            logger.info("Validation Results - Epoch: {}".format(
                engine.state.epoch))
            logger.info("mAP: {:.1%}".format(mAP))
            for r in [1, 5, 10]:
                logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(
                    r, cmc[r - 1]))

    num_iters = len(train_loader)
    data = list(range(num_iters))

    trainer.run(data, max_epochs=epochs)
Ejemplo n.º 4
0
    ProgressBar().attach(trainer, metric_names=metric_names)

    # Model checkpointing
    checkpoint_handler = ModelCheckpoint("./",
                                         "checkpoint",
                                         save_interval=1,
                                         n_saved=3,
                                         require_empty=False)

    #trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler,
    #                          to_save={'model': model, 'optimizer': optimizer,
    #                                   'annealers': (sigma_scheme.data, mu_scheme.data)})

    timer = Timer(average=True).attach(trainer,
                                       start=Events.EPOCH_STARTED,
                                       resume=Events.ITERATION_STARTED,
                                       pause=Events.ITERATION_COMPLETED,
                                       step=Events.ITERATION_COMPLETED)

    # Tensorbard writer
    writer = SummaryWriter(log_dir=args.log_dir)

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_metrics(engine):
        for key, value in engine.state.metrics.items():
            writer.add_scalar("training/{}".format(key), value,
                              engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def save_images(engine):
        print("Epoch Completed save_images")
Ejemplo n.º 5
0
def add_stdout_handler(trainer, validator=None):
    """
    This adds the following handler to the trainer engine, and also sets up
    Timers:

    - log_epoch_to_stdout: This logs the results of a model after it has trained
      for a single epoch on both the training and validation set. The output typically
      looks like this:

      .. code-block:: none

            EPOCH SUMMARY
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            - Epoch number: 0010 / 0010
            - Training loss:   0.583591
            - Validation loss: 0.137209
            - Epoch took: 00:00:03
            - Time since start: 00:00:32
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
            Saving to test.
            Output @ tests/local/trainer
    
    Args:
        trainer (ignite.Engine): Engine for trainer

        validator (ignite.Engine, optional): Engine for validation. 
          Defaults to None.
    """
    # Set up timers for overall time taken and each epoch
    overall_timer = Timer(average=False)
    overall_timer.attach(trainer,
                         start=Events.STARTED, pause=Events.COMPLETED)

    epoch_timer = Timer(average=False)
    epoch_timer.attach(
        trainer, start=Events.EPOCH_STARTED,
        pause=ValidationEvents.VALIDATION_COMPLETED
    )

    @trainer.on(ValidationEvents.VALIDATION_COMPLETED)
    def log_epoch_to_stdout(trainer):
        epoch_time = epoch_timer.value()
        epoch_time = time.strftime(
            "%H:%M:%S", time.gmtime(epoch_time))
        overall_time = overall_timer.value()
        overall_time = time.strftime(
            "%H:%M:%S", time.gmtime(overall_time))

        epoch_number = trainer.state.epoch
        total_epochs = trainer.state.max_epochs

        try:
            validation_loss = (
                f"{trainer.state.epoch_history['validation/loss'][-1]:04f}")
        except:
            validation_loss = 'N/A'

        train_loss = trainer.state.epoch_history['train/loss'][-1]
        saved_model_path = trainer.state.saved_model_path

        logging_str = (
            f"\n\n"
            f"EPOCH SUMMARY \n"
            f"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ \n"
            f"- Epoch number: {epoch_number:04d} / {total_epochs:04d} \n"
            f"- Training loss:   {train_loss:04f} \n"
            f"- Validation loss: {validation_loss} \n"
            f"- Epoch took: {epoch_time} \n"
            f"- Time since start: {overall_time} \n"
            f"~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ \n"
            f"Saving to {saved_model_path}. \n"
            f"Output @ {trainer.state.output_folder} \n"
        )

        logging.info(logging_str)
Ejemplo n.º 6
0
def train(run_name, forward_func, model, train_set, val_set, n_epochs,
          batch_size, lr):

    # Make the run directory
    save_dir = os.path.join('training/simple/saved_runs', run_name)
    if run_name == 'debug':
        shutil.rmtree(save_dir, ignore_errors=True)
    os.mkdir(save_dir)

    model = model.to(device)

    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              shuffle=True,
                              drop_last=True)
    val_loader = DataLoader(val_set,
                            batch_size=batch_size,
                            shuffle=True,
                            drop_last=True)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Training step
    def step(engine, batch):
        model.train()

        if isinstance(batch, list):
            batch = [tensor.to(device) for tensor in batch]
        else:
            batch = batch.to(device)
        x_gen, x_q, _ = forward_func(model, batch)

        loss = F.l1_loss(x_gen, x_q)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        return {'L1': loss}

    # Trainer and metrics
    trainer = Engine(step)
    metric_names = ['L1']
    RunningAverage(output_transform=lambda x: x['L1']).attach(trainer, 'L1')
    ProgressBar().attach(trainer, metric_names=metric_names)
    Timer(average=True).attach(trainer,
                               start=Events.EPOCH_STARTED,
                               resume=Events.ITERATION_STARTED,
                               pause=Events.ITERATION_COMPLETED,
                               step=Events.ITERATION_COMPLETED)

    # Model checkpointing
    checkpoint_handler = ModelCheckpoint(os.path.join(save_dir, 'checkpoints'),
                                         type(model).__name__,
                                         save_interval=1,
                                         n_saved=3,
                                         require_empty=False)
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={
                                  'model': model,
                                  'optimizer': optimizer
                              })

    # Tensorbard writer
    writer = SummaryWriter(log_dir=os.path.join(save_dir, 'logs'))

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_metrics(engine):
        if engine.state.iteration % 100 == 0:
            for metric, value in engine.state.metrics.items():
                writer.add_scalar('training/{}'.format(metric), value,
                                  engine.state.iteration)

    def save_images(engine, batch):
        x_gen, x_q, r = forward_func(model, batch)
        r_dim = r.shape[1]
        if isinstance(model, SimpleVVGQN):
            r = (r + 1) / 2
        r = r.view(-1, 1, int(math.sqrt(r_dim)), int(math.sqrt(r_dim)))

        x_gen = x_gen.detach().cpu().float()
        r = r.detach().cpu().float()

        writer.add_image('representation', make_grid(r), engine.state.epoch)
        writer.add_image('generation', make_grid(x_gen), engine.state.epoch)
        writer.add_image('query', make_grid(x_q), engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def validate(engine):
        model.eval()
        with torch.no_grad():
            batch = next(iter(val_loader))
            if isinstance(batch, list):
                batch = [tensor.to(device) for tensor in batch]
            else:
                batch = batch.to(device)
            x_gen, x_q, r = forward_func(model, batch)

            loss = F.l1_loss(x_gen, x_q)

            writer.add_scalar('validation/L1', loss.item(), engine.state.epoch)

            save_images(engine, batch)

    @trainer.on(Events.EXCEPTION_RAISED)
    def handle_exception(engine, e):
        writer.close()
        engine.terminate()
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            import warnings
            warnings.warn('KeyboardInterrupt caught. Exiting gracefully.')
            checkpoint_handler(engine, {'model_exception': model})
        else:
            raise e

    start_time = time.time()
    trainer.run(train_loader, n_epochs)
    writer.close()
    end_time = time.time()
    print('Total training time: {}'.format(
        timedelta(seconds=end_time - start_time)))
Ejemplo n.º 7
0
def main(
    dataset,
    dataroot,
    z_dim,
    g_filters,
    d_filters,
    batch_size,
    epochs,
    learning_rate,
    beta_1,
    saved_G,
    saved_D,
    seed,
    n_workers,
    device,
    alpha,
    output_dir,
):

    # seed
    check_manual_seed(seed)

    # data
    dataset, num_channels = check_dataset(dataset, dataroot)
    loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True)

    # netowrks
    netG = Generator(z_dim, g_filters, num_channels).to(device)
    netD = Discriminator(num_channels, d_filters).to(device)

    # criterion
    bce = nn.BCELoss()

    # optimizers
    optimizerG = optim.Adam(netG.parameters(), lr=learning_rate, betas=(beta_1, 0.999))
    optimizerD = optim.Adam(netD.parameters(), lr=learning_rate, betas=(beta_1, 0.999))

    # load pre-trained models
    if saved_G:
        netG.load_state_dict(torch.load(saved_G))

    if saved_D:
        netD.load_state_dict(torch.load(saved_D))

    # misc
    real_labels = torch.ones(batch_size, device=device)
    fake_labels = torch.zeros(batch_size, device=device)
    fixed_noise = torch.randn(batch_size, z_dim, 1, 1, device=device)

    def get_noise():
        return torch.randn(batch_size, z_dim, 1, 1, device=device)

    # The main function, processing a batch of examples
    def step(engine, batch):

        # unpack the batch. It comes from a dataset, so we have <images, labels> pairs. Discard labels.
        real, _ = batch
        real = real.to(device)

        # -----------------------------------------------------------
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        netD.zero_grad()

        # train with real
        output = netD(real)
        errD_real = bce(output, real_labels)
        D_x = output.mean().item()

        errD_real.backward()

        # get fake image from generator
        noise = get_noise()
        fake = netG(noise)

        # train with fake
        output = netD(fake.detach())
        errD_fake = bce(output, fake_labels)
        D_G_z1 = output.mean().item()

        errD_fake.backward()

        # gradient update
        errD = errD_real + errD_fake
        optimizerD.step()

        # -----------------------------------------------------------
        # (2) Update G network: maximize log(D(G(z)))
        netG.zero_grad()

        # Update generator. We want to make a step that will make it more likely that discriminator outputs "real"
        output = netD(fake)
        errG = bce(output, real_labels)
        D_G_z2 = output.mean().item()

        errG.backward()

        # gradient update
        optimizerG.step()

        return {"errD": errD.item(), "errG": errG.item(), "D_x": D_x, "D_G_z1": D_G_z1, "D_G_z2": D_G_z2}

    # ignite objects
    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir, CKPT_PREFIX, n_saved=10, require_empty=False)
    timer = Timer(average=True)

    # attach running average metrics
    monitoring_metrics = ["errD", "errG", "D_x", "D_G_z1", "D_G_z2"]
    RunningAverage(alpha=alpha, output_transform=lambda x: x["errD"]).attach(trainer, "errD")
    RunningAverage(alpha=alpha, output_transform=lambda x: x["errG"]).attach(trainer, "errG")
    RunningAverage(alpha=alpha, output_transform=lambda x: x["D_x"]).attach(trainer, "D_x")
    RunningAverage(alpha=alpha, output_transform=lambda x: x["D_G_z1"]).attach(trainer, "D_G_z1")
    RunningAverage(alpha=alpha, output_transform=lambda x: x["D_G_z2"]).attach(trainer, "D_G_z2")

    # attach progress bar
    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    @trainer.on(Events.ITERATION_COMPLETED(every=PRINT_FREQ))
    def print_logs(engine):
        fname = os.path.join(output_dir, LOGS_FNAME)
        columns = ["iteration",] + list(engine.state.metrics.keys())
        values = [str(engine.state.iteration),] + [str(round(value, 5)) for value in engine.state.metrics.values()]

        with open(fname, "a") as f:
            if f.tell() == 0:
                print("\t".join(columns), file=f)
            print("\t".join(values), file=f)
        message = f"[{engine.state.epoch}/{epochs}][{engine.state.iteration % len(loader)}/{len(loader)}]"
        for name, value in zip(columns, values):
            message += f" | {name}: {value}"

        pbar.log_message(message)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def save_fake_example(engine):
        fake = netG(fixed_noise)
        path = os.path.join(output_dir, FAKE_IMG_FNAME.format(engine.state.epoch))
        vutils.save_image(fake.detach(), path, normalize=True)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def save_real_example(engine):
        img, y = engine.state.batch
        path = os.path.join(output_dir, REAL_IMG_FNAME.format(engine.state.epoch))
        vutils.save_image(img, path, normalize=True)

    # adding handlers using `trainer.add_event_handler` method API
    trainer.add_event_handler(
        event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={"netG": netG, "netD": netD}
    )

    # automatically adding handlers via a special `attach` method of `Timer` handler
    timer.attach(
        trainer,
        start=Events.EPOCH_STARTED,
        resume=Events.ITERATION_STARTED,
        pause=Events.ITERATION_COMPLETED,
        step=Events.ITERATION_COMPLETED,
    )

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(f"Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]")
        timer.reset()

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def create_plots(engine):
        try:
            import matplotlib as mpl

            mpl.use("agg")

            import matplotlib.pyplot as plt
            import pandas as pd

        except ImportError:
            warnings.warn("Loss plots will not be generated -- pandas or matplotlib not found")

        else:
            df = pd.read_csv(os.path.join(output_dir, LOGS_FNAME), delimiter="\t", index_col="iteration")
            _ = df.plot(subplots=True, figsize=(20, 20))
            _ = plt.xlabel("Iteration number")
            fig = plt.gcf()
            path = os.path.join(output_dir, PLOT_FNAME)

            fig.savefig(path)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EXCEPTION_RAISED)
    def handle_exception(engine, e):
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            engine.terminate()
            warnings.warn("KeyboardInterrupt caught. Exiting gracefully.")

            create_plots(engine)
            checkpoint_handler(engine, {"netG_exception": netG, "netD_exception": netD})

        else:
            raise e

    # Setup is done. Now let's run the training
    trainer.run(loader, epochs)
Ejemplo n.º 8
0
def do_train(cfg, model, train_loader, val_loader, optimizer, scheduler,
             loss_fn, num_query):
    log_period = cfg.SOLVER.LOG_PERIOD
    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    eval_period = cfg.SOLVER.EVAL_PERIOD
    output_dir = cfg.OUTPUT_DIR
    device = cfg.MODEL.DEVICE
    epochs = cfg.SOLVER.MAX_EPOCHS
    if device == "cuda":
        torch.cuda.set_device(cfg.MODEL.CUDA)

    logger = logging.getLogger("reid_baseline.train")
    logger.info("Start training")
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        loss_fn,
                                        device=device)
    evaluator = create_supervised_evaluator(
        model, metrics={'r1_mAP': R1_mAP(num_query)}, device=device)
    checkpointer = ModelCheckpoint(dirname=output_dir,
                                   filename_prefix=cfg.MODEL.NAME,
                                   n_saved=None,
                                   require_empty=False)
    timer = Timer(average=True)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {
        'model': model,
        'optimizer': optimizer
    })
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    # average metric to attach on trainer
    RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
    RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')

    @trainer.on(Events.EPOCH_STARTED)
    def adjust_learning_rate(engine):
        scheduler.step()

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.state.iteration - 1) % len(train_loader) + 1

        if iter % log_period == 0:
            logger.info(
                "Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
                .format(engine.state.epoch, iter, len(train_loader),
                        engine.state.metrics['avg_loss'],
                        engine.state.metrics['avg_acc'],
                        scheduler.get_lr()[0]))

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        logger.info(
            'Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
            .format(engine.state.epoch,
                    timer.value() * timer.step_count,
                    train_loader.batch_size / timer.value()))
        logger.info('-' * 10)
        timer.reset()

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        if engine.state.epoch % eval_period == 0:
            evaluator.run(val_loader)
            cmc, mAP = evaluator.state.metrics['r1_mAP']
            logger.info("Validation Results - Epoch: {}".format(
                engine.state.epoch))
            logger.info("mAP: {:.1%}".format(mAP))
            for r in [1, 5, 10]:
                logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(
                    r, cmc[r - 1]))

    trainer.run(train_loader, max_epochs=epochs)
Ejemplo n.º 9
0
def main():
    args = parse_args()

    logger.info('Num GPU: {}'.format(num_gpus))
    logger.info('Load Dataset')
    data = get_dataset(args.dataset, args.data_root, args.batch_size)
    data1, _ = data['train'][0]

    dims = list(data1.shape)
    param = dict(
        zdim=args.zdim,
        hdim=args.hdim,
        quant=not args.no_quantization,
        layers=args.layers,
        sigma=args.sigma,
    )
    model, optimizer = get_model(args.model, args.learning_rate, param, *dims)

    model = torch.nn.DataParallel(model) if num_gpus > 1 else model
    model.to(device)
    logger.info(model)

    kwargs = {
        'pin_memory': True if use_gpu else False,
        'shuffle': True,
        'num_workers': num_gpus * 4
    }

    logdir = get_logdir_name(args, param)
    logger.info('Log Dir: {}'.format(logdir))
    writer = SummaryWriter(logdir)

    os.makedirs(logdir, exist_ok=True)

    train_loader = DataLoader(data['train'], args.batch_size * num_gpus,
                              **kwargs)
    kwargs['shuffle'] = True
    test_loader = DataLoader(data['test'], args.batch_size * num_gpus,
                             **kwargs)

    if not args.no_quantization:
        q = Quantization(device=device)
        # raise NotImplementedError('It is using sigmoid now')
    else:
        q = Range()

    sigma_default = args.sigma * torch.ones(1, args.layers, 1, 1, 1)
    if use_gpu:
        sigma_default = sigma_default.cuda()
    else:
        sigma_default = sigma_default.cpu()

    def get_recon_error(x, sigma, x_mu_k, log_ms_k, recon):
        batch, *xdims = x.shape
        n = Normal(x_mu_k, sigma)
        log_x_mu = n.log_prob(x.view(batch, 1, *xdims))
        log_mx = log_x_mu + log_ms_k
        ll = torch.log(log_mx.exp().sum(dim=1))
        return -ll.sum(dim=[1, 2, 3]).mean()

    def step(engine, batch):
        model.train()
        x, _ = batch
        x = x.to(device)
        x = q.preprocess(x)

        recon, recon_k, x_mu_k, log_ms_k, kl_m, kl_c = model(x)

        nll = get_recon_error(x, sigma_default, x_mu_k, log_ms_k, recon)
        kl_m = kl_m.sum(dim=[1, 2, 3, 4]).mean()
        kl_c = kl_c.sum(dim=[1, 2, 3, 4]).mean()
        optimizer.zero_grad()

        nll_ema = engine.global_info['nll_ema']
        kl_ema = engine.global_info['kl_ema']
        beta = engine.global_info['beta']

        nll_ema = get_ema(nll.detach(), nll_ema, args.geco_alpha)
        kl_ema = get_ema((kl_m + kl_c).detach(), kl_ema, args.geco_alpha)

        loss = nll + beta * (kl_c + kl_m)
        elbo = -loss
        loss.backward()
        optimizer.step()

        # GECO update
        n_pixels = x.shape[1] * x.shape[2] * x.shape[3]
        goal = args.geco_goal * n_pixels
        geco_lr = args.geco_lr
        beta = geco_beta_update(beta,
                                nll_ema,
                                goal,
                                geco_lr,
                                speedup=args.geco_speedup)

        engine.global_info['nll_ema'] = nll_ema
        engine.global_info['kl_ema'] = kl_ema
        engine.global_info['beta'] = beta

        lr = optimizer.param_groups[0]['lr']
        ret = {
            'elbo': elbo.item(),
            'nll': nll.item(),
            'kl_m': kl_m.item(),
            'kl_c': kl_c.item(),
            'lr': lr,
            'sigma': args.sigma,
            'beta': beta
        }
        return ret

    trainer = Engine(step)
    trainer.global_info = {
        'nll_ema': None,
        'kl_ema': None,
        'beta': torch.tensor(args.geco_init).to(device)
    }
    metric_names = ['elbo', 'nll', 'kl_m', 'kl_c', 'lr', 'sigma', 'beta']

    RunningAverage(output_transform=lambda x: x['elbo']).attach(
        trainer, 'elbo')
    RunningAverage(output_transform=lambda x: x['nll']).attach(trainer, 'nll')
    RunningAverage(output_transform=lambda x: x['kl_m']).attach(
        trainer, 'kl_m')
    RunningAverage(output_transform=lambda x: x['kl_c']).attach(
        trainer, 'kl_c')
    RunningAverage(output_transform=lambda x: x['lr']).attach(trainer, 'lr')
    RunningAverage(output_transform=lambda x: x['sigma']).attach(
        trainer, 'sigma')
    RunningAverage(output_transform=lambda x: x['beta']).attach(
        trainer, 'beta')

    ProgressBar().attach(trainer, metric_names=metric_names)
    Timer(average=True).attach(trainer)

    add_events(trainer, model, writer, logdir, args.log_interval)

    @trainer.on(Events.EPOCH_COMPLETED)
    def validate(engine):
        model.eval()

        val_elbo = 0
        val_kl_m = 0
        val_kl_c = 0
        val_nll = 0

        beta = engine.global_info['beta']
        with torch.no_grad():
            for i, (x, _) in enumerate(test_loader):
                x = x.to(device)
                x_processed = q.preprocess(x)
                recon_processed, recon_k_processed, x_mu_k, log_ms_k, kl_m, kl_c = model(
                    x_processed)
                nll = get_recon_error(x_processed, sigma_default, x_mu_k,
                                      log_ms_k, recon_processed)
                kl_m = kl_m.sum(dim=[1, 2, 3, 4]).mean()
                kl_c = kl_c.sum(dim=[1, 2, 3, 4]).mean()
                loss = nll + beta * (kl_m + kl_c)
                elbo = -loss

                val_elbo += elbo
                val_kl_m += kl_m
                val_kl_c += kl_c
                val_nll += nll
                if i == 0:
                    cat = []
                    max_col = (args.layers + 2)
                    for x1, mu1, mu1_k, l_k, c_k in zip(
                            x_processed, recon_processed, recon_k_processed,
                            log_ms_k, x_mu_k):
                        # What a lazy way..
                        cat.extend([x1, mu1])  # Recon per layer
                        cat.extend(mu1_k)
                        cat.extend(x1.new_zeros([2, 3, 64,
                                                 64]))  # Masks per layer
                        cat.extend(
                            q.preprocess(l_k.exp().expand(
                                args.layers, 3, 64, 64)))
                        cat.extend(x1.new_zeros([2, 3, 64,
                                                 64]))  #  components per layer
                        cat.extend(c_k)
                        if len(cat) > (max_col * 7 * 3):
                            break
                    cat = torch.stack(cat)
                    #if cat.shape[0] > max_col * 3:
                    #    cat = cat[:max_col * 3]
                    cat = q.postprocess(cat)
                    writer.add_image(
                        '{}/layers'.format(args.dataset),
                        make_grid(cat.detach().cpu(), nrow=max_col),
                        engine.state.iteration)
            val_elbo /= len(test_loader)
            val_kl_m /= len(test_loader)
            val_kl_c /= len(test_loader)
            val_nll /= len(test_loader)
            writer.add_scalar('val/elbo', val_elbo.item(),
                              engine.state.iteration)
            writer.add_scalar('val/beta', beta.item(), engine.state.iteration)
            writer.add_scalar('val/kl_m', val_kl_m.item(),
                              engine.state.iteration)
            writer.add_scalar('val/kl_c', val_kl_c.item(),
                              engine.state.iteration)
            writer.add_scalar('val/nll', val_nll.item(),
                              engine.state.iteration)
            print('{:3d} /{:3d} : ELBO: {:.4f}, KL-M: {:.4f}, '
                  'KL-C: {:.4f} NLL: {:.4f}'.format(engine.state.epoch,
                                                    engine.state.max_epochs,
                                                    val_elbo, val_kl_m,
                                                    val_kl_c, val_nll))

    @trainer.on(Events.EXCEPTION_RAISED)
    def handler_exception(engine, e):
        writer.close()
        engine.terminate()
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            logger.warn('KeyboardInterrupt caught. Exiting gracefully.')
        else:
            raise e

    logger.info(
        'Start training. Max epoch = {}, Batch = {}, # Trainset = {}'.format(
            args.epoch, args.batch_size, len(data['train'])))
    trainer.run(train_loader, args.epoch)
    logger.info('Done training')
    writer.close()
Ejemplo n.º 10
0
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size,
         epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale,
         flow_permutation, flow_coupling, LU_decomposed, learn_top,
         y_condition, y_weight, max_grad_clip, max_grad_norm, lr,
         n_workers, cuda, n_init_batches, warmup_steps, output_dir,
         saved_optimizer, warmup, fresh, gpuid):

    device = 'cpu' if (not torch.cuda.is_available() or not cuda) else 'cuda:'+str(gpuid)

    check_manual_seed(seed)

    ds = check_dataset(dataset, dataroot, augment, download)
    image_shape, num_classes, train_dataset, test_dataset = ds

    # Note: unsupported for now
    multi_class = False

    train_loader = data.DataLoader(train_dataset, batch_size=batch_size,
                                   shuffle=True, num_workers=n_workers,
                                   drop_last=True)
    test_loader = data.DataLoader(test_dataset, batch_size=eval_batch_size,
                                  shuffle=False, num_workers=n_workers,
                                  drop_last=False)

    model = Glow(image_shape, hidden_channels, K, L, actnorm_scale,
                 flow_permutation, flow_coupling, LU_decomposed, num_classes,
                 learn_top, y_condition)

    model = model.to(device)
    optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    lr_lambda = lambda epoch: lr * min(1., epoch / warmup)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
    
    # set logging option
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    formatter = logging.Formatter('%(asctime)s - %(message)s')
    console = logging.StreamHandler()
    console.setFormatter(formatter)
    logger.addHandler(console)

    hdlr = logging.FileHandler(output_dir + 'losses.log')
    hdlr.setFormatter(formatter)
    logger.addHandler(hdlr)
    
    writer = SummaryWriter(output_dir + 'losses')

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = batch
        x = x.to(device)

        if y_condition:
            y = y.to(device)
            z, nll, y_logits = model(x, y)
            losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class)
        else:
            z, nll, y_logits = model(x, None)
            losses = compute_loss(nll)
            
        losses['total_loss'].backward()

        if max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        if max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits = model(x, y)
                losses = compute_loss_y(nll, y_logits, y_weight, y,
                                        multi_class, reduction='none')
            else:
                z, nll, y_logits = model(x, None)
                losses = compute_loss(nll, reduction='none')

        return losses

    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir, 'glow', save_interval=1,
                                         n_saved=2, require_empty=False)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler,
                              {'model': model, 'optimizer': optimizer})

    monitoring_metrics = ['total_loss']
    RunningAverage(output_transform=lambda x: x['total_loss']).attach(trainer, 'total_loss')

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach(evaluator, 'total_loss')

    if y_condition:
        monitoring_metrics.extend(['nll'])
        RunningAverage(output_transform=lambda x: x['nll']).attach(trainer, 'nll')

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['nll'], torch.empty(x['nll'].shape[0]))).attach(evaluator, 'nll')
        monitoring_metrics.extend(['loss_classes'])
        RunningAverage(output_transform=lambda x: x['loss_classes']).attach(trainer, 'loss_classes')

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['loss_classes'], torch.empty(x['loss_classes'].shape[0]))).attach(evaluator, 'loss_classes')



    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        model.load_state_dict(torch.load(saved_model))
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer))

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split('_')[-1])

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(engine.state.dataloader)

    @trainer.on(Events.STARTED)
    def init(engine):
        model.train()

        init_batches = []
        init_targets = []

        with torch.no_grad():
            for batch, target in islice(train_loader, None,
                                        n_init_batches):
                init_batches.append(batch)
                init_targets.append(target)

            init_batches = torch.cat(init_batches).to(device)

            assert init_batches.shape[0] == n_init_batches * batch_size

            if y_condition:
                init_targets = torch.cat(init_targets).to(device)
            else:
                init_targets = None

            model(init_batches, init_targets)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(engine):

        evaluator.run(train_loader)

        metrics = evaluator.state.metrics
        
        writer.add_scalars('loss_classes', {'train': metrics['loss_classes']}, engine.state.epoch)
        writer.add_scalars('loss_nll', {'train': metrics['nll']}, engine.state.epoch)
        writer.add_scalars('loss_total', {'train': metrics['total_loss']}, engine.state.epoch)
        epoch_str = ("Epoch %d. Train_Loss_Classes: %f, Train_NLL: %f, Train_Total: %f " % (engine.state.epoch, metrics['loss_classes'], metrics['nll'], metrics['total_loss']))       

        logging.info(epoch_str)
        
        evaluator.run(test_loader)

        scheduler.step()
        metrics = evaluator.state.metrics

        losses = ', '.join([f"{key}: {value:.2f}" for key, value in metrics.items()])
        
        writer.add_scalars('loss_classes', {'eval': metrics['loss_classes']}, engine.state.epoch)
        writer.add_scalars('loss_nll', {'eval': metrics['nll']}, engine.state.epoch)
        writer.add_scalars('loss_total', {'eval': metrics['total_loss']}, engine.state.epoch)    
        epoch_str = ("Epoch %d. Eval_Loss_Classes: %f, Eval_NLL: %f, Eval_Total: %f " % (engine.state.epoch, metrics['loss_classes'], metrics['nll'], metrics['total_loss']))       

        logging.info(epoch_str)

        print(f'Validation Results - Epoch: {engine.state.epoch} {losses}')

    timer = Timer(average=True)
    timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]')
        timer.reset()

    trainer.run(train_loader, epochs)
Ejemplo n.º 11
0
    def run_once(self, fold_idx):

        log_dir = self.log_dir
        check_manual_seed(self.seed)
        train_pairs, valid_pairs = getattr(
            dataset, ('prepare_%s_data' % self.dataset))()
        print(len(train_pairs))
        print(len(valid_pairs))

        train_augmentors = self.train_augmentors()
        train_dataset = dataset.DatasetSerial(
            train_pairs[:],
            shape_augs=iaa.Sequential(train_augmentors[0]),
            input_augs=iaa.Sequential(train_augmentors[1]))

        infer_augmentors = self.infer_augmentors()  # HACK at has_aux
        infer_dataset = dataset.DatasetSerial(valid_pairs[:],
                                              shape_augs=iaa.Sequential(
                                                  infer_augmentors[0]))

        train_loader = data.DataLoader(train_dataset,
                                       num_workers=self.nr_procs_train,
                                       batch_size=self.train_batch_size,
                                       shuffle=True,
                                       drop_last=True)

        valid_loader = data.DataLoader(infer_dataset,
                                       num_workers=self.nr_procs_valid,
                                       batch_size=self.infer_batch_size,
                                       shuffle=True,
                                       drop_last=False)

        # --------------- Training Sequence

        if self.logging:
            check_log_dir(log_dir)

        device = 'cuda'

        # networksv
        input_chs = 3  # TODO: dynamic config
        net = EfficientNet.from_pretrained("efficientnet-b2", num_classes=2)

        # load pre-trained models
        if self.load_network:
            net = load_weight(net, self.save_net_path)

        net = torch.nn.DataParallel(net).to(device)

        # optimizers
        optimizer = optim.Adam(net.parameters(), lr=self.init_lr)
        #scheduler = optim.lr_scheduler.StepLR(optimizer, self.lr_steps)

        #
        trainer = Engine(lambda engine, batch: self.train_step(
            net, batch, optimizer, device))
        valider = Engine(
            lambda engine, batch: self.infer_step(net, batch, device))

        infer_output = ['prob', 'true']
        ##

        if self.logging:
            checkpoint_handler = ModelCheckpoint(log_dir,
                                                 self.chkpts_prefix,
                                                 save_interval=1,
                                                 n_saved=30,
                                                 require_empty=False)
            # adding handlers using `trainer.add_event_handler` method API
            trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                                      handler=checkpoint_handler,
                                      to_save={'net': net})

        timer = Timer(average=True)
        timer.attach(trainer,
                     start=Events.EPOCH_STARTED,
                     resume=Events.ITERATION_STARTED,
                     pause=Events.ITERATION_COMPLETED,
                     step=Events.ITERATION_COMPLETED)
        timer.attach(valider,
                     start=Events.EPOCH_STARTED,
                     resume=Events.ITERATION_STARTED,
                     pause=Events.ITERATION_COMPLETED,
                     step=Events.ITERATION_COMPLETED)

        # attach running average metrics computation
        # decay of EMA to 0.95 to match tensorpack default
        # TODO: refactor this
        RunningAverage(alpha=0.95, output_transform=lambda x: x['acc']).attach(
            trainer, 'acc')
        RunningAverage(alpha=0.95,
                       output_transform=lambda x: x['loss']).attach(
                           trainer, 'loss')

        # attach progress bar
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=['loss'])
        pbar.attach(valider)

        # adding handlers using `trainer.on` decorator API
        @trainer.on(Events.EXCEPTION_RAISED)
        def handle_exception(engine, e):
            if isinstance(e,
                          KeyboardInterrupt) and (engine.state.iteration > 1):
                engine.terminate()
                warnings.warn('KeyboardInterrupt caught. Exiting gracefully.')
                checkpoint_handler(engine, {'net_exception': net})
            else:
                raise e

        # writer for tensorboard logging
        tfwriter = None  # HACK temporary
        if self.logging:
            tfwriter = SummaryWriter(log_dir)
            json_log_file = log_dir + '/stats.json'
            with open(json_log_file, 'w') as json_file:
                json.dump({}, json_file)  # create empty file

        ### TODO refactor again
        log_info_dict = {
            'logging': self.logging,
            'optimizer': optimizer,
            'tfwriter': tfwriter,
            'json_file': json_log_file,
            'nr_classes': self.nr_classes,
            'metric_names': infer_output,
            'infer_batch_size': self.infer_batch_size  # too cumbersome
        }
        # trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: scheduler.step())  # to change the lr
        trainer.add_event_handler(Events.EPOCH_COMPLETED,
                                  log_train_ema_results, log_info_dict)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, inference, valider,
                                  valid_loader, log_info_dict)
        valider.add_event_handler(Events.ITERATION_COMPLETED,
                                  accumulate_outputs)

        # Setup is done. Now let's run the training
        trainer.run(train_loader, self.nr_epochs)
        return
Ejemplo n.º 12
0
def attach_decorators(trainer, SR, feature_extractor, domain_classifier,
                      resolution_classifier, sr_classif_critic, optim, loader):
    timer = Timer(average=True)

    checkpoint_handler = ModelCheckpoint(
        args.output_dir + '/checkpoints/domain_adaptation_training/',
        'training',
        save_interval=1,
        n_saved=300,
        require_empty=False,
        iteration=args.epoch_c)

    monitoring_metrics = [
        'tgt_loss', 'src_loss', 'sr_loss', 'loss', 'GP', 'res_down_loss',
        'res_up_loss', 'tv_loss', 'vgg_loss'
    ]
    RunningAverage(alpha=0.98,
                   output_transform=lambda x: x['tgt_loss']).attach(
                       trainer, 'tgt_loss')
    RunningAverage(alpha=0.98,
                   output_transform=lambda x: x['src_loss']).attach(
                       trainer, 'src_loss')
    RunningAverage(alpha=0.98, output_transform=lambda x: x['sr_loss']).attach(
        trainer, 'sr_loss')
    RunningAverage(alpha=0.98, output_transform=lambda x: x['loss']).attach(
        trainer, 'loss')
    RunningAverage(alpha=0.98,
                   output_transform=lambda x: x['GP']).attach(trainer, 'GP')
    # RunningAverage(alpha=0.98, output_transform=lambda x: x['g_loss']).attach(trainer, 'g_loss')
    RunningAverage(alpha=0.98,
                   output_transform=lambda x: x['res_down_loss']).attach(
                       trainer, 'res_down_loss')
    RunningAverage(alpha=0.98,
                   output_transform=lambda x: x['res_up_loss']).attach(
                       trainer, 'res_up_loss')
    RunningAverage(alpha=0.98, output_transform=lambda x: x['tv_loss']).attach(
        trainer, 'tv_loss')
    RunningAverage(alpha=0.98,
                   output_transform=lambda x: x['vgg_loss']).attach(
                       trainer, 'vgg_loss')

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    trainer.add_event_handler(
        event_name=Events.EPOCH_COMPLETED,
        handler=checkpoint_handler,
        to_save={
            'feature_extractor': feature_extractor,
            'SR': SR,
            # 'optim_feature': optim_feature,
            # 'optim_domain_classif': optim_domain_classif,
            # 'optim_res_classif': optim_res_classif,
            'optim': optim,
            # 'optim_sr_critic': optim_sr_critic,
            'domain_D': domain_classifier,
            'res_D': resolution_classifier,
            'sr_D': sr_classif_critic
        })

    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    @trainer.on(Events.ITERATION_COMPLETED)
    def print_logs(engine):
        if (engine.state.iteration - 1) % PRINT_FREQ == 0:
            fname = os.path.join(args.output_dir, LOGS_FNAME)
            columns = engine.state.metrics.keys()
            values = [
                str(round(value, 5))
                for value in engine.state.metrics.values()
            ]

            with open(fname, 'a') as f:
                if f.tell() == 0:
                    print('\t'.join(columns), file=f)
                print('\t'.join(values), file=f)

            i = (engine.state.iteration % len(loader))
            message = '[{epoch}/{max_epoch}][{i}/{max_i}]'.format(
                epoch=engine.state.epoch,
                max_epoch=args.epochs,
                i=i,
                max_i=len(loader))
            for name, value in zip(columns, values):
                message += ' | {name}: {value}'.format(name=name, value=value)

            pbar.log_message(message)

    @trainer.on(Events.ITERATION_COMPLETED)
    def save_real_example(engine):
        if (engine.state.iteration - 1) % PRINT_FREQ == 0:
            if (engine.state.iteration - 1) % PRINT_FREQ == 0:
                if not os.path.exists(args.output_dir +
                                      '/imgs/domain_adaptation_training/'):
                    os.makedirs(args.output_dir +
                                '/imgs/domain_adaptation_training/')
            px, py, px2, py2, px_up, _, px2_up, _ = engine.state.batch
            img = SR(feature_extractor(px2.cuda()))
            path = os.path.join(
                args.output_dir + '/imgs/domain_adaptation_training/',
                predtgt_IMG_FNAME.format(engine.state.epoch,
                                         engine.state.iteration))
            vutils.save_image(img, path)
            path = os.path.join(
                args.output_dir + '/imgs/domain_adaptation_training/',
                targetY_IMG_FNAME.format(engine.state.epoch,
                                         engine.state.iteration))
            vutils.save_image(py2, path)
            path = os.path.join(
                args.output_dir + '/imgs/domain_adaptation_training/',
                targetX_IMG_FNAME.format(engine.state.epoch,
                                         engine.state.iteration))
            vutils.save_image(px2, path)
            path = os.path.join(
                args.output_dir + '/imgs/domain_adaptation_training/',
                sourceX_IMG_FNAME.format(engine.state.epoch,
                                         engine.state.iteration))
            vutils.save_image(px, path)
            path = os.path.join(
                args.output_dir + '/imgs/domain_adaptation_training/',
                sourceY_IMG_FNAME.format(engine.state.epoch,
                                         engine.state.iteration))
            vutils.save_image(py, path)
            img = SR(feature_extractor(px.cuda()))
            path = os.path.join(
                args.output_dir + '/imgs/domain_adaptation_training/',
                predsrc_IMG_FNAME.format(engine.state.epoch,
                                         engine.state.iteration))
            vutils.save_image(img, path)

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message('Epoch {} done. Time per batch: {:.3f}[s]'.format(
            engine.state.epoch, timer.value()))
        timer.reset()

    @trainer.on(Events.EXCEPTION_RAISED)
    def handle_exception(engine, e):
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            engine.terminate()
            warnings.warn('KeyboardInterrupt caught. Exiting gracefully.')

            checkpoint_handler(
                engine, {
                    'feature_extractor_{}'.format(engine.state.iteration):
                    feature_extractor,
                    'SR_{}'.format(engine.state.iteration): SR,
                    'DOMAIN_D_{}'.format(engine.state.iteration):
                    domain_classifier,
                    'RES_D_{}'.format(engine.state.iteration):
                    resolution_classifier,
                    'SR_D_{}'.format(engine.state.iteration):
                    sr_classif_critic,
                    'OPTIM_{}'.format(engine.state.iteration): optim
                })

        else:
            raise e

    @trainer.on(Events.STARTED)
    def loaded(engine):
        if args.epoch_c != 0:
            engine.state.epoch = args.epoch_c
            engine.state.iteration = args.epoch_c * len(loader)
Ejemplo n.º 13
0
 def __init__(self):        
     self._dataflow_timer = Timer()
     self._processing_timer = Timer()
     self._event_handlers_timer = Timer()
Ejemplo n.º 14
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("data_directory", type=Path)
    parser.add_argument("--generator-weights", type=Path)
    parser.add_argument("--discriminator-weights", type=Path)

    args = parser.parse_args()

    generator = Generator(GENERATOR_FILTERS)
    if args.generator_weights is not None:
        LOGGER.info(f"Loading generator weights: {args.generator_weights}")
        generator.load_state_dict(torch.load(args.generator_weights))
    else:
        generator.weight_init(mean=0.0, std=0.02)

    discriminator = Discriminator(DISCRIMINATOR_FILTERS)
    if args.discriminator_weights is not None:
        LOGGER.info(
            f"Loading discriminator weights: {args.discriminator_weights}")
        discriminator.load_state_dict(torch.load(args.discriminator_weights))
    else:
        discriminator.weight_init(mean=0.0, std=0.02)

    dataset = XView2Dataset(args.data_directory, )
    train_dataset, test_dataset = torch.utils.data.random_split(
        dataset, [len(dataset) - 10, 10])

    # Create a dev train dataset with just 10 samples
    # train_dataset, _ = torch.utils.data.random_split(train_dataset, [10, len(train_dataset) - 10])

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=TRAIN_BATCH_SIZE)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=TEST_BATCH_SIZE)

    generator.cuda()
    discriminator.cuda()

    generator.train()
    discriminator.train()

    BCE_loss = nn.BCELoss().cuda()
    L1_loss = nn.L1Loss().cuda()

    generator_optimizer = optim.Adam(generator.parameters(),
                                     lr=GENERATOR_LR,
                                     betas=(BETA_1, BETA_2))
    discriminator_optimizer = optim.Adam(discriminator.parameters(),
                                         lr=DISCRIMINATOR_LR,
                                         betas=(BETA_1, BETA_2))

    def step(engine, batch):
        x, y = batch
        x = x.cuda()
        y = y.cuda()

        discriminator.zero_grad()
        discriminator_result = discriminator(x, y).squeeze()
        discriminator_real_loss = BCE_loss(
            discriminator_result,
            torch.ones(discriminator_result.size()).cuda())

        generator_result = generator(x)
        discriminator_result = discriminator(x, generator_result).squeeze()

        discriminator_fake_loss = BCE_loss(
            discriminator_result,
            torch.zeros(discriminator_result.size()).cuda())
        discriminator_train_loss = (discriminator_real_loss +
                                    discriminator_fake_loss) * 0.5
        discriminator_train_loss.backward()
        discriminator_optimizer.step()

        generator.zero_grad()
        generator_result = generator(x)
        # TODO Work out if the below time saving technique impacts training.
        #generator_result = generator_result.detach()
        discriminator_result = discriminator(x, generator_result).squeeze()

        l1_loss = L1_loss(generator_result, y)
        bce_loss = BCE_loss(discriminator_result,
                            torch.ones(discriminator_result.size()).cuda())

        G_train_loss = bce_loss + L1_LAMBDA * l1_loss
        G_train_loss.backward()
        generator_optimizer.step()

        return {
            'generator_train_loss': G_train_loss.item(),
            'discriminator_real_loss': discriminator_real_loss.item(),
            'discriminator_fake_loss': discriminator_fake_loss.item(),
        }

    trainer = Engine(step)

    tb_logger = TensorboardLogger(log_dir=f"tensorboard/logdir/{uuid4()}")
    tb_logger.attach(trainer,
                     log_handler=OutputHandler(
                         tag="training",
                         output_transform=lambda out: out,
                         metric_names='all'),
                     event_name=Events.ITERATION_COMPLETED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def add_generated_images(engine):
        def min_max(image):
            return (image - image.min()) / (image.max() - image.min())

        for idx, (x, y) in enumerate(test_loader):
            generated = min_max(generator(x.cuda()).squeeze().cpu())
            real = min_max(y.squeeze())

            tb_logger.writer.add_image(
                f"generated_test_image_{idx}",
                # Concatenate the images into a single tiled image
                torch.cat([x.squeeze(), generated, real], 2),
                global_step=engine.state.epoch)

    checkpoint_handler = ModelCheckpoint("checkpoints/",
                                         "pix2pix",
                                         n_saved=1,
                                         require_empty=False,
                                         save_interval=1)
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={
                                  'generator': generator,
                                  'discriminator': discriminator
                              })

    timer = Timer(average=True)
    timer.attach(trainer,
                 resume=Events.ITERATION_STARTED,
                 step=Events.ITERATION_COMPLETED)

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        print("Epoch[{}] Iteration[{}] Duration[{}] Losses: {}".format(
            engine.state.epoch, engine.state.iteration, timer.value(),
            engine.state.output))

    trainer.run(train_loader, max_epochs=TRAIN_EPOCHS)

    tb_logger.close()
Ejemplo n.º 15
0
def do_train_with_feat(cfg, train_loader, valid, tr_comp: TrainComponent,
                       saver):
    tb_log = TensorBoardXLog(cfg, saver.save_dir)

    device = cfg.MODEL.DEVICE

    trainer = create_supervised_trainer(
        tr_comp.model,
        tr_comp.optimizer,
        tr_comp.loss,
        device=device,
        apex=cfg.APEX.IF_ON,
        has_center=cfg.LOSS.IF_WITH_CENTER,
        center_criterion=tr_comp.loss.center,
        optimizer_center=tr_comp.optimizer_center,
        center_loss_weight=cfg.LOSS.CENTER_LOSS_WEIGHT)

    saver.to_save = {
        'trainer': trainer,
        'module': tr_comp.model,
        'optimizer': tr_comp.optimizer,
        'center_param': tr_comp.loss_center,
        'optimizer_center': tr_comp.optimizer_center
    }

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED(every=cfg.SAVER.CHECKPOINT_PERIOD),
        saver.train_checkpointer, saver.to_save)

    # multi-valid-dataset
    validation_evaluator_map = get_valid_eval_map(cfg, device, tr_comp.model,
                                                  valid)

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    # average metric to attach on trainer
    names = ["Acc", "Loss"]
    names.extend(tr_comp.loss.loss_function_map.keys())

    for n in names:
        RunningAverage(output_transform=Run(n)).attach(trainer, n)

    # TODO start epoch
    @trainer.on(Events.STARTED)
    def start_training(engine):
        engine.state.epoch = 0

    @trainer.on(Events.EPOCH_STARTED)
    def adjust_learning_rate(engine):
        tr_comp.scheduler.step()

    @trainer.on(Events.ITERATION_COMPLETED(every=cfg.TRAIN.LOG_ITER_PERIOD))
    def log_training_loss(engine):
        message = f"Epoch[{engine.state.epoch}], " + \
                  f"Iteration[{engine.state.iteration}/{len(train_loader)}], " + \
                  f"Lr: {tr_comp.scheduler.get_lr()[0]:.2e}, " + \
                  f"Loss: {engine.state.metrics['Loss']:.4f}, " + \
                  f"Acc: {engine.state.metrics['Acc']:.4f}, "

        for loss_name in tr_comp.loss.loss_function_map.keys():
            message += f"{loss_name}: {engine.state.metrics[loss_name]:.4f}, "

        message += f"feat: {engine.state.metrics['feat']:.4f}"

        logger.info(message)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        logger.info(
            'Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
            .format(engine.state.epoch,
                    timer.value() * timer.step_count,
                    train_loader.batch_size / timer.value()))
        logger.info('-' * 80)
        timer.reset()

    @trainer.on(Events.EPOCH_COMPLETED(every=cfg.EVAL.EPOCH_PERIOD),
                saver=saver)
    def log_validation_results(engine, saver):
        # train_evaluator.run(train_loader)
        # cmc, mAP = validation_evaluator.state.metrics['r1_mAP']
        # logger.info("Train Results - Epoch: {}".format(engine.state.epoch))
        # logger.info("mAP: {:.1%}".format(mAP))
        # for r in [1, 5, 10]:
        #     logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))

        logger.info(f"Valid - Epoch: {engine.state.epoch}")

        sum_result = eval_multi_dataset(device, validation_evaluator_map,
                                        valid)

        if saver.best_result < sum_result:
            logger.info(f'Save best: {sum_result:.4f}')
            saver.save_best_value(sum_result)
            saver.best_checkpointer(engine, saver.to_save)
            saver.best_result = sum_result
        else:
            logger.info(
                f"Not best: {saver.best_result:.4f} > {sum_result:.4f}")
        logger.info('-' * 80)

    tb_log.attach_handler(trainer, tr_comp.model, tr_comp.optimizer)

    # self.tb_logger.attach(
    #     validation_evaluator,
    #     log_handler=ReIDOutputHandler(tag="valid", metric_names=["r1_mAP"], another_engine=trainer),
    #     event_name=Events.EPOCH_COMPLETED,
    # )

    trainer.run(train_loader, max_epochs=cfg.TRAIN.MAX_EPOCHS)

    tb_log.close()
Ejemplo n.º 16
0
def run(config, plx_experiment):

    set_seed(config['seed'])

    device = "cuda"
    batch_size = config['batch_size']

    cutout_size = config['cutout_size']
    train_transforms = [
        DynamicCrop(32, 32),
        FlipLR(),
        DynamicCutout(cutout_size, cutout_size)
    ]
    train_loader, test_loader = get_fast_train_test_loaders(
        path=config["data_path"],
        batch_size=batch_size,
        num_workers=config['num_workers'],
        device=device,
        train_transforms=train_transforms)

    bn_kwargs = config['bn_kwargs']
    conv_kwargs = config['conv_kwargs']

    model = FastResRecNet(conv_kwargs=conv_kwargs,
                          bn_kwargs=bn_kwargs,
                          final_weight=config['final_weight'])
    model = model.to(device)
    model = model.half()
    model_name = model.__class__.__name__

    criterion = nn.CrossEntropyLoss(reduction='sum').to(device)
    criterion = criterion.half()
    eval_criterion = criterion

    if config["enable_mixup"]:
        criterion = MixupCriterion(criterion)

    weight_decay = config['weight_decay']

    if not config['use_adam']:
        opt_kwargs = [("lr", 0.0), ("momentum", config['momentum']),
                      ("weight_decay", weight_decay), ("nesterov", True)]
        optimizer_cls = optim.SGD
    else:
        opt_kwargs = [
            ("lr", 0.0),
            ("betas", (0.9, 0.999)),
            ("eps", 1e-08),
            ("amsgrad", True),
            ("weight_decay", weight_decay),
        ]
        optimizer_cls = optim.Adam

    optimizer = optimizer_cls([
        # conv + bn
        dict([("params", model.prep.parameters())] + opt_kwargs),
        # conv + bn
        dict([("params", model.layer1[0].parameters())] + opt_kwargs),
        # identity residual recurrent blocks
        dict([("params", model.layer1[-1].conv_rec.parameters())] +
             opt_kwargs),
        # conv + bn
        dict([("params", model.layer2[0].parameters())] + opt_kwargs),
        # identity residual recurrent blocks
        dict([("params", model.layer2[-1].conv_rec.parameters())] +
             opt_kwargs),
        # conv + bn
        dict([("params", model.layer3[0].parameters())] + opt_kwargs),
        # identity residual recurrent blocks
        dict([("params", model.layer3[-1].conv_rec.parameters())] +
             opt_kwargs),
        # linear
        dict([("params", model.classifier.parameters())] + opt_kwargs),
    ])

    num_iterations_per_epoch = len(train_loader)
    num_iterations = num_iterations_per_epoch * config['num_epochs']
    layerwise_milestones_lr_values = []
    for i in range(len(optimizer.param_groups)):
        key = "lr_param_group_{}".format(i)
        assert key in config, "{} not in config".format(key)
        milestones_values = config[key]
        layerwise_milestones_lr_values.append([(m * num_iterations_per_epoch,
                                                v / batch_size)
                                               for m, v in milestones_values])

    lr_scheduler = get_layerwise_lr_scheduler(optimizer,
                                              layerwise_milestones_lr_values)

    def _prepare_batch_fp16(batch, device, non_blocking):
        x, y = batch
        return (convert_tensor(x, device=device,
                               non_blocking=non_blocking).half(),
                convert_tensor(y, device=device,
                               non_blocking=non_blocking).long())

    def process_function(engine, batch):
        x, y = _prepare_batch_fp16(batch, device=device, non_blocking=True)

        if config['enable_mixup']:
            x, y = mixup_data(x, y, config['mixup_alpha'],
                              config['mixup_proba'])

        optimizer.zero_grad()
        y_pred = model(x)

        loss = criterion(y_pred, y)
        loss.backward()

        if config["clip_gradients"] is not None:
            clip_grad_norm_(model.parameters(), config["clip_gradients"])

        optimizer.step()
        loss = loss.item()

        return loss

    trainer = Engine(process_function)

    metrics = {
        "accuracy": Accuracy(),
        "loss": Loss(eval_criterion) / len(test_loader)
    }
    evaluator = create_supervised_evaluator(model,
                                            metrics,
                                            prepare_batch=_prepare_batch_fp16,
                                            device=device,
                                            non_blocking=True)

    train_evaluator = create_supervised_evaluator(
        model,
        metrics,
        prepare_batch=_prepare_batch_fp16,
        device=device,
        non_blocking=True)

    total_timer = Timer(average=False)
    train_timer = Timer(average=False)
    test_timer = Timer(average=False)

    table_logger = TableLogger()

    if config["use_tb_logger"]:
        path = "experiments/tb_logs" if "TB_LOGGER_PATH" not in os.environ else os.environ[
            "TB_LOGGER_PATH"]
        tb_logger = SummaryWriter(log_dir=path)

    test_timer.attach(evaluator, start=Events.EPOCH_STARTED)

    @trainer.on(Events.STARTED)
    def on_training_started(engine):
        print("Warming up cudnn on random inputs")
        for _ in range(5):
            for size in [batch_size, len(test_loader.dataset) % batch_size]:
                warmup_cudnn(model, criterion, size, config)

        total_timer.reset()

    @trainer.on(Events.EPOCH_STARTED)
    def on_epoch_started(engine):
        model.train()
        train_timer.reset()

        # Warm-up on small images
        if config['warmup_on_small_images']:
            if engine.state.epoch < config['warmup_duration']:
                train_loader.dataset.transforms[0].h = 20
                train_loader.dataset.transforms[0].w = 20
            elif engine.state.epoch == config['warmup_duration']:
                train_loader.dataset.transforms[0].h = 32
                train_loader.dataset.transforms[0].w = 32

        train_loader.dataset.set_random_choices()

        if config['reduce_cutout']:
            # after 15 epoch remove cutout augmentation
            if 14 <= engine.state.epoch < 16:
                train_loader.dataset.transforms[-1].h -= 1
                train_loader.dataset.transforms[-1].w -= 1
            elif engine.state.epoch == 16:
                train_loader.dataset.transforms.pop()

        if config['enable_mixup'] and config[
                'mixup_max_epochs'] == engine.state.epoch - 1:
            config['mixup_proba'] = 0.0

    if config["use_tb_logger"]:

        @trainer.on(Events.ITERATION_COMPLETED)
        def on_iteration_completed(engine):
            # log learning rate
            param_name = "lr"
            if len(optimizer.param_groups) == 1:
                param = float(optimizer.param_groups[0][param_name])
                tb_logger.add_scalar(param_name, param * batch_size,
                                     engine.state.iteration)
            else:
                for i, param_group in enumerate(optimizer.param_groups):
                    param = float(param_group[param_name])
                    tb_logger.add_scalar(
                        "{}/{}/group_{}".format(param_name, model_name, i),
                        param * batch_size, engine.state.iteration)

            # log training loss
            tb_logger.add_scalar("training/loss_vs_iterations",
                                 engine.state.output / batch_size,
                                 engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def on_epoch_completed(engine):
        trainer.state.train_time = train_timer.value()

        if config["use_tb_logger"]:
            # Log |w|^2 and gradients
            for i, p in enumerate(model.parameters()):
                tb_logger.add_scalar(
                    "w2/{}/{}_{}".format(model_name, i, list(p.data.shape)),
                    torch.norm(p.data), engine.state.epoch)
                tb_logger.add_scalar(
                    "mean_grad/{}/{}_{}".format(model_name, i,
                                                list(p.grad.shape)),
                    torch.mean(p.grad), engine.state.epoch)

        for i, p in enumerate(model.parameters()):
            plx_experiment.log_metrics(
                step=engine.state.epoch,
                **{
                    "w2/{}/{}_{}".format(model_name, i, list(p.data.shape)):
                    torch.norm(p.data).item()
                })

        evaluator.run(test_loader)

    trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)

    @evaluator.on(Events.COMPLETED)
    def log_results(engine):
        evaluator.state.test_time = test_timer.value()
        metrics = evaluator.state.metrics
        output = [("epoch", trainer.state.epoch)]
        output += [(key, trainer.state.param_history[key][-1][0] * batch_size)
                   for key in trainer.state.param_history]
        output += [("train time", trainer.state.train_time),
                   ("train loss", trainer.state.output / batch_size),
                   ("test time", evaluator.state.test_time),
                   ("test loss", metrics['loss'] / batch_size),
                   ("test acc", metrics['accuracy']),
                   ("total time", total_timer.value())]
        output = OrderedDict(output)
        table_logger.append(output)

        plx_experiment.log_metrics(step=trainer.state.epoch, **output)

        if config["use_tb_logger"]:
            tb_logger.add_scalar("training/total_time", total_timer.value(),
                                 trainer.state.epoch)
            tb_logger.add_scalar("test/loss", metrics['loss'] / batch_size,
                                 trainer.state.epoch)
            tb_logger.add_scalar("test/accuracy", metrics['accuracy'],
                                 trainer.state.epoch)

    @trainer.on(Events.COMPLETED)
    def on_training_completed(engine):
        if config["use_tb_logger"]:

            train_evaluator.run(train_loader)
            metrics = train_evaluator.state.metrics

            tb_logger.add_scalar("training/loss", metrics['loss'] / batch_size,
                                 0)
            tb_logger.add_scalar("training/loss", metrics['loss'] / batch_size,
                                 trainer.state.epoch)

            tb_logger.add_scalar("training/accuracy", metrics['accuracy'], 0)
            tb_logger.add_scalar("training/accuracy", metrics['accuracy'],
                                 trainer.state.epoch)

    trainer.run(train_loader, max_epochs=config['num_epochs'])

    if config["use_tb_logger"]:
        tb_logger.close()
def do_train(cfg, model, train_loader, val_loader, optimizer, scheduler,
             loss_fn):
    log_period = cfg.SOLVER.LOG_PERIOD
    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    output_dir = cfg.OUTPUT_DIR
    device = cfg.MODEL.DEVICE
    epochs = cfg.SOLVER.MAX_EPOCHS

    logger = logging.getLogger("template_model.train")
    logger.info("Start training")
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        loss_fn,
                                        device=device)
    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                "accuracy": Accuracy(),
                                                "ce_loss": Loss(loss_fn)
                                            },
                                            device=device)
    checkpointer = ModelCheckpoint(output_dir,
                                   "mnist",
                                   checkpoint_period,
                                   n_saved=10,
                                   require_empty=False)
    timer = Timer(average=True)

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        checkpointer,
        {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict()
        },
    )
    timer.attach(
        trainer,
        start=Events.EPOCH_STARTED,
        resume=Events.ITERATION_STARTED,
        pause=Events.ITERATION_COMPLETED,
        step=Events.ITERATION_COMPLETED,
    )

    RunningAverage(output_transform=lambda x: x).attach(trainer, "avg_loss")

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.state.iteration - 1) % len(train_loader) + 1

        if iter % log_period == 0:
            logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}".format(
                engine.state.epoch,
                iter,
                len(train_loader),
                engine.state.metrics["avg_loss"],
            ))

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_loss = metrics["ce_loss"]
        logger.info(
            "Training Results - Epoch: {} Avg accuracy: {:.3f} Avg Loss: {:.3f}"
            .format(engine.state.epoch, avg_accuracy, avg_loss))

    if val_loader is not None:

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_validation_results(engine):
            evaluator.run(val_loader)
            metrics = evaluator.state.metrics
            avg_accuracy = metrics["accuracy"]
            avg_loss = metrics["ce_loss"]
            logger.info(
                "Validation Results - Epoch: {} Avg accuracy: {:.3f} Avg Loss: {:.3f}"
                .format(engine.state.epoch, avg_accuracy, avg_loss))

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        logger.info(
            "Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]"
            .format(
                engine.state.epoch,
                timer.value() * timer.step_count,
                train_loader.batch_size / timer.value(),
            ))
        timer.reset()

    trainer.run(train_loader, max_epochs=epochs)
Ejemplo n.º 18
0
def engine_train(cfg):
    prepare_config_train(cfg)

    ckpt_nets = cfg.train.general.ckpt_nets
    ckpt_path = cfg.train.general.ckpt_path
    epochs = cfg.train.solver.epochs
    gpu = cfg.general.gpu
    lr = cfg.train.solver.lr
    lr_gamma = cfg.train.solver.lr_gamma
    lr_step = cfg.train.solver.lr_step
    optim = cfg.train.solver.optim
    renderer_lr = cfg.train.solver.renderer_lr
    root_path = cfg.log.root_path
    save_freq = cfg.train.solver.save_freq
    seed = cfg.general.seed

    eu.redirect_stdout(root_path, 'train')
    eu.print_config(cfg)

    eu.seed_random(seed)

    device = eu.get_device(gpu)

    dataloader = get_dataloader_train(cfg)
    num_batches = len(dataloader)

    render_model, desc_model = get_models(cfg)
    render_model.to(device)
    render_model.train_mode()
    render_model.print_params('render_model')
    desc_model.to(device)
    desc_model.train_mode()
    desc_model.print_params('desc_model')

    crit = get_criterion(cfg)
    print('[*] Loss Function:', crit.__class__.__name__)

    render_params = render_model.params(True, named=False, add_prefix=False)
    render_optimizer = eu.get_optimizer(optim, renderer_lr, render_params)
    render_lr_scheduler = eu.get_lr_scheduler(lr_step, lr_gamma,
                                              render_optimizer)

    desc_params = desc_model.params(True, named=False, add_prefix=False)
    desc_optimizer = eu.get_optimizer(optim, lr, desc_params)
    desc_lr_scheduler = eu.get_lr_scheduler(lr_step, lr_gamma, desc_optimizer)

    if eu.is_not_empty(ckpt_path):
        render_model.load(ckpt_path, ckpt_nets)
        desc_model.load(ckpt_path, ckpt_nets)

    tbwriter = eu.get_tbwriter(root_path)

    engine = Engine(
        functools.partial(step_train,
                          render_model=render_model,
                          desc_model=desc_model,
                          render_optimizer=render_optimizer,
                          desc_optimizer=desc_optimizer,
                          criterion=crit,
                          tbwriter=tbwriter,
                          device=device,
                          cfg=cfg))
    engine.add_event_handler(Events.EPOCH_COMPLETED,
                             eu.step_lr_scheduler,
                             scheduler=render_lr_scheduler)
    engine.add_event_handler(Events.EPOCH_COMPLETED,
                             eu.step_lr_scheduler,
                             scheduler=desc_lr_scheduler)

    ckpt_handler = ModelCheckpoint(root_path,
                                   eu.PTH_PREFIX,
                                   atomic=False,
                                   save_interval=save_freq,
                                   n_saved=epochs // save_freq,
                                   require_empty=False)
    render_subnets = render_model.subnet_dict()
    desc_subnets = desc_model.subnet_dict()
    engine.add_event_handler(Events.EPOCH_COMPLETED,
                             ckpt_handler,
                             to_save={
                                 **render_subnets,
                                 **desc_subnets
                             })

    timer = Timer(average=True)
    timer.attach(engine,
                 start=Events.EPOCH_STARTED,
                 pause=Events.EPOCH_COMPLETED,
                 resume=Events.ITERATION_STARTED,
                 step=Events.ITERATION_COMPLETED)

    engine.add_event_handler(Events.ITERATION_COMPLETED,
                             eu.print_train_log,
                             timer=timer,
                             num_batches=num_batches,
                             cfg=cfg)

    engine.add_event_handler(Events.EXCEPTION_RAISED, eu.handle_exception)

    engine.run(dataloader, epochs)

    tbwriter.close()

    return root_path
Ejemplo n.º 19
0
def train_model(
    name="",
    resume="",
    base_dir=utils.BASE_DIR,
    model_name="v0",
    chosen_diseases=None,
    n_epochs=10,
    batch_size=4,
    oversample=False,
    max_os=None,
    shuffle=False,
    opt="sgd",
    opt_params={},
    loss_name="wbce",
    loss_params={},
    train_resnet=False,
    log_metrics=None,
    flush_secs=120,
    train_max_images=None,
    val_max_images=None,
    test_max_images=None,
    experiment_mode="debug",
    save=True,
    save_cms=True,  # Note that in this case, save_cms (to disk) includes write_cms (to TB)
    write_graph=False,
    write_emb=False,
    write_emb_img=False,
    write_img=False,
    image_format="RGB",
    multiple_gpu=False,
):

    # Choose GPU
    device = utilsT.get_torch_device()
    print("Using device: ", device)

    # Common folders
    dataset_dir = os.path.join(base_dir, "dataset")

    # Dataset handling
    print("Loading train dataset...")
    train_dataset, train_dataloader = utilsT.prepare_data(
        dataset_dir,
        "train",
        chosen_diseases,
        batch_size,
        oversample=oversample,
        max_os=max_os,
        shuffle=shuffle,
        max_images=train_max_images,
        image_format=image_format,
    )
    train_samples, _ = train_dataset.size()

    print("Loading val dataset...")
    val_dataset, val_dataloader = utilsT.prepare_data(
        dataset_dir,
        "val",
        chosen_diseases,
        batch_size,
        max_images=val_max_images,
        image_format=image_format,
    )
    val_samples, _ = val_dataset.size()

    # Should be the same than chosen_diseases
    chosen_diseases = list(train_dataset.classes)
    print("Chosen diseases: ", chosen_diseases)

    if resume:
        # Load model and optimizer
        model, model_name, optimizer, opt, loss_name, loss_params, chosen_diseases = models.load_model(
            base_dir, resume, experiment_mode="", device=device)
        model.train(True)
    else:
        # Create model
        model = models.init_empty_model(model_name,
                                        chosen_diseases,
                                        train_resnet=train_resnet).to(device)

        # Create optimizer
        OptClass = optimizers.get_optimizer_class(opt)
        optimizer = OptClass(model.parameters(), **opt_params)
        # print("OPT: ", opt_params)

    # Allow multiple GPUs
    if multiple_gpu:
        model = DataParallel(model)

    # Tensorboard log options
    run_name = utils.get_timestamp()
    if name:
        run_name += "_{}".format(name)

    if len(chosen_diseases) == 1:
        run_name += "_{}".format(chosen_diseases[0])
    elif len(chosen_diseases) == 14:
        run_name += "_all"

    log_dir = get_log_dir(base_dir, run_name, experiment_mode=experiment_mode)

    print("Run name: ", run_name)
    print("Saved TB in: ", log_dir)

    writer = SummaryWriter(log_dir=log_dir, flush_secs=flush_secs)

    # Create validator engine
    validator = Engine(
        utilsT.get_step_fn(model, optimizer, device, loss_name, loss_params,
                           False))

    val_loss = RunningAverage(output_transform=lambda x: x[0], alpha=1)
    val_loss.attach(validator, loss_name)

    utilsT.attach_metrics(validator, chosen_diseases, "prec", Precision, True)
    utilsT.attach_metrics(validator, chosen_diseases, "recall", Recall, True)
    utilsT.attach_metrics(validator, chosen_diseases, "acc", Accuracy, True)
    utilsT.attach_metrics(validator, chosen_diseases, "roc_auc",
                          utilsT.RocAucMetric, False)
    utilsT.attach_metrics(validator,
                          chosen_diseases,
                          "cm",
                          ConfusionMatrix,
                          get_transform_fn=utilsT.get_transform_cm,
                          metric_args=(2, ))
    utilsT.attach_metrics(validator,
                          chosen_diseases,
                          "positives",
                          RunningAverage,
                          get_transform_fn=utilsT.get_count_positives)

    # Create trainer engine
    trainer = Engine(
        utilsT.get_step_fn(model, optimizer, device, loss_name, loss_params,
                           True))

    train_loss = RunningAverage(output_transform=lambda x: x[0], alpha=1)
    train_loss.attach(trainer, loss_name)

    utilsT.attach_metrics(trainer, chosen_diseases, "acc", Accuracy, True)
    utilsT.attach_metrics(trainer, chosen_diseases, "prec", Precision, True)
    utilsT.attach_metrics(trainer, chosen_diseases, "recall", Recall, True)
    utilsT.attach_metrics(trainer, chosen_diseases, "roc_auc",
                          utilsT.RocAucMetric, False)
    utilsT.attach_metrics(trainer,
                          chosen_diseases,
                          "cm",
                          ConfusionMatrix,
                          get_transform_fn=utilsT.get_transform_cm,
                          metric_args=(2, ))
    utilsT.attach_metrics(trainer,
                          chosen_diseases,
                          "positives",
                          RunningAverage,
                          get_transform_fn=utilsT.get_count_positives)

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 step=Events.EPOCH_COMPLETED)

    # TODO: Early stopping
    #     def score_function(engine):
    #         val_loss = engine.state.metrics[loss_name]
    #         return -val_loss

    #     handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer)
    #     validator.add_event_handler(Events.COMPLETED, handler)

    # Metrics callbacks
    if log_metrics is None:
        log_metrics = list(ALL_METRICS)

    def _write_metrics(run_type, metrics, epoch, wall_time):
        loss = metrics.get(loss_name, 0)

        writer.add_scalar("Loss/" + run_type, loss, epoch, wall_time)

        for metric_base_name in log_metrics:
            for disease in chosen_diseases:
                metric_value = metrics.get(
                    "{}_{}".format(metric_base_name, disease), -1)
                writer.add_scalar(
                    "{}_{}/{}".format(metric_base_name, disease, run_type),
                    metric_value, epoch, wall_time)

    @trainer.on(Events.EPOCH_COMPLETED)
    def tb_write_metrics(trainer):
        epoch = trainer.state.epoch
        max_epochs = trainer.state.max_epochs

        # Run on evaluation
        validator.run(val_dataloader, 1)

        # Common time
        wall_time = time.time()

        # Log all metrics to TB
        _write_metrics("train", trainer.state.metrics, epoch, wall_time)
        _write_metrics("val", validator.state.metrics, epoch, wall_time)

        train_loss = trainer.state.metrics.get(loss_name, 0)
        val_loss = validator.state.metrics.get(loss_name, 0)

        tb_write_histogram(writer, model, epoch, wall_time)

        print("Finished epoch {}/{}, loss {:.3f}, val loss {:.3f} (took {})".
              format(epoch, max_epochs, train_loss, val_loss,
                     utils.duration_to_str(int(timer._elapsed()))))

    # Hparam dict
    hparam_dict = {
        "resume": resume,
        "n_diseases": len(chosen_diseases),
        "diseases": ",".join(chosen_diseases),
        "n_epochs": n_epochs,
        "batch_size": batch_size,
        "shuffle": shuffle,
        "model_name": model_name,
        "opt": opt,
        "loss": loss_name,
        "samples (train, val)": "{},{}".format(train_samples, val_samples),
        "train_resnet": train_resnet,
        "multiple_gpu": multiple_gpu,
    }

    def copy_params(params_dict, base_name):
        for name, value in params_dict.items():
            hparam_dict["{}_{}".format(base_name, name)] = value

    copy_params(loss_params, "loss")
    copy_params(opt_params, "opt")
    print("HPARAM: ", hparam_dict)

    # Train
    print("-" * 50)
    print("Training...")
    trainer.run(train_dataloader, n_epochs)

    # Capture time
    secs_per_epoch = timer.value()
    duration_per_epoch = utils.duration_to_str(int(secs_per_epoch))
    print("Average time per epoch: ", duration_per_epoch)
    print("-" * 50)

    ## Write all hparams
    hparam_dict["duration_per_epoch"] = duration_per_epoch

    # FIXME: this is commented to avoid having too many hparams in TB frontend
    # metrics
    #     def copy_metrics(engine, engine_name):
    #         for metric_name, metric_value in engine.state.metrics.items():
    #             hparam_dict["{}_{}".format(engine_name, metric_name)] = metric_value
    #     copy_metrics(trainer, "train")
    #     copy_metrics(validator, "val")

    print("Writing TB hparams")
    writer.add_hparams(hparam_dict, {})

    # Save model to disk
    if save:
        print("Saving model...")
        models.save_model(base_dir, run_name, model_name, experiment_mode,
                          hparam_dict, trainer, model, optimizer)

    # Write graph to TB
    if write_graph:
        print("Writing TB graph...")
        tb_write_graph(writer, model, train_dataloader, device)

    # Write embeddings to TB
    if write_emb:
        print("Writing TB embeddings...")
        image_size = 256 if write_emb_img else 0

        # FIXME: be able to select images (balanced, train vs val, etc)
        image_list = list(train_dataset.label_index["FileName"])[:1000]
        # disease = chosen_diseases[0]
        # positive = train_dataset.label_index[train_dataset.label_index[disease] == 1]
        # negative = train_dataset.label_index[train_dataset.label_index[disease] == 0]
        # positive_images = list(positive["FileName"])[:25]
        # negative_images = list(negative["FileName"])[:25]
        # image_list = positive_images + negative_images

        all_images, all_embeddings, all_predictions, all_ground_truths = gen_embeddings(
            model,
            train_dataset,
            device,
            image_list=image_list,
            image_size=image_size)
        tb_write_embeddings(
            writer,
            chosen_diseases,
            all_images,
            all_embeddings,
            all_predictions,
            all_ground_truths,
            global_step=n_epochs,
            use_images=write_emb_img,
            tag="1000_{}".format("img" if write_emb_img else "no_img"),
        )

    # Save confusion matrices (is expensive to calculate them afterwards)
    if save_cms:
        print("Saving confusion matrices...")
        # Assure folder
        cms_dir = os.path.join(base_dir, "cms", experiment_mode)
        os.makedirs(cms_dir, exist_ok=True)
        base_fname = os.path.join(cms_dir, run_name)

        n_diseases = len(chosen_diseases)

        def extract_cms(metrics):
            """Extract confusion matrices from a metrics dict."""
            cms = []
            for disease in chosen_diseases:
                key = "cm_" + disease
                if key not in metrics:
                    cm = np.array([[-1, -1], [-1, -1]])
                else:
                    cm = metrics[key].numpy()

                cms.append(cm)
            return np.array(cms)

        # Train confusion matrix
        train_cms = extract_cms(trainer.state.metrics)
        np.save(base_fname + "_train", train_cms)
        tb_write_cms(writer, "train", chosen_diseases, train_cms)

        # Validation confusion matrix
        val_cms = extract_cms(validator.state.metrics)
        np.save(base_fname + "_val", val_cms)
        tb_write_cms(writer, "val", chosen_diseases, val_cms)

        # All confusion matrix (train + val)
        all_cms = train_cms + val_cms
        np.save(base_fname + "_all", all_cms)

        # Print to console
        if len(chosen_diseases) == 1:
            print("Train CM: ")
            print(train_cms[0])
            print("Val CM: ")
            print(val_cms[0])


#             print("Train CM 2: ")
#             print(trainer.state.metrics["cm_" + chosen_diseases[0]])
#             print("Val CM 2: ")
#             print(validator.state.metrics["cm_" + chosen_diseases[0]])

    if write_img:
        # NOTE: this option is not recommended, use Testing notebook to plot and analyze images

        print("Writing images to TB...")

        test_dataset, test_dataloader = utilsT.prepare_data(
            dataset_dir,
            "test",
            chosen_diseases,
            batch_size,
            max_images=test_max_images,
        )

        # TODO: add a way to select images?
        # image_list = list(test_dataset.label_index["FileName"])[:3]

        # Examples in test_dataset (with bboxes available):
        image_list = [
            # "00010277_000.png", # (Effusion, Infiltrate, Mass, Pneumonia)
            # "00018427_004.png", # (Atelectasis, Effusion, Mass)
            # "00021703_001.png", # (Atelectasis, Effusion, Infiltrate)
            # "00028640_008.png", # (Effusion, Infiltrate)
            # "00019124_104.png", # (Pneumothorax)
            # "00019124_090.png", # (Nodule)
            # "00020318_007.png", # (Pneumothorax)
            "00000003_000.png",  # (0)
            # "00000003_001.png", # (0)
            # "00000003_002.png", # (0)
            "00000732_005.png",  # (Cardiomegaly, Pneumothorax)
            # "00012261_001.png", # (Cardiomegaly, Pneumonia)
            # "00013249_033.png", # (Cardiomegaly, Pneumonia)
            # "00029808_003.png", # (Cardiomegaly, Pneumonia)
            # "00022215_012.png", # (Cardiomegaly, Pneumonia)
            # "00011402_007.png", # (Cardiomegaly, Pneumonia)
            # "00019018_007.png", # (Cardiomegaly, Infiltrate)
            # "00021009_001.png", # (Cardiomegaly, Infiltrate)
            # "00013670_151.png", # (Cardiomegaly, Infiltrate)
            # "00005066_030.png", # (Cardiomegaly, Infiltrate, Effusion)
            "00012288_000.png",  # (Cardiomegaly)
            "00008399_007.png",  # (Cardiomegaly)
            "00005532_000.png",  # (Cardiomegaly)
            "00005532_014.png",  # (Cardiomegaly)
            "00005532_016.png",  # (Cardiomegaly)
            "00005827_000.png",  # (Cardiomegaly)
            # "00006912_007.png", # (Cardiomegaly)
            # "00007037_000.png", # (Cardiomegaly)
            # "00007043_000.png", # (Cardiomegaly)
            # "00012741_004.png", # (Cardiomegaly)
            # "00007551_020.png", # (Cardiomegaly)
            # "00007735_040.png", # (Cardiomegaly)
            # "00008339_010.png", # (Cardiomegaly)
            # "00008365_000.png", # (Cardiomegaly)
            # "00012686_003.png", # (Cardiomegaly)
        ]

        tb_write_images(writer, model, test_dataset, chosen_diseases, n_epochs,
                        device, image_list)

    # Close TB writer
    if experiment_mode != "debug":
        writer.close()

    # Run post_train
    print("-" * 50)
    print("Running post_train...")

    print("Loading test dataset...")
    test_dataset, test_dataloader = utilsT.prepare_data(
        dataset_dir,
        "test",
        chosen_diseases,
        batch_size,
        max_images=test_max_images)

    save_cms_with_names(run_name, experiment_mode, model, test_dataset,
                        test_dataloader, chosen_diseases)

    evaluate_model(run_name,
                   model,
                   optimizer,
                   device,
                   loss_name,
                   loss_params,
                   chosen_diseases,
                   test_dataloader,
                   experiment_mode=experiment_mode,
                   base_dir=base_dir)

    # Return values for debugging
    model_run = ModelRun(model, run_name, model_name, chosen_diseases)
    if experiment_mode == "debug":
        model_run.save_debug_data(writer, trainer, validator, train_dataset,
                                  train_dataloader, val_dataset,
                                  val_dataloader)

    return model_run
Ejemplo n.º 20
0
def do_train(cfg,
             model,
             train_loader,
             val_loader,
             optimizer,
             loss_fns,
             n_fold=0):
    log_period = cfg.SOLVER.LOG_PERIOD
    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    epochs = cfg.SOLVER.MAX_EPOCHS
    device = cfg.MODEL.DEVICE
    output_dir = cfg.OUTPUT_DIR
    lr_scheduler = ReduceLROnPlateau(optimizer,
                                     mode='min',
                                     factor=0.1,
                                     threshold=1e-3,
                                     patience=3,
                                     min_lr=5e-6,
                                     eps=1e-08,
                                     verbose=True)

    logger = logging.getLogger("MOA_MLP.train")
    logger.info("Start training")

    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        loss_fns,
                                        device=device)
    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                'log_loss': LogLoss(loss_fns),
                                                'cv_score': CV_Score()
                                            },
                                            device=device)
    checkpointer = ModelCheckpoint(output_dir,
                                   'moa_mlp_' + str(n_fold),
                                   n_saved=100,
                                   require_empty=False)

    timer = Timer(average=True)

    # automatically adding handlers via a special `attach` method of `RunningAverage` handler
    RunningAverage(output_transform=lambda x: x).attach(trainer, 'avg_loss')

    # automatically adding handlers via a special `attach` method of `Checkpointer` handler
    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=checkpoint_period),
                              checkpointer, {
                                  'model': model,
                                  'optimizer': optimizer
                              })

    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda engine: lr_scheduler.step)

    # automatically adding handlers via a special `attach` method of `Timer` handler
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.state.iteration - 1) % len(train_loader) + 1

        if iter % log_period == 0:
            logger.info(
                "K:[{}] Epoch[{}] Iteration[{}/{}] LR: {} Log Loss: {:.3f}".
                format(n_fold, engine.state.epoch, iter, len(train_loader),
                       optimizer.param_groups[0]['lr'],
                       engine.state.metrics['avg_loss']))

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics

        log_loss = metrics['log_loss']
        cv_score = metrics['cv_score']
        logger.info(
            "Training Results - K:[{}] Epoch: {} LR: {} Log Loss: {:.3f} CV Score: {:.3f}"
            .format(n_fold, engine.state.epoch,
                    optimizer.param_groups[0]['lr'], log_loss, cv_score))

    if val_loader is not None:
        # adding handlers using `trainer.on` decorator API
        @trainer.on(Events.EPOCH_COMPLETED)
        def log_validation_results(engine):
            evaluator.run(val_loader)
            metrics = evaluator.state.metrics
            log_loss = metrics['log_loss']
            cv_score = metrics['cv_score']
            logger.info(
                "Validation Results - K:[{}] Epoch: {} LR: {} Log Loss: {:.3f} CV Score: {:.3f}"
                .format(n_fold, engine.state.epoch,
                        optimizer.param_groups[0]['lr'], log_loss, cv_score))

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        logger.info(
            'K:[{}] Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
            .format(n_fold, engine.state.epoch,
                    timer.value() * timer.step_count,
                    train_loader.batch_size / timer.value()))
        timer.reset()

    trainer.run(train_loader, max_epochs=epochs)
Ejemplo n.º 21
0
def do_train(
    cfg,
    model,
    train_loader,
    val_loader,
    optimizer,
    loss_fn,
):
    log_period = cfg.SOLVER.LOG_PERIOD
    output_dir = cfg.OUTPUT_DIR
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'
                          ) if cfg.MODEL.DEVICE == 'cuda' else 'cpu'
    epochs = cfg.SOLVER.MAX_EPOCHS

    logger.info("Start training")
    logger.info("use {}".format(device))
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        loss_fn,
                                        device=device)
    evaluator = create_supervised_evaluator(model,
                                            metrics={
                                                'accuracy': Accuracy(),
                                                'ce_loss': Loss(loss_fn)
                                            },
                                            device=device)
    checkpointer = ModelCheckpoint(output_dir,
                                   'mnist',
                                   n_saved=10,
                                   require_empty=False)

    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), checkpointer,
                              {'model': model})

    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    RunningAverage(output_transform=lambda x: x).attach(trainer, 'avg_loss')

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.state.iteration - 1) % len(train_loader) + 1

        if iter % log_period == 0:
            logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}".format(
                engine.state.epoch, iter, len(train_loader),
                engine.state.metrics['avg_loss']))

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics['accuracy']
        avg_loss = metrics['ce_loss']
        logger.info(
            "Training Results - Epoch: {} Avg accuracy: {:.3f} Avg Loss: {:.3f}"
            .format(engine.state.epoch, avg_accuracy, avg_loss))

    if val_loader is not None:

        @trainer.on(Events.EPOCH_COMPLETED)
        def log_validation_results(engine):
            evaluator.run(val_loader)
            metrics = evaluator.state.metrics
            avg_accuracy = metrics['accuracy']
            avg_loss = metrics['ce_loss']
            logger.info(
                "Validation Results - Epoch: {} Avg accuracy: {:.3f} Avg Loss: {:.3f}"
                .format(engine.state.epoch, avg_accuracy, avg_loss))

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        logger.info(
            'Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
            .format(engine.state.epoch,
                    timer.value() * timer.step_count,
                    train_loader.batch_size / timer.value()))
        timer.reset()

    trainer.run(train_loader, max_epochs=epochs)
Ejemplo n.º 22
0
def do_train(
        cfg,
        model,
        train_loader,
        val_loader,
        classes_list,
        optimizers,
        schedulers,
        loss_fn,
        start_epoch
):
    #1.先把cfg中的参数导出
    epochs = cfg.SOLVER.MAX_EPOCHS
    log_period = cfg.SOLVER.LOG_PERIOD
    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    eval_period = cfg.SOLVER.EVAL_PERIOD
    output_dir = cfg.SOLVER.OUTPUT_DIR
    device = cfg.MODEL.DEVICE

    #2.构建模块
    logger = logging.getLogger("fundus_prediction.train")
    logger.info("Start training")

    # TensorBoard setup
    writer_train = {}
    for i in range(len(optimizers)):
        writer_train[i] = SummaryWriter(cfg.SOLVER.OUTPUT_DIR + "/summary/train/" + str(i))

    writer_val = SummaryWriter(cfg.SOLVER.OUTPUT_DIR + "/summary/val")

    writer_train["graph"] = SummaryWriter(cfg.SOLVER.OUTPUT_DIR + "/summary/train/graph")
    try:
        #print(model)
        images, labels = next(iter(train_loader))
        grid = torchvision.utils.make_grid(images)
        writer_train["graph"].add_image('images', grid, 0)
        writer_train["graph"].add_graph(model, images)
        writer_train["graph"].flush()
    except Exception as e:
        print("Failed to save model graph: {}".format(e))


    # 设置训练相关的metrics
    metrics_train = {"avg_total_loss": RunningAverage(output_transform=lambda x: x["total_loss"]),
                     "avg_precision": RunningAverage(Precision(output_transform=lambda x: (x["scores"], x["labels"]))),
                     "avg_accuracy": RunningAverage(Accuracy(output_transform=lambda x: (x["scores"], x["labels"]))),  #由于训练集样本均衡后远离原始样本集,故只采用平均metric
                     }

    lossKeys = cfg.LOSS.TYPE.split(" ")
    # 设置loss相关的metrics
    for lossName in lossKeys:
        if lossName == "similarity_loss":
            metrics_train["AVG-" + "similarity_loss"] = RunningAverage(
                output_transform=lambda x: x["losses"]["similarity_loss"])
        elif lossName == "ranked_loss":
            metrics_train["AVG-" + "ranked_loss"] = RunningAverage(
                output_transform=lambda x: x["losses"]["ranked_loss"])
        elif lossName == "cranked_loss":
            metrics_train["AVG-" + "cranked_loss"] = RunningAverage(
                output_transform=lambda x: x["losses"]["cranked_loss"])
        elif lossName == "cross_entropy_loss":
            metrics_train["AVG-" + "cross_entropy_loss"] = RunningAverage(
                output_transform=lambda x: x["losses"]["cross_entropy_loss"])
        elif lossName == "cluster_loss":
            metrics_train["AVG-" + "cluster_loss"] = RunningAverage(
                output_transform=lambda x: x["losses"]["cluster_loss"][0])
        elif lossName == "one_vs_rest_loss":
            metrics_train["AVG-" + "one_vs_rest_loss"] = RunningAverage(
                output_transform=lambda x: x["losses"]["one_vs_rest_loss"])
        elif lossName == "attention_loss":
            metrics_train["AVG-" + "attention_loss"] = RunningAverage(
                output_transform=lambda x: x["losses"]["attention_loss"])
        elif lossName == "class_predict_loss":
            metrics_train["AVG-" + "class_predict_loss"] = RunningAverage(
                output_transform=lambda x: x["losses"]["class_predict_loss"])
        elif lossName == "kld_loss":
            metrics_train["AVG-" + "kld_loss"] = RunningAverage(
                output_transform=lambda x: x["losses"]["kld_loss"])
        elif lossName == "margin_loss":
            metrics_train["AVG-" + "margin_loss"] = RunningAverage(
                output_transform=lambda x: x["losses"]["margin_loss"])
        elif lossName == "cross_entropy_multilabel_loss":
            metrics_train["AVG-" + "cross_entropy_multilabel_loss"] = RunningAverage(
                output_transform=lambda x: x["losses"]["cross_entropy_multilabel_loss"])
        else:
            raise Exception('expected METRIC_LOSS_TYPE should be similarity_loss, ranked_loss, cranked_loss'
                            'but got {}'.format(cfg.LOSS.TYPE))


    trainer = create_supervised_trainer(model, optimizers, metrics_train, loss_fn, device=device)

    #CJY  at 2019.9.26
    def output_transform(output):
        # `output` variable is returned by above `process_function`
        y_pred = output['scores']
        y = output['labels']
        return y_pred, y  # output format is according to `Accuracy` docs

    metrics_eval = {"overall_accuracy": Accuracy(output_transform=output_transform),
                    "precision": Precision(output_transform=output_transform)}

    checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=100, require_empty=False, start_step=start_epoch)
    timer = Timer(average=True)

    #3.将模块与engine联系起来attach
    #CJY at 2019.9.23
    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model,
                                                                     'optimizer': optimizers[0]})

    #trainer.add_event_handler(Events.STARTED, checkpointer, {'model': model,
    #                                                                'optimizer': optimizers[0]})

    timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)


    #4.事件处理函数
    @trainer.on(Events.STARTED)
    def start_training(engine):
        engine.state.epoch = start_epoch
        engine.state.iteration = engine.state.iteration + start_epoch * len(train_loader)
        """
        metrics = do_inference(cfg, model, val_loader, classes_list, loss_fn, plotFlag=False)

        step = 0#len(train_loader) * (engine.state.epoch - 1) + engine.state.iteration
        for preKey in metrics['precision'].keys():
            writer_val.add_scalar("Precision/" + str(preKey), metrics['precision'][preKey], step)

        for recKey in metrics['recall'].keys():
            writer_val.add_scalar("Recall/" + str(recKey), metrics['recall'][recKey], step)

        for aucKey in metrics['roc_auc'].keys():
            writer_val.add_scalar("ROC_AUC/" + str(aucKey), metrics['roc_auc'][aucKey], step)

        writer_val.add_scalar("OverallAccuracy", metrics["overall_accuracy"], step)

        # writer.add_scalar("Val/"+"confusion_matrix", metrics['confusion_matrix'], step)

        # 混淆矩阵 和 ROC曲线可以用图的方式来存储
        roc_numpy = metrics["roc_figure"]
        writer_val.add_image("ROC", roc_numpy, step, dataformats='HWC')

        confusion_matrix_numpy = metrics["confusion_matrix_numpy"]
        writer_val.add_image("ConfusionMatrix", confusion_matrix_numpy, step, dataformats='HWC')

        writer_val.flush()
        #"""

    @trainer.on(Events.EPOCH_COMPLETED) #_STARTED)   #注意,在pytorch1.2里面 scheduler.steo()应该放到 optimizer.step()之后
    def adjust_learning_rate(engine):
        """
        #if (engine.state.epoch - 1) % engine.state.epochs_traverse_optimizers == 0:
        if engine.state.epoch == 2:
            op_i_scheduler1 = WarmupMultiStepLR(optimizers[0], cfg.SOLVER.SCHEDULER.STEPS, cfg.SOLVER.SCHEDULER.GAMMA,
                                               cfg.SOLVER.SCHEDULER.WARMUP_FACTOR,
                                               cfg.SOLVER.SCHEDULER.WARMUP_ITERS, cfg.SOLVER.SCHEDULER.WARMUP_METHOD)
            op_i_scheduler2 = WarmupMultiStepLR(optimizers[1], cfg.SOLVER.SCHEDULER.STEPS, cfg.SOLVER.SCHEDULER.GAMMA,
                                                cfg.SOLVER.SCHEDULER.WARMUP_FACTOR,
                                                cfg.SOLVER.SCHEDULER.WARMUP_ITERS, cfg.SOLVER.SCHEDULER.WARMUP_METHOD)
            engine.state.schedulers = [op_i_scheduler1, op_i_scheduler2]
            print("copy")
        """
        schedulers[engine.state.schedulers_epochs_index][engine.state.optimizer_index].step()

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        global ITER
        ITER += 1

        if ITER % (log_period*accumulation_steps) == 0:
            step = engine.state.iteration

            #写入train-summary
            #记录avg-presicion
            avg_precision = engine.state.metrics['avg_precision'].numpy().tolist()
            avg_precisions = {}
            ap_sum = 0
            for index, ap in enumerate(avg_precision):
                avg_precisions[index] = float("{:.2f}".format(ap))
                ap_sum += avg_precisions[index]
                scalarDict = {}
                for i in range(len(optimizers)):
                    if i != engine.state.optimizer_index:
                        scalarDict["optimizer" + str(i)] = 0
                    else:
                        scalarDict["optimizer" + str(i)] = avg_precisions[index]
                    writer_train[i].add_scalar("Precision/" + str(index), scalarDict["optimizer" + str(i)], step)
                    writer_train[i].flush()
            avg_precisions["avg_precision"] = float("{:.2f}".format(ap_sum/len(avg_precision)))

            #记录avg-loss
            avg_losses = {}
            for lossName in lossKeys:
                avg_losses[lossName] = (float("{:.3f}".format(engine.state.metrics["AVG-" + lossName])))
                scalarDict = {}
                for i in range(len(optimizers)):
                    if i != engine.state.optimizer_index:
                        scalarDict["optimizer" + str(i)] = 0
                    else:
                        scalarDict["optimizer" + str(i)] = avg_losses[lossName]
                    writer_train[i].add_scalar("Loss/" + lossName, scalarDict["optimizer" + str(i)], step)
                    writer_train[i].flush()

            #记录其余标量
            scalar_list = ["avg_accuracy", "avg_total_loss"]
            for scalar in scalar_list:
                scalarDict = {}
                for i in range(len(optimizers)):
                    if i != engine.state.optimizer_index:
                        scalarDict["optimizer" + str(i)] = 0
                    else:
                        scalarDict["optimizer" + str(i)] = engine.state.metrics[scalar]
                    writer_train[i].add_scalar("Train/" + scalar, scalarDict["optimizer" + str(i)], step)
                    writer_train[i].flush()

            #记录学习率
            LearningRateDict = {}
            for i in range(len(optimizers)):
                if i != engine.state.optimizer_index:
                    LearningRateDict["optimizer" + str(i)] = 0
                else:
                    LearningRateDict["optimizer" + str(i)] = schedulers[engine.state.schedulers_epochs_index][engine.state.optimizer_index].get_lr()[0]
                writer_train[i].add_scalar("Train/" + "LearningRate", LearningRateDict["optimizer" + str(i)], step)
                writer_train[i].flush()

            #记录weight
            choose_list = ["base.conv1.weight", "base.bn1.weight",
                          "base.layer1.0.conv1.weight", "base.layer1.2.conv3.weight",
                          "base.layer2.0.conv1.weight", "base.layer2.3.conv3.weight",
                          "base.layer3.0.conv1.weight", "base.layer3.5.conv3.weight",
                          "base.layer4.0.conv1.weight", "base.layer4.2.conv1.weight",
                          "bottleneck.weight", "classifier.weight"]
            """
            #记录参数分布 非常耗时
            params_dict = {}
            for name, parameters in model.named_parameters():
                #print(name, ':', parameters.size())
                params_dict[name] = parameters.detach().cpu().numpy()
            #print(len(params_dict))
                        
            for cp in params_dict.keys():
                writer_train["graph"].add_histogram("Train/" + cp, params_dict[cp], step)
                writer_train["graph"].flush()
            #"""

            logger.info("Epoch[{}] Iteration[{}/{}] Training {} - ATLoss: {:.3f}, AvgLoss: {}, Avg Pre: {}, Avg_Acc: {:.3f}, Base Lr: {:.2e}, step: {}"
                        .format(engine.state.epoch, ITER, len(train_loader),
                                engine.state.losstype,
                                engine.state.metrics['avg_total_loss'], avg_losses, avg_precisions, engine.state.metrics['avg_accuracy'],
                                schedulers[engine.state.schedulers_epochs_index][engine.state.optimizer_index].get_lr()[0], step))

            #logger.info(engine.state.output["rf_loss"])

            if engine.state.output["losses"].get("cluster_loss") != None:
                logger.info("Epoch[{}] Iteration[{}/{}] Center {} \n r_inter: {}, r_outer: {}, step: {}"
                            .format(engine.state.epoch, ITER, len(train_loader),
                                    engine.state.output["losses"]["cluster_loss"][-1]["center"].cpu().detach().numpy(),
                                    engine.state.output["losses"]["cluster_loss"][-1]["r_inter"].item(),
                                    engine.state.output["losses"]["cluster_loss"][-1]["r_outer"].item(),
                                    step))

        if len(train_loader) == ITER:
            ITER = 0

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
                    .format(engine.state.epoch, timer.value() * timer.step_count,
                            train_loader.batch_size / timer.value()))
        logger.info('-' * 10)
        timer.reset()

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        if engine.state.epoch % eval_period == 0:
            metrics = do_inference(cfg, model, val_loader, classes_list, loss_fn)

            step = engine.state.iteration
            for preKey in metrics['precision'].keys():
                writer_val.add_scalar("Precision/" + str(preKey), metrics['precision'][preKey], step)

            for recKey in metrics['recall'].keys():
                writer_val.add_scalar("Recall/" + str(recKey), metrics['recall'][recKey], step)

            for aucKey in metrics['roc_auc'].keys():
                writer_val.add_scalar("ROC_AUC/" + str(aucKey), metrics['roc_auc'][aucKey], step)

            writer_val.add_scalar("OverallAccuracy", metrics["overall_accuracy"], step)

            #writer.add_scalar("Val/"+"confusion_matrix", metrics['confusion_matrix'], step)

            #混淆矩阵 和 ROC曲线可以用图的方式来存储
            roc_numpy = metrics["roc_figure"]
            writer_val.add_image("ROC", roc_numpy, step, dataformats='HWC')

            confusion_matrix_numpy = metrics["confusion_matrix_numpy"]
            writer_val.add_image("ConfusionMatrix", confusion_matrix_numpy, step, dataformats='HWC')

            writer_val.flush()


    #5.engine运行
    trainer.run(train_loader, max_epochs=epochs)
    for key in writer_train.keys():
        writer_train[key].close()
    writer_val.close()
Ejemplo n.º 23
0
def do_train_with_center(
        cfg,
        model,
        center_criterion,
        train_loader,
        val_loader,
        optimizer,
        optimizer_center,
        scheduler,
        loss_fn,
        num_query,
        start_epoch
):
    log_period = cfg.SOLVER.LOG_PERIOD
    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    eval_period = cfg.SOLVER.EVAL_PERIOD
    output_dir = cfg.OUTPUT_DIR
    device = cfg.MODEL.DEVICE
    epochs = cfg.SOLVER.MAX_EPOCHS

    logger = logging.getLogger("reid_baseline.train")
    logger.info("Start training")
    trainer = create_supervised_trainer_with_center(model, center_criterion, optimizer, optimizer_center, loss_fn,
                                                    cfg.SOLVER.CENTER_LOSS_WEIGHT, device=device)
    evaluator = create_supervised_evaluator(model, metrics={
        'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device)
    checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False)
    timer = Timer(average=True)

    # trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model.state_dict(),
    #                                                                 'optimizer': optimizer.state_dict(),
    #                                                                 'optimizer_center': optimizer_center.state_dict()})
    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model,
                                                                     'optimizer': optimizer,
                                                                     'centerloss': center_criterion,
                                                                     'optimizer_center': optimizer_center})

    timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)

    # average metric to attach on trainer
    RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
    RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')

    @trainer.on(Events.STARTED)
    def start_training(engine):
        engine.state.epoch = start_epoch

    @trainer.on(Events.EPOCH_STARTED)
    def adjust_learning_rate(engine):
        scheduler.step()

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        global ITER
        ITER += 1

        if ITER % log_period == 0:
            logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
                        .format(engine.state.epoch, ITER, len(train_loader),
                                engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'],
                                scheduler.get_lr()[0]))
        if len(train_loader) == ITER:
            ITER = 0

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
                    .format(engine.state.epoch, timer.value() * timer.step_count,
                            train_loader.batch_size / timer.value()))
        logger.info('-' * 10)
        timer.reset()

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        if engine.state.epoch % eval_period == 0:
            evaluator.run(val_loader)
            cmc, mAP = evaluator.state.metrics['r1_mAP']
            logger.info("Validation Results - Epoch: {}".format(engine.state.epoch))
            logger.info("mAP: {:.1%}".format(mAP))
            for r in [1, 5, 10]:
                logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))

    trainer.run(train_loader, max_epochs=epochs)
Ejemplo n.º 24
0
def get_trainer(model, optimizer, lr_scheduler=None, logger=None, writer=None, non_blocking=False, log_period=10,
                save_dir="checkpoints", prefix="model", gallery_loader=None, query_loader=None,
                eval_interval=None, dataset="sysu"):
    if logger is None:
        logger = logging.getLogger()
        logger.setLevel(logging.WARN)

    # trainer
    trainer = create_train_engine(model, optimizer, non_blocking)

    # checkpoint handler
    handler = ModelCheckpoint(save_dir, prefix, save_interval=eval_interval, n_saved=3, create_dir=True,
                              save_as_state_dict=True, require_empty=False)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, {"model": model})

    # metric
    timer = Timer(average=True)

    kv_metric = AutoKVMetric()

    # evaluator
    evaluator = None
    if not type(eval_interval) == int:
        raise TypeError("The parameter 'validate_interval' must be type INT.")
    if eval_interval > 0 and gallery_loader is not None and query_loader is not None:
        evaluator = create_eval_engine(model, non_blocking)

    @trainer.on(Events.EPOCH_STARTED)
    def epoch_started_callback(engine):
        kv_metric.reset()
        timer.reset()

    @trainer.on(Events.EPOCH_COMPLETED)
    def epoch_completed_callback(engine):
        epoch = engine.state.epoch

        if lr_scheduler is not None:
            lr_scheduler.step()

        if epoch % eval_interval == 0:
            logger.info("Model saved at {}/{}_model_{}.pth".format(save_dir, prefix, epoch))

        if evaluator and epoch % eval_interval == 0:
            torch.cuda.empty_cache()

            # extract query feature
            evaluator.run(query_loader)

            q_feats = torch.cat(evaluator.state.feat_list, dim=0)
            q_ids = torch.cat(evaluator.state.id_list, dim=0).numpy()
            q_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy()

            # extract gallery feature
            evaluator.run(gallery_loader)

            g_feats = torch.cat(evaluator.state.feat_list, dim=0)
            g_ids = torch.cat(evaluator.state.id_list, dim=0).numpy()
            g_cams = torch.cat(evaluator.state.cam_list, dim=0).numpy()
            g_img_paths = np.concatenate(evaluator.state.img_path_list, axis=0)
            
            if dataset == "sysu":
                perm = sio.loadmat(os.path.join(dataset_cfg.sysu.data_root, 'exp', 'rand_perm_cam.mat'))[
                    'rand_perm_cam']
                mAP, r1, r5, _, _ = eval_sysu(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams, g_img_paths, perm)
            else:
                mAP, r1, r5, _, _ = eval_regdb(q_feats, q_ids, q_cams, g_feats, g_ids, g_cams)

            if writer is not None:
                writer.add_scalar('eval/mAP', mAP, epoch)
                writer.add_scalar('eval/r1', r1, epoch)
                writer.add_scalar('eval/r5', r5, epoch)

            evaluator.state.feat_list.clear()
            evaluator.state.id_list.clear()
            evaluator.state.cam_list.clear()
            evaluator.state.img_path_list.clear()
            del q_feats, q_ids, q_cams, g_feats, g_ids, g_cams

            torch.cuda.empty_cache()

    @trainer.on(Events.ITERATION_COMPLETED)
    def iteration_complete_callback(engine):
        timer.step()

        kv_metric.update(engine.state.output)

        epoch = engine.state.epoch
        iteration = engine.state.iteration
        iter_in_epoch = iteration - (epoch - 1) * len(engine.state.dataloader)

        if iter_in_epoch % log_period == 0:
            batch_size = engine.state.batch[0].size(0)
            speed = batch_size / timer.value()

            msg = "Epoch[%d] Batch [%d]\tSpeed: %.2f samples/sec" % (epoch, iter_in_epoch, speed)

            metric_dict = kv_metric.compute()

            # log output information
            if logger is not None:
                for k in sorted(metric_dict.keys()):
                    msg += "\t%s: %.4f" % (k, metric_dict[k])
                    if writer is not None:
                        writer.add_scalar('metric/{}'.format(k), metric_dict[k], iteration)

                logger.info(msg)

            kv_metric.reset()
            timer.reset()

    return trainer
Ejemplo n.º 25
0
def run(opt):
    # logging.basicConfig(filename=os.path.join(opt.log_dir, opt.log_file), level=logging.INFO)
    # logger = logging.getLogger()
    # # logger.addHandler(logging.StreamHandler())
    # logger = logger.info

    log = Logger(filename=os.path.join(opt.log_dir, opt.log_file),
                 level='debug')
    logger = log.logger.info

    # Decide what attrs to train
    attr, attr_name = get_tasks(opt)

    # Generate model based on tasks
    logger('Loading models')
    model, parameters, mean, std = generate_model(opt, attr)
    # parameters[0]['lr'] = 0
    # parameters[1]['lr'] = opt.lr / 3

    logger('Loading dataset')
    train_loader, val_loader = get_data(opt, attr, mean, std)
    writer = create_summary_writer(model, train_loader, opt.log_dir)
    # have to after writer
    model = nn.DataParallel(model, device_ids=None)
    # Learning configurations
    if opt.optimizer == 'sgd':
        optimizer = SGD(parameters,
                        lr=opt.lr,
                        momentum=opt.momentum,
                        weight_decay=opt.weight_decay,
                        nesterov=opt.nesterov)
    elif opt.optimizer == 'adam':
        optimizer = Adam(parameters, lr=opt.lr, betas=opt.betas)
    else:
        raise Exception("Not supported")
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               'max',
                                               patience=opt.lr_patience,
                                               factor=opt.factor,
                                               min_lr=1e-6)

    # Loading checkpoint
    if opt.checkpoint:
        logger('loading checkpoint {}'.format(opt.checkpoint))
        checkpoint = torch.load(opt.checkpoint)

        opt.begin_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])

    device = 'cuda'
    loss_fns, metrics = get_losses_metrics(attr, opt.categorical_loss, opt.at,
                                           opt.at_loss)
    trainer = my_trainer(
        model,
        optimizer,
        lambda pred, target, epoch: multitask_loss(
            pred, target, loss_fns, len(attr_name), opt.at_coe, epoch),
        device=device)
    train_evaluator = create_supervised_evaluator(
        model,
        metrics={'multitask': MultiAttributeMetric(metrics, attr_name)},
        device=device)
    val_evaluator = create_supervised_evaluator(
        model,
        metrics={'multitask': MultiAttributeMetric(metrics, attr_name)},
        device=device)

    # Training timer handlers
    model_timer, data_timer = Timer(average=True), Timer(average=True)
    model_timer.attach(trainer,
                       start=Events.EPOCH_STARTED,
                       resume=Events.ITERATION_STARTED,
                       pause=Events.ITERATION_COMPLETED,
                       step=Events.ITERATION_COMPLETED)
    data_timer.attach(trainer,
                      start=Events.EPOCH_STARTED,
                      resume=Events.ITERATION_COMPLETED,
                      pause=Events.ITERATION_STARTED,
                      step=Events.ITERATION_STARTED)

    # Training log/plot handlers
    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter_num = (engine.state.iteration - 1) % len(train_loader) + 1
        if iter_num % opt.log_interval == 0:
            logger(
                "Epoch[{}] Iteration[{}/{}] Sum Loss: {:.2f} Cls Loss: {:.2f} At Loss: {:.2f} "
                "Coe: {:.2f} Model Process: {:.3f}s/batch Data Preparation: {:.3f}s/batch"
                .format(engine.state.epoch, iter_num, len(train_loader),
                        engine.state.output['sum'], engine.state.output['cls'],
                        engine.state.output['at'], engine.state.output['coe'],
                        model_timer.value(), data_timer.value()))
            writer.add_scalar("training/loss", engine.state.output['sum'],
                              engine.state.iteration)

    # Log/Plot Learning rate
    @trainer.on(Events.EPOCH_STARTED)
    def log_learning_rate(engine):
        lr = optimizer.param_groups[-1]['lr']
        logger('Epoch[{}] Starts with lr={}'.format(engine.state.epoch, lr))
        writer.add_scalar("learning_rate", lr, engine.state.epoch)

    # Checkpointing
    @trainer.on(Events.EPOCH_COMPLETED)
    def save_checkpoint(engine):
        if engine.state.epoch % opt.save_interval == 0:
            save_file_path = os.path.join(
                opt.log_dir, 'save_{}.pth'.format(engine.state.epoch))
            states = {
                'epoch': engine.state.epoch,
                'arch': opt.model,
                'state_dict': model.module.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(states, save_file_path)
            # model.eval()
            # example = torch.rand(1, 3, 224, 224)
            # traced_script_module = torch.jit.trace(model, example)
            # traced_script_module.save(save_file_path)
            # model.train()
            # torch.save(model._modules.state_dict(), save_file_path)

    # val_evaluator event handlers
    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        data_list = [train_loader, val_loader]
        name_list = ['train', 'val']
        eval_list = [train_evaluator, val_evaluator]

        for data, name, evl in zip(data_list, name_list, eval_list):
            evl.run(data)
            metrics_info = evl.state.metrics["multitask"]

            for m, val in metrics_info['metrics'].items():
                writer.add_scalar(name + '_metrics/{}'.format(m), val,
                                  engine.state.epoch)

            for m, val in metrics_info['summaries'].items():
                writer.add_scalar(name + '_summary/{}'.format(m), val,
                                  engine.state.epoch)

            logger(
                name +
                ": Validation Results - Epoch: {}".format(engine.state.epoch))
            print_summar_table(logger, attr_name, metrics_info['logger'])

            # Update Learning Rate
            if name == 'train':
                scheduler.step(metrics_info['logger']['attr']['ap'][-1])

    # kick everything off
    logger('Start training')
    trainer.run(train_loader, max_epochs=opt.n_epochs)

    writer.close()
Ejemplo n.º 26
0
def main(config, needs_save, study_name, k, n_splits, output_dir_path):
    if config.run.visible_devices:
        os.environ['CUDA_VISIBLE_DEVICES'] = config.run.visible_devices

    seed = check_manual_seed(config.run.seed)
    print('Using seed: {}'.format(seed))

    train_data_loader, test_data_loader, data_train = get_k_hold_data_loader(
        config.dataset,
        k=k,
        n_splits=n_splits,
    )

    data_train = torch.from_numpy(data_train).float().cuda(non_blocking=True)
    data_train = torch.t(data_train)

    model = get_model(config.model)
    model.cuda()
    model = nn.DataParallel(model)

    criterion = nn.CrossEntropyLoss()

    if config.optimizer.optimizer_name == 'Adam':
        optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, model.parameters()),
            config.optimizer.lr,
            [0.9, 0.9999],
            weight_decay=config.optimizer.weight_decay,
        )
    else:
        raise NotImplementedError

    # scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 0.99 ** epoch)

    def update(engine, batch):
        model.train()

        x = batch['data'].float().cuda(non_blocking=True)
        y = batch['label'].long().cuda(non_blocking=True)

        if config.run.transposed_matrix == 'overall':
            x_t = data_train
        elif config.run.transposed_matrix == 'batch':
            x_t = torch.t(x)

        def closure():
            optimizer.zero_grad()

            if 'MLP' in config.model.model_name:
                out, x_hat = model(x)
            else:
                out, x_hat = model(x, x_t)

            l_discriminative = criterion(out, y)

            l_feature = torch.tensor(0.0).cuda()
            if config.run.w_feature_selection:
                l_feature += config.run.w_feature_selection * torch.sum(torch.abs(model.module.Ue))

            l_recon = torch.tensor(0.0).cuda()
            if config.run.w_reconstruction:
                l_recon += config.run.w_reconstruction * F.mse_loss(x, x_hat)

            l_total = l_discriminative + l_feature + l_recon

            l_total.backward()
            return l_total, l_discriminative, l_feature, l_recon, out

        l_total, l_discriminative, l_feature, l_recon, out = optimizer.step(closure)

        metrics = calc_metrics(out, y)

        metrics.update({
            'l_total': l_total.item(),
            'l_discriminative': l_discriminative.item(),
            'l_feature': l_feature.item(),
            'l_recon': l_recon.item(),
        })

        torch.cuda.synchronize()

        return metrics

    def inference(engine, batch):
        model.eval()

        x = batch['data'].float().cuda(non_blocking=True)
        y = batch['label'].long().cuda(non_blocking=True)

        if config.run.transposed_matrix == 'overall':
            x_t = data_train
        elif config.run.transposed_matrix == 'batch':
            x_t = torch.t(x)

        with torch.no_grad():
            if 'MLP' in config.model.model_name:
                out, x_hat = model(x)
            else:
                out, x_hat = model(x, x_t)

            l_discriminative = criterion(out, y)

            l_feature = torch.tensor(0.0).cuda()
            if config.run.w_feature_selection:
                l_feature += config.run.w_feature_selection * torch.sum(torch.abs(model.module.Ue))

            l_recon = torch.tensor(0.0).cuda()
            if config.run.w_reconstruction:
                l_recon += config.run.w_reconstruction * F.mse_loss(x, x_hat)

            l_total = l_discriminative + l_feature + l_recon

        metrics = calc_metrics(out, y)

        metrics.update({
            'l_total': l_total.item(),
            'l_discriminative': l_discriminative.item(),
            'l_feature': l_feature.item(),
            'l_recon': l_recon.item(),
        })

        torch.cuda.synchronize()

        return metrics

    trainer = Engine(update)
    evaluator = Engine(inference)
    timer = Timer(average=True)

    monitoring_metrics = ['l_total', 'l_discriminative', 'l_feature', 'l_recon', 'accuracy']

    for metric in monitoring_metrics:
        RunningAverage(
            alpha=0.98,
            output_transform=partial(lambda x, metric: x[metric], metric=metric)
        ).attach(trainer, metric)

    for metric in monitoring_metrics:
        RunningAverage(
            alpha=0.98,
            output_transform=partial(lambda x, metric: x[metric], metric=metric)
        ).attach(evaluator, metric)

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)
    pbar.attach(evaluator, metric_names=monitoring_metrics)

    @trainer.on(Events.STARTED)
    def events_started(engine):
        if needs_save:
            save_config(config, seed, output_dir_path)

    @trainer.on(Events.EPOCH_COMPLETED)
    def switch_training_to_evaluation(engine):
        if needs_save:
            save_logs('train', k, n_splits, trainer, trainer.state.epoch, trainer.state.iteration,
                      config, output_dir_path)

        evaluator.run(test_data_loader, max_epochs=1)

    @evaluator.on(Events.EPOCH_COMPLETED)
    def switch_evaluation_to_training(engine):
        if needs_save:
            save_logs('val', k, n_splits, evaluator, trainer.state.epoch, trainer.state.iteration,
                      config, output_dir_path)

            if trainer.state.epoch % 100 == 0:
                save_models(model, optimizer, k, n_splits, trainer.state.epoch, trainer.state.iteration,
                            config, output_dir_path)

        # scheduler.step()

    @trainer.on(Events.EPOCH_COMPLETED)
    @evaluator.on(Events.EPOCH_COMPLETED)
    def show_logs(engine):
        columns = ['k', 'n_splits', 'epoch', 'iteration'] + list(engine.state.metrics.keys())
        values = [str(k), str(n_splits), str(engine.state.epoch), str(engine.state.iteration)] \
               + [str(value) for value in engine.state.metrics.values()]

        message = '[{epoch}/{max_epoch}][{i}/{max_i}]'.format(epoch=engine.state.epoch,
                                                              max_epoch=config.run.n_epochs,
                                                              i=engine.state.iteration,
                                                              max_i=len(train_data_loader))

        for name, value in zip(columns, values):
            message += ' | {name}: {value}'.format(name=name, value=value)

        pbar.log_message(message)

    timer.attach(trainer,
                 start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)

    print('Training starts: [max_epochs] {}, [max_iterations] {}'.format(
        config.run.n_epochs, config.run.n_epochs * len(train_data_loader))
    )

    trainer.run(train_data_loader, config.run.n_epochs)
Ejemplo n.º 27
0
def run(path, model_name, imgaugs, train_batch_size, val_batch_size,
        num_workers, epochs, optim, lr, lr_update_every, gamma, restart_every,
        restart_factor, init_lr_factor, lr_reduce_patience,
        early_stop_patience, log_interval, output, debug):

    print("--- Cifar10 Playground : Training --- ")

    from datetime import datetime
    now = datetime.now()
    log_dir = os.path.join(
        output, "training_{}_{}".format(model_name,
                                        now.strftime("%Y%m%d_%H%M")))
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    log_level = logging.INFO
    if debug:
        log_level = logging.DEBUG
        print("Activated debug mode")

    logger = logging.getLogger("Cifar10 Playground: Train")
    setup_logger(logger, log_dir, log_level)

    logger.debug("Setup tensorboard writer")
    writer = SummaryWriter(log_dir=os.path.join(log_dir, "tensorboard"))

    save_conf(logger, writer, model_name, imgaugs, train_batch_size,
              val_batch_size, num_workers, epochs, optim, lr, lr_update_every,
              gamma, restart_every, restart_factor, init_lr_factor,
              lr_reduce_patience, early_stop_patience, log_dir)

    device = 'cpu'
    if torch.cuda.is_available():
        logger.debug("CUDA is enabled")
        from torch.backends import cudnn
        cudnn.benchmark = True
        device = 'cuda'

    logger.debug("Setup model: {}".format(model_name))

    if not os.path.isfile(model_name):
        assert model_name in MODEL_MAP, "Model name not in {}".format(
            MODEL_MAP.keys())
        model = MODEL_MAP[model_name](num_classes=10)
    else:
        model = torch.load(model_name)

    model_name = model.__class__.__name__
    if 'cuda' in device:
        model = model.to(device)

    logger.debug("Setup train/val dataloaders")
    train_loader, val_loader = get_data_loaders(path,
                                                imgaugs,
                                                train_batch_size,
                                                val_batch_size,
                                                num_workers,
                                                device=device)

    write_model_graph(writer, model, train_loader, device=device)

    logger.debug("Setup optimizer")
    assert optim in OPTIMIZER_MAP, "Optimizer name not in {}".format(
        OPTIMIZER_MAP.keys())
    optimizer = OPTIMIZER_MAP[optim](model.parameters(), lr=lr)

    logger.debug("Setup criterion")
    criterion = nn.CrossEntropyLoss()
    if 'cuda' in device:
        criterion = criterion.cuda()

    lr_scheduler = ExponentialLR(optimizer, gamma=gamma)
    lr_scheduler_restarts = LRSchedulerWithRestart(
        lr_scheduler,
        restart_every=restart_every,
        restart_factor=restart_factor,
        init_lr_factor=init_lr_factor)
    reduce_on_plateau = ReduceLROnPlateau(optimizer,
                                          mode='min',
                                          factor=0.1,
                                          patience=lr_reduce_patience,
                                          threshold=0.01,
                                          verbose=True)

    logger.debug("Setup ignite trainer and evaluator")
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=device)

    metrics = {
        'accuracy': CategoricalAccuracy(),
        'precision': Precision(),
        'recall': Recall(),
        'nll': Loss(criterion)
    }
    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device)
    val_evaluator = create_supervised_evaluator(model,
                                                metrics=metrics,
                                                device=device)

    logger.debug("Setup handlers")
    # Setup timer to measure training time
    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED)

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.state.iteration - 1) % len(train_loader) + 1
        if iter % log_interval == 0:
            logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.4f}".format(
                engine.state.epoch, iter, len(train_loader),
                engine.state.output))

            writer.add_scalar("training/loss_vs_iterations",
                              engine.state.output, engine.state.iteration)

    @trainer.on(Events.EPOCH_STARTED)
    def update_lr_schedulers(engine):
        if (engine.state.epoch - 1) % lr_update_every == 0:
            lr_scheduler_restarts.step()

    @trainer.on(Events.EPOCH_STARTED)
    def log_lrs(engine):
        if len(optimizer.param_groups) == 1:
            lr = float(optimizer.param_groups[0]['lr'])
            writer.add_scalar("learning_rate", lr, engine.state.epoch)
            logger.debug("Learning rate: {}".format(lr))
        else:
            for i, param_group in enumerate(optimizer.param_groups):
                lr = float(param_group['lr'])
                logger.debug("Learning rate (group {}): {}".format(i, lr))
                writer.add_scalar("learning_rate_group_{}".format(i), lr,
                                  engine.state.epoch)

    log_images_dir = os.path.join(log_dir, "figures")
    os.makedirs(log_images_dir)

    def log_precision_recall_results(metrics, epoch, mode):
        for metric_name in ['precision', 'recall']:
            value = metrics[metric_name]
            avg_value = torch.mean(value).item()
            writer.add_scalar("{}/avg_{}".format(mode, metric_name), avg_value,
                              epoch)
            # Save metric per class figure
            sorted_values = value.to('cpu').numpy()
            indices = np.argsort(sorted_values)
            sorted_values = sorted_values[indices]
            n_classes = len(sorted_values)
            classes = np.array(
                ["class_{}".format(i) for i in range(n_classes)])
            sorted_classes = classes[indices]
            fig = create_fig_param_per_class(sorted_values,
                                             metric_name,
                                             classes=sorted_classes,
                                             n_classes_per_fig=20)
            fname = os.path.join(
                log_images_dir,
                "{}_{}_{}_per_class.png".format(mode, epoch, metric_name))
            fig.savefig(fname)
            tag = "{}_{}".format(mode, metric_name)
            writer.add_figure(tag, fig, epoch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_metrics(engine):
        epoch = engine.state.epoch
        logger.info("One epoch training time (seconds): {}".format(
            timer.value()))
        metrics = train_evaluator.run(train_loader).metrics
        logger.info(
            "Training Results - Epoch: {}  Avg accuracy: {:.4f} Avg loss: {:.4f}"
            .format(engine.state.epoch, metrics['accuracy'], metrics['nll']))
        writer.add_scalar("training/avg_accuracy", metrics['accuracy'], epoch)
        writer.add_scalar("training/avg_error", 1.0 - metrics['accuracy'],
                          epoch)
        writer.add_scalar("training/avg_loss", metrics['nll'], epoch)
        log_precision_recall_results(metrics, epoch, "training")

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        epoch = engine.state.epoch
        metrics = val_evaluator.run(val_loader).metrics
        writer.add_scalar("validation/avg_loss", metrics['nll'], epoch)
        writer.add_scalar("validation/avg_accuracy", metrics['accuracy'],
                          epoch)
        writer.add_scalar("validation/avg_error", 1.0 - metrics['accuracy'],
                          epoch)
        logger.info(
            "Validation Results - Epoch: {}  Avg accuracy: {:.4f} Avg loss: {:.4f}"
            .format(engine.state.epoch, metrics['accuracy'], metrics['nll']))
        log_precision_recall_results(metrics, epoch, "validation")

    @val_evaluator.on(Events.COMPLETED)
    def update_reduce_on_plateau(engine):
        val_loss = engine.state.metrics['nll']
        reduce_on_plateau.step(val_loss)

    def score_function(engine):
        val_loss = engine.state.metrics['nll']
        # Objects with highest scores will be retained.
        return -val_loss

    # Setup early stopping:
    handler = EarlyStopping(patience=early_stop_patience,
                            score_function=score_function,
                            trainer=trainer)
    setup_logger(handler._logger, log_dir, log_level)
    val_evaluator.add_event_handler(Events.COMPLETED, handler)

    # Setup model checkpoint:
    best_model_saver = ModelCheckpoint(log_dir,
                                       filename_prefix="model",
                                       score_name="val_loss",
                                       score_function=score_function,
                                       n_saved=5,
                                       atomic=True,
                                       create_dir=True)
    val_evaluator.add_event_handler(Events.COMPLETED, best_model_saver,
                                    {model_name: model})

    last_model_saver = ModelCheckpoint(log_dir,
                                       filename_prefix="checkpoint",
                                       save_interval=1,
                                       n_saved=1,
                                       atomic=True,
                                       create_dir=True)
    trainer.add_event_handler(Events.COMPLETED, last_model_saver,
                              {model_name: model})

    logger.info("Start training: {} epochs".format(epochs))
    try:
        trainer.run(train_loader, max_epochs=epochs)
    except KeyboardInterrupt:
        logger.info("Catched KeyboardInterrupt -> exit")
    except Exception as e:  # noqa
        logger.exception("")
        if args.debug:
            try:
                # open an ipython shell if possible
                import IPython
                IPython.embed()  # noqa
            except ImportError:
                print("Failed to start IPython console")

    logger.debug("Training is ended")
    writer.close()
Ejemplo n.º 28
0
def do_train(cfg, model, train_loader, val_loader, optimizer, scheduler,
             loss_fn, num_query, start_epoch):
    log_period = cfg.SOLVER.LOG_PERIOD
    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    eval_period = cfg.SOLVER.EVAL_PERIOD
    output_dir = cfg.OUTPUT_DIR
    device = cfg.MODEL.DEVICE
    epochs = cfg.SOLVER.MAX_EPOCHS

    logger = logging.getLogger("reid_baseline.train")
    logger.info("Start training")
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        loss_fn,
                                        device=device,
                                        gamma=cfg.MODEL.GAMMA,
                                        margin=cfg.SOLVER.MARGIN,
                                        beta=cfg.MODEL.BETA)
    if cfg.TEST.PAIR == "no":
        evaluator = create_supervised_evaluator(
            model,
            metrics={
                'r1_mAP': R1_mAP(1, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)
            },
            device=device)
    elif cfg.TEST.PAIR == "yes":
        evaluator = create_supervised_evaluator(
            model,
            metrics={
                'r1_mAP': R1_mAP_pair(1,
                                      max_rank=50,
                                      feat_norm=cfg.TEST.FEAT_NORM)
            },
            device=device)
    checkpointer = ModelCheckpoint(output_dir,
                                   cfg.MODEL.NAME,
                                   checkpoint_period,
                                   n_saved=10,
                                   require_empty=False)
    # checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, n_saved=10, require_empty=False)
    timer = Timer(average=True)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {
        'model': model,
        'optimizer': optimizer
    })
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    # average metric to attach on trainer
    RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
    RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')

    dataset = init_dataset(cfg.DATASETS.NAMES, root=cfg.DATASETS.ROOT_DIR)

    @trainer.on(Events.STARTED)
    def start_training(engine):
        engine.state.epoch = start_epoch

    @trainer.on(Events.EPOCH_STARTED)
    def adjust_learning_rate(engine):
        scheduler.step()

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        global ITER
        ITER += 1
        if ITER % log_period == 0:
            logger.info(
                "Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
                .format(engine.state.epoch, ITER, len(train_loader),
                        engine.state.metrics['avg_loss'],
                        engine.state.metrics['avg_acc'],
                        scheduler.get_lr()[0]))
        if len(train_loader) == ITER:
            ITER = 0

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        # multi_person_training_info2()
        train_loader, val_loader, num_query, num_classes = make_data_loader_train(
            cfg)
        logger.info(
            'Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
            .format(engine.state.epoch,
                    timer.value() * timer.step_count,
                    train_loader.batch_size / timer.value()))
        logger.info('-' * 10)
        timer.reset()

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        # if engine.state.epoch % eval_period == 0:
        if engine.state.epoch >= eval_period:
            all_cmc = []
            all_AP = []
            num_valid_q = 0
            q_pids = []
            for query_index in tqdm(range(num_query)):

                val_loader = make_data_loader_val(cfg, query_index, dataset)
                evaluator.run(val_loader)
                cmc, AP, q_pid = evaluator.state.metrics['r1_mAP']

                if AP >= 0:
                    if cmc.shape[0] < 50:
                        continue
                    num_valid_q += 1

                    all_cmc.append(cmc)
                    all_AP.append(AP)
                    q_pids.append(int(q_pid))
                else:
                    continue

            all_cmc = np.asarray(all_cmc).astype(np.float32)
            cmc = all_cmc.sum(0) / num_valid_q
            mAP = np.mean(all_AP)
            logger.info("Validation Results - Epoch: {}".format(
                engine.state.epoch))
            logger.info("mAP: {:.1%}".format(mAP))
            for r in [1, 5, 10]:
                logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(
                    r, cmc[r - 1]))

    trainer.run(train_loader, max_epochs=epochs)
Ejemplo n.º 29
0
def main(
    dataset,
    dataroot,
    download,
    augment,
    batch_size,
    eval_batch_size,
    epochs,
    saved_model,
    seed,
    hidden_channels,
    K,
    L,
    actnorm_scale,
    flow_permutation,
    flow_coupling,
    LU_decomposed,
    learn_top,
    y_condition,
    y_weight,
    max_grad_clip,
    max_grad_norm,
    lr,
    n_workers,
    cuda,
    n_init_batches,
    output_dir,
    saved_optimizer,
    warmup,
):

    device = "cpu" if (not torch.cuda.is_available() or not cuda) else "cuda:0"

    check_manual_seed(seed)

    ds = check_dataset(dataset, dataroot, augment, download)
    image_shape, num_classes, train_dataset, test_dataset = ds

    # Note: unsupported for now
    multi_class = False

    train_loader = data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=n_workers,
        drop_last=True,
    )
    test_loader = data.DataLoader(
        test_dataset,
        batch_size=eval_batch_size,
        shuffle=False,
        num_workers=n_workers,
        drop_last=False,
    )

    model = Glow(
        image_shape,
        hidden_channels,
        K,
        L,
        actnorm_scale,
        flow_permutation,
        flow_coupling,
        LU_decomposed,
        num_classes,
        learn_top,
        y_condition,
    )

    model = model.to(device)
    optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup)  # noqa
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                  lr_lambda=lr_lambda)

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = batch
        x = x.to(device)

        if y_condition:
            y = y.to(device)
            z, nll, y_logits = model(x, y)
            losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class)
        else:
            z, nll, y_logits = model(x, None)
            losses = compute_loss(nll)

        losses["total_loss"].backward()

        if max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        if max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits = model(x, y)
                losses = compute_loss_y(nll,
                                        y_logits,
                                        y_weight,
                                        y,
                                        multi_class,
                                        reduction="none")
            else:
                z, nll, y_logits = model(x, None)
                losses = compute_loss(nll, reduction="none")

        return losses

    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir,
                                         "glow",
                                         save_interval=1,
                                         n_saved=2,
                                         require_empty=False)

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        checkpoint_handler,
        {
            "model": model,
            "optimizer": optimizer
        },
    )

    monitoring_metrics = ["total_loss"]
    RunningAverage(output_transform=lambda x: x["total_loss"]).attach(
        trainer, "total_loss")

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(
        lambda x, y: torch.mean(x),
        output_transform=lambda x: (
            x["total_loss"],
            torch.empty(x["total_loss"].shape[0]),
        ),
    ).attach(evaluator, "total_loss")

    if y_condition:
        monitoring_metrics.extend(["nll"])
        RunningAverage(output_transform=lambda x: x["nll"]).attach(
            trainer, "nll")

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(
            lambda x, y: torch.mean(x),
            output_transform=lambda x:
            (x["nll"], torch.empty(x["nll"].shape[0])),
        ).attach(evaluator, "nll")

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        model.load_state_dict(torch.load(saved_model))
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer))

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split("_")[-1])

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(
                engine.state.dataloader)

    @trainer.on(Events.STARTED)
    def init(engine):
        model.train()

        init_batches = []
        init_targets = []

        with torch.no_grad():
            for batch, target in islice(train_loader, None, n_init_batches):
                init_batches.append(batch)
                init_targets.append(target)

            init_batches = torch.cat(init_batches).to(device)

            assert init_batches.shape[0] == n_init_batches * batch_size

            if y_condition:
                init_targets = torch.cat(init_targets).to(device)
            else:
                init_targets = None

            model(init_batches, init_targets)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(engine):
        evaluator.run(test_loader)

        scheduler.step()
        metrics = evaluator.state.metrics

        losses = ", ".join(
            [f"{key}: {value:.2f}" for key, value in metrics.items()])

        print(f"Validation Results - Epoch: {engine.state.epoch} {losses}")

    timer = Timer(average=True)
    timer.attach(
        trainer,
        start=Events.EPOCH_STARTED,
        resume=Events.ITERATION_STARTED,
        pause=Events.ITERATION_COMPLETED,
        step=Events.ITERATION_COMPLETED,
    )

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f"Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]"
        )
        timer.reset()

    trainer.run(train_loader, epochs)
Ejemplo n.º 30
0
def run(cfg, train_loader, tr_comp, saver, trainer, evaler, tb_log=None):
    # saver.checkpoint_params = {'trainer': trainer,
    #                  'model': tr_comp.model}
    # 'optimizer': tr_comp.optimizer,
    # 'center_param': tr_comp.loss_center,
    # 'optimizer_center': tr_comp.optimizer_center}
    trainer.add_event_handler(
        Events.EPOCH_COMPLETED(every=cfg.SAVER.CHECKPOINT_PERIOD),
        saver.train_checkpointer, saver.checkpoint_params)
    timer = Timer(average=True)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)
    # average metric to attach on trainer
    names = ["Acc", "Loss"]
    names.extend(tr_comp.loss.loss_function_map.keys())
    for n in names:
        RunningAverage(output_transform=Run(n)).attach(trainer, n)

    # TODO start epoch
    @trainer.on(Events.STARTED)
    def start_training(engine):
        engine.state.epoch = 0

    @trainer.on(Events.EPOCH_STARTED)
    def adjust_learning_rate(engine):
        tr_comp.loss.scheduler_step()
        tr_comp.scheduler.step()

    @trainer.on(Events.ITERATION_COMPLETED(every=cfg.TRAIN.LOG_ITER_PERIOD))
    def log_training_loss(engine):
        message = f"Epoch[{engine.state.epoch}], " + \
                  f"Iteration[{engine.state.iteration}/{len(train_loader)}], " + \
                  f"Base Lr: {tr_comp.scheduler.get_last_lr()[0]:.2e}, "

        if tr_comp.loss.xent and tr_comp.loss.xent.learning_weight:
            message += f"xentWeight: {tr_comp.loss.xent.uncertainty.mean().item():.4f}, "

        for loss_name in engine.state.metrics.keys():
            message += f"{loss_name}: {engine.state.metrics[loss_name]:.4f}, "

        logger.info(message)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        logger.info(
            'Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
            .format(engine.state.epoch,
                    timer.value() * timer.step_count,
                    train_loader.batch_size / timer.value()))
        logger.info('-' * 80)
        timer.reset()

    @trainer.on(Events.EPOCH_COMPLETED(every=cfg.EVAL.EPOCH_PERIOD),
                saver=saver)
    def log_validation_results(engine, saver):
        logger.info(f"Valid - Epoch: {engine.state.epoch}")

        sum_result = evaler.eval_multi_dataset()

        if saver.best_result < sum_result:
            logger.info(f'Save best: {sum_result:.4f}')
            saver.save_best_value(sum_result)
            saver.best_checkpointer(engine, saver.checkpoint_params)
            saver.best_result = sum_result
        else:
            logger.info(
                f"Not best: {saver.best_result:.4f} > {sum_result:.4f}")
        logger.info('-' * 80)

    if tb_log:
        tb_log.attach_handler(trainer, tr_comp.model, tr_comp.optimizer)
    # self.tb_logger.attach(
    #     validation_evaluator,
    #     log_handler=ReIDOutputHandler(tag="valid", metric_names=["r1_mAP"], another_engine=trainer),
    #     event_name=Events.EPOCH_COMPLETED,
    # )
    trainer.run(train_loader, max_epochs=cfg.TRAIN.MAX_EPOCHS)
    if tb_log:
        tb_log.close()