def test_piecewiselinear_asserts():

    tensor = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([tensor], lr=0)

    with pytest.raises(ValueError):
        PiecewiseLinear(optimizer, "lr", milestones_values=[])

    with pytest.raises(ValueError):
        PiecewiseLinear(
            optimizer,
            "lr",
            milestones_values=[
                (0.5, ),
            ],
        )

    with pytest.raises(ValueError):
        PiecewiseLinear(optimizer,
                        "lr",
                        milestones_values=[(10, 0.5), (0.6, )])

    with pytest.raises(ValueError):
        PiecewiseLinear(optimizer,
                        "lr",
                        milestones_values=[(10, 0.5), (5, 0.6)])
    def _test(milestones_as_np_int):
        tensor = torch.zeros([1], requires_grad=True)
        optimizer = torch.optim.SGD([tensor], lr=0)

        milestones_values = [(5, 0.5), (15, 1.0), (25, 0.0), (35, 1.0),
                             (40, 0.5)]
        if milestones_as_np_int:
            milestones_values = [(np.int64(t), v)
                                 for t, v in milestones_values]

        scheduler = PiecewiseLinear(optimizer,
                                    'lr',
                                    milestones_values=milestones_values)
        state_dict = scheduler.state_dict()

        def save_lr(engine):
            lrs.append(optimizer.param_groups[0]['lr'])

        trainer = Engine(lambda engine, batch: None)
        trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)
        trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)

        for _ in range(2):
            lrs = []
            trainer.run([0] * 25, max_epochs=2)

            assert lrs == list(
                map(pytest.approx, [
                    0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75,
                    0.8, 0.85, 0.9, 0.95, 1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4,
                    0.3, 0.2, 0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8,
                    0.9, 1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5,
                    0.5, 0.5, 0.5, 0.5
                ]))
            scheduler.load_state_dict(state_dict)
Exemple #3
0
    def _run(
        self,
        trainer: Engine,
        optimizer: Optimizer,
        output_transform: Callable,
        num_iter: int,
        end_lr: float,
        step_mode: str,
        smooth_f: float,
        diverge_th: float,
    ):

        self._history = {"lr": [], "loss": []}
        self._best_loss = None
        self._diverge_flag = False

        # attach LRScheduler to trainer.
        if num_iter is None:
            num_iter = trainer.state.epoch_length * trainer.state.max_epochs
        else:
            max_iter = trainer.state.epoch_length * trainer.state.max_epochs
            if num_iter > max_iter:
                warnings.warn(
                    "Desired num_iter {} is unreachable with the current run setup of {} iteration "
                    "({} epochs)".format(num_iter, max_iter,
                                         trainer.state.max_epochs),
                    UserWarning,
                )

        if not trainer.has_event_handler(self._reached_num_iterations):
            trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                      self._reached_num_iterations, num_iter)

        # attach loss and lr logging
        if not trainer.has_event_handler(self._log_lr_and_loss):
            trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                      self._log_lr_and_loss, output_transform,
                                      smooth_f, diverge_th)

        self.logger.debug(
            "Running LR finder for {} iterations".format(num_iter))
        # Initialize the proper learning rate policy
        if step_mode.lower() == "exp":
            self._lr_schedule = LRScheduler(
                _ExponentialLR(optimizer, end_lr, num_iter))
        else:
            start_lr = optimizer.param_groups[0]["lr"]
            self._lr_schedule = PiecewiseLinear(optimizer,
                                                param_name="lr",
                                                milestones_values=[
                                                    (0, start_lr),
                                                    (num_iter, end_lr)
                                                ])
        if not trainer.has_event_handler(self._lr_schedule):
            trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                      self._lr_schedule, num_iter)
Exemple #4
0
def test_piecewiselinear():
    tensor = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([tensor], lr=0)

    scheduler = PiecewiseLinear(optimizer,
                                'lr',
                                milestones_values=[(5, 0.5), (15, 1.0),
                                                   (25, 0.0), (35, 1.0),
                                                   (40, 0.5)])
    lrs = []

    def save_lr(engine):
        lrs.append(optimizer.param_groups[0]['lr'])

    trainer = Engine(lambda engine, batch: None)
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)
    trainer.run([0] * 25, max_epochs=2)

    assert lrs == list(
        map(pytest.approx, [
            0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8,
            0.85, 0.9, 0.95, 1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1,
            0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 0.9, 0.8,
            0.7, 0.6, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5
        ]))
Exemple #5
0
def test_lr_scheduling_on_non_torch_optimizers():
    # tests https://github.com/pytorch/ignite/issues/1162
    optimizer = MagicMock()
    optimizer.param_groups = [{"params": 0}]
    FakeParamScheduler(optimizer, "lr")

    tensor = torch.zeros([1], requires_grad=True)
    base_optimizer = torch.optim.SGD([tensor], lr=0)
    optimizer = MockFP16DeepSpeedZeroOptimizer(base_optimizer)

    milestones_values = [(5, 0.5), (15, 1.0)]

    scheduler = PiecewiseLinear(optimizer, "lr", milestones_values=milestones_values)

    def save_lr(engine):
        lrs.append(optimizer.param_groups[0]["lr"])

    trainer = Engine(lambda engine, batch: None)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr)

    lrs = []
    trainer.run([0] * 15, max_epochs=1)

    assert lrs == list(
        map(pytest.approx, [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95,],)
    )
Exemple #6
0
def test_scheduler_with_param_groups():
    def _test(lr_scheduler, optimizer):
        num_iterations = 10
        max_epochs = 20

        state_dict = lr_scheduler.state_dict()

        trainer = Engine(lambda engine, batch: None)

        @trainer.on(Events.ITERATION_COMPLETED)
        def save_lr():
            lrs.append((optimizer.param_groups[0]["lr"], optimizer.param_groups[1]["lr"]))

        trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)

        data = [0] * num_iterations

        for _ in range(2):
            lrs = []
            trainer.run(data, max_epochs=max_epochs)
            assert [lr[0] for lr in lrs] == pytest.approx([lr[1] for lr in lrs])
            lr_scheduler.load_state_dict(state_dict)

    t1 = torch.zeros([1], requires_grad=True)
    t2 = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([{"params": t1, "lr": 0.1}, {"params": t2, "lr": 0.1}])

    lr_scheduler = LinearCyclicalScheduler(optimizer, "lr", start_value=1.0, end_value=0.0, cycle_size=10)
    _test(lr_scheduler, optimizer)

    lr_scheduler = PiecewiseLinear(
        optimizer, "lr", milestones_values=[(5, 0.5), (15, 1.0), (25, 0.0), (35, 1.0), (40, 0.5)]
    )
    _test(lr_scheduler, optimizer)

    lr_scheduler = CosineAnnealingScheduler(optimizer, "lr", start_value=0.0, end_value=1.0, cycle_size=10)
    _test(lr_scheduler, optimizer)

    torch_lr_scheduler = ExponentialLR(optimizer, gamma=0.98)
    _test(LRScheduler(torch_lr_scheduler), optimizer)

    torch_lr_scheduler = StepLR(optimizer, step_size=50, gamma=0.5)
    _test(LRScheduler(torch_lr_scheduler), optimizer)
Exemple #7
0
def test_piecewiselinear_asserts():

    tensor = torch.zeros([1], requires_grad=True)
    optimizer = torch.optim.SGD([tensor], lr=0)

    with pytest.raises(TypeError, match=r"Argument milestones_values should be a list or tuple"):
        PiecewiseLinear(optimizer, "lr", milestones_values=None)

    with pytest.raises(ValueError, match=r"Argument milestones_values should be with at least one value"):
        PiecewiseLinear(optimizer, "lr", milestones_values=[])

    with pytest.raises(ValueError, match=r"Argument milestones_values should be a list of pairs"):
        PiecewiseLinear(optimizer, "lr", milestones_values=[(0.5,)])

    with pytest.raises(ValueError, match=r"Argument milestones_values should be a list of pairs"):
        PiecewiseLinear(optimizer, "lr", milestones_values=[(10, 0.5), (0.6,)])

    with pytest.raises(ValueError, match=r"Milestones should be increasing integers"):
        PiecewiseLinear(optimizer, "lr", milestones_values=[(10, 0.5), (5, 0.6)])

    with pytest.raises(TypeError, match=r"Value of a milestone should be integer"):
        PiecewiseLinear(optimizer, "lr", milestones_values=[(0.5, 1)])
Exemple #8
0
class FastaiLRFinder:
    """Learning rate finder handler for supervised trainers.

    While attached, the handler increases the learning rate in between two
    boundaries in a linear or exponential manner. It provides valuable
    information on how well the network can be trained over a range of learning
    rates and what can be an optimal learning rate.

    Examples:

    .. code-block:: python

        from ignite.contrib.handlers import FastaiLRFinder

        trainer = ...
        model = ...
        optimizer = ...

        lr_finder = FastaiLRFinder()
        to_save = {"model": model, "optimizer": optimizer}

        with lr_finder.attach(trainer, to_save=to_save) as trainer_with_lr_finder:
            trainer_with_lr_finder.run(dataloader)

        # Get lr_finder results
        lr_finder.get_results()

        # Plot lr_finder results (requires matplotlib)
        lr_finder.plot()

        # get lr_finder suggestion for lr
        lr_finder.lr_suggestion()


    Note:
        When context manager is exited all LR finder's handlers are removed.

    Note:
        Please, also keep in mind that all other handlers attached the trainer will be executed during LR finder's run.

    Note:
        This class may require `matplotlib` package to be installed to plot learning rate range test:

        .. code-block:: bash

            pip install matplotlib


    References:

        Cyclical Learning Rates for Training Neural Networks:
        https://arxiv.org/abs/1506.01186

        fastai/lr_find: https://github.com/fastai/fastai
    """

    def __init__(self):
        self._diverge_flag = False
        self._history = None
        self._best_loss = None
        self._lr_schedule = None
        self.logger = logging.getLogger(__name__)

    def _run(
        self,
        trainer: Engine,
        optimizer: Optimizer,
        output_transform: Callable,
        num_iter: int,
        end_lr: float,
        step_mode: str,
        smooth_f: float,
        diverge_th: float,
    ):

        self._history = {"lr": [], "loss": []}
        self._best_loss = None
        self._diverge_flag = False

        # attach LRScheduler to trainer.
        if num_iter is None:
            num_iter = trainer.state.epoch_length * trainer.state.max_epochs
        else:
            max_iter = trainer.state.epoch_length * trainer.state.max_epochs
            if num_iter > max_iter:
                warnings.warn(
                    "Desired num_iter {} is unreachable with the current run setup of {} iteration "
                    "({} epochs)".format(num_iter, max_iter, trainer.state.max_epochs),
                    UserWarning,
                )

        if not trainer.has_event_handler(self._reached_num_iterations):
            trainer.add_event_handler(Events.ITERATION_COMPLETED, self._reached_num_iterations, num_iter)

        # attach loss and lr logging
        if not trainer.has_event_handler(self._log_lr_and_loss):
            trainer.add_event_handler(
                Events.ITERATION_COMPLETED, self._log_lr_and_loss, output_transform, smooth_f, diverge_th
            )

        self.logger.debug("Running LR finder for {} iterations".format(num_iter))
        # Initialize the proper learning rate policy
        if step_mode.lower() == "exp":
            self._lr_schedule = LRScheduler(_ExponentialLR(optimizer, end_lr, num_iter))
        else:
            start_lr = optimizer.param_groups[0]["lr"]
            self._lr_schedule = PiecewiseLinear(
                optimizer, param_name="lr", milestones_values=[(0, start_lr), (num_iter, end_lr)]
            )
        if not trainer.has_event_handler(self._lr_schedule):
            trainer.add_event_handler(Events.ITERATION_COMPLETED, self._lr_schedule, num_iter)

    def _reset(self, trainer: Engine):
        self.logger.debug("Completed LR finder run")
        trainer.remove_event_handler(self._lr_schedule, Events.ITERATION_COMPLETED)
        trainer.remove_event_handler(self._log_lr_and_loss, Events.ITERATION_COMPLETED)
        trainer.remove_event_handler(self._reached_num_iterations, Events.ITERATION_COMPLETED)

    def _log_lr_and_loss(self, trainer: Engine, output_transform: Callable, smooth_f: float, diverge_th: float):
        output = trainer.state.output
        loss = output_transform(output)
        lr = self._lr_schedule.get_param()
        self._history["lr"].append(lr)
        if trainer.state.iteration == 1:
            self._best_loss = loss
        else:
            if smooth_f > 0:
                loss = smooth_f * loss + (1 - smooth_f) * self._history["loss"][-1]
            if loss < self._best_loss:
                self._best_loss = loss
        self._history["loss"].append(loss)

        # Check if the loss has diverged; if it has, stop the trainer
        if self._history["loss"][-1] > diverge_th * self._best_loss:
            self._diverge_flag = True
            self.logger.info("Stopping early, the loss has diverged")
            trainer.terminate()

    def _reached_num_iterations(self, trainer: Engine, num_iter: int):
        if trainer.state.iteration > num_iter:
            trainer.terminate()

    def _warning(self, _):
        if not self._diverge_flag:
            warnings.warn(
                "Run completed without loss diverging, increase end_lr, decrease diverge_th or look"
                " at lr_finder.plot()",
                UserWarning,
            )

    def _detach(self, trainer: Engine):
        """
        Detaches lr_finder from trainer.

        Args:
            trainer: the trainer to detach form.
        """

        if trainer.has_event_handler(self._run, Events.STARTED):
            trainer.remove_event_handler(self._run, Events.STARTED)
        if trainer.has_event_handler(self._warning, Events.COMPLETED):
            trainer.remove_event_handler(self._warning, Events.COMPLETED)
        if trainer.has_event_handler(self._reset, Events.COMPLETED):
            trainer.remove_event_handler(self._reset, Events.COMPLETED)

    def get_results(self):
        """
        Returns: dictionary with loss and lr logs fromm the previous run
        """
        return self._history

    def plot(self, skip_start: int = 10, skip_end: int = 5, log_lr: bool = True):
        """Plots the learning rate range test.

        This method requires `matplotlib` package to be installed:

        .. code-block:: bash

            pip install matplotlib

        Args:
            skip_start (int, optional): number of batches to trim from the start.
                Default: 10.
            skip_end (int, optional): number of batches to trim from the start.
                Default: 5.
            log_lr (bool, optional): True to plot the learning rate in a logarithmic
                scale; otherwise, plotted in a linear scale. Default: True.
        """
        try:
            from matplotlib import pyplot as plt
        except ImportError:
            raise RuntimeError(
                "This method requires matplotlib to be installed. "
                "Please install it with command: \n pip install matplotlib"
            )

        if self._history is None:
            raise RuntimeError("learning rate finder didn't run yet so results can't be plotted")

        if skip_start < 0:
            raise ValueError("skip_start cannot be negative")
        if skip_end < 0:
            raise ValueError("skip_end cannot be negative")

        # Get the data to plot from the history dictionary. Also, handle skip_end=0
        # properly so the behaviour is the expected

        lrs = self._history["lr"]
        losses = self._history["loss"]
        if skip_end == 0:
            lrs = lrs[skip_start:]
            losses = losses[skip_start:]
        else:
            lrs = lrs[skip_start:-skip_end]
            losses = losses[skip_start:-skip_end]

        # Plot loss as a function of the learning rate
        plt.plot(lrs, losses)
        if log_lr:
            plt.xscale("log")
        plt.xlabel("Learning rate")
        plt.ylabel("Loss")
        plt.show()

    def lr_suggestion(self):
        """
        Returns: learning rate at the minimum numerical gradient
        """
        if self._history is None:
            raise RuntimeError("learning rate finder didn't run yet so lr_suggestion can't be returned")
        loss = self._history["loss"]
        grads = torch.tensor([loss[i] - loss[i - 1] for i in range(1, len(loss))])
        min_grad_idx = grads.argmin() + 1
        return self._history["lr"][int(min_grad_idx)]

    @contextlib.contextmanager
    def attach(
        self,
        trainer: Engine,
        to_save: Mapping,
        output_transform: Callable = lambda output: output,
        num_iter: Optional[int] = None,
        end_lr: float = 10.0,
        step_mode: str = "exp",
        smooth_f: float = 0.05,
        diverge_th: float = 5.0,
    ):
        """Attaches lr_finder to a given trainer. It also resets model and optimizer at the end of the run.

        Usage:

        .. code-block:: python

            to_save = {"model": model, "optimizer": optimizer}
            with lr_finder.attach(trainer, to_save=to_save) as trainer_with_lr_finder:
                trainer_with_lr_finder.run(dataloader)`

        Args:
            trainer (Engine): lr_finder is attached to this trainer. Please, keep in mind that all attached handlers
                will be executed.
            to_save (Mapping): dictionary with optimizer and other objects that needs to be restored after running
                the LR finder. For example, `to_save={'optimizer': optimizer, 'model': model}`. All objects should
                implement `state_dict` and `load_state_dict` methods.
            output_transform (callable, optional): function that transforms the trainer's `state.output` after each
                iteration. It must return the loss of that iteration.
            num_iter (int, optional): number of iterations for lr schedule between base lr and end_lr. Default, it will
                run for `trainer.state.epoch_length * trainer.state.max_epochs`.
            end_lr (float, optional): upper bound for lr search. Default, 10.0.
            step_mode (str, optional): "exp" or "linear", which way should the lr be increased from optimizer's initial
                lr to `end_lr`. Default, "exp".
            smooth_f (float, optional): loss smoothing factor in range `[0, 1)`. Default, 0.05
            diverge_th (float, optional): Used for stopping the search when `current loss > diverge_th * best_loss`.
                Default, 5.0.

        Note:
            lr_finder cannot be attached to more than one trainer at a time.

        Returns:
            trainer_with_lr_finder: trainer used for finding the lr
        """
        if not isinstance(to_save, Mapping):
            raise TypeError("Argument to_save should be a mapping, but given {}".format(type(to_save)))

        Checkpoint._check_objects(to_save, "state_dict")
        Checkpoint._check_objects(to_save, "load_state_dict")

        if "optimizer" not in to_save:
            raise ValueError("Mapping to_save should contain 'optimizer' key")

        if not isinstance(to_save["optimizer"], torch.optim.Optimizer):
            raise TypeError(
                "Object to_save['optimizer'] should be torch optimizer, but given {}".format(type(to_save["optimizer"]))
            )

        if smooth_f < 0 or smooth_f >= 1:
            raise ValueError("smooth_f is outside the range [0, 1]")
        if diverge_th < 1:
            raise ValueError("diverge_th should be larger than 1")
        if step_mode not in ["exp", "linear"]:
            raise ValueError("step_mode should be 'exp' or 'linear', but given {}".format(step_mode))
        if num_iter is not None:
            if not isinstance(num_iter, int):
                raise TypeError("if provided, num_iter should be an integer, but give {}".format(num_iter))
            if num_iter <= 0:
                raise ValueError("if provided, num_iter should be positive, but give {}".format(num_iter))

        # store to_save
        with tempfile.TemporaryDirectory() as tmpdirname:
            obj = {k: o.state_dict() for k, o in to_save.items()}
            # add trainer
            obj["trainer"] = trainer.state_dict()
            cache_filepath = Path(tmpdirname) / "ignite_lr_finder_cache.pt"
            torch.save(obj, cache_filepath.as_posix())

            optimizer = to_save["optimizer"]
            # Attach handlers
            if not trainer.has_event_handler(self._run):
                trainer.add_event_handler(
                    Events.STARTED,
                    self._run,
                    optimizer,
                    output_transform,
                    num_iter,
                    end_lr,
                    step_mode,
                    smooth_f,
                    diverge_th,
                )
            if not trainer.has_event_handler(self._warning):
                trainer.add_event_handler(Events.COMPLETED, self._warning)
            if not trainer.has_event_handler(self._reset):
                trainer.add_event_handler(Events.COMPLETED, self._reset)

            yield trainer
            self._detach(trainer)
            # restore to_save and reset trainer's state
            obj = torch.load(cache_filepath.as_posix())
            trainer.load_state_dict(obj["trainer"])
            for k, o in obj.items():
                if k in to_save:
                    to_save[k].load_state_dict(o)
def main():
    args = get_args()
    if 'e-SNLI-VE' in args.data_path:
        args.no_image = False
    else:
        args.no_image = True
    if not args.no_image:
        args.no_premise = True
    args.with_expl = True

    '''Setup'''
    t = datetime.today()
    output_dir = os.path.join(args.output_folder,
                              f"{t.month}_{t.day}_{t.hour}_{t.minute}_{t.second}")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(filename=os.path.join(output_dir, 'app.log'),
                        filemode='a',
                        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    # This is a logger.warning: it will be printed by all distributed processes
    logger.warning(f"Running process {args.local_rank}")
    logger.info(f"Arguments: {pformat(args)}")
    logger.info(f'Image not used:{args.no_image}')
    logger.info(f'Premise not used:{args.no_premise}')
    logger.info(f'Explanations used:{args.with_expl}')

    '''Initialize distributed training if needed'''
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info(
        "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning")
    tokenizer = GPT2Tokenizer.from_pretrained(args.model_checkpoint)
    tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
    if args.no_image:
        model = GPT2LMHeadModel.from_pretrained(args.model_checkpoint)
    else:
        import image_gpt2_291
        model = image_gpt2_291.GPT2LMHeadModel.from_pretrained(
            args.model_checkpoint)
    model.resize_token_embeddings(len(tokenizer))
    model.to(args.device)
    optimizer = AdamW(model.parameters(), lr=args.lr)

    '''
    Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    '''
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)
        model = model.module

    logger.info("Prepare datasets")
    train_loader, val_loader = get_data_loaders(args, tokenizer)

    '''Training function and trainer'''
    def train(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        if args.no_image:
            input_ids, lm_label, label, input_mask = batch
        else:
            image, input_ids, lm_label, label, input_mask = batch

        if args.no_image:
            output = model(input_ids=input_ids,
                           #    attention_mask=input_mask,
                           labels=lm_label)
        else:
            output = model(input_ids=input_ids,
                           images=image,
                           #    attention_mask=input_mask,
                           labels=lm_label)
        loss, logits, _ = output

        loss = loss / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(
                amp.master_params(optimizer), args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        if not args.with_expl:
            lbl_accuracy = torch.eq(label, logits.argmax(
                dim=1)).float().sum() / len(label)
            return {
                'loss': loss.item(),
                'lbl_accuracy': lbl_accuracy.item()
            }
        else:
            if engine.state.iteration % (args.gradient_accumulation_steps * 500) == 0:
                input_output = list(zip(input_ids, logits))
                random_item = random.choice(input_output)
                in_sent = tokenizer.decode(list(filter(
                    lambda x: x != tokenizer.eos_token_id,
                    random_item[0])))
                out_expl = tokenizer.decode(random_item[1].argmax(dim=1),
                                            skip_special_tokens=True)
                logger.info(f'MODEL INPUT: {in_sent}')
                logger.info(f'GEN. EXPL {out_expl}')
                logger.info('--------------------------------')
            return {
                'loss': loss.item(),
            }

    '''Validation function and validator (validator output is the input of the metrics)'''
    def validation(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(input_tensor.to(args.device)
                          for input_tensor in batch)
            if args.no_image:
                input_ids, lm_label, label, input_mask = batch
            else:
                image, input_ids, lm_label, label, input_mask = batch

            if args.no_image:
                output = model(input_ids=input_ids,
                               #    attention_mask=input_mask
                               )
            else:
                output = model(input_ids=input_ids,
                               images=image,
                               #    attention_mask=input_mask
                               )
            logits, _ = output

            logits_shifted = logits[..., :-1, :].contiguous().view(-1,
                                                                   logits.size(-1))
            labels_shifted = lm_label[..., 1:].contiguous().view(-1)
            return logits_shifted, labels_shifted

    '''Engines'''
    trainer = Engine(train)
    validator = Engine(validation)

    # t_total = len(
    #     train_loader) // args.gradient_accumulation_steps * args.n_epochs
    # scheduler = get_linear_schedule_with_warmup(
    #     optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
    '''Linearly decrease the learning rate from lr to zero'''
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    '''
    Attach validation to trainer: we evaluate when we start the training and at the end of each epoch
    '''
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: validator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: validator.run(val_loader))

    '''Prepare metrics - note how we compute distributed metrics'''
    RunningAverage(output_transform=lambda x: x['loss']).attach(
        trainer, "loss")
    RunningAverage(output_transform=lambda x: math.exp(
        average_distributed_scalar(x['loss'], args))).attach(trainer, "ppl")
    if not args.with_expl:
        RunningAverage(output_transform=lambda x: 100 * x['lbl_accuracy']).attach(
            trainer, "lbl_accuracy")

    metrics = {}
    metrics["lbl_loss"] = Loss(torch.nn.CrossEntropyLoss(),
                               output_transform=lambda x: (x[0], x[1]))
    metrics["loss"] = MetricsLambda(
        lambda l, a: average_distributed_scalar(
            l / a.gradient_accumulation_steps, a), metrics["lbl_loss"], args)
    metrics["ppl"] = MetricsLambda(math.exp, metrics["loss"])
    if not args.with_expl:
        metrics["lbl_accuracy"] = 100 * \
            Accuracy(output_transform=lambda x: (x[0], x[1]))
    for name, metric in metrics.items():
        metric.attach(validator, name)

    '''
    On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    '''
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer,
                    metric_names=["loss", 'ppl'] if args.with_expl else ["loss", 'lbl_accuracy', 'ppl'])
        validator.add_event_handler(Events.COMPLETED,
                                    lambda _: pbar.log_message(
                                        "Validation: %s" % pformat(validator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=output_dir)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(
                             tag="training",
                             metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(
                             tag="training",
                             metric_names=["ppl"] if args.with_expl else ["lbl_accuracy", "ppl"]),
                         event_name=Events.EPOCH_COMPLETED)

        tb_logger.attach(validator,
                         log_handler=OutputHandler(
                             tag="validation",
                             metric_names=[
                                 'ppl', 'loss'] if args.with_expl else['ppl', 'loss', 'lbl_accuracy'],
                             global_step_transform=lambda *args, **kwargs: trainer.state.iteration),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(output_dir,
                                             'checkpoint',
                                             n_saved=8,
                                             require_empty=False)
        trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1),
                                  checkpoint_handler,
                                  {'mymodel': getattr(model, 'module', model)})

        # "getattr" take care of distributed encapsulation
        torch.save(args, os.path.join(output_dir, 'model_training_args.bin'))
        getattr(model, 'module', model).config.to_json_file(
            os.path.join(output_dir, CONFIG_NAME))
        tokenizer.save_vocabulary(output_dir)

    '''Run the training'''
    trainer.run(train_loader, max_epochs=args.n_epochs)
Exemple #10
0
class FastaiLRFinder:
    """Learning rate finder handler for supervised trainers.

    While attached, the handler increases the learning rate in between two
    boundaries in a linear or exponential manner. It provides valuable
    information on how well the network can be trained over a range of learning
    rates and what can be an optimal learning rate.

    Examples:

    .. code-block:: python

        from ignite.contrib.handlers import FastaiLRFinder

        trainer = ...
        model = ...
        optimizer = ...

        lr_finder = FastaiLRFinder()
        to_save = {"model": model, "optimizer": optimizer}

        with lr_finder.attach(trainer, to_save=to_save) as trainer_with_lr_finder:
            trainer_with_lr_finder.run(dataloader)

        # Get lr_finder results
        lr_finder.get_results()

        # Plot lr_finder results (requires matplotlib)
        lr_finder.plot()

        # get lr_finder suggestion for lr
        lr_finder.lr_suggestion()


    Note:
        When context manager is exited all LR finder's handlers are removed.

    Note:
        Please, also keep in mind that all other handlers attached the trainer will be executed during LR finder's run.

    Note:
        This class may require `matplotlib` package to be installed to plot learning rate range test:

        .. code-block:: bash

            pip install matplotlib


    References:

        Cyclical Learning Rates for Training Neural Networks:
        https://arxiv.org/abs/1506.01186

        fastai/lr_find: https://github.com/fastai/fastai
    """
    def __init__(self) -> None:
        self._diverge_flag = False
        self._history = {}  # type: Dict[str, List[Any]]
        self._best_loss = None
        self._lr_schedule = None  # type: Optional[Union[LRScheduler, PiecewiseLinear]]
        self.logger = logging.getLogger(__name__ + "." +
                                        self.__class__.__name__)

    def _run(
        self,
        trainer: Engine,
        optimizer: Optimizer,
        output_transform: Callable,
        num_iter: int,
        end_lr: float,
        step_mode: str,
        smooth_f: float,
        diverge_th: float,
    ) -> None:

        self._history = {"lr": [], "loss": []}
        self._best_loss = None
        self._diverge_flag = False

        # attach LRScheduler to trainer.
        if num_iter is None:
            num_iter = trainer.state.epoch_length * trainer.state.max_epochs
        else:
            max_iter = trainer.state.epoch_length * trainer.state.max_epochs  # type: ignore[operator]
            if num_iter > max_iter:
                warnings.warn(
                    f"Desired num_iter {num_iter} is unreachable with the current run setup of {max_iter} iteration "
                    f"({trainer.state.max_epochs} epochs)",
                    UserWarning,
                )

        if not trainer.has_event_handler(self._reached_num_iterations):
            trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                      self._reached_num_iterations, num_iter)

        # attach loss and lr logging
        if not trainer.has_event_handler(self._log_lr_and_loss):
            trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                      self._log_lr_and_loss, output_transform,
                                      smooth_f, diverge_th)

        self.logger.debug(f"Running LR finder for {num_iter} iterations")
        # Initialize the proper learning rate policy
        if step_mode.lower() == "exp":
            self._lr_schedule = LRScheduler(
                _ExponentialLR(optimizer, end_lr, num_iter))
        else:
            start_lr = optimizer.param_groups[0]["lr"]
            self._lr_schedule = PiecewiseLinear(optimizer,
                                                param_name="lr",
                                                milestones_values=[
                                                    (0, start_lr),
                                                    (num_iter, end_lr)
                                                ])
        if not trainer.has_event_handler(self._lr_schedule):
            trainer.add_event_handler(Events.ITERATION_COMPLETED,
                                      self._lr_schedule, num_iter)

    def _reset(self, trainer: Engine) -> None:
        self.logger.debug("Completed LR finder run")
        trainer.remove_event_handler(
            self._lr_schedule,
            Events.ITERATION_COMPLETED)  # type: ignore[arg-type]
        trainer.remove_event_handler(self._log_lr_and_loss,
                                     Events.ITERATION_COMPLETED)
        trainer.remove_event_handler(self._reached_num_iterations,
                                     Events.ITERATION_COMPLETED)

    def _log_lr_and_loss(self, trainer: Engine, output_transform: Callable,
                         smooth_f: float, diverge_th: float) -> None:
        output = trainer.state.output
        loss = output_transform(output)
        lr = self._lr_schedule.get_param()  # type: ignore[union-attr]
        self._history["lr"].append(lr)
        if trainer.state.iteration == 1:
            self._best_loss = loss
        else:
            if smooth_f > 0:
                loss = smooth_f * loss + (1 -
                                          smooth_f) * self._history["loss"][-1]
            if loss < self._best_loss:
                self._best_loss = loss
        self._history["loss"].append(loss)

        # Check if the loss has diverged; if it has, stop the trainer
        if self._history["loss"][
                -1] > diverge_th * self._best_loss:  # type: ignore[operator]
            self._diverge_flag = True
            self.logger.info("Stopping early, the loss has diverged")
            trainer.terminate()

    def _reached_num_iterations(self, trainer: Engine, num_iter: int) -> None:
        if trainer.state.iteration > num_iter:
            trainer.terminate()

    def _warning(self, _: Any) -> None:
        if not self._diverge_flag:
            warnings.warn(
                "Run completed without loss diverging, increase end_lr, decrease diverge_th or look"
                " at lr_finder.plot()",
                UserWarning,
            )

    def _detach(self, trainer: Engine) -> None:
        """
        Detaches lr_finder from trainer.

        Args:
            trainer: the trainer to detach form.
        """

        if trainer.has_event_handler(self._run, Events.STARTED):
            trainer.remove_event_handler(self._run, Events.STARTED)
        if trainer.has_event_handler(self._warning, Events.COMPLETED):
            trainer.remove_event_handler(self._warning, Events.COMPLETED)
        if trainer.has_event_handler(self._reset, Events.COMPLETED):
            trainer.remove_event_handler(self._reset, Events.COMPLETED)

    def get_results(self) -> Dict[str, List[Any]]:
        """
        Returns:
            Dictionary with loss and lr logs from the previous run
        """
        return self._history

    def plot(
        self,
        skip_start: int = 10,
        skip_end: int = 5,
        log_lr: bool = True,
        display_suggestion: bool = True,
        ax: Optional[Any] = None,
        **kwargs: Any,
    ) -> None:
        """Plots the learning rate range test.

        This method requires ``matplotlib`` package to be installed:

        .. code-block:: bash

            pip install matplotlib

        Args:
            skip_start: number of batches to trim from the start.
                Default: 10.
            skip_end: number of batches to trim from the start.
                Default: 5.
            log_lr: True to plot the learning rate in a logarithmic
                scale; otherwise, plotted in a linear scale. Default: True.
            display_suggestion: if True, red dot shows the suggested learning rate.
            ax: Pre-existing axes for the plot. Default: None.
            kwargs: optional kwargs passed to ``plt.subplots`` if ``ax`` is not provided.

        .. code-block:: python

            ax = lr_finder.plot(skip_end=0)
            ax.figure.savefig("output.jpg")

        """
        try:
            from matplotlib import pyplot as plt
        except ImportError:
            raise RuntimeError(
                "This method requires matplotlib to be installed. "
                "Please install it with command: \n pip install matplotlib")
        if not self._history:
            raise RuntimeError(
                "learning rate finder didn't run yet so results can't be plotted"
            )

        if skip_start < 0:
            raise ValueError("skip_start cannot be negative")
        if skip_end < 0:
            raise ValueError("skip_end cannot be negative")

        # Get the data to plot from the history dictionary.
        lrs = self._history["lr"]
        losses = self._history["loss"]

        num_groups = len(lrs[0]) if isinstance(lrs[0], list) else 1
        legends = [
            f"suggested lr for param_groups {i}" for i in range(num_groups)
        ]

        if ax is None:
            fig, ax = plt.subplots(**kwargs)

        # Check to show the suggested learning rate
        if display_suggestion:
            sug_lr = self.lr_suggestion()
            idx = self._history["lr"].index(sug_lr)

            if skip_start >= idx:
                warnings.warn(
                    "skip_start is larger than the suggested LR found"
                    " and it will not be visible on the plot. Please, make the value smaller.",
                    UserWarning,
                )

            corresponding_loss = self._history["loss"][int(idx)]

            # Check if optimizer has multiple param_groups
            if not isinstance(sug_lr, list):
                sug_lr = [
                    sug_lr,
                ]
            for lr in sug_lr:
                ax.scatter(
                    lr,
                    corresponding_loss,
                    color="red" if len(sug_lr) == 1 else None,
                    s=75,
                    marker="o",
                    zorder=3,
                )

        # handle skip_end=0 properly
        if skip_end == 0:
            lrs = lrs[skip_start:]
            losses = losses[skip_start:]
        else:
            lrs = lrs[skip_start:-skip_end]
            losses = losses[skip_start:-skip_end]

        plt.legend(legends)
        # Plot loss as a function of the learning rate
        ax.plot(lrs, losses)
        if log_lr:
            ax.set_xscale("log")
        lr_min = min(lrs[0]) if isinstance(lrs[0], list) else lrs[0]
        lr_max = max(lrs[-1]) if isinstance(lrs[-1], list) else lrs[-1]
        ax.set_xlim([lr_min, lr_max])
        ax.set_xlabel("Learning rate")
        ax.set_ylabel("Loss")
        plt.show()
        return ax

    def lr_suggestion(self) -> Any:
        """
        Returns:
            Learning rate at the minimum numerical gradient
            (ignoring the increasing part of the curve)
        """
        if not self._history:
            raise RuntimeError(
                "learning rate finder didn't run yet so lr_suggestion can't be returned"
            )
        loss = self._history["loss"]
        min_loss_idx = torch.tensor(loss).argmin()
        # Ignore the increasing part of the curve
        decreasing_losses = self._history["loss"][:int(min_loss_idx.item()) +
                                                  1]
        if len(decreasing_losses) < 3:
            raise RuntimeError(
                "FastaiLRFinder got unexpected curve shape, the curve should be somehow U-shaped"
            )
        losses = torch.tensor(decreasing_losses)
        grads = torch.tensor([
            0.5 * (losses[i + 1] - losses[i - 1])
            for i in range(1,
                           len(losses) - 1)
        ])
        min_grad_idx = grads.argmin() + 1
        return self._history["lr"][int(min_grad_idx)]

    def apply_suggested_lr(self, optimizer: Optimizer) -> None:
        """
        Applying the suggested learning rate(s) on the given optimizer.

        Note:
            The given optimizer must be the same as the one we before found the suggested learning rate for.

        Args:
            optimizer: the optimizer to apply the suggested learning rate(s) on.

        """
        sug_lr = self.lr_suggestion()
        if not isinstance(sug_lr, list):
            sug_lr = [
                sug_lr,
            ]

        if len(sug_lr) != len(optimizer.param_groups):
            raise RuntimeError(
                "The number of parameter groups does not match between "
                "given optimizer and the one used for estimating the "
                f"learning rate: {len(sug_lr)} vs {len(optimizer.param_groups)}"
            )

        for i, lr in enumerate(sug_lr):
            optimizer.param_groups[i]["lr"] = lr

    @contextlib.contextmanager
    def attach(
        self,
        trainer: Engine,
        to_save: Mapping,
        output_transform: Callable = lambda output: output,
        num_iter: Optional[int] = None,
        end_lr: float = 10.0,
        step_mode: str = "exp",
        smooth_f: float = 0.05,
        diverge_th: float = 5.0,
    ) -> Any:
        """Attaches lr_finder to a given trainer. It also resets model and optimizer at the end of the run.

        Usage:

        .. code-block:: python

            to_save = {"model": model, "optimizer": optimizer}
            with lr_finder.attach(trainer, to_save=to_save) as trainer_with_lr_finder:
                trainer_with_lr_finder.run(dataloader)

        Args:
            trainer: lr_finder is attached to this trainer. Please, keep in mind that all attached handlers
                will be executed.
            to_save: dictionary with optimizer and other objects that needs to be restored after running
                the LR finder. For example, ``to_save={'optimizer': optimizer, 'model': model}``. All objects should
                implement ``state_dict`` and ``load_state_dict`` methods.
            output_transform: function that transforms the trainer's ``state.output`` after each
                iteration. It must return the loss of that iteration.
            num_iter: number of iterations for lr schedule between base lr and end_lr. Default, it will
                run for ``trainer.state.epoch_length * trainer.state.max_epochs``.
            end_lr: upper bound for lr search. Default, 10.0.
            step_mode: "exp" or "linear", which way should the lr be increased from optimizer's initial
                lr to ``end_lr``. Default, "exp".
            smooth_f: loss smoothing factor in range ``[0, 1)``. Default, 0.05
            diverge_th: Used for stopping the search when ``current loss > diverge_th * best_loss``.
                Default, 5.0.

        Returns:
            trainer_with_lr_finder (trainer used for finding the lr)

        Note:
            lr_finder cannot be attached to more than one trainer at a time.
        """
        if not isinstance(to_save, Mapping):
            raise TypeError(
                f"Argument to_save should be a mapping, but given {type(to_save)}"
            )

        Checkpoint._check_objects(to_save, "state_dict")
        Checkpoint._check_objects(to_save, "load_state_dict")

        if "optimizer" not in to_save:
            raise ValueError("Mapping to_save should contain 'optimizer' key")

        if not isinstance(to_save["optimizer"], torch.optim.Optimizer):
            raise TypeError(
                f"Object to_save['optimizer'] should be torch optimizer, but given {type(to_save['optimizer'])}"
            )

        if smooth_f < 0 or smooth_f >= 1:
            raise ValueError("smooth_f is outside the range [0, 1]")
        if diverge_th < 1:
            raise ValueError("diverge_th should be larger than 1")
        if step_mode not in ["exp", "linear"]:
            raise ValueError(
                f"step_mode should be 'exp' or 'linear', but given {step_mode}"
            )
        if num_iter is not None:
            if not isinstance(num_iter, int):
                raise TypeError(
                    f"if provided, num_iter should be an integer, but give {num_iter}"
                )
            if num_iter <= 0:
                raise ValueError(
                    f"if provided, num_iter should be positive, but give {num_iter}"
                )

        # store to_save
        with tempfile.TemporaryDirectory() as tmpdirname:
            obj = {k: o.state_dict() for k, o in to_save.items()}
            # add trainer
            obj["trainer"] = trainer.state_dict()
            cache_filepath = Path(tmpdirname) / "ignite_lr_finder_cache.pt"
            torch.save(obj, cache_filepath.as_posix())

            optimizer = to_save["optimizer"]
            # Attach handlers
            if not trainer.has_event_handler(self._run):
                trainer.add_event_handler(
                    Events.STARTED,
                    self._run,
                    optimizer,
                    output_transform,
                    num_iter,
                    end_lr,
                    step_mode,
                    smooth_f,
                    diverge_th,
                )
            if not trainer.has_event_handler(self._warning):
                trainer.add_event_handler(Events.COMPLETED, self._warning)
            if not trainer.has_event_handler(self._reset):
                trainer.add_event_handler(Events.COMPLETED, self._reset)

            yield trainer
            self._detach(trainer)
            # restore to_save and reset trainer's state
            obj = torch.load(cache_filepath.as_posix())
            trainer.load_state_dict(obj["trainer"])
            for k, o in obj.items():
                if k in to_save:
                    to_save[k].load_state_dict(o)
Exemple #11
0
def train():
    args = get_args()
    '''Setup'''
    if not os.path.exists(args.log_path):
        os.makedirs(args.log_path, exist_ok=True)
    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    # This is a logger.warning: it will be printed by all distributed processes
    logger.warning("Running process %d", args.local_rank)
    logger.info("Arguments: %s", pformat(args))
    '''Initialize distributed training if needed'''
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info(
        "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning"
    )
    tokenizer_class = GPT2Tokenizer
    tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)
    model_class = VideoGPT2LMHeadModel
    model = model_class.from_pretrained(args.model_checkpoint)
    tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
    model.resize_token_embeddings(len(tokenizer))
    model.to(args.device)
    optimizer = AdamW(model.parameters(), lr=args.lr)
    '''
    Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    '''
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)
        model = model.module

    logger.info("Prepare datasets")
    train_loader, val_loader = get_data_loaders_new(args, tokenizer)
    '''Training function and trainer'''
    def update(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        input_ids, token_type_ids, labels, input_mask, i3d, video_mask, reply_mask = batch
        input_embs = model.transformer.wte(input_ids)
        video_embs = model.video_ff(i3d)
        input_embs = torch.cat([video_embs, input_embs], dim=1)
        token_type_ids = torch.cat([
            torch.ones((i3d.size(0), i3d.size(1))).long().cuda() *
            tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-2]), token_type_ids
        ],
                                   dim=1)
        video_loss = model(input_embs,
                           token_type_ids=token_type_ids,
                           labels=(labels, i3d),
                           attention_mask=[video_mask, input_mask],
                           mode="video")[0]
        reply_loss = model(input_embs,
                           token_type_ids=token_type_ids,
                           labels=(labels, i3d),
                           attention_mask=[reply_mask, input_mask],
                           mode="reply")[0]
        loss = (video_loss + reply_loss) / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    '''Evaluation function and evaluator (evaluator output is the input of the metrics)'''

    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            input_ids, token_type_ids, lm_labels, input_mask, i3d, video_mask, reply_mask = batch
            input_embs = model.transformer.wte(input_ids)
            video_embs = model.video_ff(i3d)
            input_embs = torch.cat([video_embs, input_embs], dim=1)
            token_type_ids = torch.cat([
                torch.ones((i3d.size(0), i3d.size(1))).long().cuda() *
                tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-2]),
                token_type_ids
            ],
                                       dim=1)
            model_outputs = model(input_embs,
                                  token_type_ids=token_type_ids,
                                  attention_mask=[reply_mask, input_mask])[0]

            lm_logits = model_outputs  # So we can also use GPT2 outputs
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return lm_logits_flat_shifted, lm_labels_flat_shifted

    '''Engines'''
    trainer = Engine(update)
    evaluator = Engine(inference)
    '''
    Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    '''
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0], x[1]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)
    '''
    On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    '''
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir="./tb_logs")
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(args.log_path,
                                             'checkpoint',
                                             n_saved=8,
                                             require_empty=False)
        trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1),
                                  checkpoint_handler,
                                  {'mymodel': getattr(model, 'module', model)})
        # "getattr" take care of distributed encapsulation

        torch.save(args, args.log_path + 'model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(
            os.path.join(args.log_path, CONFIG_NAME))
        tokenizer.save_vocabulary(args.log_path)
    '''Run the training'''
    trainer.run(train_loader, max_epochs=args.n_epochs)
    '''
    On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    '''
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        # TODO: PR in ignite to have better access to saved file paths (cleaner)
        os.rename(checkpoint_handler._saved[-1][1][-1],
                  os.path.join(args.log_path, WEIGHTS_NAME))
        tb_logger.close()