def running_evaluate_for_val2(model): # set global variable global runinng_val2_data_iter # init model.eval() with torch.no_grad(): # get one minibatch try: _, eval_context, eval_target = runinng_val2_data_iter.next() except: runinng_val2_data_iter = iter(val2_loader) _, eval_context, eval_target = runinng_val2_data_iter.next() # init batch eval_context = batch_to_device(eval_context, device) eval_target = batch_to_device(eval_target, device) num_episodes = len(eval_context) batch_size, mod_batch_sizes = get_batch_size(eval_target) # forward _, _, loss, info = model(eval_context, eval_target) # unpack info loss_likelihood, loss_kl = info['likelihood'], info['kl'] loss_mod_likelihoods = info['mod_likelihoods'] # add to total_loss total_loss = loss.item() / batch_size * num_modalities return total_loss
def eval_vae(epoch, args, trainer, eval_data): tokenizer = BertTokenizer.from_pretrained(args.bert_model) RawResult = collections.namedtuple( "RawResult", ["unique_id", "start_logits", "end_logits"]) eval_loader, eval_examples, eval_features = eval_data all_results = [] qa_results = [] qg_results = {} res_dict = {} example_index = -1 for batch in tqdm(eval_loader, desc="Eval iter", leave=False, position=3): c_ids, q_ids, a_ids = batch_to_device(batch, args.device) batch_size = c_ids.size(0) batch_q_ids = q_ids.cpu().tolist() generated_q_ids = trainer.model.generate(c_ids, a_ids) generated_q_ids = generated_q_ids.cpu().tolist() for i in range(batch_size): example_index += 1 eval_feature = eval_features[example_index] unique_id = int(eval_feature.unique_id) real_question = to_string(batch_q_ids[i], tokenizer) generated_question = to_string(generated_q_ids[i], tokenizer) qg_results[unique_id] = generated_question res_dict[unique_id] = real_question bleu = eval_qg(res_dict, qg_results) return bleu
def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.train_metrics.reset() tqdm_bar = tqdm(self.data_loader, desc='Train Epoch : {}'.format(epoch)) for batch_idx, batch in enumerate(tqdm_bar): data, target = batch_to_device(self.model_inputs, batch, self.device) # data, target = data.to(self.device), target.to(self.device) self.optimizer.zero_grad() output = self.model(data) loss = self.criterion(output, target) loss.backward() self.optimizer.step() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss.item()) for met in self.metric_ftns: if 'accuracy_diff' not in met.__name__: self.train_metrics.update(met.__name__, met(output, target)) else: self.train_metrics.update( met.__name__, *met(output, target, batch['q_level_logic'])) if batch_idx % self.log_step == 0 or batch_idx == self.len_epoch - 1: tqdm_bar.set_description( 'Train Epoch: {} {} Loss: {:.6f}'.format( epoch, self._progress(batch_idx), loss.item())) """ self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format( epoch, self._progress(batch_idx), loss.item())) """ #self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) if batch_idx == self.len_epoch: break log = self.train_metrics.result() if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() return log
def main(args): tokenizer = BertTokenizer.from_pretrained(args.bert_model) train_loader, _, _ = get_squad_data_loader(tokenizer, args.train_dir, shuffle=True, args=args) eval_data = get_squad_data_loader(tokenizer, args.dev_dir, shuffle=False, args=args) args.device = torch.cuda.current_device() trainer = VAETrainer(args) loss_log1 = tqdm(total=0, bar_format='{desc}', position=2) loss_log2 = tqdm(total=0, bar_format='{desc}', position=3) eval_log = tqdm(total=0, bar_format='{desc}', position=5) best_eval_log = tqdm(total=0, bar_format='{desc}', position=6) print("MODEL DIR: " + args.model_dir) best_bleu, best_em, best_f1 = 0.0, 0.0, 0.0 for epoch in trange(int(args.epochs), desc="Epoch", position=0): for batch in tqdm(train_loader, desc="Train iter", leave=False, position=1): c_ids, q_ids, a_ids, start_positions, end_positions \ = batch_to_device(batch, args.device) trainer.train(c_ids, q_ids, a_ids, start_positions, end_positions) str1 = 'Q REC : {:06.4f} A REC : {:06.4f}' str2 = 'ZQ KL : {:06.4f} ZA KL : {:06.4f} INFO : {:06.4f}' str1 = str1.format(float(trainer.loss_q_rec), float(trainer.loss_a_rec)) str2 = str2.format(float(trainer.loss_zq_kl), float(trainer.loss_za_kl), float(trainer.loss_info)) loss_log1.set_description_str(str1) loss_log2.set_description_str(str2) if epoch > 10: metric_dict, bleu, _ = eval_vae(epoch, args, trainer, eval_data) f1 = metric_dict["f1"] em = metric_dict["exact_match"] bleu = bleu * 100 _str = '{}-th Epochs BLEU : {:02.2f} EM : {:02.2f} F1 : {:02.2f}' _str = _str.format(epoch, bleu, em, f1) eval_log.set_description_str(_str) if em > best_em: best_em = em if f1 > best_f1: best_f1 = f1 trainer.save(os.path.join(args.model_dir, "best_f1_model.pt")) if bleu > best_bleu: best_bleu = bleu trainer.save(os.path.join(args.model_dir, "best_bleu_model.pt")) _str = 'BEST BLEU : {:02.2f} EM : {:02.2f} F1 : {:02.2f}' _str = _str.format(best_bleu, best_em, best_f1) best_eval_log.set_description_str(_str)
def main(args): tokenizer = BertTokenizer.from_pretrained(args.bert_model) train_loader, _, _ = get_squad_data_loader(tokenizer, args.train_dir, shuffle=True, args=args) eval_data = get_squad_data_loader(tokenizer, args.dev_dir, shuffle=False, args=args) args.device = torch.cuda.current_device() trainer = Trainer(args) log_dir = os.path.join(args.model_dir, socket.gethostname()) writer = SummaryWriter(log_dir=log_dir) loss_log = tqdm(total=0, bar_format='{desc}', position=2) eval_log = tqdm(total=0, bar_format='{desc}', position=4) best_eval_log = tqdm(total=0, bar_format='{desc}', position=5) print("MODEL DIR: " + args.model_dir) stack = 0 niter = 0 best_bleu = 0.0 for epoch in trange(int(args.epochs), desc="Epoch", position=0): #train_iterator = train_loader for batch in tqdm(train_loader, desc="Train iter", leave=False, position=1): c_ids, q_ids, a_ids \ = batch_to_device(batch, args.device) trainer.train(c_ids, q_ids, a_ids) niter += 1 writer.add_scalars('data/loss_group', {'loss_q_rec': trainer.loss_q_rec}, niter) str = 'Q REC : {:06.4f}' str = str.format(float(trainer.loss_q_rec)) loss_log.set_description_str(str) bleu = eval_vae(epoch, args, trainer, eval_data) bleu = bleu * 100 str = '{}-th Epochs BLEU : {:02.2f}' str = str.format(epoch, bleu) eval_log.set_description_str(str) writer.add_scalars('data/performance', {'bleu': bleu}, epoch) if bleu > best_bleu: best_bleu = bleu trainer.save(os.path.join(args.model_dir, "best_bleu_model.pt")) str = 'BEST BLEU : {:02.2f}' str = str.format(best_bleu) best_eval_log.set_description_str(str)
def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() with torch.no_grad(): tqdm_bar = tqdm(self.valid_data_loader, desc='Valid Epoch: {}'.format(epoch)) for batch_idx, batch in enumerate(tqdm_bar): data, target = batch_to_device(self.model_inputs, batch, self.device) output = self.model(data) loss = self.criterion(output, target) self.writer.set_step( (epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss.item()) for met in self.metric_ftns: if 'accuracy_diff' not in met.__name__: self.valid_metrics.update(met.__name__, met(output, target)) else: self.valid_metrics.update( met.__name__, *met(output, target, batch['q_level_logic'])) #self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) # add histogram of model parameters to the tensorboard if self.config['add_histogram']: for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return self.valid_metrics.result()
def evaluate(eval_loader, test=False): # Turn on evaluation mode which disables dropout. name='test' if test else 'val' model.eval() start_time = time.time() NUM_TEST = 5 NUM_CONTEXTS = [0, 1, 5, 10] indices_NUM_CONTEXTS = {0:0, 1:1, 5:2, 10:3} assert (NUM_TEST + NUM_CONTEXTS[-1]) <= dataset_info['nviews'] NUM_Q_SAMPLES = opt.num_q_samples NUM_Z_SAMPLES = opt.num_z_samples NUM_ITERS = min(opt.num_iters, len(eval_loader)) total_loss = 0. total_batch_size = 0 total_mod_batch_sizes_per_sources = [[ [0 for i in range(num_modalities)] for num_context in NUM_CONTEXTS] for num_context in NUM_CONTEXTS] total_mod_logprobs_per_sources = [[ [0 for i in range(num_modalities)] for num_context in NUM_CONTEXTS] for num_context in NUM_CONTEXTS] with torch.no_grad(): for i_query in range(NUM_Q_SAMPLES): for batch_idx, (eval_info, eval_context, eval_target) in enumerate(eval_loader): # init batch eval_context = batch_to_device(eval_context, device) eval_target = batch_to_device(eval_target, device) # get merged context / target eval_all = merge_two_batch(eval_context, eval_target) num_episodes = len(eval_all) ''' temporary ''' def new_eval(img_num_context=0, hpt_num_context=0): # get context new_eval_context, new_eval_target = binary_trim_context_target(eval_all, img_num_context=img_num_context, hpt_num_context=hpt_num_context, num_modalities=num_modalities) if sum([int(new_eval_target[0][j*2] is None) for j in range(num_modalities)]) == num_modalities: return # select test _new_eval_target = [] for target in new_eval_target: _target = tuple([target[i][-NUM_TEST:] for i in range(len(target))]) _new_eval_target += [_target] new_eval_target = _new_eval_target # get batch size _, mod_batch_sizes = get_batch_size(new_eval_target) # loss loss_mod_logprobs = [None]*num_modalities for i in range(num_modalities): if new_eval_target[i*2] is None: pass else: newnew_eval_target = [ tuple([target[j] if j//2 == i else None for j in range(num_modalities*2)]) for target in new_eval_target ] # get batch size batch_size, _ = get_batch_size(newnew_eval_target) assert batch_size == mod_batch_sizes[i] # get dim size per episode dim_per_eps = get_dim_size(newnew_eval_target, is_grayscale=opt.grayscale) # forward logprobs = [] for j in range(NUM_Z_SAMPLES): # forward _, _, logprob, info = model.predict(new_eval_context, newnew_eval_target, is_grayscale=opt.grayscale, use_uint8=opt.uint8) # append to loss_logprobs logprobs += [logprob.unsqueeze(1)] # concat logprobs = torch.cat(logprobs, dim=1) # get logprob _logprobs_max, _ = torch.max(logprobs, dim=1, keepdim=True) _logprobs = logprobs - _logprobs_max # w - \hat(w) _logprobs = torch.log(torch.sum(_logprobs.exp(), dim=1, keepdim=True)) # log sum(exp(w - \hat(w))) logprobs = -math.log(float(NUM_Z_SAMPLES)) + _logprobs_max + _logprobs # log(1/NUM_Z_SAMPLES) + w + log sum(exp(w - \hat(w))) # get logprob per dimension for j in range(num_episodes): logprobs[j:j+1] /= float(dim_per_eps[j]) # add to total_mod_logprobs loss_mod_logprobs[i] = torch.sum(logprobs) # add to total_loss for i in range(num_modalities): total_mod_logprobs_per_sources[indices_NUM_CONTEXTS[img_num_context]][indices_NUM_CONTEXTS[hpt_num_context]][i] += loss_mod_logprobs[i].item() if loss_mod_logprobs[i] is not None else 0 total_mod_batch_sizes_per_sources[indices_NUM_CONTEXTS[img_num_context]][indices_NUM_CONTEXTS[hpt_num_context]][i] += mod_batch_sizes[i] # run new_eval for hpt_num_context in NUM_CONTEXTS: for img_num_context in NUM_CONTEXTS: new_eval(img_num_context, hpt_num_context) if (batch_idx+1) % opt.vis_interval == 0: elapsed = time.time() - start_time print('[', i_query+1, '/', NUM_Q_SAMPLES, ']', ' ', batch_idx+1, '/', NUM_ITERS, 'elapsed: {:.3f} ms'.format(elapsed*1000/opt.vis_interval)) start_time = time.time() if (batch_idx+1) == NUM_ITERS: break # add to total_loss for hpt_num_context in NUM_CONTEXTS: for img_num_context in NUM_CONTEXTS: for i, (channels, height, width, nc_query, mtype) in enumerate(dataset_info['dims']): batch_size = total_mod_batch_sizes_per_sources[indices_NUM_CONTEXTS[img_num_context]][indices_NUM_CONTEXTS[hpt_num_context]][i] if mtype == 'image' and opt.grayscale: dim = 1*height*width else: dim = channels*height*width if batch_size > 0: total_mod_logprobs_per_sources[indices_NUM_CONTEXTS[img_num_context]][indices_NUM_CONTEXTS[hpt_num_context]][i] /= float(batch_size*dim) else: total_mod_logprobs_per_sources[indices_NUM_CONTEXTS[img_num_context]][indices_NUM_CONTEXTS[hpt_num_context]][i] = None # print for loss_func, loss_name in zip( [total_mod_logprobs_per_sources], ['logprob']): i = 0 logging('', path=opt.new_path) logging('', path=opt.new_path) logging('', path=opt.new_path) logging('--------------------', path=opt.new_path) logging('({}) predict = img (per pixel/dim)'.format(loss_name), path=opt.new_path) logging(''.join(['hpt: V, img: > '] + [' {:4d}'.format(img_num_context) for img_num_context in NUM_CONTEXTS]), path=opt.new_path) for hpt_num_context in NUM_CONTEXTS: txt = '{:4d}'.format(hpt_num_context) for img_num_context in NUM_CONTEXTS: loss = loss_func[indices_NUM_CONTEXTS[img_num_context]][indices_NUM_CONTEXTS[hpt_num_context]][i] writer.add_scalar('{}/pred_img/nch{}/'.format(loss_name, hpt_num_context), loss if loss is not None else -float("inf"), img_num_context) writer.add_scalar('{}/pred_img/nci{}/'.format(loss_name, img_num_context), loss if loss is not None else -float("inf"), hpt_num_context) txt += ' {:5.8f}'.format(loss if loss is not None else -float("inf")) logging(txt, path=opt.new_path) logging('--------------------', path=opt.new_path) i = 1 logging('', path=opt.new_path) logging('--------------------', path=opt.new_path) logging('({}) predict = hpt (per pixel/dim)'.format(loss_name), path=opt.new_path) logging(''.join(['hpt: V, img: > '] + [' {:4d}'.format(img_num_context) for img_num_context in NUM_CONTEXTS]), path=opt.new_path) for hpt_num_context in NUM_CONTEXTS: txt = '{:4d}'.format(hpt_num_context) for img_num_context in NUM_CONTEXTS: loss = loss_func[indices_NUM_CONTEXTS[img_num_context]][indices_NUM_CONTEXTS[hpt_num_context]][i] writer.add_scalar('{}/pred_hpt/nch{}/'.format(loss_name, hpt_num_context), loss if loss is not None else -float("inf"), img_num_context) writer.add_scalar('{}/pred_hpt/nci{}/'.format(loss_name, img_num_context), loss if loss is not None else -float("inf"), hpt_num_context) txt += ' {:5.8f}'.format(loss if loss is not None else -float("inf")) logging(txt, path=opt.new_path) logging('--------------------', path=opt.new_path) return total_mod_logprobs_per_sources
def evaluate(eval_loader, name='val'): # Turn on evaluation mode which disables dropout. model.eval() start_time = time.time() #NUM_Q_SAMPLES = opt.num_q_samples NUM_Z_SAMPLES = opt.num_z_samples NUM_ITERS = min(opt.num_iters, len(eval_loader)) MOD_STEP = opt.mod_step #5 #3 NUM_MODS = sorted( list( set([1] + [n_mods for n_mods in range(2, num_modalities + 1, MOD_STEP)] + [num_modalities]))) NUM_CONTEXTS = [0, 1, 2, 3, 4, 5] NUM_TARGET = 5 MASK_STEP = opt.mask_step #10 #5 all_masks = [] for n_mods in NUM_MODS: masks = get_masks(num_modalities, min_modes=n_mods, max_modes=n_mods) masks = list(set(masks[::MASK_STEP] + [masks[-1]])) all_masks += masks m_indices = dict( zip([get_str_from_mask(mask) for mask in all_masks], [i for i in range(len(all_masks))])) logging('num mods : {}'.format(NUM_MODS), path=opt.path) logging('num ctxs : {}'.format(NUM_CONTEXTS), path=opt.path) logging('num tgt : {}'.format(NUM_TARGET), path=opt.path) logging('mask step: {}'.format(MASK_STEP), path=opt.path) logging('masks : {}'.format(m_indices), path=opt.path) total_avg_batch_sizes_per_nmod_nctx = [ [0 for i in range(len(NUM_CONTEXTS))] for j in range(num_modalities) ] total_avg_acc1_per_nmod_nctx = [[0 for i in range(len(NUM_CONTEXTS))] for j in range(num_modalities)] total_avg_acc5_per_nmod_nctx = [[0 for i in range(len(NUM_CONTEXTS))] for j in range(num_modalities)] total_batch_sizes_per_nmod_nctx = [[0 for i in range(len(NUM_CONTEXTS))] for j in range(len(all_masks))] total_acc1_per_nmod_nctx = [[0 for i in range(len(NUM_CONTEXTS))] for j in range(len(all_masks))] total_acc5_per_nmod_nctx = [[0 for i in range(len(NUM_CONTEXTS))] for j in range(len(all_masks))] with torch.no_grad(): for batch_idx, (eval_info, eval_context, eval_target) in enumerate(eval_loader): # init batch eval_context = batch_to_device(eval_context, device) eval_target = batch_to_device(eval_target, device) eval_all = merge_two_batch(eval_context, eval_target) num_episodes = len(eval_context) # select target _new_eval_target = [] for target in eval_all: _target = tuple([ target[i][-NUM_TARGET:] if target[i] is not None else None for i in range(len(target)) ]) _new_eval_target += [_target] eval_target = _new_eval_target # forward for n_mods in NUM_MODS: masks = get_masks(num_modalities, min_modes=n_mods, max_modes=n_mods) masks = list(set(masks[::MASK_STEP] + [masks[-1]])) for mask in masks: n_mods = sum(mask) avg_m_idx = n_mods - 1 m_idx = m_indices[get_str_from_mask(mask)] for c_idx, num_context in enumerate(NUM_CONTEXTS): # select context _new_eval_context = [] for context in eval_all: _context = tuple([ context[i][:num_context] if context[i] is not None and num_context > 0 and mask[i // 2] else None for i in range(len(context)) ]) _new_eval_context += [_context] eval_context = _new_eval_context # get labels eval_label = torch.Tensor( [i for i in range(num_episodes)]).long().to(device) # infer logprobs_per_batch = [] for i_ep in range(num_episodes): new_eval_target = [eval_target[i_ep] ] * num_episodes # get dim size per episode dim_per_eps = get_dim_size( new_eval_target, is_grayscale=opt.grayscale) # forward logprobs = [] for j in range(NUM_Z_SAMPLES): # forward _, _, logprob, info = model.predict( eval_context, new_eval_target, is_grayscale=opt.grayscale, use_uint8=False) # append to loss_logprobs logprobs += [logprob.unsqueeze(1)] # concat logprobs = torch.cat(logprobs, dim=1) # get logprob logprobs = logprob_logsumexp(logprobs).detach() # get logprob per dimension for i in range(num_episodes): logprobs[i:i + 1] /= float(dim_per_eps[i]) # append logprobs_per_batch += [logprobs.unsqueeze(1)] # concat logprobs_per_batch = torch.cat(logprobs_per_batch, dim=1) # get acc acc1, acc5 = accuracy(logprobs_per_batch, eval_label, topk=(1, 5)) cur_acc1 = acc1[0].item() cur_acc5 = acc5[0].item() total_avg_acc1_per_nmod_nctx[avg_m_idx][ c_idx] += cur_acc1 * num_episodes total_avg_acc5_per_nmod_nctx[avg_m_idx][ c_idx] += cur_acc5 * num_episodes total_avg_batch_sizes_per_nmod_nctx[avg_m_idx][ c_idx] += num_episodes total_acc1_per_nmod_nctx[m_idx][ c_idx] += cur_acc1 * num_episodes total_acc5_per_nmod_nctx[m_idx][ c_idx] += cur_acc5 * num_episodes total_batch_sizes_per_nmod_nctx[m_idx][ c_idx] += num_episodes # plot if (batch_idx + 1) % opt.vis_interval == 0 or ( batch_idx + 1) == len(eval_loader): elapsed = time.time() - start_time start_time = time.time() # print logging('| {} ' '| {:5d}/{:5d} ' '| sec/step {:5.2f} ' '| acc (top1) {:.3f} ' '| acc (top5) {:.3f} '.format( name, batch_idx + 1, len(eval_loader), elapsed / opt.vis_interval, cur_acc1, cur_acc5, ), path=opt.path) if (batch_idx + 1) == NUM_ITERS: break # print logging(''.join( ['masks V / # of context > '] + [' {:4d}'.format(num_context) for num_context in NUM_CONTEXTS]), path=opt.new_path) logging('=' * 17 + ' acc1 ' + '=' * 17 + ' | ' + '=' * 17 + ' acc5 ' + '=' * 17, path=opt.new_path) for mask in all_masks: mask_str = get_str_from_mask(mask) m_idx = m_indices[mask_str] txt = ' {} |'.format(mask_str) for c_idx, num_context in enumerate(NUM_CONTEXTS): total_batch_size = total_batch_sizes_per_nmod_nctx[m_idx][c_idx] total_acc1 = total_acc1_per_nmod_nctx[m_idx][ c_idx] / total_batch_size writer.add_scalar('mask{}/{}/acc1'.format(mask_str, name), total_acc1, num_context) txt += ' {:3.1f}'.format(total_acc1) txt += ' | ' for c_idx, num_context in enumerate(NUM_CONTEXTS): total_batch_size = total_batch_sizes_per_nmod_nctx[m_idx][c_idx] total_acc5 = total_acc5_per_nmod_nctx[m_idx][ c_idx] / total_batch_size writer.add_scalar('mask{}/{}/acc5'.format(mask_str, name), total_acc5, num_context) txt += ' {:3.1f}'.format(total_acc5) logging(txt, path=opt.new_path) # print logging('', path=opt.new_path) logging('', path=opt.new_path) logging(''.join( ['# of mods V / # of context > '] + [' {:4d}'.format(num_context) for num_context in NUM_CONTEXTS]), path=opt.new_path) logging('=' * 17 + ' acc1 ' + '=' * 17 + ' | ' + '=' * 17 + ' acc5 ' + '=' * 17, path=opt.new_path) for n_mods in NUM_MODS: avg_m_idx = n_mods - 1 txt = ' {} |'.format(n_mods) for c_idx, num_context in enumerate(NUM_CONTEXTS): total_avg_batch_size = total_avg_batch_sizes_per_nmod_nctx[ avg_m_idx][c_idx] total_avg_acc1 = total_avg_acc1_per_nmod_nctx[avg_m_idx][ c_idx] / total_avg_batch_size writer.add_scalar('M{}/{}/acc1'.format(n_mods, name), total_avg_acc1, num_context) writer.add_scalar('C{}/{}/acc1'.format(num_context, name), total_avg_acc1, n_mods) txt += ' {:3.1f}'.format(total_avg_acc1) txt += ' | ' for c_idx, num_context in enumerate(NUM_CONTEXTS): total_avg_batch_size = total_avg_batch_sizes_per_nmod_nctx[ avg_m_idx][c_idx] total_avg_acc5 = total_avg_acc5_per_nmod_nctx[avg_m_idx][ c_idx] / total_avg_batch_size writer.add_scalar('M{}/{}/acc5'.format(n_mods, name), total_avg_acc5, num_context) writer.add_scalar('C{}/{}/acc5'.format(num_context, name), total_avg_acc5, n_mods) txt += ' {:3.1f}'.format(total_avg_acc5) logging(txt, path=opt.new_path) return total_acc1 / total_batch_size, total_acc5 / total_batch_size
def main(args): tokenizer = BertTokenizer.from_pretrained(args.bert_model) train_loader, _, _ = get_squad_data_loader(tokenizer, args.train_dir, shuffle=True, args=args) eval_data = get_squad_data_loader(tokenizer, args.dev_dir, shuffle=False, args=args) args.device = torch.cuda.current_device() trainer = VAETrainer(args) log_dir = os.path.join(args.model_dir, socket.gethostname()) writer = SummaryWriter(log_dir=log_dir) loss_log = tqdm(total=0, bar_format='{desc}', position=2) eval_log = tqdm(total=0, bar_format='{desc}', position=4) best_eval_log = tqdm(total=0, bar_format='{desc}', position=5) print("MODEL DIR: " + args.model_dir) stack = 0 niter = 0 best_avg_qa_loss, best_bleu, best_em, best_f1 = 1000.0, 0.0, 0.0, 0.0 for epoch in trange(int(args.epochs), desc="Epoch", position=0): #train_iterator = train_loader for batch in tqdm(train_loader, desc="Train iter", leave=False, position=1): c_ids, q_ids, a_ids, start_positions, end_positions \ = batch_to_device(batch, args.device) trainer.train(c_ids, q_ids, a_ids, start_positions, end_positions) niter += 1 writer.add_scalars( 'data/loss_group', { 'loss_q_rec': trainer.loss_q_rec, 'loss_a_rec': trainer.loss_a_rec, 'loss_kl': trainer.loss_kl, 'loss_info': trainer.loss_info }, niter) str = 'Q REC : {:06.4f} A REC : {:06.4f} KL : {:06.4f} INFO : {:06.4f}' str = str.format(float(trainer.loss_q_rec), float(trainer.loss_a_rec), float(trainer.loss_kl), float(trainer.loss_info)) loss_log.set_description_str(str) metric_dict, bleu, all_results \ = eval_vae(epoch, args, trainer, eval_data) f1 = metric_dict["f1"] em = metric_dict["exact_match"] bleu = bleu * 100 str = '{}-th Epochs BLEU : {:02.2f} EM : {:02.2f} F1 : {:02.2f}' str = str.format(epoch, bleu, em, f1) eval_log.set_description_str(str) writer.add_scalars('data/performance', { 'bleu': bleu, 'em': em, 'f1': f1 }, epoch) if em > best_em: best_em = em if f1 > best_f1: best_f1 = f1 trainer.save(os.path.join(args.model_dir, "best_f1_model.pt")) if bleu > best_bleu: best_bleu = bleu trainer.save(os.path.join(args.model_dir, "best_bleu_model.pt")) str = 'BEST BLEU : {:02.2f} EM : {:02.2f} F1 : {:02.2f}' str = str.format(best_bleu, best_em, best_f1) best_eval_log.set_description_str(str) mat = [] metadata = [] for j in range(len(all_results)): mat.append(all_results[j].posterior_z_prob.view(-1)) str = "[{}] [Pos] Real Q: {} Real A: {} Pos Q: {} Pos A: {}" str = str.format(j, all_results[j].real_question, all_results[j].real_answer, all_results[j].posterior_question, all_results[j].posterior_answer) if j % 100 == 0: print('###################### real questions\n') print(all_results[j].real_question) print(all_results[j].real_answer) print('###################### generated prior questions\n') print(all_results[j].posterior_question) print(all_results[j].posterior_answer) print('###################### generated prior questions\n') print(all_results[j].prior_question) print(all_results[j].prior_answer) metadata.append(str) mat.append(all_results[j].prior_z_prob.view(-1)) str = "[{}] [Pri] Pri Q: {} Pri A: {}" str = str.format(j, all_results[j].prior_question, all_results[j].prior_answer) metadata.append(str) mat = torch.stack(mat, dim=0) writer.add_embedding(mat=mat, metadata=metadata, global_step=epoch)
def evaluate(eval_loader, test=False): # Turn on evaluation mode which disables dropout. name = 'test' if test else 'val' model.eval() transform = get_transform() NUM_ITERS = min(opt.num_iters, len(eval_loader)) NUM_TEST = 5 NUM_CONTEXTS = sorted([nc for nc in opt.n_context]) if len(NUM_CONTEXTS) == 0: NUM_CONTEXTS = [0, 1, 5, 10] #[0, 1, 5, 10, 15] assert (NUM_TEST + NUM_CONTEXTS[-1]) <= dataset_info['nviews'] NUM_MODS = sorted([n_mods for n_mods in opt.n_mods]) if len(NUM_MODS) == 0: NUM_MODS = [n_mod for n_mod in range(1, num_modalities + 1)] assert NUM_MODS[-1] <= num_modalities assert NUM_MODS[0] > 0 all_masks = [] for n_mods in NUM_MODS: masks = get_masks(num_modalities, min_modes=n_mods, max_modes=n_mods) all_masks += masks logging('num mods : {}'.format(NUM_MODS), path=opt.path) logging('num ctxs : {}'.format(NUM_CONTEXTS), path=opt.path) logging('num tgt : {}'.format(NUM_TEST), path=opt.path) logging('masks : {}'.format(all_masks), path=opt.path) hpt_tgt_gen = {} avg_diffs = {} num_datas = {} with torch.no_grad(): for i_sample in range(1, opt.num_samples + 1): did_plot = [False] * num_classes for batch_idx, (eval_info, eval_context, eval_target) in enumerate(eval_loader): # init batch eval_context = batch_to_device(eval_context, device) eval_target = batch_to_device(eval_target, device) eval_all = merge_two_batch(eval_context, eval_target) num_episodes = len(eval_context) # get img_queries img_queries = torch.from_numpy( np.array(eval_info[0]['add_cameras'])).float() # get true_images and hand_images true_images = load_images(eval_info[0]['add_images'], transform) _true_images = get_grid_image(true_images, 16, 3, 64, 64, nrow=4, pad_value=0) hand_images = load_images(eval_info[0]['hand_images'], transform) _hand_images = get_grid_image(hand_images, 15, 3, 64, 64, nrow=15, pad_value=0) _data_images = [] for idx, (nchannels, nheight, nwidth, _, mtype) in enumerate(dataset_info['dims']): if mtype == 'image': _data_images += [eval_all[0][idx * 2]] _data_images = get_combined_visualization_image_data( opt.dataset, dataset_info['dims'], _data_images, dataset_info['nviews'], min(4, num_episodes), nrow=15, pad_value=0)[0] ''' temporary ''' assert len(eval_context) == 1 assert len(eval_target) == 1 cls = eval_info[0]['class'] ''' per class ''' # visualize per class if not did_plot[cls]: # change flag did_plot[cls] = True # draw true_images and hand_images writer.add_image( '{}/gt-img-cls{}-i{}'.format(name, cls, i_sample), _true_images, 0) writer.add_image( '{}/hand-img-cls{}-i{}'.format(name, cls, i_sample), _hand_images, 0) writer.add_image( '{}/data-img-cls{}-i{}'.format(name, cls, i_sample), _data_images, 0) for num_context in NUM_CONTEXTS: _hand_images = get_grid_image( hand_images[:num_context], num_context, 3, 64, 64, nrow=5, pad_value=0) writer.add_image( '{}/ctx-hand-img-cls{}-i{}-nc{}'.format( name, cls, i_sample, num_context), _hand_images, 0) def draw_img_gen(mask, num_context=0): # get mask index m_idx = sum(mask) - 1 # get context new_eval_context, new_eval_target = trim_context_target( eval_all, num_context=num_context, mask=mask, num_modalities=num_modalities) if sum([ int(new_eval_target[0][j * 2] is None) for j in range(num_modalities) ]) == num_modalities: return # select test _new_eval_target = [] for i in range(num_episodes): _target = [] for idx, (nchannels, nheight, nwidth, _, mtype) in enumerate( dataset_info['dims']): data, query = new_eval_target[i][ idx * 2], new_eval_target[i][idx * 2 + 1] if mtype == 'haptic': _target += [ data[-NUM_TEST:] if data is not None else None ] _target += [ query[-NUM_TEST:] if data is not None else None ] else: _target += [data] _target += [query] _new_eval_target += [tuple(_target)] new_eval_target = _new_eval_target # get batch size batch_size, mod_batch_sizes = get_batch_size( new_eval_target) # get queries mod_queries, num_mod_queries = get_queries( new_eval_target, device, num_hpt_queries=NUM_TEST, img_queries=img_queries) # forward outputs, _, _, _ = model(new_eval_context, new_eval_target, is_grayscale=opt.grayscale) # generate gens, _ = model.generate(new_eval_context, tuple(mod_queries), is_grayscale=opt.grayscale) # visualize img_ctxs, img_tgts, img_outputs, img_gens = [], [], [], [] hpt_ctxs, hpt_tgts, hpt_outputs, hpt_gens = [], [], [], [] for idx, (nchannels, nheight, nwidth, _, mtype) in enumerate(dataset_info['dims']): # get output and gen output = outputs[idx] gen = gens[idx] _num_mod_queries = num_mod_queries[idx] # visualize if mtype == 'image': # grayscale if opt.grayscale: if output.size(0) > 0: output = output.expand( output.size(0), nchannels, nheight, nwidth) gen = gen.expand(gen.size(0), nchannels, nheight, nwidth) # get ctx, tgt if num_context > 0 and mask[idx]: sz = new_eval_context[0][idx * 2].size()[1:] ctx = torch.cat([ new_eval_context[0][idx * 2], gen.new_zeros( dataset_info['nviews'] - num_context, *sz) ], dim=0) num_target = new_eval_target[0][ idx * 2].size(0) if new_eval_target[0][ idx * 2] is not None else 0 assert num_target == output.size(0) if num_target > 0: tgt = torch.cat([ gen.new_zeros( dataset_info['nviews'] - num_target, *sz), new_eval_target[0][idx * 2], ], dim=0) output = torch.cat([ gen.new_zeros( dataset_info['nviews'] - num_target, *sz), output, ], dim=0) else: tgt = gen.new_zeros( dataset_info['nviews'] * num_episodes, *sz) output = gen.new_zeros( dataset_info['nviews'] * num_episodes, *sz) else: ctx = gen.new_zeros( dataset_info['nviews'] * num_episodes, nchannels, nheight, nwidth) tgt = new_eval_target[0][idx * 2] # append to list img_gens += [gen] img_outputs += [output] img_ctxs += [ctx] img_tgts += [tgt] num_img_queries = _num_mod_queries elif mtype == 'haptic': ctx = new_eval_context[0][idx * 2] tgt = new_eval_target[0][idx * 2] # append to list hpt_gens += [gen] hpt_outputs += [output] hpt_ctxs += [ctx] hpt_tgts += [tgt] num_hpt_queries = _num_mod_queries else: raise NotImplementedError # combine haptic if not get_str_from_mask(mask) in hpt_tgt_gen: hpt_tgt_gen[get_str_from_mask(mask)] = {} avg_diffs[get_str_from_mask(mask)] = np.zeros( len(NUM_CONTEXTS)) num_datas[get_str_from_mask(mask)] = 0 hpt_tgts = torch.cat(hpt_tgts, dim=1) hpt_gens = torch.cat(hpt_gens, dim=1) hpt_tgt_gen[get_str_from_mask(mask)][num_context] = ( hpt_tgts, hpt_gens) # visualize combined image xgs = get_combined_visualization_image_data( opt.dataset, dataset_info['dims'], img_gens, num_img_queries, min(4, num_episodes), nrow=4, pad_value=0) xos = get_combined_visualization_image_data( opt.dataset, dataset_info['dims'], img_outputs, dataset_info['nviews'], min(4, num_episodes), nrow=4, pad_value=0) xcs = get_combined_visualization_image_data( opt.dataset, dataset_info['dims'], img_ctxs, dataset_info['nviews'], min(4, num_episodes), nrow=4, pad_value=0) _xcs = get_combined_visualization_image_data( opt.dataset, dataset_info['dims'], img_ctxs, num_context, min(4, num_episodes), nrow=5, pad_value=0) xts = get_combined_visualization_image_data( opt.dataset, dataset_info['dims'], img_tgts, dataset_info['nviews'], min(4, num_episodes), nrow=4, pad_value=0) for i, (xc, xt, xo, xg) in enumerate(zip(xcs, xts, xos, xgs)): writer.add_image( '{}/ctx-cls{}-M{}-mask{}-i{}-nc{}/img'.format( name, cls, m_idx + 1, get_str_from_mask(mask), i, num_context), _xcs[i], 0) writer.add_image( '{}/gen-cls{}-M{}-mask{}-i{}-nc{}/img'.format( name, cls, m_idx + 1, get_str_from_mask(mask), i, num_context), xg, 0) x = torch.cat([xc, xt, xo, xg], dim=2) writer.add_image( '{}/ctx-tgt-rec-gen-cls{}-M{}-mask{}-i{}-nc{}/img' .format(name, cls, m_idx + 1, get_str_from_mask(mask), i, num_context), x, 0) # run vis for mask in all_masks: for num_context in NUM_CONTEXTS: draw_img_gen(mask=mask, num_context=num_context) # visualize combined haptic m_idx = sum(mask) - 1 # get mask index xs, diffs = get_combined_visualization_haptic_data( hpt_tgt_gen[get_str_from_mask(mask)], title='mask: {}'.format(get_str_from_mask(mask))) _diffs = [ np.mean([diffs[i][j] for i in range(len(diffs))]) for j in range(len(NUM_CONTEXTS)) ] num_datas[get_str_from_mask(mask)] += 1 for j, diff in enumerate(_diffs): avg_diffs[get_str_from_mask(mask)][j:j + 1] += diff num_context = NUM_CONTEXTS[j] writer.add_scalar( '{}/diff-cls{}-M{}-mask{}-all/hpt'.format( name, cls, m_idx + 1, get_str_from_mask(mask)), diff, num_context) for i, x in enumerate(xs): writer.add_image( '{}/tgt-gen-cls{}-M{}-mask{}-i{}/hpt'.format( name, cls, m_idx + 1, get_str_from_mask(mask), i), convert_npimage_torchimage(x), 0) for j, diff in enumerate(diffs[i]): num_context = NUM_CONTEXTS[j] writer.add_scalar( '{}/diff-cls{}-M{}-mask{}-i{}/hpt'.format( name, cls, m_idx + 1, get_str_from_mask(mask), i), diff, num_context) if (batch_idx + 1) % 1 == 0: print(batch_idx + 1, '/', NUM_ITERS, ' [', len(eval_loader), ']') if (batch_idx + 1) == NUM_ITERS: break return
def train(self): global_step = 0 running_loss = 0.0 eval_steps = 0 for epoch in trange(self.train_epochs): train_preds = [] train_labels = [] logger.info(f"start training epoch {epoch + 1}") logger.info(f"training using device={self.device}") logger.info("\n*************hyperparam_dict**********\n") logger.info(json.dumps(self.hyperparam_dict, indent=2)) epoch_train_loss = 0.0 pbar = tqdm(enumerate(self.train_dataloader), total=len(self.train_dataloader)) for step, (X, y) in (pbar): self.optimizer.zero_grad() self.model.train() features, labels = batch_to_device(X, y, device=self.device) if self.loss_fn == "ce": labels = labels.long() # zero the parameter gradients outputs = self.model(features.float()) loss = self.criterion(outputs, labels) # total_loss=loss + self.model.reg_loss# add reg_loss to avoid overfitting loss.backward() self.optimizer.step() pbar.set_description( f"training epoch {epoch + 1}/{self.train_epochs} iter {step}: train loss {loss.item():.5f}. lr {self.lr:e}" ) train_preds.extend(outputs.tolist( )) if self.model.task == "reg" else train_preds.extend( torch.argmax(outputs, -1).tolist()) train_labels.extend(labels.tolist()) # print statistics running_loss += loss.item() epoch_train_loss += loss.item() eval_steps += 1 if eval_steps == self.eval_every: logger.info( f'\n*****************[epoch: {epoch + 1}, global step: {global_step + 1}] eval training set based on eval_every={self.eval_every}***************' ) train_eval_metrics = self.eval_train_during_training( train_labels, train_preds) train_eval_metrics[ "train_loss"] = running_loss / eval_steps logger.info(json.dumps(train_eval_metrics, indent=2)) self.tensorboard_logging(train_eval_metrics, global_step) # wandb logging train if is_wandb_available() and self.use_wandb: wandb.log(train_eval_metrics, step=global_step) running_loss = 0.0 eval_steps = 0 # evalute if self.dev_data != None if self.dev_data is not None: is_out_of_patience = self.evaluate_dev_data( epoch, global_step) if is_out_of_patience: logger.info( f" run out of patience={self.patience} and save model before exit" ) self.save_checkpoint( os.path.join(self.save_path, f"ck_{global_step + 1}")) return 0 self.save_checkpoint( os.path.join(self.save_path, f"ck_{global_step + 1}")) global_step += 1 logger.info( f'\n*****************[epoch: {epoch + 1}, global step: {global_step + 1}] eval training set at end of epoch***************' ) eval_metrics = self.eval_train_during_training( train_labels, train_preds) eval_metrics["train_loss"] = epoch_train_loss / len( self.train_dataloader) logger.info(json.dumps(eval_metrics, indent=2)) if self.eval_every == -1: if self.dev_data is not None: is_out_of_patience = self.evaluate_dev_data( epoch, global_step) if is_out_of_patience: logger.info( f" run out of patience={self.patience} and save model before exit" ) self.save_checkpoint( os.path.join(self.save_path, f"ck_{global_step + 1}")) return 0 self.save_checkpoint( os.path.join(self.save_path, f"ck_{global_step + 1}")) self.tb_writer.close()
def evaluate(eval_loader, test=False): # Turn on evaluation mode which disables dropout. name='test' if test else 'val' model.eval() transform = get_transform() with torch.no_grad(): for i_sample in range(1, opt.num_samples+1): did_plot = [False]*num_classes for batch_idx, (eval_info, eval_context, eval_target) in enumerate(eval_loader): # init batch eval_context = batch_to_device(eval_context, device) eval_target = batch_to_device(eval_target, device) eval_all = merge_two_batch(eval_context, eval_target) num_episodes = len(eval_context) batch_size, mod_batch_sizes = get_batch_size(eval_target) # get img_queries img_queries = torch.from_numpy(np.array(eval_info[0]['add_cameras'])).float() # get true_images and hand_images true_images = load_images(eval_info[0]['add_images'], transform) _true_images = get_grid_image(true_images, 16, 3, 64, 64, nrow=4, pad_value=0) hand_images = load_images(eval_info[0]['hand_images'], transform) _hand_images = get_grid_image(hand_images, 15, 3, 64, 64, nrow=15, pad_value=0) _fst_image = get_grid_image(eval_all[0][0][:1], 1, 3, 64, 64, nrow=1, pad_value=0) _data_image = get_grid_image(eval_all[0][0], 15, 3, 64, 64, nrow=15, pad_value=0) ''' temporary ''' assert len(eval_context) == 1 assert len(eval_target) == 1 cls = eval_info[0]['class'] if (batch_idx+1) % 1 == 0: print(batch_idx+1, '/', len(eval_loader)) ''' per class ''' # visualize per class if not did_plot[cls]: # change flag did_plot[cls] = True # draw true_images and hand_images writer.add_image('{}/gt-img-cls{}-i{}'.format(name, cls, i_sample), _true_images, 0) writer.add_image('{}/hand-img-cls{}-i{}'.format(name, cls, i_sample), _hand_images, 0) writer.add_image('{}/fst-img-cls{}-i{}'.format(name, cls, i_sample), _fst_image, 0) writer.add_image('{}/data-img-cls{}-i{}'.format(name, cls, i_sample), _data_image, 0) # init queries mod_queries, num_mod_queries = [], [] for idx, (_, _, _, _, mtype) in enumerate(dataset_info['dims']): # get queries if mtype == 'image': # image queries _mod_queries, _num_mod_queries = get_visualization_queries_with_predefined_dist(num_episodes, device, img_queries) elif mtype == 'haptic': # haptic queries _mod_queries, _num_mod_queries = get_visualization_queries_for_haptic(num_episodes, device) # append to list mod_queries += [_mod_queries] num_mod_queries += [_num_mod_queries] def draw_img_gen(num_context=0, use_img=True, use_hpt=True, use_first_img=True): # use_first_img if use_first_img: first_image = [tuple([ eval_all[i][j][:1] if j//2 == 0 else None for j in range(len(eval_all[i])) ]) for i in range(num_episodes)] new_eval_all = [tuple([ eval_all[i][j][1:] if j//2 == 0 else eval_all[i][j] for j in range(len(eval_all[i])) ]) for i in range(num_episodes)] else: new_eval_all = eval_all # get context new_eval_context, new_eval_target = trim_context_target(new_eval_all, num_context=num_context, use_img=use_img, use_hpt=use_hpt) if new_eval_target[0][0] is None and new_eval_target[0][2] is None: return if use_first_img: new_eval_context = merge_two_batch(new_eval_context, first_image) # forward outputs, _, _, _ = model(new_eval_context, new_eval_target, is_grayscale=opt.grayscale) # generate gens, _ = model.generate(new_eval_context, tuple(mod_queries), is_grayscale=opt.grayscale) # visualize img_gens = [] for idx, (nchannels, nheight, nwidth, _, mtype) in enumerate(dataset_info['dims']): # get output and gen output = outputs[idx] gen = gens[idx] _num_mod_queries = num_mod_queries[idx] # visualize if mtype == 'image': # grayscale if opt.grayscale: output = output.expand(output.size(0), nchannels, nheight, nwidth) gen = gen.expand(gen.size(0), nchannels, nheight, nwidth) _gen = get_grid_image(gen, 16, 3, 64, 64, nrow=4, pad_value=0) writer.add_image('{}/m{}-gen-cls{}-uimg{}-uhpt{}-ufimg{}-i{}-nc{}/img'.format(name, idx, cls, int(use_img), int(use_hpt), int(use_first_img), i_sample, num_context), _gen, 0) # visualize predictions (image) xs = get_visualization_image_data(idx, nchannels, nheight, nwidth, device, new_eval_context, new_eval_target, output, gen, _num_mod_queries, dataset_info['nviews'], nrow=4) for i, x in enumerate(xs): writer.add_image('{}/m{}-cond-target-recon-gen-cls{}-uimg{}-uhpt{}-ufimg{}-i{}-nc{}/img'.format(name, idx, cls, int(use_img), int(use_hpt), int(use_first_img), i_sample, num_context), x, 0) img_gens += [gen] num_img_queries = _num_mod_queries elif mtype == 'haptic': # visualize predictions (haptic) xs = get_visualization_haptic_data(idx, nchannels, nheight, device, new_eval_context, new_eval_target, output, gen, _num_mod_queries) for i, x in enumerate(xs): writer.add_image('{}/m{}-cond-target-recon-gen-cls{}-uimg{}-uhpt{}-ufimg{}-i{}-nc{}/hpt'.format(name, idx, cls, int(use_img), int(use_hpt), int(use_first_img), i_sample, num_context), x, 0) else: raise NotImplementedError # run vis use_imgs = [True, False] if eval_all[0][0] is not None else [False] use_hpts = [True, False] if eval_all[0][2] is not None else [False] for use_first_img in [False, True]: for use_img in use_imgs: for use_hpt in use_hpts: for num_context in [0, 1, 2, 3, 4, 5, 10, 15]: draw_img_gen(num_context, use_img, use_hpt, use_first_img) return
def train(train_loader, model, optimizer, epoch, start_batch_idx=0): # init start_time = time.time() model.train() total_loss = 0. total_likelihood = 0. total_mod_likelihoods = [0.]*num_modalities total_kl = 0. total_batch_size = 0 total_mod_batch_sizes = [0]*num_modalities for _batch_idx, (_, train_context, train_target) in enumerate(train_loader): # init batch_idx batch_idx = _batch_idx + start_batch_idx i_episode = (epoch-1)*len(train_loader) + batch_idx # init beta and std beta = opt.beta_init + (opt.beta_fin - opt.beta_init) / float(opt.beta_annealing) * float(min(opt.beta_annealing, i_episode)) std = opt.std_init + (opt.std_fin - opt.std_init) / float(opt.std_annealing) * float(min(opt.std_annealing, i_episode)) if opt.std_annealing is not None else None # init batch train_context = batch_to_device(train_context, device) train_target = batch_to_device(train_target, device) # add additional datasets _train_context = [] _train_target = [] if opt.add_opposite: _train_context += train_target _train_target += train_context train_context += _train_context train_target += _train_target # init numbers num_episodes = len(train_context) batch_size, mod_batch_sizes = get_batch_size(train_target) # init grad model.zero_grad() ''' ELBO ''' # forward (joint observation) outputs, latent, loss, info = \ model(train_context, train_target, beta=beta) if opt.std_annealing is None \ else model(train_context, train_target, beta=beta, std=std) # backward (joint observation) loss.backward() # forward (module-specific observation) if opt.add_mod: for m in range(num_modalities): # check target is not empty is_not_empty = True in [train_target[i][m*2] is not None for i in range(num_episodes)] if is_not_empty: # fetch module-specific data mod_train_context = [] mod_train_target = [] for i in range(num_episodes): if train_target[i][m*2] is not None: _mod_train_context = [None, None]*num_modalities _mod_train_context[m*2] = train_context[i][m*2] _mod_train_context[m*2+1] = train_context[i][m*2+1] _mod_train_context = tuple(_mod_train_context) mod_train_context += [_mod_train_context] _mod_train_target = [None, None]*num_modalities _mod_train_target[m*2] = train_target[i][m*2] _mod_train_target[m*2+1] = train_target[i][m*2+1] _mod_train_target = tuple(_mod_train_target) mod_train_target += [_mod_train_target] # forward (module-specific observation) _, _, mod_loss, _ = \ model(mod_train_context, mod_train_target, beta=beta) if opt.std_annealing is None \ else model(mod_train_context, mod_train_target, beta=beta, std=std) # backward (module-specific observation) if mod_loss is not None: mod_loss.backward() # unpack info loss_likelihood, loss_kl = info['likelihood'], info['kl'] loss_mod_likelihoods = info['mod_likelihoods'] # `clip_grad_norm` helps prevent the exploding gradient problem in continuous data with gaussian likelihood if opt.clip > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), opt.clip) # update optimizer.step() # add to total loss cur_loss = loss.item() cur_likelihood = loss_likelihood.item() cur_kl = loss_kl.item() cur_mod_likelihoods = [loss_mod_likelihood.item() if loss_mod_likelihood is not None else None for loss_mod_likelihood in loss_mod_likelihoods] total_loss += cur_loss * num_modalities #/ batch_size * num_episodes total_likelihood += cur_likelihood * num_modalities #/ batch_size * num_episodes total_kl += cur_kl * num_modalities #/ batch_size * num_episodes total_batch_size += batch_size for i in range(num_modalities): #total_mod_likelihoods[i] += cur_mod_likelihoods[i] / mod_batch_sizes[i] * num_episodes if cur_mod_likelihoods[i] is not None else 0 total_mod_likelihoods[i] += cur_mod_likelihoods[i] if cur_mod_likelihoods[i] is not None else 0 total_mod_batch_sizes[i] += mod_batch_sizes[i] # print if (batch_idx+1) % opt.log_interval == 0: # plot running val val1_loss = running_evaluate_for_val1(model) val2_loss = running_evaluate_for_val2(model) if run_val2 else -1. model.train() # set log info elapsed = time.time() - start_time lr_min, lr_max = get_lrs(optimizer) # print logging('| epoch {:3d} | {:5d}/{:5d} ' '| lr_min {:02.4f} | lr_max {:02.4f} | ms/step {:5.2f} ' '| beta {:02.4f} ' '| loss {:5.8f} | lk+kl {:5.8f} | likelihood {:5.8f} | kl {:5.8f} ' '{}' '| val1 loss (lk+kl) {:5.8f} ' '| val2 loss (lk+kl) {:5.8f} ' .format( epoch, batch_idx+1, len(train_loader), lr_min, lr_max, elapsed * 1000 / opt.log_interval, beta, cur_loss / batch_size * num_modalities, (cur_likelihood + cur_kl) / batch_size * num_modalities, cur_likelihood / batch_size * num_modalities, cur_kl / batch_size * num_modalities, ''.join(['| m{}_{}_lk {} '.format( i, mtype, '{:5.8f}'.format(cur_mod_likelihoods[i] / mod_batch_sizes[i]) if cur_mod_likelihoods[i] is not None else '-.--------', ) for i, (_, _, _, _, mtype) in enumerate(dataset_info['dims'])]), val1_loss, val2_loss, ), path=opt.path) # write to tensorboard writer.add_scalar('train/loss/step', cur_loss / batch_size * num_modalities, i_episode) writer.add_scalar('train/lk+kl/step', (cur_likelihood + cur_kl) / batch_size * num_modalities, i_episode) writer.add_scalar('train/likelihood/step', cur_likelihood / batch_size * num_modalities, i_episode) for i, (_, _, _, _, mtype) in enumerate(dataset_info['dims']): if cur_mod_likelihoods[i] is not None: writer.add_scalar('train/m{}_{}_lk/step'.format(i, 'img' if mtype == 'image' else 'hpt'), cur_mod_likelihoods[i] / mod_batch_sizes[i], i_episode) writer.add_scalar('train/kl/step', cur_kl / batch_size * num_modalities, i_episode) writer.add_scalar('train/beta', beta, i_episode) writer.add_scalar('val1/loss/step', val1_loss, i_episode) writer.add_scalar('val1/lk+kl/step', val1_loss, i_episode) writer.add_scalar('val2/loss/step', val2_loss, i_episode) writer.add_scalar('val2/lk+kl/step', val2_loss, i_episode) if std is not None: writer.add_scalar('train/std', std, i_episode) # reset log info start_time = time.time() if batch_idx+1 == len(train_loader): # print logging('| epoch {:3d} | {:5d}/{:5d} batches ' '| loss {:5.8f} | lk+kl {:5.8f} | likelihood {:5.8f} | kl {:5.8f} ' '{}' .format( epoch, batch_idx+1, len(train_loader), total_loss / total_batch_size, #len(train_loader.dataset), (total_likelihood+total_kl) / total_batch_size, #/ len(train_loader.dataset), total_likelihood / total_batch_size, #len(train_loader.dataset), total_kl / total_batch_size, #len(train_loader.dataset), ''.join(['| m{}_{}_lk {:5.8f} '.format(i, mtype, total_mod_likelihoods[i] / total_mod_batch_sizes[i]) for i, (_, _, _, _, mtype) in enumerate(dataset_info['dims']) if total_mod_batch_sizes[i] > 0]) ), path=opt.path) # write to tensorboard writer.add_scalar('train/loss', total_loss / total_batch_size, epoch) #len(train_loader.dataset), epoch) writer.add_scalar('train/likelihood', total_likelihood / total_batch_size, epoch) #len(train_loader.dataset), epoch) for i, (_, _, _, _, mtype) in enumerate(dataset_info['dims']): if total_mod_batch_sizes[i] > 0: writer.add_scalar('train/m{}_{}_lk/step'.format(i, 'img' if mtype == 'image' else 'hpt'), total_mod_likelihoods[i] / total_mod_batch_sizes[i], epoch) #len(train_loader.dataset), epoch) writer.add_scalar('train/kl', total_kl / total_batch_size, epoch) #len(train_loader.dataset), epoch) writer.add_scalar('train/lk+kl', (total_likelihood + total_kl) / total_batch_size, epoch) #len(train_loader.dataset), epoch) if (batch_idx+1) % opt.vis_interval == 0 or (batch_idx+1 == len(train_loader)): # generate image with shuffled queries and random queries model.eval() with torch.no_grad(): # init queries mod_queries, num_mod_queries = [], [] for idx, (_, _, _, _, mtype) in enumerate(dataset_info['dims']): # get queries if mtype == 'image': # image queries _mod_queries, _num_mod_queries = get_visualization_queries_with_predefined_dist(num_episodes, device) elif mtype == 'haptic': # haptic queries _mod_queries, _num_mod_queries = get_visualization_queries_from_data(idx, train_target, num_episodes) # append to list mod_queries += [_mod_queries] num_mod_queries += [_num_mod_queries] # generate gens, latent = model.generate(train_context, tuple(mod_queries)) model.train() # visualize img_gens = [] for idx, (nchannels, nheight, nwidth, _, mtype) in enumerate(dataset_info['dims']): # get output and gen output = outputs[idx] gen = gens[idx] _num_mod_queries = num_mod_queries[idx] # visualize if mtype == 'image': # visualize predictions (image) xs = get_visualization_image_data(idx, nchannels, nheight, nwidth, device, train_context, train_target, output, gen, _num_mod_queries, dataset_info['nviews']) for i, x in enumerate(xs): writer.add_image( 'train/m{}-cond-target-recon-gensh-genrd-i{}/img'.format(idx, i), x, i_episode) # temporary img_gens += [gen] num_img_queries = _num_mod_queries elif mtype == 'haptic': # visualize predictions (haptic) xs = get_visualization_haptic_data(idx, nchannels, nheight, device, train_context, train_target, output, gen, _num_mod_queries) for i, x in enumerate(xs): writer.add_image( 'train/m{}-cond-target-recon-gensh-genrd-i{}/hpt'.format(idx, i), x, i_episode) else: raise NotImplementedError # visualize combined image xs = get_combined_visualization_image_data(opt.dataset, dataset_info['dims'], img_gens, num_img_queries, min(4, len(train_context))) for i, x in enumerate(xs): writer.add_image('train/cond-target-recon-gensh-genrd-i{}/img'.format(i), x, i_episode) # save model with open(os.path.join(opt.path, 'model.pt'), 'wb') as f: torch.save(model, f) save_checkpoint({ 'epoch': epoch+1 if (batch_idx+1) == len(train_loader) else epoch, 'batch_idx': (batch_idx+1) % len(train_loader), 'model': opt.model, 'state_dict': model.state_dict(), 'best_val1_loss': best_val1_loss, 'optimizer' : optimizer.state_dict(), }, opt, False) # flush writer writer.flush() if batch_idx+1 == len(train_loader): writer.flush() break
def evaluate(eval_loader, test=False): # Turn on evaluation mode which disables dropout. name='test' if test else 'val' model.eval() total_loss = 0. total_batch_size = 0 total_mod_likelihoods = [0]*num_modalities total_mod_batch_sizes = [0]*num_modalities latents = [] with torch.no_grad(): for batch_idx, (_, eval_context, eval_target) in enumerate(eval_loader): # init batch eval_context = batch_to_device(eval_context, device) eval_target = batch_to_device(eval_target, device) num_episodes = len(eval_context) batch_size, mod_batch_sizes = get_batch_size(eval_target) # forward outputs, latent, loss, info = model(eval_context, eval_target) # unpack info loss_likelihood, loss_kl = info['likelihood'], info['kl'] loss_mod_likelihoods = info['mod_likelihoods'] # add to latents latents += [latent] if latent is not None else [] # add to total_loss total_loss += loss.item() * num_modalities #/ batch_size * num_episodes total_batch_size += batch_size for i in range(num_modalities): #total_mod_likelihoods[i] += loss_mod_likelihoods[i].item() / mod_batch_sizes[i] * num_episodes if loss_mod_likelihoods[i] is not None else 0 total_mod_likelihoods[i] += loss_mod_likelihoods[i].item() if loss_mod_likelihoods[i] is not None else 0 total_mod_batch_sizes[i] += mod_batch_sizes[i] # visualize prediction if batch_idx + 1 == len(eval_loader): # init queries mod_queries, num_mod_queries = [], [] for idx, (_, _, _, _, mtype) in enumerate(dataset_info['dims']): # get queries if mtype == 'image': # image queries _mod_queries, _num_mod_queries = get_visualization_queries_with_predefined_dist(num_episodes, device) elif mtype == 'haptic': # haptic queries _mod_queries, _num_mod_queries = get_visualization_queries_from_data(idx, eval_target, num_episodes) # append to list mod_queries += [_mod_queries] num_mod_queries += [_num_mod_queries] # generate gens, latent = model.generate(eval_context, tuple(mod_queries)) # visualize img_gens = [] for idx, (nchannels, nheight, nwidth, _, mtype) in enumerate(dataset_info['dims']): # get output and gen output = outputs[idx] gen = gens[idx] _num_mod_queries = num_mod_queries[idx] # visualize if mtype == 'image': # visualize predictions (image) xs = get_visualization_image_data(idx, nchannels, nheight, nwidth, device, eval_context, eval_target, output, gen, _num_mod_queries, dataset_info['nviews']) for i, x in enumerate(xs): writer.add_image('{}/m{}-cond-target-recon-gensh-genrd-b{}-i{}/img'.format(name, idx, batch_idx, i), x, epoch) # temporary img_gens += [gen] num_img_queries = _num_mod_queries elif mtype == 'haptic': # visualize predictions (haptic) xs = get_visualization_haptic_data(idx, nchannels, nheight, device, eval_context, eval_target, output, gen, _num_mod_queries) for i, x in enumerate(xs): writer.add_image('{}/m{}-cond-target-recon-gensh-genrd-b{}-i{}/hpt'.format(name, idx, batch_idx, i), x, epoch) else: raise NotImplementedError # visualize combined image xs = get_combined_visualization_image_data(opt.dataset, dataset_info['dims'], img_gens, num_img_queries, min(4, len(eval_context))) for i, x in enumerate(xs): writer.add_image('{}/cond-target-recon-gensh-genrd-b{}-i{}/img'.format(name, batch_idx, i), x, epoch) return total_loss / total_batch_size
def main(cfg): run_path = pjoin(cfg.out_root, cfg.run_dir) device = torch.device('cuda' if cfg.cuda else 'cpu') model = DEC(cfg.n_clusters, roi_size=1, roi_scale=cfg.roi_spatial_scale, alpha=cfg.alpha) path_cp = pjoin(run_path, 'checkpoints', 'checkpoint_autoenc.pth.tar') if (os.path.exists(path_cp)): print('loading checkpoint {}'.format(path_cp)) state_dict = torch.load(path_cp, map_location=lambda storage, loc: storage) model.autoencoder.load_state_dict(state_dict) model.autoencoder else: print( 'checkpoint {} not found. Train autoencoder first'.format(path_cp)) return transf, transf_normal = im_utils.make_data_aug(cfg) dl_train = Loader(pjoin(cfg.in_root, 'Dataset' + cfg.train_dir), normalization=transf_normal) dataloader_train = DataLoader(dl_train, batch_size=cfg.batch_size, shuffle=True, collate_fn=dl_train.collate_fn, drop_last=True, num_workers=cfg.n_workers) dataloader_prev = DataLoader(dl_train, batch_size=1, collate_fn=dl_train.collate_fn) dataloaders = {'train': dataloader_train, 'prev': dataloader_prev} check_cp_exist = pjoin(run_path, 'checkpoints', 'checkpoint_dec.pth.tar') if (os.path.exists(check_cp_exist)): print('found checkpoint at {}. Skipping.'.format(check_cp_exist)) return init_clusters_path = pjoin(run_path, 'init_clusters.npz') preds = np.load(init_clusters_path, allow_pickle=True)['preds'] init_clusters = np.load(init_clusters_path, allow_pickle=True)['clusters'] init_clusters = torch.tensor(init_clusters, dtype=torch.float, requires_grad=True) if cfg.cuda: init_clusters = init_clusters.cuda(non_blocking=True) with torch.no_grad(): # initialise the cluster centers model.state_dict()['assignment.cluster_centers'].copy_(init_clusters) model.to(device) distrib_buff = DistribBuffer(cfg.tgt_update_period) distrib_buff.maybe_update(model, dataloaders['prev'], device) criterion_clust = PairwiseConstrainedClustering(cfg.lambda_, cfg.n_edges) distrib_buff.maybe_update(model, dataloaders['prev'], device) for i, data in enumerate(dataloaders['train']): data = utls.batch_to_device(data, device) res = model(data) distribs, targets = distrib_buff[data['frame_idx']] loss, edges_pw = criterion_clust(data['graph'], res['feats'], distribs, targets) prev = im_utils.make_grid_samples(data, edges_pw, cfg.n_clusters) io.imsave(pjoin(run_path, 'prev_{:04d}.png'.format(i)), prev)
def eval_vae(epoch, args, trainer, eval_data): tokenizer = BertTokenizer.from_pretrained(args.bert_model) RawResult = collections.namedtuple( "RawResult", ["unique_id", "start_logits", "end_logits"]) eval_loader, eval_examples, eval_features = eval_data all_results = [] qa_results = [] qg_results = {} res_dict = {} example_index = -1 for batch in tqdm(eval_loader, desc="Eval iter", leave=False, position=4): c_ids, q_ids, a_ids, start, end = batch_to_device(batch, args.device) batch_size = c_ids.size(0) batch_c_ids = c_ids.cpu().tolist() batch_q_ids = q_ids.cpu().tolist() batch_start = start.cpu().tolist() batch_end = end.cpu().tolist() batch_posterior_q_ids, \ batch_posterior_start, batch_posterior_end, \ posterior_z_prob = trainer.generate_posterior(c_ids, q_ids, a_ids) batch_start_logits, batch_end_logits \ = trainer.generate_answer_logits(c_ids, q_ids, a_ids) batch_posterior_q_ids, \ batch_posterior_start, batch_posterior_end = \ batch_posterior_q_ids.cpu().tolist(), \ batch_posterior_start.cpu().tolist(), batch_posterior_end.cpu().tolist() posterior_z_prob = posterior_z_prob.cpu() batch_prior_q_ids, \ batch_prior_start, batch_prior_end, \ prior_z_prob = trainer.generate_prior(c_ids) batch_prior_q_ids, \ batch_prior_start, batch_prior_end = \ batch_prior_q_ids.cpu().tolist(), \ batch_prior_start.cpu().tolist(), batch_prior_end.cpu().tolist() prior_z_prob = prior_z_prob.cpu() for i in range(batch_size): example_index += 1 start_logits = batch_start_logits[i].detach().cpu().tolist() end_logits = batch_end_logits[i].detach().cpu().tolist() eval_feature = eval_features[example_index] unique_id = int(eval_feature.unique_id) context = to_string(batch_c_ids[i], tokenizer) real_question = to_string(batch_q_ids[i], tokenizer) posterior_question = to_string(batch_posterior_q_ids[i], tokenizer) prior_question = to_string(batch_prior_q_ids[i], tokenizer) real_answer = to_string( batch_c_ids[i][batch_start[i]:(batch_end[i] + 1)], tokenizer) posterior_answer = to_string( batch_c_ids[i][batch_posterior_start[i]:( batch_posterior_end[i] + 1)], tokenizer) prior_answer = to_string( batch_c_ids[i][batch_prior_start[i]:(batch_prior_end[i] + 1)], tokenizer) all_results.append( Result(context=context, real_question=real_question, posterior_question=posterior_question, prior_question=prior_question, real_answer=real_answer, posterior_answer=posterior_answer, prior_answer=prior_answer, posterior_z_prob=posterior_z_prob[i], prior_z_prob=prior_z_prob[i])) qg_results[unique_id] = posterior_question res_dict[unique_id] = real_question qa_results.append( RawResult(unique_id=unique_id, start_logits=start_logits, end_logits=end_logits)) output_prediction_file = os.path.join(args.model_dir, "pred.json") write_predictions(eval_examples, eval_features, qa_results, n_best_size=20, max_answer_length=30, do_lower_case=True, output_prediction_file=output_prediction_file, verbose_logging=False, version_2_with_negative=False, null_score_diff_threshold=0, noq_position=True) with open(args.dev_dir) as f: dataset_json = json.load(f) dataset = dataset_json["data"] with open(os.path.join(args.model_dir, "pred.json")) as prediction_file: predictions = json.load(prediction_file) ret = evaluate(dataset, predictions) bleu = eval_qg(res_dict, qg_results) return ret, bleu, all_results
def main(args): tokenizer = BertTokenizer.from_pretrained(args.bert_model) train_loader, _, _ = get_squad_data_loader(tokenizer, args.train_dir, shuffle=True, args=args) eval_data = get_squad_data_loader(tokenizer, args.dev_dir, shuffle=False, args=args) args.device = torch.cuda.current_device() trainer = VAETrainer(args) loss_log1 = tqdm(total=0, bar_format='{desc}', position=2) loss_log2 = tqdm(total=0, bar_format='{desc}', position=3) eval_log = tqdm(total=0, bar_format='{desc}', position=5) best_eval_log = tqdm(total=0, bar_format='{desc}', position=6) # Cargar checkpoint if args.load_checkpoint: epochs = trainer.loadd(args.model_dir) best_f1, best_bleu, best_em = VAETrainer.load_measures(args.model_dir) print( f"The current best measures are: F1 = {best_f1}, BLEU = {best_bleu} and EM = {best_em}." ) else: epochs = -1 best_bleu, best_em, best_f1 = 0.0, 0.0, 0.0 print("MODEL DIR: " + args.model_dir) mlflow_logger = init_mlflow(args, f"{args.model_dir}/mlruns") for epoch in trange(int(args.epochs), desc="Epoch", position=0): if epoch <= epochs: print(f"jumping epoch {epoch}...") else: for batch in tqdm(train_loader, desc="Train iter", leave=False, position=1): c_ids, q_ids, a_ids, start_positions, end_positions \ = batch_to_device(batch, args.device) trainer.train(c_ids, q_ids, a_ids, start_positions, end_positions) str1 = 'Q REC : {:06.4f} A REC : {:06.4f}' str2 = 'ZQ KL : {:06.4f} ZA KL : {:06.4f} INFO : {:06.4f}' str1 = str1.format(float(trainer.loss_q_rec), float(trainer.loss_a_rec)) str2 = str2.format(float(trainer.loss_zq_kl), float(trainer.loss_za_kl), float(trainer.loss_info)) loss_log1.set_description_str(str1) loss_log2.set_description_str(str2) if epoch >= 0: f1, em, bleu, _str = eval_measures(epoch, args, trainer, eval_data) eval_log.set_description_str(_str) result = {"epoch": epoch, "em": em, "f1": f1, "bleu": bleu} mlflow_logger.on_result(result) if em > best_em: best_em = em if f1 > best_f1: best_f1 = f1 trainer.save( os.path.join(args.model_dir, "best_f1_model.pt"), epoch, f1, bleu, em) if bleu > best_bleu: best_bleu = bleu trainer.save( os.path.join(args.model_dir, "best_bleu_model.pt"), epoch, f1, bleu, em) trainer.save(os.path.join(args.model_dir, "checkpoint.pt"), epoch, f1, bleu, em) mlflow_logger.on_checkpoint( f"{args.model_dir}/mlruns/checkpoint") _str = 'BEST BLEU : {:02.2f} EM : {:02.2f} F1 : {:02.2f}' _str = _str.format(best_bleu, best_em, best_f1) best_eval_log.set_description_str(_str)