Beispiel #1
0
def train():
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset_path",
        type=str,
        default="",
        help="Path or url of the dataset. If empty download from S3.")
    parser.add_argument("--model_checkpoint",
                        type=str,
                        default="t5-small",
                        help="Path, url or short name of the model")
    parser.add_argument("--max_history",
                        type=int,
                        default=7,
                        help="Number of previous exchanges to keep in history")
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=10,
                        help="Batch size for training")
    parser.add_argument("--valid_batch_size",
                        type=int,
                        default=10,
                        help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=12,
                        help="Accumulate gradients on several steps")
    parser.add_argument("--lr", type=float, default=6e-4, help="Learning rate")
    parser.add_argument("--max_norm",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=3,
                        help="Number of training epochs")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument(
        "--fp16",
        type=str,
        default="",
        help=
        "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank for distributed training (-1: not distributed)")
    parser.add_argument("--save_name", type=str, default="")
    parser.add_argument("--mask_ratio", type=float, default=0.15)
    parser.add_argument("--objective",
                        type=str,
                        default="span_denosing",
                        help="response_generation, span_denosing, both")
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", args.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info("Prepare tokenizer, pretrained model and optimizer.")
    tokenizer = T5Tokenizer.from_pretrained(args.model_checkpoint)
    model = T5ForConditionalGeneration.from_pretrained(args.model_checkpoint)
    model.to(args.device)
    # Add special tokens if they are not already added
    add_special_tokens_(model, tokenizer)
    optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)

    def collate_fn(data):
        batch = {
            "corrupted_context": [],
            "context": [],
            "target": [],
            "response": []
        }
        padded_dataset = {}
        batch_size = len(data)
        resp_sos, context_sos = tokenizer.convert_tokens_to_ids([
            "<go_r>",
            "<go_b>",
        ])
        for x in data:
            corrupted_context = ["fill : "]
            target = []
            length = len(x["context_words"])
            mask_bool = random_spans_noise_mask(length=length,
                                                noise_density=args.mask_ratio,
                                                mean_noise_span_length=3.0)
            mask_id = 0
            #print(mask_bool)
            for i in range(length):
                if mask_bool[i]:
                    if i > 0 and mask_bool[i - 1]:
                        target.append(x["context_words"][i])
                    else:
                        target.append(f"<extra_id_{mask_id}>")
                        target.append(x["context_words"][i])
                        corrupted_context.append(f"<extra_id_{mask_id}>")
                        mask_id += 1
                else:
                    corrupted_context.append(x["context_words"][i])
            target.append("<eos_b>")
            batch["context"].append(
                tokenizer.encode("response : " + " ".join(x["context_words"])))
            batch["corrupted_context"].append(
                tokenizer.encode(" ".join(corrupted_context)))
            batch["target"].append(tokenizer.encode(" ".join(target)))
            batch["response"].append(tokenizer.encode(x["response"]))
            # print(" ".join(x["context_words"]))
            # print(" ".join(corrupted_context))
            # print(" ".join(target))
            # print("")

            # print(tokenizer.decode(batch["corrupted_context"][-1]))
            # print(tokenizer.decode(batch["target"][-1]))
            # print(tokenizer.decode(batch["response"][-1]))
            # print("")
        context_ids, context_masks = padInput(batch["context"])
        input_ids, masks = padInput(batch["corrupted_context"])
        target_ids, target_inputs = padOutput(batch["target"])
        response_ids, response_inputs = padOutput(batch["response"])
        #inputs
        padded_dataset["input_ids"] = torch.tensor(input_ids, dtype=torch.long)
        padded_dataset["masks"] = torch.tensor(masks, dtype=torch.long)
        padded_dataset["context_ids"] = torch.tensor(context_ids,
                                                     dtype=torch.long)
        padded_dataset["context_masks"] = torch.tensor(context_masks,
                                                       dtype=torch.long)
        padded_dataset["target_ids"] = torch.tensor(target_ids,
                                                    dtype=torch.long)
        padded_dataset["response_ids"] = torch.tensor(response_ids,
                                                      dtype=torch.long)
        padded_dataset["target_inputs"] = torch.tensor(np.concatenate((np.ones(
            (batch_size, 1)) * context_sos, target_inputs[:, :-1]),
                                                                      axis=1),
                                                       dtype=torch.long)
        padded_dataset["response_inputs"] = torch.tensor(np.concatenate(
            (np.ones((batch_size, 1)) * resp_sos, response_inputs[:, :-1]),
            axis=1),
                                                         dtype=torch.long)

        return padded_dataset

    logger.info("Prepare datasets")
    train_dataset, valid_dataset, train_sampler, valid_sampler = get_data(
        args, tokenizer)

    train_loader = DataLoader(train_dataset,
                              sampler=train_sampler,
                              batch_size=args.train_batch_size,
                              shuffle=(not args.distributed),
                              collate_fn=collate_fn,
                              num_workers=4)
    val_loader = DataLoader(valid_dataset,
                            sampler=valid_sampler,
                            batch_size=args.valid_batch_size,
                            shuffle=False,
                            collate_fn=collate_fn,
                            num_workers=4)

    logger.info("Train dataset length: {}".format(len(train_dataset)))
    logger.info("Valid dataset length: {}".format(len(valid_dataset)))

    # for batch in train_loader:
    #     #print(batch)
    #     exit(0)
    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = tuple(batch[input_name].to(args.device)
                      for input_name in MODEL_INPUTS)
        input_ids, masks, context_ids, context_masks, target_ids, target_inputs, response_ids, response_inputs = batch
        # print("input")
        # print(tokenizer.decode(input_ids[0, :].tolist()))
        # print("context_ids")
        # print(tokenizer.decode(context_ids[0, :].tolist()))
        # print("target")
        # print(tokenizer.decode(target_ids[0, :].tolist()))
        # print("target In")
        # print(tokenizer.decode(target_inputs[0, :].tolist()))
        # print("response_ids")
        # print(tokenizer.decode(response_ids[0, :].tolist()))
        # print("response_inputs")
        # print(tokenizer.decode(response_inputs[0, :].tolist()))
        #exit(0)
        outputs = model(input_ids,
                        attention_mask=masks,
                        decoder_input_ids=target_inputs,
                        lm_labels=target_ids)
        context_loss = outputs[0]

        outputs = model(context_ids,
                        attention_mask=context_masks,
                        decoder_input_ids=response_inputs,
                        lm_labels=response_ids)

        resp_loss = outputs[0]

        loss = (context_loss + resp_loss) / args.gradient_accumulation_steps

        loss = (context_loss) / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(batch[input_name].to(args.device)
                          for input_name in MODEL_INPUTS)
            input_ids, masks, context_ids, context_masks, target_ids, target_inputs, response_ids, response_inputs = batch

            outputs = model(
                input_ids,
                attention_mask=masks,
                decoder_input_ids=target_inputs  #, lm_labels=target_ids
            )

            context_logits = outputs[0]
            outputs = model(
                context_ids,
                attention_mask=context_masks,
                decoder_input_ids=response_inputs,
                #lm_labels=response_ids
            )
            resp_logits = outputs[0]

            context_logits_flat_shifted = context_logits.view(
                -1, context_logits.size(-1))
            context_labels_flat_shifted = target_ids.view(-1)

            resp_logits_flat_shifted = resp_logits.view(
                -1, resp_logits.size(-1))
            resp_labels_flat_shifted = response_ids.view(-1)

            return (context_logits_flat_shifted,
                    resp_logits_flat_shifted), (context_labels_flat_shifted,
                                                resp_labels_flat_shifted)
            #return (context_logits_flat_shifted, context_logits_flat_shifted), (context_labels_flat_shifted, context_labels_flat_shifted)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    # if args.eval_before_start:
    #     trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "span":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-100),
             output_transform=lambda x: (x[0][0], x[1][0])),
        "response":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-100),
             output_transform=lambda x: (x[0][1], x[1][1]))
    }
    metrics.update({
        "average_span":
        MetricsLambda(average_distributed_scalar, metrics["span"], args),
        "average_response":
        MetricsLambda(average_distributed_scalar, metrics["response"], args)
    })
    metrics["average_response"] = MetricsLambda(math.exp,
                                                metrics["average_response"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        if not os.path.exists(f"pretrained_model/{args.save_name}"):
            os.makedirs(f"pretrained_model/{args.save_name}")
        log_dir = f"pretrained_model/{args.save_name}"
        tb_logger = TensorboardLogger(log_dir)

        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" takes care of distributed encapsulation

        torch.save(args, log_dir + '/model_training_args.bin')
        getattr(model, 'module',
                model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME))
        tokenizer.save_pretrained(log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(
            os.path.join(log_dir, checkpoint_handler._saved[-1][1]),
            os.path.join(log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Beispiel #2
0
                                {'_': mude})

    trainer.add_event_handler(Events.ITERATION_COMPLETED, nan_handler)
    trainer.add_event_handler(Events.ITERATION_COMPLETED, coslr)

    GpuInfo().attach(trainer, name='gpu')
    pbar.attach(trainer,
                output_transform=lambda output: {'loss': output['loss']},
                metric_names=[f"gpu:{args.gpu} mem(%)"])

    # FIRE
    tb_logger = TensorboardLogger(log_dir=TENSORBOARD_RUN_LOG_DIR_PATH)
    tb_logger.attach(
        trainer,
        log_handler=OutputHandler(
            tag='training',
            output_transform=lambda output: {'loss': output['loss']}),
        event_name=Events.ITERATION_COMPLETED(
            every=LOG_TRAINING_PROGRESS_EVERY_N))
    tb_logger.attach(
        evaluator,
        log_handler=OutputHandler(
            tag='validation',
            metric_names='all',
            global_step_transform=global_step_from_engine(trainer)),
        event_name=Events.EPOCH_COMPLETED)
    tb_logger.attach(trainer,
                     log_handler=OptimizerParamsHandler(opt),
                     event_name=Events.ITERATION_STARTED)
    tb_logger.attach(trainer,
                     log_handler=WeightsHistHandler(mude),
Beispiel #3
0
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_dir):
    train_loader, val_loader = get_data_loaders(train_batch_size,
                                                val_batch_size)
    model = Net()
    device = "cpu"

    if torch.cuda.is_available():
        device = "cuda"

    model.to(device)  # Move model before creating optimizer
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    criterion = nn.CrossEntropyLoss()
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=device)
    trainer.logger = setup_logger("Trainer")

    if sys.version_info > (3, ):
        from ignite.contrib.metrics.gpu_info import GpuInfo

        try:
            GpuInfo().attach(trainer)
        except RuntimeError:
            print(
                "INFO: By default, in this example it is possible to log GPU information (used memory, utilization). "
                "As there is no pynvml python package installed, GPU information won't be logged. Otherwise, please "
                "install it : `pip install pynvml`")

    metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)}

    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device)
    train_evaluator.logger = setup_logger("Train Evaluator")
    validation_evaluator = create_supervised_evaluator(model,
                                                       metrics=metrics,
                                                       device=device)
    validation_evaluator.logger = setup_logger("Val Evaluator")

    @trainer.on(Events.EPOCH_COMPLETED)
    def compute_metrics(engine):
        train_evaluator.run(train_loader)
        validation_evaluator.run(val_loader)

    tb_logger = TensorboardLogger(log_dir=log_dir)

    tb_logger.attach_output_handler(
        trainer,
        event_name=Events.ITERATION_COMPLETED(every=100),
        tag="training",
        output_transform=lambda loss: {"batchloss": loss},
        metric_names="all",
    )

    for tag, evaluator in [("training", train_evaluator),
                           ("validation", validation_evaluator)]:
        tb_logger.attach_output_handler(
            evaluator,
            event_name=Events.EPOCH_COMPLETED,
            tag=tag,
            metric_names=["loss", "accuracy"],
            global_step_transform=global_step_from_engine(trainer),
        )

    tb_logger.attach_opt_params_handler(
        trainer,
        event_name=Events.ITERATION_COMPLETED(every=100),
        optimizer=optimizer)

    tb_logger.attach(trainer,
                     log_handler=WeightsScalarHandler(model),
                     event_name=Events.ITERATION_COMPLETED(every=100))

    tb_logger.attach(trainer,
                     log_handler=WeightsHistHandler(model),
                     event_name=Events.EPOCH_COMPLETED(every=100))

    tb_logger.attach(trainer,
                     log_handler=GradsScalarHandler(model),
                     event_name=Events.ITERATION_COMPLETED(every=100))

    tb_logger.attach(trainer,
                     log_handler=GradsHistHandler(model),
                     event_name=Events.EPOCH_COMPLETED(every=100))

    def score_function(engine):
        return engine.state.metrics["accuracy"]

    model_checkpoint = ModelCheckpoint(
        log_dir,
        n_saved=2,
        filename_prefix="best",
        score_function=score_function,
        score_name="validation_accuracy",
        global_step_transform=global_step_from_engine(trainer),
    )
    validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint,
                                           {"model": model})

    # kick everything off
    trainer.run(train_loader, max_epochs=epochs)

    tb_logger.close()
def train(args):
    logger.info("Prepare tokenizer, pretrained model and optimizer.")
    tokenizer, _, vocab = get_kogpt2_tokenizer()
    model = get_kogpt2_model()
    model.to(args.device)
    optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True)

    logger.info("Prepare datasets")
    train_loader, val_loader = get_data_loaders(args, tokenizer, vocab)

    def update(engine, batch):
        model.train()

        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        input_ids, labels, token_type_ids = batch

        loss, *_ = model(input_ids,
                         token_type_ids=token_type_ids,
                         labels=labels)
        loss = loss / args.gradient_accumulation_steps

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)

        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        return loss.item()

    trainer = Engine(update)

    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            input_ids, labels, token_type_ids = batch
            # logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            # if we dont send labels to model, it doesnt return losses
            logits, *_ = model(input_ids, token_type_ids=token_type_ids)
            logits_flat_shifted = logits[..., :-1, :].contiguous().view(
                -1, logits.size(-1))
            labels_flat_shifted = labels[..., 1:].contiguous().view(-1)
            return (logits_flat_shifted), (labels_flat_shifted)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-100),
             output_transform=lambda x: (x[0], x[1])),
        "accuracy":
        Accuracy(output_transform=lambda x: (x[0], x[1]))
    }
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model,
    # configuration and tokenizer before we start to train
    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names=["loss"])
    evaluator.add_event_handler(
        Events.COMPLETED, lambda _: pbar.log_message(
            "Validation: %s" % pformat(evaluator.state.metrics)))

    log_dir = make_logdir("kogpt2_personachat")
    tb_logger = TensorboardLogger(log_dir)

    tb_logger.attach(trainer,
                     log_handler=OutputHandler(tag="training",
                                               metric_names=["loss"]),
                     event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(trainer,
                     log_handler=OptimizerParamsHandler(optimizer),
                     event_name=Events.ITERATION_STARTED)
    tb_logger.attach(
        evaluator,
        log_handler=OutputHandler(
            tag="validation",
            metric_names=list(metrics.keys()),
            global_step_transform=global_step_from_engine(trainer)),
        event_name=Events.EPOCH_COMPLETED)

    checkpoint_handler = ModelCheckpoint(log_dir,
                                         'checkpoint',
                                         save_interval=1,
                                         n_saved=3)
    trainer.add_event_handler(
        Events.EPOCH_COMPLETED, checkpoint_handler,
        {'mymodel': getattr(model, 'module', model)
         })  # "getattr" takes care of distributed encapsulation

    torch.save(args, log_dir + '/model_training_args.bin')
    getattr(model, 'module',
            model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME))
    # tokenizer.save_pretrained(log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    # TODO: PR in ignite to have better access to saved file paths (cleaner)
    os.rename(os.path.join(log_dir, checkpoint_handler._saved[-1][1]),
              os.path.join(log_dir, WEIGHTS_NAME))
    tb_logger.close()
def train(): 
    parser = ArgumentParser()
    parser.add_argument("--train_path", type=str, default='data/spolin-train-acl.json', help="Set data path")    
    parser.add_argument("--valid_path", type=str, default='data/spolin-valid.json', help="Set data path")     

    parser.add_argument("--correct_bias", type=bool, default=False, help="Set to true to correct bias for Adam optimizer")
    parser.add_argument("--lr", type=float, default=2e-5, help="Set learning rate")
    parser.add_argument("--n_epochs", type=int, default=4, help="Set number of epochs")
    parser.add_argument("--num_warmup_steps", type=float, default=1000, help="Set number of warm-up steps")
    parser.add_argument("--num_total_steps", type=float, default=10000, help="Set number of total steps")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)")
    parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Set maximum gradient normalization.")
    parser.add_argument("--pretrained_path", type=str, default='bert-base-uncased', help="Choose which pretrained model to use (bert-base-uncased, roberta-base, roberta-large, roberta-large-mnli)")    
    parser.add_argument("--batch_size", type=int, default=32, help="Provide the batch size")    
    parser.add_argument("--random_seed", type=int, default=42, help="Set the random seed")
    parser.add_argument("--test", action='store_true', help="If true, run with small dataset for testing code")
    parser.add_argument("--base", action='store_true', help="If true, run with base experiment configuration (training with spont only) for comparison")

    args = parser.parse_args() 

    logging.basicConfig(level=logging.INFO)
    logger.info("Arguments: {}".format(pformat(args)))

    if 'roberta' in args.pretrained_path: 
        # initialize tokenizer and model 
        logger.info("Initialize model and tokenizer.")
        tokenizer = RobertaTokenizer.from_pretrained(args.pretrained_path, cache_dir = '../pretrained_models')
        model = RobertaForSequenceClassification.from_pretrained(args.pretrained_path, cache_dir='../pretrained_models')

        ### START MODEL MODIFICATION
        # Pretrained model was not trained with token type ids. 
        # fix token type embeddings for finetuning. Without this, the model can only take 0s as valid input for token_type_ids 
        model.config.type_vocab_size = 2 
        model.roberta.embeddings.token_type_embeddings = torch.nn.Embedding(2, model.config.hidden_size)
        model.roberta.embeddings.token_type_embeddings.weight.data.normal_(mean=0.0, std=model.config.initializer_range)

        ### END MOD
    elif 'bert' in args.pretrained_path: 
        model = BertForSequenceClassification.from_pretrained(args.pretrained_path, cache_dir='../pretrained_models')
        tokenizer = BertTokenizer.from_pretrained(args.pretrained_path, cache_dir='../pretrained_models')

    model.to(args.device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']

    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01}, 
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
    ]

    optimizer = AdamW(optimizer_grouped_parameters,
                        lr=args.lr,
                        correct_bias = args.correct_bias)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.num_warmup_steps, t_total=args.num_total_steps) 

    logger.info("Prepare datasets")
    logger.info("Loading train set...")

    train_data = get_data(args.train_path)
    valid_data = get_data(args.valid_path)

    cornell_valid_data = {k: {'cornell': valid_data[k]['cornell']} for k in valid_data.keys()}
    spont_valid_data = {k: {'spont': valid_data[k]['spont']} for k in valid_data.keys()}

    train_loader, train_sampler = get_data_loaders(args, train_data, args.train_path, tokenizer)
    logger.info("Loading validation set...")
    valid_p = Path(args.valid_path)
    cornell_valid_loader, cornell_valid_sampler = get_data_loaders(args, cornell_valid_data, f"{str(valid_p.parent)}/cornell_{valid_p.name}",  tokenizer)
    spont_valid_loader, spont_valid_sampler = get_data_loaders(args, spont_valid_data, f"{str(valid_p.parent)}/spont_{valid_p.name}", tokenizer)


    # Training function and trainer 
    def update(engine, batch): 
        model.train() 

        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        b_input_ids, b_input_mask, b_input_segment, b_labels = batch

        optimizer.zero_grad()
        #roberta has issues with token_type_ids 
        loss, logits = model(b_input_ids, token_type_ids=b_input_segment, attention_mask=b_input_mask, labels=b_labels)
        # loss, logits = model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask, labels=b_labels)


        loss.backward() 
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
        
        optimizer.step() 
        scheduler.step() 

        return loss.item(), logits, b_labels

    trainer = Engine(update)     

    # Evaluation function and evaluator 
    def inference(engine, batch): 
        model.eval() 

        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        b_input_ids, b_input_mask, b_input_segment, b_labels = batch
        
        with torch.no_grad(): 
            #roberta has issues with token_type_ids 
            # loss, logits = model(b_input_ids, token_type_ids = None, attention_mask=b_input_mask, labels=b_labels)
            loss, logits = model(b_input_ids, token_type_ids = b_input_segment, attention_mask=b_input_mask, labels=b_labels)
            label_ids = b_labels

        return logits, label_ids, loss.item()
    cornell_evaluator = Engine(inference)
    spont_evaluator = Engine(inference)


    trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: cornell_evaluator.run(cornell_valid_loader))
    trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: spont_evaluator.run(spont_valid_loader))


    RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "loss") 
    RunningAverage(Accuracy(output_transform=lambda x: (x[1], x[2]))).attach(trainer, "accuracy")
    if torch.cuda.is_available(): 
        GpuInfo().attach(trainer, name='gpu')

    recall = Recall(output_transform=lambda x: (x[0], x[1]))
    precision = Precision(output_transform=lambda x: (x[0], x[1]))
    F1 = (precision * recall * 2 / (precision + recall)).mean()
    accuracy = Accuracy(output_transform=lambda x: (x[0], x[1]))
    metrics = {"recall": recall, "precision": precision, "f1": F1, "accuracy": accuracy, "loss": Average(output_transform=lambda x: x[2])}

    for name, metric in metrics.items(): 
        metric.attach(cornell_evaluator, name) 
        metric.attach(spont_evaluator, name) 


    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names=['loss', 'accuracy'])
    pbar.attach(trainer, metric_names=['gpu:0 mem(%)', 'gpu:0 util(%)'])
    
    cornell_evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Cornell validation metrics:\n %s" % pformat(cornell_evaluator.state.metrics)))
    spont_evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Spont validation metrics:\n %s" % pformat(spont_evaluator.state.metrics)))


    tb_logger = TensorboardLogger(log_dir=None)
    tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED)
    tb_logger.attach(cornell_evaluator, log_handler=OutputHandler(tag="valid", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED)
    tb_logger.attach(spont_evaluator, log_handler=OutputHandler(tag="valid", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED)


    # tb_logger.writer.log_dir -> tb_logger.writer.logdir (this is the correct attribute name as seen in: https://tensorboardx.readthedocs.io/en/latest/_modules/tensorboardX/writer.html#SummaryWriter)
    checkpoint_handler = ModelCheckpoint(tb_logger.writer.logdir, 'checkpoint', save_interval=1, n_saved=5)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model)})  # "getattr" take care of distributed encapsulation

    torch.save(args, tb_logger.writer.logdir + '/model_training_args.bin')
    getattr(model, 'module', model).config.to_json_file(os.path.join(tb_logger.writer.logdir, CONFIG_NAME))
    tokenizer.save_vocabulary(tb_logger.writer.logdir)

    trainer.run(train_loader, max_epochs = args.n_epochs)

    if args.n_epochs > 0: 
        os.rename(checkpoint_handler._saved[-1][1][-1], os.path.join(tb_logger.writer.logdir, WEIGHTS_NAME))  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Beispiel #6
0
def train():
    parser = ArgumentParser()
    parser.add_argument("--local_rank", type=int, default=-1)
    args = parser.parse_args()
    device = torch.device("cuda" if torch.cuda.device_count() > 1 else "cpu")
    model = GPT2DoubleHeadsModel.from_pretrained('gpt2')
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

    DISTRIBUTED = args.local_rank != -1

    if DISTRIBUTED and torch.distributed.is_available():
        print("Distributed")
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        #BATCH_SIZE *= 2

    def average_distributed_scalar(scalar):
        if (not DISTRIBUTED):
            return scalar
        scalar_t = torch.tensor(
            scalar, dtype=torch.float,
            device=device) / torch.distributed.get_world_size()
        torch.distributed.all_reduce(scalar_t,
                                     op=torch.distributed.ReduceOp.SUM)
        return scalar_t.item()

    optimizer = AdamW(model.parameters(), lr=6.25e-5)

    ds = dataloader.Conv_GPT2_DataClass(tokenizer)
    v_ds = dataloader.Conv_GPT2_DataClass(tokenizer, dev=True)
    orig_added_tokens = len(tokenizer.encoder)
    num_added_tokens = tokenizer.add_special_tokens(
        dataloader.ATTR_SPECIAL_TOKENS)
    if (num_added_tokens > 0):
        model.resize_token_embeddings(new_num_tokens=orig_added_tokens +
                                      num_added_tokens)
    model = model.to(device)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        ds) if DISTRIBUTED else None
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        v_ds) if DISTRIBUTED else None

    dl = DataLoader(ds,
                    sampler=train_sampler,
                    batch_size=BATCH_SIZE,
                    shuffle=not DISTRIBUTED)
    v_dl = DataLoader(v_ds, sampler=valid_sampler, shuffle=False)

    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0][0], x[1][0])),
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"]),
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])

    def update(engine, batch):
        model.train()
        batch = tuple(t.to(device) for t in batch)
        lm_loss, *_ = model(batch[0],
                            token_type_ids=batch[1],
                            lm_labels=batch[2])
        loss = lm_loss / ITERATION_STEP
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        if engine.state.iteration % ITERATION_STEP == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(t.to(device) for t in batch)
            input_ids, token_type_ids, lm_labels = batch
            model_outputs = model(input_ids, token_type_ids=token_type_ids)
            lm_logits = model_outputs[0]
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return lm_logits_flat_shifted, lm_labels_flat_shifted

    trainer = Engine(update)
    evaluator = Engine(inference)

    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, 6.25e-5),
                                 (EPOCHS * len(ds) // BATCH_SIZE, 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(v_dl))

    if DISTRIBUTED:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        #evaluator.add_event_handler(Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")

    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    if (args.local_rank in [0, -1]):
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        #evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir='./logs')
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        #tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint('./checkpoint',
                                             '_checkpoint',
                                             n_saved=3)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler,
                                  {'gpt2_qg': getattr(model, 'module', model)})

        getattr(model, 'module', model).config.to_json_file(
            os.path.join('./checkpoint', 'config'))
        tokenizer.save_pretrained('./checkpoint')

    trainer.run(dl, max_epochs=EPOCHS)

    if (args.local_rank in [0, -1]):
        tb_logger.close()
Beispiel #7
0
def trainer(
    train_batch,
    evaluate_batch,
    evaluate_data_loaders,
    metrics,
    optimizers,
):
    '''
    Create standard trainer with evaluators.

    Parameters
    ----------
    train_batch : function
        function that trains on given batch
    evaluate_batch : function
        function that evaluates a given batch
    evaluate_data_loaders: list
        data loaders that yield batches to evaluate on
    metrics : dict
        dict with one dict each for 'train' and evaluate data loader. Wrap a
        metric with trainer.Progress to show in progress bar.
    optimizers : dict
        dict with optimizers for logging

    Returns
    -------
    tuple
        trainer engine
        list of evaluator engines
        tensorboard logger
    '''

    trainer = ignite.engine.Engine(train_batch)

    for name, metric in metrics.get(PROGRESS_DESC, dict()).items():
        metric.attach(trainer, name)

    for name, metric in metrics.get(TRAIN_DESC, dict()).items():
        metric.attach(trainer, name)

    evaluators = {
        evaluator_name: ignite.engine.Engine(evaluate_batch)
        for evaluator_name in evaluate_data_loaders.keys()
    }

    for evaluator_name, evaluator in evaluators.items():
        for metric_name, metric in metrics[evaluator_name].items():
            metric.attach(evaluator, metric_name)

    tensorboard_logger = TensorboardLogger(log_dir='tb')

    EpochLogger().attach(trainer)

    # Order of attaching progress bars is important for vscode / atom
    ProgressBar(desc=TRAIN_DESC).attach(trainer,
                                        metric_names=list(
                                            metrics.get(PROGRESS_DESC,
                                                        dict()).keys()))
    tensorboard_logger.attach(
        trainer,
        OutputHandler(
            tag=PROGRESS_DESC,
            metric_names=list(metrics.get(PROGRESS_DESC, dict()).keys()),
        ),
        Events.ITERATION_COMPLETED,
    )

    MetricsLogger(TRAIN_DESC).attach(trainer,
                                     metrics.get(TRAIN_DESC, dict()).keys())
    tensorboard_logger.attach(
        trainer,
        OutputHandler(
            tag=TRAIN_DESC,
            metric_names=list(metrics.get(TRAIN_DESC, dict()).keys()),
        ),
        Events.ITERATION_COMPLETED,
    )

    def run_evaluator(evaluator_desc):
        return lambda engine: evaluators[evaluator_desc].run(
            evaluate_data_loaders[evaluator_desc])

    for evaluator_desc, evaluator in evaluators.items():
        evaluator_metric_names = list(metrics[evaluator_desc].keys())

        trainer.add_event_handler(
            Events.EPOCH_COMPLETED,
            run_evaluator(evaluator_desc),
        )

        ProgressBar(desc=evaluator_desc).attach(evaluator)
        MetricsLogger(evaluator_desc).attach(evaluator, evaluator_metric_names)
        tensorboard_logger.attach(
            evaluator,
            OutputHandler(
                tag=evaluator_desc,
                metric_names=evaluator_metric_names,
                global_step_transform=global_step_from_engine(trainer),
            ),
            Events.EPOCH_COMPLETED,
        )

    if type(optimizers) is not dict:
        optimizers = dict(optimizer=optimizers)

    for name, optimizer in optimizers.items():
        tensorboard_logger.attach(
            trainer,
            log_handler=OptimizerParamsHandler(
                tag=f'{TRAIN_DESC}/{name}',
                param_name='lr',
                optimizer=optimizer,
            ),
            event_name=Events.ITERATION_COMPLETED,
        )

    return trainer, evaluators, tensorboard_logger
Beispiel #8
0
def main():
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset_path",
        type=str,
        default="",
        help="Path or url of the dataset. If empty download from S3.")
    parser.add_argument("--dataset_cache",
                        type=str,
                        default='./dataset_cache',
                        help="Path or url of the dataset cache")
    parser.add_argument("--model",
                        type=str,
                        default="",
                        help="Model type, one of: %s" %
                        ', '.join(MODELS.keys()))
    parser.add_argument("--model_checkpoint",
                        type=str,
                        default="",
                        help="Path, url or short name of a pretrained model")
    parser.add_argument("--num_candidates",
                        type=int,
                        default=2,
                        help="Number of candidates for training")
    parser.add_argument("--max_history",
                        type=int,
                        default=2,
                        help="Number of previous exchanges to keep in history")
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for training")
    parser.add_argument("--valid_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=8,
                        help="Accumulate gradients on several steps")
    parser.add_argument("--lr",
                        type=float,
                        default=6.25e-5,
                        help="Learning rate")
    parser.add_argument("--lm_coef",
                        type=float,
                        default=1.0,
                        help="LM loss coefficient")
    parser.add_argument("--mc_coef",
                        type=float,
                        default=1.0,
                        help="Multiple-choice loss coefficient")
    parser.add_argument("--adv_coef",
                        type=float,
                        default=1.0,
                        help="Adversarial dataset prediction loss coefficient")
    parser.add_argument("--max_norm",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight decay if we apply some.")
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=3,
                        help="Number of training epochs")
    #parser.add_argument("--personality_permutations", type=int, default=1, help="Number of permutations of personality sentences")
    parser.add_argument(
        "--eval_before_start",
        action='store_true',
        help="If true start with a first evaluation before training")
    #parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument(
        "--fp16",
        type=str,
        default="",
        help=
        "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank for distributed training (-1: not distributed)")
    parser.add_argument(
        "--max_sequence_length",
        type=int,
        default=-1,
        help="If set, use this to manually restrict the sequence length. "
        "This might be helpful to save resources (memory). "
        "If not set, this is looked up from the model config (n_ctx value).")
    parser.add_argument(
        "--adversarial_dataset_prediction",
        action='store_true',
        help="Set to train with adversarial dataset prediction")
    parser.add_argument("--seed",
                        type=int,
                        default=None,
                        help='set random seed')
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", args.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    if args.seed is not None:
        torch.manual_seed(args.seed)

    args.distributed = (args.local_rank != -1)

    logger.info("Prepare tokenizer and data")
    if not args.model:
        logger.warning(
            '"model" parameter is not set! This is deprecated. Please use one of: %s. '
            'To mimic deprecated behaviour, "model_checkpoint" will be used as "model"'
            % ', '.join(MODELS.keys()))
        args.model = args.model_checkpoint
    if args.model not in MODELS:
        raise NotImplementedError(
            'model "%s" not implemented. use one of: %s' %
            (args.model, ', '.join(MODELS.keys())))
    config_class, tokenizer_class, model_class, _ = MODELS[args.model]
    if not args.model_checkpoint:
        args.model_checkpoint = args.model

    model_config = config_class.from_pretrained(args.model_checkpoint)
    tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)

    additional_special_tokens = [TYPE_BACKGROUND, TYPE_BOT, TYPE_USER]
    # for adversarial training (dataset prediction)
    dataset_labels = None
    if args.adversarial_dataset_prediction:
        dataset_labels = [
            get_dataset_label(dataset_path)
            for dataset_path in args.dataset_path.split(',')
        ]
        #additional_special_tokens.extend(dataset_labels)
        #if model_class not in ADV_MODELS.values():
        assert model_class in ADV_MODELS, f'no adversarial model implemented for model class: {model_class.__name__}'
        model_class = ADV_MODELS[model_class]
        if not hasattr(model_config, 'cls'):
            model_config.cls = {}
        if 'dataset_labels' in model_config.cls:
            assert all([dl in model_config.cls['dataset_labels']['labels'] for dl in dataset_labels]), \
                f'loaded dataset_labels [{model_config.cls["dataset_labels"]["labels"]}] do not contain all ' \
                f'current dataset_labels [{dataset_labels}]'
            dataset_labels = model_config.cls['dataset_labels']['labels']
        else:
            model_config.cls['dataset_labels'] = {
                'labels': dataset_labels,
                'is_adversarial': True
            }
        model_input_names = [
            "input_ids", "mc_token_ids", "lm_labels", "mc_labels",
            "dataset_labels", "token_type_ids"
        ]
        # not yet used
        model_output_names = [
            "lm_loss", "mc_loss", "cl_loss_0", "lm_logits", "mc_logits",
            "cl_logits_0", "presents"
        ]
    else:
        model_input_names = [
            "input_ids", "mc_token_ids", "lm_labels", "mc_labels",
            "token_type_ids"
        ]
        # not yet used
        model_output_names = [
            "lm_loss", "mc_loss", "lm_logits", "mc_logits", "presents"
        ]

    tokenizer.add_special_tokens({
        'bos_token':
        TYPE_BOS,
        'eos_token':
        TYPE_EOS,
        'pad_token':
        TYPE_PAD,
        'additional_special_tokens':
        additional_special_tokens
    })

    logger.info("Prepare datasets")
    max_sequence_length = model_config.n_ctx if args.max_sequence_length <= 0 else args.max_sequence_length
    assert max_sequence_length <= model_config.n_ctx, 'max_sequence_length [%i] was set to a value higher than ' \
                                                      'supported by the model (config.n_ctx [%i]). Please use a lower ' \
                                                      'value or do not set it [-1] to use the highest supported one.' \
                                                      % (max_sequence_length, model_config.n_ctx)
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        args=args,
        tokenizer=tokenizer,
        model_input_names=model_input_names,
        max_sequence_length=max_sequence_length,
        dataset_labels=dataset_labels)

    logger.info(
        "Prepare pretrained model and optimizer - add special tokens for fine-tuning"
    )

    # Initialize distributed training if needed
    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.n_gpu = 1
    args.device = device

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier(
        )  # Barrier to make sure only the first process in distributed training download model & vocab

    #model = model_class.from_pretrained(args.model_checkpoint, num_cl_labels=len(dataset_ids))    # for GPT2DoubleHeadsModelwithAdversarial
    model = model_class.from_pretrained(args.model_checkpoint,
                                        config=model_config)
    model.resize_token_embeddings(len(tokenizer))
    model.to(args.device)

    if args.local_rank == 0:
        torch.distributed.barrier(
        )  # End of barrier to make sure only the first process in distributed training download model & vocab

    ####################################################################################################################

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    #optimizer = OpenAIAdam(model.parameters(), lr=args.lr)
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr)
    # scheduler is set below (see ignite)
    #scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps,
    #                                            num_training_steps=len(train_loader) // args.train_batch_size + 1)

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_checkpoint, 'optimizer.pt')) and os.path.isfile(
                os.path.join(args.model_checkpoint, 'scheduler.pt')):
        # Load in optimizer and scheduler states
        # TODO: this needs to be dumped somewhere
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_checkpoint, 'optimizer.pt')))
        #scheduler.load_state_dict(torch.load(os.path.join(args.model_checkpoint, 'scheduler.pt')))

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = {
            model_input_names[i]: input_tensor.to(args.device)
            for i, input_tensor in enumerate(batch)
        }
        model_output = model(**batch)
        losses = model_output[:
                              3] if args.adversarial_dataset_prediction else model_output[:
                                                                                          2]
        if args.n_gpu > 1:  # mean() to average on multi-gpu.
            losses = list(losses)
            for i in range(len(losses)):
                losses[i] = losses[i].mean()
        lm_loss, mc_loss = losses[0], losses[1]
        loss = (lm_loss * args.lm_coef +
                mc_loss * args.mc_coef) / args.gradient_accumulation_steps

        # handle adversarial loss
        loss_wo_adv = loss.clone()
        if args.adversarial_dataset_prediction:
            adv_loss = model_output[2]
            loss += (adv_loss *
                     args.adv_coef) / args.gradient_accumulation_steps

        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            #scheduler.step()  # Update learning rate schedule # already DONE below!
            optimizer.zero_grad()
        return loss_wo_adv.item(), loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            if args.adversarial_dataset_prediction:
                input_ids, mc_token_ids, lm_labels, mc_labels, dataset_labels, token_type_ids = batch
            else:
                input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch

            logger.debug(
                tokenizer.decode(input_ids[0, -1, :].tolist()).replace(
                    TYPE_PAD, ''))
            model_outputs = model(input_ids=input_ids,
                                  mc_token_ids=mc_token_ids,
                                  token_type_ids=token_type_ids)
            lm_logits, mc_logits = model_outputs[0], model_outputs[
                1]  # So we can also use GPT2 outputs
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted,
                    mc_logits), (lm_labels_flat_shifted, mc_labels)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero (scheduler)
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "loss")
    if args.adversarial_dataset_prediction:
        RunningAverage(output_transform=lambda x: x[1]).attach(
            trainer, "loss_w/_adv")
        RunningAverage(output_transform=lambda x: x[1] - x[0]).attach(
            trainer, "loss_only_adv")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0][0], x[1][0])),
        "accuracy":
        Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args),
        "average_accuracy":
        MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=None)
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        if args.adversarial_dataset_prediction:
            tb_logger.attach(trainer,
                             log_handler=OutputHandler(
                                 tag="training", metric_names=["loss_w/_adv"]),
                             event_name=Events.ITERATION_COMPLETED)
            tb_logger.attach(trainer,
                             log_handler=OutputHandler(
                                 tag="training",
                                 metric_names=["loss_only_adv"]),
                             event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        logger.info('save checkpoints to: %s' % tb_logger.writer.log_dir)
        checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" take care of distributed encapsulation

        torch.save(args, tb_logger.writer.log_dir + '/model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(
            os.path.join(tb_logger.writer.log_dir, CONFIG_NAME))
        tokenizer.save_pretrained(tb_logger.writer.log_dir)

        #logger.debug("Saving optimizer and scheduler states to %s", tb_logger.writer.log_dir)
        #torch.save(optimizer.state_dict(), os.path.join(tb_logger.writer.log_dir, 'optimizer.pt'))
        #torch.save(scheduler.state_dict(), os.path.join(tb_logger.writer.log_dir, 'scheduler.pt'))

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(
            checkpoint_handler._saved[-1][1][-1],
            os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Beispiel #9
0
def train():
    parser = ArgumentParser()
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for training")
    parser.add_argument("--valid_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=8,
                        help="Accumulate gradients on several steps")
    parser.add_argument("--lr",
                        type=float,
                        default=6.25e-5,
                        help="Learning rate")
    parser.add_argument("--lm_coef",
                        type=float,
                        default=1.0,
                        help="LM loss coefficient")
    parser.add_argument("--mc_coef",
                        type=float,
                        default=1.0,
                        help="Multiple-choice loss coefficient")
    parser.add_argument("--max_norm",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=3,
                        help="Number of training epochs")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(level=logging.INFO)
    logger.info("Arguments: %s", pformat(args))

    logger.info("Prepare tokenizer, pretrained model and optimizer.")
    tokenizer_class = GPT2Tokenizer
    tokenizer = tokenizer_class.from_pretrained("gpt2")

    model_class = GPT2DoubleHeadsModel
    model = model_class.from_pretrained("gpt2")
    model.to(args.device)
    # Add special tokens if they are not already added
    add_special_tokens_(model, tokenizer)  ### TODO add our own special tokens
    optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        args, tokenizer)  ### TODO load data ourselves

    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
        (lm_loss), (mc_loss), *_ = model(input_ids,
                                         token_type_ids=token_type_ids,
                                         mc_token_ids=mc_token_ids,
                                         mc_labels=mc_labels,
                                         lm_labels=lm_labels)
        loss = (lm_loss * args.lm_coef +
                mc_loss * args.mc_coef) / args.gradient_accumulation_steps
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
            logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            # if we dont send labels to model, it doesnt return losses
            lm_logits, mc_logits, *_ = model(
                input_ids,
                token_type_ids=token_type_ids,
                mc_token_ids=mc_token_ids,
            )
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted,
                    mc_logits), (lm_labels_flat_shifted, mc_labels)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    trainer.add_event_handler(Events.STARTED,
                              lambda _: evaluator.run(val_loader))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0][0], x[1][0])),
        "accuracy":
        Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args),
        "average_accuracy":
        MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names=["loss"])
    evaluator.add_event_handler(
        Events.COMPLETED, lambda _: pbar.log_message(
            "Validation: %s" % pformat(evaluator.state.metrics)))

    log_dir = make_logdir("gpt2")
    tb_logger = TensorboardLogger(log_dir)

    tb_logger.attach(trainer,
                     log_handler=OutputHandler(tag="training",
                                               metric_names=["loss"]),
                     event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(trainer,
                     log_handler=OptimizerParamsHandler(optimizer),
                     event_name=Events.ITERATION_STARTED)
    tb_logger.attach(evaluator,
                     log_handler=OutputHandler(tag="validation",
                                               metric_names=list(
                                                   metrics.keys()),
                                               another_engine=trainer),
                     event_name=Events.EPOCH_COMPLETED)

    checkpoint_handler = ModelCheckpoint(log_dir,
                                         'checkpoint',
                                         save_interval=1,
                                         n_saved=3)
    trainer.add_event_handler(
        Events.EPOCH_COMPLETED, checkpoint_handler,
        {'mymodel': getattr(model, 'module', model)
         })  # "getattr" takes care of distributed encapsulation

    torch.save(args, log_dir + '/model_training_args.bin')
    getattr(model, 'module',
            model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME))
    tokenizer.save_pretrained(log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.n_epochs > 0:
        os.rename(
            checkpoint_handler._saved[-1][1][-1],
            os.path.join(log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Beispiel #10
0
def train():
    parser = ArgumentParser()
    parser.add_argument("--dataset_path", type=str, default=DATA_FOLDER, help="Path of the dataset.")
    parser.add_argument("--image_path", type=str, default=IMG_FOLDER, help="Path of the images.")
    parser.add_argument("--images_feature_path", type=str, default=IMG_FEATURE_FOLDER, help="Path of the images.")
    parser.add_argument("--dataset_cache", type=str, default=DATA_CACHE, help="Path of the dataset cache_no_pretrained")
    parser.add_argument("--model_checkpoint", type=str, default="gpt2", help="Path, url or short name of the model")
    parser.add_argument('--dhead_gpt2', action='store_true', default=False, help="use double head gpt2")
    parser.add_argument("--from_step", type=int, default=-1, help="Init learning rate from this step")
    parser.add_argument('--pretrained', action='store_true', default=True, help="If False train from scratch")
    parser.add_argument("--num_candidates", type=int, default=1, help="Number of candidates for training")
    parser.add_argument("--max_history", type=int, default=3, help="Number of previous turns to keep in history")
    parser.add_argument("--max_length", type=int, default=256, help="Max length of input sentence")
    parser.add_argument("--train_batch_size", type=int, default=58, help="Batch size for training")
    parser.add_argument("--valid_batch_size", type=int, default=32, help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=9, help="Accumulate gradients on several steps")
    parser.add_argument("--lr", type=float, default=6.25e-5, help="Learning rate")
    parser.add_argument("--scheduler", type=str, default="linear", choices=['noam', 'linear'], help="method of optim")
    parser.add_argument("--n_emd", type=int, default=768, help="Number of n_emd in config file (for noam)")
    parser.add_argument("--warmup_steps", type=int, default=5000, help="Warm up steps")
    parser.add_argument("--lm_coef", type=float, default=2.0, help="LM loss coefficient")
    parser.add_argument("--mc_coef", type=float, default=1.0, help="Multiple-choice loss coefficient")
    parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm")
    parser.add_argument("--n_epochs", type=int, default=50, help="Number of training epochs")
    parser.add_argument("--num_workers", type=int, default=0, help="Number of subprocesses for data loading")
    parser.add_argument("--personality_permutations", type=int, default=1, help="Number of permutations of personality sentences")
    parser.add_argument("--eval_before_start", action='store_true', help="If true start with a first evaluation before training")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)")
    parser.add_argument("--fp16", type=str, default="O1", help="Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)")
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning("Running process %d", args.local_rank)  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
    logger.info("Prepare tokenizer, pretrained model and optimizer.")
    tokenizer_class = BertTokenizer
    config_class = GPT2Config  # GPT2Config if "gpt2" in args.model_checkpoint else OpenAIGPTConfig
    model_class = GPT2LMHeadModel  # GPT2DoubleHeadsModel if "gpt2" in args.model_checkpoint else OpenAIGPTDoubleHeadsModel
    if args.pretrained:
        tokenizer = tokenizer_class.from_pretrained(MODEL_CHECKPOINT, do_lower_case=False)
        # tokenizer = tokenizer_class(vocab_file=VOCAB_PATH, do_lower_case=True)
        model = model_class.from_pretrained(MODEL_CHECKPOINT)
    else:
        tokenizer = tokenizer_class(vocab_file=VOCAB_PATH, do_lower_case=False)
        tokenizer.add_special_tokens(ATTR_TO_SPECIAL_TOKEN)
        config = config_class.from_json_file(CONFIG_PATH)
        model = model_class(config)
    model.to(args.device)
    # Add special tokens if they are not already added
    # add_special_tokens_(model, tokenizer)
    # optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True)
    optimizer = AdamW([{'params': model.parameters(), 'initial_lr': args.lr}], lr=args.lr, correct_bias=True)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = build_dataloader(args, tokenizer, logger)

    def update(engine, batch):
        model.train()
        batch = tuple(torch.tensor(input_data).to(args.device) if idx not in [2, 3] else input_data for idx, input_data in enumerate(batch))
        input_ids, token_type_ids, input_images, image_ids, lm_labels, mc_token_ids, mc_labels = batch
        if args.dhead_gpt2:
            (lm_loss), (mc_loss), *_ = model(input_ids,
                                             token_type_ids=token_type_ids,
                                             mc_token_ids=mc_token_ids,
                                             mc_labels=mc_labels,
                                             lm_labels=lm_labels)
            loss = (lm_loss * args.lm_coef + mc_loss * args.mc_coef) / args.gradient_accumulation_steps
        else:
            (lm_loss), *_ = model(input_ids,
                                  labels=lm_labels,
                                  token_type_ids=token_type_ids,
                                  input_images=input_images,
                                  image_ids=image_ids)
            loss = lm_loss / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item() #, optimizer.param_groups[0]['lr']
    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
            # logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            # if we dont send labels to model, it doesnt return losses
            if args.dhead_gpt2:
                lm_logits, mc_logits, *_ = model(
                    input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids,
                )
                lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1))
                lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
                return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels)
            else:
                lm_logits, *_ = model(input_ids, token_type_ids=token_type_ids)
                lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1))
                lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
                return lm_logits_flat_shifted, lm_labels_flat_shifted
    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    # trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    model_size = args.n_emd
    noam_lambda = lambda step: (
            model_size ** (-0.5) * min((step + 1) ** (-0.5), (step + 1) * args.warmup_steps ** (-1.5)))
    noam_scheduler = LambdaLR(optimizer, lr_lambda=noam_lambda, last_epoch=args.from_step)
    scheduler = LRScheduler(noam_scheduler)
    if args.scheduler == "linear":
        scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-100), output_transform=lambda x: (x[0][0], x[1][0])),
               "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))}
    metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args),
                    "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)})
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics)))

        log_dir = make_logdir(args.model_checkpoint)
        tb_logger = TensorboardLogger(log_dir)

        tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED)
        # tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(log_dir, 'checkpoint', n_saved=None)
        trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), checkpoint_handler, {'mymodel': getattr(model, 'module', model)})  # "getattr" takes care of distributed encapsulation

        torch.save(args, log_dir + '/model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME))
        tokenizer.save_pretrained(log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(os.path.join(log_dir, checkpoint_handler._saved[-1][1]), os.path.join(log_dir, WEIGHTS_NAME))  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Beispiel #11
0
def train():
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset_path",
        type=str,
        default="",
        help="Path or url of the dataset. If empty download from S3.")
    parser.add_argument("--dataset_cache",
                        type=str,
                        default='./dataset_cache',
                        help="Path or url of the dataset cache")
    parser.add_argument("--model_checkpoint",
                        type=str,
                        default="gpt2",
                        help="Path, url or short name of the model")
    parser.add_argument(
        "--task",
        type=str,
        default="dialogue",
        help="one of task from [dialogue, qa, mt, nlg, summarization]")
    parser.add_argument("--emb_only",
                        action='store_true',
                        help="fine tune only task embeddings")
    parser.add_argument("--linear_perturb",
                        action='store_true',
                        help="fine tune only task embeddings")
    parser.add_argument("--max_history",
                        type=int,
                        default=2,
                        help="Number of previous exchanges to keep in history")
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for training")
    parser.add_argument("--valid_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=8,
                        help="Accumulate gradients on several steps")
    parser.add_argument("--lr",
                        type=float,
                        default=6.25e-5,
                        help="Learning rate")
    parser.add_argument("--lm_coef",
                        type=float,
                        default=1.0,
                        help="LM loss coefficient")
    parser.add_argument("--max_norm",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=1,
                        help="Number of training epochs")
    parser.add_argument("--personality_permutations",
                        type=int,
                        default=1,
                        help="Number of permutations of personality sentences")
    parser.add_argument(
        "--eval_before_start",
        action='store_true',
        help="If true start with a first evaluation before training")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument(
        "--fp16",
        type=str,
        default="",
        help=
        "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank for distributed training (-1: not distributed)")
    parser.add_argument("--perturbation_layers",
                        type=int,
                        default=0,
                        help="number of perturbation layers")
    parser.add_argument("--self_copy",
                        action='store_true',
                        help="add self copy ")
    parser.add_argument("--adapter_bottleneck",
                        type=int,
                        default=0,
                        help="adapter layer bottleneck")
    parser.add_argument("--random_init",
                        action='store_true',
                        help="don't use GPT-2 pre-trained model ")
    parser.add_argument("--distillation", action='store_true')
    parser.add_argument("--outputlayer_only", action='store_true')
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", args.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info("Prepare tokenizer, pretrained model and optimizer.")
    tokenizer_class = GPT2Tokenizer if "gpt2" in args.model_checkpoint else OpenAIGPTTokenizer  # cant use Autotokenizer because checkpoint could be a Path
    tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)

    model_class = GPT2LMHeadModel if "gpt2" in args.model_checkpoint else OpenAIGPTDoubleHeadsModel
    if not args.random_init:
        model = model_class.from_pretrained(
            args.model_checkpoint,
            perturbation_layers=args.perturbation_layers,
            self_copy=args.self_copy,
            adapter_bottleneck=args.adapter_bottleneck)
    else:
        config = GPT2Config()
        model = model_class(config,
                            perturbation_layers=args.perturbation_layers,
                            self_copy=args.self_copy,
                            adapter_bottleneck=args.adapter_bottleneck)
    model.to(args.device)

    # Add special tokens if they are not already added
    add_special_tokens_(model, tokenizer)

    if args.adapter_bottleneck > 0:
        parameters_to_update = [
            p for n, p in model.named_parameters() if "adapter" in str(n)
        ] + [model.transformer.wte.weight]
        optimizer = AdamW(parameters_to_update, lr=args.lr, correct_bias=True)
    elif args.outputlayer_only:
        parameters_to_update = [model.transformer.wte.weight]
        optimizer = AdamW(parameters_to_update, lr=args.lr, correct_bias=True)
    else:
        optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        args, tokenizer)

    # Training function and trainer
    def update_emb(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        input_ids, lm_labels, token_type_ids = batch
        (lm_loss), *_ = model(input_ids,
                              token_type_ids=token_type_ids,
                              labels=lm_labels)
        loss = lm_loss / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            param_to_check = []
            for n, p in model.named_parameters():
                if (n != "transformer.wte.weight"):
                    param_to_check.append(p)
            a = list(param_to_check)[0].clone()
            model.transformer.wte.weight.grad[:50257, :] = 0
            model.transformer.wte.weight.data.add_(
                -args.lr, model.transformer.wte.weight.grad.data)
            optimizer.zero_grad()
            param_to_check = []
            for n, p in model.named_parameters():
                if (n != "transformer.wte.weight"):
                    param_to_check.append(p)

            b = list(param_to_check)[0].clone()
            assert torch.equal(a.data, b.data)
        return loss.item()

    # Training function and trainer
    def update_linear_perturbation(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        input_ids, lm_labels, token_type_ids = batch
        (lm_loss), *_ = model(input_ids,
                              token_type_ids=token_type_ids,
                              labels=lm_labels)
        loss = lm_loss / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:

            model.transformer.wte.weight.grad[:50257, :] = 0
            # model.transformer.wte.weight.data.add_(-args.lr,model.transformer.wte.weight.grad.data)
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    # Training function and trainer
    def update_all(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        input_ids, lm_labels, token_type_ids = batch
        (lm_loss), *_ = model(input_ids,
                              token_type_ids=token_type_ids,
                              labels=lm_labels,
                              self_copy=args.self_copy)
        loss = lm_loss / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    if args.emb_only:
        trainer = Engine(update_emb)
    elif (args.linear_perturb or args.adapter_bottleneck > 0):
        trainer = Engine(update_linear_perturbation)
    else:
        trainer = Engine(update_all)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            input_ids, lm_labels, token_type_ids = batch
            logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            # if we dont send labels to model, it doesnt return losses
            lm_logits, *_ = model(
                input_ids,
                token_type_ids=token_type_ids,
            )
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted, ), (lm_labels_flat_shifted, )

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0][0], x[1][0]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        log_dir = make_logdir(args.model_checkpoint,
                              task=args.task,
                              lr=args.lr,
                              layer=args.perturbation_layers,
                              self_copy=args.self_copy,
                              n_epochs=args.n_epochs,
                              adapter=args.adapter_bottleneck,
                              random_init=args.random_init)
        if args.distillation:
            log_dir += "_distillation"
        if args.outputlayer_only:
            log_dir += "_outputlayer_only"
        tb_logger = TensorboardLogger(log_dir)

        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" takes care of distributed encapsulation

        torch.save(args, log_dir + '/model_training_args.bin')
        getattr(model, 'module',
                model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME))
        tokenizer.save_pretrained(log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(
            checkpoint_handler._saved[-1][1][-1],
            os.path.join(log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
def train(experiment_id, ds_train, ds_val, model, optimizer, hyperparams,
          num_workers, device, debug):

    train_loader = torch.utils.data.DataLoader(ds_train,
                                               batch_size=hyperparams['bs'],
                                               shuffle=True,
                                               num_workers=num_workers)
    val_loader = torch.utils.data.DataLoader(ds_val,
                                             batch_size=hyperparams['bs'],
                                             shuffle=True,
                                             num_workers=num_workers)

    criterion = nn.CrossEntropyLoss().to(device)

    metrics = {
        'loss': Loss(criterion),
        'accuracy': Accuracy(),
    }

    trainer = create_supervised_trainer(model, optimizer, criterion, device)

    if hyperparams['pretrained']:

        @trainer.on(Events.EPOCH_STARTED)
        def turn_on_layers(engine):
            epoch = engine.state.epoch
            if epoch == 1:
                print()
                temp = next(model.named_children())[1]
                for name, child in temp.named_children():
                    if (name == 'mlp') or (name == 'classifier'):
                        print(name + ' is unfrozen')
                        for param in child.parameters():
                            param.requires_grad = True
                    else:
                        for param in child.parameters():
                            param.requires_grad = False

            if epoch == 3:
                print()
                print('Turn on all the layers')
                for name, child in model.named_children():
                    for param in child.parameters():
                        param.requires_grad = True

    pbar = ProgressBar(bar_format='')
    pbar.attach(trainer, output_transform=lambda x: {'loss': x})

    val_evaluator = create_supervised_evaluator(model, metrics, device)

    if hyperparams['early_stopping']:

        def score_function(engine):
            return engine.state.metrics['accuracy']

        handler = EarlyStopping(patience=hyperparams['patience'],
                                score_function=score_function,
                                trainer=trainer)
        val_evaluator.add_event_handler(Events.COMPLETED, handler)

    @trainer.on(Events.STARTED)
    @trainer.on(Events.EPOCH_COMPLETED)
    def compute_and_display_val_metrics(engine):
        epoch = engine.state.epoch
        metrics = val_evaluator.run(val_loader).metrics

        if (epoch == 0) or (metrics['accuracy'] > engine.state.best_acc):
            engine.state.best_acc = metrics['accuracy']
            print('New best accuracy! Accuracy: ' +
                  str(engine.state.best_acc) + '\nModel saved!')
            if not os.path.exists('models/'):
                os.makedirs('models/')
            path = 'models/best_model_' + experiment_id + '.pth'
            torch.save(model.state_dict(), path)

        print('Validation Results - Epoch: {} \
              Average Loss: {:.4f} | Accuracy: {:.4f}'.format(
            engine.state.epoch, metrics['loss'], metrics['accuracy']))

    if hyperparams['scheduler']:
        lr_scheduler = CosineAnnealingLR(optimizer,
                                         hyperparams['nb_epochs'],
                                         eta_min=hyperparams['lr'] / 100,
                                         last_epoch=-1)

        @trainer.on(Events.EPOCH_COMPLETED)
        def update_lr_scheduler(engine):
            lr_scheduler.step()

    tb_logger = TensorboardLogger('board/' + experiment_id)

    def output_transform(loss):
        return {'loss': loss}

    log_handler = OutputHandler(tag='training',
                                output_transform=output_transform)
    tb_logger.attach(trainer,
                     log_handler,
                     event_name=Events.ITERATION_COMPLETED)
    log_handler = OutputHandler(tag='validation',
                                metric_names=['accuracy', 'loss'],
                                another_engine=trainer)
    tb_logger.attach(val_evaluator, log_handler, event_name=Events.STARTED)
    tb_logger.attach(val_evaluator,
                     log_handler,
                     event_name=Events.EPOCH_COMPLETED)
    tb_logger.attach(trainer,
                     log_handler=OptimizerParamsHandler(optimizer),
                     event_name=Events.ITERATION_STARTED)
    tb_logger.attach(trainer,
                     log_handler=GradsHistHandler(model),
                     event_name=Events.EPOCH_COMPLETED)
    tb_logger.close()

    trainer.run(train_loader, max_epochs=hyperparams['nb_epochs'])
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("data_directory", type=Path)
    parser.add_argument("--generator-weights", type=Path)
    parser.add_argument("--discriminator-weights", type=Path)

    args = parser.parse_args()

    generator = Generator(GENERATOR_FILTERS)
    if args.generator_weights is not None:
        LOGGER.info(f"Loading generator weights: {args.generator_weights}")
        generator.load_state_dict(torch.load(args.generator_weights))
    else:
        generator.weight_init(mean=0.0, std=0.02)

    discriminator = Discriminator(DISCRIMINATOR_FILTERS)
    if args.discriminator_weights is not None:
        LOGGER.info(
            f"Loading discriminator weights: {args.discriminator_weights}")
        discriminator.load_state_dict(torch.load(args.discriminator_weights))
    else:
        discriminator.weight_init(mean=0.0, std=0.02)

    dataset = XView2Dataset(args.data_directory, )
    train_dataset, test_dataset = torch.utils.data.random_split(
        dataset, [len(dataset) - 10, 10])

    # Create a dev train dataset with just 10 samples
    # train_dataset, _ = torch.utils.data.random_split(train_dataset, [10, len(train_dataset) - 10])

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=TRAIN_BATCH_SIZE)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=TEST_BATCH_SIZE)

    generator.cuda()
    discriminator.cuda()

    generator.train()
    discriminator.train()

    BCE_loss = nn.BCELoss().cuda()
    L1_loss = nn.L1Loss().cuda()

    generator_optimizer = optim.Adam(generator.parameters(),
                                     lr=GENERATOR_LR,
                                     betas=(BETA_1, BETA_2))
    discriminator_optimizer = optim.Adam(discriminator.parameters(),
                                         lr=DISCRIMINATOR_LR,
                                         betas=(BETA_1, BETA_2))

    def step(engine, batch):
        x, y = batch
        x = x.cuda()
        y = y.cuda()

        discriminator.zero_grad()
        discriminator_result = discriminator(x, y).squeeze()
        discriminator_real_loss = BCE_loss(
            discriminator_result,
            torch.ones(discriminator_result.size()).cuda())

        generator_result = generator(x)
        discriminator_result = discriminator(x, generator_result).squeeze()

        discriminator_fake_loss = BCE_loss(
            discriminator_result,
            torch.zeros(discriminator_result.size()).cuda())
        discriminator_train_loss = (discriminator_real_loss +
                                    discriminator_fake_loss) * 0.5
        discriminator_train_loss.backward()
        discriminator_optimizer.step()

        generator.zero_grad()
        generator_result = generator(x)
        # TODO Work out if the below time saving technique impacts training.
        #generator_result = generator_result.detach()
        discriminator_result = discriminator(x, generator_result).squeeze()

        l1_loss = L1_loss(generator_result, y)
        bce_loss = BCE_loss(discriminator_result,
                            torch.ones(discriminator_result.size()).cuda())

        G_train_loss = bce_loss + L1_LAMBDA * l1_loss
        G_train_loss.backward()
        generator_optimizer.step()

        return {
            'generator_train_loss': G_train_loss.item(),
            'discriminator_real_loss': discriminator_real_loss.item(),
            'discriminator_fake_loss': discriminator_fake_loss.item(),
        }

    trainer = Engine(step)

    tb_logger = TensorboardLogger(log_dir=f"tensorboard/logdir/{uuid4()}")
    tb_logger.attach(trainer,
                     log_handler=OutputHandler(
                         tag="training",
                         output_transform=lambda out: out,
                         metric_names='all'),
                     event_name=Events.ITERATION_COMPLETED)

    @trainer.on(Events.EPOCH_COMPLETED)
    def add_generated_images(engine):
        def min_max(image):
            return (image - image.min()) / (image.max() - image.min())

        for idx, (x, y) in enumerate(test_loader):
            generated = min_max(generator(x.cuda()).squeeze().cpu())
            real = min_max(y.squeeze())

            tb_logger.writer.add_image(
                f"generated_test_image_{idx}",
                # Concatenate the images into a single tiled image
                torch.cat([x.squeeze(), generated, real], 2),
                global_step=engine.state.epoch)

    checkpoint_handler = ModelCheckpoint("checkpoints/",
                                         "pix2pix",
                                         n_saved=1,
                                         require_empty=False,
                                         save_interval=1)
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={
                                  'generator': generator,
                                  'discriminator': discriminator
                              })

    timer = Timer(average=True)
    timer.attach(trainer,
                 resume=Events.ITERATION_STARTED,
                 step=Events.ITERATION_COMPLETED)

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        print("Epoch[{}] Iteration[{}] Duration[{}] Losses: {}".format(
            engine.state.epoch, engine.state.iteration, timer.value(),
            engine.state.output))

    trainer.run(train_loader, max_epochs=TRAIN_EPOCHS)

    tb_logger.close()
Beispiel #14
0
def attach_handlers(run, model, optimizer, trainer, train_evaluator, evaluator,
                    train_loader, val_loader, params):
    # Tqdm logger
    pbar = ProgressBar(persist=True, bar_format=config.IGNITE_BAR_FORMAT)
    pbar.attach(trainer.engine, metric_names='all')
    tqdm_logger = TqdmLogger(pbar=pbar)
    # noinspection PyTypeChecker
    tqdm_logger.attach_output_handler(
        evaluator.engine,
        event_name=Events.COMPLETED,
        tag="validation",
        global_step_transform=global_step_from_engine(trainer.engine),
    )
    # noinspection PyTypeChecker
    tqdm_logger.attach_output_handler(
        train_evaluator.engine,
        event_name=Events.COMPLETED,
        tag="train",
        global_step_transform=global_step_from_engine(trainer.engine),
    )

    # Evaluators
    train_evaluator.attach(trainer.engine, Events.EPOCH_COMPLETED,
                           train_loader)
    evaluator.attach(trainer.engine, Events.EPOCH_COMPLETED, data=val_loader)

    # Learning rate scheduling
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                              'max',
                                                              verbose=True,
                                                              patience=5,
                                                              factor=0.5)
    evaluator.engine.add_event_handler(
        Events.COMPLETED,
        lambda engine: lr_scheduler.step(engine.state.metrics['accuracy']))

    # Early stopping
    es_handler = EarlyStopping(
        patience=15,
        score_function=lambda engine: engine.state.metrics['accuracy'],
        trainer=trainer.engine,
        cumulative_delta=True,
        min_delta=0.0001)
    if 'train_all' in params and params['train_all']:
        train_evaluator.engine.add_event_handler(Events.COMPLETED, es_handler)
    else:
        evaluator.engine.add_event_handler(Events.COMPLETED, es_handler)

    es_handler.logger.setLevel(logging.DEBUG)

    # Model checkpoints
    name = run.replace('/', '-')
    mc_handler = ModelCheckpoint(
        config.MODELS_DIR,
        name,
        n_saved=1,
        create_dir=True,
        require_empty=False,
        score_name='acc',
        score_function=lambda engine: engine.state.metrics['accuracy'],
        global_step_transform=global_step_from_engine(trainer.engine))
    evaluator.engine.add_event_handler(Events.EPOCH_COMPLETED, mc_handler,
                                       {'m': model})

    # TensorBoard logger
    tb_logger = TensorboardLogger(
        log_dir=os.path.join(config.TENSORBOARD_DIR, run))
    images, labels = next(iter(train_loader))
    tb_logger.writer.add_graph(copy.deepcopy(model).cpu(), images)
    tb_logger.writer.add_hparams(params, {'hparam/dummy': 0})

    # noinspection PyTypeChecker
    tb_logger.attach_output_handler(
        train_evaluator.engine,
        event_name=Events.COMPLETED,
        tag="train",
        metric_names="all",
        global_step_transform=global_step_from_engine(trainer.engine),
    )
    # noinspection PyTypeChecker
    tb_logger.attach_output_handler(
        evaluator.engine,
        event_name=Events.COMPLETED,
        tag="validation",
        metric_names="all",
        global_step_transform=global_step_from_engine(trainer.engine),
    )
    input_shape = tuple(next(iter(train_loader))[0].shape[1:])
    tb_logger.attach(trainer.engine,
                     log_handler=WeightsImageHandler(model, input_shape),
                     event_name=Events.EPOCH_COMPLETED)
    tb_logger.attach(trainer.engine,
                     log_handler=OptimizerParamsHandler(optimizer),
                     event_name=Events.EPOCH_STARTED)
    # tb_logger.attach(trainer.engine, log_handler=WeightsScalarHandler(model), event_name=Events.EPOCH_COMPLETED)
    # tb_logger.attach(trainer.engine, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=ActivationsHistHandler(model, layer_names=['linear1', 'batch_norm', 'repu']),
    #                  event_name=Events.ITERATION_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=NumActivationsScalarHandler(model, layer_names=['linear1', 'repu']),
    #                  event_name=Events.ITERATION_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=ActivationsScalarHandler(model, reduction=torch.mean,
    #                                                       layer_names=['linear1', 'batch_norm', 'repu']),
    #                  event_name=Events.ITERATION_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=ActivationsScalarHandler(model, reduction=torch.std,
    #                                                       layer_names=['linear1', 'batch_norm', 'repu']),
    #                  event_name=Events.ITERATION_COMPLETED)

    return es_handler, tb_logger
def attach_handlers(run, model, optimizer, learning_rule, trainer, evaluator, train_loader, val_loader, params):
    # Metrics
    UnitConvergence(model[0], learning_rule.norm).attach(trainer.engine, 'unit_conv')

    # Tqdm logger
    pbar = ProgressBar(persist=True, bar_format=config.IGNITE_BAR_FORMAT)
    pbar.attach(trainer.engine, metric_names='all')
    tqdm_logger = TqdmLogger(pbar=pbar)
    # noinspection PyTypeChecker
    tqdm_logger.attach_output_handler(
        evaluator.engine,
        event_name=Events.COMPLETED,
        tag="validation",
        global_step_transform=global_step_from_engine(trainer.engine),
    )

    # Evaluator
    evaluator.attach(trainer.engine, Events.EPOCH_COMPLETED(every=100), train_loader, val_loader)

    # Learning rate scheduling
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer,
                                                     lr_lambda=lambda epoch: 1 - epoch / params['epochs'])
    lr_scheduler = LRScheduler(lr_scheduler)
    trainer.engine.add_event_handler(Events.EPOCH_COMPLETED, lr_scheduler)

    # Early stopping
    mc_handler = ModelCheckpoint(config.MODELS_DIR, run.replace('/', '-'), n_saved=1, create_dir=True,
                                 require_empty=False,
                                 global_step_transform=global_step_from_engine(trainer.engine))
    trainer.engine.add_event_handler(Events.EPOCH_COMPLETED, mc_handler, {'m': model})

    # Create a TensorBoard logger
    tb_logger = TensorboardLogger(log_dir=os.path.join(config.TENSORBOARD_DIR, run))
    images, labels = next(iter(train_loader))
    tb_logger.writer.add_graph(copy.deepcopy(model).cpu(), images)
    tb_logger.writer.add_hparams(params, {})

    # noinspection PyTypeChecker
    tb_logger.attach_output_handler(
        evaluator.engine,
        event_name=Events.COMPLETED,
        tag="validation",
        metric_names="all",
        global_step_transform=global_step_from_engine(trainer.engine),
    )
    # noinspection PyTypeChecker
    tb_logger.attach_output_handler(
        trainer.engine,
        event_name=Events.EPOCH_COMPLETED,
        tag="train",
        metric_names=["unit_conv"]
    )
    input_shape = tuple(next(iter(train_loader))[0].shape[1:])
    tb_logger.attach(trainer.engine,
                     log_handler=WeightsImageHandler(model, input_shape),
                     event_name=Events.EPOCH_COMPLETED)
    tb_logger.attach(trainer.engine, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.EPOCH_STARTED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=WeightsScalarHandler(model, layer_names=['linear1', 'linear2']),
    #                  event_name=Events.EPOCH_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=WeightsHistHandler(model, layer_names=['linear1', 'linear2']),
    #                  event_name=Events.EPOCH_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=ActivationsHistHandler(model, layer_names=['batch_norm', 'repu']),
    #                  event_name=Events.ITERATION_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=NumActivationsScalarHandler(model, layer_names=['repu']),
    #                  event_name=Events.ITERATION_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=ActivationsScalarHandler(model, reduction=torch.mean,
    #                                                       layer_names=['batch_norm', 'repu']),
    #                  event_name=Events.ITERATION_COMPLETED)
    # tb_logger.attach(trainer.engine,
    #                  log_handler=ActivationsScalarHandler(model, reduction=torch.std,
    #                                                       layer_names=['batch_norm', 'repu']),
    #                  event_name=Events.ITERATION_COMPLETED)

    return tb_logger
Beispiel #16
0
        ]

        logger.writer.add_images(
            'samples',
            np.stack([np.concatenate([
                np.concatenate([
                    np.array(sample.representation())
                    for sample in samples
                ], axis=1)
                for samples in std_samples
            ], axis=0)]) / 255,
            trainer.state.epoch,
            dataformats='NHWC',
        )

    trainer = ignite.engine.Engine(update_batch_norms)

    tensorboard_logger = TensorboardLogger(log_dir='tb')
    tensorboard_logger.attach(
        trainer,
        log_examples,
        ignite.engine.Events.EPOCH_COMPLETED,
    )

    ProgressBar(desc='update_batch_norms').attach(trainer)

    trainer.run(
        range(config['n_batches_per_epoch']),
        max_epochs=config['n_epochs']
    )
    def setup(self, training_metrics):
        def metric_name(n) -> str:
            if n.endswith('Accuracy'):
                n = 'acc'
            else:
                n = n[:-6] if n.endswith('Metric') else n
            return n

        def print_metrics(metrics) -> str:
            rv = ''
            metric_keys = sorted(k for k in metrics)
            for k in metric_keys:
                if k == 'Accuracy':
                    rv += f'{metric_name(k)}: {metrics[k]:.3}'
                else:
                    rv += f'{metric_name(k)}: {metrics[k]:.6}'
            return rv

        if self.seed:
            set_seed_everywhere(self.seed, self.cuda)

        pbar = ProgressBar()

        names = []
        for k, v in training_metrics.items():
            name = f'r{k}'
            names.append(name)
            RunningAverage(v).attach(self.trainer, name)
        RunningAverage(None,
                       output_transform=lambda x: x[-1] * self.
                       loss_accumulation_steps).attach(self.trainer, 'rloss')
        names.append('rloss')
        pbar.attach(self.trainer, names)

        pbar = ProgressBar()
        pbar.attach(self.evaluator)

        # A few events handler. To add / modify the events handler, you need to extend the __init__ method of RunnerABC
        # Ignite provides the necessary abstractions and a furnished repository of useful tools

        @self.trainer.on(Events.EPOCH_COMPLETED)
        def log_validation_results(trainer):
            self.evaluator.run(self.dataset_splits.val_data_loader())
            metrics = self.evaluator.state.metrics
            logger.info(
                f"Validation Results - Epoch: {trainer.state.epoch} {print_metrics(metrics)}"
            )

            if self.scheduler:
                self.scheduler.step(
                    metrics[self.loss_metric.__class__.__name__])

        @self.trainer.on(Events.COMPLETED)
        def log_test_results(trainer):
            self.evaluator.run(self.dataset_splits.test_data_loader())
            metrics = self.evaluator.state.metrics
            logger.info(
                f"Test Results - Epoch: {trainer.state.epoch} {print_metrics(metrics)}"
            )

        if self.tensorboard_logs:
            tb_logger = TensorboardLogger(log_dir=self.tensorboard_logs)
            tb_logger.attach(self.trainer,
                             log_handler=OutputHandler(
                                 tag="training",
                                 output_transform=lambda loss: {'loss': loss}),
                             event_name=Events.ITERATION_COMPLETED)
            tb_logger.attach(self.evaluator,
                             log_handler=OutputHandler(
                                 tag="validation",
                                 metric_names=["LossMetric"],
                                 another_engine=self.trainer),
                             event_name=Events.EPOCH_COMPLETED)
            tb_logger.attach(self.trainer,
                             log_handler=OptimizerParamsHandler(
                                 self.optimizer),
                             event_name=Events.ITERATION_STARTED)
            tb_logger.attach(self.trainer,
                             log_handler=WeightsScalarHandler(self.model),
                             event_name=Events.ITERATION_COMPLETED)
            tb_logger.attach(self.trainer,
                             log_handler=WeightsHistHandler(self.model),
                             event_name=Events.EPOCH_COMPLETED)
            tb_logger.attach(self.trainer,
                             log_handler=GradsScalarHandler(self.model),
                             event_name=Events.ITERATION_COMPLETED)

            # This is important to close the tensorboard file logger
            @self.trainer.on(Events.COMPLETED)
            def end_tensorboard(trainer):
                logger.info("Training completed")
                tb_logger.close()

        if self.embeddings_name:

            @self.trainer.on(Events.COMPLETED)
            def log_embeddings(trainer):
                if hasattr(self.model, self.embeddings_name) and hasattr(
                        self.dataset_splits, "vectorizer"):
                    logger.info(
                        f"Logging embeddings ({self.embeddings_name}) to Tensorboard!"
                    )
                    embeddings = getattr(self.model,
                                         self.embeddings_name).weight.data
                    metadata = [
                        str(self.dataset_splits.vectorizer.data_vocab.
                            _id2token[token_index]).encode('utf-8')
                        for token_index in range(embeddings.shape[0])
                    ]
                    self.writer.add_embedding(
                        mat=embeddings,
                        metadata=metadata,
                        global_step=self.trainer.state.epoch)
Beispiel #18
0
def main():
    run_id = str(uuid4())

    print("Initialising embedding network.")
    embedding_model = ResNetEmbedding(16)
    embedding_model.cuda()

    print("Initialising triplet network.")
    triplet_model = ResNetTriplet(embedding_model)
    triplet_model.cuda()

    print("Initialising training dataset.")

    train_dataset = OsmTileDataset(samples=[
        sample
        for sample in load_samples(Path("data/extents/train_1500000.json"))
        if random.random() > 0.99 or sample.anchor.entropy > 1.7
    ],
                                   cache_dir=CACHE_DIR)

    print("Initialising testing dataset.")
    test_dataset = OsmTileDataset(samples=[
        sample for sample in load_samples(Path("data/extents/test_15000.json"))
        if random.random() > 0.99 or sample.anchor.entropy > 1.7
    ],
                                  cache_dir=CACHE_DIR)

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        pin_memory=True,
        num_workers=4,
    )
    test_loader = DataLoader(test_dataset, batch_size=1, num_workers=4)

    triplet_loss = TripletLoss(margin=1)

    optimizer = optim.Adam(embedding_model.parameters(), lr=1e-4)
    lr_scheduler = ExponentialLR(optimizer, 0.99)

    def train_step(engine, batch):
        embedding_model.train()
        triplet_model.train()

        anchor, positive, negative = batch
        anchor = anchor.cuda()
        positive = positive.cuda()
        negative = negative.cuda()

        optimizer.zero_grad()

        anchor_embedding, positive_embedding, negative_embedding = triplet_model(
            anchor, positive, negative)
        loss = triplet_loss(anchor_embedding, positive_embedding,
                            negative_embedding).cuda()

        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        return {
            'loss': loss.item(),
        }

    trainer = Engine(train_step)

    tb_logger = TensorboardLogger(log_dir=f"tensorboard/{run_id}")
    tb_logger.attach(trainer,
                     log_handler=OutputHandler(
                         tag="training",
                         output_transform=lambda out: out,
                         metric_names='all'),
                     event_name=Events.ITERATION_COMPLETED)

    for idx, sample in enumerate(test_loader):
        if idx > 9:
            break
        anchor, positive, negative = sample
        anchor = DENORMALIZE(anchor.squeeze())
        positive = DENORMALIZE(positive.squeeze())
        negative = DENORMALIZE(negative.squeeze())

        tb_logger.writer.add_image(f"test_image_{idx}",
                                   torch.cat([anchor, positive, negative], 2),
                                   global_step=0)

    @trainer.on(Events.EPOCH_COMPLETED)
    def test(engine):
        embedding_model.eval()
        triplet_model.eval()

        with torch.no_grad():
            embeddings = []
            images = []
            loss_total = 0
            for idx, sample in enumerate(test_loader):
                anchor, positive, negative = sample
                anchor = anchor.cuda()
                positive = positive.cuda()
                negative = negative.cuda()

                anchor_embedding, positive_embedding, negative_embedding = triplet_model(
                    anchor, positive, negative)
                loss = triplet_loss(anchor_embedding, positive_embedding,
                                    negative_embedding)
                loss_total += loss.item()

                # 300 is a good number of images to plot
                if len(embeddings) < 300:
                    embeddings.append(
                        anchor_embedding.squeeze().detach().cpu().numpy())
                    images.append(
                        DENORMALIZE(
                            anchor.squeeze()).detach().cpu().numpy().transpose(
                                1, 2, 0))

        fig = plt.gcf()
        fig.clf()
        fig.set_size_inches(9, 9)
        ax = plt.subplot(111)

        embeddings = TSNE(n_components=2).fit_transform(embeddings)
        for embedding_idx, image in enumerate(images):
            offset_image = OffsetImage(image, zoom=.2)
            ab = AnnotationBbox(offset_image,
                                embeddings[embedding_idx],
                                xybox=(30.0, -30.0),
                                xycoords='data',
                                boxcoords="offset points",
                                frameon=False)
            ax.add_artist(ab)

        plt.axis("off")
        plt.xlim((embeddings[:, 0].min(), embeddings[:, 0].max()))
        plt.ylim((embeddings[:, 1].min(), embeddings[:, 1].max()))
        plt.draw()

        tb_logger.writer.add_figure(f"test_embeddings",
                                    fig,
                                    global_step=engine.state.epoch)
        tb_logger.writer.add_scalar(
            f"test_loss",
            # Assumes test loader has batch size of 1
            loss_total / len(test_loader),
            global_step=engine.state.epoch)

    print("Initialising checkpoint handler.")
    checkpoint_handler = ModelCheckpoint(
        "checkpoints/",
        f"{triplet_model.__class__.__name__}-{train_dataset.__class__.__name__}-{run_id}",
        n_saved=10,
        require_empty=False)
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={
                                  'embedding': embedding_model,
                              })

    print("Initialising timer.")
    timer = Timer(average=True)
    timer.attach(trainer,
                 resume=Events.ITERATION_STARTED,
                 step=Events.ITERATION_COMPLETED)

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        print("Epoch[{}] Iteration[{}] Duration[{}] Losses: {}".format(
            engine.state.epoch, engine.state.iteration, timer.value(),
            engine.state.output))

    print("Running trainer.")
    trainer.run(train_loader, max_epochs=TRAIN_EPOCHS)

    tb_logger.close()
def train():
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset_path",
        type=str,
        default='wikitext-2',
        help="One of ('wikitext-103', 'wikitext-2') or a dict of splits paths."
    )
    parser.add_argument("--dataset_cache",
                        type=str,
                        default='./dataset_cache',
                        help="Path or url of the dataset cache")

    parser.add_argument("--embed_dim",
                        type=int,
                        default=410,
                        help="Embeddings dim")
    parser.add_argument("--hidden_dim",
                        type=int,
                        default=2100,
                        help="Hidden dimension")
    parser.add_argument("--num_max_positions",
                        type=int,
                        default=256,
                        help="Max input length")
    parser.add_argument("--num_heads",
                        type=int,
                        default=10,
                        help="Number of heads")
    parser.add_argument("--num_layers",
                        type=int,
                        default=16,
                        help="NUmber of layers")
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout")
    parser.add_argument("--initializer_range",
                        type=float,
                        default=0.02,
                        help="Dropout")

    parser.add_argument("--train_batch_size",
                        type=int,
                        default=8,
                        help="Batch size for training")
    parser.add_argument("--valid_batch_size",
                        type=int,
                        default=8,
                        help="Batch size for validation")
    parser.add_argument("--lr",
                        type=float,
                        default=2.5e-4,
                        help="Learning rate")
    parser.add_argument("--max_norm",
                        type=float,
                        default=0.25,
                        help="Clipping gradient norm")
    parser.add_argument("--weight_decay",
                        type=float,
                        default=0.0,
                        help="Weight decay")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=200,
                        help="Number of training epochs")
    parser.add_argument("--n_warmup",
                        type=float,
                        default=1000,
                        help="Number of warmup iterations")
    parser.add_argument("--eval_every",
                        type=int,
                        default=-1,
                        help="Evaluate every X steps (-1 => end of epoch)")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=1,
                        help="Accumulate gradient")

    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank for distributed training (-1: not distributed)")
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log on main process only, logger.warning => log on all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", args.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(
        args))  # This is a logger.info: only printed on the first process

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info("Prepare tokenizer, model and optimizer")
    tokenizer = BertTokenizer.from_pretrained(
        'bert-base-cased',
        do_lower_case=False)  # Let's use a pre-defined tokenizer
    args.num_embeddings = len(
        tokenizer.vocab
    )  # We need this to create the model at next line (number of embeddings to use)
    model = TransformerWithLMHead(args)
    model.to(args.device)
    optimizer = Adam(model.parameters(),
                     lr=args.lr,
                     weight_decay=args.weight_decay)
    logger.info("Model has %s parameters",
                sum(p.numel() for p in model.parameters() if p.requires_grad))

    # Prepare model for distributed training if needed
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler, train_num_words, valid_num_words = get_data_loaders(
        args, tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = batch.transpose(0, 1).contiguous().to(
            args.device)  # to shape [seq length, batch]
        logits, loss = model(batch, labels=batch)
        loss = loss / args.gradient_accumulation_steps
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = batch.transpose(0, 1).contiguous().to(
                args.device)  # to shape [seq length, batch]
            logits = model(batch)
            shift_logits = logits[:-1].view(-1, logits.size(-1))
            shift_labels = batch[1:].view(-1)
            return shift_logits, shift_labels

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate at the end of each epoch and every 'eval_every' iterations if needed
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.eval_every > 0:
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED,
            lambda engine: evaluator.run(val_loader)
            if engine.state.iteration % args.eval_every == 0 else None)
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Learning rate schedule: linearly warm-up to lr and then decrease the learning rate to zero with cosine schedule
    cos_scheduler = CosineAnnealingScheduler(optimizer, 'lr', args.lr, 0.0,
                                             len(train_loader) * args.n_epochs)
    scheduler = create_lr_scheduler_with_warmup(cos_scheduler, 0.0, args.lr,
                                                args.n_warmup)
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we average distributed metrics using average_distributed_scalar
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1))}
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    metrics["average_word_ppl"] = MetricsLambda(
        lambda x: math.exp(x * val_loader.dataset.numel() / valid_num_words),
        metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model and configuration before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=None)
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)

        @evaluator.on(Events.COMPLETED)  # Log evaluator metrics on tensorboard
        def tb_log_metrics(engine):
            for name in metrics.keys():
                tb_logger.writer.add_scalar(name, engine.state.metrics[name],
                                            trainer.state.iteration)

        checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" take care of distributed encapsulation

        torch.save(args, os.path.join(tb_logger.writer.log_dir, CONFIG_NAME))

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint for easy re-loading
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(
            checkpoint_handler._saved[-1][1][-1],
            os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Beispiel #20
0
    def custom_setup(self):

        if self.tensorboard_logs:
            tb_logger = TensorboardLogger(log_dir=self.tensorboard_logs)
            tb_logger.attach(self.trainer,
                             log_handler=OutputHandler(
                                 tag="training",
                                 output_transform=lambda loss: {'loss': loss}),
                             event_name=Events.ITERATION_COMPLETED)
            tb_logger.attach(self.evaluator,
                             log_handler=OutputHandler(
                                 tag="validation",
                                 metric_names=["LossMetric"],
                                 another_engine=self.trainer),
                             event_name=Events.EPOCH_COMPLETED)

            if self.optional_tensorboard_features:
                tb_logger.attach(self.trainer,
                                 log_handler=OptimizerParamsHandler(
                                     self.optimizer),
                                 event_name=Events.ITERATION_STARTED)
                tb_logger.attach(self.trainer,
                                 log_handler=WeightsScalarHandler(self.model),
                                 event_name=Events.ITERATION_COMPLETED)
                tb_logger.attach(self.trainer,
                                 log_handler=WeightsHistHandler(self.model),
                                 event_name=Events.EPOCH_COMPLETED)
                tb_logger.attach(self.trainer,
                                 log_handler=GradsScalarHandler(self.model),
                                 event_name=Events.ITERATION_COMPLETED)

            # This is important to close the tensorboard file logger
            @self.trainer.on(Events.COMPLETED)
            def end_tensorboard(trainer):
                logger.info("Training completed")
                tb_logger.close()

        if self.embeddings_name:

            @self.trainer.on(Events.COMPLETED)
            def log_embeddings(trainer):
                if hasattr(self.model, self.embeddings_name) and hasattr(
                        self.dataset_splits, "vectorizer") and TENSORBOARD:
                    logger.info(
                        f"Logging embeddings ({self.embeddings_name}) to Tensorboard!"
                    )
                    embeddings = getattr(self.model,
                                         self.embeddings_name).weight.data
                    metadata = [
                        str(self.dataset_splits.vectorizer.data_vocab.
                            _id2token[token_index]).encode('utf-8')
                        for token_index in range(embeddings.shape[0])
                    ]
                    self.writer.add_embedding(
                        mat=embeddings,
                        metadata=metadata,
                        global_step=self.trainer.state.epoch)
Beispiel #21
0
def train():
    os.environ['CUDA_VISIBLE_DEVICES'] = '7'

    parser = ArgumentParser()
    parser.add_argument('--gpt2', action='store_true', help="use gpt2")
    parser.add_argument("--model_checkpoint", type=str, default="uer/gpt2-chinese-cluecorpussmall", help="Path or URL of the model")
    parser.add_argument("--from_step", type=int, default=-1, help="Init learning rate from this step")
    parser.add_argument('--pretrained', action='store_true', help="If False train from scratch")
    parser.add_argument("--data_path", type=str, default="data/autocloze.json",
                        help="Path or url of the dataset. ")
    parser.add_argument("--train_path", type=str, default="data/toy_train.txt",
                        help="Path of the train dataset for dist dataset. ")
    parser.add_argument("--valid_path", type=str, default="data/toy_valid.txt",
                        help="Path of the valid dataset for dist dataset. ")
    #--------------------------------------------------------------
    parser.add_argument("--dataset_cache", type=str, default="dataset_zh",
                        help="Path or url of the dataset cache")
    parser.add_argument('--log_file', '-log_file', type=str, default="", help="Output logs to a file under this path")
    parser.add_argument("--num_workers", type=int, default=8, help="Number of subprocesses for data loading")
    parser.add_argument("--n_epochs", type=int, default=40, help="Number of training epochs")
    parser.add_argument("--train_batch_size", type=int, default=1, help="Batch size for training")
    parser.add_argument("--valid_batch_size", type=int, default=1, help="Batch size for validation")
    parser.add_argument("--max_history", type=int, default=15, help="Number of previous exchanges to keep in history")
    parser.add_argument("--scheduler", type=str, default="noam", choices=['noam', 'linear'], help="method of optim")
    parser.add_argument("--n_emd", type=int, default=768, help="Number of n_emd in config file (for noam)")
    parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate")
    parser.add_argument("--eval_before_start", action='store_true',
                        help="If true start with a first evaluation before training")
    parser.add_argument("--warmup_steps", type=int, default=5000, help="Warm up steps")
    parser.add_argument("--valid_steps", type=int, default=5000, help="Perfom validation every X steps")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=64,
                        help="Accumulate gradients on several steps")
    parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument("--fp16", type=str, default="",
                        help="Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument("--local_rank", type=int, default=-1,
                        help="Local rank for distributed training (-1: not distributed)")
    args = parser.parse_args()
    print('cuda ',torch.cuda.is_available())
    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process.
    # logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning("Running process %d", args.local_rank)
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    '''if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
    '''
    args.device = torch.device("cuda")
    print('device ',args.device)
    logger.info("Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning")
    #model_class = OpenAIGPTLMHeadModel if not args.gpt2 else GPT2LMHeadModel
    #config_class = OpenAIGPTConfig if not args.gpt2 else GPT2Config
    model_class = GPT2LMHeadModel
    config_class = GPT2Config
    tokenizer_class = BertTokenizer
    print('pretrained:',args.pretrained)
    if args.pretrained:
        print("----------------pretrained")
        tokenizer = BertTokenizer.from_pretrained(args.model_checkpoint, do_lower_case=True)
        model = GPT2LMHeadModel.from_pretrained(args.model_checkpoint)
    else:
        tokenizer = BertTokenizer.from_pretrained("uer/gpt2-chinese-cluecorpussmall")
        model = GPT2LMHeadModel.from_pretrained("uer/gpt2-chinese-cluecorpussmall",from_tf=True)
        #print('generate')
        #print(text_generator("这是很久之前的事情了", max_length=100, do_sample=True))

    #args.device=torch.device("cuda", 2)
    
    model.to(args.device)
    
    optimizer = AdamW([{'params': model.parameters(), 'initial_lr': args.lr}], lr=args.lr, correct_bias=True)

    logger.info("Prepare datasets")
    loader_class = build_dist_loaders if not args.data_path else build_dataloaders
    train_loader, val_loader, train_sampler, valid_sampler = loader_class(args, tokenizer, logger)

    logger.info("Prepare datasets ends")
    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)
        model=model.module
    #if isinstance(model,torch.nn.DataParallel):
    
    #print('params:',params_count(model))

    #tokens_embed = model.transformer.get_input_embeddings()
    # Training function and trainer
    def update(engine, batch):
        input_ids, token_type_ids, lm_labels = tuple(input_tensor.to(args.device) for input_tensor in batch)
        
        #for i in range(input_ids.size()[0]):
        #    for j in range(input_ids.size()[1]):
        #        if input_ids[i,j]==-1:
        #            input_ids[i,j]=-100
        #        if lm_labels[i,j]==-1:
        #            lm_labels[i,j]=-100
        #one=torch.tensor(-100)
        #input_ids=torch.where(input_ids==-1,one,input_ids)
        #lm_labels=torch.where(lm_labels==-1,one,lm_labels)
        #print('traindata',input_ids,lm_labels)

        #lm_labels=input_ids
        r'''input_shape = input_ids.siz`e`()
        input_ids = input_ids.view(-1, input_shape[-1])
        inputs_embeds = tokens_embed(input_ids) * math.sqrt(tokens_embed.embedding_dim)'''

        model.train()
        #(lm_loss), *_ = model(inputs_embeds=inputs_embeds, labels=lm_labels,return_dict=0)
        (lm_loss), *_ = model(input_ids=input_ids, labels=lm_labels,return_dict=False)
        #print('lm_loss',lm_loss)
        loss = lm_loss / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item(), optimizer.param_groups[0]['lr']

    trainer = Engine(update)
    

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    cntepoch=0
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            input_ids, token_type_ids, lm_labels = tuple(input_tensor.to(args.device) for input_tensor in batch)
            # logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            #one = torch.tensor(-100)
            #input_ids=torch.where(input_ids==-1,one,input_ids)
            #print('validdata',input_ids,lm_labels)
            #lm_labels=input_ids
            r'''input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
            inputs_embeds = tokens_embed(input_ids) * math.sqrt(tokens_embed.embedding_dim)'''
            

            #lm_logits, *_ = model(inputs_embeds=inputs_embeds,return_dict=0)
            lm_logits, *_ = model(input_ids=input_ids,return_dict=False)
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return lm_logits_flat_shifted, lm_labels_flat_shifted
        cntepoch+=1
        torch.save(args, tb_logger.writer.logdir + '_%s/model_training_args.bin'%(str(cntepoch)))

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader))

    # Evaluation during training
    @trainer.on(Events.ITERATION_STARTED)
    def log_iterations(engine):
        # if engine.state.iteration % max(int(0.1 * len(train_loader)), 1) == 0:
        if engine.state.iteration % args.valid_steps == 0:
            evaluator.run(val_loader)

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # noam decrease the learning rate
    # model_size = model.config.n_embd
    model_size = args.n_emd
    noam_lambda = lambda step: (
            model_size ** (-0.5) * min((step + 1) ** (-0.5), (step + 1) * args.warmup_steps ** (-1.5)))
    noam_scheduler = LambdaLR(optimizer, lr_lambda=noam_lambda, last_epoch=args.from_step)
    scheduler = LRScheduler(noam_scheduler)
    if args.scheduler == "linear":
        scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "loss")
    RunningAverage(output_transform=lambda x: x[1]).attach(trainer, "lr")
    metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-100), output_transform=lambda x: (x[0], x[1]))}
    metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args)})
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints
    # And save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True, mininterval=2)
        pbar.attach(trainer, metric_names=["loss", "lr"])
        evaluator.add_event_handler(Events.COMPLETED,
                                    lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=None)
        tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()),
                                                              another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(tb_logger.writer.logdir, 'checkpoint', save_interval=1, n_saved=6)
        # save model after evaluation
        evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {
            'mymodel': getattr(model, 'module', model)})
        trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {
            'mymodel': getattr(model, 'module', model)})  # "getattr" take care of distributed encapsulation

        torch.save(args, tb_logger.writer.logdir + '/model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(os.path.join(tb_logger.writer.logdir, CONFIG_NAME))
        tokenizer.save_vocabulary(tb_logger.writer.logdir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)
    
    # On the main process: close tensorboard logger and rename the last checkpoint
    # (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(checkpoint_handler._saved[-1][1][-1],
                  os.path.join(tb_logger.writer.logdir,
                               WEIGHTS_NAME))  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Beispiel #22
0
def train():
    parser = ArgumentParser()
    parser.add_argument("--dataset_path",
                        type=str,
                        default="",
                        help="Path or url of the dataset.")
    parser.add_argument("--use_adapter",
                        default=False,
                        action='store_true',
                        help="Use adapter or not")
    parser.add_argument("--keyword_Module",
                        type=str,
                        default="",
                        help="add, attention, ")
    parser.add_argument("--model_checkpoint",
                        type=str,
                        default="bertGpt",
                        help="Path, url or short name of the model")
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=8,
                        help="Batch size for training")
    parser.add_argument("--valid_batch_size",
                        type=int,
                        default=8,
                        help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=8,
                        help="Accumulate gradients on several steps")
    parser.add_argument("--lr",
                        type=float,
                        default=6.25e-5,
                        help="Learning rate")
    parser.add_argument("--max_norm",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=3,
                        help="Number of training epochs")
    parser.add_argument(
        "--eval_before_start",
        action='store_true',
        help="If true start with a first evaluation before training")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument(
        "--fp16",
        type=str,
        default="",
        help=
        "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank for distributed training (-1: not distributed)")
    parser.add_argument("--bert_model_path",
                        default="./",
                        type=str,
                        help="Bert pre-trained model path")
    parser.add_argument(
        "--vocab_file",
        default="./vocab.korean.rawtext.list",
        type=str,
        help="The vocabulary file that the BERT model was trained on.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", args.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info("Prepare tokenizer, pretrained model and optimizer.")
    #tokenizer_class = GPT2Tokenizer if "gpt2" in args.model_checkpoint else OpenAIGPTTokenizer  # cant use Autotokenizer because checkpoint could be a Path
    #tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # Load KoBERT model and tokenizer
    bert_tokenizer = BertTokenizer.from_pretrained(
        args.vocab_file, do_lower_case=args.do_lower_case)
    bert_model = BertModel.from_pretrained(args.bert_model_path)
    bert_model.to(args.device)

    # Load KoGPT2 model and tokenizer
    tok_path = get_tokenizer()
    gpt_model, gpt_vocab = get_pytorch_conkogpt2_model2(
        keyword_Module=args.keyword_Module, use_adapter=args.use_adapter)
    gpt_tokenizer = SentencepieceTokenizer(tok_path)
    gpt_model.to(args.device)

    model = Seq2Seq(bert_model, gpt_model, gpt_vocab, args)

    optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    #if args.fp16:
    #from apex import amp  # Apex is only required if we use fp16 training
    #model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        args, bert_tokenizer, gpt_tokenizer, gpt_vocab)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        source_ids, target_ids, lm_labels = batch

        #(lm_loss), *_ = model(input_ids, token_type_ids=token_type_ids, labels=lm_labels)
        (lm_loss), *_ = model(source_ids, target_ids, lm_labels=lm_labels)
        loss = lm_loss / args.gradient_accumulation_steps

        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            source_ids, target_ids, lm_labels = batch

            #lm_logits, *_ = model(input_ids, token_type_ids=token_type_ids,)
            lm_logits, *_ = model(source_ids, target_ids)
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted), (lm_labels_flat_shifted)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-100),
             output_transform=lambda x: (x[0], x[1]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        log_dir = make_logdir(args.model_checkpoint, args.dataset_path,
                              args.keyword_Module)
        tb_logger = TensorboardLogger(log_dir)

        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=2)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': model
             })  # "getattr" takes care of distributed encapsulation

        torch.save(args, log_dir + '/model_training_args.bin')
        #getattr(model, 'module', model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME))
        #tokenizer.save_pretrained(log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(
            os.path.join(log_dir, checkpoint_handler._saved[-1][1]),
            os.path.join(log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Beispiel #23
0
def main():
    args = get_args()
    if 'e-SNLI-VE' in args.data_path:
        args.no_image = False
    else:
        args.no_image = True
    if not args.no_image:
        args.no_premise = True
    args.with_expl = True

    '''Setup'''
    t = datetime.today()
    output_dir = os.path.join(args.output_folder,
                              f"{t.month}_{t.day}_{t.hour}_{t.minute}_{t.second}")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(filename=os.path.join(output_dir, 'app.log'),
                        filemode='a',
                        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    # This is a logger.warning: it will be printed by all distributed processes
    logger.warning(f"Running process {args.local_rank}")
    logger.info(f"Arguments: {pformat(args)}")
    logger.info(f'Image not used:{args.no_image}')
    logger.info(f'Premise not used:{args.no_premise}')
    logger.info(f'Explanations used:{args.with_expl}')

    '''Initialize distributed training if needed'''
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info(
        "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning")
    tokenizer = GPT2Tokenizer.from_pretrained(args.model_checkpoint)
    tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT)
    if args.no_image:
        model = GPT2LMHeadModel.from_pretrained(args.model_checkpoint)
    else:
        import image_gpt2_291
        model = image_gpt2_291.GPT2LMHeadModel.from_pretrained(
            args.model_checkpoint)
    model.resize_token_embeddings(len(tokenizer))
    model.to(args.device)
    optimizer = AdamW(model.parameters(), lr=args.lr)

    '''
    Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    '''
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)
        model = model.module

    logger.info("Prepare datasets")
    train_loader, val_loader = get_data_loaders(args, tokenizer)

    '''Training function and trainer'''
    def train(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        if args.no_image:
            input_ids, lm_label, label, input_mask = batch
        else:
            image, input_ids, lm_label, label, input_mask = batch

        if args.no_image:
            output = model(input_ids=input_ids,
                           #    attention_mask=input_mask,
                           labels=lm_label)
        else:
            output = model(input_ids=input_ids,
                           images=image,
                           #    attention_mask=input_mask,
                           labels=lm_label)
        loss, logits, _ = output

        loss = loss / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(
                amp.master_params(optimizer), args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        if not args.with_expl:
            lbl_accuracy = torch.eq(label, logits.argmax(
                dim=1)).float().sum() / len(label)
            return {
                'loss': loss.item(),
                'lbl_accuracy': lbl_accuracy.item()
            }
        else:
            if engine.state.iteration % (args.gradient_accumulation_steps * 500) == 0:
                input_output = list(zip(input_ids, logits))
                random_item = random.choice(input_output)
                in_sent = tokenizer.decode(list(filter(
                    lambda x: x != tokenizer.eos_token_id,
                    random_item[0])))
                out_expl = tokenizer.decode(random_item[1].argmax(dim=1),
                                            skip_special_tokens=True)
                logger.info(f'MODEL INPUT: {in_sent}')
                logger.info(f'GEN. EXPL {out_expl}')
                logger.info('--------------------------------')
            return {
                'loss': loss.item(),
            }

    '''Validation function and validator (validator output is the input of the metrics)'''
    def validation(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(input_tensor.to(args.device)
                          for input_tensor in batch)
            if args.no_image:
                input_ids, lm_label, label, input_mask = batch
            else:
                image, input_ids, lm_label, label, input_mask = batch

            if args.no_image:
                output = model(input_ids=input_ids,
                               #    attention_mask=input_mask
                               )
            else:
                output = model(input_ids=input_ids,
                               images=image,
                               #    attention_mask=input_mask
                               )
            logits, _ = output

            logits_shifted = logits[..., :-1, :].contiguous().view(-1,
                                                                   logits.size(-1))
            labels_shifted = lm_label[..., 1:].contiguous().view(-1)
            return logits_shifted, labels_shifted

    '''Engines'''
    trainer = Engine(train)
    validator = Engine(validation)

    # t_total = len(
    #     train_loader) // args.gradient_accumulation_steps * args.n_epochs
    # scheduler = get_linear_schedule_with_warmup(
    #     optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
    '''Linearly decrease the learning rate from lr to zero'''
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    '''
    Attach validation to trainer: we evaluate when we start the training and at the end of each epoch
    '''
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: validator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: validator.run(val_loader))

    '''Prepare metrics - note how we compute distributed metrics'''
    RunningAverage(output_transform=lambda x: x['loss']).attach(
        trainer, "loss")
    RunningAverage(output_transform=lambda x: math.exp(
        average_distributed_scalar(x['loss'], args))).attach(trainer, "ppl")
    if not args.with_expl:
        RunningAverage(output_transform=lambda x: 100 * x['lbl_accuracy']).attach(
            trainer, "lbl_accuracy")

    metrics = {}
    metrics["lbl_loss"] = Loss(torch.nn.CrossEntropyLoss(),
                               output_transform=lambda x: (x[0], x[1]))
    metrics["loss"] = MetricsLambda(
        lambda l, a: average_distributed_scalar(
            l / a.gradient_accumulation_steps, a), metrics["lbl_loss"], args)
    metrics["ppl"] = MetricsLambda(math.exp, metrics["loss"])
    if not args.with_expl:
        metrics["lbl_accuracy"] = 100 * \
            Accuracy(output_transform=lambda x: (x[0], x[1]))
    for name, metric in metrics.items():
        metric.attach(validator, name)

    '''
    On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    '''
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer,
                    metric_names=["loss", 'ppl'] if args.with_expl else ["loss", 'lbl_accuracy', 'ppl'])
        validator.add_event_handler(Events.COMPLETED,
                                    lambda _: pbar.log_message(
                                        "Validation: %s" % pformat(validator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=output_dir)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(
                             tag="training",
                             metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(
                             tag="training",
                             metric_names=["ppl"] if args.with_expl else ["lbl_accuracy", "ppl"]),
                         event_name=Events.EPOCH_COMPLETED)

        tb_logger.attach(validator,
                         log_handler=OutputHandler(
                             tag="validation",
                             metric_names=[
                                 'ppl', 'loss'] if args.with_expl else['ppl', 'loss', 'lbl_accuracy'],
                             global_step_transform=lambda *args, **kwargs: trainer.state.iteration),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(output_dir,
                                             'checkpoint',
                                             n_saved=8,
                                             require_empty=False)
        trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1),
                                  checkpoint_handler,
                                  {'mymodel': getattr(model, 'module', model)})

        # "getattr" take care of distributed encapsulation
        torch.save(args, os.path.join(output_dir, 'model_training_args.bin'))
        getattr(model, 'module', model).config.to_json_file(
            os.path.join(output_dir, CONFIG_NAME))
        tokenizer.save_vocabulary(output_dir)

    '''Run the training'''
    trainer.run(train_loader, max_epochs=args.n_epochs)
Beispiel #24
0
    tag="training",
    metric_names=["loss", "accuracy", "precision", "recall", "f1", "topKCatAcc"],
    global_step_transform=global_step_from_engine(trainer),
)
# Logging epoch validation metrics
tb_logger.attach_output_handler(
    engine=evaluator,
    event_name=Events.EPOCH_COMPLETED,
    tag="validation",
    metric_names=["loss", "accuracy", "precision", "recall", "f1", "topKCatAcc"],
    global_step_transform=global_step_from_engine(trainer),
)
# Attach the logger to the trainer to log model's weights as a histogram after each epoch
tb_logger.attach(
    trainer,
    event_name=Events.EPOCH_COMPLETED,
    log_handler=WeightsHistHandler(model)
)
# Attach the logger to the trainer to log model's gradients as a histogram after each epoch
tb_logger.attach(
    trainer,
    event_name=Events.EPOCH_COMPLETED,
    log_handler=GradsHistHandler(model)
)
print('Tensorboard Logging...', end='')
print('done')

## SETUP CALLBACKS
print('[INFO] Creating callback functions for training loop...', end='')
# Early Stopping - stops training if the validation loss does not decrease after 5 epochs
handler = EarlyStopping(patience=early_stopping_patience, score_function=score_function_loss, trainer=trainer)
Beispiel #25
0
    metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(), output_transform=lambda x: (x[0][0], x[1][0]))}
    metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args)})
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics)))

        log_dir = make_logdir(args,args.model_checkpoint)
        tb_logger = TensorboardLogger(log_dir)

        tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(log_dir, 'checkpoint', save_interval=1, n_saved=3)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model)})  # "getattr" takes care of distributed encapsulation

        torch.save(args, log_dir + '/model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME))
        tokenizer.save_pretrained(log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
def train():
    parser = ArgumentParser()
    parser.add_argument("--dataset_path",
                        type=str,
                        default="",
                        help="Path or url of the dataset.")
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=64,
                        help="Batch size for training")
    parser.add_argument("--valid_batch_size",
                        type=int,
                        default=64,
                        help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=8,
                        help="Accumulate gradients on several steps")
    parser.add_argument("--lr",
                        type=float,
                        default=6.25e-4,
                        help="Learning rate")
    parser.add_argument("--max_norm",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=15,
                        help="Number of training epochs")
    parser.add_argument(
        "--eval_before_start",
        action='store_true',
        help="If true start with a first evaluation before training")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument(
        "--fp16",
        type=str,
        default="",
        help=
        "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank for distributed training (-1: not distributed)")
    parser.add_argument("--gpt2_model_name",
                        type=str,
                        default="gpt2",
                        help="Path, url or short name of the model")

    parser.add_argument('-d_model', type=int, default=512)
    parser.add_argument('-d_inner_hid', type=int, default=2048)
    parser.add_argument('-d_k', type=int, default=64)
    parser.add_argument('-d_v', type=int, default=64)

    parser.add_argument('-n_head', type=int, default=8)
    parser.add_argument('-n_layers', type=int, default=6)
    parser.add_argument('-warmup', '--n_warmup_steps', type=int, default=4000)

    parser.add_argument('-dropout', type=float, default=0.1)
    parser.add_argument('-embs_share_weight', action='store_true')
    parser.add_argument('-proj_share_weight', action='store_true')
    parser.add_argument('-label_smoothing', action='store_true')
    args = parser.parse_args()
    args.d_word_vec = args.d_model

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", args.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info("Prepare tokenizer, pretrained model and optimizer.")

    tokenizer_class = GPT2Tokenizer if "gpt2" in args.gpt2_model_name else OpenAIGPTTokenizer  # cant use Autotokenizer because checkpoint could be a Path
    tokenizer = tokenizer_class.from_pretrained(args.gpt2_model_name)

    num_tokens = len(tokenizer.encoder)
    num_added_tokens = tokenizer.add_special_tokens(
        ATTR_TO_SPECIAL_TOKEN)  # doesn't add if they are already there

    model = Transformer(
        num_tokens + num_added_tokens,
        num_tokens + num_added_tokens,
        src_pad_idx=tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-1]),
        trg_pad_idx=tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-1]),
        trg_emb_prj_weight_sharing=args.proj_share_weight,
        emb_src_trg_weight_sharing=args.embs_share_weight,
        d_k=args.d_k,
        d_v=args.d_v,
        d_model=args.d_model,
        d_word_vec=args.d_word_vec,
        d_inner=args.d_inner_hid,
        n_layers=args.n_layers,
        n_head=args.n_head,
        dropout=args.dropout).to(args.device)

    optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        args, tokenizer, tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        source_ids, target_ids, lm_labels = batch

        (lm_loss), *_ = model(source_ids, target_ids, labels=lm_labels)

        loss = lm_loss / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            source_ids, target_ids, lm_labels = batch
            #logger.info(tokenizer.decode(target_ids[0].tolist()))

            lm_logits, *_ = model(source_ids, target_ids)
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted, ), (lm_labels_flat_shifted, )

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-100),
             output_transform=lambda x: (x[0][0], x[1][0]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        log_dir = make_logdir(args.gpt2_model_name, args.dataset_path)
        tb_logger = TensorboardLogger(log_dir)

        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=4)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" takes care of distributed encapsulation

        torch.save(args, log_dir + '/model_training_args.bin')
        #getattr(model, 'module', model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME))
        tokenizer.save_pretrained(log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(
            os.path.join(log_dir, checkpoint_handler._saved[-1][1]),
            os.path.join(log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Beispiel #27
0
def train():
    config_file = "configs/train_daily_dialog_emotion_action_config.json"
    config = Config.from_json_file(config_file)

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if config.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", config.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(config))

    # Initialize distributed training if needed
    config.distributed = (config.local_rank != -1)
    if config.distributed:
        torch.cuda.set_device(config.local_rank)
        config.device = torch.device("cuda", config.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info(
        "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning"
    )
    tokenizer_class = GPT2Tokenizer if "gpt2" in config.model_checkpoint else OpenAIGPTTokenizer
    tokenizer = tokenizer_class.from_pretrained(config.model_checkpoint)
    model_class = GPT2DoubleHeadsModel if "gpt2" in config.model_checkpoint else OpenAIGPTDoubleHeadsModel
    model = model_class.from_pretrained(config.model_checkpoint)
    tokenizer.set_special_tokens(SPECIAL_TOKENS)
    model.set_num_special_tokens(len(SPECIAL_TOKENS))
    model.to(config.device)
    optimizer = OpenAIAdam(model.parameters(), lr=config.lr)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if config.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=config.fp16)
    if config.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[config.local_rank],
                                        output_device=config.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        config, tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids, token_action_ids = tuple(
            input_tensor.to(config.device) for input_tensor in batch)
        lm_loss, mc_loss = model(input_ids, mc_token_ids, lm_labels, mc_labels,
                                 token_type_ids, token_emotion_ids,
                                 token_action_ids)
        loss = (lm_loss * config.lm_coef +
                mc_loss * config.mc_coef) / config.gradient_accumulation_steps
        if config.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           config.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_norm)
        if engine.state.iteration % config.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(config.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids, token_action_ids = batch
            #logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            model_outputs = model(input_ids,
                                  mc_token_ids,
                                  token_type_ids=token_type_ids,
                                  token_emotion_ids=token_emotion_ids,
                                  token_action_ids=token_action_ids)
            lm_logits, mc_logits = model_outputs[0], model_outputs[
                1]  # So we can also use GPT2 outputs
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted,
                    mc_logits), (lm_labels_flat_shifted, mc_labels)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if config.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if config.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if config.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, config.lr),
                                 (config.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0][0], x[1][0])),
        "accuracy":
        Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], config),
        "average_accuracy":
        MetricsLambda(average_distributed_scalar, metrics["accuracy"], config)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if config.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=config.log_dir)
        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" take care of distributed encapsulation

        torch.save(config,
                   tb_logger.writer.log_dir + '/model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(
            os.path.join(tb_logger.writer.log_dir, CONFIG_NAME))
        tokenizer.save_vocabulary(tb_logger.writer.log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=config.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if config.local_rank in [-1, 0] and config.n_epochs > 0:
        os.rename(
            checkpoint_handler._saved[-1][1][-1],
            os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
def run_training(model, optimizer, scheduler, output_path,
                 train_loader, val_loader, epochs, patience,
                  epochs_pretrain, mixed_precision, classes_weights):

    # trainer
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    if classes_weights is not None:
        classes_weights = classes_weights.to(device)
    crit = nn.CrossEntropyLoss(weight=classes_weights)
    metrics = {"accuracy": Accuracy(), "loss": Loss(crit)}
    trainer = create_supervised_trainer_with_pretraining(
        model, optimizer, crit, device=device, epochs_pretrain=epochs_pretrain,
        mixed_precision=mixed_precision)
    train_evaluator = create_supervised_evaluator(
        model, metrics=metrics, device=device)
    val_evaluator = create_supervised_evaluator(
        model, metrics=metrics, device=device)

    # Out paths
    path_ckpt = os.path.join(output_path, "model_ckpt")
    log_dir = os.path.join(output_path, "log_dir")
    os.makedirs(log_dir, exist_ok=True)

    # tensorboard
    tb_logger = TensorboardLogger(log_dir=log_dir)
    tb_logger.attach(train_evaluator, log_handler=OutputHandler(tag="training", metric_names=[
        "accuracy", "loss"], another_engine=trainer), event_name=Events.EPOCH_COMPLETED)
    tb_logger.attach(val_evaluator, log_handler=OutputHandler(tag="validation", metric_names=[
        "accuracy", "loss"], another_engine=trainer), event_name=Events.EPOCH_COMPLETED)

    # training progress
    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names="all")

    # @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        train_evaluator.run(train_loader)
        val_evaluator.run(val_loader)
        train_loss = train_evaluator.state.metrics["loss"]
        val_loss = val_evaluator.state.metrics["loss"]
        train_acc = train_evaluator.state.metrics["accuracy"]
        val_acc = val_evaluator.state.metrics["accuracy"]
        pbar.log_message(
            "Training Results - Epoch: {}  Loss: {:.6f}  Accuracy: {:.6f}".format(engine.state.epoch, train_loss, train_acc))
        pbar.log_message(
            "Validation Results - Epoch: {}  Loss: {:.6f}  Accuracy: {:.6f}".format(engine.state.epoch, val_loss, val_acc))

        pbar.n = pbar.last_print_n = 0

    trainer.add_event_handler(Events.EPOCH_COMPLETED, log_training_results)

    # def get_val_loss(engine):
    # 	return -engine.state.metrics['loss']
    def get_val_acc(engine):
        return engine.state.metrics['accuracy']

    # checkpoint and early stopping
    checkpointer = ModelCheckpoint(
        path_ckpt, "model", score_function=get_val_acc, score_name="accuracy", require_empty=False)
    early_stopper = EarlyStopping(patience, get_val_acc, trainer)

    to_save = {'optimizer': optimizer, 'model': model}
    if scheduler is not None:
        to_save["scheduler"] = scheduler
    val_evaluator.add_event_handler(Events.COMPLETED, checkpointer, to_save)
    val_evaluator.add_event_handler(Events.COMPLETED, early_stopper)
    if scheduler is not None:
        trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # free resources
    trainer.add_event_handler(
        Events.ITERATION_COMPLETED, lambda _: _empty_cache())
    train_evaluator.add_event_handler(
        Events.ITERATION_COMPLETED, lambda _: _empty_cache())
    val_evaluator.add_event_handler(
        Events.ITERATION_COMPLETED, lambda _: _empty_cache())

    trainer.run(train_loader, max_epochs=epochs)
    tb_logger.close()

    # Evaluation with best model
    model.load_state_dict(torch.load(
        glob.glob(os.path.join(path_ckpt, "*.pth"))[0])["model"])
    train_evaluator = create_supervised_evaluator(
        model, metrics=metrics, device=device)
    val_evaluator = create_supervised_evaluator(
        model, metrics=metrics, device=device)

    train_evaluator.run(train_loader)
    val_evaluator.run(val_loader)

    _pretty_print("Evaluating best model")
    pbar.log_message(
        "Best model on training set - Loss: {:.6f}  Accuracy: {:.6f}"
        .format(train_evaluator.state.metrics["loss"], train_evaluator.state.metrics["accuracy"]))
    pbar.log_message(
        "Best model on validation set - Loss: {:.6f}  Accuracy: {:.6f}"
        .format(val_evaluator.state.metrics["loss"], val_evaluator.state.metrics["accuracy"]))

    return model, train_evaluator.state.metrics, val_evaluator.state.metrics
Beispiel #29
0
def train():
    parser = ArgumentParser()
    parser.add_argument("--dataset_path", type=str, default="", help="Path or url of the dataset. If empty download from S3.")
    parser.add_argument("--dataset_cache", type=str, default='./dataset_cache', help="Path or url of the dataset cache")
    parser.add_argument("--model_checkpoint", type=str, default="openai-gpt", help="Path, url or short name of the model")
    parser.add_argument("--num_candidates", type=int, default=2, help="Number of candidates for training")
    parser.add_argument("--max_history", type=int, default=2, help="Number of previous exchanges to keep in history")
    parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size for training")
    parser.add_argument("--valid_batch_size", type=int, default=4, help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Accumulate gradients on several steps")
    parser.add_argument("--lr", type=float, default=6.25e-5, help="Learning rate")
    parser.add_argument("--lm_coef", type=float, default=1.0, help="LM loss coefficient")
    parser.add_argument("--mc_coef", type=float, default=1.0, help="Multiple-choice loss coefficient")
    parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm")
    parser.add_argument("--n_epochs", type=int, default=3, help="Number of training epochs")
    parser.add_argument("--personality_permutations", type=int, default=1, help="Number of permutations of personality sentences")
    parser.add_argument("--eval_before_start", action='store_true', help="If true start with a first evaluation before training")
    parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)")
    parser.add_argument("--fp16", type=str, default="", help="Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)")
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning("Running process %d", args.local_rank)  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')

    logger.info("Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning")
    tokenizer_class = GPT2Tokenizer if "gpt2" in args.model_checkpoint else OpenAIGPTTokenizer
    tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)
    model_class = GPT2LMHeadModel if "gpt2" in args.model_checkpoint else OpenAIGPTLMHeadModel
    model = model_class.from_pretrained(args.model_checkpoint)
    tokenizer.set_special_tokens(SPECIAL_TOKENS)
    model.set_num_special_tokens(len(SPECIAL_TOKENS))
    model.to(args.device)
    optimizer = OpenAIAdam(model.parameters(), lr=args.lr)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(args, tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        lm_loss, mc_loss = model(*batch)
        loss = (lm_loss * args.lm_coef + mc_loss * args.mc_coef) / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()
    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
            logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            model_outputs = model(input_ids, mc_token_ids, token_type_ids=token_type_ids)
            lm_logits, mc_logits = model_outputs[0], model_outputs[1]  # So we can also use GPT2 outputs
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels)
    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics 
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])),
               "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))}
    metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args),
                    "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)})
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics)))

        tb_logger = TensorboardLogger(log_dir=None)
        tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir, 'checkpoint', save_interval=1, n_saved=3)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model)})  # "getattr" take care of distributed encapsulation

        torch.save(args, tb_logger.writer.log_dir + '/model_training_args.bin')
        getattr(model, 'module', model).config.to_json_file(os.path.join(tb_logger.writer.log_dir, CONFIG_NAME))
        tokenizer.save_vocabulary(tb_logger.writer.log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(checkpoint_handler._saved[-1][1][-1], os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME))  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
def train(dataset_path,
          dataset_cache='./dataset_cache',
          model_checkpoint='gpt2',
          num_candidates=2,
          max_history=2,
          train_batch_size=4,
          valid_batch_size=4,
          gradient_accumulation_steps=8,
          lr=6.25e-5,
          lm_coef=1.0,
          mc_coef=1.0,
          max_norm=1.0,
          n_epochs=3,
          personality_permutations=1,
          eval_before_start=False,
          device="cuda" if torch.cuda.is_available() else "cpu",
          fp16='',
          path_prefix='',
          log_dir='',
          local_rank=-1):
    args = {**locals()}

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if local_rank in [-1, 0] else logging.WARN)
    # This is a logger.warning: it will be printed by all distributed processes
    logger.warning("Running process %d", local_rank)
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    distributed = (local_rank != -1)
    args['distributed'] = distributed

    if distributed:
        torch.cuda.set_device(local_rank)
        device = torch.device("cuda", local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info("Prepare tokenizer, pretrained model and optimizer.")
    # cant use Autotokenizer because checkpoint could be a Path
    tokenizer_class = GPT2Tokenizer if "gpt2" in model_checkpoint else OpenAIGPTTokenizer
    tokenizer = tokenizer_class.from_pretrained(model_checkpoint)

    model_class = GPT2DoubleHeadsModel if "gpt2" in model_checkpoint else OpenAIGPTDoubleHeadsModel
    model = model_class.from_pretrained(model_checkpoint)
    model.to(device)
    # Add special tokens if they are not already added
    add_special_tokens_(model, tokenizer)
    optimizer = AdamW(model.parameters(), lr=lr, correct_bias=True)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model, optimizer, opt_level=fp16)
    if distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[local_rank],
                                        output_device=local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        dataset_path, dataset_cache, num_candidates, personality_permutations,
        max_history, train_batch_size, valid_batch_size, distributed,
        tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(device) for input_tensor in batch)
        input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
        (lm_loss), (mc_loss), *_ = model(input_ids,
                                         token_type_ids=token_type_ids,
                                         mc_token_ids=mc_token_ids,
                                         mc_labels=mc_labels,
                                         lm_labels=lm_labels)
        loss = (lm_loss * lm_coef + mc_loss * mc_coef) / \
            gradient_accumulation_steps
        if fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        if engine.state.iteration % gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(input_tensor.to(device) for input_tensor in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch
            logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            # if we dont send labels to model, it doesnt return losses
            lm_logits, mc_logits, *_ = model(
                input_ids,
                token_type_ids=token_type_ids,
                mc_token_ids=mc_token_ids,
            )
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted,
                    mc_logits), (lm_labels_flat_shifted, mc_labels)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, lr), (n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0][0], x[1][0])),
        "accuracy":
        Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], local_rank,
                      device),
        "average_accuracy":
        MetricsLambda(average_distributed_scalar, metrics["accuracy"],
                      local_rank, device)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        log_dir = log_dir if log_dir else make_logdir(model_checkpoint,
                                                      path=path_prefix)
        tb_logger = TensorboardLogger(log_dir)

        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" takes care of distributed encapsulation

        torch.save(args, log_dir + '/model_training_bin')
        getattr(model, 'module',
                model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME))
        tokenizer.save_pretrained(log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if local_rank in [-1, 0] and n_epochs > 0:
        # TODO: PR in ignite to have better access to saved file paths (cleaner)
        os.rename(checkpoint_handler._saved[-1][1][-1],
                  os.path.join(log_dir, WEIGHTS_NAME))
        tb_logger.close()