def finetune(model): print("Start finetuning") dataset = Dataset(split='gold', use_gold=True) loader = torch.utils.data.DataLoader(dataset, 128, shuffle=True, num_workers=0, collate_fn=collate_samples) criterion = nn.BCEWithLogitsLoss(reduction='none') optim = torch.optim.Adam(model.parameters()) model.train() vat_loss = VATLoss() for _, batch in enumerate(loader): move_batch(batch) optim.zero_grad() xs = model(**batch, feature_only=True) lds = vat_loss(model, xs) logits = model(xs=xs) pos_weight = 1. + (batch['ys'] * 8.) loss = (pos_weight * criterion(logits, batch['ys'])).mean() + .1 * lds loss.backward() optim.step() del batch validate(model)
def train(model, device, loader, optimizer, config): model.train() correct = 0 total_loss = 0 for batch_idx, sample in enumerate(loader): data = sample['image'] target = sample['target'] if config.vat: vat_loss = VATLoss(xi=config.xi, eps=config.eps, ip=config.ip) ul_data = sample['ul_image'] ul_data = ul_data.to(device) lds = vat_loss(model, ul_data) data, target = data.to(device), target.to(device) optimizer.zero_grad() output = model(data) loss_fun = torch.nn.CrossEntropyLoss() if config.null_space_tuning: ul_data1 = sample['ul_img1'] ul_data2 = sample['ul_img2'] ul_data1, ul_data2 = ul_data1.to(device), ul_data2.to(device) ul_out1 = model(ul_data1) ul_out2 = model(ul_data2) null_loss_fun = torch.nn.MSELoss() loss = loss_fun(output, target) + config.alpha * null_loss_fun( ul_out1, ul_out2) elif config.vat: loss = loss_fun(output, target) + config.alpha * lds else: loss = loss_fun(output, target) pred = output.max(1, keepdim=True)[1] correct += pred.eq(target.view(-1, 1)).sum().item() total_loss += loss.item() loss.backward() optimizer.step() avg_loss = total_loss / (batch_idx + 1) accuracy = 100 * (correct / loader.dataset.len) print('\tTraining set: Average loss: {:.4f}, Accuracy: {:.0f}%'.format( avg_loss, accuracy)) return avg_loss, accuracy
def train(args, model, device, data_iterators, optimizer): model.train() for i in tqdm(range(args.iters)): # reset if i % args.log_interval == 0: ce_losses = utils.AverageMeter() vat_losses = utils.AverageMeter() prec1 = utils.AverageMeter() x_l, y_l = next(data_iterators['labeled']) x_ul, _ = next(data_iterators['unlabeled']) x_l, y_l = x_l.to(device), y_l.to(device) x_ul = x_ul.to(device) optimizer.zero_grad() vat_loss = VATLoss(xi=args.xi, eps=args.eps, ip=args.ip) cross_entropy = nn.CrossEntropyLoss() lds = vat_loss(model, x_ul) output = model(x_l) classification_loss = cross_entropy(output, y_l) loss = classification_loss + args.alpha * lds loss.backward() optimizer.step() acc = utils.accuracy(output, y_l) ce_losses.update(classification_loss.item(), x_l.shape[0]) vat_losses.update(lds.item(), x_ul.shape[0]) prec1.update(acc.item(), x_l.shape[0]) if i % args.log_interval == 0: print( f'\nIteration: {i}\t' f'CrossEntropyLoss {ce_losses.val:.4f} ({ce_losses.avg:.4f})\t' f'VATLoss {vat_losses.val:.4f} ({vat_losses.avg:.4f})\t' f'Prec@1 {prec1.val:.3f} ({prec1.avg:.3f})')
def scAdapt(args, data_set): ## prepare data batch_size = args.batch_size kwargs = {'num_workers': 0, 'pin_memory': True} source_name = args.source_name #"TM_baron_mouse_for_baron" target_name = args.target_name #"baron_human" domain_to_indices = np.where(data_set['accessions'] == source_name)[0] train_set = {'features': data_set['features'][domain_to_indices], 'labels': data_set['labels'][domain_to_indices], 'accessions': data_set['accessions'][domain_to_indices]} domain_to_indices = np.where(data_set['accessions'] == target_name)[0] test_set = {'features': data_set['features'][domain_to_indices], 'labels': data_set['labels'][domain_to_indices], 'accessions': data_set['accessions'][domain_to_indices]} print('source labels:', np.unique(train_set['labels']), ' target labels:', np.unique(test_set['labels'])) test_set_eval = {'features': data_set['features'][domain_to_indices], 'labels': data_set['labels'][domain_to_indices], 'accessions': data_set['accessions'][domain_to_indices]} print(train_set['features'].shape, test_set['features'].shape) data = torch.utils.data.TensorDataset( torch.FloatTensor(train_set['features']), torch.LongTensor(matrix_one_hot(train_set['labels'], int(max(train_set['labels'])+1)).long())) source_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs) data = torch.utils.data.TensorDataset( torch.FloatTensor(test_set['features']), torch.LongTensor(matrix_one_hot(test_set['labels'], int(max(train_set['labels'])+1)).long())) target_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs) target_test_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=False, drop_last=False, **kwargs) class_num = max(train_set['labels'])+1 class_num_test = max(test_set['labels']) + 1 ### re-weighting the classifier cls_num_list = [np.sum(train_set['labels'] == i) for i in range(class_num)] #from https://github.com/YyzHarry/imbalanced-semi-self/blob/master/train.py # # Normalized weights based on inverse number of effective data per class. #2019 Learning Imbalanced Datasets with Label-Distribution-Aware Margin Loss #2020 Rethinking the Value of Labels for Improving Class-Imbalanced Learning beta = 0.9999 effective_num = 1.0 - np.power(beta, cls_num_list) per_cls_weights = (1.0 - beta) / np.array(effective_num) per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list) per_cls_weights = torch.FloatTensor(per_cls_weights).cuda() ## set base network embedding_size = args.embedding_size base_network = FeatureExtractor(num_inputs=train_set['features'].shape[1], embed_size = embedding_size).cuda() label_predictor = LabelPredictor(base_network.output_num(), class_num).cuda() total_model = nn.Sequential(base_network, label_predictor) center_loss = CenterLoss(num_classes=class_num, feat_dim=embedding_size, use_gpu=True) optimizer_centloss = torch.optim.SGD([{'params': center_loss.parameters()}], lr=0.5) print("output size of FeatureExtractor and LabelPredictor: ", base_network.output_num(), class_num) ad_net = scAdversarialNetwork(base_network.output_num(), 1024).cuda() ## set optimizer config_optimizer = {"lr_type": "inv", "lr_param": {"lr": 0.001, "gamma": 0.001, "power": 0.75}} parameter_list = base_network.get_parameters() + ad_net.get_parameters() + label_predictor.get_parameters() optimizer = optim.SGD(parameter_list, lr=1e-3, weight_decay=5e-4, momentum=0.9, nesterov=True) schedule_param = config_optimizer["lr_param"] lr_scheduler = lr_schedule.schedule_dict[config_optimizer["lr_type"]] ## train len_train_source = len(source_loader) len_train_target = len(target_loader) transfer_loss_value = classifier_loss_value = total_loss_value = 0.0 epoch_global = 0.0 hit = False s_global_centroid = torch.zeros(class_num, embedding_size).cuda() t_global_centroid = torch.zeros(class_num, embedding_size).cuda() for epoch in range(args.num_iterations): if epoch % (2500) == 0 and epoch != 0: feature_target = base_network(torch.FloatTensor(test_set['features']).cuda()) output_target = label_predictor.forward(feature_target) softmax_out = nn.Softmax(dim=1)(output_target) predict_prob_arr, predict_label_arr = torch.max(softmax_out, 1) if epoch == args.epoch_th: data = torch.utils.data.TensorDataset(torch.FloatTensor(test_set['features']), predict_label_arr.cpu()) target_loader_align = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs) result_path = args.result_path #"../results/" model_file = result_path + 'final_model_' + str(epoch) + source_name + target_name+'.ckpt' torch.save({'base_network': base_network.state_dict(), 'label_predictor': label_predictor.state_dict()}, model_file) if not os.path.exists(result_path): os.makedirs(result_path) with torch.no_grad(): code_arr_s = base_network(Variable(torch.FloatTensor(train_set['features']).cuda())) code_arr_t = base_network(Variable(torch.FloatTensor(test_set_eval['features']).cuda())) code_arr = np.concatenate((code_arr_s.cpu().data.numpy(), code_arr_t.cpu().data.numpy()), 0) digit_label_dict = pd.read_csv(args.dataset_path + 'digit_label_dict.csv') digit_label_dict = pd.DataFrame(zip(digit_label_dict.iloc[:,0], digit_label_dict.index), columns=['digit','label']) digit_label_dict = digit_label_dict.to_dict()['label'] # transform digit label to cell type name y_pred_label = [digit_label_dict[x] if x in digit_label_dict else x for x in predict_label_arr.cpu().data.numpy()] pred_labels_file = result_path + 'pred_labels_' + source_name + "_" + target_name + "_" + str(epoch) + ".csv" pd.DataFrame([predict_prob_arr.cpu().data.numpy(), y_pred_label], index=["pred_probability", "pred_label"]).to_csv(pred_labels_file, sep=',') embedding_file = result_path + 'embeddings_' + source_name + "_" + target_name + "_" + str(epoch)+ ".csv" pd.DataFrame(code_arr).to_csv(embedding_file, sep=',') #### only for evaluation # acc_by_label = np.zeros( class_num_test ) # all_label = test_set['labels'] # for i in range(class_num_test): # acc_by_label[i] = np.sum(predict_label_arr.cpu().data.numpy()[all_label == i] == i) / np.sum(all_label == i) # np.set_printoptions(suppress=True) # print('iter:', epoch, "average acc over all test cell types: ", round(np.nanmean(acc_by_label), 3)) # print("acc of each test cell type: ", acc_by_label) # div_score, div_score_all, ent_score, sil_score = evaluate_multibatch(code_arr, train_set, test_set_eval, epoch) #results_file = result_path + source_name + "_" + target_name + "_" + str(epoch)+ "_acc_div_sil.csv" #evel_res = [np.nanmean(acc_by_label), div_score, div_score_all, ent_score, sil_score] #pd.DataFrame(evel_res, index = ["acc","div_score","div_score_all","ent_score","sil_score"], columns=["values"]).to_csv(results_file, sep=',') # pred_labels_file = result_path + source_name + "_" + target_name + "_" + str(epoch) + "_pred_labels.csv" # pd.DataFrame([predict_label_arr.cpu().data.numpy(), all_label], index=["pred_label", "true_label"]).to_csv(pred_labels_file, sep=',') ## train one iter base_network.train(True) ad_net.train(True) label_predictor.train(True) optimizer = lr_scheduler(optimizer, epoch, **schedule_param) optimizer.zero_grad() optimizer_centloss.zero_grad() if epoch % len_train_source == 0: iter_source = iter(source_loader) epoch_global = epoch_global + 1 if epoch % len_train_target == 0: if epoch < args.epoch_th: iter_target = iter(target_loader) else: hit = True iter_target = iter(target_loader_align) inputs_source, labels_source = iter_source.next() inputs_target, labels_target = iter_target.next() inputs_source, inputs_target, labels_source, labels_target = inputs_source.cuda(), inputs_target.cuda(), labels_source.cuda(), labels_target.cuda() feature_source = base_network(inputs_source) feature_target = base_network(inputs_target) features = torch.cat((feature_source, feature_target), dim=0) output_source = label_predictor.forward(feature_source) output_target = label_predictor.forward(feature_target) ######## VAT and BNM loss # LDS should be calculated before the forward for cross entropy vat_loss = VATLoss(xi=args.xi, eps=args.eps, ip=args.ip) lds_loss = vat_loss(total_model, inputs_target) softmax_tgt = nn.Softmax(dim=1)(output_target[:, 0:class_num]) _, s_tgt, _ = torch.svd(softmax_tgt) BNM_loss = -torch.mean(s_tgt) ########domain alignment loss if args.method == 'DANN': domain_prob_discriminator_1_source = ad_net.forward(feature_source) domain_prob_discriminator_1_target = ad_net.forward(feature_target) adv_loss = loss_utility.BCELossForMultiClassification(label=torch.ones_like(domain_prob_discriminator_1_source), \ predict_prob=domain_prob_discriminator_1_source) # domain matching adv_loss += loss_utility.BCELossForMultiClassification(label=torch.ones_like(domain_prob_discriminator_1_target), \ predict_prob=1 - domain_prob_discriminator_1_target) transfer_loss = adv_loss elif args.method == 'mmd': base = 1.0 # sigma for MMD sigma_list = [1, 2, 4, 8, 16] sigma_list = [sigma / base for sigma in sigma_list] transfer_loss = loss_utility.mix_rbf_mmd2(feature_source, feature_target, sigma_list) ######CrossEntropyLoss classifier_loss = nn.CrossEntropyLoss(weight=per_cls_weights)(output_source, torch.max(labels_source, dim=1)[1]) # classifier_loss = loss_utility.CrossEntropyLoss(labels_source.float(), nn.Softmax(dim=1)(output_source)) ######semantic_loss and center loss cell_th = args.cell_th epoch_th = args.epoch_th if epoch < args.epoch_th or hit == False: semantic_loss = torch.FloatTensor([0.0]).cuda() center_loss_src = torch.FloatTensor([0.0]).cuda() sum_dist_loss = torch.FloatTensor([0.0]).cuda() # center_loss.centers = feature_source[torch.max(labels_source, dim=1)[1] == 0].mean(dim=0, keepdim=True) pass elif hit == True: center_loss_src = center_loss(feature_source, labels=torch.max(labels_source, dim=1)[1]) s_global_centroid = center_loss.centers semantic_loss, s_global_centroid, t_global_centroid = loss_utility.semant_use_s_center(class_num, s_global_centroid, t_global_centroid, feature_source, feature_target, torch.max( labels_source, dim=1)[1], labels_target, 0.7, cell_th) #softmax_tgt if epoch > epoch_th: lds_loss = torch.FloatTensor([0.0]).cuda() if epoch <= args.num_iterations: progress = epoch / args.epoch_th #args.num_iterations else: progress = 1 lambd = 2 / (1 + math.exp(-10 * progress)) - 1 total_loss = classifier_loss + lambd*args.DA_coeff * transfer_loss + lambd*args.BNM_coeff*BNM_loss + lambd*args.alpha*lds_loss\ + args.semantic_coeff *semantic_loss + args.centerloss_coeff*center_loss_src total_loss.backward() optimizer.step() # multiple (1./centerloss_coeff) in order to remove the effect of centerloss_coeff on updating centers if args.centerloss_coeff > 0 and center_loss_src > 0: for param in center_loss.parameters(): param.grad.data *= (1. / args.centerloss_coeff) optimizer_centloss.step() #optimize the center in center loss
#calculate Vst Vst = Classification_loss + args.lambda_val * Adversarial_DA_loss t_class_output,_, domain_output_mt,_ = my_net(t_imgs,alpha=alpha) #calculate Vmt if args.dataset_name == "OfficeHome": Vmt = loss_class(domain_output_mt, t_labels) else: t_labels = torch.unsqueeze(t_labels, 1) t_labels = t_labels.float() Vmt = loss_domain(domain_output_mt, t_labels) #calculate Lent Lent = loss_entropy(t_class_output) #calculate Lvir vat_loss = VATLoss(xi=10.0, eps=1.0, ip=1) s_vat_loss = vat_loss(my_net,s_imgs) t_vat_loss = vat_loss(my_net,t_imgs) Lvir = s_vat_loss + args.rho_val * t_vat_loss loss = Vst + args.gamma_val*Vmt + args.beta_val*Lent + Lvir loss.backward() optim_net.step() if iter_count%10 == 0: tqdm.write("iter is : {} Classification loss is :{:.3f} Lent is :{:.3f} Adversarial DA loss is :{:.3f} Vmt is:{:.3f} gamma_val is:{:.3f}".format(iter_count,Classification_loss.item(),\ Lent.item(), Adversarial_DA_loss.item(),Vmt.item(), args.gamma_val)) print_log(iter_count+1, \ Classification_loss.item(), \ Lent.item(), \
writer.add_figure(nametag, fig, epoch) ### Model input_size = 512 num_feature = 1024 class_num = 10 generator_g = net.Encoder().cuda() classifier_c = net.Classifier(512, class_num).cuda() classifier_j = net.Classifier(512, class_num * 2).cuda() ### Loss & Optimizers # Loss functions criterion_CE = torch.nn.CrossEntropyLoss().cuda() criterion_VAT = VATLoss().cuda() # Optimizers generator_g_optim = torch.optim.Adam(generator_g.parameters(), lr=opt.lr) classifier_c_optim = torch.optim.Adam(classifier_c.parameters(), lr=opt.lr) classifier_j_optim = torch.optim.Adam(classifier_j.parameters(), lr=opt.lr) # ---------- # Training # ---------- print('') niter = 0 epoch = 0 best_test = 0
def train_vat(fold=2, disable_progress=False): directory = './data/' device = 'cuda' if torch.cuda.is_available() else 'cpu' train_df = pd.read_csv('folds.csv') train_path = os.path.join(directory, 'train') train_ids = train_df[train_df.fold != fold]['id'].values val_ids = train_df[train_df.fold == fold]['id'].values dataset = TGSSaltDataset(train_path, train_ids, augment=True) dataset_val = TGSSaltDataset(train_path, val_ids) test_path = os.path.join(directory, 'test') test_file_list = glob.glob(os.path.join(test_path, 'images', '*.png')) test_file_list = [f.split('/')[-1].split('.')[0] for f in test_file_list] test_file_list[:3], test_path dataset_test = TGSSaltDataset(test_path, test_file_list, is_test=True) model = UnetModel() model.train() model.to(device) #load_auto(f'auto-new-{num_filters}-s2-best.pth', model, device) #set_encoder_train(model, False) epoch = 100 learning_rate = 5e-3 alpha = 1.0 #loss_fn = torch.nn.BCEWithLogitsLoss() vat_loss = VATLoss(eps=1.0, ip=1) empty_loss_fn = nn.BCEWithLogitsLoss() ent_loss_fn = EntropyLoss() #mask = get_mask().to(device) loss_fn = LovaszLoss() optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, nesterov=True, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, factor=0.5) try: load_checkpoint(f'unet-fold{fold}-s2-best.pth', model, optimizer) except: print('oops no file') train_iter = data.DataLoader(dataset, batch_size=16, shuffle=True) test_iter = data.DataLoader(dataset_test, batch_size=16, shuffle=True) best_iou = 0 for e in range(epoch): train_loss = [] smooth_loss = [] unlabeled_loss = [] #for sample in tqdm(data.DataLoader(dataset, batch_size = 30, shuffle = True)): for sample, test_sample in zip(tqdm(train_iter), test_iter): image, mask = sample['image'], sample['mask'] image = image.type(torch.float).to(device) test_image = test_sample['image'] test_image = test_image.type(torch.float).to(device) lds = 0 # vat_loss(model, image) lds_test = 0 # vat_loss(model, test_image) y_pred, y_pred_empty, _ = model(image) test_pred, test_pred_empty, _ = model(test_image) ul_loss = ent_loss_fn(test_pred_empty) direct_loss = loss_fn(y_pred, mask.to(device)) class_loss = empty_loss_fn(y_pred_empty, empty_mask(mask).to(device)) loss = direct_loss + 0.05 * (class_loss + alpha * (lds + lds_test) + ul_loss) optimizer.zero_grad() loss.backward() optimizer.step() train_loss.append(direct_loss.item()) smooth_loss.append(0) #lds.item() + lds_test.item()) unlabeled_loss.append(ul_loss.item()) val_loss = [] val_iou = [] for sample in data.DataLoader(dataset_val, batch_size=16, shuffle=False): image, mask = sample['image'], sample['mask'] image = image.to(device) y_pred, _, _ = model(image) loss = loss_fn(y_pred, mask.to(device)) val_loss.append(loss.item()) iou = iou_pytorch((y_pred > 0).int(), mask.int().to(device)).cpu() val_iou.append(iou) avg_iou = np.mean(val_iou) scheduler.step(np.mean(val_loss)) print( "Epoch: %d, Train: %.3f, Smooth: %.3f, UL: %.3f, Val: %.3f, IoU: %.3f" % (e, np.mean(train_loss), np.mean(smooth_loss), np.mean(unlabeled_loss), np.mean(val_loss), avg_iou)) if avg_iou > best_iou: print('saving new best') save_checkpoint(f'unet-fold{fold}-vat-best.pth', model, optimizer) best_iou = avg_iou print('Best IoU: %.3f' % best_iou)
def train(): LOG_FREQ = 10 VAL_FREQ = 2000 QUERY_FREQ = 400 QUERY_SIZE = 20 with open('data/priors.pkl', 'rb') as f: priors = pickle.load(f) mean_prior = np.mean(list(priors.values())) dataset = Dataset(use_gold=True) loader = torch.utils.data.DataLoader(dataset, 128, shuffle=True, num_workers=0, collate_fn=collate_samples) model = Model(num_classes=len(dataset.type_to_id), num_props=len(dataset.prop_to_id), num_chars=len(dataset.char_to_id)).to(device) criterion = nn.BCELoss(reduction='none') optim = torch.optim.Adam(model.parameters()) vat_loss = VATLoss() total_it = 0 for ep in range(3): model.train() running_loss = 0 running_acc = 0 for i, batch in enumerate(loader): if total_it % QUERY_FREQ == 0: print( f"Querying {len(gold_label_cache)} - {len(gold_label_cache) + QUERY_SIZE}" ) run_query(model, QUERY_SIZE) for _ in range(2): finetune(model) # fine-tune on annotated pairs move_batch(batch) batch_priors = [ 0.5 + priors.get(n, mean_prior) / 2 for n in batch['names'] ] batch_priors = torch.tensor(batch_priors, dtype=torch.float).view( -1, 1).repeat(1, model.num_classes) is_gold_label = torch.tensor( [x in gold_label_cache for x in batch['names']], dtype=torch.float).to(device) is_gold_label = is_gold_label.view(-1, 1).repeat(1, model.num_classes) batch_priors = torch.max(0.8 * batch_priors, 2. * is_gold_label) total_it += 1 optim.zero_grad() xs = model(**batch, feature_only=True) lds = vat_loss(model, xs) logits = model(xs=xs) # MAP noise model probs = logits.sigmoid() retain_probs = model.simple_transition.sigmoid() retain_probs = torch.max(retain_probs, is_gold_label) adjusted_probs = retain_probs * probs + \ (1-retain_probs) * (1 - probs) # noise model + prior loss = (batch_priors * criterion(adjusted_probs, batch['ys'])).mean() + .1 * lds # noise model # loss = (criterion(adjusted_probs, batch['ys'])).mean() # vanilla # loss = (criterion(probs, batch['ys'])).mean() + .1 * lds loss.backward() optim.step() running_loss += loss.item() del batch if i % LOG_FREQ == LOG_FREQ - 1: print( f"Train {i:05d}/{ep:05d} Loss {running_loss / LOG_FREQ:.4f} Acc {running_acc / LOG_FREQ: .4f}" ) running_acc = 0 running_loss = 0. if i % VAL_FREQ == VAL_FREQ - 1: validate(model) validate(model) save_path = f'checkpoints/{model_name}_{ep}.pyt' print(f'Save to {save_path}') model.save(save_path)