def prepare_incremental_input(self, step_seq): conc = ListsToTensor(step_seq, self.vocabs['concept']) conc_char = ListsofStringToTensor(step_seq, self.vocabs['concept_char']) conc, conc_char = move_to_device(conc, self.device), move_to_device( conc_char, self.device) return conc, conc_char
def train_loop(model, dl, batch_size: int, epoch, epochs, optimizer, verbose, max_iter_count, device): running_metrics = [] print(f"running epoch [{epoch+1}/{epochs}]") model.train() for iter_count, (batch, targets) in enumerate(dl): batch, targets = move_to_device(batch, targets, device=device) optimizer.zero_grad() metrics = model.training_step(batch, targets) metrics['loss'].backward() optimizer.step() metrics['loss'] = metrics['loss'].item() running_metrics.append(metrics) if (iter_count + 1) % verbose == 0: means = caclulate_means(running_metrics) running_metrics = [] log = [] for k, v in means.items(): log.append(f"{k}: {v:.04f}") log = "\t".join(log) log += f"\titer[{iter_count}/{max_iter_count}]" print(log)
def record_batch(model, batch, data): batch = move_to_device(batch, model.device) attn = model.encoder_attn(batch) #nlayers x tgt_len x src_len x bsz x num_heads for i, x in enumerate(data): L = len(x['concept']) + 1 x['attn'] = attn[:, :L, :L, i, :].cpu() return data
def generate_batch(model, batch, beam_size, alpha, max_time_step): batch = move_to_device(batch, model.device) res = dict() token_batch, score_batch = [], [] beams = model.work(batch, beam_size, max_time_step) for beam in beams: best_hyp = beam.get_k_best(1, alpha)[0] predicted_token = [token for token in best_hyp.seq[1:-1]] token_batch.append(predicted_token) score_batch.append(best_hyp.score) res['token'] = token_batch res['score'] = score_batch return res
def work(self, inp, allow_hit): src_tokens = inp['src_tokens'] src_feat, src, src_mask = self.model(src_tokens, return_src=True) num_heads, bsz, dim = src_feat.size() assert num_heads == self.num_heads topk = self.topk vecsq = src_feat.reshape(num_heads * bsz, -1).detach().cpu().numpy() #retrieval_start = time.time() vecsq = augment_query(vecsq) D, I = self.mips.search(vecsq, topk + 1) D = l2_to_ip(D, vecsq, self.mips_max_norm) / (self.mips_max_norm * self.mips_max_norm) # I, D: (bsz * num_heads x (topk + 1) ) indices = torch.zeros(topk, num_heads, bsz, dtype=torch.long) for i, (Ii, Di) in enumerate(zip(I, D)): bid, hid = i % bsz, i // bsz tmp_list = [] for pred, _ in zip(Ii, Di): if allow_hit or self.mem_pool[pred]!=inp['tgt_raw_sents'][bid]: tmp_list.append(pred) tmp_list = tmp_list[:topk] assert len(tmp_list) == topk indices[:, hid, bid] = torch.tensor(tmp_list) #retrieval_cost = time.time() - retrieval_start #print ('retrieval_cost', retrieval_cost) # convert to tensors: # all_mem_tokens -> seq_len x ( topk * num_heads * bsz ) # all_mem_feats -> topk * num_heads * bsz x dim all_mem_tokens = [] for idx in indices.view(-1).tolist(): #TODO self.mem_pool[idx] +[EOS] all_mem_tokens.append([BOS] + self.mem_pool[idx]) all_mem_tokens = ListsToTensor(all_mem_tokens, self.vocabs['tgt']) # to avoid GPU OOM issue, truncate the mem to the max. length of 1.5 x src_tokens max_mem_len = int(1.5 * src_tokens.shape[0]) all_mem_tokens = move_to_device(all_mem_tokens[:max_mem_len,:], inp['src_tokens'].device) if torch.is_tensor(self.mem_feat_or_feat_maker): all_mem_feats = self.mem_feat_or_feat_maker[indices].to(src_feat.device) else: all_mem_feats = self.mem_feat_or_feat_maker(all_mem_tokens).view(topk, num_heads, bsz, dim) # all_mem_scores -> topk x num_heads x bsz all_mem_scores = torch.sum(src_feat.unsqueeze(0) * all_mem_feats, dim=-1) / (self.mips_max_norm ** 2) mem_ret = {} indices = indices.view(-1, bsz).transpose(0, 1).tolist() mem_ret['retrieval_raw_sents'] = [ [self.mem_pool[idx] for idx in ind] for ind in indices] mem_ret['all_mem_tokens'] = all_mem_tokens mem_ret['all_mem_scores'] = all_mem_scores return src, src_mask, mem_ret
def get_features(batch_size, norm_th, vocab, model, used_data, used_ids, max_norm=None, max_norm_cf=1.0): vecs, ids = [], [] model = torch.nn.DataParallel(model, device_ids=list(range(torch.cuda.device_count()))) model.eval() data_loader = DataLoader(used_data, vocab, batch_size) cur, tot = 0, len(used_data) for batch in asynchronous_load(data_loader): batch = move_to_device(batch, torch.device('cuda', 0)).t() bsz = batch.size(0) cur_vecs = model(batch, batch_first=True).detach().cpu().numpy() valid = np.linalg.norm(cur_vecs, axis=1) <= norm_th vecs.append(cur_vecs[valid]) ids.append(used_ids[cur:cur+batch_size][valid]) cur += bsz logger.info("%d / %d", cur, tot) vecs = np.concatenate(vecs, 0) ids = np.concatenate(ids, 0) out, max_norm = augment_data(vecs, max_norm, max_norm_cf) return out, ids, max_norm
def validate(model, dev_data, device): model.eval() q_list = [] r_list = [] for batch in dev_data: batch = move_to_device(batch, device) q = model.query_encoder(batch['src_tokens']) r = model.response_encoder(batch['tgt_tokens']) q_list.append(q) r_list.append(r) q = torch.cat(q_list, dim=0) r = torch.cat(r_list, dim=0) bsz = q.size(0) scores = torch.mm(q, r.t()) # bsz x bsz gold = torch.arange(bsz, device=scores.device) _, pred = torch.max(scores, -1) acc = torch.sum(torch.eq(gold, pred).float()) / bsz return acc
def validate(device, model, test_data, beam_size=5, alpha=0.6, max_time_step=100, dump_path=None): """For Development Only""" ref_stream = [] sys_stream = [] topk_sys_retr_stream = [] for batch in test_data: batch = move_to_device(batch, device) res, _ = generate_batch(model, batch, beam_size, alpha, max_time_step) sys_stream.extend(res) ref_stream.extend(batch['tgt_raw_sents']) sys_retr = batch.get('retrieval_raw_sents', None) if sys_retr: topk_sys_retr_stream.extend(sys_retr) assert len(sys_stream) == len(ref_stream) sys_stream = [ re.sub(r'(@@ )|(@@ ?$)', '', ' '.join(o)) for o in sys_stream ] ref_stream = [ re.sub(r'(@@ )|(@@ ?$)', '', ' '.join(o)) for o in ref_stream ] ref_streams = [ref_stream] bleu = sacrebleu.corpus_bleu(sys_stream, ref_streams, force=True, lowercase=False, tokenize='none').score sys_retr_streams = [] if topk_sys_retr_stream: assert len(topk_sys_retr_stream) == len(ref_stream) topk = len(topk_sys_retr_stream[0]) for i in range(topk): sys_retr_stream = [ re.sub(r'(@@ )|(@@ ?$)', '', ' '.join(o[i])) for o in topk_sys_retr_stream ] lratio = [] for aa, bb in zip(sys_retr_stream, ref_stream): laa = len(aa.split()) lbb = len(bb.split()) lratio.append(max(laa / lbb, lbb / laa)) bleu_retr = sacrebleu.corpus_bleu(sys_retr_stream, ref_streams, force=True, lowercase=False, tokenize='none').score sys_retr_streams.append(sys_retr_stream) logger.info("Retrieval top%d bleu %.2f length ratio %.2f", i + 1, bleu_retr, sum(lratio) / len(lratio)) # logger.info("show some examples >>>") # for sample_id in [5, 6, 11, 22, 33, 44, 55, 66, 555, 666]: # retrieval = [ "%d: %s"%(i, sys_retr_streams[i][sample_id]) for i in range(topk)] # logger.info("%d: %s###\n generation: %s###\nretrieval:\n %s", sample_id, ref_stream[sample_id], sys_stream[sample_id], '\n'.join(retrieval)) # logger.info("<<< show some examples") if dump_path is not None: results = { 'sys_stream': sys_stream, 'ref_stream': ref_stream, 'sys_retr_streams': sys_retr_streams } json.dump(results, open(dump_path, 'w')) return bleu
model, test_data, beam_size=args.beam_size, alpha=args.alpha, max_time_step=args.max_time_step, dump_path=args.dump_path) logger.info("%s %s %.2f", test_model, args.test_data, bleu) if args.output_path is not None: start_time = time.time() TOT = len(test_data) DONE = 0 logger.info("%d/%d", DONE, TOT) outs, indices = [], [] for batch in test_data: batch = move_to_device(batch, device) res, ind = generate_batch(model, batch, args.beam_size, args.alpha, args.max_time_step) for out_tokens, index in zip(res, ind): if args.retain_bpe: out_line = ' '.join(out_tokens) else: out_line = re.sub(r'(@@ )|(@@ ?$)', '', ' '.join(out_tokens)) DONE += 1 if DONE % 10000 == -1 % 10000: logger.info("%d/%d", DONE, TOT) outs.append(out_line) indices.append(index) end_time = time.time() logger.info("Time elapsed: %f", end_time - start_time)
def main(args, local_rank): vocabs = dict() vocabs['tok'] = Vocab(args.tok_vocab, 5, [CLS]) vocabs['lem'] = Vocab(args.lem_vocab, 5, [CLS]) vocabs['pos'] = Vocab(args.pos_vocab, 5, [CLS]) vocabs['ner'] = Vocab(args.ner_vocab, 5, [CLS]) vocabs['predictable_concept'] = Vocab(args.predictable_concept_vocab, 10, [DUM, END]) vocabs['concept'] = Vocab(args.concept_vocab, 5, [DUM, END]) vocabs['rel'] = Vocab(args.rel_vocab, 50, [NIL]) vocabs['word_char'] = Vocab(args.word_char_vocab, 100, [CLS, END]) vocabs['concept_char'] = Vocab(args.concept_char_vocab, 100, [CLS, END]) lexical_mapping = LexicalMap(args.lexical_mapping) if args.pretrained_word_embed is not None: vocab, pretrained_embs = load_pretrained_word_embed( args.pretrained_word_embed) vocabs['glove'] = vocab else: pretrained_embs = None for name in vocabs: print((name, vocabs[name].size)) torch.manual_seed(19940117) torch.cuda.manual_seed_all(19940117) random.seed(19940117) device = torch.device('cuda', local_rank) #print(device) #exit() model = Parser(vocabs, args.word_char_dim, args.word_dim, args.pos_dim, args.ner_dim, args.concept_char_dim, args.concept_dim, args.cnn_filters, args.char2word_dim, args.char2concept_dim, args.embed_dim, args.ff_embed_dim, args.num_heads, args.dropout, args.snt_layers, args.graph_layers, args.inference_layers, args.rel_dim, pretrained_embs, device=device) if args.world_size > 1: torch.manual_seed(19940117 + dist.get_rank()) torch.cuda.manual_seed_all(19940117 + dist.get_rank()) random.seed(19940117 + dist.get_rank()) model = model.cuda(local_rank) train_data = DataLoader(vocabs, lexical_mapping, args.train_data, args.train_batch_size, for_train=True) dev_data = DataLoader(vocabs, lexical_mapping, args.dev_data, args.dev_batch_size, for_train=True) train_data.set_unk_rate(args.unk_rate) weight_decay_params = [] no_weight_decay_params = [] for name, param in model.named_parameters(): if name.endswith('bias') or 'layer_norm' in name: no_weight_decay_params.append(param) else: weight_decay_params.append(param) grouped_params = [{ 'params': weight_decay_params, 'weight_decay': 1e-4 }, { 'params': no_weight_decay_params, 'weight_decay': 0. }] optimizer = AdamWeightDecayOptimizer(grouped_params, lr=args.lr, betas=(0.9, 0.999), eps=1e-6) batches_acm, loss_acm, concept_loss_acm, arc_loss_acm, rel_loss_acm = 0, 0, 0, 0, 0 #model.load_state_dict(torch.load('./ckpt/epoch297_batch49999')['model']) discarded_batches_acm = 0 queue = mp.Queue(10) train_data_generator = mp.Process(target=data_proc, args=(train_data, queue)) train_data_generator.start() used_batches = 0 if args.resume_ckpt: ckpt = torch.load(args.resume_ckpt) model.load_state_dict(ckpt['model']) optimizer.load_state_dict(ckpt['optimizer']) batches_acm = ckpt['batches_acm'] del ckpt model.train() epoch = 0 while True: batch = queue.get() #print("epoch",epoch) #print("batches_acm",batches_acm) #print("used_batches",used_batches) if isinstance(batch, str): epoch += 1 print('epoch', epoch, 'done', 'batches', batches_acm) else: batch = move_to_device(batch, model.device) concept_loss, arc_loss, rel_loss = model(batch) loss = (concept_loss + arc_loss + rel_loss) / args.batches_per_update loss_value = loss.item() concept_loss_value = concept_loss.item() arc_loss_value = arc_loss.item() rel_loss_value = rel_loss.item() if batches_acm > args.warmup_steps and arc_loss_value > 5. * ( arc_loss_acm / batches_acm): discarded_batches_acm += 1 print('abnormal', concept_loss.item(), arc_loss.item(), rel_loss.item()) continue loss_acm += loss_value concept_loss_acm += concept_loss_value arc_loss_acm += arc_loss_value rel_loss_acm += rel_loss_value loss.backward() used_batches += 1 if not (used_batches % args.batches_per_update == -1 % args.batches_per_update): continue batches_acm += 1 if args.world_size > 1: average_gradients(model) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) update_lr(optimizer, args.embed_dim, batches_acm, args.warmup_steps) optimizer.step() optimizer.zero_grad() if args.world_size == 1 or (dist.get_rank() == 0): if batches_acm % args.print_every == -1 % args.print_every: print( 'Train Epoch %d, Batch %d, Discarded Batch %d, conc_loss %.3f, arc_loss %.3f, rel_loss %.3f' % (epoch, batches_acm, discarded_batches_acm, concept_loss_acm / batches_acm, arc_loss_acm / batches_acm, rel_loss_acm / batches_acm)) model.train() if batches_acm % args.eval_every == -1 % args.eval_every: model.eval() torch.save( { 'args': args, 'model': model.state_dict(), 'batches_acm': batches_acm, 'optimizer': optimizer.state_dict() }, '%s/epoch%d_batch%d' % (args.ckpt, epoch, batches_acm)) model.train()
vocabs['concept'] = Vocab(args.concept_vocab, 5, [CLS]) vocabs['token'] = Vocab(args.token_vocab, 5, [STR, END]) vocabs['predictable_token'] = Vocab(args.predictable_token_vocab, 5, [END]) vocabs['token_char'] = Vocab(args.token_char_vocab, 100, [STR, END]) vocabs['concept_char'] = Vocab(args.concept_char_vocab, 100, [STR, END]) vocabs['relation'] = Vocab(args.relation_vocab, 5, [CLS, rCLS, SEL, TL]) lexical_mapping = LexicalMap() train_data = DataLoader(vocabs, lexical_mapping, args.train_data, args.train_batch_size, for_train=True) epoch_idx = 0 batch_idx = 0 last = 0 while True: st = time.time() for d in train_data: d = move_to_device(d, torch.device('cpu')) batch_idx += 1 #if d['concept'].size(0) > 5: # continue print (epoch_idx, batch_idx, d['concept'].size(), d['token_in'].size()) c_len, bsz = d['concept'].size() t_len, bsz = d['token_in'].size() print (bsz, c_len*bsz, t_len * bsz) #print (d['relation_bank'].size()) #print (d['relation'].size()) #_back_to_txt_for_check(d['concept'], vocabs['concept']) #for x in d['concept_depth'].t().tolist(): # print (x) #_back_to_txt_for_check(d['token_in'], vocabs['token']) #_back_to_txt_for_check(d['token_out'], vocabs['predictable_token'], d['local_idx2token'])
def main(args, local_rank): vocabs = dict() vocabs['concept'] = Vocab(args.concept_vocab, 5, [CLS]) vocabs['token'] = Vocab(args.token_vocab, 5, [STR, END]) vocabs['predictable_token'] = Vocab(args.predictable_token_vocab, 5, [END]) vocabs['token_char'] = Vocab(args.token_char_vocab, 100, [STR, END]) vocabs['concept_char'] = Vocab(args.concept_char_vocab, 100, [STR, END]) vocabs['relation'] = Vocab(args.relation_vocab, 5, [CLS, rCLS, SEL, TL]) lexical_mapping = LexicalMap() for name in vocabs: print((name, vocabs[name].size, vocabs[name].coverage)) torch.manual_seed(19940117) torch.cuda.manual_seed_all(19940117) random.seed(19940117) #device = torch.device('cpu') device = torch.device('cuda', local_rank) model = Generator(vocabs, args.token_char_dim, args.token_dim, args.concept_char_dim, args.concept_dim, args.cnn_filters, args.char2word_dim, args.char2concept_dim, args.rel_dim, args.rnn_hidden_size, args.rnn_num_layers, args.embed_dim, args.ff_embed_dim, args.num_heads, args.dropout, args.snt_layers, args.graph_layers, args.inference_layers, args.pretrained_file, device) if args.world_size > 1: torch.manual_seed(19940117 + dist.get_rank()) torch.cuda.manual_seed_all(19940117 + dist.get_rank()) random.seed(19940117 + dist.get_rank()) model = model.to(device) train_data = DataLoader(vocabs, lexical_mapping, args.train_data, args.train_batch_size, for_train=True) #dev_data = DataLoader(vocabs, lexical_mapping, args.dev_data, args.dev_batch_size, for_train=False) train_data.set_unk_rate(args.unk_rate) weight_decay_params = [] no_weight_decay_params = [] for name, param in model.named_parameters(): if name.endswith('bias') or 'layer_norm' in name: no_weight_decay_params.append(param) else: weight_decay_params.append(param) grouped_params = [{ 'params': weight_decay_params, 'weight_decay': 1e-4 }, { 'params': no_weight_decay_params, 'weight_decay': 0. }] optimizer = AdamWeightDecayOptimizer(grouped_params, lr=args.lr, betas=(0.9, 0.999), eps=1e-6) batches_acm, loss_acm = 0, 0 discarded_batches_acm = 0 queue = mp.Queue(10) train_data_generator = mp.Process(target=data_proc, args=(train_data, queue)) train_data_generator.start() model.train() epoch = 0 while batches_acm < args.total_train_steps: batch = queue.get() if isinstance(batch, str): epoch += 1 print('epoch', epoch, 'done', 'batches', batches_acm) continue batch = move_to_device(batch, device) loss = model(batch) exit(0) loss_value = loss.item() if batches_acm > args.warmup_steps and loss_value > 5. * (loss_acm / batches_acm): discarded_batches_acm += 1 print('abnormal', loss_value) continue loss_acm += loss_value batches_acm += 1 loss.backward() if args.world_size > 1: average_gradients(model) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) update_lr(optimizer, args.embed_dim, batches_acm, args.warmup_steps) optimizer.step() optimizer.zero_grad() #------------ if args.world_size == 1 or (dist.get_rank() == 0): if batches_acm % args.print_every == -1 % args.print_every: print( 'Train Epoch %d, Batch %d, Discarded Batch %d, loss %.3f' % (epoch, batches_acm, discarded_batches_acm, loss_acm / batches_acm)) model.train() if batches_acm > args.warmup_steps and batches_acm % args.eval_every == -1 % args.eval_every: #model.eval() #bleu, chrf = validate(model, dev_data) #print ("epoch", "batch", "bleu", "chrf") #print (epoch, batches_acm, bleu, chrf) torch.save({ 'args': args, 'model': model.state_dict() }, '%s/epoch%d_batch%d' % (args.ckpt, epoch, batches_acm)) model.train()
def main(args, local_rank): logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) vocabs = dict() vocabs['src'] = Vocab(args.src_vocab, 0, [BOS, EOS]) vocabs['tgt'] = Vocab(args.tgt_vocab, 0, [BOS, EOS]) if args.world_size == 1 or (dist.get_rank() == 0): logger.info(args) for name in vocabs: logger.info("vocab %s, size %d, coverage %.3f", name, vocabs[name].size, vocabs[name].coverage) set_seed(19940117) #device = torch.device('cpu') torch.cuda.set_device(local_rank) device = torch.device('cuda', local_rank) if args.arch == 'vanilla': model = Generator(vocabs, args.embed_dim, args.ff_embed_dim, args.num_heads, args.dropout, args.enc_layers, args.dec_layers, args.label_smoothing) elif args.arch == 'mem': model = MemGenerator(vocabs, args.embed_dim, args.ff_embed_dim, args.num_heads, args.dropout, args.mem_dropout, args.enc_layers, args.dec_layers, args.mem_enc_layers, args.label_smoothing, args.use_mem_score) elif args.arch == 'rg': logger.info("start building model") logger.info("building retriever") retriever = Retriever.from_pretrained( args.num_retriever_heads, vocabs, args.retriever, args.nprobe, args.topk, local_rank, use_response_encoder=(args.rebuild_every > 0)) logger.info("building retriever + generator") model = RetrieverGenerator(vocabs, retriever, args.share_encoder, args.embed_dim, args.ff_embed_dim, args.num_heads, args.dropout, args.mem_dropout, args.enc_layers, args.dec_layers, args.mem_enc_layers, args.label_smoothing) if args.resume_ckpt: model.load_state_dict(torch.load(args.resume_ckpt)['model']) else: global_step = 0 if args.world_size > 1: set_seed(19940117 + dist.get_rank()) model = model.to(device) retriever_params = [ v for k, v in model.named_parameters() if k.startswith('retriever.') ] other_params = [ v for k, v in model.named_parameters() if not k.startswith('retriever.') ] optimizer = Adam([{ 'params': retriever_params, 'lr': args.embed_dim**-0.5 * 0.1 }, { 'params': other_params, 'lr': args.embed_dim**-0.5 }], betas=(0.9, 0.98), eps=1e-9) lr_schedule = get_inverse_sqrt_schedule_with_warmup( optimizer, args.warmup_steps, args.total_train_steps) train_data = DataLoader(vocabs, args.train_data, args.per_gpu_train_batch_size, for_train=True, rank=local_rank, num_replica=args.world_size) model.eval() #dev_data = DataLoader(vocabs, cur_dev_data, args.dev_batch_size, for_train=False) #bleu = validate(device, model, dev_data, beam_size=5, alpha=0.6, max_time_step=10) step, epoch = 0, 0 tr_stat = Statistics() logger.info("start training") model.train() best_dev_bleu = 0. while global_step <= args.total_train_steps: for batch in train_data: #step_start = time.time() batch = move_to_device(batch, device) if args.arch == 'rg': loss, acc = model( batch, update_mem_bias=(global_step > args.update_retriever_after)) else: loss, acc = model(batch) tr_stat.update({ 'loss': loss.item() * batch['tgt_num_tokens'], 'tokens': batch['tgt_num_tokens'], 'acc': acc }) tr_stat.step() loss.backward() #step_cost = time.time() - step_start #print ('step_cost', step_cost) step += 1 if not (step % args.gradient_accumulation_steps == -1 % args.gradient_accumulation_steps): continue if args.world_size > 1: average_gradients(model) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() lr_schedule.step() optimizer.zero_grad() global_step += 1 if args.world_size == 1 or (dist.get_rank() == 0): if global_step % args.print_every == -1 % args.print_every: logger.info("epoch %d, step %d, loss %.3f, acc %.3f", epoch, global_step, tr_stat['loss'] / tr_stat['tokens'], tr_stat['acc'] / tr_stat['tokens']) tr_stat = Statistics() if global_step % args.eval_every == -1 % args.eval_every: model.eval() max_time_step = 256 if global_step > 2 * args.warmup_steps else 5 bleus = [] for cur_dev_data in args.dev_data: dev_data = DataLoader(vocabs, cur_dev_data, args.dev_batch_size, for_train=False) bleu = validate(device, model, dev_data, beam_size=5, alpha=0.6, max_time_step=max_time_step) bleus.append(bleu) bleu = sum(bleus) / len(bleus) logger.info("epoch %d, step %d, dev bleu %.2f", epoch, global_step, bleu) if bleu > best_dev_bleu: testbleus = [] for cur_test_data in args.test_data: test_data = DataLoader(vocabs, cur_test_data, args.dev_batch_size, for_train=False) testbleu = validate(device, model, test_data, beam_size=5, alpha=0.6, max_time_step=max_time_step) testbleus.append(testbleu) testbleu = sum(testbleus) / len(testbleus) logger.info("epoch %d, step %d, test bleu %.2f", epoch, global_step, testbleu) torch.save({ 'args': args, 'model': model.state_dict() }, '%s/best.pt' % (args.ckpt, )) if not args.only_save_best: torch.save( { 'args': args, 'model': model.state_dict() }, '%s/epoch%d_batch%d_devbleu%.2f_testbleu%.2f' % (args.ckpt, epoch, global_step, bleu, testbleu)) best_dev_bleu = bleu model.train() if args.rebuild_every > 0 and (global_step % args.rebuild_every == -1 % args.rebuild_every): model.retriever.drop_index() torch.cuda.empty_cache() next_index_dir = '%s/batch%d' % (args.ckpt, global_step) if args.world_size == 1 or (dist.get_rank() == 0): model.retriever.rebuild_index(next_index_dir) dist.barrier() else: dist.barrier() model.retriever.update_index(next_index_dir, args.nprobe) if global_step > args.total_train_steps: break epoch += 1 logger.info('rank %d, finish training after %d steps', local_rank, global_step)
def validation_loop(model, dl, batch_size: int, epoch: int, device: str): # start validation total_val_iter = int(len(dl.dataset) / batch_size) model.eval() print("running validation...") all_detections = [] all_losses = [] for batch, targets in tqdm(dl, total=total_val_iter): batch, targets = move_to_device(batch, targets, device=device) detections, losses = model.validation_step(batch, targets) all_losses.append(losses) all_detections.append(detections) # evalute RPN iou_thresholds = torch.arange(0.5, 1.0, 0.05) rpn_predictions = [] rpn_ground_truths = [] for dets in all_detections: rpn_predictions += dets['rpn']['predictions'] rpn_ground_truths += dets['rpn']['ground_truths'] rpn_recalls = roi_recalls(rpn_predictions, rpn_ground_truths, iou_thresholds=iou_thresholds) # evalute FastRCNN head_predictions = [] head_ground_truths = [] for dets in all_detections: head_predictions += dets['head']['predictions'] head_ground_truths += dets['head']['ground_truths'] head_predictions = [pred[:, :5] for pred in head_predictions] head_ground_truths = [pred[:, :4] for pred in head_ground_truths] AP50 = calculate_AP(head_predictions, head_ground_truths, iou_threshold=0.5) AP75 = calculate_AP(head_predictions, head_ground_truths, iou_threshold=0.75) AP90 = calculate_AP(head_predictions, head_ground_truths, iou_threshold=0.90) AP = (AP50 + AP75 + AP90) / 3 means = caclulate_means(all_losses) print(f"--validation results for epoch {epoch+1} --") print(f"RPN mean recall at iou thresholds are:") for iou_threshold, rpn_recall in zip(iou_thresholds.cpu().numpy(), rpn_recalls.cpu().numpy() * 100): print(f"IoU={iou_threshold:.02f} recall={int(rpn_recall)}") print(f"HEAD AP IoU=.5 :{AP50.item()*100:.02f}") print(f"HEAD AP IoU=.75 :{AP75.item()*100:.02f}") print(f"HEAD AP IoU=.90 :{AP90.item()*100:.02f}") print(f"HEAD AP IoU=.5:.95 :{AP.item()*100:.02f}") for k, v in means.items(): print(f"{k}: {v:.4f}") print("--------------------------------------------")
def prepare_incremental_input(self, step_seq): token = ListsToTensor(step_seq, self.vocabs['token']) token_char = ListsofStringToTensor(step_seq, self.vocabs['token_char']) token, token_char = move_to_device(token, self.device), move_to_device( token_char, self.device) return token, token_char
def main(args): logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) logger.info('Loading model...') device = torch.device('cuda', 0) vocab = Vocab(args.vocab_path, 0, [BOS, EOS]) model_args = torch.load(args.args_path) model = ProjEncoder.from_pretrained(vocab, model_args, args.ckpt_path) model.to(device) logger.info('Collecting data...') data_r = [] with open(args.index_file) as f: for line in f.readlines(): r = line.strip() data_r.append(r) data_q = [] data_qr = [] with open(args.input_file, 'r') as f: for line in f.readlines(): q, r = line.strip().split('\t') data_q.append(q) data_qr.append(r) logger.info('Collected %d instances', len(data_q)) textq, textqr, textr = data_q, data_qr, data_r data_loader = DataLoader(data_q, vocab, args.batch_size) mips = MIPS.from_built(args.index_path, nprobe=args.nprobe) max_norm = torch.load(os.path.dirname(args.index_path) + '/max_norm.pt') mips.to_gpu() model.cuda() model = torch.nn.DataParallel(model, device_ids=list( range(torch.cuda.device_count()))) model.eval() logger.info('Start search') cur, tot = 0, len(data_q) with open(args.output_file, 'w') as fo: for batch in asynchronous_load(data_loader): with torch.no_grad(): q = move_to_device(batch, torch.device('cuda')).t() bsz = q.size(0) vecsq = model(q, batch_first=True).detach().cpu().numpy() vecsq = augment_query(vecsq) D, I = mips.search(vecsq, args.topk + 1) D = l2_to_ip(D, vecsq, max_norm) / (max_norm * max_norm) for i, (Ii, Di) in enumerate(zip(I, D)): item = [textq[cur + i], textqr[cur + i]] for pred, s in zip(Ii, Di): if args.allow_hit or textr[pred] != textqr[cur + i]: item.append(textr[pred]) item.append(str(float(s))) item = item[:2 + 2 * args.topk] assert len(item) == 2 + 2 * args.topk fo.write('\t'.join(item) + '\n') cur += bsz logger.info('finished %d / %d', cur, tot)
def main(args, local_rank): logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO) vocabs = dict() vocabs['src'] = Vocab(args.src_vocab, 0, [BOS, EOS]) vocabs['tgt'] = Vocab(args.tgt_vocab, 0, [BOS, EOS]) if args.world_size == 1 or (dist.get_rank() == 0): logger.info(args) for name in vocabs: logger.info("vocab %s, size %d, coverage %.3f", name, vocabs[name].size, vocabs[name].coverage) set_seed(19940117) #device = torch.device('cpu') torch.cuda.set_device(local_rank) device = torch.device('cuda', local_rank) if args.resume_ckpt: model = MatchingModel.from_pretrained(vocabs, args.resume_ckpt) else: model = MatchingModel.from_params(vocabs, args.layers, args.embed_dim, args.ff_embed_dim, args.num_heads, args.dropout, args.output_dim, args.bow) if args.world_size > 1: set_seed(19940117 + dist.get_rank()) model = model.to(device) if args.resume_ckpt: dev_data = DataLoader(vocabs, args.dev_data, args.dev_batch_size, addition=args.additional_negs) acc = validate(model, dev_data, device) logger.info("initialize from %s, initial acc %.2f", args.resume_ckpt, acc) optimizer = Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98), eps=1e-9) lr_schedule = get_linear_schedule_with_warmup(optimizer, args.warmup_steps, args.total_train_steps) train_data = DataLoader(vocabs, args.train_data, args.per_gpu_train_batch_size, worddrop=args.worddrop, addition=args.additional_negs) global_step, step, epoch = 0, 0, 0 tr_stat = Statistics() logger.info("start training") model.train() while global_step <= args.total_train_steps: for batch in train_data: batch = move_to_device(batch, device) loss, acc, bsz = model(batch['src_tokens'], batch['tgt_tokens'], args.label_smoothing) tr_stat.update({ 'loss': loss.item() * bsz, 'nsamples': bsz, 'acc': acc * bsz }) tr_stat.step() loss.backward() step += 1 if not (step % args.gradient_accumulation_steps == -1 % args.gradient_accumulation_steps): continue if args.world_size > 1: average_gradients(model) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() lr_schedule.step() optimizer.zero_grad() global_step += 1 if args.world_size == 1 or (dist.get_rank() == 0): if global_step % args.print_every == -1 % args.print_every: logger.info("epoch %d, step %d, loss %.3f, acc %.3f", epoch, global_step, tr_stat['loss'] / tr_stat['nsamples'], tr_stat['acc'] / tr_stat['nsamples']) tr_stat = Statistics() if global_step > args.warmup_steps and global_step % args.eval_every == -1 % args.eval_every: dev_data = DataLoader(vocabs, args.dev_data, args.dev_batch_size, addition=args.additional_negs) acc = validate(model, dev_data, device) logger.info("epoch %d, step %d, dev, dev acc %.2f", epoch, global_step, acc) save_path = '%s/epoch%d_batch%d_acc%.2f' % ( args.ckpt, epoch, global_step, acc) model.save(args, save_path) model.train() if global_step > args.total_train_steps: break epoch += 1 logger.info('rank %d, finish training after %d steps', local_rank, global_step)