Example #1
0
def main(args, local_rank=0):

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)

    vocabs = dict()
    vocabs['src'] = Vocab(args.src_vocab, 0, [BOS, EOS])
    vocabs['tgt'] = Vocab(args.tgt_vocab, 0, [BOS, EOS])

    logger.info(args)
    for name in vocabs:
        logger.info("vocab %s, size %d, coverage %.3f", name,
                    vocabs[name].size, vocabs[name].coverage)

    set_seed(19940117)

    #device = torch.device('cpu')
    torch.cuda.set_device(local_rank)
    device = torch.device('cuda', local_rank)

    logger.info("start building model")
    logger.info("building retriever")
    if args.add_retrieval_loss:
        retriever, another_model = Retriever.from_pretrained(
            args.num_retriever_heads,
            vocabs,
            args.retriever,
            args.nprobe,
            args.topk,
            local_rank,
            load_response_encoder=True)
        matchingmodel = MatchingModel(retriever.model, another_model)
        matchingmodel = matchingmodel.to(device)
    else:
        retriever = Retriever.from_pretrained(args.num_retriever_heads, vocabs,
                                              args.retriever, args.nprobe,
                                              args.topk, local_rank)

    logger.info("building retriever + generator")
    model = RetrieverGenerator(vocabs, retriever, args.share_encoder,
                               args.embed_dim, args.ff_embed_dim,
                               args.num_heads, args.dropout, args.mem_dropout,
                               args.enc_layers, args.dec_layers,
                               args.mem_enc_layers, args.label_smoothing)

    model = model.to(device)

    model.eval()
    dev_data = DataLoader(vocabs,
                          args.dev_data,
                          args.dev_batch_size,
                          for_train=False)
    bleu = validate(device,
                    model,
                    dev_data,
                    beam_size=5,
                    alpha=0.6,
                    max_time_step=10)
Example #2
0
def load_retriever(args,
                   device,
                   task_tokenizer,
                   retriever_tokenizer,
                   finetuned_path=None,
                   stored_index=None,
                   train_use_idx=None):
    print(
        f"\nLoading retriever: {finetuned_path if finetuned_path is not None else args.retrieval_model}\n"
    )
    config = AutoConfig.from_pretrained(args.retrieval_model)
    config.__dict__.update(args.__dict__)
    model = Retriever.from_pretrained(args.retrieval_model,
                                      config=config,
                                      cache_dir=args.cache_dir,
                                      task_tokenizer=task_tokenizer,
                                      retriever_tokenizer=retriever_tokenizer,
                                      stored_index=stored_index,
                                      train_use_idx=train_use_idx)
    model.resize_token_embeddings(len(retriever_tokenizer))
    if args.reinitialize_retriever and finetuned_path is None:
        model.init_weights()
    if finetuned_path is not None:
        model_state_dict = torch.load(
            finetuned_path, map_location=lambda storage, loc: storage
        )  # args for preventing memory leakage across gpus
        utils.rectify_mismatched_embeddings(model, model_state_dict,
                                            retriever_tokenizer)
        model.load_state_dict(model_state_dict)
    model = model.to(device)
    return model
Example #3
0
                          model_args.ff_embed_dim, model_args.num_heads,
                          model_args.dropout, model_args.enc_layers,
                          model_args.dec_layers, model_args.label_smoothing)
    elif model_args.arch == 'mem':
        model = MemGenerator(vocabs, model_args.embed_dim,
                             model_args.ff_embed_dim, model_args.num_heads,
                             model_args.dropout, model_args.mem_dropout,
                             model_args.enc_layers, model_args.dec_layers,
                             model_args.mem_enc_layers,
                             model_args.label_smoothing,
                             model_args.use_mem_score)
    elif model_args.arch == 'rg':
        retriever = Retriever.from_pretrained(
            model_args.num_retriever_heads,
            vocabs,
            args.index_path if args.index_path else model_args.retriever,
            model_args.nprobe,
            model_args.topk,
            args.device,
            use_response_encoder=(model_args.rebuild_every > 0))
        model = RetrieverGenerator(
            vocabs, retriever, model_args.share_encoder, model_args.embed_dim,
            model_args.ff_embed_dim, model_args.num_heads, model_args.dropout,
            model_args.mem_dropout, model_args.enc_layers,
            model_args.dec_layers, model_args.mem_enc_layers,
            model_args.label_smoothing)

        if args.hot_index is not None:
            model.retriever.drop_index()
            torch.cuda.empty_cache()
            model.retriever.update_index(args.hot_index, model_args.nprobe)
Example #4
0
def main(args, local_rank):

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO)

    vocabs = dict()
    vocabs['src'] = Vocab(args.src_vocab, 0, [BOS, EOS])
    vocabs['tgt'] = Vocab(args.tgt_vocab, 0, [BOS, EOS])

    if args.world_size == 1 or (dist.get_rank() == 0):
        logger.info(args)
        for name in vocabs:
            logger.info("vocab %s, size %d, coverage %.3f", name,
                        vocabs[name].size, vocabs[name].coverage)

    set_seed(19940117)

    #device = torch.device('cpu')
    torch.cuda.set_device(local_rank)
    device = torch.device('cuda', local_rank)

    if args.arch == 'vanilla':
        model = Generator(vocabs, args.embed_dim, args.ff_embed_dim,
                          args.num_heads, args.dropout, args.enc_layers,
                          args.dec_layers, args.label_smoothing)
    elif args.arch == 'mem':
        model = MemGenerator(vocabs, args.embed_dim, args.ff_embed_dim,
                             args.num_heads, args.dropout, args.mem_dropout,
                             args.enc_layers, args.dec_layers,
                             args.mem_enc_layers, args.label_smoothing,
                             args.use_mem_score)
    elif args.arch == 'rg':
        logger.info("start building model")
        logger.info("building retriever")
        retriever = Retriever.from_pretrained(
            args.num_retriever_heads,
            vocabs,
            args.retriever,
            args.nprobe,
            args.topk,
            local_rank,
            use_response_encoder=(args.rebuild_every > 0))

        logger.info("building retriever + generator")
        model = RetrieverGenerator(vocabs, retriever, args.share_encoder,
                                   args.embed_dim, args.ff_embed_dim,
                                   args.num_heads, args.dropout,
                                   args.mem_dropout, args.enc_layers,
                                   args.dec_layers, args.mem_enc_layers,
                                   args.label_smoothing)

    if args.resume_ckpt:
        model.load_state_dict(torch.load(args.resume_ckpt)['model'])
    else:
        global_step = 0

    if args.world_size > 1:
        set_seed(19940117 + dist.get_rank())

    model = model.to(device)

    retriever_params = [
        v for k, v in model.named_parameters() if k.startswith('retriever.')
    ]
    other_params = [
        v for k, v in model.named_parameters()
        if not k.startswith('retriever.')
    ]

    optimizer = Adam([{
        'params': retriever_params,
        'lr': args.embed_dim**-0.5 * 0.1
    }, {
        'params': other_params,
        'lr': args.embed_dim**-0.5
    }],
                     betas=(0.9, 0.98),
                     eps=1e-9)
    lr_schedule = get_inverse_sqrt_schedule_with_warmup(
        optimizer, args.warmup_steps, args.total_train_steps)
    train_data = DataLoader(vocabs,
                            args.train_data,
                            args.per_gpu_train_batch_size,
                            for_train=True,
                            rank=local_rank,
                            num_replica=args.world_size)

    model.eval()
    #dev_data = DataLoader(vocabs, cur_dev_data, args.dev_batch_size, for_train=False)
    #bleu = validate(device, model, dev_data, beam_size=5, alpha=0.6, max_time_step=10)

    step, epoch = 0, 0
    tr_stat = Statistics()
    logger.info("start training")
    model.train()

    best_dev_bleu = 0.
    while global_step <= args.total_train_steps:
        for batch in train_data:
            #step_start = time.time()
            batch = move_to_device(batch, device)
            if args.arch == 'rg':
                loss, acc = model(
                    batch,
                    update_mem_bias=(global_step >
                                     args.update_retriever_after))
            else:
                loss, acc = model(batch)

            tr_stat.update({
                'loss': loss.item() * batch['tgt_num_tokens'],
                'tokens': batch['tgt_num_tokens'],
                'acc': acc
            })
            tr_stat.step()
            loss.backward()
            #step_cost = time.time() - step_start
            #print ('step_cost', step_cost)
            step += 1
            if not (step % args.gradient_accumulation_steps
                    == -1 % args.gradient_accumulation_steps):
                continue

            if args.world_size > 1:
                average_gradients(model)

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            lr_schedule.step()
            optimizer.zero_grad()
            global_step += 1

            if args.world_size == 1 or (dist.get_rank() == 0):
                if global_step % args.print_every == -1 % args.print_every:
                    logger.info("epoch %d, step %d, loss %.3f, acc %.3f",
                                epoch, global_step,
                                tr_stat['loss'] / tr_stat['tokens'],
                                tr_stat['acc'] / tr_stat['tokens'])
                    tr_stat = Statistics()
                if global_step % args.eval_every == -1 % args.eval_every:
                    model.eval()
                    max_time_step = 256 if global_step > 2 * args.warmup_steps else 5
                    bleus = []
                    for cur_dev_data in args.dev_data:
                        dev_data = DataLoader(vocabs,
                                              cur_dev_data,
                                              args.dev_batch_size,
                                              for_train=False)
                        bleu = validate(device,
                                        model,
                                        dev_data,
                                        beam_size=5,
                                        alpha=0.6,
                                        max_time_step=max_time_step)
                        bleus.append(bleu)
                    bleu = sum(bleus) / len(bleus)
                    logger.info("epoch %d, step %d, dev bleu %.2f", epoch,
                                global_step, bleu)
                    if bleu > best_dev_bleu:
                        testbleus = []
                        for cur_test_data in args.test_data:
                            test_data = DataLoader(vocabs,
                                                   cur_test_data,
                                                   args.dev_batch_size,
                                                   for_train=False)
                            testbleu = validate(device,
                                                model,
                                                test_data,
                                                beam_size=5,
                                                alpha=0.6,
                                                max_time_step=max_time_step)
                            testbleus.append(testbleu)
                        testbleu = sum(testbleus) / len(testbleus)
                        logger.info("epoch %d, step %d, test bleu %.2f", epoch,
                                    global_step, testbleu)
                        torch.save({
                            'args': args,
                            'model': model.state_dict()
                        }, '%s/best.pt' % (args.ckpt, ))
                        if not args.only_save_best:
                            torch.save(
                                {
                                    'args': args,
                                    'model': model.state_dict()
                                },
                                '%s/epoch%d_batch%d_devbleu%.2f_testbleu%.2f' %
                                (args.ckpt, epoch, global_step, bleu,
                                 testbleu))
                        best_dev_bleu = bleu
                    model.train()

            if args.rebuild_every > 0 and (global_step % args.rebuild_every
                                           == -1 % args.rebuild_every):
                model.retriever.drop_index()
                torch.cuda.empty_cache()
                next_index_dir = '%s/batch%d' % (args.ckpt, global_step)
                if args.world_size == 1 or (dist.get_rank() == 0):
                    model.retriever.rebuild_index(next_index_dir)
                    dist.barrier()
                else:
                    dist.barrier()
                model.retriever.update_index(next_index_dir, args.nprobe)

            if global_step > args.total_train_steps:
                break
        epoch += 1
    logger.info('rank %d, finish training after %d steps', local_rank,
                global_step)