Esempio n. 1
0
def main():
    args = get_args()
    if 'e-SNLI-VE' in args.data_path:
        args.no_image = False
    else:
        args.no_image = True
    if not args.no_image:
        args.no_premise = True
    args.with_expl = True

    '''Setup'''
    t = datetime.today()
    output_dir = os.path.join(args.output_folder,
                              f"{t.month}_{t.day}_{t.hour}_{t.minute}_{t.second}")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(filename=os.path.join(output_dir, 'app.log'),
                        filemode='a',
                        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    # This is a logger.warning: it will be printed by all distributed processes
    logger.warning(f"Running process {args.local_rank}")
    logger.info(f"Arguments: {pformat(args)}")
    logger.info(f'Image not used:{args.no_image}')
    logger.info(f'Premise not used:{args.no_premise}')
    logger.info(f'Explanations used:{args.with_expl}')

    '''Initialize distributed training if needed'''
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info(
        "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning")
    tokenizer = GPT2Tokenizer.from_pretrained(args.model_checkpoint)
    tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
    if args.no_image:
        model = GPT2LMHeadModel.from_pretrained(args.model_checkpoint)
    else:
        import image_gpt2_291
        model = image_gpt2_291.GPT2LMHeadModel.from_pretrained(
            args.model_checkpoint)
    model.resize_token_embeddings(len(tokenizer))
    model.to(args.device)
    optimizer = AdamW(model.parameters(), lr=args.lr)

    '''
    Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    '''
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)
        model = model.module

    logger.info("Prepare datasets")
    train_loader, val_loader = get_data_loaders(args, tokenizer)

    '''Training function and trainer'''
    def train(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        if args.no_image:
            input_ids, lm_label, label, input_mask = batch
        else:
            image, input_ids, lm_label, label, input_mask = batch

        if args.no_image:
            output = model(input_ids=input_ids,
                           #    attention_mask=input_mask,
                           labels=lm_label)
        else:
            output = model(input_ids=input_ids,
                           images=image,
                           #    attention_mask=input_mask,
                           labels=lm_label)
        loss, logits, _ = output

        loss = loss / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(
                amp.master_params(optimizer), args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        if not args.with_expl:
            lbl_accuracy = torch.eq(label, logits.argmax(
                dim=1)).float().sum() / len(label)
            return {
                'loss': loss.item(),
                'lbl_accuracy': lbl_accuracy.item()
            }
        else:
            if engine.state.iteration % (args.gradient_accumulation_steps * 500) == 0:
                input_output = list(zip(input_ids, logits))
                random_item = random.choice(input_output)
                in_sent = tokenizer.decode(list(filter(
                    lambda x: x != tokenizer.eos_token_id,
                    random_item[0])))
                out_expl = tokenizer.decode(random_item[1].argmax(dim=1),
                                            skip_special_tokens=True)
                logger.info(f'MODEL INPUT: {in_sent}')
                logger.info(f'GEN. EXPL {out_expl}')
                logger.info('--------------------------------')
            return {
                'loss': loss.item(),
            }

    '''Validation function and validator (validator output is the input of the metrics)'''
    def validation(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(input_tensor.to(args.device)
                          for input_tensor in batch)
            if args.no_image:
                input_ids, lm_label, label, input_mask = batch
            else:
                image, input_ids, lm_label, label, input_mask = batch

            if args.no_image:
                output = model(input_ids=input_ids,
                               #    attention_mask=input_mask
                               )
            else:
                output = model(input_ids=input_ids,
                               images=image,
                               #    attention_mask=input_mask
                               )
            logits, _ = output

            logits_shifted = logits[..., :-1, :].contiguous().view(-1,
                                                                   logits.size(-1))
            labels_shifted = lm_label[..., 1:].contiguous().view(-1)
            return logits_shifted, labels_shifted

    '''Engines'''
    trainer = Engine(train)
    validator = Engine(validation)

    # t_total = len(
    #     train_loader) // args.gradient_accumulation_steps * args.n_epochs
    # scheduler = get_linear_schedule_with_warmup(
    #     optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
    '''Linearly decrease the learning rate from lr to zero'''
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    '''
    Attach validation to trainer: we evaluate when we start the training and at the end of each epoch
    '''
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: validator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: validator.run(val_loader))

    '''Prepare metrics - note how we compute distributed metrics'''
    RunningAverage(output_transform=lambda x: x['loss']).attach(
        trainer, "loss")
    RunningAverage(output_transform=lambda x: math.exp(
        average_distributed_scalar(x['loss'], args))).attach(trainer, "ppl")
    if not args.with_expl:
        RunningAverage(output_transform=lambda x: 100 * x['lbl_accuracy']).attach(
            trainer, "lbl_accuracy")

    metrics = {}
    metrics["lbl_loss"] = Loss(torch.nn.CrossEntropyLoss(),
                               output_transform=lambda x: (x[0], x[1]))
    metrics["loss"] = MetricsLambda(
        lambda l, a: average_distributed_scalar(
            l / a.gradient_accumulation_steps, a), metrics["lbl_loss"], args)
    metrics["ppl"] = MetricsLambda(math.exp, metrics["loss"])
    if not args.with_expl:
        metrics["lbl_accuracy"] = 100 * \
            Accuracy(output_transform=lambda x: (x[0], x[1]))
    for name, metric in metrics.items():
        metric.attach(validator, name)

    '''
    On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    '''
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer,
                    metric_names=["loss", 'ppl'] if args.with_expl else ["loss", 'lbl_accuracy', 'ppl'])
        validator.add_event_handler(Events.COMPLETED,
                                    lambda _: pbar.log_message(
                                        "Validation: %s" % pformat(validator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=output_dir)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(
                             tag="training",
                             metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(
                             tag="training",
                             metric_names=["ppl"] if args.with_expl else ["lbl_accuracy", "ppl"]),
                         event_name=Events.EPOCH_COMPLETED)

        tb_logger.attach(validator,
                         log_handler=OutputHandler(
                             tag="validation",
                             metric_names=[
                                 'ppl', 'loss'] if args.with_expl else['ppl', 'loss', 'lbl_accuracy'],
                             global_step_transform=lambda *args, **kwargs: trainer.state.iteration),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(output_dir,
                                             'checkpoint',
                                             n_saved=8,
                                             require_empty=False)
        trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1),
                                  checkpoint_handler,
                                  {'mymodel': getattr(model, 'module', model)})

        # "getattr" take care of distributed encapsulation
        torch.save(args, os.path.join(output_dir, 'model_training_args.bin'))
        getattr(model, 'module', model).config.to_json_file(
            os.path.join(output_dir, CONFIG_NAME))
        tokenizer.save_vocabulary(output_dir)

    '''Run the training'''
    trainer.run(train_loader, max_epochs=args.n_epochs)
Esempio n. 2
0
def train():
    args = get_args()
    '''Setup'''
    if not os.path.exists(args.log_path):
        os.makedirs(args.log_path, exist_ok=True)
    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    # This is a logger.warning: it will be printed by all distributed processes
    logger.warning("Running process %d", args.local_rank)
    logger.info("Arguments: %s", pformat(args))
    '''Initialize distributed training if needed'''
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info(
        "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning"
    )
    tokenizer_class = GPT2Tokenizer
    tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)
    model_class = VideoGPT2LMHeadModel
    model = model_class.from_pretrained(args.model_checkpoint)
    tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
    model.resize_token_embeddings(len(tokenizer))
    model.to(args.device)
    optimizer = AdamW(model.parameters(), lr=args.lr)
    '''
    Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    '''
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)
        model = model.module

    logger.info("Prepare datasets")
    train_loader, val_loader = get_data_loaders_new(args, tokenizer)
    '''Training function and trainer'''
    def update(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        input_ids, token_type_ids, labels, input_mask, i3d, video_mask, reply_mask = batch
        input_embs = model.transformer.wte(input_ids)
        video_embs = model.video_ff(i3d)
        input_embs = torch.cat([video_embs, input_embs], dim=1)
        token_type_ids = torch.cat([
            torch.ones((i3d.size(0), i3d.size(1))).long().cuda() *
            tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-2]), token_type_ids
        ],
                                   dim=1)
        video_loss = model(input_embs,
                           token_type_ids=token_type_ids,
                           labels=(labels, i3d),
                           attention_mask=[video_mask, input_mask],
                           mode="video")[0]
        reply_loss = model(input_embs,
                           token_type_ids=token_type_ids,
                           labels=(labels, i3d),
                           attention_mask=[reply_mask, input_mask],
                           mode="reply")[0]
        loss = (video_loss + reply_loss) / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    '''Evaluation function and evaluator (evaluator output is the input of the metrics)'''

    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            input_ids, token_type_ids, lm_labels, input_mask, i3d, video_mask, reply_mask = batch
            input_embs = model.transformer.wte(input_ids)
            video_embs = model.video_ff(i3d)
            input_embs = torch.cat([video_embs, input_embs], dim=1)
            token_type_ids = torch.cat([
                torch.ones((i3d.size(0), i3d.size(1))).long().cuda() *
                tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-2]),
                token_type_ids
            ],
                                       dim=1)
            model_outputs = model(input_embs,
                                  token_type_ids=token_type_ids,
                                  attention_mask=[reply_mask, input_mask])[0]

            lm_logits = model_outputs  # So we can also use GPT2 outputs
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return lm_logits_flat_shifted, lm_labels_flat_shifted

    '''Engines'''
    trainer = Engine(update)
    evaluator = Engine(inference)
    '''
    Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    '''
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0], x[1]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)
    '''
    On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    '''
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir="./tb_logs")
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(args.log_path,
                                             'checkpoint',
                                             n_saved=8,
                                             require_empty=False)
        trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1),
                                  checkpoint_handler,
                                  {'mymodel': getattr(model, 'module', model)})
        # "getattr" take care of distributed encapsulation

        torch.save(args, args.log_path + 'model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(
            os.path.join(args.log_path, CONFIG_NAME))
        tokenizer.save_vocabulary(args.log_path)
    '''Run the training'''
    trainer.run(train_loader, max_epochs=args.n_epochs)
    '''
    On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    '''
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        # TODO: PR in ignite to have better access to saved file paths (cleaner)
        os.rename(checkpoint_handler._saved[-1][1][-1],
                  os.path.join(args.log_path, WEIGHTS_NAME))
        tb_logger.close()