コード例 #1
0
def plot_wasserstein_loss(ckpts_dir, save_dir, cfg_file, dataset_file):
    cfg = convert_to_float(yaml.load(open(cfg_file)))
    train_dataset = UKBioBankDataset(dataset_file, None, 'train')

    val_loader = torch.utils.data.DataLoader(train_dataset,
                                             cfg['batch_size'],
                                             shuffle=False,
                                             num_workers=cfg['num_workers'])
    # Datasets & Loaders

    os.makedirs(save_dir, exist_ok=True)
    fname = 'ckpt_step_' + str(181) + '.pth'
    png_name = os.path.join(save_dir, 'wass_plot.')

    npy_log = torch.load(os.path.join(ckpts_dir, fname))['numpy_log']
    w_loss = npy_log['wasserstein_loss']
    steps = npy_log['step']

    print(steps, w_loss)
    plt.plot(steps, w_loss)
    plt.plot(dot_steps, np.array(w_loss)[dot_steps], 'r.')
    plt.grid(which='both', axis='both')
    # plt.xlim(left=0, right=181)
    plt.yscale('log')
    plt.title('Wasserstein Loss with respect to iteration number')
    plt.xlabel('Iteration')
    plt.ylabel('Wasserstein Loss')
    plt.savefig(png_name + 'png')
    plt.savefig(png_name + 'svg')
    plt.savefig(png_name + 'eps')
    plt.close('all')
コード例 #2
0
    def __init__(self, gpu_id, log_dir, dataset_root, cfg):
        self.config_class = Dict(yaml.load(open(cfg)))
        self.log_dir = log_dir
        self.gpu_id = gpu_id
        self.dataset_root = dataset_root
        self.val_interval = 1
        self.test_dataset = UKBioBankDataset(self.dataset_root, None, 'test')
        self.val_dataset = UKBioBankDataset(self.dataset_root, None, 'val')

        self.test_loader = torch.utils.data.DataLoader(
            self.test_dataset,
            self.config_class.batch_size,
            shuffle=False,
            num_workers=self.config_class.num_workers)
        self.val_loader = torch.utils.data.DataLoader(
            self.val_dataset,
            self.config_class.batch_size,
            shuffle=False,
            num_workers=self.config_class.num_workers)
コード例 #3
0
def main(args):

    # Simple training script.
    log_dir = os.path.join(args.runs_path, args.exp_name)
    os.makedirs(log_dir, exist_ok=True)

    config = convert_to_float(yaml.load(open(args.config)))
    copyfile(args.config, os.path.join(log_dir, 'config.yaml'))

    np.random.seed(config['man_seed'])
    torch.manual_seed(config['man_seed'])
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    model_cls, model_py = {
        'condgan': (CondGAN, os.path.join(args.models_path, 'cond_gan.py')),
    }[config['gan_name']]
    copyfile(model_py, os.path.join(log_dir, 'gan_model.py'))

    # Datasets & Loaders
    train_dataset = UKBioBankDataset(args.dataset_file, None, 'train')
    val_dataset = UKBioBankDataset(args.dataset_file, None, 'val')

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             config['batch_size'],
                                             shuffle=False,
                                             num_workers=config['num_workers'])

    start_time = time.time()
    print("\n########################################################")
    model = model_cls(config, train_dataset, val_loader, args.gpu_id,
                      config['fid_interval'], log_dir, config['num_epochs'],
                      args.cnn_run_dir, args.metric_model_id)

    model.run_train()
    print("\nTraining took ",
          str(np.around(time.time() - start_time, 2)) + " seconds.")
    print("########################################################\n")
コード例 #4
0
 def set_train_loader(self, aug_fn, num_total_examples=14864):
     """
     Uses aug_fn to add noise to the training dataset.
     :param aug_fn:
     :param num_total_examples:
     :return:
     """
     real_dataset = UKBioBankDataset(self.dataset_root,
                                     num_total_examples,
                                     'train',
                                     aug_fn=aug_fn)
     self.train_loader = torch.utils.data.DataLoader(
         real_dataset,
         self.config_class.batch_size,
         shuffle=True,
         num_workers=self.config_class.num_workers)
コード例 #5
0
def plot_wad(ckpts_dir, save_dir, cfg_file, dataset_file):
    cfg = convert_to_float(yaml.load(open(cfg_file)))
    train_dataset = UKBioBankDataset(dataset_file, None, 'train')

    val_loader = torch.utils.data.DataLoader(train_dataset,
                                             cfg['batch_size'],
                                             shuffle=False,
                                             num_workers=cfg['num_workers'])
    model = CondGAN(cfg, train_dataset, val_loader, 0, 10, save_dir, 200, 887)

    # Datasets & Loaders

    os.makedirs(save_dir, exist_ok=True)
    fnames = ['ckpt_step_' + str(i) + '.pth' for i in range(0, 181)]
    steps = []
    wads = []
    png_name = os.path.join(save_dir, 'wad_plot.')

    for i, f in enumerate(fnames):
        print(i)
        fname = os.path.join(ckpts_dir, f)
        gen_w = torch.load(fname)['netg']
        model.netg.load_state_dict(gen_w)
        with torch.no_grad():
            gen_batch, gen_labels = model.netg.generate_fake_images(
                model.hyperparameters['batch_size_metric'])
            model.metric_calculator.feed_batch(gen_batch.detach(),
                                               gen_labels.detach())
            _, _, fid_c_mean = model.metric_calculator.calc_class_agnostic_fid(
            )
        steps.append(i)
        wads.append(fid_c_mean)
    print(steps, wads)
    plt.plot(steps, wads)
    plt.plot(dot_steps, np.array(wads)[dot_steps], 'r.')
    plt.grid(which='both', axis='both')

    # plt.xlim(left=0, right=181)
    plt.yscale('log')
    plt.title('WAD with respect to iteration number')
    plt.xlabel('Iteration')
    plt.ylabel('WAD')
    plt.savefig(png_name + 'png')
    plt.savefig(png_name + 'svg')
    plt.savefig(png_name + 'eps')
    plt.close('all')
コード例 #6
0
def test_eval_fid(model_id, gpu_id, batch_size, aug_fn, epochs, log_dir, alpha,
                  save_dir, dataset_file):
    noise_trainer = NoisedTrainer(gpu_id, save_dir, dataset_file, cfg=args.cfg)

    os.makedirs(log_dir, exist_ok=True)
    dataset2 = UKBioBankDataset(dataset_file, None, 'train')
    loader2 = torch.utils.data.DataLoader(dataset2, batch_size, shuffle=True)
    wad_list = []
    max_steps = len(loader2) // 2
    globstep = 0
    calcer = MetricCalculator(model_id,
                              dataset2,
                              batch_size,
                              gpu_id,
                              batch=None)
    for epoch in range(epochs):
        print('ep ', epoch)
        iter_loader2 = iter(loader2)
        for step in range(max_steps):
            print('step ', step)
            inputs1 = next(iter_loader2)
            inputs2 = next(iter_loader2)
            with torch.no_grad():
                aug_inp = aug_fn(*inputs1)
                calcer.reset_ref_batch(inputs2)
                calcer.feed_batch(*aug_inp)
                if epoch == 0 and step == 0:
                    calcer.scatter_plot_activations(
                        os.path.join(save_dir,
                                     str(alpha) + '.svg'))
                _, _, wad = calcer.calc_wad()
            if epoch == 0 and step == 0:
                noise_trainer.set_train_loader(aug_fn)
                testloss, testacc = noise_trainer.run_train_cycle(alpha)
            wad_list.append(wad)
        globstep += 1
    wad = np.array(wad_list)
    wad_mean = np.mean(wad)
    wad_std = np.std(wad)

    return wad_mean, wad_std, testloss, testacc
コード例 #7
0
def main(args):
    log_dir = os.path.join(args.runs_path, args.exp_name)
    os.makedirs(log_dir, exist_ok=True)

    results_list = []
    start_step = 0
    if os.path.isfile(os.path.join(log_dir, "backup.pth")):

        # If we have already started the experiment, resume it
        config_file = os.path.join(log_dir, 'config.yaml')
        config = Dict(yaml.load(open(config_file)))
        load_dict = torch.load(os.path.join(log_dir, "backup.pth"))
        start_step = load_dict['step']
        results_list = load_dict['results_list']
        np.random.set_state(load_dict['numpy_rng'])
        torch.set_rng_state(load_dict['torch_rng'])
    else:
        config = Dict(yaml.load(open(args.config_file)))
        copyfile(args.config_file, os.path.join(log_dir, 'config.yaml'))
        np.random.seed(config.man_seed)
        torch.manual_seed(config.man_seed)

    # Deterministic training!
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # hyperparam split
    conv1_widths = get_log_range(*config.conv1.values()).astype(int)
    conv1_widths = conv1_widths[args.gpu_id * len(conv1_widths) // 2:
                                (args.gpu_id + 1) * len(conv1_widths) // 2]
    conv2_widths = get_log_range(*config.conv2.values()).astype(int)
    lr_values = get_log_range(*config.lr.values())
    mom_values = get_add_range(*config.mom.values())
    wd_values = get_log_range(*config.wd.values())

    val_interval = 1
    max_epochs = 200

    # Datasets & Loaders
    train_dataset = UKBioBankDataset(args.dataset_file, None, 'train')
    val_dataset = UKBioBankDataset(args.dataset_file, None, 'val')

    train_loader = torch.utils.data.DataLoader(train_dataset, config.batch_size, shuffle=True,
                                               num_workers=config.num_workers)
    val_loader = torch.utils.data.DataLoader(val_dataset, config.batch_size, shuffle=False,
                                             num_workers=config.num_workers)

    for step in range(start_step, args.num_runs):
        c1_w = conv1_widths[np.random.randint(0, len(conv1_widths))]
        c2_w = conv2_widths[np.random.randint(0, len(conv2_widths))]
        lr = lr_values[np.random.randint(0, len(lr_values))]
        mom = mom_values[np.random.randint(0, len(mom_values))]
        wd = wd_values[np.random.randint(0, len(wd_values))]
        start_time = time.time()
        print("\n########################################################")
        print("Starting step ", str(step) + " on GPU #" + str(args.gpu_id) + ".")
        print("conv1, conv2 widths: " + str((c1_w, c2_w)))
        print("lr: %.2E, mom: %.2E, wd_ %.2E\n" % (Decimal(lr), Decimal(mom), Decimal(wd)))
        model = ConnectomeConvNet(
            (c1_w, c2_w),
            lr,
            mom,
            wd,
            train_loader,
            val_loader,
            args.gpu_id,
            val_interval,
            step,
            log_dir,
            max_epochs,
        )
        loss, acc, num_params = model.run_train()
        result = [args.gpu_id, step, loss, acc, num_params, c1_w, c2_w, lr, mom, wd]
        results_list.append(result)
        print("\nTraining took ", str(np.around(time.time() - start_time, 2)) + " seconds.")
        print("Loss: ", np.around(loss, 4))
        print("Acc: ", np.around(acc, 4))
        print("Number of params: ", num_params)
        print("########################################################\n")
        save_dict = {
            'numpy_rng': np.random.get_state(),
            'torch_rng': torch.get_rng_state(),
            'results_list': results_list,
            'step': step + 1
        }
        torch.save(save_dict, os.path.join(log_dir, "backup.pth"))
    # saving results to csv
    with open(os.path.join(log_dir, "cnn_search_result.csv"), "w") as f:
        writer = csv.writer(f)
        writer.writerows(results_list)
コード例 #8
0
        plt.close('all')


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_file",
                        type=str,
                        default=os.path.join(os.getcwd(),
                                             'partitioned_dataset_gender.npz'))
    parser.add_argument("--out_dir",
                        type=str,
                        default=os.path.join(os.getcwd(), 'cnn_arch_search'))
    parser.add_argument("--cnn_run_dir",
                        type=str,
                        default=os.path.join(os.getcwd(),
                                             'cnn_arch_search/runs'))
    parser.add_argument("--model_id", type=int, default=887)

    args = parser.parse_args()
    batch_size = 7000
    gpu_id = 0
    from data_handling.dataset import UKBioBankDataset

    dataset = UKBioBankDataset(args.dataset_file)
    calcer = MetricCalculator(args.model_id, dataset, batch_size, gpu_id,
                              args.cnn_run_dir)
    calcer.plot_empty_scatter(os.path.join(args.out_dir, 'empty_scatter.png'))
コード例 #9
0
def plot_generated_matrices(ckpt_dir, save_dir, dataset_file):
    os.makedirs(save_dir, exist_ok=True)
    fnames = ['gen_img_it_' + str(i) + '.npy' for i in range(0, 181)]

    # real example
    train_dataset = UKBioBankDataset(dataset_file, None, 'train')
    loader = torch.utils.data.DataLoader(train_dataset,
                                         200,
                                         shuffle=True,
                                         num_workers=0)
    iterloader = iter(loader)
    batch = next(iterloader)
    female_examples = batch[0][(batch[1] == 0).view(-1), 0, :, :][0, :, :]
    male_examples = batch[0][(batch[1] == 1).view(-1), 0, :, :][0, :, :]
    real_examples = [female_examples, male_examples]
    fig = plt.figure()
    for i in range(2):
        plt.subplot(1, 2, i + 1)
        plt.imshow(real_examples[i],
                   cmap='jet',
                   interpolation='nearest',
                   vmin=-1,
                   vmax=1)
        plt.title('Sex: ' + ['Female', 'Male'][i])
        plt.axis('off')
        # plt.subplots_adjust(top=1, bottom=0, right=1, left=0,
        #                     hspace=1, wspace=0.1)
        # plt.margins(0.1, 0.1)
    plt.savefig(os.path.join(save_dir, 'real_example') + '.eps',
                bbox_inches='tight',
                pad_inches=0.0)
    plt.savefig(os.path.join(save_dir, 'real_example') + '.png',
                bbox_inches='tight',
                pad_inches=0.0)
    plt.close()

    # generated examples
    for j, fname in enumerate(fnames):
        fname = os.path.join(ckpt_dir, fname)
        img_fname = os.path.join(save_dir, 'gen_' + str(j))
        arr = np.load(fname)
        fig = plt.figure()
        for i in range(2):
            plt.subplot(1, 2, i + 1)
            plt.imshow(arr[i, :, :],
                       cmap='jet',
                       interpolation='nearest',
                       vmin=-1,
                       vmax=1)
            plt.title('Sex: ' + ['Female', 'Male'][i])
            plt.axis('off')
        plt.savefig(img_fname + '.eps', bbox_inches='tight', pad_inches=0.0)
        plt.savefig(img_fname + '.png', bbox_inches='tight', pad_inches=0.0)
        plt.close()

    if __name__ == '__main__':
        import argparse
        parser = argparse.ArgumentParser()
        parser.add_argument("--dataset_file",
                            type=str,
                            help='Path to the dataset .npz file',
                            default=os.path.join(
                                os.getcwd(), 'partitioned_dataset_gender.npz'))
        parser.add_argument("--save_dir",
                            type=str,
                            default=os.path.join(
                                os.getcwd(),
                                'gan_runs/cond_gan_debug_24/plots'))
        parser.add_argument("--ckpts_dir",
                            type=str,
                            default=os.path.join(
                                os.getcwd(),
                                'gan_runs/cond_gan_debug_24/ckpts'))
        parser.add_argument("--cnn_runs_dir",
                            type=str,
                            default=os.path.join(os.getcwd(),
                                                 'cnn_arch_search'))

        parser.add_argument("--gan_cfg_file",
                            type=str,
                            default=os.path.join(
                                os.getcwd(),
                                'gan_runs/cond_gan_debug_24/config.yaml'))
        parser.add_argument("--learner_cfg_file",
                            type=str,
                            default=os.path.join(os.getcwd(),
                                                 'config/learning_loss.yaml'))

        args = parser.parse_args()

        # Uncomment to use desired function!

        # plot_generated_matrices(args.ckpt_dir, args.save_dir, args.dataset_file)
        # plot_wasserstein_loss(args.ckpts_dir, args.save_dir, args.gan_cfg_file)
        # plot_wad(args.ckpts_dir, args.save_dir, args.gan_cfg_file)
        calc_learner_loss(args.ckpt_dir, args.save_dir, args.gan_cfg_file,
                          args.learner_cfg_file, args.dataset_file,
                          args.cnn_arch_dir)
コード例 #10
0
def calc_learner_loss(ckpt_dir, save_dir, gan_cfg_file, learner_cfg_file,
                      dataset_file, cnn_arch_dir):
    # steps = [0, 50, 100, 150, 174]
    # steps = [25, 75, 125, 180]
    # steps = [280, 320]
    steps = [254, 300]

    test_dataset = UKBioBankDataset(dataset_file, None, 'test')
    val_dataset = UKBioBankDataset(dataset_file, None, 'val')

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        learner_cfg_file.batch_size,
        shuffle=False,
        num_workers=learner_cfg_file.num_workers)
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        learner_cfg_file.batch_size,
        shuffle=False,
        num_workers=learner_cfg_file.num_workers)

    os.makedirs(save_dir, exist_ok=True)
    gan_model = CondGAN(convert_to_float(yaml.load(open(gan_cfg_file))),
                        val_dataset,
                        val_loader,
                        1,
                        1,
                        ckpt_dir,
                        200,
                        cnn_arch_dir=cnn_arch_dir,
                        metric_model_id=887)
    test_losses, test_accs = [], []

    for step in steps:
        fname = os.path.join(ckpt_dir, 'ckpt_step_' + str(step) + '.pth')
        gan_model.netg.load_state_dict(torch.load(fname)['netg'])
        imgs, labels = gan_model.netg.generate_fake_images(14861)
        gen_dataset = GenDataset(imgs.cpu(), labels.cpu())
        train_loader = torch.utils.data.DataLoader(gen_dataset,
                                                   14861,
                                                   shuffle=True,
                                                   num_workers=0)

        classifier = ConnectomeConvNet(
            (learner_cfg_file.c1, learner_cfg_file.c2),
            learner_cfg_file.lr,
            learner_cfg_file.mom,
            learner_cfg_file.wd,
            train_loader,
            val_loader,
            1,
            1,
            0,
            os.path.join(save_dir, 'gen' + str(step)),
            200,
            allow_stop=False,
            verbose=True,
        )
        loss, acc, num_params = classifier.run_train()
        testloss, testacc = classifier.test(test_loader)
        print(step, testloss, testacc)
        test_losses.append(testloss)
        test_accs.append(test_accs)
    np.save(os.path.join(save_dir, 'loss_result.npy'),
            np.array([steps, test_losses, test_accs]))