def generate_samples(args):
    """Use a pre-trained GPT-2 model to generate a set of samples from scratch."""
    # Set seed
    set_random_seeds(args.random_seed)

    # Initialize training
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print('Device: {}'.format(str(device)))

    # Load pre-trained network weights
    print('Loading pre-trained model...')
    config = GPT2Config.from_pretrained(args.gpt2_version)
    model = GPT2LMHeadModel(config)
    model.load_state_dict(torch.load(args.model_load_path))
    model = model.to(device)
    model.eval()

    # Create tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained(args.gpt2_version)

    # Generate some samples
    print('Generating...')
    generated = generate_sequence(model,
                                  tokenizer,
                                  context=args.context,
                                  max_length=args.max_gen_len,
                                  num_samples=args.num_samples,
                                  top_k=args.sampling_top_k,
                                  device=device)
    print('Generated samples:')
    print(*generated, sep="\n---\n")
Ejemplo n.º 2
0
def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file,
                                       pytorch_dump_folder_path):
    # Construct model
    if gpt2_config_file == "":
        config = GPT2Config()
    else:
        config = GPT2Config.from_json_file(gpt2_config_file)
    model = GPT2Model(config)

    # Load weights from numpy
    load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path)

    # Save pytorch-model
    pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME
    pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME
    print("Save PyTorch model to {}".format(pytorch_weights_dump_path))
    torch.save(model.state_dict(), pytorch_weights_dump_path)
    print("Save configuration file to {}".format(pytorch_config_dump_path))
    with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
        f.write(config.to_json_string())
Ejemplo n.º 3
0
 def __init__(self, model_name: str) -> None:
     super().__init__()
     config = GPT2Config.from_pretrained(model_name)
     self.input_dim = config.hidden_size
     self.output_dim = config.vocab_size
     # TODO(mattg): It's possible that we could use some kind of cache like we have in
     # allennlp.modules.token_embedders.bert_token_embedder.PretrainedBertModel.  That way, we
     # would only load the GPT2 weights once.  Though, it's not clear how to do that here, as we
     # need to load `GPT2LMHeadModel`, not just `GPT2Model`...
     gpt2_model = GPT2LMHeadModel.from_pretrained(model_name)
     self.gpt2_lm_head = gpt2_model.lm_head
Ejemplo n.º 4
0
    def __init__(self, model_path):
        super(OnmtGPT2Encoder, self).__init__()
        config = GPT2Config.from_json_file(
            os.path.join(model_path, "config.json"))
        pretrained_dict = os.path.join(model_path, "pytorch_model.bin")
        if os.path.exists(pretrained_dict):
            model = GPT2Model.from_pretrained(
                pretrained_model_name_or_path=pretrained_dict, config=config)
            print("init GPT2 model with {} weights".format(
                len(model.state_dict())))
        else:
            model = GPT2Model(config)

        model.wte = expandEmbeddingByN(model.wte, 4)
        self.encoder = model

        #print(model)
        print("***" * 20)
Ejemplo n.º 5
0
def zero_shot_gpt2(args):
    print('Get model')
    config = GPT2Config.from_pretrained('gpt2')
    model = GPT2LMHeadModel(config)
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

    print("Evaluating Maslow test set with GPT2")
    path = ".data/stories/story_commonsense/torchtext_class/maslow/"
    src_query = " That made them"
    trg_query = " feel"  # need to split due to offset in loss
    ma_t_pred, ma_t_true = \
        evaluate_zero_shot(args, model, tokenizer, path, src_query, trg_query)

    # Maslow results
    t_acc = accuracy_score(ma_t_true, ma_t_pred)
    t_f1 = f1_score(ma_t_true, ma_t_pred, average='macro')
    t_p = precision_score(ma_t_true, ma_t_pred, average='macro')
    t_r = recall_score(ma_t_true, ma_t_pred, average='macro')
    print('Maslow')
    print(
        f'\t Test | acc: {t_acc:7.4f} | f1: {t_f1:7.4f} | prec: {t_p:7.4f} | rec: {t_r:7.4f}'
    )

    print("Evaluating Reiss test set with GPT2")
    path = ".data/stories/story_commonsense/torchtext_class/reiss/"
    src_query = " They did this to"
    trg_query = " to"  # need to split due to offset in loss
    re_t_true, re_t_pred = \
        evaluate_zero_shot(args, model, tokenizer, path, src_query, trg_query)

    # Reiss results
    t_acc = accuracy_score(re_t_true, re_t_pred)
    t_f1 = f1_score(re_t_true, re_t_pred, average='macro')
    t_p = precision_score(re_t_true, re_t_pred, average='macro')
    t_r = recall_score(re_t_true, re_t_pred, average='macro')
    print('Reiss')
    print(
        f'\t Test | acc: {t_acc:7.4f} | f1: {t_f1:7.4f} | prec: {t_p:7.4f} | rec: {t_r:7.4f}'
    )
        def prepare_config_and_inputs(self):
            input_ids = ids_tensor([self.batch_size, self.seq_length],
                                   self.vocab_size)

            token_type_ids = None
            if self.use_token_type_ids:
                token_type_ids = ids_tensor([self.batch_size, self.seq_length],
                                            self.type_vocab_size)

            sequence_labels = None
            token_labels = None
            choice_labels = None
            if self.use_labels:
                sequence_labels = ids_tensor([self.batch_size],
                                             self.type_sequence_label_size)
                token_labels = ids_tensor([self.batch_size, self.seq_length],
                                          self.num_labels)
                choice_labels = ids_tensor([self.batch_size], self.num_choices)

            config = GPT2Config(
                vocab_size_or_config_json_file=self.vocab_size,
                n_embd=self.hidden_size,
                n_layer=self.num_hidden_layers,
                n_head=self.num_attention_heads,
                # intermediate_size=self.intermediate_size,
                # hidden_act=self.hidden_act,
                # hidden_dropout_prob=self.hidden_dropout_prob,
                # attention_probs_dropout_prob=self.attention_probs_dropout_prob,
                n_positions=self.max_position_embeddings,
                n_ctx=self.max_position_embeddings
                # type_vocab_size=self.type_vocab_size,
                # initializer_range=self.initializer_range
            )

            head_mask = ids_tensor(
                [self.num_hidden_layers, self.num_attention_heads], 2)

            return config, input_ids, head_mask, token_type_ids, sequence_labels, token_labels, choice_labels
    def __init__(self):
        super(GPT2Generator, self).__init__()

        # TODO: can i make the outputs below large and the knowledge medium?
        self.gpt2_config = GPT2Config.from_pretrained('gpt2-large')
        self.lh_model = GPT2LMHeadModel.from_pretrained('gpt2-large')
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, help="pretrained_model.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_probe",
                        action='store_true',
                        help="Whether to probe the representation we got.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        '--data_dir',
        type=str,
        default=
        '/home/xiongyi/dataxyz/repos/SemSynLSTM/word_language_model/data/wikitext-2/'
    )
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--num_train_epochs', type=int, default=3)
    parser.add_argument('--train_batch_size', type=int, default=8)
    parser.add_argument('--eval_batch_size', type=int, default=16)
    parser.add_argument('--max_grad_norm', type=int, default=1)
    parser.add_argument('--learning_rate', type=float, default=6.25e-5)
    parser.add_argument('--warmup_proportion', type=float, default=0.002)
    parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
    parser.add_argument('--weight_decay', type=float, default=0.01)
    parser.add_argument('--lm_coef', type=float, default=0.9)
    parser.add_argument('--n_valid', type=int, default=374)

    timenow = datetime.datetime.now().strftime("%b%d%H%M")
    model_option = 'adv'
    outdir = model_option + timenow

    args = parser.parse_args(
        ['--output_dir', outdir, '--do_probe', '--num_train_epochs', '50'])
    #args = parser.parse_args(['--output_dir', './tmp', '--do_eval', '--model_name', 'gpt2'])
    print(args)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    logger.info("device: {}, n_gpu {}".format(device, n_gpu))

    if not args.do_train and not args.do_eval and not args.do_probe:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    # Compute the max input length for the Transformer
    # Todo: Where is this used?
    input_length = 128
    data_dir = '../SemSynLSTM/word_language_model/data/wikitext-2/' if args.data_dir is None else args.data_dir
    train_set, val_set, test_set, dictionary, pos_dictionary = load_tokenize_and_batchify(
        data_dir, input_length)

    # Prepare inputs tensors and dataloaders

    train_data = TensorDataset(*train_set)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data,
                                  sampler=train_sampler,
                                  batch_size=32)

    eval_data = TensorDataset(*val_set)
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=32)

    # TODO: Load tokenizer and model
    # This loading functions also add new tokens and embeddings called `special tokens`
    # These new embeddings will be fine-tuned on the RocStories dataset
    #special_tokens = ['_start_', '_delimiter_']
    #special_tokens_ids = list(tokenizer.convert_tokens_to_ids(token) for token in special_tokens)

    # TODO: Add config
    config = GPT2Config(n_positions=input_length,
                        n_ctx=input_length,
                        n_layer=6,
                        n_head=8,
                        n_embd=384)
    config.vocab_size = dictionary.__len__()
    config.pos_vocab_size = pos_dictionary.__len__()
    if args.model_name:
        model = GPT2LMHeadModel.from_pretrained(args.model_name)
    else:
        model = GPT2_adverse(config=config)
    model.to(device)

    # TODO: Load and encode the datasets

    logger.info("Encoding dataset...")

    # Prepare optimizer
    if args.do_train:
        all_param = list(model.named_parameters())
        param_optimizer = [(n, p) for n, p in all_param
                           if 'pos_head_adv' not in n]
        param_optimizer_adv = [(n, p) for n, p in all_param
                               if 'pos_head_adv' in n]
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        optimizer_adv_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer_adv
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params': [
                p for n, p in param_optimizer_adv
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]
        num_train_optimization_steps = len(
            train_dataloader) * args.num_train_epochs
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            #max_grad_norm=args.max_grad_norm,
            weight_decay=args.weight_decay)
        #t_total=num_train_optimization_steps)
        optimizer_adv = AdamW(
            optimizer_adv_grouped_parameters,
            lr=args.learning_rate,
            #max_grad_norm=args.max_grad_norm,
            weight_decay=args.weight_decay)

    if args.do_train:
        train_results = {}
        nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None
        for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
            ###eval on eval set
            model.eval()
            nb_eval_steps, nb_eval_examples = 0, 0
            perp = 0
            average_loss = np.asanyarray([0, 0, 0, 0], dtype='float')
            for batch in tqdm(eval_dataloader, desc="Evaluating"):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_pos_ids = batch

                with torch.no_grad():
                    #breakpoint()
                    loss = model(
                        input_ids, labels=input_ids,
                        pos_ids=input_pos_ids)[0].detach().cpu().numpy()
                    loss_syn = model(
                        input_ids, labels=input_ids,
                        pos_ids=input_pos_ids)[1].detach().cpu().numpy()
                    loss_sem = model(
                        input_ids, labels=input_ids,
                        pos_ids=input_pos_ids)[2].detach().cpu().numpy()
                    loss_lm = model(
                        input_ids, labels=input_ids,
                        pos_ids=input_pos_ids)[3].detach().cpu().numpy()
                    perp_batch = np.exp(loss_lm)
                    perp += perp_batch
                    average_loss += np.asanyarray(
                        [loss, loss_syn, loss_sem, loss_lm])
                nb_eval_steps += 1
            perp /= nb_eval_steps
            average_loss /= nb_eval_steps
            print('loss,loss_syn,loss_sem,loss_lm', average_loss, 'perp ',
                  perp, 'epoch ', epoch)
            train_results[epoch] = (perp, average_loss)

            model.train()

            tr_loss = 0
            nb_tr_steps = 0
            tqdm_bar = tqdm(train_dataloader, desc="Training")
            for step, batch in enumerate(tqdm_bar):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_pos_ids = batch
                loss = model(input_ids,
                             labels=input_ids,
                             pos_ids=input_pos_ids)[0]
                loss_lm = model(input_ids,
                                labels=input_ids,
                                pos_ids=input_pos_ids)[3]
                loss_sem = model(input_ids,
                                 labels=input_ids,
                                 pos_ids=input_pos_ids)[2]
                #breakpoint()
                #loss = args.lm_coef * losses[0] + losses[1]
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                loss_sem.backward()
                optimizer_adv.step()
                optimizer_adv.zero_grad()
                tr_loss += loss.item()
                exp_average_loss = loss.item(
                ) if exp_average_loss is None else 0.7 * exp_average_loss + 0.3 * loss.item(
                )
                nb_tr_steps += 1
                tqdm_bar.desc = "Training loss: {:.2e} sem: {:.2e} lm: {:.2e}".format(
                    exp_average_loss, loss_sem.item(), loss_lm.item())
        print(train_results)
    # Save a trained model
    if args.do_train:
        all_param = list(model.named_parameters())
        param_optimizer = [(n, p) for n, p in all_param
                           if 'pos_head_adv' not in n]
        param_optimizer_adv = [(n, p) for n, p in all_param
                               if 'pos_head_adv' in n]
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        optimizer_adv_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer_adv
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params': [
                p for n, p in param_optimizer_adv
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]
        num_train_optimization_steps = len(
            train_dataloader) * args.num_train_epochs
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            #max_grad_norm=args.max_grad_norm,
            weight_decay=args.weight_decay)
        #t_total=num_train_optimization_steps)
        optimizer_adv = AdamW(
            optimizer_adv_grouped_parameters,
            lr=args.learning_rate,
            #max_grad_norm=args.max_grad_norm,
            weight_decay=args.weight_decay)
    if args.do_train:
        # Save a trained model, configuration and tokenizer
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self

        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)

        torch.save(model_to_save.state_dict(), output_model_file)
        model_to_save.config.to_json_file(output_config_file)
        #tokenizer.save_vocabulary(args.output_dir)

        # Load a trained model and vocabulary that you have fine-tuned
        model = GPT2LMHeadModel.from_pretrained(args.output_dir)
        #tokenizer = OpenAIGPTTokenizer.from_pretrained(args.output_dir)
        model.to(device)

    if args.do_eval:
        model.eval()
        nb_eval_steps, nb_eval_examples = 0, 0
        log_probs_sum = 0
        perp = 0
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_pos_ids = batch

            with torch.no_grad():
                loss = model(input_ids,
                             labels=input_ids)[0].detach().cpu().numpy()
                perp_batch = np.exp(loss)
                perp += perp_batch
            nb_eval_steps += 1

        perp /= nb_eval_steps
        # perp_word = perp / 128
        print(perp)
        result = {'eval_perp': perp}
        logger.info("***** Eval results *****")
        logger.info("'eval_perp' = %s", str(result['eval_perp']))

    if args.do_probe:

        ##load model (how???)
        model_path = '/home/xiongyi/dataxyz/repos/pytorch-pretrained-BERT/examples/advJul232307/pytorch_model.bin'
        model.load_state_dict(torch.load(model_path))
        ##Add a mlp to the representation

        probe_model = ProbeModel(model, config)
        probe_model.to(device)
        ##train and eval
        all_param = list(probe_model.named_parameters())
        param_probe = [(n, p) for n, p in all_param if 'probe_cls' in n]
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params':
            [p for n, p in param_probe if not any(nd in n for nd in no_decay)],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_probe if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            # max_grad_norm=args.max_grad_norm,
            weight_decay=args.weight_decay)
        # t_total=num_train_optimization_steps)
        train_results = {}
        nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None
        for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
            ###eval on eval set
            probe_model.eval()
            nb_eval_steps, nb_eval_examples = 0, 0
            average_loss = 0
            average_acc = 0
            for batch in tqdm(eval_dataloader, desc="Evaluating"):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_pos_ids = batch

                with torch.no_grad():
                    #breakpoint()
                    loss = probe_model(
                        input_ids, labels=input_ids,
                        pos_ids=input_pos_ids)[0].detach().cpu().numpy()
                    pos_logits = probe_model(
                        input_ids, labels=input_ids,
                        pos_ids=input_pos_ids)[1].detach().cpu().numpy()
                    predicted_labels = np.argmax(pos_logits, -1)
                    correct_rate = np.mean(predicted_labels == input_pos_ids.
                                           detach().cpu().numpy()[:, 1:])
                    average_acc += correct_rate
                    average_loss += loss
                nb_eval_steps += 1
            average_loss /= nb_eval_steps
            ##TODO Hard CODED!
            average_acc /= nb_eval_steps
            print('loss', average_loss, ' acc_rate ', average_acc, ' epoch ',
                  epoch)
            train_results[epoch] = (average_loss, average_acc)

            probe_model.train()

            tr_loss = 0
            nb_tr_steps = 0
            tqdm_bar = tqdm(train_dataloader, desc="Training")
            for step, batch in enumerate(tqdm_bar):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_pos_ids = batch
                loss = probe_model(input_ids,
                                   labels=input_ids,
                                   pos_ids=input_pos_ids)[0]

                # breakpoint()
                # loss = args.lm_coef * losses[0] + losses[1]
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                tr_loss += loss.item()
                exp_average_loss = loss.item(
                ) if exp_average_loss is None else 0.7 * exp_average_loss + 0.3 * loss.item(
                )
                nb_tr_steps += 1
                tqdm_bar.desc = "Training loss: {:.2e}".format(
                    exp_average_loss)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_name", type=str, help="pretrained_model.")
    parser.add_argument("--model_path", type=str, help="pretrained_model.")
    parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
    parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
    parser.add_argument("--context_length", type=int, help="Whether to run eval on the dev set with limited context length.", default=30)
    parser.add_argument("--shuffle_pos", type=int, help="Shuffle words starting at a certain relative position.", default=30)
    parser.add_argument("--do_local_shuffle", action='store_true', help="Whether to run eval on the dev set with shuffled word order.")
    parser.add_argument("--do_global_shuffle", action='store_true', help="Whether to run eval on the dev set with shuffled word order.")
    parser.add_argument("--word_order_context_length", type=int, help="Whether to run eval on the dev set with shuffled word order.",default=None)
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    parser.add_argument('--data_dir', type=str, default='/home/xiongyi/dataxyz/repos/SemSynLSTM/word_language_model/data/wikitext-2/')
    parser.add_argument('--tokenized', action='store_true', help="Whether we have tokenized data ready.")
    parser.add_argument('--load_finetuned', action='store_true', help="Whether to load a finetuned model.")
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--num_train_epochs', type=int, default=3)
    parser.add_argument('--train_batch_size', type=int, default=2)
    parser.add_argument('--eval_batch_size', type=int, default=10000)
    parser.add_argument('--sequence_length', type=int, default=512)
    parser.add_argument('--max_grad_norm', type=int, default=1)
    parser.add_argument('--learning_rate', type=float, default=6.25e-5)
    parser.add_argument('--warmup_proportion', type=float, default=0.002)
    parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
    parser.add_argument('--weight_decay', type=float, default=0.01)
    parser.add_argument('--lm_coef', type=float, default=0.9)
    parser.add_argument('--n_valid', type=int, default=374)

    args = parser.parse_args(['--output_dir', './fine_tuned_model','--do_eval', '--num_train_epochs', '1',\
                              '--model_name', 'gpt2', '--tokenized','--load_finetuned', '--context_length',\
                              '300','--shuffle_pos','200', '--do_local_shuffle'])
    #args = parser.parse_args()
    #args = parser.parse_args(['--output_dir', './tmp', '--do_eval', '--model_name', 'gpt2'])
    print(args)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    logger.warning("device: {}, n_gpu {}".format(device, n_gpu))

    if not args.do_train and not args.do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    if args.load_finetuned:
        config = GPT2Config.from_pretrained('gpt2')
        model = GPT2LMHeadModel(config)
        model.load_state_dict(torch.load(os.path.join(args.output_dir, 'gpt2_1epoch.bin')))

        # tokenizer = OpenAIGPTTokenizer.from_pretrained(args.output_dir)
        model.to(device)
    elif args.model_name:
        model = GPT2LMHeadModel.from_pretrained(args.model_name)
        config = model.config

    wandb.watch(model)

    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

    # Compute the max input length for the Transformer
    # Todo: Where is this used?
    sequence_length = max(config.n_ctx, args.sequence_length)
    if not args.tokenized:
        data_dir = '../SemSynLSTM/word_language_model/data/wikitext-2/' if args.data_dir is None else args.data_dir
        corpus = TaggedGPT2Corpus(data_dir, tokenizer=tokenizer)
        torch.save(corpus.train[0], 'train_id.pt')
        torch.save(corpus.train[1], 'train_pos.pt')
        torch.save(corpus.valid[0], 'val_id.pt')
        torch.save(corpus.valid[1], 'val_pos.pt')
        train_set, val_set, test_set, dictionary, pos_dictionary = load_tokenize_and_batchify(tokenizer, corpus,
                                                                                              data_dir, sequence_length)
    else:
        train_id = torch.load('/home/xiongyi/dataxyz/data/corpora/wikitext-2/train_id.pt')
        train_pos = torch.load('/home/xiongyi/dataxyz/data/corpora/wikitext-2/train_pos.pt')
        train_set = (train_id, train_pos)
        val_set = torch.load('/home/xiongyi/dataxyz/data/corpora/wikitext-2/val_id.pt')[100000:110000]
        n_batch = len(train_set[0]) // sequence_length
        input_ids = train_set[0][: n_batch * sequence_length].reshape(n_batch, sequence_length)
        pos_ids = train_set[1][: n_batch * sequence_length].reshape(n_batch, sequence_length)
        all_inputs = (input_ids, pos_ids)
        train_set=tuple (t for t in all_inputs)
    #breakpoint()
    model.to(device)
    # Prepare inputs tensors and dataloaders

    train_data = TensorDataset(*train_set)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

    eval_data = TensorDataset(val_set)
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=len(val_set))

    # TODO: Load tokenizer and model
    # This loading functions also add new tokens and embeddings called `special tokens`
    # These new embeddings will be fine-tuned on the RocStories dataset
    #special_tokens = ['_start_', '_delimiter_']
    #special_tokens_ids = list(tokenizer.convert_tokens_to_ids(token) for token in special_tokens)

    # TODO: Add config



    # TODO: Load and encode the datasets

    logger.warning("Encoding dataset...")
    # Prepare optimizer
    if args.do_train:
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        num_train_optimization_steps = len(train_dataloader) * args.num_train_epochs
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          #max_grad_norm=args.max_grad_norm,
                          weight_decay=args.weight_decay)
                          #t_total=num_train_optimization_steps)

    if args.do_train:
        train_results = {}
        nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None
        model.train()
        for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
            ###eval on eval set
            # model.eval()
            # nb_eval_steps, nb_eval_examples = 0, 0
            # log_probs_sum = 0
            # perp = 0
            # average_loss = 0
            # for batch in tqdm(eval_dataloader, desc="Evaluating"):
            #     batch = tuple(t.to(device) for t in batch)
            #     input_ids, input_pos_ids = batch
            #
            #     with torch.no_grad():
            #         loss = model(input_ids, labels=input_ids)[0].detach().cpu().numpy()
            #         perp_batch = np.exp(loss)
            #         perp += perp_batch
            #         average_loss += loss
            #     nb_eval_steps += 1
            # perp /= nb_eval_steps
            # average_loss /= nb_eval_steps
            # print('loss', average_loss,'perp ', perp, 'epoch ', epoch)
            # train_results[epoch]= (perp, average_loss)

            model.train()

            tr_loss = 0
            nb_tr_steps = 0
            tqdm_bar = tqdm(train_dataloader, desc="Training")
            for step, batch in enumerate(tqdm_bar):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_pos_ids = batch
                loss = model(input_ids, labels=input_ids)[0]
                #breakpoint()
                #loss = args.lm_coef * losses[0] + losses[1]
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                tr_loss += loss.item()
                exp_average_loss = loss.item() if exp_average_loss is None else 0.7 * exp_average_loss + 0.3 * loss.item()
                nb_tr_steps += 1
                tqdm_bar.desc = "Training loss: {:.2e} ".format(exp_average_loss)

    # Save a trained model
    if args.do_train:
        # Save a trained model, configuration and tokenizer
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self

        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(args.output_dir, args.model_name+'_epoch_' + str(args.num_train_epochs))
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)

        torch.save(model_to_save.state_dict(), output_model_file)
        model_to_save.config.to_json_file(output_config_file)
        #tokenizer.save_vocabulary(args.output_dir)

        # Load a trained model and vocabulary that you have fine-tuned

        print (train_results)

    if args.do_eval:
        model.eval()
        with torch.no_grad():
            nb_eval_steps, nb_eval_examples = 0, 0
            perp = 0
            loss = 0
            processed_tokens = 0
            for batch in tqdm(eval_dataloader, desc="Evaluating"):
                dat = batch[0]
                #breakpoint()
                perp_batch = 0
                for i,token in enumerate(tqdm(dat)):
                    if i < args.context_length:
                        continue
                    if processed_tokens % 500 == 0 and processed_tokens:
                        print ('perp ', np.exp(loss/processed_tokens), 'processed_tokens ', processed_tokens )
                        logger.warning("'perp' = %s, 'processed_tokens = %s'", str(np.exp(loss/processed_tokens)), str(processed_tokens) )
                        wandb.log({"Eval perp": str(np.exp(loss/processed_tokens)), "Processed tokens": str(processed_tokens)})
                    input_ids = dat[i-args.context_length:i].to(device).unsqueeze(0)
                    if args.do_local_shuffle:
                        copy = input_ids[0,-args.shuffle_pos-20 : -args.shuffle_pos]
                        rand_ids = torch.randperm(len(copy))
                        copy = copy[rand_ids]
                        #random.shuffle(copy)
                        #copy.reverse()
                        input_ids[0,-args.shuffle_pos-20 : -args.shuffle_pos] = copy
                    elif args.do_global_shuffle:
                        copy = input_ids[0,:args.shuffle_pos]
                        rand_ids = torch.randperm(len(copy))
                        copy = copy[rand_ids]
                        #random.shuffle(copy)
                        #copy.reverse()
                        input_ids[0,:args.shuffle_pos]= copy

                    logits = model(input_ids)[0][0,-1,:].detach().cpu().numpy()
                    #pred_id = np.argmax(logits)
                    #pred_token = tokenizer.convert_ids_to_tokens([pred_id])[0]
                    #print (input_sent + ' ' + pred_token)
                    logprob = logits[token.item()] - logsumexp(logits)
                    #perp_tok = np.exp(-logprob)
                    #print (tokenizer.convert_ids_to_tokens([token.item()]), 'perp_tok ', perp_tok)
                    loss += -logprob
                    processed_tokens += 1
                nb_eval_steps += 1
                print ('processed ', processed_tokens)
                loss /= processed_tokens
                perp = np.exp(loss)
                # perp_word = perp / 128
                print (perp)
            result = {'eval_perp': perp}
            logger.warning("***** Eval results *****")
            logger.warning("'eval_perp' = %s", str(result['eval_perp']))
Ejemplo n.º 10
0
def main():
    set_seed(cfg_gpt)

    device = torch.device(
        "cuda" if cfg_gpt.use_gpu and torch.cuda.is_available() else 'cpu')
    cfg_gpt.device = device
    local_dir = f"JointSentGPT_32_VAD7_seed{cfg_gpt.seed}_lr{cfg_gpt.learning_rate}_len{cfg_gpt.max_sequence_length}_batch{cfg_gpt.train_batch_size}_epoch{cfg_gpt.num_train_epochs}" \
        f"_warmup{cfg_gpt.warmup_proportion}_{cfg_gpt.emotion_cls}_alpha{cfg_gpt.alpha}_beta{cfg_gpt.beta}_hidLoss{cfg_gpt.dist_loss}_hinge{cfg_gpt.hinge_phi if cfg_gpt.use_hinge else 0}"
    cfg_gpt.save_dir = os.path.join(cfg_gpt.save_dir, local_dir)
    if not os.path.exists(cfg_gpt.save_dir):
        os.mkdir(cfg_gpt.save_dir)
    print("save_dir: ", cfg_gpt.save_dir)

    # prepare path for evaluation
    cfg_gpt.ref_file = os.path.join(cfg_gpt.data_dir, f'ref.txt')

    # prepare config
    config = GPT2Config.from_pretrained(cfg_gpt.model_path)
    emo_num = 7 if cfg_gpt.emotion_cls == 'coarse' else 32
    setattr(config, 'emotion_size', emo_num + 1)
    setattr(config, 'alpha', cfg_gpt.alpha)
    setattr(config, 'temperature', cfg_gpt.temperature)
    setattr(config, 'leak_emotion_step', cfg_gpt.leak_emotion_step)
    setattr(config, 'emotion_label_num', emo_num)
    print(config)

    # build tokenizer
    print("load tokenizer ...")
    tokenizer = GPT2Tokenizer.from_pretrained(cfg_gpt.model_path)
    tokenizer.add_special_tokens(OrderedDict(cfg_gpt.SPECIAL_tokens))
    cfg_gpt.special_id_list = get_special_token_ids(cfg_gpt, tokenizer)

    # build model
    print("load model ...")
    model = JointSentiGPT2Model.from_pretrained(cfg_gpt.model_path,
                                                config=config)
    # reshape vocab size
    new_vocab_size = len(tokenizer)
    model.resize_token_embeddings(new_vocab_size)
    model.to(cfg_gpt.device)

    if cfg_gpt.do_train:
        train_loader = load_cache_examples(cfg_gpt, 'train')
        dev_loader = load_cache_examples(cfg_gpt, 'valid')
        if cfg_gpt.parallel:
            print("use parallel")
            model_parallel = DataParallelModel(model)
            train(cfg_gpt, model_parallel, train_loader, dev_loader)
        else:
            train(cfg_gpt, model, train_loader, dev_loader)

    if cfg_gpt.do_eval:
        print("begin decoding ...")
        # load dialog context as test input
        file_test = open(os.path.join(
            cfg_gpt.cache_dir,
            f"cache_JointSentGPT_decode_Joint_latent_32_VAD7_test_len{cfg_gpt.max_sequence_length}.json"
        ),
                         'r',
                         encoding='utf-8')
        test_intput_for_generate = json.load(file_test)

        # # load dialog context as train input
        # file_train = open(os.path.join(cfg_gpt.cache_dir, f"cache_decode_Joint_32_VAD7_train_len128_seed{cfg_gpt.seed}.json"), 'r', encoding='utf-8')
        # train_intput_for_generate = json.load(file_train)

        # # load dialog context as valid input
        # file_valid = open(os.path.join(cfg_gpt.cache_dir, f"cache_decode_Joint_32_VAD7_valid_len128_seed{cfg_gpt.seed}.json"), 'r', encoding='utf-8')
        # valid_intput_for_generate = json.load(file_valid)

        best_epoch = get_best_epoch(
            os.path.join(cfg_gpt.save_dir, "record.log"))
        cfg_gpt.best_epoch = best_epoch
        model.load_state_dict(
            torch.load(os.path.join(cfg_gpt.save_dir, f"epo{best_epoch}.pt")))

        # load cache test loader used for calculate perplexity
        test_dataloader = load_cache_examples(cfg_gpt, 'test')

        evaluation(cfg_gpt, model, tokenizer, test_intput_for_generate,
                   test_dataloader, 'test')
Ejemplo n.º 11
0
def train():
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset_path",
        type=str,
        default="",
        help="Path or url of the dataset. If empty download from S3.")
    parser.add_argument("--dataset_cache",
                        type=str,
                        default='./dataset_cache',
                        help="Path or url of the dataset cache")
    parser.add_argument("--model_checkpoint",
                        type=str,
                        default="gpt2",
                        help="Path, url or short name of the model")
    parser.add_argument(
        "--task",
        type=str,
        default="dialogue",
        help="one of task from [dialogue, qa, mt, nlg, summarization]")
    parser.add_argument("--emb_only",
                        action='store_true',
                        help="fine tune only task embeddings")
    parser.add_argument("--linear_perturb",
                        action='store_true',
                        help="fine tune only task embeddings")
    parser.add_argument("--max_history",
                        type=int,
                        default=2,
                        help="Number of previous exchanges to keep in history")
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for training")
    parser.add_argument("--valid_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=8,
                        help="Accumulate gradients on several steps")
    parser.add_argument("--lr",
                        type=float,
                        default=6.25e-5,
                        help="Learning rate")
    parser.add_argument("--lm_coef",
                        type=float,
                        default=1.0,
                        help="LM loss coefficient")
    parser.add_argument("--max_norm",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=1,
                        help="Number of training epochs")
    parser.add_argument("--personality_permutations",
                        type=int,
                        default=1,
                        help="Number of permutations of personality sentences")
    parser.add_argument(
        "--eval_before_start",
        action='store_true',
        help="If true start with a first evaluation before training")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument(
        "--fp16",
        type=str,
        default="",
        help=
        "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank for distributed training (-1: not distributed)")
    parser.add_argument("--perturbation_layers",
                        type=int,
                        default=0,
                        help="number of perturbation layers")
    parser.add_argument("--self_copy",
                        action='store_true',
                        help="add self copy ")
    parser.add_argument("--adapter_bottleneck",
                        type=int,
                        default=0,
                        help="adapter layer bottleneck")
    parser.add_argument("--random_init",
                        action='store_true',
                        help="don't use GPT-2 pre-trained model ")
    parser.add_argument("--distillation", action='store_true')
    parser.add_argument("--outputlayer_only", action='store_true')
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", args.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info("Prepare tokenizer, pretrained model and optimizer.")
    tokenizer_class = GPT2Tokenizer if "gpt2" in args.model_checkpoint else OpenAIGPTTokenizer  # cant use Autotokenizer because checkpoint could be a Path
    tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint)

    model_class = GPT2LMHeadModel if "gpt2" in args.model_checkpoint else OpenAIGPTDoubleHeadsModel
    if not args.random_init:
        model = model_class.from_pretrained(
            args.model_checkpoint,
            perturbation_layers=args.perturbation_layers,
            self_copy=args.self_copy,
            adapter_bottleneck=args.adapter_bottleneck)
    else:
        config = GPT2Config()
        model = model_class(config,
                            perturbation_layers=args.perturbation_layers,
                            self_copy=args.self_copy,
                            adapter_bottleneck=args.adapter_bottleneck)
    model.to(args.device)

    # Add special tokens if they are not already added
    add_special_tokens_(model, tokenizer)

    if args.adapter_bottleneck > 0:
        parameters_to_update = [
            p for n, p in model.named_parameters() if "adapter" in str(n)
        ] + [model.transformer.wte.weight]
        optimizer = AdamW(parameters_to_update, lr=args.lr, correct_bias=True)
    elif args.outputlayer_only:
        parameters_to_update = [model.transformer.wte.weight]
        optimizer = AdamW(parameters_to_update, lr=args.lr, correct_bias=True)
    else:
        optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        args, tokenizer)

    # Training function and trainer
    def update_emb(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        input_ids, lm_labels, token_type_ids = batch
        (lm_loss), *_ = model(input_ids,
                              token_type_ids=token_type_ids,
                              labels=lm_labels)
        loss = lm_loss / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            param_to_check = []
            for n, p in model.named_parameters():
                if (n != "transformer.wte.weight"):
                    param_to_check.append(p)
            a = list(param_to_check)[0].clone()
            model.transformer.wte.weight.grad[:50257, :] = 0
            model.transformer.wte.weight.data.add_(
                -args.lr, model.transformer.wte.weight.grad.data)
            optimizer.zero_grad()
            param_to_check = []
            for n, p in model.named_parameters():
                if (n != "transformer.wte.weight"):
                    param_to_check.append(p)

            b = list(param_to_check)[0].clone()
            assert torch.equal(a.data, b.data)
        return loss.item()

    # Training function and trainer
    def update_linear_perturbation(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        input_ids, lm_labels, token_type_ids = batch
        (lm_loss), *_ = model(input_ids,
                              token_type_ids=token_type_ids,
                              labels=lm_labels)
        loss = lm_loss / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:

            model.transformer.wte.weight.grad[:50257, :] = 0
            # model.transformer.wte.weight.data.add_(-args.lr,model.transformer.wte.weight.grad.data)
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    # Training function and trainer
    def update_all(engine, batch):
        model.train()
        batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        input_ids, lm_labels, token_type_ids = batch
        (lm_loss), *_ = model(input_ids,
                              token_type_ids=token_type_ids,
                              labels=lm_labels,
                              self_copy=args.self_copy)
        loss = lm_loss / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    if args.emb_only:
        trainer = Engine(update_emb)
    elif (args.linear_perturb or args.adapter_bottleneck > 0):
        trainer = Engine(update_linear_perturbation)
    else:
        trainer = Engine(update_all)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(
                input_tensor.to(args.device) for input_tensor in batch)
            input_ids, lm_labels, token_type_ids = batch
            logger.info(tokenizer.decode(input_ids[0, -1, :].tolist()))
            # if we dont send labels to model, it doesnt return losses
            lm_logits, *_ = model(
                input_ids,
                token_type_ids=token_type_ids,
            )
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted, ), (lm_labels_flat_shifted, )

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0][0], x[1][0]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        log_dir = make_logdir(args.model_checkpoint,
                              task=args.task,
                              lr=args.lr,
                              layer=args.perturbation_layers,
                              self_copy=args.self_copy,
                              n_epochs=args.n_epochs,
                              adapter=args.adapter_bottleneck,
                              random_init=args.random_init)
        if args.distillation:
            log_dir += "_distillation"
        if args.outputlayer_only:
            log_dir += "_outputlayer_only"
        tb_logger = TensorboardLogger(log_dir)

        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" takes care of distributed encapsulation

        torch.save(args, log_dir + '/model_training_args.bin')
        getattr(model, 'module',
                model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME))
        tokenizer.save_pretrained(log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(
            checkpoint_handler._saved[-1][1][-1],
            os.path.join(log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()