def predict(context, top_n=5, normalize=False): """ returns a list of top_n tuples ("sentence", "score") """ with torch.no_grad(): context = context.unsqueeze(0) candidates = fixed_candidates if args.cuda: context = context.cuda(non_blocking=True) ctx, _ = net(context, None) scores, index = score_candidates(ctx, cand_embs, top_n, normalize) response = [] outputs = [] for i, (score, index) in enumerate(zip(scores.squeeze(0), index.squeeze(0)), 1): response.append((stringify(candidates[index]), float(score))) if index < breakingpt: outputs.append("EmpChat") elif index < breakingpt2: outputs.append("DailyDialog") else: outputs.append("Reddit") return response, outputs
def validate( epoch, model, data_loader, max_exs=100000, is_test=False, nb_candidates=100, shuffled_str="shuffled", ): model.eval() examples = 0 eval_start = time.time() sum_losses = 0 n_losses = 0 correct = 0 all_context = [] all_cands = [] n_skipped = 0 dtype = model.module.opt.dataset_name for i, ex in enumerate(data_loader): batch_size = ex[0].size(0) if dtype == "reddit" and is_test and n_skipped < max_exs: n_skipped += batch_size continue params = [ field.cuda(non_blocking=True) if opt.cuda else field if field is not None else None for field in ex ] ctx, cands = model(*params) all_context.append(ctx) all_cands.append(cands) loss, nb_ok = loss_fn(ctx, cands) sum_losses += loss correct += nb_ok n_losses += 1 examples += batch_size if examples >= max_exs and dtype == "reddit": break n_examples = 0 if len(all_context) > 0: logging.info("Processing candidate top-K") all_context = torch.cat(all_context, dim=0) # [:50000] # [N, 2h] all_cands = torch.cat(all_cands, dim=0) # [:50000] # [N, 2h] acc_ranges = [1, 3, 10] n_correct = {r: 0 for r in acc_ranges} for context, cands in list( zip(all_context.split(nb_candidates), all_cands.split(nb_candidates)))[:-1]: _, top_answers = score_candidates(context, cands) n_cands = cands.size(0) gt_index = torch.arange(n_cands, out=top_answers.new(n_cands, 1)) for acc_range in acc_ranges: n_acc = (top_answers[:, :acc_range] == gt_index).float().sum() n_correct[acc_range] += n_acc n_examples += n_cands accuracies = { r: 100 * n_acc / n_examples for r, n_acc in n_correct.items() } avg_loss = sum_losses / (n_losses + 0.00001) avg_acc = 100 * correct / (examples + 0.000001) valid_time = time.time() - eval_start logging.info( f"Valid ({shuffled_str}): Epoch = {epoch:d} | avg loss = {avg_loss:.3f} | " f"batch P@1 = {avg_acc:.2f} % | " + f" | ".join(f"P@{k},{nb_candidates} = {v:.2f}%" for k, v in accuracies.items()) + f" | valid time = {valid_time:.2f} (s)") return avg_loss return 10
def validate( epoch, model, data_loader, is_shuffled, max_exs=100000, is_test=False, nb_candidates=100, shuffled_str="shuffled", ): model.eval() examples = 0 eval_start = time.time() sum_losses = 0 n_losses = 0 correct = 0 all_context = [] all_cands = [] all_next_sentences = [] n_skipped = 0 dtype = model.module.opt.dataset_name for i, ex in enumerate(data_loader): if i == 0: print('First context tensor:') print(ex[0]) print('First "next" tensor:') print(ex[1]) batch_size = ex[0].size(0) if dtype == "reddit" and is_test and n_skipped < max_exs: n_skipped += batch_size continue params = [ field.cuda(non_blocking=True) if opt.cuda else field if field is not None else None for field in ex[:2] ] ctx, cands = model(*params) all_context.append(ctx) all_cands.append(cands) all_next_sentences.extend(ex[2]) loss, nb_ok = loss_fn(ctx, cands) sum_losses += loss correct += nb_ok n_losses += 1 examples += batch_size if examples >= max_exs and dtype == "reddit": break n_examples = 0 if len(all_context) > 0: set_string = 'test' if is_test else 'valid' shuffle_string = 'shuffled' if is_shuffled else 'unshuffled' save_path = os.path.join( os.getcwd(), f'{set_string}_candidate_groupings_{shuffle_string}.txt') print(f'Saving candidate groupings to {save_path}.') f = open(save_path, 'w', encoding="utf8") logging.info("Processing candidate top-K") all_context = torch.cat(all_context, dim=0) # [:50000] # [N, 2h] all_cands = torch.cat(all_cands, dim=0) # [:50000] # [N, 2h] acc_ranges = [1, 3, 10] n_correct = {r: 0 for r in acc_ranges} i_grouping = 0 for context, cands in list( zip(all_context.split(nb_candidates), all_cands.split(nb_candidates)))[:-1]: next_sentences = all_next_sentences[nb_candidates * i_grouping:nb_candidates * (i_grouping + 1)] f.write('|'.join(next_sentences) + '\n') i_grouping += 1 _, top_answers = score_candidates(context, cands) n_cands = cands.size(0) gt_index = torch.arange(n_cands, out=top_answers.new(n_cands, 1)) for acc_range in acc_ranges: n_acc = (top_answers[:, :acc_range] == gt_index).float().sum() n_correct[acc_range] += n_acc n_examples += n_cands accuracies = { r: 100 * n_acc / n_examples for r, n_acc in n_correct.items() } avg_loss = sum_losses / (n_losses + 0.00001) avg_acc = 100 * correct / (examples + 0.000001) valid_time = time.time() - eval_start logging.info( f"Valid ({shuffled_str}): Epoch = {epoch:d} | avg loss = {avg_loss:.3f} | " f"batch P@1 = {avg_acc:.2f} % | " + f" | ".join(f"P@{k},{nb_candidates} = {v:.2f}%" for k, v in accuracies.items()) + f" | valid time = {valid_time:.2f} (s)") f.close() return avg_loss return 10