Exemplo n.º 1
0
def main(
    model_name,
    dataset,
    dataroot,
    download,
    augment,
    batch_size,
    eval_batch_size,
    epochs,
    saved_model,
    seed,
    hidden_channels,
    K,
    L,
    actnorm_scale,
    flow_permutation,
    flow_coupling,
    LU_decomposed,
    learn_top,
    y_condition,
    y_weight,
    max_grad_clip,
    max_grad_norm,
    lr,
    n_workers,
    cuda,
    n_init_batches,
    output_dir,
    saved_optimizer,
    warmup,
):

    vis = visdom.Visdom()
    env = "{}_{}".format(model_name, dataset)

    device = "cpu" if (not torch.cuda.is_available() or not cuda) else "cuda:0"

    check_manual_seed(seed)

    ds = check_dataset(dataset, dataroot, augment, download)
    image_shape, num_classes, train_dataset, test_dataset = ds

    # Note: unsupported for now
    multi_class = False

    train_loader = data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=n_workers,
        drop_last=True,
    )
    test_loader = data.DataLoader(
        test_dataset,
        batch_size=eval_batch_size,
        shuffle=False,
        num_workers=n_workers,
        drop_last=False,
    )

    if model_name == "Glow":
        model = Glow(
            image_shape,
            hidden_channels,
            K,
            L,
            actnorm_scale,
            flow_permutation,
            flow_coupling,
            LU_decomposed,
            num_classes,
            learn_top,
            y_condition,
        )
    elif model_name == "VAE":
        model = VAE(
            image_shape,
            hidden_channels,
        )

    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=5e-5)

    lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup)  # noqa
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                  lr_lambda=lr_lambda)

    train_loss_window = create_plot_window(vis, env, '#Iterations', 'Loss',
                                           'Training Loss')
    val_avg_loss_window = create_plot_window(vis, env, '#Epochs', 'Loss',
                                             'Validation Average Loss')
    train_image_window = create_image_window(vis, env, 'Training Images')

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = batch
        x = x.to(device)

        if y_condition:
            y = y.to(device)
            z, nll, y_logits = model(x, y)
            losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class)
        else:
            z, nll, y_logits, im = model(x)
            losses = compute_loss(nll)
        if engine.state.iteration % 250 == 1:
            vis.line(X=np.array([engine.state.iteration]),
                     Y=np.array([losses["total_loss"].item()]),
                     win=train_loss_window,
                     update='append',
                     env=env)
            vis.images(postprocess(im),
                       nrow=16,
                       win=train_image_window,
                       env=env)

        losses["total_loss"].backward()

        if max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        if max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits = model(x, y)
                losses = compute_loss_y(nll,
                                        y_logits,
                                        y_weight,
                                        y,
                                        multi_class,
                                        reduction="none")
            else:
                z, nll, y_logits, im = model(x)
                losses = compute_loss(nll, reduction="none")

        return losses

    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir,
                                         model_name,
                                         save_interval=1,
                                         n_saved=5,
                                         require_empty=False)

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        checkpoint_handler,
        {
            "model": model,
            "optimizer": optimizer
        },
    )

    monitoring_metrics = ["total_loss"]
    RunningAverage(output_transform=lambda x: x["total_loss"]).attach(
        trainer, "total_loss")

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(
        lambda x, y: torch.mean(x),
        output_transform=lambda x: (
            x["total_loss"],
            torch.empty(x["total_loss"].shape[0]),
        ),
    ).attach(evaluator, "total_loss")

    if y_condition:
        monitoring_metrics.extend(["nll"])
        RunningAverage(output_transform=lambda x: x["nll"]).attach(
            trainer, "nll")

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(
            lambda x, y: torch.mean(x),
            output_transform=lambda x:
            (x["nll"], torch.empty(x["nll"].shape[0])),
        ).attach(evaluator, "nll")

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        model.load_state_dict(torch.load(saved_model))
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer))

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split("_")[-1])

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(
                engine.state.dataloader)

    @trainer.on(Events.STARTED)
    def init(engine):
        model.train()

        init_batches = []
        init_targets = []

        with torch.no_grad():
            for batch, target in islice(train_loader, None, n_init_batches):
                init_batches.append(batch)
                init_targets.append(target)

            init_batches = torch.cat(init_batches).to(device)

            assert init_batches.shape[0] == n_init_batches * batch_size

            if y_condition:
                init_targets = torch.cat(init_targets).to(device)
                model(init_batches, init_targets)
            else:
                init_targets = None
                model(init_batches)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(engine):
        evaluator.run(test_loader)

        scheduler.step()
        metrics = evaluator.state.metrics

        losses = ", ".join(
            [f"{key}: {value:.2f}" for key, value in metrics.items()])
        vis.line(X=np.array([engine.state.epoch]),
                 Y=np.array([metrics["total_loss"]]),
                 win=val_avg_loss_window,
                 update='append',
                 env=env)
        print(f"Validation Results - Epoch: {engine.state.epoch} {losses}")

    timer = Timer(average=True)
    timer.attach(
        trainer,
        start=Events.EPOCH_STARTED,
        resume=Events.ITERATION_STARTED,
        pause=Events.ITERATION_COMPLETED,
        step=Events.ITERATION_COMPLETED,
    )

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f"Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]"
        )
        timer.reset()

    trainer.run(train_loader, epochs)
Exemplo n.º 2
0
def do_train(
        cfg,
        model,
        data_loader,
        optimizer,
        scheduler,
        criterion,
        num_query,
        start_epoch
):
    log_period = cfg.SOLVER.LOG_PERIOD
    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    eval_period = cfg.SOLVER.EVAL_PERIOD
    output_dir = cfg.OUTPUT_DIR
    device = cfg.MODEL.DEVICE
    epochs = cfg.SOLVER.MAX_EPOCHS

    logger = logging.getLogger("reid_baseline")
    logger.info("Start training")

    trainer = create_supervised_trainer(model, optimizer, criterion, cfg.SOLVER.CENTER_LOSS_WEIGHT, device=device)

    if cfg.TEST.PARTIAL_REID == 'off':
        evaluator = create_supervised_evaluator(model, metrics={
            'r1_mAP_mINP': r1_mAP_mINP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device)
    else:
        evaluator_reid = create_supervised_evaluator(model,
                                                     metrics={'r1_mAP_mINP': r1_mAP_mINP(300, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)},
                                                     device=device)
        evaluator_ilids = create_supervised_evaluator(model,
                                                      metrics={'r1_mAP_mINP': r1_mAP_mINP(119, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)},
                                                      device=device)
    checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': model,
                                                                     'optimizer': optimizer['model'],
                                                                     'center_param': criterion['center'],
                                                                     'optimizer_center': optimizer['center']})
    timer = Timer(average=True)
    timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED)

    # average metric to attach on trainer
    RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
    RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')

    @trainer.on(Events.STARTED)
    def start_training(engine):
        engine.state.epoch = start_epoch

    @trainer.on(Events.EPOCH_STARTED)
    def adjust_learning_rate(engine):
        scheduler.step()

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        global ITER
        ITER += 1

        if ITER % log_period == 0:
            logger.info("Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
                        .format(engine.state.epoch, ITER, len(data_loader['train']),
                                engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'],
                                scheduler.get_lr()[0]))
        if len(data_loader['train']) == ITER:
            ITER = 0

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        logger.info('Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
                    .format(engine.state.epoch, timer.value() * timer.step_count,
                            data_loader['train'].batch_size / timer.value()))
        logger.info('-' * 10)
        timer.reset()

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        if engine.state.epoch % eval_period == 0:
            if cfg.TEST.PARTIAL_REID == 'off':
                evaluator.run(data_loader['eval'])
                cmc, mAP, mINP = evaluator.state.metrics['r1_mAP_mINP']
                logger.info("Validation Results - Epoch: {}".format(engine.state.epoch))
                logger.info("mINP: {:.1%}".format(mINP))
                logger.info("mAP: {:.1%}".format(mAP))
                for r in [1, 5, 10]:
                    logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
            else:
                evaluator_reid.run(data_loader['eval_reid'])
                cmc, mAP, mINP = evaluator_reid.state.metrics['r1_mAP_mINP']
                logger.info("Validation Results - Epoch: {}".format(engine.state.epoch))
                logger.info("mINP: {:.1%}".format(mINP))
                logger.info("mAP: {:.1%}".format(mAP))
                for r in [1, 3, 5, 10]:
                    logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))


                evaluator_ilids.run(data_loader['eval_ilids'])
                cmc, mAP, mINP = evaluator_ilids.state.metrics['r1_mAP_mINP']
                logger.info("Validation Results - Epoch: {}".format(engine.state.epoch))
                logger.info("mINP: {:.1%}".format(mINP))
                logger.info("mAP: {:.1%}".format(mAP))
                for r in [1, 3, 5, 10]:
                    logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))

    trainer.run(data_loader['train'], max_epochs=epochs)
def train(): 
    parser = ArgumentParser()
    parser.add_argument("--train_path", type=str, default='data/spolin-train-acl.json', help="Set data path")    
    parser.add_argument("--valid_path", type=str, default='data/spolin-valid.json', help="Set data path")     

    parser.add_argument("--correct_bias", type=bool, default=False, help="Set to true to correct bias for Adam optimizer")
    parser.add_argument("--lr", type=float, default=2e-5, help="Set learning rate")
    parser.add_argument("--n_epochs", type=int, default=4, help="Set number of epochs")
    parser.add_argument("--num_warmup_steps", type=float, default=1000, help="Set number of warm-up steps")
    parser.add_argument("--num_total_steps", type=float, default=10000, help="Set number of total steps")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)")
    parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Set maximum gradient normalization.")
    parser.add_argument("--pretrained_path", type=str, default='bert-base-uncased', help="Choose which pretrained model to use (bert-base-uncased, roberta-base, roberta-large, roberta-large-mnli)")    
    parser.add_argument("--batch_size", type=int, default=32, help="Provide the batch size")    
    parser.add_argument("--random_seed", type=int, default=42, help="Set the random seed")
    parser.add_argument("--test", action='store_true', help="If true, run with small dataset for testing code")
    parser.add_argument("--base", action='store_true', help="If true, run with base experiment configuration (training with spont only) for comparison")

    args = parser.parse_args() 

    logging.basicConfig(level=logging.INFO)
    logger.info("Arguments: {}".format(pformat(args)))

    if 'roberta' in args.pretrained_path: 
        # initialize tokenizer and model 
        logger.info("Initialize model and tokenizer.")
        tokenizer = RobertaTokenizer.from_pretrained(args.pretrained_path, cache_dir = '../pretrained_models')
        model = RobertaForSequenceClassification.from_pretrained(args.pretrained_path, cache_dir='../pretrained_models')

        ### START MODEL MODIFICATION
        # Pretrained model was not trained with token type ids. 
        # fix token type embeddings for finetuning. Without this, the model can only take 0s as valid input for token_type_ids 
        model.config.type_vocab_size = 2 
        model.roberta.embeddings.token_type_embeddings = torch.nn.Embedding(2, model.config.hidden_size)
        model.roberta.embeddings.token_type_embeddings.weight.data.normal_(mean=0.0, std=model.config.initializer_range)

        ### END MOD
    elif 'bert' in args.pretrained_path: 
        model = BertForSequenceClassification.from_pretrained(args.pretrained_path, cache_dir='../pretrained_models')
        tokenizer = BertTokenizer.from_pretrained(args.pretrained_path, cache_dir='../pretrained_models')

    model.to(args.device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']

    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01}, 
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
    ]

    optimizer = AdamW(optimizer_grouped_parameters,
                        lr=args.lr,
                        correct_bias = args.correct_bias)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.num_warmup_steps, t_total=args.num_total_steps) 

    logger.info("Prepare datasets")
    logger.info("Loading train set...")

    train_data = get_data(args.train_path)
    valid_data = get_data(args.valid_path)

    cornell_valid_data = {k: {'cornell': valid_data[k]['cornell']} for k in valid_data.keys()}
    spont_valid_data = {k: {'spont': valid_data[k]['spont']} for k in valid_data.keys()}

    train_loader, train_sampler = get_data_loaders(args, train_data, args.train_path, tokenizer)
    logger.info("Loading validation set...")
    valid_p = Path(args.valid_path)
    cornell_valid_loader, cornell_valid_sampler = get_data_loaders(args, cornell_valid_data, f"{str(valid_p.parent)}/cornell_{valid_p.name}",  tokenizer)
    spont_valid_loader, spont_valid_sampler = get_data_loaders(args, spont_valid_data, f"{str(valid_p.parent)}/spont_{valid_p.name}", tokenizer)


    # Training function and trainer 
    def update(engine, batch): 
        model.train() 

        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        b_input_ids, b_input_mask, b_input_segment, b_labels = batch

        optimizer.zero_grad()
        #roberta has issues with token_type_ids 
        loss, logits = model(b_input_ids, token_type_ids=b_input_segment, attention_mask=b_input_mask, labels=b_labels)
        # loss, logits = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)


        loss.backward() 
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
        
        optimizer.step() 
        scheduler.step() 

        return loss.item(), logits, b_labels

    trainer = Engine(update)     

    # Evaluation function and evaluator 
    def inference(engine, batch): 
        model.eval() 

        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        b_input_ids, b_input_mask, b_input_segment, b_labels = batch
        
        with torch.no_grad(): 
            #roberta has issues with token_type_ids 
            # loss, logits = model(b_input_ids, token_type_ids = None, attention_mask=b_input_mask, labels=b_labels)
            loss, logits = model(b_input_ids, token_type_ids = b_input_segment, attention_mask=b_input_mask, labels=b_labels)
            label_ids = b_labels

        return logits, label_ids, loss.item()
    cornell_evaluator = Engine(inference)
    spont_evaluator = Engine(inference)


    trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: cornell_evaluator.run(cornell_valid_loader))
    trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: spont_evaluator.run(spont_valid_loader))


    RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "loss") 
    RunningAverage(Accuracy(output_transform=lambda x: (x[1], x[2]))).attach(trainer, "accuracy")
    if torch.cuda.is_available(): 
        GpuInfo().attach(trainer, name='gpu')

    recall = Recall(output_transform=lambda x: (x[0], x[1]))
    precision = Precision(output_transform=lambda x: (x[0], x[1]))
    F1 = (precision * recall * 2 / (precision + recall)).mean()
    accuracy = Accuracy(output_transform=lambda x: (x[0], x[1]))
    metrics = {"recall": recall, "precision": precision, "f1": F1, "accuracy": accuracy, "loss": Average(output_transform=lambda x: x[2])}

    for name, metric in metrics.items(): 
        metric.attach(cornell_evaluator, name) 
        metric.attach(spont_evaluator, name) 


    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names=['loss', 'accuracy'])
    pbar.attach(trainer, metric_names=['gpu:0 mem(%)', 'gpu:0 util(%)'])
    
    cornell_evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Cornell validation metrics:\n %s" % pformat(cornell_evaluator.state.metrics)))
    spont_evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Spont validation metrics:\n %s" % pformat(spont_evaluator.state.metrics)))


    tb_logger = TensorboardLogger(log_dir=None)
    tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED)
    tb_logger.attach(cornell_evaluator, log_handler=OutputHandler(tag="valid", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED)
    tb_logger.attach(spont_evaluator, log_handler=OutputHandler(tag="valid", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED)


    # tb_logger.writer.log_dir -> tb_logger.writer.logdir (this is the correct attribute name as seen in: https://tensorboardx.readthedocs.io/en/latest/_modules/tensorboardX/writer.html#SummaryWriter)
    checkpoint_handler = ModelCheckpoint(tb_logger.writer.logdir, 'checkpoint', save_interval=1, n_saved=5)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model)})  # "getattr" take care of distributed encapsulation

    torch.save(args, tb_logger.writer.logdir + '/model_training_args.bin')
    getattr(model, 'module', model).config.to_json_file(os.path.join(tb_logger.writer.logdir, CONFIG_NAME))
    tokenizer.save_vocabulary(tb_logger.writer.logdir)

    trainer.run(train_loader, max_epochs = args.n_epochs)

    if args.n_epochs > 0: 
        os.rename(checkpoint_handler._saved[-1][1][-1], os.path.join(tb_logger.writer.logdir, WEIGHTS_NAME))  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Exemplo n.º 4
0
def train():
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset_path",
        type=str,
        default="",
        help="Path or url of the dataset. If empty download from S3.")
    parser.add_argument("--dataset_cache",
                        type=str,
                        default='./dataset_cache',
                        help="Path or url of the dataset cache")
    parser.add_argument("--model_checkpoint",
                        type=str,
                        default="openai-gpt",
                        help="Path, url or short name of the model")
    parser.add_argument("--num_candidates",
                        type=int,
                        default=2,
                        help="Number of candidates for training")
    parser.add_argument("--max_history",
                        type=int,
                        default=2,
                        help="Number of previous exchanges to keep in history")
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for training")
    parser.add_argument("--valid_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=8,
                        help="Accumulate gradients on several steps")
    parser.add_argument("--lr",
                        type=float,
                        default=6.25e-5,
                        help="Learning rate")
    parser.add_argument("--lm_coef",
                        type=float,
                        default=1.0,
                        help="LM loss coefficient")
    parser.add_argument("--mc_coef",
                        type=float,
                        default=1.0,
                        help="Multiple-choice loss coefficient")
    parser.add_argument("--max_norm",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=3,
                        help="Number of training epochs")
    parser.add_argument("--personality_permutations",
                        type=int,
                        default=1,
                        help="Number of permutations of personality sentences")
    parser.add_argument(
        "--eval_before_start",
        action='store_true',
        help="If true start with a first evaluation before training")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument(
        "--fp16",
        type=str,
        default="",
        help=
        "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank for distributed training (-1: not distributed)")
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", args.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info("Prepare tokenizer, pretrained model and optimizer.")
    tokenizer_class = GPT2Tokenizer if "gpt2" in args.model_checkpoint else OpenAIGPTTokenizer  # cant use Autotokenizer because checkpoint could be a Path
    tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)

    model_class = GPT2DoubleHeadsModel if "gpt2" in args.model_checkpoint else OpenAIGPTDoubleHeadsModel
    model = model_class.from_pretrained(args.model_checkpoint)
    model.to(args.device)
    # Add special tokens if they are not already added
    add_special_tokens_(model, tokenizer)
    optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        args, tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, position_ids = batch
        (lm_loss), (mc_loss), *_ = model(input_ids,
                                         token_type_ids=token_type_ids,
                                         mc_token_ids=mc_token_ids,
                                         mc_labels=mc_labels,
                                         lm_labels=lm_labels,
                                         position_ids=position_ids)
        loss = (lm_loss * args.lm_coef +
                mc_loss * args.mc_coef) / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
            logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            # if we dont send labels to model, it doesnt return losses
            lm_logits, mc_logits, *_ = model(
                input_ids,
                token_type_ids=token_type_ids,
                mc_token_ids=mc_token_ids,
            )
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted,
                    mc_logits), (lm_labels_flat_shifted, mc_labels)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-100),
             output_transform=lambda x: (x[0][0], x[1][0])),
        "accuracy":
        Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args),
        "average_accuracy":
        MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        log_dir = make_logdir(args.model_checkpoint)
        tb_logger = TensorboardLogger(log_dir)

        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" takes care of distributed encapsulation

        torch.save(args, log_dir + '/model_training_args.bin')
        getattr(model, 'module',
                model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME))
        tokenizer.save_pretrained(log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(
            os.path.join(log_dir, checkpoint_handler._saved[-1][1]),
            os.path.join(log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Exemplo n.º 5
0
def train(model,
          train_loader,
          eval_loaders,
          optimizer,
          loss_fn,
          n_it_max,
          patience,
          split_names,
          select_metric='Val accuracy_0',
          select_mode='max',
          viz=None,
          device='cpu',
          lr_scheduler=None,
          name=None,
          log_steps=None,
          log_epoch=False,
          _run=None,
          prepare_batch=_prepare_batch,
          single_pass=False,
          n_ep_max=None):

    # print(model)

    if not log_steps and not log_epoch:
        logger.warning('/!\\ No logging during training /!\\')

    if log_steps is None:
        log_steps = []

    epoch_steps = len(train_loader)
    if log_epoch:
        log_steps.append(epoch_steps)

    if single_pass:
        max_epoch = 1
    elif n_ep_max is None:
        assert n_it_max is not None
        max_epoch = int(n_it_max / epoch_steps) + 1
    else:
        assert n_it_max is None
        max_epoch = n_ep_max

    all_metrics = defaultdict(dict)
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        loss_fn,
                                        device=device,
                                        prepare_batch=prepare_batch)

    if hasattr(model, 'new_epoch_hook'):
        trainer.add_event_handler(Events.EPOCH_STARTED, model.new_epoch_hook)
    if hasattr(model, 'new_iter_hook'):
        trainer.add_event_handler(Events.ITERATION_STARTED,
                                  model.new_iter_hook)

    trainer._logger.setLevel(logging.WARNING)

    # trainer output is in the format (x, y, y_pred, loss, optionals)
    train_loss = RunningAverage(output_transform=lambda out: out[3].item(),
                                epoch_bound=True)
    train_loss.attach(trainer, 'Trainer loss')
    if hasattr(model, 's'):
        met = Average(output_transform=lambda _: float('nan')
                      if model.s is None else model.s)
        met.attach(trainer, 'cur_s')
        trainer.add_event_handler(Events.ITERATION_COMPLETED, met.completed,
                                  'cur_s')

    if hasattr(model, 'arch_sampler') and model.arch_sampler.distrib_dim > 0:
        met = Average(output_transform=lambda _: float('nan')
                      if model.cur_split is None else model.cur_split)
        met.attach(trainer, 'Trainer split')
        trainer.add_event_handler(Events.ITERATION_COMPLETED, met.completed,
                                  'Trainer split')
        # trainer.add_event_handler(Events.EPOCH_STARTED, met.started)
        all_ent = Average(
            output_transform=lambda out: out[-1]['arch_entropy_avg'].item())
        all_ent.attach(trainer, 'Trainer all entropy')
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  all_ent.completed, 'Trainer all entropy')
        train_ent = Average(
            output_transform=lambda out: out[-1]['arch_entropy_sample'].item())
        train_ent.attach(trainer, 'Trainer sampling entropy')
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  train_ent.completed,
                                  'Trainer sampling entropy')
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, lambda engine: model.check_arch_freezing(
                ent=train_ent.compute(),
                epoch=engine.state.iteration / (epoch_steps * max_epoch)))

        def log_always(engine, name):
            val = engine.state.output[-1][name]
            all_metrics[name][engine.state.iteration /
                              epoch_steps] = val.mean().item()

        def log_always_dict(engine, name):
            for node, val in engine.state.output[-1][name].items():
                all_metrics['node {} {}'.format(
                    node, name)][engine.state.iteration /
                                 epoch_steps] = val.mean().item()

        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  log_always_dict,
                                  name='arch_grads')
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  log_always_dict,
                                  name='arch_probas')
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  log_always_dict,
                                  name='node_grads')
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  log_always,
                                  name='task all_loss')
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  log_always,
                                  name='arch all_loss')
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  log_always,
                                  name='entropy all_loss')

    if n_it_max is not None:
        StopAfterIterations([n_it_max]).attach(trainer)
    # epoch_pbar = ProgressBar(bar_format='{l_bar}{bar}{r_bar}', desc=name,
    #                          persist=True, disable=not (_run or viz))
    # epoch_pbar.attach(trainer, metric_names=['Train loss'])
    #
    # training_pbar = ProgressBar(bar_format='{l_bar}{bar}{r_bar}', desc=name,
    #                             persist=True, disable=not (_run or viz))
    # training_pbar.attach(trainer, event_name=Events.EPOCH_COMPLETED,
    #                      closing_event_name=Events.COMPLETED)
    total_time = Timer(average=False)
    eval_time = Timer(average=False)
    eval_time.pause()
    data_time = Timer(average=False)
    forward_time = Timer(average=False)
    forward_time.attach(trainer,
                        start=Events.EPOCH_STARTED,
                        pause=Events.ITERATION_COMPLETED,
                        resume=Events.ITERATION_STARTED,
                        step=Events.ITERATION_COMPLETED)
    epoch_time = Timer(average=False)
    epoch_time.attach(trainer,
                      start=Events.EPOCH_STARTED,
                      pause=Events.EPOCH_COMPLETED,
                      resume=Events.EPOCH_STARTED,
                      step=Events.EPOCH_COMPLETED)

    def get_loss(y_pred, y):
        l = loss_fn(y_pred, y)
        if not torch.is_tensor(l):
            l, *l_details = l
        return l.mean()

    def get_member(x, n=0):
        if isinstance(x, (list, tuple)):
            return x[n]
        return x

    eval_metrics = {'loss': Loss(get_loss)}

    for i in range(model.n_out):
        out_trans = get_attr_transform(i)

        def extract_ys(out):
            x, y, y_pred, loss, _ = out
            return out_trans((y_pred, y))

        train_acc = Accuracy(extract_ys)
        train_acc.attach(trainer, 'Trainer accuracy_{}'.format(i))
        trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                  train_acc.completed,
                                  'Trainer accuracy_{}'.format(i))
        eval_metrics['accuracy_{}'.format(i)] = \
            Accuracy(output_transform=out_trans)
        # if isinstance(model, SSNWrapper):
        #     model.arch_sampler.entropy().mean()

    evaluator = create_supervised_evaluator(model,
                                            metrics=eval_metrics,
                                            device=device,
                                            prepare_batch=prepare_batch)
    last_iteration = 0
    patience_counter = 0

    best = {
        'value': float('inf') * 1 if select_mode == 'min' else -1,
        'iter': -1,
        'state_dict': None
    }

    def is_better(new, old):
        if select_mode == 'min':
            return new < old
        else:
            return new > old

    def log_results(evaluator, data_loader, iteration, split_name):
        evaluator.run(data_loader)
        metrics = evaluator.state.metrics

        log_metrics = {}

        for metric_name, metric_val in metrics.items():
            log_name = '{} {}'.format(split_name, metric_name)
            if viz:
                first = iteration == 0 and split_name == split_names[0]
                viz.line(
                    [metric_val],
                    X=[iteration],
                    win=metric_name,
                    name=log_name,
                    update=None if first else 'append',
                    opts={
                        'title': metric_name,
                        'showlegend': True,
                        'width': 500,
                        'xlabel': 'iterations'
                    })
                viz.line(
                    [metric_val],
                    X=[iteration / epoch_steps],
                    win='{}epoch'.format(metric_name),
                    name=log_name,
                    update=None if first else 'append',
                    opts={
                        'title': metric_name,
                        'showlegend': True,
                        'width': 500,
                        'xlabel': 'epoch'
                    })
            if _run:
                _run.log_scalar(log_name, metric_val, iteration)
            log_metrics[log_name] = metric_val
            all_metrics[log_name][iteration] = metric_val

        return log_metrics

    if lr_scheduler is not None:

        @trainer.on(Events.EPOCH_COMPLETED)
        def step(_):
            lr_scheduler.step()
            # logger.warning('current lr {:.5e}'.format(
            #     optimizer.param_groups[0]['lr']))

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_event(trainer):
        iteration = trainer.state.iteration if trainer.state else 0
        nonlocal last_iteration, patience_counter, best

        if not log_steps or not \
                (iteration in log_steps or iteration % log_steps[-1] == 0):
            return
        epoch_time.pause()
        eval_time.resume()
        all_metrics['training_epoch'][iteration] = iteration / epoch_steps
        all_metrics['training_iteration'][iteration] = iteration
        if hasattr(model, 'arch_sampler'):
            all_metrics['training_archs'][iteration] = \
                model.arch_sampler().squeeze().detach()
        # if hasattr(model, 'distrib_gen'):
        #     entropy = model.distrib_gen.entropy()
        #     all_metrics['entropy'][iteration] = entropy.mean().item()
        # if trainer.state and len(trainer.state.metrics) > 1:
        #     raise ValueError(trainer.state.metrics)
        all_metrics['data time'][iteration] = data_time.value()
        all_metrics['data time_ps'][iteration] = data_time.value() / max(
            data_time.step_count, 1.)
        all_metrics['forward time'][iteration] = forward_time.value()
        all_metrics['forward time_ps'][iteration] = forward_time.value() / max(
            forward_time.step_count, 1.)
        all_metrics['epoch time'][iteration] = epoch_time.value()
        all_metrics['epoch time_ps'][iteration] = epoch_time.value() / max(
            epoch_time.step_count, 1.)

        if trainer.state:
            # logger.warning(trainer.state.metrics)
            for metric, value in trainer.state.metrics.items():
                all_metrics[metric][iteration] = value
                if viz:
                    viz.line(
                        [value],
                        X=[iteration],
                        win=metric.split()[-1],
                        name=metric,
                        update=None if iteration == 0 else 'append',
                        opts={
                            'title': metric,
                            'showlegend': True,
                            'width': 500,
                            'xlabel': 'iterations'
                        })

        iter_this_step = iteration - last_iteration
        for d_loader, name in zip(eval_loaders, split_names):
            if name == 'Train':
                if iteration == 0:
                    all_metrics['Trainer loss'][iteration] = float('nan')
                    all_metrics['Trainer accuracy_0'][iteration] = float('nan')
                    if hasattr(model, 'arch_sampler'):
                        all_metrics['Trainer all entropy'][iteration] = float(
                            'nan')
                        all_metrics['Trainer sampling entropy'][
                            iteration] = float('nan')
                        # if hasattr(model, 'cur_split'):
                        all_metrics['Trainer split'][iteration] = float('nan')
                continue
            split_metrics = log_results(evaluator, d_loader, iteration, name)
            if select_metric not in split_metrics:
                continue
            if is_better(split_metrics[select_metric], best['value']):
                best['value'] = split_metrics[select_metric]
                best['iter'] = iteration
                best['state_dict'] = copy.deepcopy(model.state_dict())
                if patience > 0:
                    patience_counter = 0
            elif patience > 0:
                patience_counter += iter_this_step
                if patience_counter >= patience:
                    logger.info('#####')
                    logger.info('# Early stopping Run')
                    logger.info('#####')
                    trainer.terminate()
        last_iteration = iteration
        eval_time.pause()
        eval_time.step()
        all_metrics['eval time'][iteration] = eval_time.value()
        all_metrics['eval time_ps'][iteration] = eval_time.value(
        ) / eval_time.step_count
        all_metrics['total time'][iteration] = total_time.value()
        epoch_time.resume()

    log_event(trainer)

    #
    # @trainer.on(Events.EPOCH_COMPLETED)
    # def log_epoch(trainer):
    #     iteration = trainer.state.iteration if trainer.state else 0
    #     epoch = iteration/epoch_steps
    #     fw_t = forward_time.value()
    #     fw_t_ps = fw_t / forward_time.step_count
    #     d_t = data_time.value()
    #     d_t_ps = d_t / data_time.step_count
    #     e_t = epoch_time.value()
    #     e_t_ps = e_t / epoch_time.step_count
    #     ev_t = eval_time.value()
    #     ev_t_ps = ev_t / eval_time.step_count
    #     logger.warning('<{}> Epoch {}/{} finished (Forward: {:.3f}s({:.3f}), '
    #                    'data: {:.3f}s({:.3f}), epoch: {:.3f}s({:.3f}),'
    #                    ' Eval: {:.3f}s({:.3f}), Total: '
    #                    '{:.3f}s)'.format(type(model).__name__, epoch,
    #                                      max_epoch, fw_t, fw_t_ps, d_t, d_t_ps,
    #                                      e_t, e_t_ps, ev_t, ev_t_ps,
    #                                      total_time.value()))

    data_time.attach(trainer,
                     start=Events.STARTED,
                     pause=Events.ITERATION_STARTED,
                     resume=Events.ITERATION_COMPLETED,
                     step=Events.ITERATION_STARTED)

    if hasattr(model, 'iter_per_epoch'):
        model.iter_per_epoch = len(train_loader)
    trainer.run(train_loader, max_epochs=max_epoch)
    return trainer.state.iteration, all_metrics, best
Exemplo n.º 6
0
    def _test(metric_device):
        data = list(range(n_iters))
        np.random.seed(12)
        all_y_true_batch_values = np.random.randint(
            0,
            n_classes,
            size=(idist.get_world_size(), n_epochs * n_iters, batch_size))
        all_y_pred_batch_values = np.random.rand(idist.get_world_size(),
                                                 n_epochs * n_iters,
                                                 batch_size, n_classes)

        y_true_batch_values = iter(all_y_true_batch_values[rank, ...])
        y_pred_batch_values = iter(all_y_pred_batch_values[rank, ...])

        def update_fn(engine, batch):
            y_true_batch = next(y_true_batch_values)
            y_pred_batch = next(y_pred_batch_values)
            return torch.from_numpy(y_pred_batch), torch.from_numpy(
                y_true_batch)

        trainer = Engine(update_fn)
        alpha = 0.98

        acc_metric = RunningAverage(Accuracy(
            output_transform=lambda x: [x[0], x[1]], device=metric_device),
                                    alpha=alpha,
                                    epoch_bound=False)
        acc_metric.attach(trainer, "running_avg_accuracy")

        running_avg_acc = [
            None,
        ]
        true_acc_metric = Accuracy(device=metric_device)

        @trainer.on(Events.ITERATION_COMPLETED)
        def manual_running_avg_acc(engine):
            i = engine.state.iteration - 1

            true_acc_metric.reset()
            for j in range(idist.get_world_size()):
                output = (
                    torch.from_numpy(all_y_pred_batch_values[j, i, :, :]),
                    torch.from_numpy(all_y_true_batch_values[j, i, :]),
                )
                true_acc_metric.update(output)

            batch_acc = true_acc_metric._num_correct.item(
            ) * 1.0 / true_acc_metric._num_examples

            if running_avg_acc[0] is None:
                running_avg_acc[0] = batch_acc
            else:
                running_avg_acc[0] = running_avg_acc[0] * alpha + (
                    1.0 - alpha) * batch_acc
            engine.state.running_avg_acc = running_avg_acc[0]

        @trainer.on(Events.ITERATION_COMPLETED)
        def assert_equal_running_avg_acc_values(engine):
            assert (
                engine.state.running_avg_acc ==
                engine.state.metrics["running_avg_accuracy"]
            ), f"{engine.state.running_avg_acc} vs {engine.state.metrics['running_avg_accuracy']}"

        trainer.run(data, max_epochs=3)
Exemplo n.º 7
0
def train(args):
    logger.info("Prepare tokenizer, pretrained model and optimizer.")
    tokenizer, _, vocab = get_kogpt2_tokenizer()
    model = get_kogpt2_model()
    model.to(args.device)
    optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True)

    logger.info("Prepare datasets")
    train_loader, val_loader = get_data_loaders(args, tokenizer, vocab)

    def update(engine, batch):
        model.train()

        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        input_ids, labels, token_type_ids = batch

        loss, *_ = model(input_ids,
                         token_type_ids=token_type_ids,
                         labels=labels)
        loss = loss / args.gradient_accumulation_steps

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)

        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        return loss.item()

    trainer = Engine(update)

    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            input_ids, labels, token_type_ids = batch
            # logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            # if we dont send labels to model, it doesnt return losses
            logits, *_ = model(input_ids, token_type_ids=token_type_ids)
            logits_flat_shifted = logits[..., :-1, :].contiguous().view(
                -1, logits.size(-1))
            labels_flat_shifted = labels[..., 1:].contiguous().view(-1)
            return (logits_flat_shifted), (labels_flat_shifted)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-100),
             output_transform=lambda x: (x[0], x[1])),
        "accuracy":
        Accuracy(output_transform=lambda x: (x[0], x[1]))
    }
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model,
    # configuration and tokenizer before we start to train
    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names=["loss"])
    evaluator.add_event_handler(
        Events.COMPLETED, lambda _: pbar.log_message(
            "Validation: %s" % pformat(evaluator.state.metrics)))

    log_dir = make_logdir("kogpt2_personachat")
    tb_logger = TensorboardLogger(log_dir)

    tb_logger.attach(trainer,
                     log_handler=OutputHandler(tag="training",
                                               metric_names=["loss"]),
                     event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(trainer,
                     log_handler=OptimizerParamsHandler(optimizer),
                     event_name=Events.ITERATION_STARTED)
    tb_logger.attach(
        evaluator,
        log_handler=OutputHandler(
            tag="validation",
            metric_names=list(metrics.keys()),
            global_step_transform=global_step_from_engine(trainer)),
        event_name=Events.EPOCH_COMPLETED)

    checkpoint_handler = ModelCheckpoint(log_dir,
                                         'checkpoint',
                                         save_interval=1,
                                         n_saved=3)
    trainer.add_event_handler(
        Events.EPOCH_COMPLETED, checkpoint_handler,
        {'mymodel': getattr(model, 'module', model)
         })  # "getattr" takes care of distributed encapsulation

    torch.save(args, log_dir + '/model_training_args.bin')
    getattr(model, 'module',
            model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME))
    # tokenizer.save_pretrained(log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    # TODO: PR in ignite to have better access to saved file paths (cleaner)
    os.rename(os.path.join(log_dir, checkpoint_handler._saved[-1][1]),
              os.path.join(log_dir, WEIGHTS_NAME))
    tb_logger.close()
Exemplo n.º 8
0
def main(
    model_type,
    dataset,
    dataroot,
    download,
    augment,
    batch_size,
    eval_batch_size,
    epochs,
    saved_model,
    seed,
    hidden_channels,
    K,
    L,
    actnorm_scale,
    flow_permutation,
    flow_coupling,
    LU_decomposed,
    flow_embed_dim,
    learn_top,
    y_condition,
    y_weight,
    max_grad_clip,
    max_grad_norm,
    lr,
    n_workers,
    cuda,
    n_init_batches,
    output_dir,
    saved_optimizer,
    warmup,
):
    def build_model():
        if model_type == "glow":
            model = Glow(
                image_shape,
                hidden_channels,
                K,
                L,
                actnorm_scale,
                flow_permutation,
                flow_coupling,
                LU_decomposed,
                num_classes,
                learn_top,
                y_condition,
            )
        elif model_type == "glow_large":
            model = GlowLarge(
                image_shape,
                hidden_channels,
                K,
                L,
                actnorm_scale,
                flow_permutation,
                flow_coupling,
                LU_decomposed,
                flow_embed_dim,
                num_classes,
                learn_top,
                y_condition,
            )
        elif model_type == "nanoflow_naive":
            model = NanoFlowNaive(
                image_shape,
                hidden_channels,
                K,
                L,
                actnorm_scale,
                flow_permutation,
                flow_coupling,
                LU_decomposed,
                num_classes,
                learn_top,
                y_condition,
            )
        elif model_type == "nanoflow_decomp":
            model = NanoFlowDecomp(
                image_shape,
                hidden_channels,
                K,
                L,
                actnorm_scale,
                flow_permutation,
                flow_coupling,
                LU_decomposed,
                num_classes,
                learn_top,
                y_condition,
            )
        elif model_type == "nanoflow":
            model = NanoFlow(
                image_shape,
                hidden_channels,
                K,
                L,
                actnorm_scale,
                flow_permutation,
                flow_coupling,
                LU_decomposed,
                flow_embed_dim,
                num_classes,
                learn_top,
                y_condition,
            )
        elif model_type == "nanoflowalt_naive":
            model = NanoFlowAltNaive(
                image_shape,
                hidden_channels,
                K,
                L,
                actnorm_scale,
                flow_permutation,
                flow_coupling,
                LU_decomposed,
                num_classes,
                learn_top,
                y_condition,
            )
        elif model_type == "nanoflowalt_decomp":
            model = NanoFlowAltDecomp(
                image_shape,
                hidden_channels,
                K,
                L,
                actnorm_scale,
                flow_permutation,
                flow_coupling,
                LU_decomposed,
                num_classes,
                learn_top,
                y_condition,
            )
        elif model_type == "nanoflowalt":
            model = NanoFlowAlt(
                image_shape,
                hidden_channels,
                K,
                L,
                actnorm_scale,
                flow_permutation,
                flow_coupling,
                LU_decomposed,
                flow_embed_dim,
                num_classes,
                learn_top,
                y_condition,
            )
        else:
            raise ValueError("uknown --model_type")
        return model

    device = "cpu" if (not torch.cuda.is_available() or not cuda) else "cuda:0"

    check_manual_seed(seed)

    ds = check_dataset(dataset, dataroot, augment, download)
    image_shape, num_classes, train_dataset, test_dataset = ds

    # Note: unsupported for now
    multi_class = False

    train_loader = data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=n_workers,
        drop_last=True,
    )
    test_loader = data.DataLoader(
        test_dataset,
        batch_size=eval_batch_size,
        shuffle=False,
        num_workers=n_workers,
        drop_last=False,
    )

    logger = SummaryWriter(output_dir)

    model = build_model()

    model = model.to(device)
    optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup)  # noqa
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = batch
        x = x.to(device)

        if y_condition:
            y = y.to(device)
            z, nll, y_logits = model(x, y)
            losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class)
        else:
            z, nll, y_logits = model(x, None)
            losses = compute_loss(nll)

        losses["total_loss"].backward()

        if max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        if max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits = model(x, y)
                losses = compute_loss_y(
                    nll, y_logits, y_weight, y, multi_class, reduction="none"
                )
            else:
                z, nll, y_logits = model(x, None)
                losses = compute_loss(nll, reduction="none")

        return losses

    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(
        output_dir, "glow", save_interval=1, n_saved=10, require_empty=False
    )

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        checkpoint_handler,
        {"model": model, "optimizer": optimizer},
    )

    monitoring_metrics = ["total_loss"]
    RunningAverage(output_transform=lambda x: x["total_loss"]).attach(
        trainer, "total_loss"
    )

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(
        lambda x, y: torch.mean(x),
        output_transform=lambda x: (
            x["total_loss"],
            torch.empty(x["total_loss"].shape[0]),
        ),
    ).attach(evaluator, "total_loss")

    if y_condition:
        monitoring_metrics.extend(["nll"])
        RunningAverage(output_transform=lambda x: x["nll"]).attach(trainer, "nll")

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(
            lambda x, y: torch.mean(x),
            output_transform=lambda x: (x["nll"], torch.empty(x["nll"].shape[0])),
        ).attach(evaluator, "nll")

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        model.load_state_dict(torch.load(saved_model)["model"])
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer)["optimizer"])

        file_name, ext = os.path.splitext(saved_model)
        resume_iteration = int(file_name.split("_")[-1])

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.iteration = resume_iteration
            engine.state.epoch = engine.state.iteration // len(engine.state.dataloader)

    @trainer.on(Events.STARTED)
    def init(engine):
        model.train()

        init_batches = []
        init_targets = []

        with torch.no_grad():
            for batch, target in islice(train_loader, None, n_init_batches):
                init_batches.append(batch)
                init_targets.append(target)

            init_batches = torch.cat(init_batches).to(device)

            assert init_batches.shape[0] == n_init_batches * batch_size

            if y_condition:
                init_targets = torch.cat(init_targets).to(device)
            else:
                init_targets = None

            model(init_batches, init_targets)

    # @trainer.on(Events.ITERATION_COMPLETED)
    # def log_tensorboard(engine):
    #     logger.add_scalar('train_loss', 1, engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(engine):
        evaluator.run(test_loader)

        scheduler.step()
        metrics = evaluator.state.metrics

        losses = ", ".join([f"{key}: {value:.2f}" for key, value in metrics.items()])

        print(f"Validation Results - Epoch: {engine.state.epoch} {losses}")
        logger.add_scalar("validation/loss", metrics["total_loss"], engine.state.epoch)
        logger.flush()

    timer = Timer(average=True)
    timer.attach(
        trainer,
        start=Events.EPOCH_STARTED,
        resume=Events.ITERATION_STARTED,
        pause=Events.ITERATION_COMPLETED,
        step=Events.ITERATION_COMPLETED,
    )

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f"Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]"
        )
        timer.reset()

    trainer.run(train_loader, epochs)
        outputs = model(input_ids = input_ids, mc_token_ids = mc_token_ids, mc_labels = mc_labels,
                  lm_labels = lm_labels, token_type_ids = token_type_ids)
        lm_loss, mc_loss = outputs[0], outputs[1]
        lm_coef = 2.0
        mc_coef = 1.0
        total_loss = lm_loss * lm_coef + mc_loss * mc_coef
    return lm_loss.item(),mc_loss.item(),total_loss.item()


trainer = Engine(process_function)
evaluator = Engine(evaluate_function)

training_history = {'lm_loss': [], 'mc_loss': [], 'total_loss': []}
validation_history = {'lm_loss': [], 'mc_loss': [], 'total_loss': []}

RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'lm_loss')
RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'mc_loss')
RunningAverage(output_transform=lambda x: x[2]).attach(trainer, 'total_loss')


RunningAverage(output_transform=lambda x: x[0]).attach(evaluator, 'lm_loss')
RunningAverage(output_transform=lambda x: x[1]).attach(evaluator, 'mc_loss')
RunningAverage(output_transform=lambda x: x[2]).attach(evaluator, 'total_loss')


@trainer.on(Events.ITERATION_COMPLETED(every=50))
def print_trainer_logs(engine):
    # try:
    #   start
    # except:
    #   start = timeit.default_timer()
Exemplo n.º 10
0
def train_sequence(
        model: AMOCNet,
        contrast_criterion: torch.nn.Module,
        class_criterion_A: torch.nn.Module,
        class_criterion_B: torch.nn.Module,
        train_loader: torch.utils.data.DataLoader,
        test_loader: torch.utils.data.DataLoader,
        opt,
        printInds: Iterable = None) -> (torch.nn.Module, dict, dict):
    optimizer = optim.Adam(model.parameters(), lr=opt.learningRate)
    timer = Timer()
    confusion_logger = VisdomLogger('heatmap',
                                    port=8097,
                                    opts={
                                        'title':
                                        'simMat',
                                        'columnnames':
                                        list(range(len(train_loader.dataset))),
                                        'rownames':
                                        list(range(len(train_loader.dataset)))
                                    })
    epoch = 0
    if opt.pretrained or opt.motionnet_pretrained:
        model, optimier, epoch = load_dicts(
            model, optimizer, opt.pretrained or opt.motionnet_pretrained)

    if printInds is None:
        printInds = list(range(10))

    def iterate_func(engine, batch):
        optimizer.zero_grad()
        inputA, inputB, target, personA, personB, ind, _, _ = batch
        if len(inputA.shape) == len(inputB.shape) == 4:
            inputA = torch.unsqueeze(inputA, 0)
            inputB = torch.unsqueeze(inputB, 0)
        assert inputA.shape[1] == inputB.shape[1] == opt.sampleSeqLength, \
            ValueError(f"ind: {ind}, inputA {inputA.shape}, inputB {inputB.shape}, required seq lenth {opt.sampleSeqLength}")
        if torch.cuda.is_available():
            inputA = inputA.float().cuda()
            inputB = inputB.float().cuda()
            target = target.float().cuda()
            personA = personA.long().cuda()
            personB = personB.long().cuda()
        distance, outputA, outputB = model(inputA, inputB)
        contrast_loss = contrast_criterion(distance, target)
        class_loss_A = class_criterion_A(outputA, personA)
        class_loss_B = class_criterion_B(outputB, personB)
        loss = contrast_loss + class_loss_A + class_loss_B
        loss.backward()

        clip_grad_value_(model.parameters(),
                         clip_value=opt.gradClip or sys.maxsize)
        optimizer.step()
        return loss.item(), contrast_loss.item(), class_loss_A.item(
        ), class_loss_B.item()

    trainer = Engine(iterate_func)
    train_history = {'cnst': [], 'ceA': [], 'ceB': [], 'ttl': []}
    val_history = {'avgSame': [], 'avgDiff': [], 'cmc': [], 'simMat': []}
    RunningAverage(alpha=1,
                   output_transform=lambda x: x[0]).attach(trainer, 'ttl')
    RunningAverage(alpha=1,
                   output_transform=lambda x: x[1]).attach(trainer, 'cnst')
    RunningAverage(alpha=1,
                   output_transform=lambda x: x[2]).attach(trainer, 'ceA')
    RunningAverage(alpha=1,
                   output_transform=lambda x: x[3]).attach(trainer, 'ceB')
    train_loss_logger = VisdomPlotLogger("line", name="train")
    val_loss_logger = VisdomPlotLogger("line", name="val")

    score_func = lambda engine: -engine.state.metrics['ttl']
    checkpoint_handler = ModelCheckpointSaveBest(
        opt.checkpoint_path,
        filename_prefix=opt.saveFileName,
        score_function=score_func,
        require_empty=False,
        save_as_state_dict=True)
    # stop_handler = EarlyStopping(patience=30, trainer=trainer,
    #                              score_function=score_func)

    @trainer.on(Events.STARTED)
    def resume_training(engine):
        engine.state.iteration = epoch * len(engine.state.dataloader)
        engine.state.epoch = epoch
        checkpoint_handler._iteration = epoch

    @trainer.on(Events.EPOCH_COMPLETED)
    def trainer_log(engine: Engine):
        avg_ttl = engine.state.metrics['ttl']
        avg_cnst = engine.state.metrics['cnst']
        avg_ceA = engine.state.metrics['ceA']
        avg_ceB = engine.state.metrics['ceB']
        lr = optimizer.param_groups[0]['lr']
        print(
            f"Epoch[{engine.state.epoch}]\tlr={lr:.2e}\telapsed:{timer.value():.2f}s:\t"
            f"TTL={avg_ttl:.3f}\tContrast={avg_cnst:04.3f}\t"
            f"CrossEntA={avg_ceA:04.3f}\tCrossEntB={avg_ceB:04.3f}")
        train_loss_logger.log(engine.state.epoch,
                              avg_ttl,
                              name="avg_total_loss")
        train_loss_logger.log(engine.state.epoch,
                              avg_cnst,
                              name="avg_contrast")
        train_loss_logger.log(engine.state.epoch,
                              avg_ceA,
                              name="avg_CrossEnt_A")
        train_loss_logger.log(engine.state.epoch,
                              avg_ceB,
                              name="avg_CrossEnt_B")

    @trainer.on(Events.ITERATION_COMPLETED)
    def adjust_lr(engine):
        # learning rate decay
        if engine.state.iteration >= 20000:
            lr = opt.learningRate * (0.1**min(
                (engine.state.iteration - 10000) // opt.lr_decay, 5))
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

    def on_complete(engine, dataloader, mode, history_dict):
        if not engine.state.epoch % opt.samplingEpochs:
            cmc, simMat, _, avgSame, avgDiff = compute_cmc(
                dataloader.dataset, printInds, model, opt.sampleSeqLength)

            metrics = {
                "cmc": cmc,
                "simMat": simMat,
                "avgSame": avgSame,
                "avgDiff": avgDiff
            }

            outString = ' '.join((str(np.floor(cmc[c])) for c in printInds))

            print(
                f"{mode} Result: Epoch[{engine.state.epoch}]- Avg Same={avgSame:.3f}\tAvg Diff={avgDiff:.3f}"
            )
            print(outString)

            confusion_logger.log(simMat)
            val_loss_logger.log(trainer.state.epoch, avgSame, name="avg_same")
            val_loss_logger.log(trainer.state.epoch, avgDiff, name="avg_diff")
            if mode == "Validation":
                for key in val_history.keys():
                    history_dict[key].append(metrics[key])

    trainer.add_event_handler(Events.EPOCH_COMPLETED, on_complete,
                              train_loader, 'Training', train_history)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, on_complete, test_loader,
                              'Validation', val_history)
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)
    # trainer.add_event_handler(Events.EPOCH_COMPLETED, stop_handler)
    checkpoint_handler.attach(trainer,
                              model_dict={
                                  "model": model,
                                  "optimizer": optimizer
                              })

    trainer.run(train_loader, max_epochs=opt.nEpochs)

    return model, trainer_log, val_history
Exemplo n.º 11
0
def main(
    dataset,
    dataroot,
    z_dim,
    g_filters,
    d_filters,
    batch_size,
    epochs,
    learning_rate,
    beta_1,
    saved_G,
    saved_D,
    seed,
    n_workers,
    device,
    alpha,
    output_dir,
):

    # seed
    check_manual_seed(seed)

    # data
    dataset, num_channels = check_dataset(dataset, dataroot)
    loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True)

    # netowrks
    netG = Generator(z_dim, g_filters, num_channels).to(device)
    netD = Discriminator(num_channels, d_filters).to(device)

    # criterion
    bce = nn.BCELoss()

    # optimizers
    optimizerG = optim.Adam(netG.parameters(), lr=learning_rate, betas=(beta_1, 0.999))
    optimizerD = optim.Adam(netD.parameters(), lr=learning_rate, betas=(beta_1, 0.999))

    # load pre-trained models
    if saved_G:
        netG.load_state_dict(torch.load(saved_G))

    if saved_D:
        netD.load_state_dict(torch.load(saved_D))

    # misc
    real_labels = torch.ones(batch_size, device=device)
    fake_labels = torch.zeros(batch_size, device=device)
    fixed_noise = torch.randn(batch_size, z_dim, 1, 1, device=device)

    def get_noise():
        return torch.randn(batch_size, z_dim, 1, 1, device=device)

    # The main function, processing a batch of examples
    def step(engine, batch):

        # unpack the batch. It comes from a dataset, so we have <images, labels> pairs. Discard labels.
        real, _ = batch
        real = real.to(device)

        # -----------------------------------------------------------
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        netD.zero_grad()

        # train with real
        output = netD(real)
        errD_real = bce(output, real_labels)
        D_x = output.mean().item()

        errD_real.backward()

        # get fake image from generator
        noise = get_noise()
        fake = netG(noise)

        # train with fake
        output = netD(fake.detach())
        errD_fake = bce(output, fake_labels)
        D_G_z1 = output.mean().item()

        errD_fake.backward()

        # gradient update
        errD = errD_real + errD_fake
        optimizerD.step()

        # -----------------------------------------------------------
        # (2) Update G network: maximize log(D(G(z)))
        netG.zero_grad()

        # Update generator. We want to make a step that will make it more likely that discriminator outputs "real"
        output = netD(fake)
        errG = bce(output, real_labels)
        D_G_z2 = output.mean().item()

        errG.backward()

        # gradient update
        optimizerG.step()

        return {"errD": errD.item(), "errG": errG.item(), "D_x": D_x, "D_G_z1": D_G_z1, "D_G_z2": D_G_z2}

    # ignite objects
    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir, CKPT_PREFIX, n_saved=10, require_empty=False)
    timer = Timer(average=True)

    # attach running average metrics
    monitoring_metrics = ["errD", "errG", "D_x", "D_G_z1", "D_G_z2"]
    RunningAverage(alpha=alpha, output_transform=lambda x: x["errD"]).attach(trainer, "errD")
    RunningAverage(alpha=alpha, output_transform=lambda x: x["errG"]).attach(trainer, "errG")
    RunningAverage(alpha=alpha, output_transform=lambda x: x["D_x"]).attach(trainer, "D_x")
    RunningAverage(alpha=alpha, output_transform=lambda x: x["D_G_z1"]).attach(trainer, "D_G_z1")
    RunningAverage(alpha=alpha, output_transform=lambda x: x["D_G_z2"]).attach(trainer, "D_G_z2")

    # attach progress bar
    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    @trainer.on(Events.ITERATION_COMPLETED(every=PRINT_FREQ))
    def print_logs(engine):
        fname = os.path.join(output_dir, LOGS_FNAME)
        columns = ["iteration",] + list(engine.state.metrics.keys())
        values = [str(engine.state.iteration),] + [str(round(value, 5)) for value in engine.state.metrics.values()]

        with open(fname, "a") as f:
            if f.tell() == 0:
                print("\t".join(columns), file=f)
            print("\t".join(values), file=f)

        message = "[{epoch}/{max_epoch}][{i}/{max_i}]".format(
            epoch=engine.state.epoch, max_epoch=epochs, i=(engine.state.iteration % len(loader)), max_i=len(loader)
        )
        for name, value in zip(columns, values):
            message += " | {name}: {value}".format(name=name, value=value)

        pbar.log_message(message)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def save_fake_example(engine):
        fake = netG(fixed_noise)
        path = os.path.join(output_dir, FAKE_IMG_FNAME.format(engine.state.epoch))
        vutils.save_image(fake.detach(), path, normalize=True)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def save_real_example(engine):
        img, y = engine.state.batch
        path = os.path.join(output_dir, REAL_IMG_FNAME.format(engine.state.epoch))
        vutils.save_image(img, path, normalize=True)

    # adding handlers using `trainer.add_event_handler` method API
    trainer.add_event_handler(
        event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={"netG": netG, "netD": netD}
    )

    # automatically adding handlers via a special `attach` method of `Timer` handler
    timer.attach(
        trainer,
        start=Events.EPOCH_STARTED,
        resume=Events.ITERATION_STARTED,
        pause=Events.ITERATION_COMPLETED,
        step=Events.ITERATION_COMPLETED,
    )

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(f"Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]")
        timer.reset()

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def create_plots(engine):
        try:
            import matplotlib as mpl

            mpl.use("agg")

            import numpy as np
            import pandas as pd
            import matplotlib.pyplot as plt

        except ImportError:
            warnings.warn("Loss plots will not be generated -- pandas or matplotlib not found")

        else:
            df = pd.read_csv(os.path.join(output_dir, LOGS_FNAME), delimiter="\t", index_col="iteration")
            _ = df.plot(subplots=True, figsize=(20, 20))
            _ = plt.xlabel("Iteration number")
            fig = plt.gcf()
            path = os.path.join(output_dir, PLOT_FNAME)

            fig.savefig(path)

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EXCEPTION_RAISED)
    def handle_exception(engine, e):
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            engine.terminate()
            warnings.warn("KeyboardInterrupt caught. Exiting gracefully.")

            create_plots(engine)
            checkpoint_handler(engine, {"netG_exception": netG, "netD_exception": netD})

        else:
            raise e

    # Setup is done. Now let's run the training
    trainer.run(loader, epochs)
Exemplo n.º 12
0
def train_motion_net(model: MotionNet, criterion: torch.nn.SmoothL1Loss,
                     train_loader: torch.utils.data.DataLoader,
                     test_loader: torch.utils.data.DataLoader, opt):
    optimizer = optim.Adam(model.parameters(), lr=opt.learningRate)
    model = init_weights(model)
    of_logger = VisdomLogger('image',
                             win="of",
                             port=8097,
                             opts={"caption": "output"})
    gt_of_logger = VisdomLogger('image',
                                win="gt",
                                port=8097,
                                opts={"caption": "gt"})
    loss_weight = [0.01, 0.02, 0.08]
    epoch = 0
    if opt.pretrained:
        model, optimier, epoch = load_dicts(model, optimizer, opt.pretrained)

    def iterate_func(engine, batch):
        model.train()
        inputA, inputB, _, _, _, ind, ofA, ofB = batch
        if len(inputA.shape) == len(inputB.shape) == 4:
            inputA = inputA.unsqueeze(0)
            inputB = inputB.unsqueeze(0)
        assert inputA.shape[1] == inputB.shape[1] == opt.sampleSeqLength, \
            ValueError(f"ind: {ind}, inputA {inputA.shape}, inputB {inputB.shape}, required seq lenth {opt.sampleSeqLength}")
        if torch.cuda.is_available():
            inputA = inputA.float().cuda()
            inputB = inputB.float().cuda()
            ofA = ofA.float().cuda()
            ofB = ofB.float().cuda()

        def _iterate(input_, of):
            """
            single passthrough of training of MotionNet
            :param input: two consecutive frames concatenated along axis 0: [1, 6, W, H]
            :param of: target feature map of output of MotionNet: [1, 2, W, H]
            :return:
            """
            optimizer.zero_grad()
            outs = list(model(input_))
            losses = []
            for i, out in enumerate(outs):
                factor = of.shape[2] // out.shape[2]
                gt = AvgPool2d(factor, factor)(of).detach().data
                losses += [criterion(out, gt) * loss_weight[i]]
            loss = sum(losses)
            loss.backward()
            optimizer.step()
            return loss.item()

        for i in range(inputA.shape[1] - 1):
            consecutive_frame = torch.cat(
                (inputA[:, i, ...], inputA[:, i + 1, ...]), 1)
            _iterate(consecutive_frame, ofA[:, i, ...])

        for i in range(inputB.shape[1] - 1):
            consecutive_frame = torch.cat(
                (inputB[:, i, ...], inputB[:, i + 1, ...]), 1)
            losses = _iterate(consecutive_frame, ofB[:, i, ...])
        return losses

    def eval_func(engine, batch):
        cnt = 1
        model.eval()
        with torch.no_grad():
            inputA, inputB, _, _, _, ind, ofA_, ofB_ = batch
            if len(inputA.shape) == len(inputB.shape) == 4:
                inputA = inputA.unsqueeze(0)
                inputB = inputB.unsqueeze(0)
            assert inputA.shape[1] == inputB.shape[1] == opt.sampleSeqLength, \
                ValueError(f"ind: {ind}, inputA {inputA.shape}, inputB {inputB.shape}, required seq lenth {opt.sampleSeqLength}")
            if torch.cuda.is_available():
                inputA = inputA.float().cuda()
                inputB = inputB.float().cuda()
                ofA = ofA_.float().cuda()
                ofB = ofB_.float().cuda()

            def _iterate(input_, of):
                outs = list(model(input_))
                loss = []
                for i, out in enumerate(outs):
                    factor = of.shape[2] // out.shape[2]
                    gt = AvgPool2d(factor, factor)(of).detach().data
                    loss += [criterion(out, gt) * loss_weight[i]]
                return sum(loss).item(), outs[-1]

            for i in range(inputA.shape[1] - 1):
                consecutive_frame = torch.cat(
                    (inputA[:, i, ...], inputA[:, i + 1, ...]), 1)
                _, out = _iterate(consecutive_frame, ofA[:, i, ...])
                if cnt:
                    cnt -= 1
                    of_logger.log(vis_of(out.cpu()))
                    gt_of_logger.log(vis_of(ofA_[:, i, ...]))

            for i in range(inputB.shape[1] - 1):
                consecutive_frame = torch.cat(
                    (inputB[:, i, ...], inputB[:, i + 1, ...]), 1)
                losses, _ = _iterate(consecutive_frame, ofB[:, i, ...])
            return losses

    trainer = Engine(iterate_func)
    evaluator = Engine(eval_func)
    train_history = {'loss': []}
    val_history = {'loss': []}
    RunningAverage(alpha=1,
                   output_transform=lambda x: x).attach(trainer, 'loss')
    RunningAverage(alpha=1,
                   output_transform=lambda x: x).attach(evaluator, 'loss')
    score_func = lambda engine: -engine.state.metrics['loss']
    checkpoint_handler = ModelCheckpointSaveBest(
        opt.checkpoint_path,
        filename_prefix=opt.saveFileName,
        score_function=score_func,
        require_empty=False,
        save_as_state_dict=True)
    stop_handler = EarlyStopping(patience=30,
                                 trainer=trainer,
                                 score_function=score_func)

    @trainer.on(Events.STARTED)
    def resume_training(engine):
        engine.state.iteration = epoch * len(engine.state.dataloader)
        engine.state.epoch = epoch
        checkpoint_handler._iteration = epoch

    @trainer.on(Events.EPOCH_COMPLETED)
    def trainer_log(engine: Engine):
        loss = engine.state.metrics['loss']
        lr = optimizer.param_groups[0]['lr']
        print("-" * 50)
        print(
            f"Epoch[{engine.state.epoch}] lr={lr:.2E}:\t\tAvg Loss={loss:.4f}")

    @trainer.on(Events.ITERATION_COMPLETED)
    def adjust_lr(engine):
        # learning rate decay
        lr = opt.learningRate * (0.1**(engine.state.iteration // opt.lr_decay))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    def on_complete(engine, dataloader, mode, history_dict):
        evaluator.run(dataloader)
        loss = evaluator.state.metrics["loss"]
        print(
            f"{mode} Result: Epoch[{engine.state.epoch}]:\tAvg Loss={loss:.4f}"
        )

        if mode == "Validation":
            for key in val_history.keys():
                history_dict[key].append(loss)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, on_complete,
                              train_loader, 'Training', train_history)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, on_complete, test_loader,
                              'Validation', val_history)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, stop_handler)
    checkpoint_handler.attach(trainer,
                              model_dict={
                                  "model": model,
                                  "optimizer": optimizer
                              })

    trainer.run(train_loader, max_epochs=opt.nEpochs)
Exemplo n.º 13
0
def _setup_common_training_handlers(
    trainer: Engine,
    to_save: Optional[Mapping] = None,
    save_every_iters: int = 1000,
    output_path: Optional[str] = None,
    lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None,
    with_gpu_stats: bool = False,
    output_names: Optional[Iterable[str]] = None,
    with_pbars: bool = True,
    with_pbar_on_iters: bool = True,
    log_every_iters: int = 100,
    stop_on_nan: bool = True,
    clear_cuda_cache: bool = True,
    save_handler: Optional[Union[Callable, BaseSaveHandler]] = None,
    **kwargs: Any
):
    if output_path is not None and save_handler is not None:
        raise ValueError(
            "Arguments output_path and save_handler are mutually exclusive. Please, define only one of them"
        )

    if stop_on_nan:
        trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

    if lr_scheduler is not None:
        if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler):
            trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step())
        elif isinstance(lr_scheduler, LRScheduler):
            trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler)
        else:
            trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)

    if torch.cuda.is_available() and clear_cuda_cache:
        trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)

    if to_save is not None:

        if output_path is None and save_handler is None:
            raise ValueError(
                "If to_save argument is provided then output_path or save_handler arguments should be also defined"
            )
        if output_path is not None:
            save_handler = DiskSaver(dirname=output_path, require_empty=False)

        checkpoint_handler = Checkpoint(to_save, save_handler, filename_prefix="training", **kwargs)
        trainer.add_event_handler(Events.ITERATION_COMPLETED(every=save_every_iters), checkpoint_handler)

    if with_gpu_stats:
        GpuInfo().attach(trainer, name="gpu", event_name=Events.ITERATION_COMPLETED(every=log_every_iters))

    if output_names is not None:

        def output_transform(x, index, name):
            if isinstance(x, Mapping):
                return x[name]
            elif isinstance(x, Sequence):
                return x[index]
            elif isinstance(x, (torch.Tensor, numbers.Number)):
                return x
            else:
                raise TypeError(
                    "Unhandled type of update_function's output. "
                    "It should either mapping or sequence, but given {}".format(type(x))
                )

        for i, n in enumerate(output_names):
            RunningAverage(output_transform=partial(output_transform, index=i, name=n), epoch_bound=False).attach(
                trainer, n
            )

    if with_pbars:
        if with_pbar_on_iters:
            ProgressBar(persist=False).attach(
                trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED(every=log_every_iters)
            )

        ProgressBar(persist=True, bar_format="").attach(
            trainer, event_name=Events.EPOCH_STARTED, closing_event_name=Events.COMPLETED
        )
Exemplo n.º 14
0
def do_train(
        cfg,
        model,
        train_loader,
        val_loader,
        optimizer,
        scheduler,  # modify for using self trained model
        loss_fn,
        num_query,
        start_epoch,  # add for using self trained model
        clustering_loader):
    log_period = cfg.SOLVER.LOG_PERIOD
    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    eval_period = cfg.SOLVER.EVAL_PERIOD
    output_dir = cfg.OUTPUT_DIR
    device = cfg.MODEL.DEVICE
    epochs = cfg.SOLVER.MAX_EPOCHS
    clustering_period = cfg.CLUSTERING.PERIOD
    clustering_stop = cfg.CLUSTERING.STOP
    with_arm = cfg.TEST.WITH_ARM

    logger = logging.getLogger("reid_baseline.train")
    logger.info("Start training")
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        loss_fn,
                                        device=device)
    if with_arm:
        evaluator = create_supervised_evaluator(
            model,
            metrics={
                'r1_mAP':
                R1_mAP_arm(num_query,
                           max_rank=50,
                           feat_norm=cfg.TEST.FEAT_NORM)
            },
            device=device,
            with_arm=with_arm)
    else:
        evaluator = create_supervised_evaluator(
            model,
            metrics={
                'r1_mAP':
                R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)
            },
            device=device,
            with_arm=with_arm)
    checkpointer = ModelCheckpoint(output_dir,
                                   cfg.MODEL.NAME,
                                   checkpoint_period,
                                   n_saved=10,
                                   require_empty=False)
    timer = Timer(average=True)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {
        'model': model,
        'optimizer': optimizer
    })

    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    # average metric to attach on trainer
    RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
    RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')

    @trainer.on(Events.STARTED)
    def start_training(engine):
        engine.state.epoch = start_epoch

    @trainer.on(Events.EPOCH_STARTED)
    def adjust_learning_rate(engine):
        scheduler.step()

    @trainer.on(Events.EPOCH_STARTED)
    def adjust_mask_pseudo_labels(engine):

        if engine.state.epoch % clustering_period == 1 and engine.state.epoch <= clustering_stop:
            #if False:

            torch.cuda.empty_cache()
            feats, pseudo_labels_paths, pids, shape = compute_features(
                clustering_loader, model, device, with_arm)
            torch.cuda.empty_cache()

            cluster_begin = time.time()
            logger.info('clustering and adjust pseudo-labels begin...')

            pid_label = set(pids)
            for label in pid_label:
                indexs = [i for i in range(len(pids)) if pids[i] == label]
                feats_I = feats[indexs]
                pseudo_labels_paths_I = [
                    pseudo_labels_paths[i] for i in indexs
                ]
                cluster_for_each_identity(cfg, feats_I, pseudo_labels_paths_I,
                                          shape)

            logger.info(
                'mask adjust use time: {0:.0f} s'.format(time.time() -
                                                         cluster_begin))

            #evaluate the pseudo-part-labels
            if cfg.DATASETS.NAMES == 'market1501':
                pred_dir = os.path.join(cfg.DATASETS.ROOT_DIR, 'Market-1501',
                                        cfg.DATASETS.PSEUDO_LABEL_SUBDIR)
                gt_dir = os.path.join(cfg.DATASETS.ROOT_DIR, 'Market-1501',
                                      cfg.DATASETS.PREDICTED_GT_SUBDIR)
                compute_IoU(pred_dir, gt_dir, cfg.CLUSTERING.PART_NUM)

            elif cfg.DATASETS.NAMES == 'dukemtmc':
                pred_dir = os.path.join(cfg.DATASETS.ROOT_DIR, 'DukeMTMC-reID',
                                        cfg.DATASETS.PSEUDO_LABEL_SUBDIR)
                gt_dir = os.path.join(cfg.DATASETS.ROOT_DIR, 'DukeMTMC-reID',
                                      cfg.DATASETS.PREDICTED_GT_SUBDIR)
                compute_IoU(pred_dir, gt_dir, cfg.CLUSTERING.PART_NUM)

            elif cfg.DATASETS.NAMES == 'cuhk03_np_labeled':
                pred_dir = os.path.join(cfg.DATASETS.ROOT_DIR,
                                        'cuhk03-np/labeled',
                                        cfg.DATASETS.PSEUDO_LABEL_SUBDIR)
                gt_dir = os.path.join(cfg.DATASETS.ROOT_DIR,
                                      'cuhk03-np/labeled',
                                      cfg.DATASETS.PREDICTED_GT_SUBDIR)
                compute_IoU(pred_dir, gt_dir, cfg.CLUSTERING.PART_NUM)

            elif cfg.DATASETS.NAMES == 'cuhk03_np_detected':
                pred_dir = os.path.join(cfg.DATASETS.ROOT_DIR,
                                        'cuhk03-np/detected',
                                        cfg.DATASETS.PSEUDO_LABEL_SUBDIR)
                gt_dir = os.path.join(cfg.DATASETS.ROOT_DIR,
                                      'cuhk03-np/detected',
                                      cfg.DATASETS.PREDICTED_GT_SUBDIR)
                compute_IoU(pred_dir, gt_dir, cfg.CLUSTERING.PART_NUM)

            torch.cuda.empty_cache()

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.state.iteration - 1) % len(train_loader) + 1

        if iter % log_period == 0:
            logger.info(
                "Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
                .format(engine.state.epoch, iter, len(train_loader),
                        engine.state.metrics['avg_loss'],
                        engine.state.metrics['avg_acc'],
                        scheduler.get_lr()[0]))

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        logger.info(
            'Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
            .format(engine.state.epoch,
                    timer.value() * timer.step_count,
                    train_loader.batch_size / timer.value()))
        logger.info('-' * 10)
        timer.reset()

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        if engine.state.epoch % eval_period == 0 or engine.state.epoch > 110:
            evaluator.run(val_loader)
            cmc, mAP = evaluator.state.metrics['r1_mAP']
            logger.info("Validation Results - Epoch: {}".format(
                engine.state.epoch))
            logger.info("mAP: {:.1%}".format(mAP))
            for r in [1, 5, 10]:
                logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(
                    r, cmc[r - 1]))
            torch.cuda.empty_cache()

    trainer.run(train_loader, max_epochs=epochs)
Exemplo n.º 15
0
    def eval_step(engine: Engine, batch: Batch) -> StepOutput:
        model.eval()
        images, targets = batch
        images, targets = images.to(device), targets.to(device)
        with torch.no_grad():
            preds = model(images)
        loss = F.cross_entropy(preds, targets)
        return {
            'preds': preds,
            'targets': targets,
            'cross_entropy': loss.item()
        }

    train_metrics = {
        'Loss':
        RunningAverage(output_transform=lambda x: x['cross_entropy']),
        'Accuracy':
        RunningAverage(
            Accuracy(output_transform=lambda x: (x['preds'], x['targets'])))
    }

    eval_metrics = {
        'Loss': Average(output_transform=lambda x: x['cross_entropy']),
        'Accuracy':
        Accuracy(output_transform=lambda x: (x['preds'], x['targets']))
    }

    train(args.run_name, model, train_set, test_set, train_step, eval_step,
          train_metrics, eval_metrics, args.n_iterations, args.batch_size)

    predictions = []
def do_train(cfg, model, train_loader, val_loader, optimizer, scheduler,
             loss_fn, metrics):

    device = cfg['device_ids'][0] if torch.cuda.is_available(
    ) else 'cpu'  #默认主卡设置为
    max_epochs = cfg['max_epochs']

    # create trainer
    if cfg['multi_gpu']:  #多卡时,不需要传入loss_fn
        trainer = create_supervised_dp_trainer(model.train(),
                                               optimizer,
                                               device=device)
    else:
        trainer = create_supervised_trainer(model.train(),
                                            optimizer,
                                            loss_fn,
                                            device=device)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())
    RunningAverage(output_transform=lambda x: x).attach(trainer, 'avg_loss')

    # create pbar
    len_train_loader = len(train_loader)
    pbar = tqdm(total=len_train_loader)

    ##########################################################################################
    ###########                    Events.ITERATION_COMPLETED                    #############
    ##########################################################################################

    # 每 log_period 轮迭代结束输出train_loss
    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        log_period = cfg['log_period']
        log_per_iter = int(log_period * len_train_loader) if int(
            log_period * len_train_loader) >= 1 else 1  # 计算打印周期
        current_iter = (engine.state.iteration - 1) % len_train_loader + 1 + (
            engine.state.epoch - 1) * len_train_loader  # 计算当前 iter

        lr = optimizer.state_dict()['param_groups'][0]['lr']

        if current_iter % log_per_iter == 0:
            pbar.write("Epoch[{}] Iteration[{}] lr {:.7f} Loss {:.7f}".format(
                engine.state.epoch, current_iter, lr,
                engine.state.metrics['avg_loss']))
            pbar.update(log_per_iter)

    # lr_scheduler
    @trainer.on(Events.ITERATION_COMPLETED)
    def adjust_lr_scheduler(engine):
        if isinstance(scheduler, lr_scheduler.CyclicLR):
            scheduler.step()

    @trainer.on(Events.ITERATION_COMPLETED)
    def update_swa(engine):
        if isinstance(scheduler, lr_scheduler.CyclicLR):
            if cfg['enable_swa']:
                swa_period = 2 * cfg['lr_scheduler']['step_size_up']
                current_iter = (
                    engine.state.iteration - 1) % len_train_loader + 1 + (
                        engine.state.epoch - 1) * len_train_loader  # 计算当前 iter
                if current_iter % swa_period == 0:
                    optimizer.update_swa()

    @trainer.on(Events.ITERATION_COMPLETED)
    def update_bn(engine):
        if isinstance(scheduler, lr_scheduler.CyclicLR):
            save_period = 2 * cfg['lr_scheduler']['step_size_up']
            current_iter = (
                engine.state.iteration - 1) % len_train_loader + 1 + (
                    engine.state.epoch - 1) * len_train_loader  # 计算当前 iter
            if current_iter % save_period == 0 and current_iter >= save_period * 2:  # 从第 4 个周期开始存

                save_dir = cfg['save_dir']
                if not os.path.isdir(save_dir):
                    os.makedirs(save_dir)
                if cfg['enable_swa']:
                    optimizer.swap_swa_sgd()
                    optimizer.bn_update(train_loader, model, device=device)
                model_name = os.path.join(
                    save_dir, cfg['model']['type'] + '_' + cfg['tag'] + "_" +
                    str(current_iter) + ".pth")
                if cfg['multi_gpu']:
                    save_pth = {
                        'model': model.module.model.state_dict(),
                        'cfg': cfg
                    }
                    torch.save(save_pth, model_name)
                else:
                    save_pth = {'model': model.state_dict(), 'cfg': cfg}
                    torch.save(save_pth, model_name)

    ##########################################################################################
    ##################               Events.EPOCH_COMPLETED                    ###############
    ##########################################################################################
    @trainer.on(Events.EPOCH_COMPLETED)
    def save_temp_epoch(engine):
        save_dir = cfg['save_dir']
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)

        epoch = engine.state.epoch
        if epoch % 1 == 0:
            model_name = os.path.join(
                save_dir,
                cfg['model']['type'] + '_' + cfg['tag'] + "_temp.pth")
            if cfg['multi_gpu']:
                save_pth = {
                    'model': model.module.model.state_dict(),
                    'cfg': cfg
                }
                torch.save(save_pth, model_name)
            else:
                save_pth = {'model': model.state_dict(), 'cfg': cfg}
                torch.save(save_pth, model_name)

    @trainer.on(Events.EPOCH_COMPLETED)
    def reset_pbar(engine):
        pbar.reset()

    trainer.run(train_loader, max_epochs=max_epochs)
    pbar.close()
Exemplo n.º 17
0
def test_output_is_tensor():

    m = RunningAverage(output_transform=lambda x: x)
    m.update(torch.rand(10, requires_grad=True).mean())
    v = m.compute()
    assert isinstance(v, torch.Tensor)
    assert not v.requires_grad

    m.update(torch.rand(10, requires_grad=True).mean())
    v = m.compute()
    assert isinstance(v, torch.Tensor)
    assert not v.requires_grad

    m.update(torch.rand(10, requires_grad=True).mean())
    v = m.compute()
    assert isinstance(v, torch.Tensor)
    assert not v.requires_grad
Exemplo n.º 18
0
def main(
    dataset,
    augment,
    batch_size,
    eval_batch_size,
    epochs,
    saved_model,
    seed,
    hidden_channels,
    K,
    L,
    actnorm_scale,
    flow_permutation,
    flow_coupling,
    LU_decomposed,
    learn_top,
    y_condition,
    extra_condition,
    sp_condition,
    d_condition,
    yd_condition,
    y_weight,
    max_grad_clip,
    max_grad_norm,
    lr,
    n_workers,
    cuda,
    n_init_batches,
    output_dir,
    missing,
    saved_optimizer,
    warmup,
):

    print(output_dir)
    device = "cpu" if (not torch.cuda.is_available() or not cuda) else "cuda:0"
    print(device)
    check_manual_seed(seed)
    print("augmenting?", augment)
    train_dataset, test_dataset = check_dataset(dataset, augment, missing,
                                                seed)
    image_shape = (32, 32, 3)

    multi_class = False

    if yd_condition:
        #num_classes = 10*2
        num_classes = 10 + 2
        multi_class = True
    elif d_condition:
        num_classes = 10
    else:
        num_classes = 2
    print("num classes", num_classes)

    train_loader = data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=n_workers,
        drop_last=True,
    )
    test_loader = data.DataLoader(
        test_dataset,
        batch_size=eval_batch_size,
        shuffle=False,
        num_workers=n_workers,
        drop_last=False,
    )

    model = Glow(image_shape, hidden_channels, K, L, actnorm_scale,
                 flow_permutation, flow_coupling, LU_decomposed, num_classes,
                 learn_top, y_condition, extra_condition, sp_condition,
                 d_condition, yd_condition)

    model = model.to(device)
    optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup)  # noqa
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                  lr_lambda=lr_lambda)

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y, d, yd = batch
        x = x.to(device)

        if y_condition:
            y = y.to(device)
            z, nll, y_logits = model(x, y)
            losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class)
        elif d_condition:
            d = d.to(device)
            z, nll, d_logits = model(x, d)
            d_weight = y_weight
            # multi_class false as only using 2 domains at the moment
            losses = compute_loss_y(nll, d_logits, d_weight, d, multi_class)
        elif yd_condition:
            yd = yd.to(device)
            z, nll, yd_logits = model(x, yd)
            yd_weight = y_weight
            losses = compute_loss_y(nll, yd_logits, yd_weight, yd, multi_class)
        else:
            print("none")
            z, nll, y_logits = model(x, None)
            losses = compute_loss(nll)

        losses["total_loss"].backward()

        if max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        if max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y, d, yd = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits = model(x, y)
                losses = compute_loss_y(nll,
                                        y_logits,
                                        y_weight,
                                        y,
                                        multi_class,
                                        reduction="none")
            elif d_condition:
                d = d.to(device)
                z, nll, d_logits = model(x, d)
                d_weight = y_weight
                losses = compute_loss_y(nll,
                                        d_logits,
                                        d_weight,
                                        d,
                                        multi_class,
                                        reduction="none")
            elif yd_condition:
                yd = yd.to(device)
                z, nll, yd_logits = model(x, yd)
                yd_weight = y_weight
                losses = compute_loss_y(nll,
                                        yd_logits,
                                        yd_weight,
                                        yd,
                                        multi_class,
                                        reduction="none")
            else:

                z, nll, y_logits = model(x, None)
                losses = compute_loss(nll, reduction="none")

        return losses

    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(output_dir,
                                         "glow",
                                         save_interval=1,
                                         n_saved=2,
                                         require_empty=False)

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        checkpoint_handler,
        {
            "model": model,
            "optimizer": optimizer
        },
    )

    monitoring_metrics = ["total_loss"]
    RunningAverage(output_transform=lambda x: x["total_loss"]).attach(
        trainer, "total_loss")

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(
        lambda x, y: torch.mean(x),
        output_transform=lambda x: (
            x["total_loss"],
            torch.empty(x["total_loss"].shape[0]),
        ),
    ).attach(evaluator, "total_loss")

    if y_condition or d_condition or yd_condition:
        monitoring_metrics.extend(["nll"])
        RunningAverage(output_transform=lambda x: x["nll"]).attach(
            trainer, "nll")

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(
            lambda x, y: torch.mean(x),
            output_transform=lambda x:
            (x["nll"], torch.empty(x["nll"].shape[0])),
        ).attach(evaluator, "nll")

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        model.load_state_dict(torch.load(saved_model))
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer))

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split("_")[-1])

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(
                engine.state.dataloader)

    @trainer.on(Events.STARTED)
    def init(engine):
        model.train()

        init_batches = []
        init_targets = []
        init_domains = []
        init_yds = []

        with torch.no_grad():
            for batch, target, domain, yd in islice(train_loader, None,
                                                    n_init_batches):
                init_batches.append(batch)
                init_targets.append(target)
                init_domains.append(domain)
                init_yds.append(yd)

            init_batches = torch.cat(init_batches).to(device)

            assert init_batches.shape[0] == n_init_batches * batch_size

            if y_condition:
                init_targets = torch.cat(init_targets).to(device)
                model(init_batches, init_targets)
            elif d_condition:
                init_domains = torch.cat(init_domains).to(device)
                model(init_batches, init_domains)
            elif yd_condition:
                init_yds = torch.cat(init_yds).to(device)
                model(init_batches, init_yds)
            else:
                init_targets = None
                model(init_batches, init_targets)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(engine):
        evaluator.run(test_loader)

        scheduler.step()
        metrics = evaluator.state.metrics

        losses = ", ".join(
            [f"{key}: {value:.2f}" for key, value in metrics.items()])

        print(f"Validation Results - Epoch: {engine.state.epoch} {losses}")

    def score_function(engine):
        val_loss = engine.state.metrics['total_loss']

        return -val_loss

    name = "best_"

    val_handler = ModelCheckpoint(output_dir,
                                  name,
                                  score_function=score_function,
                                  score_name="val_loss",
                                  n_saved=1,
                                  require_empty=False)

    evaluator.add_event_handler(
        Events.EPOCH_COMPLETED,
        val_handler,
        {"model": model},
    )

    timer = Timer(average=True)
    timer.attach(
        trainer,
        start=Events.EPOCH_STARTED,
        resume=Events.ITERATION_STARTED,
        pause=Events.ITERATION_COMPLETED,
        step=Events.ITERATION_COMPLETED,
    )

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f"Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]"
        )
        timer.reset()

    trainer.run(train_loader, epochs)
Exemplo n.º 19
0
def test_integration():

    n_iters = 100
    batch_size = 10
    n_classes = 10
    y_true_batch_values = iter(
        np.random.randint(0, n_classes, size=(n_iters, batch_size)))
    y_pred_batch_values = iter(np.random.rand(n_iters, batch_size, n_classes))
    loss_values = iter(range(n_iters))

    def update_fn(engine, batch):
        loss_value = next(loss_values)
        y_true_batch = next(y_true_batch_values)
        y_pred_batch = next(y_pred_batch_values)
        return loss_value, torch.from_numpy(y_pred_batch), torch.from_numpy(
            y_true_batch)

    trainer = Engine(update_fn)
    alpha = 0.98

    acc_metric = RunningAverage(
        Accuracy(output_transform=lambda x: [x[1], x[2]]), alpha=alpha)
    acc_metric.attach(trainer, "running_avg_accuracy")

    avg_output = RunningAverage(output_transform=lambda x: x[0], alpha=alpha)
    avg_output.attach(trainer, "running_avg_output")

    running_avg_acc = [
        None,
    ]

    @trainer.on(Events.ITERATION_COMPLETED)
    def manual_running_avg_acc(engine):
        _, y_pred, y = engine.state.output
        indices = torch.max(y_pred, 1)[1]
        correct = torch.eq(indices, y).view(-1)
        num_correct = torch.sum(correct).item()
        num_examples = correct.shape[0]
        batch_acc = num_correct * 1.0 / num_examples
        if running_avg_acc[0] is None:
            running_avg_acc[0] = batch_acc
        else:
            running_avg_acc[0] = running_avg_acc[0] * alpha + (
                1.0 - alpha) * batch_acc
        engine.state.running_avg_acc = running_avg_acc[0]

    @trainer.on(Events.EPOCH_STARTED)
    def running_avg_output_init(engine):
        engine.state.running_avg_output = None

    @trainer.on(Events.ITERATION_COMPLETED)
    def running_avg_output_update(engine):
        if engine.state.running_avg_output is None:
            engine.state.running_avg_output = engine.state.output[0]
        else:
            engine.state.running_avg_output = (
                engine.state.running_avg_output * alpha +
                (1.0 - alpha) * engine.state.output[0])

    @trainer.on(Events.ITERATION_COMPLETED)
    def assert_equal_running_avg_acc_values(engine):
        assert (
            engine.state.running_avg_acc ==
            engine.state.metrics["running_avg_accuracy"]
        ), f"{engine.state.running_avg_acc} vs {engine.state.metrics['running_avg_accuracy']}"

    @trainer.on(Events.ITERATION_COMPLETED)
    def assert_equal_running_avg_output_values(engine):
        assert (
            engine.state.running_avg_output ==
            engine.state.metrics["running_avg_output"]
        ), f"{engine.state.running_avg_output} vs {engine.state.metrics['running_avg_output']}"

    np.random.seed(10)
    running_avg_acc = [
        None,
    ]
    n_iters = 10
    batch_size = 10
    n_classes = 10
    data = list(range(n_iters))
    loss_values = iter(range(n_iters))
    y_true_batch_values = iter(
        np.random.randint(0, n_classes, size=(n_iters, batch_size)))
    y_pred_batch_values = iter(np.random.rand(n_iters, batch_size, n_classes))
    trainer.run(data, max_epochs=1)

    running_avg_acc = [
        None,
    ]
    n_iters = 10
    batch_size = 10
    n_classes = 10
    data = list(range(n_iters))
    loss_values = iter(range(n_iters))
    y_true_batch_values = iter(
        np.random.randint(0, n_classes, size=(n_iters, batch_size)))
    y_pred_batch_values = iter(np.random.rand(n_iters, batch_size, n_classes))
    trainer.run(data, max_epochs=1)
Exemplo n.º 20
0
            "Episode %d: reward=%s, steps=%s, speed=%.3f frames/s, elapsed=%s"
            % (trainer.state.episode, trainer.state.episode_reward,
               trainer.state.episode_steps,
               trainer.state.metrics.get('avg_fps', 0),
               timedelta(seconds=trainer.state.metrics.get('time_passed', 0))))

    @engine.on(ptan_ignite.EpisodeEvents.BOUND_REWARD_REACHED)
    def game_solved(trainer: Engine):
        print("Game solved in %s, after %d episodes and %d iterations!" %
              (timedelta(seconds=trainer.state.metrics['time_passed']),
               trainer.state.episode, trainer.state.iteration))
        trainer.should_terminate = True

    logdir = f"runs/{datetime.now().isoformat(timespec='minutes')}-{params.run_name}-{NAME}"
    tb = tb_logger.TensorboardLogger(log_dir=logdir)
    RunningAverage(output_transform=lambda v: v['loss']).attach(
        engine, "avg_loss")

    episode_handler = tb_logger.OutputHandler(
        tag="episodes", metric_names=['reward', 'steps', 'avg_reward'])
    tb.attach(engine,
              log_handler=episode_handler,
              event_name=ptan_ignite.EpisodeEvents.EPISODE_COMPLETED)

    # write to tensorboard every 100 iterations
    ptan_ignite.PeriodicEvents().attach(engine)
    handler = tb_logger.OutputHandler(tag="train",
                                      metric_names=['avg_loss', 'avg_fps'],
                                      output_transform=lambda a: a)
    tb.attach(engine,
              log_handler=handler,
              event_name=ptan_ignite.PeriodEvents.ITERS_100_COMPLETED)
Exemplo n.º 21
0
def main():
    args = get_args()
    if 'e-SNLI-VE' in args.data_path:
        args.no_image = False
    else:
        args.no_image = True
    if not args.no_image:
        args.no_premise = True
    args.with_expl = True

    '''Setup'''
    t = datetime.today()
    output_dir = os.path.join(args.output_folder,
                              f"{t.month}_{t.day}_{t.hour}_{t.minute}_{t.second}")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(filename=os.path.join(output_dir, 'app.log'),
                        filemode='a',
                        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    # This is a logger.warning: it will be printed by all distributed processes
    logger.warning(f"Running process {args.local_rank}")
    logger.info(f"Arguments: {pformat(args)}")
    logger.info(f'Image not used:{args.no_image}')
    logger.info(f'Premise not used:{args.no_premise}')
    logger.info(f'Explanations used:{args.with_expl}')

    '''Initialize distributed training if needed'''
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info(
        "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning")
    tokenizer = GPT2Tokenizer.from_pretrained(args.model_checkpoint)
    tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
    if args.no_image:
        model = GPT2LMHeadModel.from_pretrained(args.model_checkpoint)
    else:
        import image_gpt2_291
        model = image_gpt2_291.GPT2LMHeadModel.from_pretrained(
            args.model_checkpoint)
    model.resize_token_embeddings(len(tokenizer))
    model.to(args.device)
    optimizer = AdamW(model.parameters(), lr=args.lr)

    '''
    Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    '''
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)
        model = model.module

    logger.info("Prepare datasets")
    train_loader, val_loader = get_data_loaders(args, tokenizer)

    '''Training function and trainer'''
    def train(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        if args.no_image:
            input_ids, lm_label, label, input_mask = batch
        else:
            image, input_ids, lm_label, label, input_mask = batch

        if args.no_image:
            output = model(input_ids=input_ids,
                           #    attention_mask=input_mask,
                           labels=lm_label)
        else:
            output = model(input_ids=input_ids,
                           images=image,
                           #    attention_mask=input_mask,
                           labels=lm_label)
        loss, logits, _ = output

        loss = loss / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(
                amp.master_params(optimizer), args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        if not args.with_expl:
            lbl_accuracy = torch.eq(label, logits.argmax(
                dim=1)).float().sum() / len(label)
            return {
                'loss': loss.item(),
                'lbl_accuracy': lbl_accuracy.item()
            }
        else:
            if engine.state.iteration % (args.gradient_accumulation_steps * 500) == 0:
                input_output = list(zip(input_ids, logits))
                random_item = random.choice(input_output)
                in_sent = tokenizer.decode(list(filter(
                    lambda x: x != tokenizer.eos_token_id,
                    random_item[0])))
                out_expl = tokenizer.decode(random_item[1].argmax(dim=1),
                                            skip_special_tokens=True)
                logger.info(f'MODEL INPUT: {in_sent}')
                logger.info(f'GEN. EXPL {out_expl}')
                logger.info('--------------------------------')
            return {
                'loss': loss.item(),
            }

    '''Validation function and validator (validator output is the input of the metrics)'''
    def validation(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(input_tensor.to(args.device)
                          for input_tensor in batch)
            if args.no_image:
                input_ids, lm_label, label, input_mask = batch
            else:
                image, input_ids, lm_label, label, input_mask = batch

            if args.no_image:
                output = model(input_ids=input_ids,
                               #    attention_mask=input_mask
                               )
            else:
                output = model(input_ids=input_ids,
                               images=image,
                               #    attention_mask=input_mask
                               )
            logits, _ = output

            logits_shifted = logits[..., :-1, :].contiguous().view(-1,
                                                                   logits.size(-1))
            labels_shifted = lm_label[..., 1:].contiguous().view(-1)
            return logits_shifted, labels_shifted

    '''Engines'''
    trainer = Engine(train)
    validator = Engine(validation)

    # t_total = len(
    #     train_loader) // args.gradient_accumulation_steps * args.n_epochs
    # scheduler = get_linear_schedule_with_warmup(
    #     optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
    '''Linearly decrease the learning rate from lr to zero'''
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    '''
    Attach validation to trainer: we evaluate when we start the training and at the end of each epoch
    '''
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: validator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: validator.run(val_loader))

    '''Prepare metrics - note how we compute distributed metrics'''
    RunningAverage(output_transform=lambda x: x['loss']).attach(
        trainer, "loss")
    RunningAverage(output_transform=lambda x: math.exp(
        average_distributed_scalar(x['loss'], args))).attach(trainer, "ppl")
    if not args.with_expl:
        RunningAverage(output_transform=lambda x: 100 * x['lbl_accuracy']).attach(
            trainer, "lbl_accuracy")

    metrics = {}
    metrics["lbl_loss"] = Loss(torch.nn.CrossEntropyLoss(),
                               output_transform=lambda x: (x[0], x[1]))
    metrics["loss"] = MetricsLambda(
        lambda l, a: average_distributed_scalar(
            l / a.gradient_accumulation_steps, a), metrics["lbl_loss"], args)
    metrics["ppl"] = MetricsLambda(math.exp, metrics["loss"])
    if not args.with_expl:
        metrics["lbl_accuracy"] = 100 * \
            Accuracy(output_transform=lambda x: (x[0], x[1]))
    for name, metric in metrics.items():
        metric.attach(validator, name)

    '''
    On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    '''
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer,
                    metric_names=["loss", 'ppl'] if args.with_expl else ["loss", 'lbl_accuracy', 'ppl'])
        validator.add_event_handler(Events.COMPLETED,
                                    lambda _: pbar.log_message(
                                        "Validation: %s" % pformat(validator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=output_dir)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(
                             tag="training",
                             metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(
                             tag="training",
                             metric_names=["ppl"] if args.with_expl else ["lbl_accuracy", "ppl"]),
                         event_name=Events.EPOCH_COMPLETED)

        tb_logger.attach(validator,
                         log_handler=OutputHandler(
                             tag="validation",
                             metric_names=[
                                 'ppl', 'loss'] if args.with_expl else['ppl', 'loss', 'lbl_accuracy'],
                             global_step_transform=lambda *args, **kwargs: trainer.state.iteration),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(output_dir,
                                             'checkpoint',
                                             n_saved=8,
                                             require_empty=False)
        trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1),
                                  checkpoint_handler,
                                  {'mymodel': getattr(model, 'module', model)})

        # "getattr" take care of distributed encapsulation
        torch.save(args, os.path.join(output_dir, 'model_training_args.bin'))
        getattr(model, 'module', model).config.to_json_file(
            os.path.join(output_dir, CONFIG_NAME))
        tokenizer.save_vocabulary(output_dir)

    '''Run the training'''
    trainer.run(train_loader, max_epochs=args.n_epochs)
Exemplo n.º 22
0
def main(
    dataset,
    dataset2,
    dataroot,
    download,
    augment,
    batch_size,
    eval_batch_size,
    nlls_batch_size,
    epochs,
    nb_step,
    saved_model,
    seed,
    hidden_channels,
    K,
    L,
    actnorm_scale,
    flow_permutation,
    flow_coupling,
    LU_decomposed,
    learn_top,
    y_condition,
    y_weight,
    max_grad_clip,
    max_grad_norm,
    lr,
    lr_test,
    n_workers,
    cuda,
    n_init_batches,
    output_dir,
    saved_optimizer,
    warmup,
    every_epoch,
):

    device = "cpu" if (not torch.cuda.is_available() or not cuda) else "cuda:0"

    check_manual_seed(seed)

    ds = check_dataset(dataset, dataroot, augment, download)
    ds2 = check_dataset(dataset2, dataroot, augment, download)
    image_shape, num_classes, train_dataset, test_dataset = ds
    image_shape2, num_classes2, train_dataset_2, test_dataset_2 = ds2

    assert(image_shape == image_shape2)
    data1 = []
    data2 = []
    for k in range(nlls_batch_size):
        dataaux, targetaux = test_dataset[k]
        data1.append(dataaux)
        dataaux, targetaux = test_dataset_2[k]
        data2.append(dataaux)


    # Note: unsupported for now
    multi_class = False

    train_loader = data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=n_workers,
        drop_last=True,
    )
    test_loader = data.DataLoader(
        test_dataset,
        batch_size=eval_batch_size,
        shuffle=False,
        num_workers=n_workers,
        drop_last=False,
    )

    model = Glow(
        image_shape,
        hidden_channels,
        K,
        L,
        actnorm_scale,
        flow_permutation,
        flow_coupling,
        LU_decomposed,
        num_classes,
        learn_top,
        y_condition,
    )

    model = model.to(device)
    optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5)

    lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup)  # noqa
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

    def step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = batch
        x = x.to(device)

        if y_condition:
            y = y.to(device)
            z, nll, y_logits = model(x, y)
            losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class)
        else:
            z, nll, y_logits = model(x, None)
            losses = compute_loss(nll)

        losses["total_loss"].backward()

        if max_grad_clip > 0:
            torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip)
        if max_grad_norm > 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        optimizer.step()

        return losses

    def eval_step(engine, batch):
        model.eval()

        x, y = batch
        x = x.to(device)

        with torch.no_grad():
            if y_condition:
                y = y.to(device)
                z, nll, y_logits = model(x, y)
                losses = compute_loss_y(
                    nll, y_logits, y_weight, y, multi_class, reduction="none"
                )
            else:
                z, nll, y_logits = model(x, None)
                losses = compute_loss(nll, reduction="none")

        return losses

    trainer = Engine(step)
    checkpoint_handler = ModelCheckpoint(
        output_dir, "glow", n_saved=2, require_empty=False
    )

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED,
        checkpoint_handler,
        {"model": model, "optimizer": optimizer},
    )

    monitoring_metrics = ["total_loss"]
    RunningAverage(output_transform=lambda x: x["total_loss"]).attach(
        trainer, "total_loss"
    )

    evaluator = Engine(eval_step)

    # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
    Loss(
        lambda x, y: torch.mean(x),
        output_transform=lambda x: (
            x["total_loss"],
            torch.empty(x["total_loss"].shape[0]),
        ),
    ).attach(evaluator, "total_loss")

    if y_condition:
        monitoring_metrics.extend(["nll"])
        RunningAverage(output_transform=lambda x: x["nll"]).attach(trainer, "nll")

        # Note: replace by https://github.com/pytorch/ignite/pull/524 when released
        Loss(
            lambda x, y: torch.mean(x),
            output_transform=lambda x: (x["nll"], torch.empty(x["nll"].shape[0])),
        ).attach(evaluator, "nll")

    pbar = ProgressBar()
    pbar.attach(trainer, metric_names=monitoring_metrics)

    # load pre-trained model if given
    if saved_model:
        model.load_state_dict(torch.load(saved_model)['model'])
        model.set_actnorm_init()

        if saved_optimizer:
            optimizer.load_state_dict(torch.load(saved_optimizer)['opt'])

        file_name, ext = os.path.splitext(saved_model)
        resume_epoch = int(file_name.split("_")[-1])/1e3

        @trainer.on(Events.STARTED)
        def resume_training(engine):
            engine.state.epoch = resume_epoch
            engine.state.iteration = resume_epoch * len(engine.state.dataloader)

    @trainer.on(Events.STARTED)
    def init(engine):
        model.train()

        init_batches = []
        init_targets = []

        with torch.no_grad():
            print(train_loader)
            for batch, target in islice(train_loader, None, n_init_batches):
                init_batches.append(batch)
                init_targets.append(target)

            init_batches = torch.cat(init_batches).to(device)

            assert init_batches.shape[0] == n_init_batches * batch_size

            if y_condition:
                init_targets = torch.cat(init_targets).to(device)
            else:
                init_targets = None

            model(init_batches, init_targets)

    @trainer.on(Events.EPOCH_COMPLETED)
    def evaluate(engine):
        evaluator.run(test_loader)

        scheduler.step()
        metrics = evaluator.state.metrics

        losses = ", ".join([f"{key}: {value:.2f}" for key, value in metrics.items()])

        print(f"Validation Results - Epoch: {engine.state.epoch} {losses}")

    timer = Timer(average=True)
    timer.attach(
        trainer,
        start=Events.EPOCH_STARTED,
        resume=Events.ITERATION_STARTED,
        pause=Events.ITERATION_COMPLETED,
        step=Events.ITERATION_COMPLETED,
    )

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            f"Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]"
        )
        timer.reset()

    # @trainer.on(Events.EPOCH_COMPLETED)
    # def eval_likelihood(engine):
    #     global_nlls(output_dir, engine.state.epoch, data1, data2, model, dataset1_name = dataset, dataset2_name = dataset2, nb_step = nb_step, every_epoch = every_epoch, optim_default = partial(optim.SGD, lr=1e-5, momentum = 0.))


    trainer.run(train_loader, epochs)
Exemplo n.º 23
0
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval, log_dir):
    train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
    model = Net()
    writer = SummaryWriter(log_dir=log_dir)

    # Use TPU device
    device = xm.xla_device()

    model.to(device)  # Move model before creating optimizer
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)

    # Create trainer and evaluator
    trainer = create_supervised_trainer(
        model, optimizer, F.nll_loss, device=device, output_transform=lambda x, y, y_pred, loss: [loss.item(),]
    )

    evaluator = create_supervised_evaluator(
        model, metrics={"accuracy": Accuracy(), "nll": Loss(F.nll_loss)}, device=device
    )

    tracker = xm.RateTracker()

    # Add RateTracker as an output of the training step
    @trainer.on(Events.ITERATION_COMPLETED)
    def add_rate_tracker(engine):
        tracker.add(len(engine.state.batch))
        engine.state.output.append(tracker.global_rate())

    # Setup output values of the training step as EMA metrics
    RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "batch_loss")
    RunningAverage(output_transform=lambda x: x[1]).attach(trainer, "global_rate")

    # Let's log the EMA metrics every `log_interval` iterations
    @trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
    def log_training_loss(engine):
        writer.add_scalar("training/batch_loss", engine.state.metrics["batch_loss"], engine.state.iteration)
        writer.add_scalar("training/global_rate", engine.state.metrics["global_rate"], engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        evaluator.run(train_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        print(
            "Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}".format(
                engine.state.epoch, avg_accuracy, avg_nll
            )
        )
        writer.add_scalar("training/avg_loss", avg_nll, engine.state.epoch)
        writer.add_scalar("training/avg_accuracy", avg_accuracy, engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        avg_accuracy = metrics["accuracy"]
        avg_nll = metrics["nll"]
        print(
            "Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}".format(
                engine.state.epoch, avg_accuracy, avg_nll
            )
        )
        writer.add_scalar("valdation/avg_loss", avg_nll, engine.state.epoch)
        writer.add_scalar("valdation/avg_accuracy", avg_accuracy, engine.state.epoch)

    # kick everything off
    trainer.run(train_loader, max_epochs=epochs)

    writer.close()
Exemplo n.º 24
0
def get_trainer(config, device=torch.device("cuda"), mobilenet=False):

    if mobilenet:
        cfg = config["model"]["generator1"]
        generator1 = MobileGenerator1(3 + 18, cfg["num_repeat"],
                                      cfg["middle_features_dim"],
                                      cfg["channels_base"], cfg["image_size"])
        generator1.to(device)
        generator1.load_state_dict(
            torch.load(cfg["pretrained_path"], map_location="cpu"))

        cfg = config["model"]["generator2"]
        generator2 = MobileGenerator2(3 + 3,
                                      cfg["channels_base"],
                                      cfg["num_repeat"],
                                      cfg["num_skip_out_connect"],
                                      weight_init_way=cfg["weight_init_way"])
        generator2.to(device)
        print(generator2)
    else:
        cfg = config["model"]["generator1"]
        generator1 = Generator1(3 + 18, cfg["num_repeat"],
                                cfg["middle_features_dim"],
                                cfg["channels_base"], cfg["image_size"])
        generator1.to(device)
        generator1.load_state_dict(
            torch.load(cfg["pretrained_path"], map_location="cpu"))

        cfg = config["model"]["generator2"]
        generator2 = Generator2(3 + 3,
                                cfg["channels_base"],
                                cfg["num_repeat"],
                                cfg["num_skip_out_connect"],
                                weight_init_way=cfg["weight_init_way"])
        generator2.to(device)
        print(generator2)

    discriminator = Discriminator(
        weight_init_way=config["model"]["discriminator"]["weight_init_way"])
    discriminator.to(device)
    print(discriminator)

    cfg = config["train"]["generator2"]
    generator2_optimizer = optim.Adam(generator2.parameters(),
                                      lr=cfg["lr"],
                                      betas=(cfg["beta1"], cfg["beta2"]))
    cfg = config["train"]["discriminator"]
    discriminator_optimizer = optim.Adam(discriminator.parameters(),
                                         lr=cfg["lr"],
                                         betas=(cfg["beta1"], cfg["beta2"]))

    mask_l1_loss = MaskL1Loss(config["loss"]["mask_l1"]["mask_ratio"])
    mask_l1_loss.to(device)
    adversarial_loss = torch.nn.BCEWithLogitsLoss()
    adversarial_loss.to(device)

    real_labels = torch.ones((config["train"]["batch_size"], 1), device=device)
    fake_labels = torch.zeros((config["train"]["batch_size"], 1),
                              device=device)

    def _step(engine, batch):
        batch = convert_tensor(batch, device)
        with torch.no_grad():
            generated_img_1 = generator1(batch["condition_img"],
                                         batch["target_bone"])
        generated_img = generated_img_1 + generator2(batch["condition_img"],
                                                     generated_img_1)

        generator2_optimizer.zero_grad()
        g2_gan_loss = adversarial_loss(discriminator(generated_img),
                                       real_labels)
        g2_mask_l1_loss = mask_l1_loss(generated_img, batch["target_img"],
                                       batch["target_mask"])
        g2_loss = config["loss"]["mask_l1"]["weight"] * g2_mask_l1_loss + \
                  config["loss"]["gan"]["weight"] * g2_gan_loss
        g2_loss.backward()
        generator2_optimizer.step()

        discriminator_optimizer.zero_grad()
        d_real_loss = adversarial_loss(discriminator(batch["target_img"]),
                                       real_labels)
        d_fake_loss = adversarial_loss(discriminator(generated_img.detach()),
                                       fake_labels)
        d_loss_1 = (d_fake_loss + d_real_loss) / 2

        d_real_loss = adversarial_loss(discriminator(batch["target_img"]),
                                       real_labels)
        d_fake_loss = adversarial_loss(discriminator(batch["condition_img"]),
                                       fake_labels)
        d_loss_2 = (d_fake_loss + d_real_loss) / 2

        d_loss = (d_loss_1 + d_loss_2) / 2
        d_loss.backward()
        discriminator_optimizer.step()

        return {
            "loss": {
                "g2_mask_l1_loss": g2_mask_l1_loss.item(),
                "g2_gan_loss": g2_gan_loss.item(),
                "g2_loss": g2_loss.item(),
                "d_loss": d_loss.item(),
                "d_loss_1": d_loss_1.item(),
                "d_loss_2": d_loss_2.item(),
            },
            "img": {
                "mask_img": batch["target_mask"].detach(),
                "condition_img": batch["condition_img"].detach(),
                "target_img": batch["target_img"].detach(),
                "generated_img_1": generated_img_1.detach(),
                "generated_img": generated_img.detach(),
            }
        }

    trainer = Engine(_step)

    RunningAverage(
        output_transform=lambda x: x["loss"]['g2_mask_l1_loss']).attach(
            trainer, 'g2_mask_l1_loss')
    RunningAverage(output_transform=lambda x: x["loss"]['g2_gan_loss']).attach(
        trainer, 'g2_gan_loss')
    RunningAverage(output_transform=lambda x: x["loss"]['g2_loss']).attach(
        trainer, 'g2_loss')
    RunningAverage(output_transform=lambda x: x["loss"]['d_loss_1']).attach(
        trainer, 'd_loss_1')
    RunningAverage(output_transform=lambda x: x["loss"]['d_loss']).attach(
        trainer, 'd_loss')
    RunningAverage(output_transform=lambda x: x["loss"]['d_loss_2']).attach(
        trainer, 'd_loss_2')

    ProgressBar(ncols=0).attach(trainer, ["g2_loss", "d_loss"])

    mcp = ModelCheckpoint(
        config["output"],
        "network",
        save_interval=config["log"]["model_checkpoint"]["save_interval"],
        n_saved=config["log"]["model_checkpoint"]["n_saved"],
        require_empty=False,
        save_as_state_dict=True,
        create_dir=True)
    trainer.add_event_handler(Events.ITERATION_COMPLETED,
                              mcp,
                              to_save={
                                  "G2": generator2,
                                  "D": discriminator
                              })

    check_cpe = CustomPeriodicEvent(n_iterations=config["log"]["check_freq"])
    check_cpe.attach(trainer)
    CHECK_EVENT = getattr(
        check_cpe.Events,
        "ITERATIONS_{}_COMPLETED".format(config["log"]["check_freq"]))

    loss_cpe = CustomPeriodicEvent(n_iterations=config["log"]["loss_freq"])
    loss_cpe.attach(trainer)
    LOSS_EVENT = getattr(
        loss_cpe.Events,
        "ITERATIONS_{}_COMPLETED".format(config["log"]["loss_freq"]))

    tb_logger = TensorboardLogger(config["output"])
    tb_writer = tb_logger.writer

    loss_gst = custom_global_step_transform(config["log"]["loss_freq"])
    check_gst = custom_global_step_transform(config["log"]["check_freq"])

    check_handlers = [
        (OutputHandler(
            tag="G2",
            metric_names=["g2_mask_l1_loss", "g2_gan_loss", "g2_loss"],
            global_step_transform=loss_gst), LOSS_EVENT),
        (OutputHandler(tag="D",
                       metric_names=["d_loss_1", "d_loss_2", "d_loss"],
                       global_step_transform=loss_gst), LOSS_EVENT),
        (OptimizerParamsHandler(discriminator_optimizer,
                                param_name="lr",
                                tag="D",
                                global_step_transform=check_gst), CHECK_EVENT),
        (OptimizerParamsHandler(generator2_optimizer,
                                param_name="lr",
                                tag="G2",
                                global_step_transform=check_gst), CHECK_EVENT),
        (WeightsHistHandler(generator2,
                            tag="G2",
                            global_step_transform=check_gst), CHECK_EVENT),
        (WeightsHistHandler(discriminator,
                            tag="D",
                            global_step_transform=check_gst), CHECK_EVENT),
    ]

    for ch, e in check_handlers:
        tb_logger.attach(trainer, log_handler=ch, event_name=e)

    val_data_pair = get_val_data_pairs(config)
    val_data_pair = convert_tensor(val_data_pair, device)

    @trainer.on(CHECK_EVENT)
    def log(engine):
        # from python3.7 dict will keep order so that .values() will result in same output
        tb_writer.add_image('Train/image',
                            make_2d_grid(engine.state.output["img"].values()),
                            engine.state.iteration)
        with torch.no_grad():
            generator1.eval()
            generator2.eval()
            generated_img_1 = generator1(val_data_pair["condition_img"],
                                         val_data_pair["target_bone"])
            generated_img = generator2(val_data_pair["condition_img"],
                                       generated_img_1) + generated_img_1
            output_imgs = [
                val_data_pair["target_mask"], val_data_pair["condition_img"],
                val_data_pair["target_img"], generated_img_1, generated_img
            ]
            tb_writer.add_image('Test/image', make_2d_grid(output_imgs),
                                engine.state.iteration)
            generator1.train()
            generator2.train()

    return trainer
Exemplo n.º 25
0
def run(train_loader, val_loader, epochs, lr, momentum, weight_decay, lr_step,
        k1, k2, es_patience, log_dir):
    model = Vgg16()

    device = 'cpu'
    if torch.cuda.is_available():
        device = 'cuda'
    model.to(device)

    optimizer = optim.SGD(model.parameters(),
                          lr=lr,
                          momentum=momentum,
                          weight_decay=weight_decay)

    lr_scheduler = ExponentialLR(optimizer, gamma=0.975)

    # criterion = VAELoss(k1=k1, k2=k2).to(device)

    def update_fn(engine, batch):
        x, y = _prepare_batch(batch, device=device, non_blocking=True)

        model.train()

        optimizer.zero_grad()

        output = model(x)

        # Compute loss
        loss = F.nll_loss(output, y)

        loss.backward()

        optimizer.step()

        return {
            "batchloss": loss.item(),
        }

    trainer = Engine(update_fn)

    try:
        GpuInfo().attach(trainer)
    except RuntimeError:
        print(
            "INFO: By default, in this example it is possible to log GPU information (used memory, utilization). "
            "As there is no pynvml python package installed, GPU information won't be logged. Otherwise, please "
            "install it : `pip install pynvml`")

    trainer.add_event_handler(Events.ITERATION_COMPLETED(every=lr_step),
                              lambda engine: lr_scheduler.step())

    metric_names = [
        'batchloss',
    ]

    def output_transform(x, name):
        return x[name]

    for n in metric_names:
        # We compute running average values on the output (batch loss) across all devices
        RunningAverage(output_transform=partial(output_transform, name=n),
                       epoch_bound=False,
                       device=device).attach(trainer, n)

    exp_name = datetime.now().strftime("%Y%m%d-%H%M%S")
    log_path = log_dir + "/vgg_vae/{}".format(exp_name)

    tb_logger = TensorboardLogger(log_dir=log_path)

    tb_logger.attach(trainer,
                     log_handler=OutputHandler(tag="training",
                                               metric_names=metric_names),
                     event_name=Events.ITERATION_COMPLETED)

    tb_logger.attach(trainer,
                     log_handler=OptimizerParamsHandler(optimizer, "lr"),
                     event_name=Events.ITERATION_STARTED)

    ProgressBar(persist=True,
                bar_format="").attach(trainer,
                                      event_name=Events.EPOCH_STARTED,
                                      closing_event_name=Events.COMPLETED)
    ProgressBar(persist=False, bar_format="").attach(trainer,
                                                     metric_names=metric_names)

    # val process definition
    def loss_output_transform(output):
        return output

    def acc_output_transform(output):
        return output

    customed_loss = Loss(loss_fn=F.nll_loss,
                         output_transform=loss_output_transform,
                         device=device)
    customed_accuracy = Accuracy(output_transform=acc_output_transform,
                                 device=device)

    metrics = {'Loss': customed_loss, 'Accuracy': customed_accuracy}

    def val_update_fn(engine, batch):
        model.eval()
        with torch.no_grad():
            x, y = _prepare_batch(batch, device=device, non_blocking=True)
            output = model(x)
            return output, y

    val_evaluator = Engine(val_update_fn)

    for name, metric in metrics.items():
        metric.attach(val_evaluator, name)

    def run_evaluation(engine):
        val_evaluator.run(val_loader)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, run_evaluation)
    trainer.add_event_handler(Events.COMPLETED, run_evaluation)

    ProgressBar(persist=False, desc="Train evaluation").attach(val_evaluator)

    # Log val metrics:
    tb_logger.attach(val_evaluator,
                     log_handler=OutputHandler(tag="val",
                                               metric_names=list(
                                                   metrics.keys()),
                                               another_engine=trainer),
                     event_name=Events.EPOCH_COMPLETED)

    # trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())

    # Store the best model
    def default_score_fn(engine):
        score = engine.state.metrics['Accuracy']
        return score

    best_model_handler = ModelCheckpoint(dirname=log_path,
                                         filename_prefix="best",
                                         n_saved=3,
                                         score_name="val_acc",
                                         score_function=default_score_fn)
    val_evaluator.add_event_handler(Events.COMPLETED, best_model_handler, {
        'model': model,
    })

    # Add early stopping
    es_patience = es_patience
    es_handler = EarlyStopping(patience=es_patience,
                               score_function=default_score_fn,
                               trainer=trainer)
    val_evaluator.add_event_handler(Events.COMPLETED, es_handler)

    setup_logger(es_handler._logger)
    setup_logger(logging.getLogger("ignite.engine.engine.Engine"))

    def empty_cuda_cache(engine):
        torch.cuda.empty_cache()
        import gc
        gc.collect()

    trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache)
    val_evaluator.add_event_handler(Events.COMPLETED, empty_cuda_cache)

    trainer.run(train_loader, max_epochs=epochs)
Exemplo n.º 26
0
def train():
    os.environ['CUDA_VISIBLE_DEVICES'] = '7'

    parser = ArgumentParser()
    parser.add_argument('--gpt2', action='store_true', help="use gpt2")
    parser.add_argument("--model_checkpoint", type=str, default="uer/gpt2-chinese-cluecorpussmall", help="Path or URL of the model")
    parser.add_argument("--from_step", type=int, default=-1, help="Init learning rate from this step")
    parser.add_argument('--pretrained', action='store_true', help="If False train from scratch")
    parser.add_argument("--data_path", type=str, default="data/autocloze.json",
                        help="Path or url of the dataset. ")
    parser.add_argument("--train_path", type=str, default="data/toy_train.txt",
                        help="Path of the train dataset for dist dataset. ")
    parser.add_argument("--valid_path", type=str, default="data/toy_valid.txt",
                        help="Path of the valid dataset for dist dataset. ")
    #--------------------------------------------------------------
    parser.add_argument("--dataset_cache", type=str, default="dataset_zh",
                        help="Path or url of the dataset cache")
    parser.add_argument('--log_file', '-log_file', type=str, default="", help="Output logs to a file under this path")
    parser.add_argument("--num_workers", type=int, default=8, help="Number of subprocesses for data loading")
    parser.add_argument("--n_epochs", type=int, default=40, help="Number of training epochs")
    parser.add_argument("--train_batch_size", type=int, default=1, help="Batch size for training")
    parser.add_argument("--valid_batch_size", type=int, default=1, help="Batch size for validation")
    parser.add_argument("--max_history", type=int, default=15, help="Number of previous exchanges to keep in history")
    parser.add_argument("--scheduler", type=str, default="noam", choices=['noam', 'linear'], help="method of optim")
    parser.add_argument("--n_emd", type=int, default=768, help="Number of n_emd in config file (for noam)")
    parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate")
    parser.add_argument("--eval_before_start", action='store_true',
                        help="If true start with a first evaluation before training")
    parser.add_argument("--warmup_steps", type=int, default=5000, help="Warm up steps")
    parser.add_argument("--valid_steps", type=int, default=5000, help="Perfom validation every X steps")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=64,
                        help="Accumulate gradients on several steps")
    parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument("--fp16", type=str, default="",
                        help="Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument("--local_rank", type=int, default=-1,
                        help="Local rank for distributed training (-1: not distributed)")
    args = parser.parse_args()
    print('cuda ',torch.cuda.is_available())
    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process.
    # logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning("Running process %d", args.local_rank)
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    '''if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
    '''
    args.device = torch.device("cuda")
    print('device ',args.device)
    logger.info("Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning")
    #model_class = OpenAIGPTLMHeadModel if not args.gpt2 else GPT2LMHeadModel
    #config_class = OpenAIGPTConfig if not args.gpt2 else GPT2Config
    model_class = GPT2LMHeadModel
    config_class = GPT2Config
    tokenizer_class = BertTokenizer
    print('pretrained:',args.pretrained)
    if args.pretrained:
        print("----------------pretrained")
        tokenizer = BertTokenizer.from_pretrained(args.model_checkpoint, do_lower_case=True)
        model = GPT2LMHeadModel.from_pretrained(args.model_checkpoint)
    else:
        tokenizer = BertTokenizer.from_pretrained("uer/gpt2-chinese-cluecorpussmall")
        model = GPT2LMHeadModel.from_pretrained("uer/gpt2-chinese-cluecorpussmall",from_tf=True)
        #print('generate')
        #print(text_generator("这是很久之前的事情了", max_length=100, do_sample=True))

    #args.device=torch.device("cuda", 2)
    
    model.to(args.device)
    
    optimizer = AdamW([{'params': model.parameters(), 'initial_lr': args.lr}], lr=args.lr, correct_bias=True)

    logger.info("Prepare datasets")
    loader_class = build_dist_loaders if not args.data_path else build_dataloaders
    train_loader, val_loader, train_sampler, valid_sampler = loader_class(args, tokenizer, logger)

    logger.info("Prepare datasets ends")
    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)
        model=model.module
    #if isinstance(model,torch.nn.DataParallel):
    
    #print('params:',params_count(model))

    #tokens_embed = model.transformer.get_input_embeddings()
    # Training function and trainer
    def update(engine, batch):
        input_ids, token_type_ids, lm_labels = tuple(input_tensor.to(args.device) for input_tensor in batch)
        
        #for i in range(input_ids.size()[0]):
        #    for j in range(input_ids.size()[1]):
        #        if input_ids[i,j]==-1:
        #            input_ids[i,j]=-100
        #        if lm_labels[i,j]==-1:
        #            lm_labels[i,j]=-100
        #one=torch.tensor(-100)
        #input_ids=torch.where(input_ids==-1,one,input_ids)
        #lm_labels=torch.where(lm_labels==-1,one,lm_labels)
        #print('traindata',input_ids,lm_labels)

        #lm_labels=input_ids
        r'''input_shape = input_ids.siz`e`()
        input_ids = input_ids.view(-1, input_shape[-1])
        inputs_embeds = tokens_embed(input_ids) * math.sqrt(tokens_embed.embedding_dim)'''

        model.train()
        #(lm_loss), *_ = model(inputs_embeds=inputs_embeds, labels=lm_labels,return_dict=0)
        (lm_loss), *_ = model(input_ids=input_ids, labels=lm_labels,return_dict=False)
        #print('lm_loss',lm_loss)
        loss = lm_loss / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item(), optimizer.param_groups[0]['lr']

    trainer = Engine(update)
    

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    cntepoch=0
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            input_ids, token_type_ids, lm_labels = tuple(input_tensor.to(args.device) for input_tensor in batch)
            # logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            #one = torch.tensor(-100)
            #input_ids=torch.where(input_ids==-1,one,input_ids)
            #print('validdata',input_ids,lm_labels)
            #lm_labels=input_ids
            r'''input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
            inputs_embeds = tokens_embed(input_ids) * math.sqrt(tokens_embed.embedding_dim)'''
            

            #lm_logits, *_ = model(inputs_embeds=inputs_embeds,return_dict=0)
            lm_logits, *_ = model(input_ids=input_ids,return_dict=False)
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return lm_logits_flat_shifted, lm_labels_flat_shifted
        cntepoch+=1
        torch.save(args, tb_logger.writer.logdir + '_%s/model_training_args.bin'%(str(cntepoch)))

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader))

    # Evaluation during training
    @trainer.on(Events.ITERATION_STARTED)
    def log_iterations(engine):
        # if engine.state.iteration % max(int(0.1 * len(train_loader)), 1) == 0:
        if engine.state.iteration % args.valid_steps == 0:
            evaluator.run(val_loader)

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # noam decrease the learning rate
    # model_size = model.config.n_embd
    model_size = args.n_emd
    noam_lambda = lambda step: (
            model_size ** (-0.5) * min((step + 1) ** (-0.5), (step + 1) * args.warmup_steps ** (-1.5)))
    noam_scheduler = LambdaLR(optimizer, lr_lambda=noam_lambda, last_epoch=args.from_step)
    scheduler = LRScheduler(noam_scheduler)
    if args.scheduler == "linear":
        scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "loss")
    RunningAverage(output_transform=lambda x: x[1]).attach(trainer, "lr")
    metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-100), output_transform=lambda x: (x[0], x[1]))}
    metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args)})
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints
    # And save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True, mininterval=2)
        pbar.attach(trainer, metric_names=["loss", "lr"])
        evaluator.add_event_handler(Events.COMPLETED,
                                    lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=None)
        tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()),
                                                              another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(tb_logger.writer.logdir, 'checkpoint', save_interval=1, n_saved=6)
        # save model after evaluation
        evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {
            'mymodel': getattr(model, 'module', model)})
        trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {
            'mymodel': getattr(model, 'module', model)})  # "getattr" take care of distributed encapsulation

        torch.save(args, tb_logger.writer.logdir + '/model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(os.path.join(tb_logger.writer.logdir, CONFIG_NAME))
        tokenizer.save_vocabulary(tb_logger.writer.logdir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)
    
    # On the main process: close tensorboard logger and rename the last checkpoint
    # (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(checkpoint_handler._saved[-1][1][-1],
                  os.path.join(tb_logger.writer.logdir,
                               WEIGHTS_NAME))  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Exemplo n.º 27
0
def train():
    config_file = "configs/train_emotion_recognition_config.json"
    config = Config.from_json_file(config_file)

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(level=logging.INFO if config.local_rank in [-1, 0] else logging.WARN)
    logger.warning("Running process %d", config.local_rank)  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(config))

    # Initialize distributed training if needed
    config.distributed = (config.local_rank != -1)
    if config.distributed:
        torch.cuda.set_device(config.local_rank)
        config.device = torch.device("cuda", config.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')

    logger.info("Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning")
    tokenizer_class = GPT2Tokenizer if "gpt2" in config.model_checkpoint else OpenAIGPTTokenizer
    tokenizer = tokenizer_class.from_pretrained(config.model_checkpoint)
    model_class = OpenAIGPTDoubleHeadLMEmotionRecognitionModel
    model = model_class.from_pretrained(config.model_checkpoint)
    tokenizer.set_special_tokens(SPECIAL_TOKENS)
    model.set_num_special_tokens(len(SPECIAL_TOKENS))
    model.to(config.device)
    optimizer = OpenAIAdam(model.parameters(), lr=config.lr)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if config.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model, optimizer, opt_level=config.fp16)
    if config.distributed:
        model = DistributedDataParallel(model, device_ids=[config.local_rank], output_device=config.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(config, tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids = tuple(input_tensor.to(config.device) for input_tensor in batch)
        #token_emotion_ids = None
        lm_loss, mc_loss = model(input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids)
        loss = (lm_loss * config.lm_coef + mc_loss * config.mc_coef) / config.gradient_accumulation_steps
        if config.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_norm)
        if engine.state.iteration % config.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()
    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(input_tensor.to(config.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids = batch
            #token_emotion_ids = None
            model_outputs = model(input_ids, mc_token_ids, token_type_ids=token_type_ids, token_emotion_ids=token_emotion_ids)
            lm_logits, mc_logits = model_outputs[0], model_outputs[1]  # So we can also use GPT2 outputs
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels)
    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader))
    if config.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader))
    if config.eval_before_start:
        trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if config.distributed:
        trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr", [(0, config.lr), (config.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])),
               "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))}

    metrics.update({"precision": Precision(output_transform=lambda x: (x[0][1], x[1][1])),
                    "recall": Recall(output_transform=lambda x: (x[0][1], x[1][1]))})

    metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], config),
                    "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], config)})

    metrics.update({"confusion_matrix": ConfusionMatrix(num_classes=6, output_transform=lambda x: (x[0][1], x[1][1]))})
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if config.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=config.log_dir)
        tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir, 'checkpoint', save_interval=1, n_saved=3)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model)})  # "getattr" take care of distributed encapsulation

        torch.save(config, tb_logger.writer.log_dir + '/model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(os.path.join(tb_logger.writer.log_dir, CONFIG_NAME))
        tokenizer.save_vocabulary(tb_logger.writer.log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=config.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if config.local_rank in [-1, 0] and config.n_epochs > 0:
        os.rename(checkpoint_handler._saved[-1][1][-1], os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME))  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Exemplo n.º 28
0
        optimizer.step()
        return loss.item()

    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            text, visual, acoustic = batch.text, batch.visual, batch.acoustic
            y = batch.label
            y_pred = model(text, visual, acoustic)
            return y_pred, y

    trainer = Engine(update)
    train_evaluator = Engine(inference)
    validation_evaluator = Engine(inference)
    # add metric
    RunningAverage(output_transform=lambda x: x).attach(trainer,
                                                        'loss')  # for pbar
    Accuracy().attach(train_evaluator, 'acc')
    Loss(loss_fn).attach(train_evaluator, 'bce')
    Accuracy().attach(validation_evaluator, 'acc')
    Loss(loss_fn).attach(validation_evaluator, 'bce')

    # add Progress Bar
    pbar = ProgressBar(persist=True, bar_format='')
    pbar.attach(trainer, ['loss'])

    # add early stopping
    def score_fn_1(engine):
        val_loss = engine.state.metrics['bce']
        return -val_loss

    early_stop = EarlyStopping(patience=10,
Exemplo n.º 29
0
 def attach_running_average(engine, metric_name):
     RunningAverage(output_transform=lambda x: x[metric_name]).attach(
         engine,
         metric_name,
     )
Exemplo n.º 30
0
def do_train(
    cfg,
    model,
    train_loader,
    val_loader,
    optimizer,
    scheduler,
    loss_fn,
    num_query,
    start_epoch,
):
    log_period = cfg.SOLVER.LOG_PERIOD
    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
    eval_period = cfg.SOLVER.EVAL_PERIOD
    output_dir = cfg.OUTPUT_DIR
    device = cfg.MODEL.DEVICE
    epochs = cfg.SOLVER.MAX_EPOCHS

    logger = logging.getLogger("reid_baseline.train")
    logger.info("Start training")
    trainer = create_supervised_trainer(cfg,
                                        model,
                                        optimizer,
                                        loss_fn,
                                        device=device)
    evaluator = create_supervised_evaluator(
        model,
        metrics={
            'r1_mAP': R1_mAP(num_query,
                             max_rank=50,
                             feat_norm=cfg.TEST.FEAT_NORM)
        },
        device=device)
    checkpointer = ModelCheckpoint(output_dir,
                                   cfg.MODEL.NAME,
                                   checkpoint_period,
                                   n_saved=10,
                                   require_empty=False)
    timer = Timer(average=True)

    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict()
    })
    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss')
    RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc')
    if cfg.MODEL.METRIC_LOSS_TYPE == 'ours' or cfg.MODEL.METRIC_LOSS_TYPE == 'ours_triplet':
        RunningAverage(output_transform=lambda x: x[2]).attach(
            trainer, 'avg_proxypos')
        RunningAverage(output_transform=lambda x: x[3]).attach(
            trainer, 'avg_proxyneg')
        RunningAverage(output_transform=lambda x: x[4]).attach(
            trainer, 'avg_possim')
        RunningAverage(output_transform=lambda x: x[5]).attach(
            trainer, 'avg_negsim')
    map_list = []

    @trainer.on(Events.STARTED)
    def start_training(engine):
        engine.state.epoch = start_epoch

    @trainer.on(Events.EPOCH_STARTED)
    def adjust_learning_rate(engine):
        if cfg.MODEL.METRIC_LOSS_TYPE == 'ours' or cfg.MODEL.METRIC_LOSS_TYPE == 'triplets':
            pass
        else:
            scheduler.step()

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iter = (engine.state.iteration - 1) % len(train_loader) + 1
        if iter % log_period == 0:
            if cfg.MODEL.METRIC_LOSS_TYPE == 'ours' or cfg.MODEL.METRIC_LOSS_TYPE == 'ours_triplet':
                logger.info(
                    "Epoch[{}] Iteration[{}/{}] Loss: {:.3f}\tAcc: {:.3f}\nProxyPos: {:.3f}\tProxyNeg: {:.3f}\tPosSim {:.3f}\tNegSim {:.3f}\tBase Lr: {:.2e}"
                    .format(engine.state.epoch, iter, len(train_loader),
                            engine.state.metrics['avg_loss'],
                            engine.state.metrics['avg_acc'],
                            engine.state.metrics['avg_proxypos'],
                            engine.state.metrics['avg_proxyneg'],
                            engine.state.metrics['avg_possim'],
                            engine.state.metrics['avg_negsim'],
                            optimizer.param_groups[0]['lr']))
            else:
                logger.info(
                    "Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}"
                    .format(engine.state.epoch, iter, len(train_loader),
                            engine.state.metrics['avg_loss'],
                            engine.state.metrics['avg_acc'],
                            scheduler.get_lr()[0]))

    # adding handlers using `trainer.on` decorator API
    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        logger.info(
            'Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]'
            .format(engine.state.epoch,
                    timer.value() * timer.step_count,
                    train_loader.batch_size / timer.value()))
        logger.info('-' * 10)
        timer.reset()

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        #import pdb; pdb.set_trace()
        if engine.state.epoch % eval_period == 0:
            evaluator.run(val_loader)
            cmc, mAP = evaluator.state.metrics['r1_mAP']
            logger.info("Validation Results - Epoch: {}".format(
                engine.state.epoch))
            logger.info("mAP: {:.1%}".format(mAP))
            for r in [1, 5, 10]:
                logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(
                    r, cmc[r - 1]))
            map_list.append(mAP)
            if cfg.MODEL.METRIC_LOSS_TYPE == 'ours' or cfg.MODEL.METRIC_LOSS_TYPE == 'triplets':
                if optimizer.param_groups[0][
                        'lr'] == 3.5e-4 or optimizer.param_groups[0][
                            'lr'] == 1e-4:
                    tolenrance = 3
                    #if engine.state.epoch > 20:
                    #    tolenrance = 1
                elif optimizer.param_groups[0]['lr'] == 7.0e-5:
                    tolenrance = 3
                elif optimizer.param_groups[0]['lr'] == 1.4e-5:
                    tolenrance = 3
                elif optimizer.param_groups[0]['lr'] == 3.5e-5:
                    tolenrance = 6
                else:
                    tolenrance = 1000
                #map_list.append(mAP)
                if len(map_list) > tolenrance and max(
                        map_list[-tolenrance:]) < max(map_list[:-tolenrance]):
                    adjust_learning_rate_auto(cfg.MODEL.ADJUST_LR, optimizer)

            #logger.info(map_list)
            logger.info('The max mAP is {:.1%}'.format(max(map_list)))
            logger.info('The max mAP is Epoch {}'.format(
                map_list.index(max(map_list))))

    trainer.run(train_loader, max_epochs=epochs)