Ejemplo n.º 1
0
def train_predictor_x_multiple(x, label, augmenting=False, n_aug=0, label_list=None,\
                      gcl=None, encoder=None, decoder=None,opt_dec=None):
    gcl.eval()
    encoder.eval()
    z = encoder(x)
    s = gcl.hidden(z)
    opt_dec.zero_grad()

    if augmenting:
        aug_s = []
        aug_s.append(s)

        aug_label = []
        aug_label.append(label)

        for label_idx in label_list:
            s_aug = label_augmenting_s(z, label, label_idx, n_aug, gcl)
            aug_s.append(s_aug)
            aug_label.append(torch.tensor(np.repeat(label_idx, n_aug)))

        pred_label = decoder(torch.cat(aug_s))
        CE_loss = cross_entropy(pred_label, torch.cat(aug_label))

    else:
        pred_label = decoder(s)
        CE_loss = cross_entropy(pred_label, label)

    CE_loss.backward()
    opt_dec.step()
Ejemplo n.º 2
0
def train_predictor_s_aug(s, label, s_aug, label_aug, class_weight, class_weight_aug,\
                      decoder=None,opt_dec=None, device='cpu'):
    if s_aug is None:
        CE_loss = train_predictor_s(s,
                                    label,
                                    class_weight,
                                    decoder=decoder,
                                    opt_dec=opt_dec)
        return CE_loss

    opt_dec.zero_grad()
    pred_label = decoder(s.to(device))
    pred_label_aug = decoder(s_aug.to(device))
    CE_loss_raw = cross_entropy(pred=pred_label,
                                label=label,
                                sample_weight=class_weight,
                                class_acc=False)
    CE_loss_aug = cross_entropy(pred=pred_label_aug,
                                label=label_aug,
                                sample_weight=class_weight_aug,
                                class_acc=False)

    CE_loss = CE_loss_raw + CE_loss_aug
    CE_loss.backward()
    opt_dec.step()
    return CE_loss
Ejemplo n.º 3
0
def train_predictor_s(s, label, class_weight,\
                      decoder=None,opt_dec=None):

    opt_dec.zero_grad()

    pred_label = decoder(s)
    CE_loss = cross_entropy(pred_label, label, class_weight)

    CE_loss.backward()
    opt_dec.step()
Ejemplo n.º 4
0
def eval_predictor_x(x, label, gcl=None, encoder=None, decoder=None):
    gcl.eval()
    encoder.eval()
    decoder.eval()

    z = encoder(x)
    s = gcl.hidden(z)
    pred_label = decoder(s)

    CE_loss, acc_ = cross_entropy(pred_label, label, class_acc=True)

    return pred_label, CE_loss, acc_
Ejemplo n.º 5
0
def train_predictor_x(x, label, gcl=None,\
                      encoder=None, decoder=None,opt_dec=None, device='cpu'):
    gcl.eval()
    encoder.eval()
    z = encoder(x.to(device))
    s = gcl.hidden(z.to(device))
    opt_dec.zero_grad()

    pred_label = decoder(s.to(device))
    CE_loss = cross_entropy(pred_label, label)

    CE_loss.backward()
    opt_dec.step()
    return CE_loss.item()
Ejemplo n.º 6
0
def train_predictor_x(x, label, augmenting=False, n_aug=0, label_idx=None,\
                      gcl=None, encoder=None, decoder=None,opt_dec=None):
    gcl.eval()
    encoder.eval()
    z = encoder(x)
    s = gcl.hidden(z)
    opt_dec.zero_grad()

    if augmenting:
        s_aug = label_augmenting_s(z, label, label_idx, n_aug, gcl)
        aug_s = torch.cat((s, s_aug), 0)
        aug_label = torch.cat(
            (label, torch.tensor(np.repeat(label_idx, n_aug))), 0)

        pred_label = decoder(aug_s)
        CE_loss = cross_entropy(pred_label, aug_label)

    else:
        pred_label = decoder(s)
        CE_loss = cross_entropy(pred_label, label)

    CE_loss.backward()
    opt_dec.step()
Ejemplo n.º 7
0
def train_predictor_pretrain(x,
                             label=None,
                             trainable=True,
                             predictionNN=None,
                             opt_enc=None):
    if trainable:
        predictionNN.train()
        opt_enc.zero_grad()
        pred = predictionNN(x)
        if label is None:
            '''auto-encoder set up'''
            label = x
            criterion = nn.MSELoss()
            loss_ = criterion(pred, label)
        else:
            '''encoder-decoder set up'''
            loss_ = cross_entropy(pred, label)

        loss_.backward()
        nn.utils.clip_grad_norm_(predictionNN.parameters(), 1e-3)

        opt_enc.step()
    else:
        predictionNN.eval()
        pred = predictionNN(x)
        if label is None:
            '''auto-encoder set up'''
            label = x
            criterion = nn.MSELoss()
            loss_ = criterion(pred, label)
            return loss_
        else:
            '''encoder-decoder set up'''
            loss_, acc_ = cross_entropy(pred, label, class_acc=True)
            return loss_, acc_

    return loss_
Ejemplo n.º 8
0
def CRT_result(encoder, gcl, decoder, flow_path, dec_path, valid_loader, device):
    # load the best encoder
#     encoder = torch.nn.Identity()
    # load the best gcl
    gcl.load_state_dict(torch.load(flow_path))
    # load the best decoder
    decoder.load_state_dict(torch.load(dec_path))

    encoder.eval()
    gcl.eval()
    decoder.eval()
    
    pred_risk = []
    true_label = []

    for batched_x, batched_label in valid_loader:
        batched_x = batched_x.to(device).float()
        batched_risk = decoder(gcl.hidden(encoder(batched_x)))
        pred_risk.append(batched_risk)
        true_label.append(batched_label)

    pred_risk = torch.cat(pred_risk)    
    true_label = torch.cat(true_label)
 
    loss_, acc_ = cross_entropy(pred_risk, true_label, class_acc=True)

    # top1 error
    acc_top1 = acc_
    # top 3 acc
    pred_risk = pred_risk.detach().cpu().numpy()
    val_top5 = [(pred_risk[i].argsort()[-5:][::-1]).tolist() for i in range(len(true_label))]
    true_y = true_label.numpy()

    acc_top5 = [true_y[i] in val_top5[i] for i in range(len(true_label))]
    acc_top5 = np.mean(acc_top5)        
    

    print('====> CRT CE loss: {:.3f} \t'.format(loss_.item()))
    print('====> CRT top1 acc: {:.3f} \t top 5 acc: {:.3f} \t'.format(acc_top1, acc_top5))
    
    return loss_.item(), acc_top1, acc_top5
Ejemplo n.º 9
0
def main(minority_label):
    # set hyper-parameters
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model_path = args.dir_path + '/'+args.dataset
    plot_path = model_path + '/plot'

    model_name = 'CRT_'+str(len(args.minority_label))

    ncat = len(np.unique(label_))
    ncov = x_.shape[1]
    z_dim = 2
    if args.minority_label == 0:
        args.minority_label = [0]
        n_sample_list = [200] + [2000]*6
    else:
        args.minority_label = [0, 1, 2]
        n_sample_list = [200]*3 + [2000]*4

    n_sample_list = 
    x_, label_ = simulated_data_ablation(n_sample_list = n_sample_list)
        
    Path(model_path).mkdir(parents=True, exist_ok=True)
    Path(plot_path).mkdir(parents=True, exist_ok=True)
    Path(data_path).mkdir(parents=True, exist_ok=True)

    ## permute the sources
    np.random.seed(12)
    n_samples = len(x_)
    
    permuted_idx = np.random.permutation(n_samples)
    train_idx = permuted_idx[:int(2*n_samples/3)]
    valid_idx = permuted_idx[int(2*n_samples/3):n_samples]

    train_ = SimpleDataset(x_[train_idx,:], label_[train_idx], transform=None)
    valid_ = SimpleDataset(x_[valid_idx,:], label_[valid_idx], transform=None)

    # valid_ = myDataset_nofake_dic(valid,transform=False)
    
    # initiate parameters
    # encoder-decoder setup
    predictionNN = EncoderDecoderNN(input_size=ncov, z_dim=z_dim, ncat=ncat,h_dim=[512,512])
    
    majority_label = list(set(range(ncat))-set(args.minority_label))
    args.majority_label = majority_label
    major_num = ncat - len(args.minority_label)

    gcl = GeneralizedContrastiveICAModel(dim = z_dim, n_label=ncat, n_steps=4, hidden_size=128, n_hidden=2, major_num=major_num)
    encoder = predictionNN.encoder

    decoder = DecoderNN(input_size=z_dim, ncat=ncat, h_dim=[32,32])

    # model path
    flow_path = model_path + '/'+model_name+'_gcl.pt'
    enc_path = model_path + '/'+model_name+'_enc.pt'
    dec_path = model_path + '/'+model_name+'_dec.pt'
    
    if training:
        
        train_loader = DataLoader(train_, batch_size= args.batch_size, num_workers=args.workers, shuffle= True)
        valid_loader = DataLoader(valid_, batch_size= args.batch_size, num_workers=args.workers, shuffle= True, drop_last = True)
        class_weight = torch.Tensor(np.ones(ncat))

        del train_, val_
        
        opt_enc = optim.Adam(predictionNN.parameters(), lr=args.enc_lr)
        opt_flow = optim.Adam(gcl.parameters(), lr=args.gcl_lr)
        opt_dec = optim.Adam(decoder.parameters(), lr=args.dec_lr)

        predictionNN.to(device)
        gcl.to(device)
        decoder.to(device)
        
        # identity map for toy dataset
        encoder = torch.nn.Identity()
        # pre-training the encoder
#         train_encoder(predictionNN, opt_enc, enc_path, train_loader, valid_loader, ncat, device)    
        # load pre-trained encoder
#         encoder.load_state_dict(torch.load(enc_path))

        # constractive learning of the invertible flow
        train_gcl(args.minority_label, args.majority_label, encoder, gcl, opt_flow, flow_path, train_loader, valid_loader, device, reg_weight= args.reg_weight, plot_path=plot_path, args=args)
        
        # load the best encoder
        encoder.load_state_dict(torch.load(enc_path))
        # load the best gcl
        gcl.load_state_dict(torch.load(flow_path))
        
        # load s from all minority classes
        minority_ = get_minority_s(train_loader, args.minority_label, encoder, gcl, device)
        
        # define minority class augmentation strength
        n_flag = np.zeros(ncov)
        n_flag[args.minority_label] = 1
        
        args.cls_weight = torch.tensor(np.array([1]*ncat)- (1-args.lmbda)*n_flag).float()
        args.cls_weight_aug = torch.tensor(np.array([1]*ncat)- (args.lmbda)*n_flag).float()
        args.cls_flag = torch.tensor(n_flag).float().to(device)
        
        balanced_loader = torch.utils.data.DataLoader(train_,batch_size=batch_size, \
                                                      sampler=ImbalancedDatasetSampler(train_, indices = list(range(train_.data.shape[0])),callback_get_label=callback_get_label))
        minority_ = get_minority_s(train_loader, args.minority_label, encoder.to(device), gcl.to(device), device)

        n_aug_per_class = 500
        s_aug, s_label = label_augmenting_multiple(minority_['s'], minority_['label'], label_list=args.minority_label, n_aug=[n_aug_per_class]*len(args.minority_label))

        aug_ds = SimpleDataset(s_aug.cpu(), s_label)
#         del s_aug, s_label
        aug_loader = DataLoader(aug_ds, batch_size= args.batch_size, num_workers=args.workers,shuffle= True)

#         del aug_ds, minority_

        # train the predictor
        train_decoder(encoder, gcl, decoder, opt_dec, dec_path, balanced_loader, aug_loader, valid_loader, device, args)

        
    #############################################    
    # Report performance on the validation dataset
    # load the best encoder
    encoder.load_state_dict(torch.load(enc_path))
    # load the best gcl
    gcl.load_state_dict(torch.load(flow_path))            
    # load the best decoder
    decoder.load_state_dict(torch.load(dec_path))

    encoder.eval()
    gcl.eval()
    decoder.eval()
    pred_risk = []
    true_label = []

    for batched_sample in valid_loader:
    #     batched_x = batched_sample['x']
    #     batched_label = batched_sample['label']
        batched_x, batched_label = batched_sample
        batched_x = batched_x.to(device).float().view(batched_x.shape[0], -1)
        batched_risk = decoder(gcl.hidden(encoder(batched_x)))
        pred_risk.append(batched_risk)
        true_label.append(batched_label)

    pred_risk = torch.cat(pred_risk)    
    true_label = torch.cat(true_label)    #         pred_label = get_predicted_label(pred_risk.detach())
    
    test_CE_loss, class_acc = cross_entropy(pred_risk, true_label, class_acc = True)
    
    # top1 error
    acc_top1 = np.mean((true_label.numpy()==np.argmax(pred_risk.detach(),axis=1).numpy())*1)
    # top 3 acc
    pred_risk = pred_risk.detach().numpy()
    val_top5 = [(pred_risk[i].argsort()[-3:][::-1]).tolist() for i in range(len(true_label))]
    true_y = true_label.numpy()

    acc_top5 = [true_y[i] in val_top5[i] for i in range(len(true_label))]
    acc_top5 = np.mean(acc_top5)        


    print('====> Test CE loss: {:.3f} \tper-class acc: {} \t'.format(test_CE_loss.item(), class_acc))
    print('====> Test top1 acc: {:.3f} \t top 3 acc: {:.3f} \t'.format(acc_top1, acc_top5))