def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
                                             transfo_xl_config_file,
                                             pytorch_dump_folder_path,
                                             transfo_xl_dataset_file):
    if transfo_xl_dataset_file:
        # Convert a pre-processed corpus (see original TensorFlow repo)
        with open(transfo_xl_dataset_file, "rb") as fp:
            corpus = pickle.load(fp, encoding="latin1")
        # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
        pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_FILES_NAMES[
            'pretrained_vocab_file']
        print("Save vocabulary to {}".format(pytorch_vocab_dump_path))
        corpus_vocab_dict = corpus.vocab.__dict__
        torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)

        corpus_dict_no_vocab = corpus.__dict__
        corpus_dict_no_vocab.pop('vocab', None)
        pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME
        print("Save dataset to {}".format(pytorch_dataset_dump_path))
        torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path)

    if tf_checkpoint_path:
        # Convert a pre-trained TensorFlow model
        config_path = os.path.abspath(transfo_xl_config_file)
        tf_path = os.path.abspath(tf_checkpoint_path)

        print("Converting Transformer XL checkpoint from {} with config at {}".
              format(tf_path, config_path))
        # Initialise PyTorch model
        if transfo_xl_config_file == "":
            config = TransfoXLConfig()
        else:
            config = TransfoXLConfig.from_json_file(transfo_xl_config_file)
        print("Building PyTorch model from configuration: {}".format(
            str(config)))
        model = TransfoXLLMHeadModel(config)

        model = load_tf_weights_in_transfo_xl(model, config, tf_path)
        # Save pytorch-model
        pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path,
                                                 WEIGHTS_NAME)
        pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path,
                                                CONFIG_NAME)
        print("Save PyTorch model to {}".format(
            os.path.abspath(pytorch_weights_dump_path)))
        torch.save(model.state_dict(), pytorch_weights_dump_path)
        print("Save configuration file to {}".format(
            os.path.abspath(pytorch_config_dump_path)))
        with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
            f.write(config.to_json_string())
        def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels):
            model = TransfoXLLMHeadModel(config)
            model.eval()

            lm_logits_1, mems_1 = model(input_ids_1)
            loss_1, _, mems_1 = model(input_ids_1, labels=lm_labels)
            lm_logits_2, mems_2 = model(input_ids_2, mems=mems_1)
            loss_2, _, mems_2 = model(input_ids_2, labels=lm_labels, mems=mems_1)

            outputs = {
                "loss_1": loss_1,
                "mems_1": mems_1,
                "lm_logits_1": lm_logits_1,
                "loss_2": loss_2,
                "mems_2": mems_2,
                "lm_logits_2": lm_logits_2,
            }
            return outputs
示例#3
0
def main():
    parser = argparse.ArgumentParser()
    add_dict_options(parser, ARGS)
    args = parser.parse_args()
    set_seed(args.seed)

    prefix_sampler = torch.load(args.prefix_file)
    tokenizer = TransfoXLTokenizer.from_pretrained(args.transfo_model)
    model = TransfoXLLMHeadModel.from_pretrained(args.transfo_model)
    model.load_state_dict(torch.load(args.resume, map_location=lambda s, l: s))
    model.cuda()

    sampler = SampleBatch(model, tokenizer, prefix_sampler)
    for _ in tqdm(range(args.num_samples)):
        print(sampler.simple_sample(pair=args.paired))
示例#4
0
def main():
    parser = argparse.ArgumentParser()
    add_dict_options(parser, ARGS)
    args = parser.parse_args()
    set_seed(args.seed)

    if args.prefix_file: prefix_sampler = torch.load(args.prefix_file)
    if args.transfo:
        tokenizer = TransfoXLTokenizer.from_pretrained(args.transfo_model)
        model = TransfoXLLMHeadModel.from_pretrained(args.transfo_model)
    elif args.bert:
        tokenizer = BertTokenizer.from_pretrained(args.bert_model)
        model = BertForMaskedLM.from_pretrained(args.bert_model)
    else:
        tokenizer = GPT2Tokenizer.from_pretrained(args.gpt2_model)
        model = GPT2LMHeadModel.from_pretrained(args.gpt2_model)
        init_sos(model)
    if args.resume:
        model.load_state_dict(
            torch.load(args.resume, map_location=lambda s, l: s))
    if not args.simple_sample: model = nn.DataParallel(model)
    model.cuda()

    if args.bert:
        text_batches = list(split(list(sys.stdin), 128))
        for text_batch in tqdm(text_batches, desc='Augmenting'):
            for _ in range(args.num_samples):
                mtext_batch = [
                    ' '.join('[MASK]' if (
                        random.random() < 0.2 and '\t' not in x) else x
                             for x in sent.split(' ')) for sent in text_batch
                ]
                print('\n'.join(
                    x.replace('[SEP]', '\t').strip() for x in augment_texts(
                        model, tokenizer, mtext_batch, max_len=args.msl)))
                sys.stdout.flush()
        return

    sample_batches = [
        SampleBatch(model, tokenizer, prefix_sampler)
        for _ in range(args.num_buffers)
    ]
    if args.simple_sample:
        for _ in tqdm(range(args.num_samples)):
            print(sample_batches[0].simple_sample(pair=args.paired,
                                                  transfo=args.transfo))
            sys.stdout.flush()
        return

    n_output = 0
    pbar = tqdm(total=args.num_samples, desc='Generating')
    while n_output < args.num_samples:
        try:
            sample_batch = random.choice(sample_batches)
            sample_batch.try_add_sample()
            fin_texts = sample_batch.step(pair=args.paired)
        except ValueError:
            sample_batch.try_add_sample()
            continue
        for fin_text in fin_texts:
            if n_output >= args.num_samples:
                return
            print(fin_text.replace(EOS_TOKEN, '').replace('<eos>', '\t'))
            sys.stdout.flush()
            pbar.update(1)
            n_output += 1
            if (n_output + 1) % args.balance_every == 0:
                pbar.set_postfix(dict(last_balance=n_output))
                SampleBatch.balance(sample_batches)
示例#5
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
    parser.add_argument('--model_name', type=str, default='transfo-xl-wt103',
                        help='pretrained model name')
    parser.add_argument('--split', type=str, default='test',
                        choices=['all', 'valid', 'test'],
                        help='which split to evaluate')
    parser.add_argument('--batch_size', type=int, default=10,
                        help='batch size')
    parser.add_argument('--tgt_len', type=int, default=128,
                        help='number of tokens to predict')
    parser.add_argument('--ext_len', type=int, default=0,
                        help='length of the extended context')
    parser.add_argument('--mem_len', type=int, default=1600,
                        help='length of the retained previous heads')
    parser.add_argument('--clamp_len', type=int, default=1000,
                        help='max positional embedding index')
    parser.add_argument('--no_cuda', action='store_true',
                        help='Do not use CUDA even though CUA is available')
    parser.add_argument('--work_dir', type=str, required=True,
                        help='path to the work_dir')
    parser.add_argument('--no_log', action='store_true',
                        help='do not log the eval result')
    parser.add_argument('--same_length', action='store_true',
                        help='set same length attention with masking')
    parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
    args = parser.parse_args()
    assert args.ext_len >= 0, 'extended context length must be non-negative'

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    logger.info("device: {}".format(device))

    # Load a pre-processed dataset
    # You can also build the corpus yourself using TransfoXLCorpus methods
    # The pre-processing involve computing word frequencies to prepare the Adaptive input and SoftMax
    # and tokenizing the dataset
    # The pre-processed corpus is a convertion (using the conversion script )
    tokenizer = TransfoXLTokenizer.from_pretrained(args.model_name)
    corpus = TransfoXLCorpus.from_pretrained(args.model_name)
    ntokens = len(corpus.vocab)

    va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len,
        device=device, ext_len=args.ext_len)
    te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
        device=device, ext_len=args.ext_len)

    # Load a pre-trained model
    model = TransfoXLLMHeadModel.from_pretrained(args.model_name)
    model = model.to(device)

    logger.info('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
        args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len))

    model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
    if args.clamp_len > 0:
        model.clamp_len = args.clamp_len
    if args.same_length:
        model.same_length = True

    ###############################################################################
    # Evaluation code
    ###############################################################################
    def evaluate(eval_iter):
        # Turn on evaluation mode which disables dropout.
        model.eval()
        total_len, total_loss = 0, 0.
        start_time = time.time()
        with torch.no_grad():
            mems = None
            for idx, (data, target, seq_len) in enumerate(eval_iter):
                ret = model(data, target, mems)
                loss, mems = ret
                loss = loss.mean()
                total_loss += seq_len * loss.item()
                total_len += seq_len
            total_time = time.time() - start_time
        logger.info('Time : {:.2f}s, {:.2f}ms/segment'.format(
                total_time, 1000 * total_time / (idx+1)))
        return total_loss / total_len

    # Run on test data.
    if args.split == 'all':
        test_loss = evaluate(te_iter)
        valid_loss = evaluate(va_iter)
    elif args.split == 'valid':
        valid_loss = evaluate(va_iter)
        test_loss = None
    elif args.split == 'test':
        test_loss = evaluate(te_iter)
        valid_loss = None

    def format_log(loss, split):
        log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format(
            split, loss, math.exp(loss))
        return log_str

    log_str = ''
    if valid_loss is not None:
        log_str += format_log(valid_loss, 'valid')
    if test_loss is not None:
        log_str += format_log(test_loss, 'test')

    logger.info('=' * 100)
    logger.info(log_str)
    logger.info('=' * 100)
示例#6
0
def main():
    def evaluate(data_source, split_encode=False):
        model.eval()
        total_loss = 0
        total_words = 0
        total_n = 0
        batch_idx = 0
        for batch in data_source:
            _, queries = batch
            try:
                queries, mask, total_chars, words = transfo_encode(tokenizer, queries, sos_idx, split_encode=split_encode, 
                    condition_model=args.conditioned_model)
            except KeyError:
                continue
            total_words += words
            mask = torch.Tensor(mask).cuda()
            queries = torch.LongTensor(queries).cuda()

            with torch.no_grad():
                output = model(queries[:, :-1])[0].permute(0, 2, 1)
            targets = queries[:, 1:]
            crit = criterion(output, targets)
            mask_tot = mask[:, 1:].sum()
            raw_loss = (crit * mask[:, 1:]).sum() / mask_tot
            loss = raw_loss

            total_loss += raw_loss.item() * mask_tot.item()
            total_n += total_chars
            # print(total_loss / (math.log(2) * total_n))

        cur_loss = total_loss / total_n
        elapsed = time.time() - start_time
        word_ppl = math.exp(total_loss / total_words)
        dual_print('-' * 89)
        dual_print('| end of epoch {:3d} | lr {:05.5f} | ms/batch {:5.2f} | '
                'loss {:5.2f} | ppl {:8.2f} | bpc {:8.3f}'.format(
            epoch, optimizer.param_groups[0]['lr'],
            elapsed * 1000 / args.log_interval, cur_loss, word_ppl, cur_loss / math.log(2)))
        dual_print('-' * 89)
        return cur_loss / math.log(2)

    parser = argparse.ArgumentParser()
    add_dict_options(parser, ARGS)
    args = parser.parse_args()
    set_seed(args.seed)
    sd = torch.load(args.cache_file)

    tokenizer = TransfoXLTokenizer.from_pretrained(args.transfo_model, cache_dir='transfo-model')
    model = TransfoXLLMHeadModel.from_pretrained(args.transfo_model, cache_dir='transfo-model')
    if args.reset: model.apply(model.init_weights)
    sos_idx = None
    if not args.use_sos: sos_idx = None
    train_ds, dev_ds, test_ds = sd['splits']
    criterion = nn.CrossEntropyLoss(reduction='none')

    train_loader = tud.DataLoader(train_ds, batch_size=args.train_batch_size, shuffle=True, drop_last=args.drop_last)
    dev_loader = tud.DataLoader(dev_ds, batch_size=args.eval_batch_size, shuffle=False, drop_last=args.drop_last)
    test_loader = tud.DataLoader(test_ds, batch_size=args.eval_batch_size, shuffle=False, drop_last=args.drop_last)

    no_decay = ['bias']
    params = list(model.named_parameters())
    optimizer_grouped_parameters = [
        {'params': [p for n, p in params if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in params if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    num_train_optimization_steps = args.num_train_epochs * len(train_loader)
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, correct_bias=False)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=int(args.warmup_proportion * num_train_optimization_steps),
                                     t_total=num_train_optimization_steps)

    if args.resume:
        model.load_state_dict(torch.load(args.resume, map_location=lambda s, l: s))
    if args.test_eval:
        while True:
            query = input("> ")
            print(sample_query(model, tokenizer, query))

    model = nn.DataParallel(model).cuda()
    start_time = time.time()
    best_bpc = 1000000

    if not args.do_train:
        evaluate(test_loader, split_encode=False)
        return

    for epoch in range(args.num_train_epochs):
        epoch += 1
        total_loss = 0
        total_words = 0
        total_n = 0
        batch_idx = 0
        for batch in train_loader:
            model.train()
            _, queries = batch
            try:
                queries, mask, total_chars, words = transfo_encode(tokenizer, queries, sos_idx, split_encode=args.split_encode, 
                    condition_model=args.conditioned_model)
            except KeyError:
                dual_print('Skipped batch')
                continue
            total_words += words
            mask = torch.Tensor(mask).cuda()
            queries = torch.LongTensor(queries).cuda()
            optimizer.zero_grad()

            output = model(queries[:, :-1])[0].permute(0, 2, 1)
            targets = queries[:, 1:]
            crit = criterion(output, targets)
            mask_tot = mask[:, 1:].sum()
            raw_loss = (crit * mask[:, 1:]).sum() / mask_tot

            loss = raw_loss
            loss.backward()
            optimizer.step()
            scheduler.step()

            total_loss += raw_loss.item() * mask_tot.item()
            total_n += total_chars
            if batch_idx % args.log_interval == 0 and batch_idx > 0:
                cur_loss = total_loss / total_n
                word_ppl = math.exp(total_loss / total_words)
                total_words = 0
                elapsed = time.time() - start_time
                dual_print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | '
                        'loss {:5.2f} | ppl {:8.2f} | bpc {:8.3f}'.format(
                    epoch, batch_idx, len(train_loader), optimizer.param_groups[0]['lr'],
                    elapsed * 1000 / args.log_interval, cur_loss, word_ppl, cur_loss / math.log(2)))
                total_loss = 0
                total_n = 0
                start_time = time.time()
            batch_idx += 1
        bpc = evaluate(dev_loader)
        if bpc < best_bpc:
            best_bpc = bpc
            torch.save(model.module.state_dict(), args.save)
    evaluate(test_loader)