def _test_warning():
    def compute_fn(y_preds, y_targets):
        return 0.0

    with pytest.warns(
            RuntimeWarning,
            match="EpochMetric class does not support distributed setting"):
        EpochMetric(compute_fn)
Exemple #2
0
def _test_distrib_integration(device=None):

    if device is None:
        device = idist.device() if idist.device().type != "xla" else "cpu"

    rank = idist.get_rank()
    torch.manual_seed(12)

    n_iters = 60
    s = 16
    n_classes = 7

    offset = n_iters * s
    y_true = torch.randint(0,
                           n_classes,
                           size=(offset * idist.get_world_size(), ),
                           device=device)
    y_preds = torch.rand(offset * idist.get_world_size(),
                         n_classes,
                         device=device)

    def update(engine, i):
        return (
            y_preds[i * s + rank * offset:(i + 1) * s + rank * offset, :],
            y_true[i * s + rank * offset:(i + 1) * s + rank * offset],
        )

    engine = Engine(update)

    def assert_data_fn(all_preds, all_targets):
        assert all_preds.equal(
            y_preds), f"{all_preds.shape} vs {y_preds.shape}"
        assert all_targets.equal(
            y_true), f"{all_targets.shape} vs {y_true.shape}"
        return (all_preds.argmax(dim=1) == all_targets).sum().item()

    ep_metric = EpochMetric(assert_data_fn,
                            check_compute_fn=False,
                            device=device)
    ep_metric.attach(engine, "epm")

    data = list(range(n_iters))
    engine.run(data=data, max_epochs=3)
    assert engine.state.metrics["epm"] == (y_preds.argmax(
        dim=1) == y_true).sum().item()
def test_check_compute_fn():
    def compute_fn(y_preds, y_targets):
        raise Exception

    em = EpochMetric(compute_fn, check_compute_fn=True)

    em.reset()
    output1 = (torch.rand(4,
                          3), torch.randint(0,
                                            2,
                                            size=(4, 3),
                                            dtype=torch.long))
    with pytest.warns(
            EpochMetricWarning,
            match=r"Probably, there can be a problem with `compute_fn`"):
        em.update(output1)

    em = EpochMetric(compute_fn, check_compute_fn=False)
    em.update(output1)
def test_bad_compute_fn():

    def compute_fn(y_preds, y_targets):
        # Following will raise the error: Expected object of type torch.FloatTensor but found type
        # torch.LongTensor for argument #3 'other'
        return torch.mean(y_preds - y_targets).item()

    em = EpochMetric(compute_fn)

    em.reset()
    output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long))
    with pytest.warns(RuntimeWarning):
        em.update(output1)
def test_bad_compute_fn():
    def compute_fn(y_preds, y_targets):
        # Following will raise the error:
        # The size of tensor a (3) must match the size of tensor b (4)
        # at non-singleton dimension 1
        return torch.mean(y_preds - y_targets).item()

    em = EpochMetric(compute_fn)

    em.reset()
    output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 4), dtype=torch.long))
    with pytest.warns(EpochMetricWarning, match=r"Probably, there can be a problem with `compute_fn`"):
        em.update(output1)
Exemple #6
0
def infer(images_array: np.ndarray, list_names: list,
          dir_dataset: Path, dir_model: Path, list_models: str, seed: int):

    test_csv = pd.read_csv(dir_dataset / 'test.csv')
    test_ids = list_names

    device = "cuda" if torch.cuda.is_available() else "cpu"

    model = EnsembleModels(dir_model, list_models, device)

    test_dataset = BengaliDataset(None, list_images=[f'{s}.png' for s in test_ids], images_array=images_array)

    batch_size = 8
    test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=4, shuffle=False)

    metrics = {
        'output': EpochMetric(bengali_argmax, output_transform=lambda x: (x[0], x[1][0])),
    }

    evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)

    evaluator.run(test_loader)
    output = evaluator.state.metrics['output']

    output_df = pd.DataFrame(
        output.cpu().numpy(), index=test_ids,
        columns=['grapheme_root', 'vowel_diacritic', 'consonant_diacritic']
    )
    print(output_df)

    test_csv = test_csv.query('image_id in @list_names').reset_index(drop=True)

    submission = pd.DataFrame(test_csv['row_id'], columns=['row_id'])
    submission_target = np.zeros(len(test_csv), dtype=np.int32)

    for i, row in test_csv.iterrows():
        submission_target[i] = output_df.loc[row['image_id'], row['component']]

    submission['target'] = submission_target
    return submission
 def __init__(self, compute_fn, output_transform=lambda x: x):
     EpochMetric.__init__(self,
                          compute_fn=compute_fn,
                          output_transform=output_transform)
def test_epoch_metric():
    def compute_fn(y_preds, y_targets):
        return 0.0

    em = EpochMetric(compute_fn)

    em.reset()
    output1 = (torch.rand(4,
                          3), torch.randint(0,
                                            2,
                                            size=(4, 3),
                                            dtype=torch.long))
    em.update(output1)
    output2 = (torch.rand(4,
                          3), torch.randint(0,
                                            2,
                                            size=(4, 3),
                                            dtype=torch.long))
    em.update(output2)

    assert all([t.device.type == "cpu" for t in em._predictions + em._targets])
    assert torch.equal(em._predictions[0], output1[0])
    assert torch.equal(em._predictions[1], output2[0])
    assert torch.equal(em._targets[0], output1[1])
    assert torch.equal(em._targets[1], output2[1])
    assert em.compute() == 0.0

    # test when y and y_pred are (batch_size, 1) that are squeezed to (batch_size, )
    em.reset()
    output1 = (torch.rand(4,
                          1), torch.randint(0,
                                            2,
                                            size=(4, 1),
                                            dtype=torch.long))
    em.update(output1)
    output2 = (torch.rand(4,
                          1), torch.randint(0,
                                            2,
                                            size=(4, 1),
                                            dtype=torch.long))
    em.update(output2)

    assert all([t.device.type == "cpu" for t in em._predictions + em._targets])
    assert torch.equal(em._predictions[0], output1[0][:, 0])
    assert torch.equal(em._predictions[1], output2[0][:, 0])
    assert torch.equal(em._targets[0], output1[1][:, 0])
    assert torch.equal(em._targets[1], output2[1][:, 0])
    assert em.compute() == 0.0
def test_mse_epoch_metric():
    def compute_fn(y_preds, y_targets):
        return torch.mean(((y_preds - y_targets.type_as(y_preds))**2)).item()

    em = EpochMetric(compute_fn)

    em.reset()
    output1 = (torch.rand(4,
                          3), torch.randint(0,
                                            2,
                                            size=(4, 3),
                                            dtype=torch.long))
    em.update(output1)
    output2 = (torch.rand(4,
                          3), torch.randint(0,
                                            2,
                                            size=(4, 3),
                                            dtype=torch.long))
    em.update(output2)
    output3 = (torch.rand(4,
                          3), torch.randint(0,
                                            2,
                                            size=(4, 3),
                                            dtype=torch.long))
    em.update(output3)

    preds = torch.cat([output1[0], output2[0], output3[0]], dim=0)
    targets = torch.cat([output1[1], output2[1], output3[1]], dim=0)

    result = em.compute()
    assert result == compute_fn(preds, targets)

    em.reset()
    output1 = (torch.rand(4,
                          3), torch.randint(0,
                                            2,
                                            size=(4, 3),
                                            dtype=torch.long))
    em.update(output1)
    output2 = (torch.rand(4,
                          3), torch.randint(0,
                                            2,
                                            size=(4, 3),
                                            dtype=torch.long))
    em.update(output2)
    output3 = (torch.rand(4,
                          3), torch.randint(0,
                                            2,
                                            size=(4, 3),
                                            dtype=torch.long))
    em.update(output3)

    preds = torch.cat([output1[0], output2[0], output3[0]], dim=0)
    targets = torch.cat([output1[1], output2[1], output3[1]], dim=0)

    result = em.compute()
    assert result == compute_fn(preds, targets)
def test_epoch_metric_wrong_setup_or_input():

    # Wrong compute function
    with pytest.raises(TypeError,
                       match=r"Argument compute_fn should be callable."):
        EpochMetric(12345)

    def compute_fn(y_preds, y_targets):
        return 0.0

    em = EpochMetric(compute_fn)

    # Wrong input dims
    with pytest.raises(ValueError, match=r"Predictions should be of shape"):
        output = (torch.tensor(0), torch.tensor(0))
        em.update(output)

    # Wrong input dims
    with pytest.raises(ValueError, match=r"Targets should be of shape"):
        output = (torch.rand(4, 3), torch.rand(4, 3, 1))
        em.update(output)

    # Wrong input dims
    with pytest.raises(ValueError, match=r"Predictions should be of shape"):
        output = (torch.rand(4, 3, 1), torch.rand(4, 3))
        em.update(output)

    # Target is not binary
    with pytest.raises(ValueError, match=r"Targets should be binary"):
        output = (torch.rand(4, 3), torch.randint(0, 5, size=(4, 3)))
        em.update(output)

    em.reset()
    output1 = (torch.rand(4,
                          3), torch.randint(0,
                                            2,
                                            size=(4, 3),
                                            dtype=torch.long))
    em.update(output1)

    with pytest.raises(
            ValueError,
            match=
            r"Incoherent types between input y_pred and stored predictions"):
        output2 = (torch.randint(0, 5,
                                 size=(4, 3)), torch.randint(0, 2,
                                                             size=(4, 3)))
        em.update(output2)

    with pytest.raises(
            ValueError,
            match=r"Incoherent types between input y and stored targets"):
        output2 = (torch.rand(4,
                              3), torch.randint(0, 2,
                                                size=(4, 3)).to(torch.int32))
        em.update(output2)
Exemple #11
0
def test_epoch_metric():

    # Wrong compute function
    with pytest.raises(TypeError):
        EpochMetric(12345)

    def compute_fn(y_preds, y_targets):
        return 0.0

    em = EpochMetric(compute_fn)

    # Wrong input dims
    with pytest.raises(ValueError):
        output = (torch.tensor(0), torch.tensor(0))
        em.update(output)

    # Wrong input dims
    with pytest.raises(ValueError):
        output = (torch.rand(4, 3), torch.rand(4, 3, 1))
        em.update(output)

    # Wrong input dims
    with pytest.raises(ValueError):
        output = (torch.rand(4, 3, 1), torch.rand(4, 3))
        em.update(output)

    # Target is not binary
    with pytest.raises(ValueError):
        output = (torch.rand(4, 3), torch.randint(0, 5, size=(4, 3)))
        em.update(output)

    em.reset()
    output1 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long))
    em.update(output1)
    output2 = (torch.rand(4, 3), torch.randint(0, 2, size=(4, 3), dtype=torch.long))
    em.update(output2)

    assert em._predictions.device.type == 'cpu' and em._targets.device.type == 'cpu'
    assert torch.equal(em._predictions[:4, :], output1[0])
    assert torch.equal(em._predictions[4:, :], output2[0])
    assert torch.equal(em._targets[:4, :], output1[1])
    assert torch.equal(em._targets[4:, :], output2[1])
    assert em.compute() == 0.0

    # test when y and y_pred are (batch_size, 1) that are squeezed to (batch_size, )
    em.reset()
    output1 = (torch.rand(4, 1), torch.randint(0, 2, size=(4, 1), dtype=torch.long))
    em.update(output1)
    output2 = (torch.rand(4, 1), torch.randint(0, 2, size=(4, 1), dtype=torch.long))
    em.update(output2)

    assert em._predictions.device.type == 'cpu' and em._targets.device.type == 'cpu'
    assert torch.equal(em._predictions[:4], output1[0][:, 0])
    assert torch.equal(em._predictions[4:], output2[0][:, 0])
    assert torch.equal(em._targets[:4], output1[1][:, 0])
    assert torch.equal(em._targets[4:], output2[1][:, 0])
    assert em.compute() == 0.0
Exemple #12
0
def RocAucMetric(**kwargs):
    return EpochMetric(roc_auc_compute_fn, **kwargs)
def main(parser_args):
    """Main function to create trainer engine, add handlers to train and validation engines.
    Then runs train engine to perform training and validation.

    Args:
        parser_args (dict): parsed arguments
    """
    dataloader_train, dataloader_validation = get_dataloaders(parser_args)
    criterion = nn.CrossEntropyLoss()

    unet = SphericalUNet(parser_args.pooling_class, parser_args.n_pixels,
                         parser_args.depth, parser_args.laplacian_type,
                         parser_args.kernel_size)
    unet, device = init_device(parser_args.device, unet)
    lr = parser_args.learning_rate
    optimizer = optim.Adam(unet.parameters(), lr=lr)

    def trainer(engine, batch):
        """Train Function to define train engine.
        Called for every batch of the train engine, for each epoch.

        Args:
            engine (ignite.engine): train engine
            batch (:obj:`torch.utils.data.dataloader`): batch from train dataloader

        Returns:
            :obj:`torch.tensor` : train loss for that batch and epoch
        """
        unet.train()
        data, labels = batch
        labels = labels.to(device)
        data = data.to(device)
        output = unet(data)

        B, V, C = output.shape
        B_labels, V_labels, C_labels = labels.shape
        output = output.view(B * V, C)
        labels = labels.view(B_labels * V_labels, C_labels).max(1)[1]

        loss = criterion(output, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        return loss.item()

    writer = SummaryWriter(parser_args.tensorboard_path)

    engine_train = Engine(trainer)

    engine_validate = create_supervised_evaluator(
        model=unet,
        metrics={"AP": EpochMetric(average_precision_compute_fn)},
        device=device,
        output_transform=validate_output_transform)

    engine_train.add_event_handler(
        Events.EPOCH_STARTED,
        lambda x: print("Starting Epoch: {}".format(x.state.epoch)))
    engine_train.add_event_handler(Events.ITERATION_COMPLETED,
                                   TerminateOnNan())

    @engine_train.on(Events.EPOCH_COMPLETED)
    def epoch_validation(engine):
        """Handler to run the validation engine at the end of the train engine's epoch.

        Args:
            engine (ignite.engine): train engine
        """
        print("beginning validation epoch")
        engine_validate.run(dataloader_validation)

    reduce_lr_plateau = ReduceLROnPlateau(
        optimizer,
        mode=parser_args.reducelronplateau_mode,
        factor=parser_args.reducelronplateau_factor,
        patience=parser_args.reducelronplateau_patience,
    )

    @engine_validate.on(Events.EPOCH_COMPLETED)
    def update_reduce_on_plateau(engine):
        """Handler to reduce the learning rate on plateau at the end of the validation engine's epoch

        Args:
            engine (ignite.engine): validation engine
        """
        ap = engine.state.metrics["AP"]
        mean_average_precision = np.mean(ap[1:])
        reduce_lr_plateau.step(mean_average_precision)

    @engine_validate.on(Events.EPOCH_COMPLETED)
    def save_epoch_results(engine):
        """Handler to save the metrics at the end of the validation engine's epoch

        Args:
            engine (ignite.engine): validation engine
        """
        ap = engine.state.metrics["AP"]
        mean_average_precision = np.mean(ap[1:])
        print("Average precisions:", ap)
        print("mAP:", mean_average_precision)
        writer.add_scalars(
            "metrics",
            {
                "mean average precision (AR+TC)": mean_average_precision,
                "AR average precision": ap[2],
                "TC average precision": ap[1]
            },
            engine_train.state.epoch,
        )
        writer.close()

    step_scheduler = StepLR(optimizer,
                            step_size=parser_args.steplr_step_size,
                            gamma=parser_args.steplr_gamma)
    scheduler = create_lr_scheduler_with_warmup(
        step_scheduler,
        warmup_start_value=parser_args.warmuplr_warmup_start_value,
        warmup_end_value=parser_args.warmuplr_warmup_end_value,
        warmup_duration=parser_args.warmuplr_warmup_duration,
    )
    engine_validate.add_event_handler(Events.EPOCH_COMPLETED, scheduler)

    earlystopper = EarlyStopping(
        patience=parser_args.earlystopping_patience,
        score_function=lambda x: -x.state.metrics["AP"][1],
        trainer=engine_train)
    engine_validate.add_event_handler(Events.EPOCH_COMPLETED, earlystopper)

    add_tensorboard(engine_train,
                    optimizer,
                    unet,
                    log_dir=parser_args.tensorboard_path)

    engine_train.run(dataloader_train, max_epochs=parser_args.n_epochs)

    torch.save(unet.state_dict(),
               parser_args.model_save_path + "unet_state.pt")
Exemple #14
0
def test_epoch_metric():
    def compute_fn(y_preds, y_targets):
        return 0.0

    em = EpochMetric(compute_fn)

    # Wrong input dims
    with pytest.raises(AssertionError):
        output = (torch.tensor(0), torch.tensor(0))
        em.update(output)

    # Wrong input dims
    with pytest.raises(AssertionError):
        output = (torch.rand(4, 3, 1), torch.rand(4, 3))
        em.update(output)

    # Target is not binary
    with pytest.raises(AssertionError):
        output = (torch.rand(4, 3), torch.randint(0, 5, size=(4, 3)))
        em.update(output)

    em.reset()
    output1 = (torch.rand(4,
                          3), torch.randint(0,
                                            2,
                                            size=(4, 3),
                                            dtype=torch.long))
    em.update(output1)
    output2 = (torch.rand(4,
                          3), torch.randint(0,
                                            2,
                                            size=(4, 3),
                                            dtype=torch.long))
    em.update(output2)

    assert em._predictions.device.type == 'cpu' and em._targets.device.type == 'cpu'
    assert torch.equal(em._predictions[:4, :], output1[0])
    assert torch.equal(em._predictions[4:, :], output2[0])
    assert torch.equal(em._targets[:4, :], output1[1])
    assert torch.equal(em._targets[4:, :], output2[1])
    assert em.compute() == 0.0
Exemple #15
0
def train():
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset_path",
        type=str,
        default="",
        help="Path or url of the dataset. If empty download from S3.",
    )
    parser.add_argument(
        "-s",
        "--subreddit",
        type=str,
        action="append",
        default=[],
        help="Limit the subreddits you train on",
    )
    parser.add_argument(
        "--model_checkpoint",
        type=str,
        default="gpt2",
        help="Path, url or short name of the model",
    )
    parser.add_argument(
        "--num_candidates",
        type=int,
        default=2,
        help=
        "Number of candidates for training. Larger numbers may not fit on your GPU",
    )
    parser.add_argument(
        "--max_history",
        type=int,
        default=4,
        help="Number of previous exchanges to keep in history",
    )
    parser.add_argument("--max_epoch_length",
                        type=int,
                        default=100000000000,
                        help="Limit epoch length")
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=1,
                        help="Batch size for training")
    parser.add_argument("--valid_batch_size",
                        type=int,
                        default=1,
                        help="Batch size for validation")
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=32,
        help="Accumulate gradients on several steps",
    )
    parser.add_argument("--lr",
                        type=float,
                        default=6.25e-5,
                        help="Learning rate")
    parser.add_argument("--lm_coef",
                        type=float,
                        default=1.0,
                        help="LM loss coefficient")
    parser.add_argument("--mc_coef",
                        type=float,
                        default=1.0,
                        help="Multiple-choice loss coefficient")
    parser.add_argument("--max_norm",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=3,
                        help="Number of training epochs")
    parser.add_argument(
        "--eval_before_start",
        action="store_true",
        help="If true start with a first evaluation before training",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda" if torch.cuda.is_available() else "cpu",
        help="Device (cuda or cpu)",
    )
    parser.add_argument(
        "--fp16",
        type=str,
        default="",
        help=
        "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation). Try O2. Note first char is the letter 'oh'",
    )
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank for distributed training (-1: not distributed)",
    )
    parser.add_argument(
        "--max_seq_len",
        type=int,
        default=1024,
        help="Max length size, same or smaller than n_ctx in model",
    )
    parser.add_argument(
        "--mimic_op",
        type=bool,
        default=None,
        help=
        "Whether training should train only on replies where the original poster is author (in contrast False means only on non OP replies). Default none will do all replies",
    )

    args = parser.parse_args()

    ts = datetime.datetime.utcnow().strftime("%Y%m%d_%H-%M-%S")
    model_type_name = "gpt2" if "gpt2" in args.model_checkpoint else "gpt"
    logdir = Path(f"runs/{ts}_{model_type_name}")
    logdir.mkdir()

    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
        # format='[{%(filename)s:%(lineno)d} %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(
                filename=f"{logdir}/train_{args.local_rank}.log"),
            logging.StreamHandler(sys.stdout),
        ],
    )
    coloredlogs.install(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", args.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = args.local_rank != -1
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl",
                                             init_method="env://")

    logger.info("Prepare tokenizer - add special tokens for fine-tuning")
    tokenizer_class = (GPT2Tokenizer if "gpt2" in args.model_checkpoint else
                       OpenAIGPTTokenizer)
    tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)
    tokenizer.set_special_tokens(SPECIAL_TOKENS)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        args, tokenizer)

    logger.info(
        "Prepare pretrained model and optimizer - add special tokens for fine-tuning"
    )
    model_class = (GPT2DoubleHeadsModel if "gpt2" in args.model_checkpoint else
                   OpenAIGPTDoubleHeadsModel)
    model = model_class.from_pretrained(args.model_checkpoint)
    model.set_num_special_tokens(len(SPECIAL_TOKENS))
    model.to(args.device)
    t_total = len(
        train_loader) // args.gradient_accumulation_steps * args.n_epochs
    optimizer = OpenAIAdam(model.parameters(), lr=args.lr)
    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training

        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        lm_loss, mc_loss = model(*batch)
        loss = ((lm_loss * args.lm_coef + mc_loss * args.mc_coef) /
                args.gradient_accumulation_steps / args.train_batch_size)
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch, log_output=False):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
            model_outputs = model(input_ids,
                                  mc_token_ids,
                                  token_type_ids=token_type_ids)
            lm_logits, mc_logits = (
                model_outputs[0],
                model_outputs[1],
            )  # So we can also use GPT2 outputs
            lm_logits_flat_shifted = (lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1)))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)

            # Every now and again sample I guess I should make a custom engine for this
            if log_output:
                input_text = tokenizer.decode(
                    input_ids[0, -1, :].cpu().tolist()).rstrip("<pad>")
                output_text = tokenizer.decode(
                    lm_logits[0,
                              -1, :].argmax(-1).cpu().tolist()).strip()[:200]
                logger.info("inputs : %s", input_text)
                logger.info("outputs: %s", output_text)
            return dict(
                lm_logits_flat_shifted=lm_logits_flat_shifted,
                mc_logits=mc_logits,
                lm_labels_flat_shifted=lm_labels_flat_shifted,
                mc_labels=mc_labels,
                lr=torch.Tensor([optimizer.get_lr()[0]]),
            )

    evaluator = Engine(inference)
    exampler = Engine(functools.partial(inference, log_output=True))

    trainer.add_event_handler(Events.EPOCH_STARTED, lambda _: clear_mem())
    evaluator.add_event_handler(Events.EPOCH_STARTED, lambda _: clear_mem())

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    # After eval, run a short engine that will log some examplts
    evaluator.add_event_handler(
        # Events.EPOCH_COMPLETED, lambda _: exampler.run([next(iter(val_loader))])
        Events.EPOCH_COMPLETED,
        lambda _: exampler.run(itertools.islice(val_loader, 2)),
    )
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch),
        )
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch),
        )

    # Learning rate warms up then linearly decreases
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, 0), (int(t_total * 0.2), args.lr),
                                 (t_total, 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(
            torch.nn.CrossEntropyLoss(ignore_index=-1),
            output_transform=lambda x: (
                x["lm_logits_flat_shifted"],
                x["lm_labels_flat_shifted"],
            ),
        ),
        # Display the lr for each epoch, using the metrics api
        "lr":
        EpochMetric(
            output_transform=lambda x: (x["lr"], x["lr"]),
            compute_fn=lambda x, y: x[0].mean(),
        ),
    }
    # Meta metrics
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])

    # Only add accuracy if are using distractors
    if args.num_candidates > 1:
        metrics["accuracy"] = Accuracy(
            output_transform=lambda x: (x["mc_logits"], x["mc_labels"]))
        metrics.update({
            "average_accuracy":
            MetricsLambda(average_distributed_scalar, metrics["accuracy"],
                          args)
        })
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED,
            lambda _: logger.info("Validation: %s" % pformat(evaluator.state.
                                                             metrics)),
        )

        tb_logger = TensorboardLogger(log_dir=logdir)
        tb_logger.attach(
            trainer,
            log_handler=OutputHandler(tag="training", metric_names=["loss"]),
            event_name=Events.ITERATION_COMPLETED,
        )
        tb_logger.attach(
            trainer,
            log_handler=OptimizerParamsHandler(optimizer),
            event_name=Events.ITERATION_STARTED,
        )
        tb_logger.attach(
            evaluator,
            log_handler=OutputHandler(
                tag="validation",
                metric_names=list(metrics.keys()),
                another_engine=trainer,
            ),
            event_name=Events.EPOCH_COMPLETED,
        )

        checkpoint_handler = ModelCheckpoint(logdir,
                                             "checkpoint",
                                             save_interval=1,
                                             n_saved=3)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED,
            checkpoint_handler,
            {"mymodel": getattr(model, "module", model)},
        )  # "getattr" take care of distributed encapsulation

        torch.save(args, logdir / "model_training_args.bin")
        getattr(model, "module",
                model).config.to_json_file(os.path.join(logdir, CONFIG_NAME))
        tokenizer.save_vocabulary(logdir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(
            checkpoint_handler._saved[-1][1][-1], logdir / WEIGHTS_NAME
        )  # TODO (huggingface): PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()