def supervise_evaluate_loop(agent, dev_it, dataset='iwslt', pair='fr_en'): dev_metrics = Metrics('s2p_dev', *['nll']) with torch.no_grad(): agent.eval() trg_corpus, hyp_corpus = [], [] for j, dev_batch in enumerate(dev_it): if dataset == "iwslt" or dataset == 'iwslt_small': src, src_len = dev_batch.src trg, trg_len = dev_batch.trg trg_field = dev_batch.dataset.fields['trg'] elif dataset == "multi30k": src_lang, trg_lang = pair.split("_") src, src_len = dev_batch.__dict__[src_lang] trg, trg_len = dev_batch.__dict__[trg_lang] trg_field = dev_batch.dataset.fields[trg_lang] logits, _ = agent(src[:, 1:], src_len - 1, trg[:, :-1]) nll = F.cross_entropy(logits, trg[:, 1:].contiguous().view(-1), reduction='mean', ignore_index=0) num_trg = (trg_len - 1).sum().item() dev_metrics.accumulate(num_trg, **{'nll': nll}) hyp = agent.decode(src, src_len, 'greedy', 0) trg_corpus.extend(trg_field.reverse(trg, unbpe=True)) hyp_corpus.extend(trg_field.reverse(hyp, unbpe=True)) bleu = computeBLEU(hyp_corpus, trg_corpus, corpus=True) return dev_metrics, bleu
def get_fr_en_imitate_stats(args, model, dev_it, monitor_names, extra_input): """ En BLUE, LM score and img prediction """ model.eval() eval_metrics = Metrics('dev_loss', *monitor_names, data_type="avg") eval_metrics.reset() with torch.no_grad(): unbpe = True en_corpus = [] en_hyp = [] for j, dev_batch in enumerate(dev_it): en_corpus.extend(args.EN.reverse(dev_batch.en[0], unbpe=unbpe)) en_msg, en_msg_len = model.fr_en_speak(dev_batch, is_training=False) en_hyp.extend(args.EN.reverse(en_msg, unbpe=unbpe)) results, _ = model.get_grounding( en_msg, en_msg_len, dev_batch, en_lm=extra_input["en_lm"], all_img=extra_input["img"]['multi30k'][1], ranker=extra_input["ranker"]) # Get entropy neg_Hs = model.fr_en.dec.neg_Hs # (batch_size, en_msg_len) neg_Hs = neg_Hs.mean() # (1,) results["neg_Hs"] = neg_Hs if len(monitor_names) > 0: eval_metrics.accumulate(len(dev_batch), **results) if args.debug: break bleu_en = computeBLEU(en_hyp, en_corpus, corpus=True) stats = eval_metrics.__dict__['metrics'] stats['bleu_en'] = bleu_en[0] return stats
def valid_model(args, model, valid_img_feats, valid_caps, valid_lens): model.eval() batch_size = 32 start = 0 val_metrics = Metrics('val_loss', 'loss', data_type="avg") with torch.no_grad(): while start <= valid_img_feats.shape[0]: cap_id = random.randint(0, 4) end = start + batch_size batch_img_feat = cuda(valid_img_feats[start: end]) batch_ens = cuda(valid_caps[cap_id][start: end]) batch_lens = cuda(valid_lens[cap_id][start: end]) R = model(batch_ens[:, 1:], batch_lens - 1, batch_img_feat) if args.img_pred_loss == "vse": R['loss'] = R['loss'].sum() elif args.img_pred_loss == "mse": R['loss'] = R['loss'].mean() else: raise ValueError val_metrics.accumulate(batch_size, R['loss']) start = end return val_metrics
def evaluate_communication(self): """ Use greedy decoding and check scores like BLEU, language model and grounding """ eval_metrics = Metrics('dev_loss', *self.monitor_names, data_type="avg") eval_metrics.reset() with torch.no_grad(): unbpe = True self.model.eval() en_corpus, de_corpus = [], [] en_hyp, de_hyp = [], [] for j, dev_batch in enumerate(self.dev_it): en_corpus.extend( self.args.EN.reverse(dev_batch.en[0], unbpe=unbpe)) de_corpus.extend( self.args.DE.reverse(dev_batch.de[0], unbpe=unbpe)) en_msg, de_msg, en_msg_len, _ = self.model.decode(dev_batch) en_hyp.extend(self.args.EN.reverse(en_msg, unbpe=unbpe)) de_hyp.extend(self.args.DE.reverse(de_msg, unbpe=unbpe)) results, _ = self.model.get_grounding( en_msg, en_msg_len, dev_batch, en_lm=self.extra_input["en_lm"], all_img=self.extra_input["img"]['multi30k'][1], ranker=self.extra_input["ranker"]) # Get entropy neg_Hs = self.model.fr_en.dec.neg_Hs # (batch_size, en_msg_len) neg_Hs = neg_Hs.mean() # (1,) results["neg_Hs"] = neg_Hs if len(self.monitor_names) > 0: eval_metrics.accumulate(len(dev_batch), **results) bleu_en = computeBLEU(en_hyp, en_corpus, corpus=True) bleu_de = computeBLEU(de_hyp, de_corpus, corpus=True) self.args.logger.info(eval_metrics) self.args.logger.info("Fr-En {} : {}".format( 'valid', print_bleu(bleu_en))) self.args.logger.info("En-De {} : {}".format( 'valid', print_bleu(bleu_de))) return eval_metrics, bleu_en, bleu_de, en_corpus, en_hyp, de_hyp
def train_model(args, model, iterators): (train_it, dev_it) = iterators if not args.debug: decoding_path = Path(join(args.decoding_path, args.id_str)) decoding_path.mkdir(parents=True, exist_ok=True) from tensorboardX import SummaryWriter writer = SummaryWriter(join(args.event_path, args.id_str)) params = [p for p in model.parameters() if p.requires_grad] if args.optimizer == 'Adam': opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9, lr=args.lr) else: raise NotImplementedError extra_loss_names = [] train_metrics = Metrics('train_loss', 'nll', *extra_loss_names, data_type="avg") dev_metrics = Metrics('dev_loss', 'nll', *extra_loss_names, data_type="avg") best = Best(max, 'dev_bleu', 'iters', model=model, opt=opt, path=join(args.model_path, args.id_str), gpu=args.gpu, debug=args.debug) for iters, train_batch in enumerate(train_it): if iters >= args.max_training_steps: args.logger.info( 'stopping training after {} training steps'.format( args.max_training_steps)) break if not args.debug and hasattr( args, 'save_every') and iters % args.save_every == 0: args.logger.info( 'save (back-up) checkpoints at iters={}'.format(iters)) with torch.cuda.device(args.gpu): torch.save( model.state_dict(), '{}_iter={}.pt'.format(args.model_path + args.id_str, iters)) torch.save([iters, opt.state_dict()], '{}_iter={}.pt.states'.format( args.model_path + args.id_str, iters)) if iters % args.eval_every == 0: dev_metrics.reset() dev_bleu = valid_model(args, model, dev_it, dev_metrics, 'argmax') if not args.debug: write_tb(writer, ['nll'], [dev_metrics.nll], iters, prefix="dev/") write_tb(writer, [ 'bleu', *("p_1 p_2 p_3 p_4".split()), 'bp', 'len_ref', 'len_hyp' ], dev_bleu, iters, prefix="bleu/") best.accumulate(dev_bleu[0], iters) args.logger.info(best) """ if args.early_stop and (iters - best.iters) // args.eval_every > args.patience: args.logger.info("Early stopping.") break """ model.train() def get_lr_anneal(iters): lr_end = 1e-5 return max(0, (args.lr - lr_end) * (args.linear_anneal_steps - iters) / args.linear_anneal_steps) + lr_end if args.lr_anneal == "linear": opt.param_groups[0]['lr'] = get_lr_anneal(iters) opt.zero_grad() batch_size = len(train_batch) if args.dataset == "iwslt" or args.dataset == 'iwslt_small': src, src_len = train_batch.src trg, trg_len = train_batch.trg elif args.dataset == "multi30k": src_lang, trg_lang = args.pair.split("_") src, src_len = train_batch.__dict__[src_lang] trg, trg_len = train_batch.__dict__[trg_lang] else: raise ValueError # NOTE encoder never receives <BOS> token # because during communication, Agent A's decoder will never output <BOS> logits, _ = model(src[:, 1:], src_len - 1, trg[:, :-1]) nll = F.cross_entropy(logits, trg[:, 1:].contiguous().view(-1), reduction='mean', ignore_index=0) num_trg = (trg_len - 1).sum().item() train_metrics.accumulate(num_trg, nll.item()) if args.grad_clip > 0: total_norm = nn.utils.clip_grad_norm_(params, args.grad_clip) nll.backward() opt.step() if iters % args.print_every == 0: args.logger.info("update {} : {}".format(iters, str(train_metrics))) if not args.debug: write_tb(writer, ['nll', 'lr'], [train_metrics.nll, opt.param_groups[0]['lr']], iters, prefix="train/") train_metrics.reset()
def train_model(args, model): resnet = torchvision.models.resnet152(pretrained=True) resnet = nn.Sequential(*list(resnet.children())[:-1]) resnet = nn.DataParallel(resnet).cuda() resnet.eval() if not args.debug: from tensorboardX import SummaryWriter writer = SummaryWriter( args.event_path + args.id_str) params = [p for p in model.parameters() if p.requires_grad] if args.optimizer == 'Adam': opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9, lr=args.lr) else: raise NotImplementedError loss_names, loss_cos = ["loss"], {"loss":1.0} monitor_names = "cap_r1 cap_r5 cap_r10 img_r1 img_r5 img_r10".split() train_metrics = Metrics('train_loss', *loss_names, data_type="avg") best = Best(max, 'r1', 'iters', model=model, opt=opt, path=args.model_path + args.id_str, \ gpu=args.gpu, debug=args.debug) # Train dataset args.logger.info("Loading train imgs...") train_dataset = ImageFolderWithPaths(os.path.join(args.data_dir, 'flickr30k'), preprocess_rc) train_imgs = open(os.path.join(args.data_dir, 'flickr30k/train.txt'), 'r').readlines() train_imgs = [x.strip() for x in train_imgs if x.strip() != ""] train_dataset.samples = [x for x in train_dataset.samples if x[0].split("/")[-1] in train_imgs] train_dataset.imgs = [x for x in train_dataset.imgs if x[0].split("/")[-1] in train_imgs] train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, pin_memory=False) args.logger.info("Train loader built!") en_vocab = TextVocab(counter=torch.load(os.path.join(args.data_dir, 'bpe/vocab.en.pth'))) word2idx = en_vocab.stoi train_en = [open(os.path.join(args.data_dir, 'flickr30k/caps', 'train.{}.bpe'.format(idx+1))).readlines() for idx in range(5)] train_en = [[["<bos>"] + sentence.strip().split() + ["<eos>"] for sentence in doc if sentence.strip() != "" ] for doc in train_en] train_en = [[[word2idx[word] for word in sentence] for sentence in doc] for doc in train_en] args.logger.info("Train corpus built!") # Valid dataset valid_img_feats = torch.tensor(torch.load(os.path.join(args.data_dir, 'flickr30k/val_feat.pth'))) valid_ens = [] valid_en_lens = [] for idx in range(5): valid_en = [] with open(os.path.join(args.data_dir, 'flickr30k/caps', 'val.{}.bpe'.format(idx+1))) as f: for line in f: line = line.strip() if line == "": continue words = ["<bos>"] + line.split() + ["<eos>"] words = [word2idx[word] for word in words] valid_en.append(words) # Pad valid_en_len = [len(sent) for sent in valid_en] valid_en = [np.lib.pad(xx, (0, max(valid_en_len) - len(xx)), 'constant', constant_values=(0, 0)) for xx in valid_en] valid_ens.append(torch.tensor(valid_en).long()) valid_en_lens.append(torch.tensor(valid_en_len).long()) args.logger.info("Valid corpus built!") iters = -1 should_stop = False for epoch in range(999999999): if should_stop: break for idx, (train_img, lab, path) in enumerate(train_loader): iters += 1 if iters > args.max_training_steps: should_stop = True break if iters % args.eval_every == 0: res = get_retrieve_result(args, model, valid_caps=valid_ens, valid_lens=valid_en_lens, valid_img_feats=valid_img_feats) val_metrics = valid_model(args, model, valid_img_feats=valid_img_feats, valid_caps=valid_ens, valid_lens=valid_en_lens) args.logger.info("[VALID] update {} : {}".format(iters, str(val_metrics))) if not args.debug: write_tb(writer, monitor_names, res, iters, prefix="dev/") write_tb(writer, loss_names, [val_metrics.__getattr__(name) for name in loss_names], iters, prefix="dev/") best.accumulate((res[0]+res[3])/2, iters) args.logger.info('model:' + args.prefix + args.hp_str) args.logger.info('epoch {} iters {}'.format(epoch, iters)) args.logger.info(best) if args.early_stop and (iters - best.iters) // args.eval_every > args.patience: args.logger.info("Early stopping.") return model.train() def get_lr_anneal(iters): lr_end = args.lr_min return max( 0, (args.lr - lr_end) * (args.linear_anneal_steps - iters) / args.linear_anneal_steps ) + lr_end if args.lr_anneal == "linear": opt.param_groups[0]['lr'] = get_lr_anneal(iters) opt.zero_grad() batch_size = len(path) path = [p.split("/")[-1] for p in path] sentence_idx = [train_imgs.index(p) for p in path] en = [train_en[random.randint(0, 4)][sentence_i] for sentence_i in sentence_idx] en_len = [len(x) for x in en] en = [ np.lib.pad( xx, (0, max(en_len) - len(xx)), 'constant', constant_values=(0,0) ) for xx in en ] en = cuda( torch.LongTensor( np.array(en).tolist() ) ) en_len = cuda( torch.LongTensor( en_len ) ) with torch.no_grad(): train_img = resnet(train_img).view(batch_size, -1) R = model(en[:,1:], en_len-1, train_img) if args.img_pred_loss == "vse": R['loss'] = R['loss'].sum() elif args.img_pred_loss == "mse": R['loss'] = R['loss'].mean() else: raise Exception() total_loss = 0 for loss_name in loss_names: total_loss += R[loss_name] * loss_cos[loss_name] train_metrics.accumulate(batch_size, *[R[name].item() for name in loss_names]) total_loss.backward() if args.plot_grad: plot_grad(writer, model, iters) if args.grad_clip > 0: nn.utils.clip_grad_norm_(params, args.grad_clip) opt.step() if iters % args.eval_every == 0: args.logger.info("update {} : {}".format(iters, str(train_metrics))) if iters % args.eval_every == 0 and not args.debug: write_tb(writer, loss_names, [train_metrics.__getattr__(name) for name in loss_names], \ iters, prefix="train/") write_tb(writer, ['lr'], [opt.param_groups[0]['lr']], iters, prefix="train/") train_metrics.reset()
def train_model(args, model, iterators, extra_input): (train_its, dev_its) = iterators if not args.debug: from tensorboardX import SummaryWriter writer = SummaryWriter(args.event_path + args.id_str) params = [p for p in model.parameters() if p.requires_grad] if args.optimizer == 'Adam': opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9, lr=args.lr) else: raise NotImplementedError loss_names, loss_cos = ["nll"], {"nll": 1.0} monitor_names = ["nll_rnd"] """ if args.rep_pen_co > 0.0: loss_names.append("nll_cur") loss_cos["nll_cur"] = -1 * args.rep_pen_co else: monitor_names.append("nll_cur") """ train_metrics = Metrics('train_loss', *loss_names, data_type="avg") dev_metrics = Metrics('dev_loss', *loss_names, *monitor_names, data_type="avg") best = Best(min, 'loss', 'iters', model=model, opt=opt, path=args.model_path + args.id_str, \ gpu=args.gpu, debug=args.debug) iters = 0 should_stop = False for epoch in range(999999999): if should_stop: break for dataset in args.dataset.split("_"): if should_stop: break train_it = train_its[dataset] for _, train_batch in enumerate(train_it): if iters >= args.max_training_steps: args.logger.info( 'stopping training after {} training steps'.format( args.max_training_steps)) should_stop = True break if iters % args.eval_every == 0: dev_metrics.reset() valid_model(args, model, dev_its['multi30k'], dev_metrics, iters, loss_names, monitor_names, extra_input) if not args.debug: write_tb(writer, loss_names, [dev_metrics.__getattr__(name) for name in loss_names], \ iters, prefix="dev/") write_tb(writer, monitor_names, [dev_metrics.__getattr__(name) for name in monitor_names], \ iters, prefix="dev/") best.accumulate(dev_metrics.nll, iters) args.logger.info('model:' + args.prefix + args.hp_str) args.logger.info('epoch {} dataset {} iters {}'.format( epoch, dataset, iters)) args.logger.info(best) if args.early_stop and ( iters - best.iters) // args.eval_every > args.patience: args.logger.info("Early stopping.") return model.train() def get_lr_anneal(iters): lr_end = args.lr_min return max(0, (args.lr - lr_end) * (args.linear_anneal_steps - iters) / args.linear_anneal_steps) + lr_end if args.lr_anneal == "linear": opt.param_groups[0]['lr'] = get_lr_anneal(iters) opt.zero_grad() batch_size = len(train_batch) img_input = None if args.no_img else cuda( extra_input["img"][dataset][0].index_select( dim=0, index=train_batch.idx.cpu())) if dataset == "coco": en, en_len = train_batch.__dict__[ "_" + str(random.randint(1, 5))] elif dataset == "multi30k": en, en_len = train_batch.en decoded = model(en, img_input) R = {} R["nll"] = F.cross_entropy(decoded, en[:, 1:].contiguous().view(-1), ignore_index=0) #R["nll_cur"] = F.cross_entropy( decoded, en[:,:-1].contiguous().view(-1), ignore_index=0 ) total_loss = 0 for loss_name in loss_names: total_loss += R[loss_name] * loss_cos[loss_name] train_metrics.accumulate( batch_size, *[R[name].item() for name in loss_names]) total_loss.backward() if args.plot_grad: plot_grad(writer, model, iters) if args.grad_clip > 0: nn.utils.clip_grad_norm_(params, args.grad_clip) opt.step() iters += 1 if iters % args.eval_every == 0: args.logger.info("update {} : {}".format( iters, str(train_metrics))) if iters % args.eval_every == 0 and not args.debug: write_tb(writer, loss_names, [train_metrics.__getattr__(name) for name in loss_names], \ iters, prefix="train/") write_tb(writer, ['lr'], [opt.param_groups[0]['lr']], iters, prefix="train/") train_metrics.reset()
def train_model(args, model, iterators, extra_input): (train_its, dev_its) = iterators if not args.debug: from tensorboardX import SummaryWriter writer = SummaryWriter(args.event_path + args.id_str) params = [p for p in model.parameters() if p.requires_grad] if args.optimizer == 'Adam': opt = torch.optim.Adam(params, betas=(0.9, 0.98), eps=1e-9, lr=args.lr) else: raise NotImplementedError loss_names, loss_cos = ["loss"], {"loss": 1.0} monitor_names = "cap_r1 cap_r5 cap_r10 img_r1 img_r5 img_r10".split() train_metrics = Metrics('train_loss', *loss_names, data_type="avg") best = Best(max, 'r1', 'iters', model=model, opt=opt, path=args.model_path + args.id_str, \ gpu=args.gpu, debug=args.debug) iters = 0 for epoch in range(999999999): for dataset in args.dataset.split("_"): train_it = train_its[dataset] for _, train_batch in enumerate(train_it): iters += 1 if iters % args.eval_every == 0: R = valid_model(args, model, dev_its['multi30k'], extra_input) if not args.debug: write_tb(writer, monitor_names, R, iters, prefix="dev/") best.accumulate((R[0] + R[3]) / 2, iters) args.logger.info('model:' + args.prefix + args.hp_str) args.logger.info('epoch {} dataset {} iters {}'.format( epoch, dataset, iters)) args.logger.info(best) if args.early_stop and ( iters - best.iters) // args.eval_every > args.patience: args.logger.info("Early stopping.") return model.train() def get_lr_anneal(iters): lr_end = args.lr_min return max(0, (args.lr - lr_end) * (args.linear_anneal_steps - iters) / args.linear_anneal_steps) + lr_end if args.lr_anneal == "linear": opt.param_groups[0]['lr'] = get_lr_anneal(iters) opt.zero_grad() batch_size = len(train_batch) img = extra_input["img"][dataset][0].index_select( dim=0, index=train_batch.idx.cpu()) # (batch_size, D_img) en, en_len = train_batch.__dict__["_" + str(random.randint(1, 5))] R = model(en[:, 1:], en_len - 1, cuda(img)) R['loss'] = R['loss'].mean() total_loss = 0 for loss_name in loss_names: total_loss += R[loss_name] * loss_cos[loss_name] train_metrics.accumulate( batch_size, *[R[name].item() for name in loss_names]) total_loss.backward() if args.plot_grad: plot_grad(writer, model, iters) if args.grad_clip > 0: nn.utils.clip_grad_norm_(params, args.grad_clip) opt.step() if iters % args.eval_every == 0: args.logger.info("update {} : {}".format( iters, str(train_metrics))) if iters % args.eval_every == 0 and not args.debug: write_tb(writer, loss_names, [train_metrics.__getattr__(name) for name in loss_names], \ iters, prefix="train/") write_tb(writer, ['lr'], [opt.param_groups[0]['lr']], iters, prefix="train/") train_metrics.reset()
def start(self): # Prepare Metrics train_metrics = Metrics('train_loss', *list(self.loss_cos.keys()), *self.monitor_names, data_type="avg") best = Best(max, 'de_bleu', 'en_bleu', 'iters', model=self.model, opt=self.opt, path=self.args.model_path + self.args.id_str, gpu=self.args.gpu, debug=self.args.debug) # Determine when to stop iterlearn iters = self.extra_input['resume']['iters'] if self.resume else 0 self.args.logger.info('Start Training at iters={}'.format(iters)) try: train_it = iter(self.train_it) while iters < self.args.max_training_steps: train_batch = train_it.__next__() if iters >= self.args.max_training_steps: self.args.logger.info( 'stopping training after {} training steps'.format( self.args.max_training_steps)) break self._maybe_save(iters) if iters % self.args.eval_every == 0: self.model.eval() self.evaluate(iters, best) self.supervise_evaluate(iters) if self.args.plot_grad: self._plot_grad(iters) self.model.train() self.train_step(iters, train_batch, train_metrics) if iters % self.args.eval_every == 0: self.args.logger.info("update {} : {}".format( iters, str(train_metrics))) train_stats = {} train_stats.update({ name: train_metrics.__getattr__(name) for name in self.loss_cos }) train_stats.update({ name: train_metrics.__getattr__(name) for name in self.monitor_names }) train_stats['lr'] = self.opt.param_groups[0]['lr'] write_tb(self.writer, train_stats, iters, prefix="train/") train_metrics.reset() iters += 1 except (InterruptedError, KeyboardInterrupt): # End Gracefully self.end_gracefully(iters) self.writer.flush() self.writer.close()