def score_fn(x, y=None): if args.score_fn == "px": return f(x).detach().cpu() elif args.score_fn == "py": return nn.functional.softmax(f.classify(x), dim=1).max(1)[0].detach().cpu() elif args.score_fn == "pxgrad": return -torch.log(grad_norm(x).detach().cpu()) elif args.score_fn == "pxsim": assert args.pxycontrast > 0 dist = smooth_one_hot(y, args.n_classes, args.smoothing) output, target, ce_output, neg_num = f.joint(img=x, dist=dist, evaluation=True) simloss = nn.CrossEntropyLoss(reduction="none")(output, target) simloss = simloss - np.log(neg_num) simloss = -1.0 * simloss return simloss.detach().cpu()
def sample_q(f, replay_buffer, y=None, n_steps=args.n_steps, contrast=False): """this func takes in replay_buffer now so we have the option to sample from scratch (i.e. replay_buffer==[]). See test_wrn_ebm.py for example. """ f.eval() # get batch size bs = args.sgld_batch_size if y is None else y.size(0) # generate initial samples and buffer inds of those samples (if buffer is used) init_sample, buffer_inds = sample_p_0(replay_buffer, bs=bs, y=y) x_k = init_sample.clone() x_k.requires_grad = True # sgld for k in range(n_steps): if not contrast: energy = f(x_k, y=y).sum() else: if y is not None: dist = smooth_one_hot(y, args.n_classes, args.smoothing) else: dist = torch.ones((bs, args.n_classes)).to(device) output, target, ce_output, neg_num = f.joint(img=x_k, dist=dist, evaluation=True) energy = -1.0 * nn.CrossEntropyLoss(reduction="mean")(output, target) f_prime = torch.autograd.grad(energy, [x_k], retain_graph=True)[0] x_k.data += args.sgld_lr * f_prime + args.sgld_std * torch.randn_like( x_k) f.train() final_samples = x_k.detach() # update replay buffer if len(replay_buffer) > 0: replay_buffer[buffer_inds] = final_samples.cpu() return final_samples
def train(config, fold, model, dict_loader, optimizer, scheduler, list_dir_save_model, dir_pyplot, Validation=True, Test_flag=True): train_loader = dict_loader['train'] val_loader = dict_loader['val'] test_loader = dict_loader['test'] """ loss """ # criterion_cls = nn.CrossEntropyLoss() # criterion_cls = ut.FocalLoss(gamma=st.focal_gamma, alpha=st.focal_alpha, size_average=True) # kdloss = ut.KDLoss(4.0) criterion_KL = nn.KLDivLoss(reduction="sum") criterion_cls = nn.BCELoss() # criterion_L1 = nn.L1Loss(reduction='sum').cuda() # criterion_L2 = nn.MSELoss(reduction='mean').cuda() # criterion_gdl = gdl_loss(pNorm=2).cuda() EMS = ut.eval_metric_storage() list_selected_EMS = [] list_ES = [] for i_tmp in range(len(st.list_standard_eval_dir)): list_selected_EMS.append(ut.eval_selected_metirc_storage()) list_ES.append( ut.EarlyStopping(delta=0, patience=st.early_stopping_patience, verbose=True)) loss_tmp = [0] * 5 loss_tmp_total = 0 print('training') optimizer.zero_grad() optimizer.step() """ epoch """ num_data = len(train_loader.dataset) for epoch in range(1, config.num_epochs + 1): scheduler.step() print(" ") print("--------------- epoch {} ----------------".format(epoch)) """ print learning rate """ for param_group in optimizer.param_groups: print('current LR : {}'.format(param_group['lr'])) """ batch """ for i, data_batch in enumerate(train_loader): # start = time.time() model.train() with torch.no_grad(): """ input""" datas = Variable(data_batch['data'].float()).cuda() # labels = Variable(data_batch['label'].long()).cuda() labels = Variable(data_batch['label'].float()).cuda() """ data augmentation """ ##TODO : flip # flip_flag_list = np.random.normal(size=datas.shape[0])>0 # datas[flip_flag_list] = datas[flip_flag_list].flip(-3) ##TODO : translation, cropping dict_result = ut.data_augmentation(datas=datas, cur_epoch=epoch) datas = dict_result['datas'] translation_list = dict_result['translation_list'] # aug_dict_result = ut.data_augmentation(datas=aug_datas, cur_epoch=epoch) # aug_datas = aug_dict_result['datas'] """ minmax norm""" if st.list_data_norm_type[st.data_norm_type_num] == 'minmax': tmp_datas = datas.view(datas.size(0), -1) tmp_datas -= tmp_datas.min(1, keepdim=True)[0] tmp_datas /= tmp_datas.max(1, keepdim=True)[0] datas = tmp_datas.view_as(datas) """ gaussain noise """ # Gaussian_dist = torch.distributions.normal.Normal(loc=torch.tensor([0.0]), scale=torch.tensor([0.01])) # Gaussian_dist = torch.distributions.normal.Normal(loc=torch.tensor([0.0]), scale=torch.FloatTensor(1).uniform_(0, 0.01)) # Gaussian_noise = Gaussian_dist.sample(datas.size()).squeeze(-1) # datas = datas + Gaussian_noise.cuda() """ forward propagation """ dict_result = model(datas, translation_list) output_1 = dict_result['logits'] output_2 = dict_result['Aux_logits'] output_3 = dict_result['logitMap'] output_4 = dict_result['l1_norm'] # loss_list_1 = [] count_loss = 0 if fst.flag_loss_1 == True: s_labels = ut.smooth_one_hot(labels, config.num_classes, smoothing=st.smoothing_img) loss_2 = criterion_cls( output_1, s_labels) * st.lambda_major[0] / st.iter_to_update loss_list_1.append(loss_2) loss_tmp[count_loss] += loss_2.data.cpu().numpy() if (EMS.total_train_iter + 1) % st.iter_to_update == 0: EMS.train_aux_loss[count_loss].append(loss_tmp[count_loss]) loss_tmp[count_loss] = 0 count_loss += 1 if fst.flag_loss_2 == True: for i_tmp in range(len(output_2)): s_labels = ut.smooth_one_hot(labels, config.num_classes, smoothing=st.smoothing_roi) loss_2 = criterion_cls( output_2[i_tmp], s_labels) * st.lambda_aux[i_tmp] / st.iter_to_update loss_list_1.append(loss_2) loss_tmp[count_loss] += loss_2.data.cpu().numpy() if (EMS.total_train_iter + 1) % st.iter_to_update == 0: EMS.train_aux_loss[count_loss].append( loss_tmp[count_loss]) loss_tmp[count_loss] = 0 count_loss += 1 if fst.flag_loss_3 == True: # patch list_loss_tmp = [] for tmp_j in range(len(output_4)): # type i.e., patch, roi loss_2 = 0 for tmp_i in range(len(output_4[tmp_j])): # batch tmp_shape = output_4[tmp_j][tmp_i].shape logits = output_4[tmp_j][tmp_i].view( tmp_shape[0], tmp_shape[1], -1) # loss_2 += torch.norm(logits, p=1) loss_2 += torch.norm(logits, p=1) / (logits.view(-1).size(0)) list_loss_tmp.append( (loss_2 / len(output_4[tmp_j]) * st.l1_reg_norm) / st.iter_to_update) loss_list_1.append(sum(list_loss_tmp)) loss_tmp[count_loss] += sum(list_loss_tmp).data.cpu().numpy() if (EMS.total_train_iter + 1) % st.iter_to_update == 0: EMS.train_aux_loss[count_loss].append(loss_tmp[count_loss]) loss_tmp[count_loss] = 0 count_loss += 1 """ L1 reg""" # norm = torch.FloatTensor([0]).cuda() # for parameter in model.parameters(): # norm += torch.norm(parameter, p=1) # loss_list_1.append(norm * st.l1_reg) loss = sum(loss_list_1) loss.backward() torch.cuda.empty_cache() loss_tmp_total += loss.data.cpu().numpy() #TODO : optimize the model param if (EMS.total_train_iter + 1) % st.iter_to_update == 0: optimizer.step() optimizer.zero_grad() """ pyplot """ EMS.total_train_step += 1 EMS.train_step.append(EMS.total_train_step) EMS.train_loss.append(loss_tmp_total) """ print the train loss and tensorboard""" if (EMS.total_train_step) % 10 == 0: # print('time : ', time.time() - start) print('Epoch [%d/%d], Step [%d/%d], Loss: %.4f' % (epoch, config.num_epochs, (i + 1), (num_data // (config.batch_size)), loss_tmp_total)) loss_tmp_total = 0 EMS.total_train_iter += 1 # scheduler.step(epoch + i / len(train_loader)) """ val """ if Validation == True: print("------------------ val --------------------------") if fst.flag_cropping == True and fst.flag_eval_cropping == True: dict_result = ut.eval_classification_model_cropped_input( config, fold, val_loader, model, criterion_cls) elif fst.flag_translation == True and fst.flag_eval_translation == True: dict_result = ut.eval_classification_model_esemble( config, fold, val_loader, model, criterion_cls) elif fst.flag_MC_dropout == True: dict_result = ut.eval_classification_model_MC_dropout( config, fold, val_loader, model, criterion_cls) else: dict_result = ut.eval_classification_model( config, fold, val_loader, model, criterion_cls) val_loss = dict_result['Loss'] acc = dict_result['Acc'] auc = dict_result['AUC'] print('Fold : %d, Epoch [%d/%d] val Loss = %f val Acc = %f' % (fold, epoch, config.num_epochs, val_loss, acc)) """ save the metric """ EMS.dict_val_metric['val_loss'].append(val_loss) EMS.dict_val_metric['val_acc'].append(acc) if fst.flag_loss_2 == True: for tmp_i in range(len(st.lambda_aux)): EMS.dict_val_metric['val_acc_aux'][tmp_i].append( dict_result['Acc_aux'][tmp_i]) EMS.dict_val_metric['val_auc'].append(auc) EMS.val_step.append(EMS.total_train_step) n_stacking_loss_for_selection = 5 if len(EMS.dict_val_metric['val_loss_queue'] ) > n_stacking_loss_for_selection: EMS.dict_val_metric['val_loss_queue'].popleft() EMS.dict_val_metric['val_loss_queue'].append(val_loss) EMS.dict_val_metric['val_mean_loss'].append( np.mean(EMS.dict_val_metric['val_loss_queue'])) """ save model """ for i_tmp in range(len(list_selected_EMS)): save_flag = ut.model_save_through_validation( fold, epoch, EMS=EMS, selected_EMS=list_selected_EMS[i_tmp], ES=list_ES[i_tmp], model=model, dir_save_model=list_dir_save_model[i_tmp], metric_1=st.list_standard_eval[i_tmp], metric_2='', save_flag=False) if Test_flag == True: print( "------------------ test _ test dataset --------------------------" ) """ load data """ if fst.flag_cropping == True and fst.flag_eval_cropping == True: print("eval : cropping") dict_result = ut.eval_classification_model_cropped_input( config, fold, test_loader, model, criterion_cls) elif fst.flag_translation == True and fst.flag_eval_translation == True: print("eval : assemble") dict_result = ut.eval_classification_model_esemble( config, fold, test_loader, model, criterion_cls) elif fst.flag_MC_dropout == True: dict_result = ut.eval_classification_model_MC_dropout( config, fold, test_loader, model, criterion_cls) else: print("eval : whole image") dict_result = ut.eval_classification_model( config, fold, test_loader, model, criterion_cls) acc = dict_result['Acc'] test_loss = dict_result['Loss'] """ pyplot """ EMS.test_acc.append(acc) if fst.flag_loss_2 == True: for tmp_i in range(len(st.lambda_aux)): EMS.test_acc_aux[tmp_i].append( dict_result['Acc_aux'][tmp_i]) EMS.test_loss.append(test_loss) EMS.test_step.append(EMS.total_train_step) print('number of test samples : {}'.format(len( test_loader.dataset))) print('Fold : %d, Epoch [%d/%d] test Loss = %f test Acc = %f' % (fold, epoch, config.num_epochs, test_loss, acc)) """ learning rate decay""" EMS.LR.append(optimizer.param_groups[0]['lr']) # scheduler.step() # scheduler.step(val_loss) """ plot the chat """ if epoch % 1 == 0: ut.plot_training_info_1(fold, dir_pyplot, EMS, flag='percentile', flag_match=False) ##TODO : early stop only if all of metric has been stopped tmp_count = 0 for i in range(len(list_ES)): if list_ES[i].early_stop == True: tmp_count += 1 if tmp_count == len(list_ES): break """ release the model """ del model, EMS torch.cuda.empty_cache()
def main(args): # Setup datasets dload_train, dload_train_labeled, dload_valid, dload_test = get_data(args) # Model and buffer sample_q = get_sample_q(args) f, replay_buffer = get_model_and_buffer(args, sample_q) # Setup Optimizer params = f.class_output.parameters() if args.clf_only else f.parameters() if args.optimizer == "adam": optim = torch.optim.Adam(params, lr=args.lr, betas=[0.9, 0.999], weight_decay=args.weight_decay) else: optim = torch.optim.SGD(params, lr=args.lr, momentum=0.9, weight_decay=args.weight_decay) best_valid_acc = 0.0 cur_iter = 0 for epoch in range(args.start_epoch, args.n_epochs): # Decay lr if epoch in args.decay_epochs: for param_group in optim.param_groups: new_lr = param_group["lr"] * args.decay_rate param_group["lr"] = new_lr # Load data for i, (x_p_d, _) in tqdm(enumerate(dload_train)): # Warmup if cur_iter <= args.warmup_iters: lr = args.lr * cur_iter / float(args.warmup_iters) for param_group in optim.param_groups: param_group["lr"] = lr x_p_d = x_p_d.to(device) x_lab, y_lab = dload_train_labeled.__next__() x_lab, y_lab = x_lab.to(device), y_lab.to(device) # Label smoothing dist = smooth_one_hot(y_lab, args.n_classes, args.smoothing) L = 0.0 # log p(y|x) cross entropy loss if args.pyxce > 0: logits = f.classify(x_lab) l_pyxce = KHotCrossEntropyLoss()(logits, dist) if cur_iter % args.print_every == 0: acc = (logits.max(1)[1] == y_lab).float().mean() print("p(y|x)CE {}:{:>d} loss={:>14.9f}, acc={:>14.9f}". format(epoch, cur_iter, l_pyxce.item(), acc.item())) logger.record_dict({ "l_pyxce": l_pyxce.cpu().data.item(), "acc_pyxce": acc.item() }) L += args.pyxce * l_pyxce # log p(x) using sgld if args.pxsgld > 0: if args.class_cond_p_x_sample: assert not args.uncond, "can only draw class-conditional samples if EBM is class-cond" y_q = torch.randint(0, args.n_classes, (args.sgld_batch_size, )).to(device) x_q = sample_q(f, replay_buffer, y=y_q) else: x_q = sample_q(f, replay_buffer) # sample from log-sumexp fp_all = f(x_p_d) fq_all = f(x_q) fp = fp_all.mean() fq = fq_all.mean() l_pxsgld = -(fp - fq) if cur_iter % args.print_every == 0: print( "p(x)SGLD | {}:{:>d} loss={:>14.9f} f(x_p_d)={:>14.9f} f(x_q)={:>14.9f}" .format(epoch, i, l_pxsgld, fp, fq)) logger.record_dict( {"l_pxsgld": l_pxsgld.cpu().data.item()}) L += args.pxsgld * l_pxsgld # log p(x) using contrastive learning if args.pxcontrast > 0: # ones like dist to use all indexes ones_dist = torch.ones_like(dist).to(device) output, target, ce_output, neg_num = f.joint(img=x_lab, dist=ones_dist) l_pxcontrast = nn.CrossEntropyLoss(reduction="mean")(output, target) if cur_iter % args.print_every == 0: acc = (ce_output.max(1)[1] == y_lab).float().mean() print( "p(x)Contrast {}:{:>d} loss={:>14.9f}, acc={:>14.9f}". format(epoch, cur_iter, l_pxcontrast.item(), acc.item())) logger.record_dict({ "l_pxcontrast": l_pxcontrast.cpu().data.item(), "acc_pxcontrast": acc.item() }) L += args.pxycontrast * l_pxcontrast # log p(x|y) using sgld if args.pxysgld > 0: x_q_lab = sample_q(f, replay_buffer, y=y_lab) fp, fq = f(x_lab).mean(), f(x_q_lab).mean() l_pxysgld = -(fp - fq) if cur_iter % args.print_every == 0: print( "p(x|y)SGLD | {}:{:>d} loss={:>14.9f} f(x_p_d)={:>14.9f} f(x_q)={:>14.9f}" .format(epoch, i, l_pxysgld.item(), fp, fq)) logger.record_dict( {"l_pxysgld": l_pxysgld.cpu().data.item()}) L += args.pxsgld * l_pxysgld # log p(x|y) using contrastive learning if args.pxycontrast > 0: output, target, ce_output, neg_num = f.joint(img=x_lab, dist=dist) l_pxycontrast = nn.CrossEntropyLoss(reduction="mean")(output, target) if cur_iter % args.print_every == 0: acc = (ce_output.max(1)[1] == y_lab).float().mean() print( "p(x|y)Contrast {}:{:>d} loss={:>14.9f}, acc={:>14.9f}" .format(epoch, cur_iter, l_pxycontrast.item(), acc.item())) logger.record_dict({ "l_pxycontrast": l_pxycontrast.cpu().data.item(), "acc_pxycontrast": acc.item() }) L += args.pxycontrast * l_pxycontrast # SGLD training of log q(x) may diverge # break here and record information to restart if L.abs().item() > 1e8: print("restart epoch: {}".format(epoch)) print("save dir: {}".format(args.log_dir)) print("id: {}".format(args.id)) print("steps: {}".format(args.n_steps)) print("seed: {}".format(args.seed)) print("exp prefix: {}".format(args.exp_prefix)) sys.stdout = sys.__stdout__ sys.stderr = sys.__stderr__ print("restart epoch: {}".format(epoch)) print("save dir: {}".format(args.log_dir)) print("id: {}".format(args.id)) print("steps: {}".format(args.n_steps)) print("seed: {}".format(args.seed)) print("exp prefix: {}".format(args.exp_prefix)) assert False, "shit loss explode..." optim.zero_grad() L.backward() optim.step() cur_iter += 1 if epoch % args.plot_every == 0: if args.plot_uncond: if args.class_cond_p_x_sample: assert not args.uncond, "can only draw class-conditional samples if EBM is class-cond" y_q = torch.randint(0, args.n_classes, (args.sgld_batch_size, )).to(device) x_q = sample_q(f, replay_buffer, y=y_q) plot( "{}/x_q_{}_{:>06d}.png".format(args.log_dir, epoch, i), x_q) if args.plot_contrast: x_q = sample_q(f, replay_buffer, y=y_q, contrast=True) plot( "{}/contrast_x_q_{}_{:>06d}.png".format( args.log_dir, epoch, i), x_q) else: x_q = sample_q(f, replay_buffer) plot( "{}/x_q_{}_{:>06d}.png".format(args.log_dir, epoch, i), x_q) if args.plot_contrast: x_q = sample_q(f, replay_buffer, contrast=True) plot( "{}/contrast_x_q_{}_{:>06d}.png".format( args.log_dir, epoch, i), x_q) if args.plot_cond: # generate class-conditional samples y = torch.arange(0, args.n_classes)[None].repeat( args.n_classes, 1).transpose(1, 0).contiguous().view(-1).to(device) x_q_y = sample_q(f, replay_buffer, y=y) plot("{}/x_q_y{}_{:>06d}.png".format(args.log_dir, epoch, i), x_q_y) if args.plot_contrast: y = torch.arange(0, args.n_classes)[None].repeat( args.n_classes, 1).transpose(1, 0).contiguous().view(-1).to(device) x_q_y = sample_q(f, replay_buffer, y=y, contrast=True) plot( "{}/contrast_x_q_y_{}_{:>06d}.png".format( args.log_dir, epoch, i), x_q_y) if args.ckpt_every > 0 and epoch % args.ckpt_every == 0: checkpoint(f, replay_buffer, f"ckpt_{epoch}.pt", args) if epoch % args.eval_every == 0: # Validation set correct, val_loss = eval_classification(f, dload_valid) if correct > best_valid_acc: best_valid_acc = correct print("Best Valid!: {}".format(correct)) checkpoint(f, replay_buffer, "best_valid_ckpt.pt", args) # Test set correct, test_loss = eval_classification(f, dload_test) print("Epoch {}: Valid Loss {}, Valid Acc {}".format( epoch, val_loss, correct)) print("Epoch {}: Test Loss {}, Test Acc {}".format( epoch, test_loss, correct)) f.train() logger.record_dict({ "Epoch": epoch, "Valid Loss": val_loss, "Valid Acc": correct.detach().cpu().numpy(), "Test Loss": test_loss, "Test Acc": correct.detach().cpu().numpy(), "Best Valid": best_valid_acc.detach().cpu().numpy(), "Loss": L.cpu().data.item(), }) checkpoint(f, replay_buffer, "last_ckpt.pt", args) logger.dump_tabular()