def main(args, local_rank=0): logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) vocabs = dict() vocabs['src'] = Vocab(args.src_vocab, 0, [BOS, EOS]) vocabs['tgt'] = Vocab(args.tgt_vocab, 0, [BOS, EOS]) logger.info(args) for name in vocabs: logger.info("vocab %s, size %d, coverage %.3f", name, vocabs[name].size, vocabs[name].coverage) set_seed(19940117) #device = torch.device('cpu') torch.cuda.set_device(local_rank) device = torch.device('cuda', local_rank) logger.info("start building model") logger.info("building retriever") if args.add_retrieval_loss: retriever, another_model = Retriever.from_pretrained( args.num_retriever_heads, vocabs, args.retriever, args.nprobe, args.topk, local_rank, load_response_encoder=True) matchingmodel = MatchingModel(retriever.model, another_model) matchingmodel = matchingmodel.to(device) else: retriever = Retriever.from_pretrained(args.num_retriever_heads, vocabs, args.retriever, args.nprobe, args.topk, local_rank) logger.info("building retriever + generator") model = RetrieverGenerator(vocabs, retriever, args.share_encoder, args.embed_dim, args.ff_embed_dim, args.num_heads, args.dropout, args.mem_dropout, args.enc_layers, args.dec_layers, args.mem_enc_layers, args.label_smoothing) model = model.to(device) model.eval() dev_data = DataLoader(vocabs, args.dev_data, args.dev_batch_size, for_train=False) bleu = validate(device, model, dev_data, beam_size=5, alpha=0.6, max_time_step=10)
def load_retriever(args, device, task_tokenizer, retriever_tokenizer, finetuned_path=None, stored_index=None, train_use_idx=None): print( f"\nLoading retriever: {finetuned_path if finetuned_path is not None else args.retrieval_model}\n" ) config = AutoConfig.from_pretrained(args.retrieval_model) config.__dict__.update(args.__dict__) model = Retriever.from_pretrained(args.retrieval_model, config=config, cache_dir=args.cache_dir, task_tokenizer=task_tokenizer, retriever_tokenizer=retriever_tokenizer, stored_index=stored_index, train_use_idx=train_use_idx) model.resize_token_embeddings(len(retriever_tokenizer)) if args.reinitialize_retriever and finetuned_path is None: model.init_weights() if finetuned_path is not None: model_state_dict = torch.load( finetuned_path, map_location=lambda storage, loc: storage ) # args for preventing memory leakage across gpus utils.rectify_mismatched_embeddings(model, model_state_dict, retriever_tokenizer) model.load_state_dict(model_state_dict) model = model.to(device) return model
model_args.ff_embed_dim, model_args.num_heads, model_args.dropout, model_args.enc_layers, model_args.dec_layers, model_args.label_smoothing) elif model_args.arch == 'mem': model = MemGenerator(vocabs, model_args.embed_dim, model_args.ff_embed_dim, model_args.num_heads, model_args.dropout, model_args.mem_dropout, model_args.enc_layers, model_args.dec_layers, model_args.mem_enc_layers, model_args.label_smoothing, model_args.use_mem_score) elif model_args.arch == 'rg': retriever = Retriever.from_pretrained( model_args.num_retriever_heads, vocabs, args.index_path if args.index_path else model_args.retriever, model_args.nprobe, model_args.topk, args.device, use_response_encoder=(model_args.rebuild_every > 0)) model = RetrieverGenerator( vocabs, retriever, model_args.share_encoder, model_args.embed_dim, model_args.ff_embed_dim, model_args.num_heads, model_args.dropout, model_args.mem_dropout, model_args.enc_layers, model_args.dec_layers, model_args.mem_enc_layers, model_args.label_smoothing) if args.hot_index is not None: model.retriever.drop_index() torch.cuda.empty_cache() model.retriever.update_index(args.hot_index, model_args.nprobe)
def main(args, local_rank): logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) vocabs = dict() vocabs['src'] = Vocab(args.src_vocab, 0, [BOS, EOS]) vocabs['tgt'] = Vocab(args.tgt_vocab, 0, [BOS, EOS]) if args.world_size == 1 or (dist.get_rank() == 0): logger.info(args) for name in vocabs: logger.info("vocab %s, size %d, coverage %.3f", name, vocabs[name].size, vocabs[name].coverage) set_seed(19940117) #device = torch.device('cpu') torch.cuda.set_device(local_rank) device = torch.device('cuda', local_rank) if args.arch == 'vanilla': model = Generator(vocabs, args.embed_dim, args.ff_embed_dim, args.num_heads, args.dropout, args.enc_layers, args.dec_layers, args.label_smoothing) elif args.arch == 'mem': model = MemGenerator(vocabs, args.embed_dim, args.ff_embed_dim, args.num_heads, args.dropout, args.mem_dropout, args.enc_layers, args.dec_layers, args.mem_enc_layers, args.label_smoothing, args.use_mem_score) elif args.arch == 'rg': logger.info("start building model") logger.info("building retriever") retriever = Retriever.from_pretrained( args.num_retriever_heads, vocabs, args.retriever, args.nprobe, args.topk, local_rank, use_response_encoder=(args.rebuild_every > 0)) logger.info("building retriever + generator") model = RetrieverGenerator(vocabs, retriever, args.share_encoder, args.embed_dim, args.ff_embed_dim, args.num_heads, args.dropout, args.mem_dropout, args.enc_layers, args.dec_layers, args.mem_enc_layers, args.label_smoothing) if args.resume_ckpt: model.load_state_dict(torch.load(args.resume_ckpt)['model']) else: global_step = 0 if args.world_size > 1: set_seed(19940117 + dist.get_rank()) model = model.to(device) retriever_params = [ v for k, v in model.named_parameters() if k.startswith('retriever.') ] other_params = [ v for k, v in model.named_parameters() if not k.startswith('retriever.') ] optimizer = Adam([{ 'params': retriever_params, 'lr': args.embed_dim**-0.5 * 0.1 }, { 'params': other_params, 'lr': args.embed_dim**-0.5 }], betas=(0.9, 0.98), eps=1e-9) lr_schedule = get_inverse_sqrt_schedule_with_warmup( optimizer, args.warmup_steps, args.total_train_steps) train_data = DataLoader(vocabs, args.train_data, args.per_gpu_train_batch_size, for_train=True, rank=local_rank, num_replica=args.world_size) model.eval() #dev_data = DataLoader(vocabs, cur_dev_data, args.dev_batch_size, for_train=False) #bleu = validate(device, model, dev_data, beam_size=5, alpha=0.6, max_time_step=10) step, epoch = 0, 0 tr_stat = Statistics() logger.info("start training") model.train() best_dev_bleu = 0. while global_step <= args.total_train_steps: for batch in train_data: #step_start = time.time() batch = move_to_device(batch, device) if args.arch == 'rg': loss, acc = model( batch, update_mem_bias=(global_step > args.update_retriever_after)) else: loss, acc = model(batch) tr_stat.update({ 'loss': loss.item() * batch['tgt_num_tokens'], 'tokens': batch['tgt_num_tokens'], 'acc': acc }) tr_stat.step() loss.backward() #step_cost = time.time() - step_start #print ('step_cost', step_cost) step += 1 if not (step % args.gradient_accumulation_steps == -1 % args.gradient_accumulation_steps): continue if args.world_size > 1: average_gradients(model) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() lr_schedule.step() optimizer.zero_grad() global_step += 1 if args.world_size == 1 or (dist.get_rank() == 0): if global_step % args.print_every == -1 % args.print_every: logger.info("epoch %d, step %d, loss %.3f, acc %.3f", epoch, global_step, tr_stat['loss'] / tr_stat['tokens'], tr_stat['acc'] / tr_stat['tokens']) tr_stat = Statistics() if global_step % args.eval_every == -1 % args.eval_every: model.eval() max_time_step = 256 if global_step > 2 * args.warmup_steps else 5 bleus = [] for cur_dev_data in args.dev_data: dev_data = DataLoader(vocabs, cur_dev_data, args.dev_batch_size, for_train=False) bleu = validate(device, model, dev_data, beam_size=5, alpha=0.6, max_time_step=max_time_step) bleus.append(bleu) bleu = sum(bleus) / len(bleus) logger.info("epoch %d, step %d, dev bleu %.2f", epoch, global_step, bleu) if bleu > best_dev_bleu: testbleus = [] for cur_test_data in args.test_data: test_data = DataLoader(vocabs, cur_test_data, args.dev_batch_size, for_train=False) testbleu = validate(device, model, test_data, beam_size=5, alpha=0.6, max_time_step=max_time_step) testbleus.append(testbleu) testbleu = sum(testbleus) / len(testbleus) logger.info("epoch %d, step %d, test bleu %.2f", epoch, global_step, testbleu) torch.save({ 'args': args, 'model': model.state_dict() }, '%s/best.pt' % (args.ckpt, )) if not args.only_save_best: torch.save( { 'args': args, 'model': model.state_dict() }, '%s/epoch%d_batch%d_devbleu%.2f_testbleu%.2f' % (args.ckpt, epoch, global_step, bleu, testbleu)) best_dev_bleu = bleu model.train() if args.rebuild_every > 0 and (global_step % args.rebuild_every == -1 % args.rebuild_every): model.retriever.drop_index() torch.cuda.empty_cache() next_index_dir = '%s/batch%d' % (args.ckpt, global_step) if args.world_size == 1 or (dist.get_rank() == 0): model.retriever.rebuild_index(next_index_dir) dist.barrier() else: dist.barrier() model.retriever.update_index(next_index_dir, args.nprobe) if global_step > args.total_train_steps: break epoch += 1 logger.info('rank %d, finish training after %d steps', local_rank, global_step)