Exemplo n.º 1
0
def main():
    
    best_loss = 100
    eps = 1e-4
    epoch_counter = 0
    for epoch in count(1):
        batch_losses = []
        
        samp2_batches = enumerate(sample2_loader)
        for batch_idx, batch1 in enumerate(sample1_loader):
            try:
                _, batch2 = next(samp2_batches)
            except:
                samp2_batches = enumerate(sample2_loader)
                _, batch2 = next(samp2_batches)
                
            batch1 = batch1[0].to(device=device)
            batch2 = batch2[0].to(device=device)
            
            batch_loss = training_step(batch1, batch2)
            batch_losses.append(batch_loss)
        
        epoch_loss = np.mean(batch_losses)
        
        if epoch_loss < best_loss - eps:
            best_loss = epoch_loss
            epoch_counter = 0
        else:
            epoch_counter += 1
            
        print('Epoch {}, loss: {:.3f}, counter: {}'.format(epoch, 
                                              epoch_loss,
                                              epoch_counter)
        )
   
        update_lr()
        
        if epoch_counter == args.epochs_wo_im:
            break
    print('Finished training')
    
    # calibrate sample2 -> batch 1    
    
    mmd_resnet.train(False)
    
    calibrated_sample2 = []
    for batch_idx, batch2 in enumerate(sample2_loader):
        batch2 = batch2[0].to(device=device)
        calibrated_batch = mmd_resnet(batch2)
        calibrated_sample2 += [calibrated_batch.detach().cpu().numpy()]
        
    calibrated_sample2 = np.concatenate(calibrated_sample2)
               
    # ==============================================================================
    # =                         visualize calibration                              =
    # ==============================================================================
    
    # PCA
    pca = decomposition.PCA()
    pca.fit(sample1)
    pc1 = 0
    pc2 = 1
    axis1 = 'PC'+str(pc1)
    axis2 = 'PC'+str(pc2)
    
    # plot data before calibration
    sample1_pca = pca.transform(sample1)
    sample2_pca = pca.transform(sample2)
    sh.scatterHist(sample1_pca[:,pc1], 
                   sample1_pca[:,pc2], 
                   sample2_pca[:,pc1], 
                   sample2_pca[:,pc2], 
                   axis1, 
                   axis2, 
                   title="Data before calibration",
                   name1='sample1', 
                   name2='sample2')
    
    # plot data after calibration
    calibrated_sample2_pca = pca.transform(calibrated_sample2)
    sh.scatterHist(sample1_pca[:,pc1], 
                   sample1_pca[:,pc2], 
                   calibrated_sample2_pca[:,pc1], 
                   calibrated_sample2_pca[:,pc2], 
                   axis1, 
                   axis2, 
                   title="Data after calibration",
                   name1='sample1', 
                   name2='sample2')
    
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    np.save(args.save_dir + '/sample1.csv', sample1)
    np.save(args.save_dir + '/calibrated_sample2.csv', calibrated_sample2)
Exemplo n.º 2
0
t_rec_train, t_c_train = sess.run([rec_a1, c_a1], feed_dict={input_a: target_train_data})
s_cal_train = sess.run(rec_a1, feed_dict={input_a: source_train_data})
s_rec_train, s_c_train = sess.run([rec_b1, c_b1], feed_dict={input_b: source_train_data})  
if use_test:
    t_rec_test, t_c_test = sess.run([rec_a1, c_a1], feed_dict={input_a: target_test_data})
    s_cal_test = sess.run(rec_a1, feed_dict={input_a: source_test_data})
    s_rec_test, s_c_test = sess.run([rec_b1, c_b1], feed_dict={input_b: source_test_data})  


sess.close()


target_pca = pca.transform(target_train_data)
source_pca = pca.transform(source_train_data)
sh.scatterHist(target_pca[:,pc1], target_pca[:,pc2], source_pca[:,pc1], 
            source_pca[:,pc2], axis1, axis2, title="train data before calibration",
            name1='target', name2='source')


target_rec_pca = pca.transform(t_rec_train)
source_cal_pca = pca.transform(s_cal_train)
sh.scatterHist(target_rec_pca[:,pc1], target_rec_pca[:,pc2], 
               source_cal_pca[:,pc1], source_cal_pca[:,pc2], axis1, axis2, 
               title="train data after calibration", name1='target', name2='source')

# ==============================================================================
# =                                  save data                                 =
# ==============================================================================

# save data for visualization
save_dir = './output/%s/calibrated_data' % experiment_name
Exemplo n.º 3
0
before_calib = np.concatenate([source_train_data, target_train_data], axis=0)
after_calib = np.concatenate(
    [calibrated_source_train_data, reconstructed_target_train_data], axis=0)

# embed data
embedding_before = TSNE(n_components=2,
                        n_iter=1200).fit_transform(before_calib)
embedding_after = TSNE(n_components=2, n_iter=1200).fit_transform(after_calib)

# visualize
sh.scatterHist(embedding_before[:n_s, 0],
               embedding_before[:n_s, 1],
               embedding_before[n_s:, 0],
               embedding_before[n_s:, 1],
               axis1='',
               axis2='',
               title='TSNE embedding before calibration',
               name1='batch 1',
               name2='batch 2',
               plots_dir=plots_dir)

sh.scatterHist(embedding_after[:n_s, 0],
               embedding_after[:n_s, 1],
               embedding_after[n_s:, 0],
               embedding_after[n_s:, 1],
               axis1='',
               axis2='',
               title='TSNE embedding after calibration',
               name1='batch 1',
               name2='batch 2',
               plots_dir=plots_dir)
Exemplo n.º 4
0
if use_test:
    source_test_data_pca = pca.transform(source_test_data)
    target_test_data_pca = pca.transform(target_test_data)
    reconstructed_source_test_data_pca = pca.transform(
        reconstructed_source_test_data)
    calibrated_source_test_data_pca = pca.transform(
        calibrated_source_test_data)
    reconstructed_target_test_data_pca = pca.transform(
        reconstructed_target_test_data)

# plot reconstructions
sh.scatterHist(target_train_data_pca[:, pc1],
               target_train_data_pca[:, pc2],
               reconstructed_target_train_data_pca[:, pc1],
               reconstructed_target_train_data_pca[:, pc2],
               axis1,
               axis2,
               title="target train data reconstruction",
               name1='true',
               name2='recon',
               plots_dir=plots_dir)

sh.scatterHist(source_train_data_pca[:, pc1],
               source_train_data_pca[:, pc2],
               reconstructed_source_train_data_pca[:, pc1],
               reconstructed_source_train_data_pca[:, pc2],
               axis1,
               axis2,
               title="source train data reconstruction",
               name1='true',
               name2='recon',
               plots_dir=plots_dir)