def train_epoch(iter_cnt, encoders, classifiers, attn_mats, train_loader_dst, args, optim_model, epoch): encoders, encoder_dst = encoders map(lambda m: m.train(), classifiers + encoders + attn_mats) moe_criterion = nn.NLLLoss() # with log_softmax separated entropy_criterion = HLoss() loss_total = 0 n_batch = 0 n_sources = len(encoders) for batch in train_loader_dst: if args.base_model == "cnn": batch1, batch2, label = batch elif args.base_model == "rnn": batch1, batch2, batch3, batch4, label = batch else: raise NotImplementedError bs = len(label) iter_cnt += 1 n_batch += 1 if args.cuda: batch1 = batch1.cuda() batch2 = batch2.cuda() label = label.cuda() if args.base_model == "rnn": batch3 = batch3.cuda() batch4 = batch4.cuda() if args.base_model == "cnn": _, hidden_from_dst_enc = encoder_dst(batch1, batch2) elif args.base_model == "rnn": _, hidden_from_dst_enc = encoder_dst(batch1, batch2, batch3, batch4) else: raise NotImplementedError outputs_dst_transfer = [] hidden_from_src_enc = [] one_hot_sources = [] for src_i in range(n_sources): if args.base_model == "cnn": _, cur_hidden = encoders[src_i](batch1, batch2) hidden_from_src_enc.append(cur_hidden) elif args.base_model == "rnn": _, cur_hidden = encoders[src_i](batch1, batch2, batch3, batch4) hidden_from_src_enc.append(cur_hidden) else: raise NotImplementedError cur_output = classifiers[src_i](cur_hidden) outputs_dst_transfer.append(cur_output) cur_one_hot_sources = torch.zeros(size=(bs, n_sources)) cur_one_hot_sources[:, src_i] = 1 one_hot_sources.append(cur_one_hot_sources) # print("one hot sources", one_hot_sources) optim_model.zero_grad() source_ids = range(n_sources) support_ids = [x for x in source_ids] # experts # print("attn mats", attn_mats) # source_alphas = [attn_mats[j](hidden_from_src_enc[j]).squeeze() for j in source_ids] if args.attn_type == "onehot": source_alphas = [ attn_mats[j](one_hot_sources[j]).squeeze() for j in source_ids ] elif args.attn_type == "cor": source_alphas = [ attn_mats[j](hidden_from_src_enc[j], hidden_from_dst_enc).squeeze() for j in source_ids ] else: raise NotImplementedError # source_alphas = [attn_mats[j](hidden_from_src_enc[j], hidden_from_dst_enc).squeeze() for j in source_ids] # source_alphas = [torch.bmm(attn_mats[j](hidden_from_src_enc[j]).unsqueeze(1), hidden_from_dst_enc.unsqueeze(2)).squeeze() for j in source_ids] # print("source alphas", source_alphas[0].size(), source_alphas) support_alphas = [source_alphas[x] for x in support_ids] support_alphas = softmax(support_alphas) source_alphas = softmax(source_alphas) # [ 32, 32, 32 ] if args.cuda: source_alphas = [alpha.cuda() for alpha in source_alphas] source_alphas = torch.stack(source_alphas, dim=0) source_alphas = source_alphas.permute(1, 0) loss_entropy = entropy_criterion(source_alphas) output_moe = sum([alpha.unsqueeze(1).repeat(1, 2) * F.softmax(outputs_dst_transfer[id], dim=1) \ for alpha, id in zip(support_alphas, support_ids)]) loss_moe = moe_criterion(torch.log(output_moe), label) lambda_moe = args.lambda_moe loss = lambda_moe * loss_moe loss += args.lambda_entropy * loss_entropy loss_total += loss.item() loss.backward() optim_model.step() if iter_cnt % 5 == 0: say("{} MOE loss: {:.4f}, Entropy loss: {:.4f}, " "loss: {:.4f}\n".format(iter_cnt, loss_moe.item(), loss_entropy.item(), loss.data.item())) loss_total /= n_batch writer.add_scalar('training_loss', loss_total, epoch) say("\n") return iter_cnt
def train_moe_deep_stack(args): save_model_dir = os.path.join(settings.OUT_DIR, args.test) classifiers, attn_mats = torch.load( os.path.join( save_model_dir, "{}_{}_moe_best_now.mdl".format(args.test, args.base_model))) print("base model", args.base_model) print("classifier", classifiers[0]) source_train_sets = args.train.split(',') pretrain_emb = torch.load( os.path.join(settings.OUT_DIR, "rnn_init_word_emb.emb")) encoders_src = [] for src_i in range(len(source_train_sets)): cur_model_dir = os.path.join(settings.OUT_DIR, source_train_sets[src_i]) if args.base_model == "cnn": encoder_class = CNNMatchModel( input_matrix_size1=args.matrix_size1, input_matrix_size2=args.matrix_size2, mat1_channel1=args.mat1_channel1, mat1_kernel_size1=args.mat1_kernel_size1, mat1_channel2=args.mat1_channel2, mat1_kernel_size2=args.mat1_kernel_size2, mat1_hidden=args.mat1_hidden, mat2_channel1=args.mat2_channel1, mat2_kernel_size1=args.mat2_kernel_size1, mat2_hidden=args.mat2_hidden) elif args.base_model == "rnn": encoder_class = BiLSTM(pretrain_emb=pretrain_emb, vocab_size=args.max_vocab_size, embedding_size=args.embedding_size, hidden_size=args.hidden_size, dropout=args.dropout) else: raise NotImplementedError if args.cuda: encoder_class.load_state_dict( torch.load( os.path.join( cur_model_dir, "{}-match-best-now.mdl".format(args.base_model)))) else: encoder_class.load_state_dict( torch.load(os.path.join( cur_model_dir, "{}-match-best-now.mdl".format(args.base_model)), map_location=torch.device('cpu'))) encoders_src.append(encoder_class) map(lambda m: m.eval(), encoders_src + classifiers + attn_mats) if args.cuda: map(lambda m: m.cuda(), classifiers + encoders_src + attn_mats) if args.base_model == "cnn": train_dataset_dst = ProcessedCNNInputDataset(args.test, "train") valid_dataset = ProcessedCNNInputDataset(args.test, "valid") test_dataset = ProcessedCNNInputDataset(args.test, "test") elif args.base_model == "rnn": train_dataset_dst = ProcessedRNNInputDataset(args.test, "train") valid_dataset = ProcessedRNNInputDataset(args.test, "valid") test_dataset = ProcessedRNNInputDataset(args.test, "test") else: raise NotImplementedError train_loader_dst = data.DataLoader(train_dataset_dst, batch_size=args.batch_size, shuffle=False, num_workers=0) valid_loader = data.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0) test_loader = data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0) say("Corpus loaded.\n") meta_features = np.empty(shape=(0, 192 + 2 * 8)) meta_labels = [] n_sources = len(encoders_src) encoders = encoders_src if args.base_model == "cnn": for batch1, batch2, label in train_loader_dst: if args.cuda: batch1 = batch1.cuda() batch2 = batch2.cuda() label = label.cuda() outputs_dst_transfer = [] hidden_from_src_enc = [] for src_i in range(n_sources): _, cur_hidden = encoders[src_i](batch1, batch2) hidden_from_src_enc.append(cur_hidden) cur_output = classifiers[src_i](cur_hidden) outputs_dst_transfer.append(cur_output) source_ids = range(n_sources) support_ids = [x for x in source_ids] # experts source_alphas = [ attn_mats[j](hidden_from_src_enc[j]).squeeze() for j in source_ids ] support_alphas = [source_alphas[x] for x in support_ids] support_alphas = softmax(support_alphas) source_alphas = softmax(source_alphas) # [ 32, 32, 32 ] alphas = source_alphas
def evaluate_cross(encoder, classifiers, mats, loaders, return_best_thrs, args, thr=None): ''' Evaluate model using MOE ''' map(lambda m: m.eval(), [encoder] + classifiers) if args.metric == "biaffine": Us, Ws, Vs = mats else: Us, Ps, Ns = mats source_loaders, valid_loaders_src = loaders domain_encs = domain_encoding(source_loaders, args, encoder) source_ids = range(len(valid_loaders_src)) thresholds = [] metrics = [] alphas_weights = np.zeros(shape=(4, 4)) for src_i in range(len(valid_loaders_src)): valid_loader = valid_loaders_src[src_i] oracle_correct = 0 correct = 0 tot_cnt = 0 y_true = [] y_pred = [] y_score = [] # support_ids = [x for x in source_ids if x != src_i] # experts support_ids = [x for x in source_ids] # experts cur_domain_encs = [domain_encs[x] for x in support_ids] cur_Us = [Us[x] for x in support_ids] cur_Ps = [Ps[x] for x in support_ids] cur_Ns = [Ns[x] for x in support_ids] cur_alpha_weights = [[]] * 4 cur_alpha_weights_stack = np.empty(shape=(0, len(support_ids))) for batch1, batch2, label in valid_loader: if args.cuda: batch1 = batch1.cuda() batch2 = batch2.cuda() label = label.cuda() # print("eval labels", label) batch1 = Variable(batch1) batch2 = Variable(batch2) _, hidden = encoder(batch1, batch2) # source_ids = range(len(domain_encs)) if args.metric == "biaffine": alphas = [biaffine_metric_fast(hidden, mu[0], Us[0]) \ for mu in domain_encs] else: alphas = [mahalanobis_metric_fast(hidden, mu[0], U, mu[1], P, mu[2], N) \ for (mu, U, P, N) in zip(cur_domain_encs, cur_Us, cur_Ps, cur_Ns)] # alphas = [ (1 - x / sum(alphas)) for x in alphas ] alphas = softmax(alphas) # print("alphas", alphas[0].mean(), alphas[1].mean(), alphas[2].mean()) # print("alphas", alphas) alphas = [] for al_i in range(len(support_ids)): alphas.append(torch.zeros(size=(batch1.size()[0], ))) alphas[src_i] = torch.ones(size=(batch1.size()[0], )) alpha_cat = torch.zeros(size=(alphas[0].shape[0], len(support_ids))) for col, a_list in enumerate(alphas): alpha_cat[:, col] = a_list cur_alpha_weights_stack = np.concatenate( (cur_alpha_weights_stack, alpha_cat.detach().numpy())) # for j, supp_id in enumerate(support_ids): # cur_alpha_weights[supp_id] += alphas[j].data.tolist() # cur_alpha_weights[supp_id].append(alphas[j].mean().item()) if args.cuda: alphas = [alpha.cuda() for alpha in alphas] alphas = [Variable(alpha) for alpha in alphas] outputs = [ F.softmax(classifiers[j](hidden), dim=1) for j in support_ids ] output = sum([alpha.unsqueeze(1).repeat(1, 2) * output_i \ for (alpha, output_i) in zip(alphas, outputs)]) # print("pred output", output) pred = output.data.max(dim=1)[1] oracle_eq = compute_oracle(outputs, label, args) if args.eval_only: for i in range(batch1.shape[0]): for j in range(len(alphas)): say("{:.4f}: [{:.4f}, {:.4f}], ".format( alphas[j].data[i], outputs[j].data[i][0], outputs[j].data[i][1])) oracle_TF = "T" if oracle_eq[i] == 1 else colored( "F", 'red') say("gold: {}, pred: {}, oracle: {}\n".format( label[i], pred[i], oracle_TF)) say("\n") # print torch.cat( # [ # torch.cat([ x.unsqueeze(1) for x in alphas ], 1), # torch.cat([ x for x in outputs ], 1) # ], 1 # ) y_true += label.tolist() y_pred += pred.tolist() y_score += output[:, 1].data.tolist() correct += pred.eq(label).sum() oracle_correct += oracle_eq.sum() tot_cnt += output.size(0) # print("y_true", y_true) # print("y_pred", y_pred) # for j in support_ids: # print(src_i, j, cur_alpha_weights[j]) # alphas_weights[src_i, j] = np.mean(cur_alpha_weights[j]) # print(alphas_weights) alphas_weights[src_i, support_ids] = np.mean(cur_alpha_weights_stack, axis=0) if thr is not None: print("using threshold %.4f" % thr[src_i]) y_score = np.array(y_score) y_pred = np.zeros_like(y_score) y_pred[y_score > thr[src_i]] = 1 # prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary") acc = float(correct) / tot_cnt oracle_acc = float(oracle_correct) / tot_cnt # print("source", src_i, "validation results: precision: {:.2f}, recall: {:.2f}, f1: {:.2f}".format( # prec*100, rec*100, f1*100)) # return (acc, oracle_acc), confusion_matrix(y_true, y_pred) prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary") auc = roc_auc_score(y_true, y_score) print("source {}, AUC: {:.2f}, Prec: {:.2f}, Rec: {:.2f}, F1: {:.2f}". format(src_i, auc * 100, prec * 100, rec * 100, f1 * 100)) metrics.append([auc, prec, rec, f1]) if return_best_thrs: precs, recs, thrs = precision_recall_curve(y_true, y_score) f1s = 2 * precs * recs / (precs + recs) f1s = f1s[:-1] thrs = thrs[~np.isnan(f1s)] f1s = f1s[~np.isnan(f1s)] best_thr = thrs[np.argmax(f1s)] print("best threshold=%4f, f1=%.4f", best_thr, np.max(f1s)) thresholds.append(best_thr) print("source domain weight matrix\n", alphas_weights) metrics = np.array(metrics) return thresholds, metrics, alphas_weights
def evaluate(epoch, encoders, classifiers, attn_mats, data_loader, return_best_thrs, args, thr=None): encoders, encoder_dst = encoders map(lambda m: m.eval(), encoders + classifiers + attn_mats) oracle_correct = 0 correct = 0 tot_cnt = 0 y_true = [] y_pred = [] y_score = [] loss = 0. n_sources = len(encoders) cur_alpha_weights_stack = np.empty(shape=(0, n_sources)) if args.base_model == "cnn": for batch1, batch2, label in data_loader: if args.cuda: batch1 = batch1.cuda() batch2 = batch2.cuda() label = label.cuda() batch1 = Variable(batch1) batch2 = Variable(batch2) bs = len(batch1) _, hidden_from_dst_enc = encoder_dst(batch1, batch2) outputs_dst_transfer = [] hidden_from_src_enc = [] one_hot_sources = [] for src_i in range(n_sources): _, cur_hidden = encoders[src_i](batch1, batch2) hidden_from_src_enc.append(cur_hidden) cur_output = classifiers[src_i](cur_hidden) outputs_dst_transfer.append(cur_output) cur_one_hot_sources = torch.zeros(size=(bs, n_sources)) cur_one_hot_sources[:, src_i] = 1 one_hot_sources.append(cur_one_hot_sources) source_ids = range(n_sources) support_ids = [x for x in source_ids] # experts # source_alphas = [attn_mats[j](hidden_from_src_enc[j]).squeeze() for j in source_ids] # source_alphas = [attn_mats[j](one_hot_sources[j]).squeeze() for j in source_ids] if args.attn_type == "onehot": source_alphas = [ attn_mats[j](one_hot_sources[j]).squeeze() for j in source_ids ] elif args.attn_type == "cor": source_alphas = [ attn_mats[j](hidden_from_src_enc[j], hidden_from_dst_enc).squeeze() for j in source_ids ] else: raise NotImplementedError # source_alphas = [ # torch.bmm(attn_mats[j](hidden_from_src_enc[j]).unsqueeze(1), hidden_from_dst_enc.unsqueeze(2)).squeeze() # for j in source_ids] # source_alphas = [attn_mats[j](hidden_from_src_enc[j], hidden_from_dst_enc).squeeze() for j in source_ids] support_alphas = [source_alphas[x] for x in support_ids] support_alphas = softmax(support_alphas) source_alphas = softmax(source_alphas) # [ 32, 32, 32 ] alphas = source_alphas if args.cuda: alphas = [alpha.cuda() for alpha in alphas] outputs = [F.softmax(out, dim=1) for out in outputs_dst_transfer] alpha_cat = torch.zeros(size=(alphas[0].shape[0], n_sources)) for col, a_list in enumerate(alphas): alpha_cat[:, col] = a_list cur_alpha_weights_stack = np.concatenate( (cur_alpha_weights_stack, alpha_cat.detach().numpy())) output = sum([alpha.unsqueeze(1).repeat(1, 2) * output_i \ for (alpha, output_i) in zip(alphas, outputs)]) pred = output.data.max(dim=1)[1] loss_batch = F.nll_loss(torch.log(output), label) loss += bs * loss_batch.item() y_true += label.tolist() y_pred += pred.tolist() correct += pred.eq(label).sum() tot_cnt += output.size(0) y_score += output[:, 1].data.tolist() elif args.base_model == "rnn": for batch1, batch2, batch3, batch4, label in data_loader: if args.cuda: batch1 = batch1.cuda() batch2 = batch2.cuda() batch3 = batch3.cuda() batch4 = batch4.cuda() label = label.cuda() bs = len(batch1) _, hidden_from_dst_enc = encoder_dst(batch1, batch2, batch3, batch4) outputs_dst_transfer = [] hidden_from_src_enc = [] one_hot_sources = [] for src_i in range(n_sources): _, cur_hidden = encoders[src_i](batch1, batch2, batch3, batch4) hidden_from_src_enc.append(cur_hidden) cur_output = classifiers[src_i](cur_hidden) outputs_dst_transfer.append(cur_output) cur_one_hot_sources = torch.zeros(size=(bs, n_sources)) cur_one_hot_sources[:, src_i] = 1 one_hot_sources.append(cur_one_hot_sources) source_ids = range(n_sources) support_ids = [x for x in source_ids] # experts # source_alphas = [attn_mats[j](hidden_from_src_enc[j]).squeeze() for j in source_ids] # source_alphas = [ # torch.bmm(attn_mats[j](hidden_from_src_enc[j]).unsqueeze(1), hidden_from_dst_enc.unsqueeze(2)).squeeze() # for j in source_ids] # source_alphas = [attn_mats[j](hidden_from_src_enc[j], hidden_from_dst_enc).squeeze() for j in source_ids] # source_alphas = [attn_mats[j](one_hot_sources[j]).squeeze() for j in source_ids] if args.attn_type == "onehot": source_alphas = [ attn_mats[j](one_hot_sources[j]).squeeze() for j in source_ids ] elif args.attn_type == "cor": source_alphas = [ attn_mats[j](hidden_from_src_enc[j], hidden_from_dst_enc).squeeze() for j in source_ids ] else: raise NotImplementedError support_alphas = [source_alphas[x] for x in support_ids] support_alphas = softmax(support_alphas) source_alphas = softmax(source_alphas) # [ 32, 32, 32 ] alphas = source_alphas if args.cuda: alphas = [alpha.cuda() for alpha in alphas] outputs = [F.softmax(out, dim=1) for out in outputs_dst_transfer] alpha_cat = torch.zeros(size=(alphas[0].shape[0], n_sources)) for col, a_list in enumerate(alphas): alpha_cat[:, col] = a_list cur_alpha_weights_stack = np.concatenate( (cur_alpha_weights_stack, alpha_cat.detach().numpy())) output = sum([alpha.unsqueeze(1).repeat(1, 2) * output_i \ for (alpha, output_i) in zip(alphas, outputs)]) pred = output.data.max(dim=1)[1] loss_batch = F.nll_loss(torch.log(output), label) loss += bs * loss_batch.item() y_true += label.tolist() y_pred += pred.tolist() correct += pred.eq(label).sum() tot_cnt += output.size(0) y_score += output[:, 1].data.tolist() else: raise NotImplementedError alpha_weights = np.mean(cur_alpha_weights_stack, axis=0) print("alpha weights", alpha_weights) if thr is not None: print("using threshold %.4f" % thr) y_score = np.array(y_score) y_pred = np.zeros_like(y_score) y_pred[y_score > thr] = 1 else: pass loss /= tot_cnt prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary") auc = roc_auc_score(y_true, y_score) print("Loss: {:.4f}, AUC: {:.2f}, Prec: {:.2f}, Rec: {:.2f}, F1: {:.2f}". format(loss, auc * 100, prec * 100, rec * 100, f1 * 100)) best_thr = None metric = [loss, auc, prec, rec, f1] if return_best_thrs: precs, recs, thrs = precision_recall_curve(y_true, y_score) f1s = 2 * precs * recs / (precs + recs) f1s = f1s[:-1] thrs = thrs[~np.isnan(f1s)] f1s = f1s[~np.isnan(f1s)] best_thr = thrs[np.argmax(f1s)] print("best threshold={:.4f}, f1={:.4f}".format(best_thr, np.max(f1s))) writer.add_scalar('val_loss', loss, epoch) else: writer.add_scalar('test_f1', f1, epoch) return best_thr, metric
def evaluate(epoch, encoders, classifiers, mats, loaders, return_best_thrs, args, thr=None): ''' Evaluate model using MOE ''' encoders, encoder_dst = encoders classifiers, classifier_dst, classifier_mix = classifiers map(lambda m: m.eval(), encoders + classifiers + [encoder_dst, classifier_dst, classifier_mix]) if args.metric == "biaffine": Us, Ws, Vs = mats else: Us, Ps, Ns = mats source_loaders, valid_loader = loaders domain_encs = domain_encoding(source_loaders, args, encoders) oracle_correct = 0 correct = 0 tot_cnt = 0 y_true = [] y_pred = [] y_score = [] loss = 0. source_ids = range(len(domain_encs)) for batch1, batch2, label in valid_loader: if args.cuda: batch1 = batch1.cuda() batch2 = batch2.cuda() label = label.cuda() # print("eval labels", label) batch1 = Variable(batch1) batch2 = Variable(batch2) bs = len(batch1) # print("bs", len(batch1)) _, hidden_dst = encoder_dst(batch1, batch2) cur_output_dst = classifier_dst(hidden_dst) cur_output_dst_mem = torch.softmax(cur_output_dst, dim=1) # print("mem", cur_output_dst_mem) cur_output_dst = torch.log(cur_output_dst_mem) outputs_dst_transfer = [] for src_i in range(len(source_loaders)): _, cur_hidden = encoders[src_i](batch1, batch2) cur_output = classifiers[src_i](cur_hidden) outputs_dst_transfer.append(cur_output) # _, hidden = encoders[0](batch1, batch2) # source_ids = range(len(domain_encs)) if args.metric == "biaffine": alphas = [biaffine_metric_fast(hidden_dst, mu[0], Us[0]) \ for mu in domain_encs] else: alphas = [mahalanobis_metric_fast(hidden_dst, mu[0], U, mu[1], P, mu[2], N) \ for (mu, U, P, N) in zip(domain_encs, Us, Ps, Ns)] # # alphas = [ (1 - x / sum(alphas)) for x in alphas ] alphas = softmax(alphas) if args.cuda: alphas = [alpha.cuda() for alpha in alphas] alphas = [Variable(alpha) for alpha in alphas] # # outputs = [F.softmax(classifier(hidden), dim=1) for classifier in classifiers] output_moe = sum([alpha.unsqueeze(1).repeat(1, 2) * output_i \ for (alpha, output_i) in zip(alphas, outputs_dst_transfer)]) # pred = output.data.max(dim=1)[1] # oracle_eq = compute_oracle(outputs, label, args) # outputs = classifier_mix(torch.cat((cur_output_dst_mem, output_moe), dim=1)) outputs = cur_output_dst_mem + classifier_mix.multp * output_moe # print("weight mix", classifier_mix.multp) outputs_upper_logits = torch.log_softmax(outputs, dim=1) # outputs_upper_logits = torch.log(cur_output_dst_mem) outputs_upper_logits = output_moe # print("outputs_upper_logits", outputs_upper_logits) pred = outputs_upper_logits.data.max(dim=1)[1] # oracle_eq = compute_oracle(outputs_upper_logits, label, args) loss_batch = F.nll_loss(outputs_upper_logits, label) loss += bs * loss_batch.item() # if args.eval_only: # for i in range(batch1.shape[0]): # for j in range(len(alphas)): # say("{:.4f}: [{:.4f}, {:.4f}], ".format( # alphas[j].data[i], outputs[j].data[i][0], outputs[j].data[i][1]) # ) # oracle_TF = "T" if oracle_eq[i] == 1 else colored("F", 'red') # say("gold: {}, pred: {}, oracle: {}\n".format(label[i], pred[i], oracle_TF)) # say("\n") # print torch.cat( # [ # torch.cat([ x.unsqueeze(1) for x in alphas ], 1), # torch.cat([ x for x in outputs ], 1) # ], 1 # ) y_true += label.tolist() y_pred += pred.tolist() # print("output", output[:, 1].data.tolist()) y_score += outputs_upper_logits[:, 1].data.tolist() # print("cur y score", y_score) correct += pred.eq(label).sum() # oracle_correct += oracle_eq.sum() tot_cnt += outputs_upper_logits.size(0) # print("y_true", y_true) # print("y_pred", y_pred) if thr is not None: print("using threshold %.4f" % thr) y_score = np.array(y_score) y_pred = np.zeros_like(y_score) y_pred[y_score > thr] = 1 else: # print("y_score", y_score) pass loss /= tot_cnt prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary") # print("y_score", y_score) auc = roc_auc_score(y_true, y_score) print("Loss: {:.4f}, AUC: {:.2f}, Prec: {:.2f}, Rec: {:.2f}, F1: {:.2f}". format(loss, auc * 100, prec * 100, rec * 100, f1 * 100)) best_thr = None metric = [auc, prec, rec, f1] if return_best_thrs: precs, recs, thrs = precision_recall_curve(y_true, y_score) f1s = 2 * precs * recs / (precs + recs) f1s = f1s[:-1] thrs = thrs[~np.isnan(f1s)] f1s = f1s[~np.isnan(f1s)] best_thr = thrs[np.argmax(f1s)] print("best threshold={:.4f}, f1={:.4f}".format(best_thr, np.max(f1s))) writer.add_scalar('val_loss', loss, epoch) else: writer.add_scalar('test_f1', f1, epoch) acc = float(correct) / tot_cnt oracle_acc = float(oracle_correct) / tot_cnt # return (acc, oracle_acc), confusion_matrix(y_true, y_pred) return best_thr, metric
def train_epoch(iter_cnt, encoders, classifiers, critic, mats, data_loaders, args, optim_model, epoch): encoders, encoder_dst = encoders classifiers, classifier_dst, classifier_mix = classifiers map( lambda m: m.train(), encoders + [encoder_dst, classifier_dst, critic, classifier_mix] + classifiers) train_loaders, train_loader_dst, unl_loader, valid_loader = data_loaders dup_train_loaders = deepcopy(train_loaders) # mtl_criterion = nn.CrossEntropyLoss() mtl_criterion = nn.NLLLoss() moe_criterion = nn.NLLLoss() # with log_softmax separated kl_criterion = nn.MSELoss() entropy_criterion = HLoss() if args.metric == "biaffine": metric = biaffine_metric Us, Ws, Vs = mats else: metric = mahalanobis_metric Us, Ps, Ns = mats loss_total = 0 total = 0 for batches, batches_dst, unl_batch in zip(zip(*train_loaders), train_loader_dst, unl_loader): train_batches1, train_batches2, train_labels = zip(*batches) # print("train batches1", train_labels[0].size()) # print("train batches2", train_batches2) # print("train labels", train_labels) unl_critic_batch1, unl_critic_batch2, unl_critic_label = unl_batch # print("unl", unl_critic_batch1) batches1_dst, batches2_dst, labels_dst = batches_dst # print("batches1_dst", batches1_dst) # print("batches2_dst", batches2_dst) total += len(batches1_dst) iter_cnt += 1 if args.cuda: train_batches1 = [batch.cuda() for batch in train_batches1] train_batches2 = [batch.cuda() for batch in train_batches2] train_labels = [label.cuda() for label in train_labels] batches1_dst = batches1_dst.cuda() batches2_dst = batches2_dst.cuda() labels_dst = labels_dst.cuda() unl_critic_batch1 = unl_critic_batch1.cuda() unl_critic_batch2 = unl_critic_batch2.cuda() unl_critic_label = unl_critic_label.cuda() # train_batches1 = [Variable(batch) for batch in train_batches1] # train_batches2 = [Variable(batch) for batch in train_batches2] # train_labels = [Variable(label) for label in train_labels] # unl_critic_batch1 = Variable(unl_critic_batch1) # unl_critic_batch2 = Variable(unl_critic_batch2) # unl_critic_label = Variable(unl_critic_label) optim_model.zero_grad() loss_train_dst = [] loss_mtl = [] loss_moe = [] loss_kl = [] loss_entropy = [] loss_dan = [] loss_all = [] ms_outputs = [] # (n_sources, n_classifiers) hiddens = [] hidden_corresponding_labels = [] # labels = [] _, hidden_dst = encoder_dst(batches1_dst, batches2_dst) cur_output_dst = classifier_dst(hidden_dst) cur_output_dst_mem = torch.softmax(cur_output_dst, dim=1) cur_output_dst = torch.log(cur_output_dst_mem) loss_train_dst.append(mtl_criterion(cur_output_dst, labels_dst)) outputs_dst_transfer = [] for i in range(len(train_batches1)): _, cur_hidden = encoders[i](batches1_dst, batches2_dst) cur_output = classifiers[i](cur_hidden) outputs_dst_transfer.append(cur_output) for i, (batch1, batch2, label) in enumerate( zip(train_batches1, train_batches2, train_labels)): # source i _, hidden = encoders[i](batch1, batch2) outputs = [] # create output matrix: # - (i, j) indicates the output of i'th source batch using j'th classifier # print("hidden", hidden) # raise hiddens.append(hidden) for classifier in classifiers: output = classifier(hidden) output = torch.log_softmax(output, dim=1) # print("output", output) outputs.append(output) ms_outputs.append(outputs) hidden_corresponding_labels.append(label) # multi-task loss # print("ms & label", ms_outputs[i][i], label) loss_mtl.append(mtl_criterion(ms_outputs[i][i], label)) # labels.append(label) if args.lambda_critic > 0: # critic_batch = torch.cat([batch, unl_critic_batch]) critic_label = torch.cat( [1 - unl_critic_label, unl_critic_label]) # critic_label = torch.cat([1 - unl_critic_label] * len(train_batches) + [unl_critic_label]) if isinstance(critic, ClassificationD): critic_output = critic( torch.cat( hidden, encoders[i](unl_critic_batch1, unl_critic_batch2))) loss_dan.append( critic.compute_loss(critic_output, critic_label)) else: critic_output = critic( hidden, encoders[i](unl_critic_batch1, unl_critic_batch2)) loss_dan.append(critic_output) # critic_output = critic(torch.cat(hiddens), encoder(unl_critic_batch)) # loss_dan = critic_output else: loss_dan = Variable(torch.FloatTensor([0])) # assert (len(outputs) == len(outputs[0])) source_ids = range(len(train_batches1)) # for i in source_ids: # support_ids = [x for x in source_ids if x != i] # experts support_ids = [x for x in source_ids] # experts # i = 0 # support_alphas = [ metric( # hiddens[i], # hiddens[j].detach(), # hidden_corresponding_labels[j], # Us[j], Ps[j], Ns[j], # args) for j in support_ids ] if args.metric == "biaffine": source_alphas = [ metric( hidden_dst, hiddens[j].detach(), Us[0], Ws[0], Vs[0], # for biaffine metric, we use a unified matrix args) for j in source_ids ] else: source_alphas = [ metric( hidden_dst, # i^th source hiddens[j].detach(), hidden_corresponding_labels[j], Us[j], Ps[j], Ns[j], args) for j in source_ids ] support_alphas = [source_alphas[x] for x in support_ids] # print torch.cat([ x.unsqueeze(1) for x in support_alphas ], 1) support_alphas = softmax(support_alphas) # print("support_alphas after softmax", support_alphas) # meta-supervision: KL loss over \alpha and real source source_alphas = softmax(source_alphas) # [ 32, 32, 32 ] source_labels = [ torch.FloatTensor([x == len(train_batches1)]) for x in source_ids ] # one-hot if args.cuda: source_alphas = [alpha.cuda() for alpha in source_alphas] source_labels = [label.cuda() for label in source_labels] source_labels = Variable(torch.stack(source_labels, dim=0)) # 3*1 # print("source labels", source_labels) source_alphas = torch.stack(source_alphas, dim=0) # print("source_alpha after stack", source_alphas) source_labels = source_labels.expand_as(source_alphas).permute(1, 0) source_alphas = source_alphas.permute(1, 0) loss_kl.append(kl_criterion(source_alphas, source_labels)) # entropy loss over \alpha # entropy_loss = entropy_criterion(torch.stack(support_alphas, dim=0).permute(1, 0)) # print source_alphas loss_entropy.append(entropy_criterion(source_alphas)) output_moe_i = sum([alpha.unsqueeze(1).repeat(1, 2) * F.softmax(outputs_dst_transfer[id], dim=1) \ for alpha, id in zip(support_alphas, support_ids)]) # output_moe_full = sum([ alpha.unsqueeze(1).repeat(1, 2) * F.softmax(ms_outputs[i][id], dim=1) \ # for alpha, id in zip(full_alphas, source_ids) ]) # print("output_moe_i & labels", output_moe_i, train_labels[i]) loss_moe.append(moe_criterion(torch.log(output_moe_i), labels_dst)) # loss_moe.append(moe_criterion(torch.log(output_moe_full), train_labels[i])) # print("labels_dst", labels_dst) # upper_out = classifier_mix(torch.cat((cur_output_dst_mem, output_moe_i), dim=1)) upper_out = cur_output_dst_mem + classifier_mix.multp * output_moe_i loss_all = mtl_criterion(torch.log_softmax(upper_out, dim=1), labels_dst) loss_train_dst = sum(loss_train_dst) loss_mtl = sum(loss_mtl) # print("loss mtl", loss_mtl) # loss_mtl = loss_mtl.mean() loss_mtl /= len(source_ids) loss_moe = sum(loss_moe) # if iter_cnt < 400: # lambda_moe = 0 # lambda_entropy = 0 # else: lambda_moe = args.lambda_moe lambda_entropy = args.lambda_entropy # loss = (1 - lambda_moe) * loss_mtl + lambda_moe * loss_moe loss = args.lambda_mtl * loss_mtl + lambda_moe * loss_moe loss_kl = sum(loss_kl) loss_entropy = sum(loss_entropy) loss += args.lambda_entropy * loss_entropy loss += loss_train_dst * args.lambda_dst loss += loss_all * args.lambda_all loss_total += loss if args.lambda_critic > 0: loss_dan = sum(loss_dan) loss += args.lambda_critic * loss_dan loss.backward() optim_model.step() # print("loss entropy", loss_entropy) # print("mats", [Us, Ps, Ns]) # for paras in task_paras: # print(paras) # for name, param in paras: # if param.requires_grad: # print(name, param.data) # for name, param in encoder.named_parameters(): # if param.requires_grad: # # print(name, param.data) # print(name, param.grad) for cls_i, classifier in enumerate(classifiers): for name, param in classifier.named_parameters(): # print(cls_i, name, param.grad) pass if iter_cnt % 5 == 0: # [(mu_i, covi_i), ...] # domain_encs = domain_encoding(dup_train_loaders, args, encoder) if args.metric == "biaffine": mats = [Us, Ws, Vs] else: mats = [Us, Ps, Ns] # evaluate( # # [encoders, encoder_dst], # # [classifiers, classifier_dst, classifier_mix], # # mats, # # [dup_train_loaders, valid_loader], # # True, # # args # # ) # say("\r" + " " * 50) # TODO: print train acc as well # print("loss dan", loss_dan) say("{} MTL loss: {:.4f}, MOE loss: {:.4f}, DAN loss: {:.4f}, " "loss: {:.4f}\n" # ", dev acc/oracle: {:.4f}/{:.4f}" .format(iter_cnt, loss_mtl.item(), loss_moe.item(), loss_dan.item(), loss.item(), # curr_dev, # oracle_curr_dev )) writer.add_scalar('training_loss', loss_total / total, epoch) say("\n") return iter_cnt
def evaluate(encoder, classifiers, mats, loaders, args): ''' Evaluate model using MOE ''' map(lambda m: m.eval(), [encoder] + classifiers) if args.metric == "biaffine": Us, Ws, Vs = mats else: Us, Ps, Ns = mats source_loaders, valid_loader = loaders domain_encs = domain_encoding(source_loaders, args, encoder) oracle_correct = 0 correct = 0 tot_cnt = 0 y_true = [] y_pred = [] for batch, label in valid_loader: if args.cuda: batch = batch.cuda() label = label.cuda() batch = Variable(batch) hidden = encoder(batch) source_ids = range(len(domain_encs)) if args.metric == "biaffine": alphas = [ biaffine_metric_fast(hidden, mu[0], Us[0]) \ for mu in domain_encs ] else: alphas = [ mahalanobis_metric_fast(hidden, mu[0], U, mu[1], P, mu[2], N) \ for (mu, U, P, N) in zip(domain_encs, Us, Ps, Ns) ] # alphas = [ (1 - x / sum(alphas)) for x in alphas ] alphas = softmax(alphas) if args.cuda: alphas = [alpha.cuda() for alpha in alphas] alphas = [Variable(alpha) for alpha in alphas] outputs = [ F.softmax(classifier(hidden), dim=1) for classifier in classifiers ] output = sum([ alpha.unsqueeze(1).repeat(1, 2) * output_i \ for (alpha, output_i) in zip(alphas, outputs) ]) pred = output.data.max(dim=1)[1] oracle_eq = compute_oracle(outputs, label, args) if args.eval_only: for i in range(batch.shape[0]): for j in range(len(alphas)): say("{:.4f}: [{:.4f}, {:.4f}], ".format( alphas[j].data[i], outputs[j].data[i][0], outputs[j].data[i][1])) oracle_TF = "T" if oracle_eq[i] == 1 else colored("F", 'red') say("gold: {}, pred: {}, oracle: {}\n".format( label[i], pred[i], oracle_TF)) say("\n") # print torch.cat( # [ # torch.cat([ x.unsqueeze(1) for x in alphas ], 1), # torch.cat([ x for x in outputs ], 1) # ], 1 # ) y_true += label.tolist() y_pred += pred.tolist() correct += pred.eq(label).sum() oracle_correct += oracle_eq.sum() tot_cnt += output.size(0) acc = float(correct) / tot_cnt oracle_acc = float(oracle_correct) / tot_cnt return (acc, oracle_acc), confusion_matrix(y_true, y_pred)
def train_epoch(iter_cnt, encoder, classifiers, critic, mats, data_loaders, args, optim_model): map(lambda m: m.train(), [encoder, critic] + classifiers) train_loaders, unl_loader, valid_loader = data_loaders dup_train_loaders = deepcopy(train_loaders) mtl_criterion = nn.CrossEntropyLoss() moe_criterion = nn.NLLLoss() # with log_softmax separated kl_criterion = nn.MSELoss() entropy_criterion = HLoss() if args.metric == "biaffine": metric = biaffine_metric Us, Ws, Vs = mats else: metric = mahalanobis_metric Us, Ps, Ns = mats for batches, unl_batch in zip(zip(*train_loaders), unl_loader): train_batches, train_labels = zip(*batches) unl_critic_batch, unl_critic_label = unl_batch iter_cnt += 1 if args.cuda: train_batches = [batch.cuda() for batch in train_batches] train_labels = [label.cuda() for label in train_labels] unl_critic_batch = unl_critic_batch.cuda() unl_critic_label = unl_critic_label.cuda() train_batches = [Variable(batch) for batch in train_batches] train_labels = [Variable(label) for label in train_labels] unl_critic_batch = Variable(unl_critic_batch) unl_critic_label = Variable(unl_critic_label) optim_model.zero_grad() loss_mtl = [] loss_moe = [] loss_kl = [] loss_entropy = [] loss_dan = [] ms_outputs = [] # (n_sources, n_classifiers) hiddens = [] hidden_corresponding_labels = [] # labels = [] for i, (batch, label) in enumerate(zip(train_batches, train_labels)): hidden = encoder(batch) outputs = [] # create output matrix: # - (i, j) indicates the output of i'th source batch using j'th classifier hiddens.append(hidden) for classifier in classifiers: output = classifier(hidden) outputs.append(output) ms_outputs.append(outputs) hidden_corresponding_labels.append(label) # multi-task loss loss_mtl.append(mtl_criterion(ms_outputs[i][i], label)) # labels.append(label) if args.lambda_critic > 0: # critic_batch = torch.cat([batch, unl_critic_batch]) critic_label = torch.cat( [1 - unl_critic_label, unl_critic_label]) # critic_label = torch.cat([1 - unl_critic_label] * len(train_batches) + [unl_critic_label]) if isinstance(critic, ClassificationD): critic_output = critic( torch.cat(hidden, encoder(unl_critic_batch))) loss_dan.append( critic.compute_loss(critic_output, critic_label)) else: critic_output = critic(hidden, encoder(unl_critic_batch)) loss_dan.append(critic_output) # critic_output = critic(torch.cat(hiddens), encoder(unl_critic_batch)) # loss_dan = critic_output else: loss_dan = Variable(torch.FloatTensor([0])) # assert (len(outputs) == len(outputs[0])) source_ids = range(len(train_batches)) for i in source_ids: support_ids = [x for x in source_ids if x != i] # experts # support_alphas = [ metric( # hiddens[i], # hiddens[j].detach(), # hidden_corresponding_labels[j], # Us[j], Ps[j], Ns[j], # args) for j in support_ids ] if args.metric == "biaffine": source_alphas = [ metric( hiddens[i], hiddens[j].detach(), Us[0], Ws[0], Vs[0], # for biaffine metric, we use a unified matrix args) for j in source_ids ] else: source_alphas = [ metric( hiddens[i], # i^th source hiddens[j].detach(), hidden_corresponding_labels[j], Us[j], Ps[j], Ns[j], args) for j in source_ids ] support_alphas = [source_alphas[x] for x in support_ids] # print torch.cat([ x.unsqueeze(1) for x in support_alphas ], 1) support_alphas = softmax(support_alphas) # print("support_alphas after softmax", support_alphas) # meta-supervision: KL loss over \alpha and real source source_alphas = softmax(source_alphas) # [ 32, 32, 32 ] source_labels = [torch.FloatTensor([x == i]) for x in source_ids] # one-hot if args.cuda: source_alphas = [alpha.cuda() for alpha in source_alphas] source_labels = [label.cuda() for label in source_labels] source_labels = Variable(torch.stack(source_labels, dim=0)) # 3*1 source_alphas = torch.stack(source_alphas, dim=0) print("source_alpha after stack", source_alphas.size()) source_labels = source_labels.expand_as(source_alphas).permute( 1, 0) source_alphas = source_alphas.permute(1, 0) loss_kl.append(kl_criterion(source_alphas, source_labels)) # entropy loss over \alpha # entropy_loss = entropy_criterion(torch.stack(support_alphas, dim=0).permute(1, 0)) # print source_alphas loss_entropy.append(entropy_criterion(source_alphas)) output_moe_i = sum([ alpha.unsqueeze(1).repeat(1, 2) * F.softmax(ms_outputs[i][id], dim=1) \ for alpha, id in zip(support_alphas, support_ids) ]) # output_moe_full = sum([ alpha.unsqueeze(1).repeat(1, 2) * F.softmax(ms_outputs[i][id], dim=1) \ # for alpha, id in zip(full_alphas, source_ids) ]) loss_moe.append( moe_criterion(torch.log(output_moe_i), train_labels[i])) # loss_moe.append(moe_criterion(torch.log(output_moe_full), train_labels[i])) loss_mtl = sum(loss_mtl) loss_moe = sum(loss_moe) # if iter_cnt < 400: # lambda_moe = 0 # lambda_entropy = 0 # else: lambda_moe = args.lambda_moe lambda_entropy = args.lambda_entropy # loss = (1 - lambda_moe) * loss_mtl + lambda_moe * loss_moe loss = loss_mtl + lambda_moe * loss_moe loss_kl = sum(loss_kl) loss_entropy = sum(loss_entropy) loss += args.lambda_entropy * loss_entropy if args.lambda_critic > 0: loss_dan = sum(loss_dan) loss += args.lambda_critic * loss_dan loss.backward() optim_model.step() if iter_cnt % 30 == 0: # [(mu_i, covi_i), ...] # domain_encs = domain_encoding(dup_train_loaders, args, encoder) if args.metric == "biaffine": mats = [Us, Ws, Vs] else: mats = [Us, Ps, Ns] (curr_dev, oracle_curr_dev), confusion_mat = evaluate( encoder, classifiers, mats, [dup_train_loaders, valid_loader], args) # say("\r" + " " * 50) # TODO: print train acc as well # print("loss dan", loss_dan) say("{} MTL loss: {:.4f}, MOE loss: {:.4f}, DAN loss: {:.4f}, " "loss: {:.4f}, dev acc/oracle: {:.4f}/{:.4f}\n".format( iter_cnt, loss_mtl.item(), loss_moe.item(), loss_dan.item(), loss.item(), curr_dev, oracle_curr_dev)) say("\n") return iter_cnt
def evaluate(encoders, classifiers, mats, loaders, return_best_thrs, args, thr=None): ''' Evaluate model using MOE ''' map(lambda m: m.eval(), [encoders] + classifiers) if args.metric == "biaffine": Us, Ws, Vs = mats else: Us, Ps, Ns = mats source_loaders, valid_loader = loaders domain_encs = domain_encoding(source_loaders, args, encoders) oracle_correct = 0 correct = 0 tot_cnt = 0 y_true = [] y_pred = [] y_score = [] for batch1, batch2, label in valid_loader: if args.cuda: batch1 = batch1.cuda() batch2 = batch2.cuda() label = label.cuda() # print("eval labels", label) batch1 = Variable(batch1) batch2 = Variable(batch2) # print("bs", len(batch1)) _, hidden = encoders[0](batch1, batch2) source_ids = range(len(domain_encs)) if args.metric == "biaffine": alphas = [biaffine_metric_fast(hidden, mu[0], Us[0]) \ for mu in domain_encs] else: alphas = [mahalanobis_metric_fast(hidden, mu[0], U, mu[1], P, mu[2], N) \ for (mu, U, P, N) in zip(domain_encs, Us, Ps, Ns)] # alphas = [ (1 - x / sum(alphas)) for x in alphas ] alphas = softmax(alphas) if args.cuda: alphas = [alpha.cuda() for alpha in alphas] alphas = [Variable(alpha) for alpha in alphas] outputs = [ F.softmax(classifier(hidden), dim=1) for classifier in classifiers ] output = sum([alpha.unsqueeze(1).repeat(1, 2) * output_i \ for (alpha, output_i) in zip(alphas, outputs)]) pred = output.data.max(dim=1)[1] oracle_eq = compute_oracle(outputs, label, args) if args.eval_only: for i in range(batch1.shape[0]): for j in range(len(alphas)): say("{:.4f}: [{:.4f}, {:.4f}], ".format( alphas[j].data[i], outputs[j].data[i][0], outputs[j].data[i][1])) oracle_TF = "T" if oracle_eq[i] == 1 else colored("F", 'red') say("gold: {}, pred: {}, oracle: {}\n".format( label[i], pred[i], oracle_TF)) say("\n") # print torch.cat( # [ # torch.cat([ x.unsqueeze(1) for x in alphas ], 1), # torch.cat([ x for x in outputs ], 1) # ], 1 # ) y_true += label.tolist() y_pred += pred.tolist() # print("output", output[:, 1].data.tolist()) y_score += output[:, 1].data.tolist() # print("cur y score", y_score) correct += pred.eq(label).sum() oracle_correct += oracle_eq.sum() tot_cnt += output.size(0) # print("y_true", y_true) # print("y_pred", y_pred) if thr is not None: print("using threshold %.4f" % thr) y_score = np.array(y_score) y_pred = np.zeros_like(y_score) y_pred[y_score > thr] = 1 else: # print("y_score", y_score) pass prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="binary") auc = roc_auc_score(y_true, y_score) print("AUC: {:.2f}, Prec: {:.2f}, Rec: {:.2f}, F1: {:.2f}".format( auc * 100, prec * 100, rec * 100, f1 * 100)) best_thr = None metric = [auc, prec, rec, f1] if return_best_thrs: precs, recs, thrs = precision_recall_curve(y_true, y_score) f1s = 2 * precs * recs / (precs + recs) f1s = f1s[:-1] thrs = thrs[~np.isnan(f1s)] f1s = f1s[~np.isnan(f1s)] best_thr = thrs[np.argmax(f1s)] print("best threshold={:.4f}, f1={:.4f}".format(best_thr, np.max(f1s))) acc = float(correct) / tot_cnt oracle_acc = float(oracle_correct) / tot_cnt # return (acc, oracle_acc), confusion_matrix(y_true, y_pred) return best_thr, metric