def train_predictor_x_multiple(x, label, augmenting=False, n_aug=0, label_list=None,\ gcl=None, encoder=None, decoder=None,opt_dec=None): gcl.eval() encoder.eval() z = encoder(x) s = gcl.hidden(z) opt_dec.zero_grad() if augmenting: aug_s = [] aug_s.append(s) aug_label = [] aug_label.append(label) for label_idx in label_list: s_aug = label_augmenting_s(z, label, label_idx, n_aug, gcl) aug_s.append(s_aug) aug_label.append(torch.tensor(np.repeat(label_idx, n_aug))) pred_label = decoder(torch.cat(aug_s)) CE_loss = cross_entropy(pred_label, torch.cat(aug_label)) else: pred_label = decoder(s) CE_loss = cross_entropy(pred_label, label) CE_loss.backward() opt_dec.step()
def train_predictor_s_aug(s, label, s_aug, label_aug, class_weight, class_weight_aug,\ decoder=None,opt_dec=None, device='cpu'): if s_aug is None: CE_loss = train_predictor_s(s, label, class_weight, decoder=decoder, opt_dec=opt_dec) return CE_loss opt_dec.zero_grad() pred_label = decoder(s.to(device)) pred_label_aug = decoder(s_aug.to(device)) CE_loss_raw = cross_entropy(pred=pred_label, label=label, sample_weight=class_weight, class_acc=False) CE_loss_aug = cross_entropy(pred=pred_label_aug, label=label_aug, sample_weight=class_weight_aug, class_acc=False) CE_loss = CE_loss_raw + CE_loss_aug CE_loss.backward() opt_dec.step() return CE_loss
def train_predictor_s(s, label, class_weight,\ decoder=None,opt_dec=None): opt_dec.zero_grad() pred_label = decoder(s) CE_loss = cross_entropy(pred_label, label, class_weight) CE_loss.backward() opt_dec.step()
def eval_predictor_x(x, label, gcl=None, encoder=None, decoder=None): gcl.eval() encoder.eval() decoder.eval() z = encoder(x) s = gcl.hidden(z) pred_label = decoder(s) CE_loss, acc_ = cross_entropy(pred_label, label, class_acc=True) return pred_label, CE_loss, acc_
def train_predictor_x(x, label, gcl=None,\ encoder=None, decoder=None,opt_dec=None, device='cpu'): gcl.eval() encoder.eval() z = encoder(x.to(device)) s = gcl.hidden(z.to(device)) opt_dec.zero_grad() pred_label = decoder(s.to(device)) CE_loss = cross_entropy(pred_label, label) CE_loss.backward() opt_dec.step() return CE_loss.item()
def train_predictor_x(x, label, augmenting=False, n_aug=0, label_idx=None,\ gcl=None, encoder=None, decoder=None,opt_dec=None): gcl.eval() encoder.eval() z = encoder(x) s = gcl.hidden(z) opt_dec.zero_grad() if augmenting: s_aug = label_augmenting_s(z, label, label_idx, n_aug, gcl) aug_s = torch.cat((s, s_aug), 0) aug_label = torch.cat( (label, torch.tensor(np.repeat(label_idx, n_aug))), 0) pred_label = decoder(aug_s) CE_loss = cross_entropy(pred_label, aug_label) else: pred_label = decoder(s) CE_loss = cross_entropy(pred_label, label) CE_loss.backward() opt_dec.step()
def train_predictor_pretrain(x, label=None, trainable=True, predictionNN=None, opt_enc=None): if trainable: predictionNN.train() opt_enc.zero_grad() pred = predictionNN(x) if label is None: '''auto-encoder set up''' label = x criterion = nn.MSELoss() loss_ = criterion(pred, label) else: '''encoder-decoder set up''' loss_ = cross_entropy(pred, label) loss_.backward() nn.utils.clip_grad_norm_(predictionNN.parameters(), 1e-3) opt_enc.step() else: predictionNN.eval() pred = predictionNN(x) if label is None: '''auto-encoder set up''' label = x criterion = nn.MSELoss() loss_ = criterion(pred, label) return loss_ else: '''encoder-decoder set up''' loss_, acc_ = cross_entropy(pred, label, class_acc=True) return loss_, acc_ return loss_
def CRT_result(encoder, gcl, decoder, flow_path, dec_path, valid_loader, device): # load the best encoder # encoder = torch.nn.Identity() # load the best gcl gcl.load_state_dict(torch.load(flow_path)) # load the best decoder decoder.load_state_dict(torch.load(dec_path)) encoder.eval() gcl.eval() decoder.eval() pred_risk = [] true_label = [] for batched_x, batched_label in valid_loader: batched_x = batched_x.to(device).float() batched_risk = decoder(gcl.hidden(encoder(batched_x))) pred_risk.append(batched_risk) true_label.append(batched_label) pred_risk = torch.cat(pred_risk) true_label = torch.cat(true_label) loss_, acc_ = cross_entropy(pred_risk, true_label, class_acc=True) # top1 error acc_top1 = acc_ # top 3 acc pred_risk = pred_risk.detach().cpu().numpy() val_top5 = [(pred_risk[i].argsort()[-5:][::-1]).tolist() for i in range(len(true_label))] true_y = true_label.numpy() acc_top5 = [true_y[i] in val_top5[i] for i in range(len(true_label))] acc_top5 = np.mean(acc_top5) print('====> CRT CE loss: {:.3f} \t'.format(loss_.item())) print('====> CRT top1 acc: {:.3f} \t top 5 acc: {:.3f} \t'.format(acc_top1, acc_top5)) return loss_.item(), acc_top1, acc_top5
def main(minority_label): # set hyper-parameters device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model_path = args.dir_path + '/'+args.dataset plot_path = model_path + '/plot' model_name = 'CRT_'+str(len(args.minority_label)) ncat = len(np.unique(label_)) ncov = x_.shape[1] z_dim = 2 if args.minority_label == 0: args.minority_label = [0] n_sample_list = [200] + [2000]*6 else: args.minority_label = [0, 1, 2] n_sample_list = [200]*3 + [2000]*4 n_sample_list = x_, label_ = simulated_data_ablation(n_sample_list = n_sample_list) Path(model_path).mkdir(parents=True, exist_ok=True) Path(plot_path).mkdir(parents=True, exist_ok=True) Path(data_path).mkdir(parents=True, exist_ok=True) ## permute the sources np.random.seed(12) n_samples = len(x_) permuted_idx = np.random.permutation(n_samples) train_idx = permuted_idx[:int(2*n_samples/3)] valid_idx = permuted_idx[int(2*n_samples/3):n_samples] train_ = SimpleDataset(x_[train_idx,:], label_[train_idx], transform=None) valid_ = SimpleDataset(x_[valid_idx,:], label_[valid_idx], transform=None) # valid_ = myDataset_nofake_dic(valid,transform=False) # initiate parameters # encoder-decoder setup predictionNN = EncoderDecoderNN(input_size=ncov, z_dim=z_dim, ncat=ncat,h_dim=[512,512]) majority_label = list(set(range(ncat))-set(args.minority_label)) args.majority_label = majority_label major_num = ncat - len(args.minority_label) gcl = GeneralizedContrastiveICAModel(dim = z_dim, n_label=ncat, n_steps=4, hidden_size=128, n_hidden=2, major_num=major_num) encoder = predictionNN.encoder decoder = DecoderNN(input_size=z_dim, ncat=ncat, h_dim=[32,32]) # model path flow_path = model_path + '/'+model_name+'_gcl.pt' enc_path = model_path + '/'+model_name+'_enc.pt' dec_path = model_path + '/'+model_name+'_dec.pt' if training: train_loader = DataLoader(train_, batch_size= args.batch_size, num_workers=args.workers, shuffle= True) valid_loader = DataLoader(valid_, batch_size= args.batch_size, num_workers=args.workers, shuffle= True, drop_last = True) class_weight = torch.Tensor(np.ones(ncat)) del train_, val_ opt_enc = optim.Adam(predictionNN.parameters(), lr=args.enc_lr) opt_flow = optim.Adam(gcl.parameters(), lr=args.gcl_lr) opt_dec = optim.Adam(decoder.parameters(), lr=args.dec_lr) predictionNN.to(device) gcl.to(device) decoder.to(device) # identity map for toy dataset encoder = torch.nn.Identity() # pre-training the encoder # train_encoder(predictionNN, opt_enc, enc_path, train_loader, valid_loader, ncat, device) # load pre-trained encoder # encoder.load_state_dict(torch.load(enc_path)) # constractive learning of the invertible flow train_gcl(args.minority_label, args.majority_label, encoder, gcl, opt_flow, flow_path, train_loader, valid_loader, device, reg_weight= args.reg_weight, plot_path=plot_path, args=args) # load the best encoder encoder.load_state_dict(torch.load(enc_path)) # load the best gcl gcl.load_state_dict(torch.load(flow_path)) # load s from all minority classes minority_ = get_minority_s(train_loader, args.minority_label, encoder, gcl, device) # define minority class augmentation strength n_flag = np.zeros(ncov) n_flag[args.minority_label] = 1 args.cls_weight = torch.tensor(np.array([1]*ncat)- (1-args.lmbda)*n_flag).float() args.cls_weight_aug = torch.tensor(np.array([1]*ncat)- (args.lmbda)*n_flag).float() args.cls_flag = torch.tensor(n_flag).float().to(device) balanced_loader = torch.utils.data.DataLoader(train_,batch_size=batch_size, \ sampler=ImbalancedDatasetSampler(train_, indices = list(range(train_.data.shape[0])),callback_get_label=callback_get_label)) minority_ = get_minority_s(train_loader, args.minority_label, encoder.to(device), gcl.to(device), device) n_aug_per_class = 500 s_aug, s_label = label_augmenting_multiple(minority_['s'], minority_['label'], label_list=args.minority_label, n_aug=[n_aug_per_class]*len(args.minority_label)) aug_ds = SimpleDataset(s_aug.cpu(), s_label) # del s_aug, s_label aug_loader = DataLoader(aug_ds, batch_size= args.batch_size, num_workers=args.workers,shuffle= True) # del aug_ds, minority_ # train the predictor train_decoder(encoder, gcl, decoder, opt_dec, dec_path, balanced_loader, aug_loader, valid_loader, device, args) ############################################# # Report performance on the validation dataset # load the best encoder encoder.load_state_dict(torch.load(enc_path)) # load the best gcl gcl.load_state_dict(torch.load(flow_path)) # load the best decoder decoder.load_state_dict(torch.load(dec_path)) encoder.eval() gcl.eval() decoder.eval() pred_risk = [] true_label = [] for batched_sample in valid_loader: # batched_x = batched_sample['x'] # batched_label = batched_sample['label'] batched_x, batched_label = batched_sample batched_x = batched_x.to(device).float().view(batched_x.shape[0], -1) batched_risk = decoder(gcl.hidden(encoder(batched_x))) pred_risk.append(batched_risk) true_label.append(batched_label) pred_risk = torch.cat(pred_risk) true_label = torch.cat(true_label) # pred_label = get_predicted_label(pred_risk.detach()) test_CE_loss, class_acc = cross_entropy(pred_risk, true_label, class_acc = True) # top1 error acc_top1 = np.mean((true_label.numpy()==np.argmax(pred_risk.detach(),axis=1).numpy())*1) # top 3 acc pred_risk = pred_risk.detach().numpy() val_top5 = [(pred_risk[i].argsort()[-3:][::-1]).tolist() for i in range(len(true_label))] true_y = true_label.numpy() acc_top5 = [true_y[i] in val_top5[i] for i in range(len(true_label))] acc_top5 = np.mean(acc_top5) print('====> Test CE loss: {:.3f} \tper-class acc: {} \t'.format(test_CE_loss.item(), class_acc)) print('====> Test top1 acc: {:.3f} \t top 3 acc: {:.3f} \t'.format(acc_top1, acc_top5))