Ejemplo n.º 1
0
def add_pytorch_transformers_vocab(vocab, tokenizer_name):
    """Add vocabulary from tokenizers in pytorch_transformers for use with pre-tokenized data.

    These tokenizers have a convert_tokens_to_ids method, but this doesn't do
    anything special, so we can just use the standard indexers.
    """
    do_lower_case = "uncased" in tokenizer_name

    if tokenizer_name.startswith("bert-"):
        tokenizer = BertTokenizer.from_pretrained(tokenizer_name,
                                                  do_lower_case=do_lower_case)
    elif tokenizer_name.startswith("roberta-"):
        tokenizer = RobertaTokenizer.from_pretrained(tokenizer_name)
    elif tokenizer_name.startswith("xlnet-"):
        tokenizer = XLNetTokenizer.from_pretrained(tokenizer_name,
                                                   do_lower_case=do_lower_case)
    elif tokenizer_name.startswith("openai-gpt"):
        tokenizer = OpenAIGPTTokenizer.from_pretrained(tokenizer_name)
    elif tokenizer_name.startswith("gpt2"):
        tokenizer = GPT2Tokenizer.from_pretrained(tokenizer_name)
    elif tokenizer_name.startswith("transfo-xl-"):
        tokenizer = TransfoXLTokenizer.from_pretrained(tokenizer_name)
    elif tokenizer_name.startswith("xlm-"):
        tokenizer = XLMTokenizer.from_pretrained(tokenizer_name)

    if (tokenizer_name.startswith("openai-gpt")
            or tokenizer_name.startswith("gpt2")
            or tokenizer_name.startswith("transo-xl-")):
        tokenizer.add_special_tokens({
            "bos_token": "<start>",
            "sep_token": "<delim>",
            "cls_token": "<extract>"
        })
    # TODO: this is another place can be simplified by "model-before-preprocess" reorganization
    # we can pass tokenizer created in model here, see issue <TBD>

    vocab_size = len(tokenizer)
    # do not use tokenizer.vocab_size, it does not include newly added token
    if tokenizer_name.startswith("roberta-"):
        if tokenizer.convert_ids_to_tokens(vocab_size - 1) is None:
            vocab_size -= 1
        else:
            log.info("Time to delete vocab_size-1 in preprocess.py !!!")
    # due to a quirk in huggingface's file, the last token of RobertaTokenizer is None, remove
    # this when they fix the problem

    ordered_vocab = tokenizer.convert_ids_to_tokens(range(vocab_size))
    log.info("Added pytorch_transformers vocab (%s): %d tokens",
             tokenizer_name, len(ordered_vocab))
    for word in ordered_vocab:
        vocab.add_token_to_namespace(
            word, input_module_tokenizer_name(tokenizer_name))
Ejemplo n.º 2
0
 def __init__(self,
              chunck_size=64,
              max_length=35,
              device=torch.device('cuda:0')):
     super(XLClient, self).__init__()
     self.chunck_size = chunck_size
     self.tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')
     self.max_length = max_length
     # load the model
     self.model = TransfoXLModel.from_pretrained('transfo-xl-wt103')
     self.model.eval()
     self.device = device
     # move model to device
     self.model.to(self.device)
Ejemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser()
    add_dict_options(parser, ARGS)
    args = parser.parse_args()
    set_seed(args.seed)

    prefix_sampler = torch.load(args.prefix_file)
    tokenizer = TransfoXLTokenizer.from_pretrained(args.transfo_model)
    model = TransfoXLLMHeadModel.from_pretrained(args.transfo_model)
    model.load_state_dict(torch.load(args.resume, map_location=lambda s, l: s))
    model.cuda()

    sampler = SampleBatch(model, tokenizer, prefix_sampler)
    for _ in tqdm(range(args.num_samples)):
        print(sampler.simple_sample(pair=args.paired))
Ejemplo n.º 4
0
def test_transformer_xl_embeddings():
    transfo_model = 'transfo-xl-wt103'
    tokenizer = TransfoXLTokenizer.from_pretrained(transfo_model)
    model = TransfoXLModel.from_pretrained(
        pretrained_model_name_or_path=transfo_model, output_hidden_states=True)
    model.to(flair.device)
    model.eval()
    s = 'Berlin and Munich have a lot of puppeteer to see .'
    with torch.no_grad():
        tokens = tokenizer.tokenize((s + '<eos>'))
        print(tokens)
        indexed_tokens = tokenizer.convert_tokens_to_ids(tokens)
        tokens_tensor = torch.tensor([indexed_tokens])
        tokens_tensor = tokens_tensor.to(flair.device)
        hidden_states = model(tokens_tensor)[(-1)]
        first_layer = hidden_states[1][0]
    assert (len(first_layer) == len(tokens))

    def embed_sentence(sentence: str,
                       layers: str = '1',
                       use_scalar_mix: bool = False) -> Sentence:
        embeddings = TransformerXLEmbeddings(
            pretrained_model_name_or_path=transfo_model,
            layers=layers,
            use_scalar_mix=use_scalar_mix)
        flair_sentence = Sentence(sentence)
        embeddings.embed(flair_sentence)
        return flair_sentence

    sentence = embed_sentence(sentence=s)
    first_token_embedding_ref = first_layer[0].tolist()
    first_token_embedding_actual = sentence.tokens[0].embedding.tolist()
    puppeteer_embedding_ref = first_layer[7].tolist()
    puppeteer_embedding_actual = sentence.tokens[7].embedding.tolist()
    assert (first_token_embedding_ref == first_token_embedding_actual)
    assert (puppeteer_embedding_ref == puppeteer_embedding_actual)
    sentence_mult_layers = embed_sentence(sentence='Munich', layers='1,2,3,4')
    ref_embedding_size = (4 * model.d_embed)
    actual_embedding_size = len(sentence_mult_layers.tokens[0].embedding)
    assert (ref_embedding_size == actual_embedding_size)
    sentence_mult_layers_scalar_mix = embed_sentence(sentence='Berlin',
                                                     layers='1,2,3,4',
                                                     use_scalar_mix=True)
    ref_embedding_size = (1 * model.d_embed)
    actual_embedding_size = len(
        sentence_mult_layers_scalar_mix.tokens[0].embedding)
    assert (ref_embedding_size == actual_embedding_size)
Ejemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser()
    add_dict_options(parser, ARGS)
    args = parser.parse_args()
    set_seed(args.seed)
    sd = torch.load(args.cache_file)

    if args.model_type == 'gpt2':
        tokenizer = GPT2Tokenizer.from_pretrained(args.gpt2_model)
        encode = tokenizer.encode
        decode = lambda x: tokenizer.decoder[x]
    else:
        tokenizer = TransfoXLTokenizer.from_pretrained(args.transfo_model)
        encode = lambda x: [tokenizer.get_idx(x.lower().strip().split()[0])]
        decode = tokenizer.get_sym
    train_ds, _, _ = sd['splits']

    train_loader = tud.DataLoader(train_ds, batch_size=1, shuffle=True)
    token_ids_lst = []
    for batch in train_loader:
        _, sentences = batch
        token_ids_lst.extend(map(encode, sentences))
    sampler = PrefixSampler.from_token_ids(decode, token_ids_lst)
    torch.save(sampler, args.save)
Ejemplo n.º 6
0
import torch
from pytorch_transformers import TransfoXLTokenizer, TransfoXLModel, TransfoXLLMHeadModel

# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
import logging
logging.basicConfig(level=logging.INFO)

# Load pre-trained model tokenizer (vocabulary from wikitext 103)
tokenizer = TransfoXLTokenizer.from_pretrained('transfo-xl-wt103')

# Tokenized input
text_1 = "Who was Jim Henson ?"
text_2 = "Jim Henson was a puppeteer"
tokenized_text_1 = tokenizer.tokenize(text_1)
tokenized_text_2 = tokenizer.tokenize(text_2)

# Convert token to vocabulary indices
indexed_tokens_1 = tokenizer.convert_tokens_to_ids(tokenized_text_1)
indexed_tokens_2 = tokenizer.convert_tokens_to_ids(tokenized_text_2)

# Convert inputs to PyTorch tensors
tokens_tensor_1 = torch.tensor([indexed_tokens_1])
tokens_tensor_2 = torch.tensor([indexed_tokens_2])

# Load pre-trained model (weights)
model = TransfoXLModel.from_pretrained('transfo-xl-wt103')
model.eval()

# If you have a GPU, put everything on cuda
tokens_tensor_1 = tokens_tensor_1.to('cuda')
tokens_tensor_2 = tokens_tensor_2.to('cuda')
Ejemplo n.º 7
0
def test_transformer_xl_embeddings():
    transfo_model: str = "transfo-xl-wt103"

    tokenizer = TransfoXLTokenizer.from_pretrained(transfo_model)
    model = TransfoXLModel.from_pretrained(
        pretrained_model_name_or_path=transfo_model, output_hidden_states=True)
    model.to(flair.device)
    model.eval()

    s: str = "Berlin and Munich have a lot of puppeteer to see ."

    with torch.no_grad():
        tokens = tokenizer.tokenize(s + "<eos>")

        print(tokens)

        indexed_tokens = tokenizer.convert_tokens_to_ids(tokens)
        tokens_tensor = torch.tensor([indexed_tokens])
        tokens_tensor = tokens_tensor.to(flair.device)

        hidden_states = model(tokens_tensor)[-1]

        first_layer = hidden_states[1][0]

    assert len(first_layer) == len(tokens)

    #     0       1        2        3     4     5      6        7        8      9     10     11
    #
    # 'Berlin', 'and', 'Munich', 'have', 'a', 'lot', 'of', 'puppeteer', 'to', 'see', '.', '<eos>'
    #     |       |        |        |     |     |      |        |        |      |     |
    #  Berlin    and    Munich    have    a    lot    of    puppeteer    to    see    .
    #
    #     0       1        2        3     4     5      6        7        8      9     10

    def embed_sentence(sentence: str,
                       layers: str = "1",
                       use_scalar_mix: bool = False) -> Sentence:
        embeddings = TransformerXLEmbeddings(model=transfo_model,
                                             layers=layers,
                                             use_scalar_mix=use_scalar_mix)
        flair_sentence = Sentence(sentence)
        embeddings.embed(flair_sentence)

        return flair_sentence

    sentence = embed_sentence(sentence=s)

    first_token_embedding_ref = first_layer[0].tolist()
    first_token_embedding_actual = sentence.tokens[0].embedding.tolist()

    puppeteer_embedding_ref = first_layer[7].tolist()
    puppeteer_embedding_actual = sentence.tokens[7].embedding.tolist()

    assert first_token_embedding_ref == first_token_embedding_actual
    assert puppeteer_embedding_ref == puppeteer_embedding_actual

    # Check embedding dimension when using multiple layers
    sentence_mult_layers = embed_sentence(sentence="Munich", layers="1,2,3,4")

    ref_embedding_size = 4 * model.d_embed
    actual_embedding_size = len(sentence_mult_layers.tokens[0].embedding)

    assert ref_embedding_size == actual_embedding_size

    # Check embedding dimension when using multiple layers and scalar mix
    sentence_mult_layers_scalar_mix = embed_sentence(sentence="Berlin",
                                                     layers="1,2,3,4",
                                                     use_scalar_mix=True)

    ref_embedding_size = 1 * model.d_embed
    actual_embedding_size = len(
        sentence_mult_layers_scalar_mix.tokens[0].embedding)

    assert ref_embedding_size == actual_embedding_size
Ejemplo n.º 8
0
def main():
    parser = argparse.ArgumentParser()
    add_dict_options(parser, ARGS)
    args = parser.parse_args()
    set_seed(args.seed)

    if args.prefix_file: prefix_sampler = torch.load(args.prefix_file)
    if args.transfo:
        tokenizer = TransfoXLTokenizer.from_pretrained(args.transfo_model)
        model = TransfoXLLMHeadModel.from_pretrained(args.transfo_model)
    elif args.bert:
        tokenizer = BertTokenizer.from_pretrained(args.bert_model)
        model = BertForMaskedLM.from_pretrained(args.bert_model)
    else:
        tokenizer = GPT2Tokenizer.from_pretrained(args.gpt2_model)
        model = GPT2LMHeadModel.from_pretrained(args.gpt2_model)
        init_sos(model)
    if args.resume:
        model.load_state_dict(
            torch.load(args.resume, map_location=lambda s, l: s))
    if not args.simple_sample: model = nn.DataParallel(model)
    model.cuda()

    if args.bert:
        text_batches = list(split(list(sys.stdin), 128))
        for text_batch in tqdm(text_batches, desc='Augmenting'):
            for _ in range(args.num_samples):
                mtext_batch = [
                    ' '.join('[MASK]' if (
                        random.random() < 0.2 and '\t' not in x) else x
                             for x in sent.split(' ')) for sent in text_batch
                ]
                print('\n'.join(
                    x.replace('[SEP]', '\t').strip() for x in augment_texts(
                        model, tokenizer, mtext_batch, max_len=args.msl)))
                sys.stdout.flush()
        return

    sample_batches = [
        SampleBatch(model, tokenizer, prefix_sampler)
        for _ in range(args.num_buffers)
    ]
    if args.simple_sample:
        for _ in tqdm(range(args.num_samples)):
            print(sample_batches[0].simple_sample(pair=args.paired,
                                                  transfo=args.transfo))
            sys.stdout.flush()
        return

    n_output = 0
    pbar = tqdm(total=args.num_samples, desc='Generating')
    while n_output < args.num_samples:
        try:
            sample_batch = random.choice(sample_batches)
            sample_batch.try_add_sample()
            fin_texts = sample_batch.step(pair=args.paired)
        except ValueError:
            sample_batch.try_add_sample()
            continue
        for fin_text in fin_texts:
            if n_output >= args.num_samples:
                return
            print(fin_text.replace(EOS_TOKEN, '').replace('<eos>', '\t'))
            sys.stdout.flush()
            pbar.update(1)
            n_output += 1
            if (n_output + 1) % args.balance_every == 0:
                pbar.set_postfix(dict(last_balance=n_output))
                SampleBatch.balance(sample_batches)
Ejemplo n.º 9
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
    parser.add_argument('--model_name', type=str, default='transfo-xl-wt103',
                        help='pretrained model name')
    parser.add_argument('--split', type=str, default='test',
                        choices=['all', 'valid', 'test'],
                        help='which split to evaluate')
    parser.add_argument('--batch_size', type=int, default=10,
                        help='batch size')
    parser.add_argument('--tgt_len', type=int, default=128,
                        help='number of tokens to predict')
    parser.add_argument('--ext_len', type=int, default=0,
                        help='length of the extended context')
    parser.add_argument('--mem_len', type=int, default=1600,
                        help='length of the retained previous heads')
    parser.add_argument('--clamp_len', type=int, default=1000,
                        help='max positional embedding index')
    parser.add_argument('--no_cuda', action='store_true',
                        help='Do not use CUDA even though CUA is available')
    parser.add_argument('--work_dir', type=str, required=True,
                        help='path to the work_dir')
    parser.add_argument('--no_log', action='store_true',
                        help='do not log the eval result')
    parser.add_argument('--same_length', action='store_true',
                        help='set same length attention with masking')
    parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
    args = parser.parse_args()
    assert args.ext_len >= 0, 'extended context length must be non-negative'

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    logger.info("device: {}".format(device))

    # Load a pre-processed dataset
    # You can also build the corpus yourself using TransfoXLCorpus methods
    # The pre-processing involve computing word frequencies to prepare the Adaptive input and SoftMax
    # and tokenizing the dataset
    # The pre-processed corpus is a convertion (using the conversion script )
    tokenizer = TransfoXLTokenizer.from_pretrained(args.model_name)
    corpus = TransfoXLCorpus.from_pretrained(args.model_name)
    ntokens = len(corpus.vocab)

    va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len,
        device=device, ext_len=args.ext_len)
    te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
        device=device, ext_len=args.ext_len)

    # Load a pre-trained model
    model = TransfoXLLMHeadModel.from_pretrained(args.model_name)
    model = model.to(device)

    logger.info('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
        args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len))

    model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
    if args.clamp_len > 0:
        model.clamp_len = args.clamp_len
    if args.same_length:
        model.same_length = True

    ###############################################################################
    # Evaluation code
    ###############################################################################
    def evaluate(eval_iter):
        # Turn on evaluation mode which disables dropout.
        model.eval()
        total_len, total_loss = 0, 0.
        start_time = time.time()
        with torch.no_grad():
            mems = None
            for idx, (data, target, seq_len) in enumerate(eval_iter):
                ret = model(data, target, mems)
                loss, mems = ret
                loss = loss.mean()
                total_loss += seq_len * loss.item()
                total_len += seq_len
            total_time = time.time() - start_time
        logger.info('Time : {:.2f}s, {:.2f}ms/segment'.format(
                total_time, 1000 * total_time / (idx+1)))
        return total_loss / total_len

    # Run on test data.
    if args.split == 'all':
        test_loss = evaluate(te_iter)
        valid_loss = evaluate(va_iter)
    elif args.split == 'valid':
        valid_loss = evaluate(va_iter)
        test_loss = None
    elif args.split == 'test':
        test_loss = evaluate(te_iter)
        valid_loss = None

    def format_log(loss, split):
        log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format(
            split, loss, math.exp(loss))
        return log_str

    log_str = ''
    if valid_loss is not None:
        log_str += format_log(valid_loss, 'valid')
    if test_loss is not None:
        log_str += format_log(test_loss, 'test')

    logger.info('=' * 100)
    logger.info(log_str)
    logger.info('=' * 100)
Ejemplo n.º 10
0
    def __init__(self, args):
        self.acc_NA = Accuracy()
        self.acc_not_NA = Accuracy()
        self.acc_total = Accuracy()
        self.data_path = 'prepro_data'

        self.use_gpu = True
        self.is_training = True
        self.max_length = 512 #+160
        self.pos_num = 2 * self.max_length
        self.entity_num = self.max_length
        self.relation_num = 97

        self.use_bag = False

        self.coref_size = 20
        self.entity_type_size = 20
        self.max_epoch = 20
        self.opt_method = 'Adam'
        self.optimizer = None

        self.checkpoint_dir = CHECKPOINT_DIR
        self.fig_result_dir = './fig_result'
        self.test_epoch = 5
        self.checkpoint_epoch = 50
        self.pretrain_model = None

        self.word_size = 100
        self.epoch_range = None
        self.cnn_drop_prob = 0.5  # for cnn
        self.keep_prob = 0.8  # for lstm

        self.period = 50
        self.period_test = 20

        self.batch_size = args.batch_size
        # self.test_batch_size = 40
        self.h_t_limit = 1800

        self.test_batch_size = self.batch_size
        self.test_relation_limit = 1800
        self.char_limit = 16
        self.sent_limit = 25
        # self.combined_sent_limit = 200
        self.dis2idx = np.zeros((self.max_length), dtype='int64')
        self.dis2idx[1] = 1
        self.dis2idx[2:] = 2
        self.dis2idx[4:] = 3
        self.dis2idx[8:] = 4
        self.dis2idx[16:] = 5
        self.dis2idx[32:] = 6
        self.dis2idx[64:] = 7
        self.dis2idx[128:] = 8
        self.dis2idx[256:] = 9
        self.dis_size = 20

        self.learning_rate = args.learning_rate
        self.train_prefix = args.train_prefix
        self.test_prefix = args.test_prefix


        self.use_bert = False
        if args.model_name in ["BERT","T-REX"]:
            self.use_bert = True

        if not os.path.exists("log"):
            os.mkdir("log")

        self.batch_keys = ['context_idxs', 'context_pos', 'context_ner',
                           'relation_label', 'ht_pair_pos', 'pos_idx',
                           'input_lengths', 'context_char_idxs']
        self.batch_keys_float = ['h_mapping', 't_mapping', 'relation_multi_label', 'relation_mask']

        if self.use_bert:
            self.batch_keys += ["context_starts",]
            self.batch_keys_float += ["context_masks"]
            if "text_encoder" not in args or args.text_encoder =="bert":
                self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
            else:
                self.tokenizer = TransfoXLTokenizer.from_pretrained("transfo-xl-wt103")

            self.bert_word2id = {word: wid for wid, word in self.tokenizer.vocab.items()}
            self.bert_id2word= {wid: word for word,wid in self.bert_word2id.items()}
        self.epoch = 0

        self.save_name = args.save_name

        self.data_word_vec = np.load(os.path.join(self.data_path, 'vec.npy'))
        self.data_char_vec = np.load(os.path.join(self.data_path, 'char_vec.npy'))
        self.rel2id = json.load(open(os.path.join(self.data_path, 'rel2id.json')))
        self.id2rel = {v: k for k, v in self.rel2id.items()}
Ejemplo n.º 11
0
def main():
    def evaluate(data_source, split_encode=False):
        model.eval()
        total_loss = 0
        total_words = 0
        total_n = 0
        batch_idx = 0
        for batch in data_source:
            _, queries = batch
            try:
                queries, mask, total_chars, words = transfo_encode(tokenizer, queries, sos_idx, split_encode=split_encode, 
                    condition_model=args.conditioned_model)
            except KeyError:
                continue
            total_words += words
            mask = torch.Tensor(mask).cuda()
            queries = torch.LongTensor(queries).cuda()

            with torch.no_grad():
                output = model(queries[:, :-1])[0].permute(0, 2, 1)
            targets = queries[:, 1:]
            crit = criterion(output, targets)
            mask_tot = mask[:, 1:].sum()
            raw_loss = (crit * mask[:, 1:]).sum() / mask_tot
            loss = raw_loss

            total_loss += raw_loss.item() * mask_tot.item()
            total_n += total_chars
            # print(total_loss / (math.log(2) * total_n))

        cur_loss = total_loss / total_n
        elapsed = time.time() - start_time
        word_ppl = math.exp(total_loss / total_words)
        dual_print('-' * 89)
        dual_print('| end of epoch {:3d} | lr {:05.5f} | ms/batch {:5.2f} | '
                'loss {:5.2f} | ppl {:8.2f} | bpc {:8.3f}'.format(
            epoch, optimizer.param_groups[0]['lr'],
            elapsed * 1000 / args.log_interval, cur_loss, word_ppl, cur_loss / math.log(2)))
        dual_print('-' * 89)
        return cur_loss / math.log(2)

    parser = argparse.ArgumentParser()
    add_dict_options(parser, ARGS)
    args = parser.parse_args()
    set_seed(args.seed)
    sd = torch.load(args.cache_file)

    tokenizer = TransfoXLTokenizer.from_pretrained(args.transfo_model, cache_dir='transfo-model')
    model = TransfoXLLMHeadModel.from_pretrained(args.transfo_model, cache_dir='transfo-model')
    if args.reset: model.apply(model.init_weights)
    sos_idx = None
    if not args.use_sos: sos_idx = None
    train_ds, dev_ds, test_ds = sd['splits']
    criterion = nn.CrossEntropyLoss(reduction='none')

    train_loader = tud.DataLoader(train_ds, batch_size=args.train_batch_size, shuffle=True, drop_last=args.drop_last)
    dev_loader = tud.DataLoader(dev_ds, batch_size=args.eval_batch_size, shuffle=False, drop_last=args.drop_last)
    test_loader = tud.DataLoader(test_ds, batch_size=args.eval_batch_size, shuffle=False, drop_last=args.drop_last)

    no_decay = ['bias']
    params = list(model.named_parameters())
    optimizer_grouped_parameters = [
        {'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in params if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    num_train_optimization_steps = args.num_train_epochs * len(train_loader)
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=int(args.warmup_proportion * num_train_optimization_steps),
                                     t_total=num_train_optimization_steps)

    if args.resume:
        model.load_state_dict(torch.load(args.resume, map_location=lambda s, l: s))
    if args.test_eval:
        while True:
            query = input("> ")
            print(sample_query(model, tokenizer, query))

    model = nn.DataParallel(model).cuda()
    start_time = time.time()
    best_bpc = 1000000

    if not args.do_train:
        evaluate(test_loader, split_encode=False)
        return

    for epoch in range(args.num_train_epochs):
        epoch += 1
        total_loss = 0
        total_words = 0
        total_n = 0
        batch_idx = 0
        for batch in train_loader:
            model.train()
            _, queries = batch
            try:
                queries, mask, total_chars, words = transfo_encode(tokenizer, queries, sos_idx, split_encode=args.split_encode, 
                    condition_model=args.conditioned_model)
            except KeyError:
                dual_print('Skipped batch')
                continue
            total_words += words
            mask = torch.Tensor(mask).cuda()
            queries = torch.LongTensor(queries).cuda()
            optimizer.zero_grad()

            output = model(queries[:, :-1])[0].permute(0, 2, 1)
            targets = queries[:, 1:]
            crit = criterion(output, targets)
            mask_tot = mask[:, 1:].sum()
            raw_loss = (crit * mask[:, 1:]).sum() / mask_tot

            loss = raw_loss
            loss.backward()
            optimizer.step()
            scheduler.step()

            total_loss += raw_loss.item() * mask_tot.item()
            total_n += total_chars
            if batch_idx % args.log_interval == 0 and batch_idx > 0:
                cur_loss = total_loss / total_n
                word_ppl = math.exp(total_loss / total_words)
                total_words = 0
                elapsed = time.time() - start_time
                dual_print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | '
                        'loss {:5.2f} | ppl {:8.2f} | bpc {:8.3f}'.format(
                    epoch, batch_idx, len(train_loader), optimizer.param_groups[0]['lr'],
                    elapsed * 1000 / args.log_interval, cur_loss, word_ppl, cur_loss / math.log(2)))
                total_loss = 0
                total_n = 0
                start_time = time.time()
            batch_idx += 1
        bpc = evaluate(dev_loader)
        if bpc < best_bpc:
            best_bpc = bpc
            torch.save(model.module.state_dict(), args.save)
    evaluate(test_loader)