def main():
    args = parse_option()
    set_seed(args.seed)
    print(args)

    output_dir = f'{args.work_dir}/{args.exp_name}'
    save_path = Path(output_dir)
    save_path.mkdir(parents=True, exist_ok=True)

    if torch.cuda.is_available():
        device = "cuda"
        cudnn.benchmark = True
    else:
        device = "cpu"

    # load model
    assert args.netG_ckpt_step
    print(f'load model from {save_path} step: {args.netG_ckpt_step}')
    netG, _, netD_drs, _, _, _ = get_gan_model(
        dataset_name=args.dataset,
        model=args.model,
        loss_type=args.loss_type,
        drs=True,
    )
    netG.to(device)
    if not args.netG_train_mode:
        netG.eval()
        netD_drs.eval()
        netG.to(device)
        netD_drs.to(device)

    if args.dataset == 'celeba':
        dataset = 'celeba_64'
    else:
        raise ValueError("Dataset should be CelebA")

    evaluate_drs_with_attr(
        metric='partial_recall',
        attr=args.attr,
        log_dir=save_path,
        netG=netG,
        netD_drs=netD_drs,
        dataset=dataset,
        num_real_samples=10000,
        num_fake_samples=10000,
        evaluate_step=args.netG_ckpt_step,
        num_runs=1,
        device=device,
        use_original_netD=args.use_original_netD,
    )
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", "-d", default="color_mnist", type=str)
    parser.add_argument("--root", "-r", default="./dataset/colour_mnist", type=str, help="dataset dir")
    parser.add_argument("--work_dir", default="./exp_results", type=str, help="output dir")
    parser.add_argument("--exp_name", default="colour_mnist", type=str, help="exp name")
    parser.add_argument("--baseline_exp_name", default="colour_mnist", type=str, help="exp name")
    parser.add_argument("--model", default="mnistgan", type=str, help="network model")
    parser.add_argument('--gpu', default='0', type=str,
                        help='id(s) for CUDA_VISIBLE_DEVICES')
    parser.add_argument('--num_pack', default=1, type=int)
    parser.add_argument('--batch_size', default=64, type=int)
    parser.add_argument('--seed', default=1, type=int)
    parser.add_argument('--num_steps', default=20000, type=int)
    parser.add_argument('--logit_save_steps', default=100, type=int)
    parser.add_argument('--decay', default='None', type=str)
    parser.add_argument('--n_dis', default=1, type=int)
    parser.add_argument('--p1_step', default=10000, type=int)
    parser.add_argument('--major_ratio', default=0.99, type=float)
    parser.add_argument('--num_data', default=10000, type=int)
    parser.add_argument('--resample_score', type=str)
    parser.add_argument("--loss_type", default="hinge", type=str, help="loss type")
    parser.add_argument('--use_eval_logits', type=int)
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    output_dir = f'{args.work_dir}/{args.exp_name}'
    save_path = Path(output_dir)
    save_path.mkdir(parents=True, exist_ok=True)

    baseline_output_dir = f'{args.work_dir}/{args.baseline_exp_name}'
    baseline_save_path = Path(baseline_output_dir)

    prefix = args.exp_name.split('/')[-1]

    set_seed(args.seed)

    if torch.cuda.is_available():
        device = "cuda"
        cudnn.benchmark = True
    else:
        device = "cpu"

    netG, netD, netD_drs, optG, optD, optD_drs = get_gan_model(
        dataset_name=args.dataset,
        model=args.model,
        drs=True,
        loss_type=args.loss_type,
    )

    netG_ckpt_path = baseline_save_path / f'checkpoints/netG/netG_{args.p1_step}_steps.pth'
    netD_ckpt_path = baseline_save_path / f'checkpoints/netD/netD_{args.p1_step}_steps.pth'
    netD_drs_ckpt_path = baseline_save_path / f'checkpoints/netD/netD_{args.p1_step}_steps.pth'

    logit_path = baseline_save_path / ('logits_netD_eval.pkl' if args.use_eval_logits == 1 else 'logits_netD_train.pkl')
    print(f'Use logit from: {logit_path}')
    logits = pickle.load(open(logit_path, "rb"))
    score_start_step = args.p1_step - 5000
    score_end_step = args.p1_step
    score_dict = calculate_scores(logits, start_epoch=score_start_step, end_epoch=score_end_step)
    sample_weights = score_dict[args.resample_score]
    print(f'sample_weights mean: {sample_weights.mean()}, var: {sample_weights.var()}, max: {sample_weights.max()}, min: {sample_weights.min()}')


    print_num_params(netG, netD)

    ds_train = get_predefined_dataset(
        dataset_name=args.dataset,
        root=args.root,
        weights=None,
        major_ratio=args.major_ratio,
        num_data=args.num_data
    )
    dl_train = get_dataloader(
        ds_train,
        batch_size=args.batch_size,
        weights=sample_weights if args.resample_score is not None else None)
    dl_drs = get_dataloader(ds_train, batch_size=args.batch_size, weights=None)


    data_iter = iter(dl_train)
    imgs, _, _, _ = next(data_iter)
    plot_data(imgs, num_per_side=8, save_path=save_path, file_name=f'{prefix}_resampled_train_data_p2', vis=None)
    plot_score_sort(ds_train, score_dict, save_path=save_path, phase=f'{prefix}_{score_start_step}-{score_end_step}_score', plot_metric_name=args.resample_score)
    # plot_score_box(ds_train, score_dict, save_path=save_path, phase=f'{prefix}_{score_start_step}-{score_end_step}_box')

    print(args, netG_ckpt_path, netD_ckpt_path, netD_drs_ckpt_path)

    # Start training
    trainer = LogTrainer(
        output_path=save_path,
        logit_save_steps=args.logit_save_steps,
        netD=netD,
        netG=netG,
        optD=optD,
        optG=optG,
        netG_ckpt_file=netG_ckpt_path,
        netD_ckpt_file=netD_ckpt_path,
        netD_drs_ckpt_file=netD_drs_ckpt_path,
        netD_drs=netD_drs,
        optD_drs=optD_drs,
        dataloader_drs=dl_drs,
        n_dis=args.n_dis,
        num_steps=args.num_steps,
        save_steps=1000,
        vis_steps=100,
        lr_decay=args.decay,
        dataloader=dl_train,
        log_dir=output_dir,
        print_steps=10,
        device=device,
        save_logits=False,
    )
    trainer.train()

    plot_color_mnist_generator(netG, save_path=save_path, file_name=f'{prefix}-eval_p2')

    netG_drs = drs.DRS(netG, netD_drs, device=device)
    # for percentile in np.arange(50, 100, 5):
    # netG_drs.percentile = percentile
    percentile = 80
    plot_color_mnist_generator(netG_drs, save_path=save_path, file_name=f'{prefix}-eval_drs_percent{percentile}_p2')

    netG.restore_checkpoint(ckpt_file=netG_ckpt_path)
    netG.to(device)
    plot_color_mnist_generator(netG, save_path=save_path, file_name=f'{prefix}-eval_generated_p1')
Beispiel #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", "-d", default="cifar10", type=str)
    parser.add_argument("--root",
                        "-r",
                        default="./dataset/cifar10",
                        type=str,
                        help="dataset dir")
    parser.add_argument("--work_dir",
                        default="./exp_results",
                        type=str,
                        help="output dir")
    parser.add_argument("--exp_name",
                        default="mimicry_pretrained-seed1",
                        type=str,
                        help="exp name")
    parser.add_argument("--model",
                        default="sngan",
                        type=str,
                        help="network model")
    parser.add_argument("--loss_type",
                        default="hinge",
                        type=str,
                        help="loss type")
    parser.add_argument('--gpu',
                        default='0',
                        type=str,
                        help='id(s) for CUDA_VISIBLE_DEVICES')
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--seed', default=1, type=int)
    parser.add_argument("--netG_ckpt_step", type=int)
    parser.add_argument("--netG_train_mode", action='store_true')
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    output_dir = f'{args.work_dir}/{args.exp_name}'
    save_path = Path(output_dir)
    save_path.mkdir(parents=True, exist_ok=True)

    set_seed(args.seed)

    if torch.cuda.is_available():
        device = "cuda"
        cudnn.benchmark = True
    else:
        device = "cpu"

    # load model
    assert args.netG_ckpt_step
    print(f'load model from {save_path} step: {args.netG_ckpt_step}')
    netG, _, _, _ = get_gan_model(
        dataset_name=args.dataset,
        model=args.model,
        loss_type=args.loss_type,
    )
    netG.to(device)
    if not args.netG_train_mode:
        netG.eval()

    if args.dataset == 'celeba':
        dataset = 'celeba_64'
    elif args.dataset == 'imagenet':
        dataset = 'imagenet_128'
    else:
        dataset = args.dataset

    if args.dataset == 'ffhq':
        stats_file = './precalculated_statistics/fid_stats_ffhq_69k_run_0.npz'
        # Evaluate fid
        evaluate_ffhq(metric='fid',
                      log_dir=save_path,
                      data_path=args.root,
                      netG=netG,
                      dataset=dataset,
                      num_real_samples=50000,
                      num_fake_samples=50000,
                      evaluate_step=args.netG_ckpt_step,
                      num_runs=1,
                      device=device,
                      stats_file=stats_file)
    else:
        if args.dataset == 'celeba':
            stats_name = 'celeba_64_202k_run_0'
        elif args.dataset == 'cifar10':
            stats_name = 'cifar10_train'
        elif args.dataset == 'imagenet':
            stats_name = 'imagenet_128_50k_run_0'
        stats_file = f'./precalculated_statistics/fid_stats_{stats_name}.npz'

        # Evaluate fid
        mmc.metrics.evaluate(metric='fid',
                             log_dir=save_path,
                             netG=netG,
                             dataset=dataset,
                             num_real_samples=50000,
                             num_fake_samples=50000,
                             evaluate_step=args.netG_ckpt_step,
                             num_runs=1,
                             device=device,
                             stats_file=stats_file)

        # Evaluate inception score
        mmc.metrics.evaluate(metric='inception_score',
                             log_dir=save_path,
                             netG=netG,
                             num_samples=50000,
                             evaluate_step=args.netG_ckpt_step,
                             num_runs=1,
                             device=device)

        # Evaluate PR
        evaluate_pr(
            log_dir=save_path,
            netG=netG,
            dataset=dataset,
            num_real_samples=10000,
            num_fake_samples=10000,
            evaluate_step=args.netG_ckpt_step,
            num_runs=1,
            device=device,
        )
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", "-d", default="cifar10", type=str)
    parser.add_argument("--root", "-r", default="./dataset/cifar10", type=str, help="dataset dir")
    parser.add_argument("--work_dir", default="./exp_results", type=str, help="output dir")
    parser.add_argument("--exp_name", default="mimicry_pretrained-seed1", type=str, help="exp name")
    parser.add_argument('--gpu', default='0', type=str,
                        help='id(s) for CUDA_VISIBLE_DEVICES')
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--seed', default=1, type=int)
    parser.add_argument('--epochs', default=50, type=int)
    parser.add_argument("--netG_step", type=int)
    parser.add_argument("--netG_train_mode", action='store_true')
    parser.add_argument("--cae_ckpt_path", type=str)
    parser.add_argument("--model", type=str)
    parser.add_argument("--loss_type", default='ns', type=str)
    parser.add_argument("--generated_dataset_path", type=str)
    parser.add_argument('--major_ratio', default=0.99, type=float)
    parser.add_argument('--num_data', default=10000, type=int)
    parser.add_argument('--num_pack', default=1, type=int)
    parser.add_argument('--topk', action='store_true')
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    output_dir = f'{args.work_dir}/{args.exp_name}'
    save_path = Path(output_dir)
    save_path.mkdir(parents=True, exist_ok=True)

    set_seed(args.seed)

    if torch.cuda.is_available():
        device = "cuda"
        cudnn.benchmark = True
    else:
        device = "cpu"
    
    
    if args.dataset == 'mnist_c':
        ds_test = get_predefined_dataset(
            dataset_name=args.dataset,
            root=args.root,
        )
    else:
        ds_test = get_predefined_dataset(
            dataset_name=args.dataset,
            root=args.root,
            major_ratio=args.major_ratio,
            num_data=args.num_data
        )
    dl_test = get_dataloader(dataset=ds_test, batch_size=args.batch_size)

    # load model
    assert args.netG_step
    print(f'load model from: {args.netG_step}')
    netG, _, netD_drs, _, _, _ = get_gan_model(
        args.dataset, 
        model=args.model,
        drs=True, 
        loss_type=args.loss_type, 
        topk=args.topk, 
        num_pack=args.num_pack,
        inclusive=True,
        num_data=args.num_data,
        dataloader=dl_test,)
    netG.to(device)
    netD_drs.to(device)
    netG.get_setting(train=False)

    step = netG.restore_checkpoint(ckpt_file=save_path / f'checkpoints/netG/netG_{args.netG_step}_steps.pth')

    netD_drs_ckpt_path = save_path / f'checkpoints/netD_drs/netD_drs_{args.netG_step}_steps.pth'
    if os.path.exists(netD_drs_ckpt_path):
        use_drs = True
        netD_drs.restore_checkpoint(ckpt_file=netD_drs_ckpt_path)
        netD_drs.to(device)
        drs = DRS(netG=netG, netD=netD_drs, device=device)
    else:
        use_drs = False
        drs = netG
        
    print(f'use drs: {use_drs}')
    


    model = get_ae_model(dataset_name=args.dataset).to(device)
    if args.cae_ckpt_path:
        model.load_state_dict(torch.load(args.cae_ckpt_path))
    else:
        if args.generated_dataset_path:
            print(f'skip data generation, use: {args.generated_dataset_path}')
            generated_dataset_path = args.generated_dataset_path
        else:
            # generate dataset
            generated_dataset_path = save_path / f'netG_{step}_steps_seed{args.seed}_generated_dataset.pkl'
            generate_dataset(drs, generated_dataset_path, eval_mode=not args.netG_train_mode, device=device)
            print(f'data generated in: {generated_dataset_path}')
        
        ds_train = get_generated_dataset(dataset_name=args.dataset, root=generated_dataset_path)
        dl_train = get_dataloader(dataset=ds_train, batch_size=args.batch_size)
        cae_ckpt_path = save_path / 'cae_checkpoints' / f'{step}_steps_seed{args.seed}'
        cae_ckpt_path.mkdir(parents=True, exist_ok=True)
        model = train_cae(model, dl_train=dl_train, dl_test=dl_test, save_path=cae_ckpt_path, epochs=args.epochs)

    final_loss = test_cae(dl_test, model)
    final_score = final_loss
    pickle.dump(final_score, open(save_path / f'netG_{step}_steps_seed{args.seed}_epoch{args.epochs}_ae_score.pkl', 'wb'))
    show_sorted_score_samples(
        dataset=ds_test,
        score=final_score,
        save_path=save_path,
        score_name='ae_score',
        plot_name=f'netG_{step}_steps_seed{args.seed}_epoch{args.epochs}_ae_score')
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", "-d", default="color_mnist", type=str)
    parser.add_argument("--root",
                        "-r",
                        default="./dataset/colour_mnist",
                        type=str,
                        help="dataset dir")
    parser.add_argument("--work_dir",
                        default="./exp_results",
                        type=str,
                        help="output dir")
    parser.add_argument("--exp_name",
                        default="colour_mnist",
                        type=str,
                        help="exp name")
    parser.add_argument("--loss_type",
                        default="ns",
                        type=str,
                        help="loss type")
    parser.add_argument("--model",
                        default="mnist_dcgan",
                        type=str,
                        help="network model")
    parser.add_argument('--gpu',
                        default='0',
                        type=str,
                        help='id(s) for CUDA_VISIBLE_DEVICES')
    parser.add_argument('--num_pack', default=1, type=int)
    parser.add_argument('--batch_size', default=64, type=int)
    parser.add_argument('--seed', default=1, type=int)
    parser.add_argument('--use_clipping', action='store_true')
    parser.add_argument('--num_steps', default=20000, type=int)
    parser.add_argument('--logit_save_steps', default=100, type=int)
    parser.add_argument('--decay', default='None', type=str)
    parser.add_argument('--n_dis', default=1, type=int)
    parser.add_argument('--major_ratio', default=0.99, type=float)
    parser.add_argument('--num_data', default=10000, type=int)
    parser.add_argument('--topk', default=0, type=int)
    parser.add_argument('--resample_score', type=str)
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    output_dir = f'{args.work_dir}/{args.exp_name}'
    save_path = Path(output_dir)
    save_path.mkdir(parents=True, exist_ok=True)

    set_seed(args.seed)

    if torch.cuda.is_available():
        device = "cuda"
        cudnn.benchmark = True
    else:
        device = "cpu"

    ds_train = get_predefined_dataset(dataset_name=args.dataset,
                                      root=args.root,
                                      weights=None,
                                      major_ratio=args.major_ratio,
                                      num_data=args.num_data)
    dl_train = get_dataloader(ds_train,
                              batch_size=args.batch_size,
                              weights=None)

    netG, netD, optG, optD = get_gan_model(
        dataset_name=args.dataset,
        model=args.model,
        num_pack=args.num_pack,
        loss_type=args.loss_type,
        topk=args.topk == 1,
        inclusive=True,
        num_data=args.num_data,
        dataloader=dl_train,
    )

    print_num_params(netG, netD)

    print(args)

    # Start training
    trainer = LogTrainer(
        output_path=save_path,
        logit_save_steps=args.logit_save_steps,
        netD=netD,
        netG=netG,
        optD=optD,
        optG=optG,
        n_dis=args.n_dis,
        num_steps=args.num_steps,
        save_steps=1000,
        vis_steps=100,
        lr_decay=args.decay,
        dataloader=dl_train,
        log_dir=output_dir,
        print_steps=10,
        device=device,
        topk=args.topk,
        save_logits=args.num_pack == 1,
        save_eval_logits=False,
    )
    trainer.train()

    plot_color_mnist_generator(netG, save_path=save_path, file_name='eval_p1')
Beispiel #6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", "-d", default="color_mnist", type=str)
    parser.add_argument("--root", "-r", default="./dataset/colour_mnist", type=str, help="dataset dir")
    parser.add_argument("--work_dir", default="./exp_results", type=str, help="output dir")
    parser.add_argument("--exp_name", default="colour_mnist", type=str, help="exp name")
    parser.add_argument("--baseline_exp_name", default="colour_mnist", type=str, help="exp name")
    parser.add_argument("--model", default="mnistgan", type=str, help="network model")
    parser.add_argument('--gpu', default='0', type=str,
                        help='id(s) for CUDA_VISIBLE_DEVICES')
    parser.add_argument('--num_pack', default=1, type=int)
    parser.add_argument('--batch_size', default=64, type=int)
    parser.add_argument('--seed', default=1, type=int)
    parser.add_argument('--use_clipping', action='store_true')
    parser.add_argument('--num_steps', default=20000, type=int)
    parser.add_argument('--logit_save_steps', default=100, type=int)
    parser.add_argument('--decay', default='None', type=str)
    parser.add_argument('--n_dis', default=1, type=int)
    parser.add_argument('--p1_step', default=10000, type=int)
    parser.add_argument('--major_ratio', default=0.99, type=float)
    parser.add_argument('--num_data', default=10000, type=int)
    parser.add_argument('--resample_score', type=str)
    parser.add_argument("--loss_type", default="hinge", type=str, help="loss type")
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    output_dir = f'{args.work_dir}/{args.exp_name}'
    save_path = Path(output_dir)
    save_path.mkdir(parents=True, exist_ok=True)

    baseline_output_dir = f'{args.work_dir}/{args.baseline_exp_name}'
    baseline_save_path = Path(baseline_output_dir)

    prefix = args.exp_name.split('/')[-1]

    set_seed(args.seed)

    if torch.cuda.is_available():
        device = "cuda"
        cudnn.benchmark = True
    else:
        device = "cpu"

    netG, netD, optG, optD = get_gan_model(
        dataset_name=args.dataset,
        model=args.model,
        loss_type=args.loss_type,
        gold=True
    )

    netG_ckpt_path = baseline_save_path / f'checkpoints/netG/netG_{args.p1_step}_steps.pth'
    netD_ckpt_path = baseline_save_path / f'checkpoints/netD/netD_{args.p1_step}_steps.pth'

    print_num_params(netG, netD)

    ds_train = get_predefined_dataset(
        dataset_name=args.dataset,
        root=args.root,
        weights=None,
        major_ratio=args.major_ratio,
        num_data=args.num_data
    )
    dl_train = get_dataloader(
        ds_train,
        batch_size=args.batch_size,
        weights=None)

    data_iter = iter(dl_train)
    imgs, _, _, _ = next(data_iter)
    plot_data(imgs, num_per_side=8, save_path=save_path, file_name=f'{prefix}_gold_train_data_p2', vis=None)

    print(args, netG_ckpt_path, netD_ckpt_path)

    # Start training
    trainer = LogTrainer(
        output_path=save_path,
        logit_save_steps=args.logit_save_steps,
        netD=netD,
        netG=netG,
        optD=optD,
        optG=optG,
        netG_ckpt_file=netG_ckpt_path,
        netD_ckpt_file=netD_ckpt_path,
        n_dis=args.n_dis,
        num_steps=args.num_steps,
        save_steps=1000,
        vis_steps=100,
        lr_decay=args.decay,
        dataloader=dl_train,
        log_dir=output_dir,
        print_steps=10,
        device=device,
        save_logits=False,
        gold=True,
        gold_step=args.p1_step
    )
    trainer.train()

    plot_color_mnist_generator(netG, save_path=save_path, file_name=f'{prefix}-eval_p2')

    netG.restore_checkpoint(ckpt_file=netG_ckpt_path)
    netG.to(device)
    plot_color_mnist_generator(netG, save_path=save_path, file_name=f'{prefix}-eval_generated_p1')
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", "-d", default="cifar10", type=str)
    parser.add_argument("--root",
                        "-r",
                        default="./dataset/cifar10",
                        type=str,
                        help="dataset dir")
    parser.add_argument("--work_dir",
                        default="./exp_results",
                        type=str,
                        help="output dir")
    parser.add_argument("--exp_name", type=str, help="exp name")
    parser.add_argument("--baseline_exp_name", type=str, help="exp name")
    parser.add_argument('--p1_step', default=40000, type=int)
    parser.add_argument("--model",
                        default="sngan",
                        type=str,
                        help="network model")
    parser.add_argument("--loss_type",
                        default="hinge",
                        type=str,
                        help="loss type")
    parser.add_argument('--gpu',
                        default='0',
                        type=str,
                        help='id(s) for CUDA_VISIBLE_DEVICES')
    parser.add_argument('--num_steps', default=80000, type=int)
    parser.add_argument('--batch_size', default=64, type=int)
    parser.add_argument('--seed', default=1, type=int)
    parser.add_argument('--decay', default='linear', type=str)
    parser.add_argument('--n_dis', default=5, type=int)
    parser.add_argument('--resample_score', type=str)
    parser.add_argument('--gold', action='store_true')
    parser.add_argument('--topk', action='store_true')
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    output_dir = f'{args.work_dir}/{args.exp_name}'
    save_path = Path(output_dir)
    save_path.mkdir(parents=True, exist_ok=True)

    baseline_output_dir = f'{args.work_dir}/{args.baseline_exp_name}'
    baseline_save_path = Path(baseline_output_dir)

    set_seed(args.seed)

    if torch.cuda.is_available():
        device = "cuda"
        cudnn.benchmark = True
    else:
        device = "cpu"

    prefix = args.exp_name.split('/')[-1]

    if args.dataset == 'celeba':
        window = 5000
    elif args.dataset == 'cifar10':
        window = 5000
    else:
        window = 5000

    if not args.gold:
        logit_path = baseline_save_path / 'logits_netD_eval.pkl'
        print(f'Use logit from: {logit_path}')
        logits = pickle.load(open(logit_path, "rb"))
        score_start_step = (args.p1_step - window)
        score_end_step = args.p1_step
        score_dict = calculate_scores(logits,
                                      start_epoch=score_start_step,
                                      end_epoch=score_end_step)
        sample_weights = score_dict[args.resample_score]
        print(
            f'sample_weights mean: {sample_weights.mean()}, var: {sample_weights.var()}, max: {sample_weights.max()}, min: {sample_weights.min()}'
        )
    else:
        sample_weights = None

    netG_ckpt_path = baseline_save_path / f'checkpoints/netG/netG_{args.p1_step}_steps.pth'
    netD_ckpt_path = baseline_save_path / f'checkpoints/netD/netD_{args.p1_step}_steps.pth'

    netD_drs_ckpt_path = baseline_save_path / f'checkpoints/netD/netD_{args.p1_step}_steps.pth'
    netG, netD, netD_drs, optG, optD, optD_drs = get_gan_model(
        dataset_name=args.dataset,
        model=args.model,
        loss_type=args.loss_type,
        drs=True,
        topk=args.topk,
        gold=args.gold,
    )

    print(f'model: {args.model} - netD_drs_ckpt_path: {netD_drs_ckpt_path}')

    print_num_params(netG, netD)

    ds_train = get_predefined_dataset(dataset_name=args.dataset,
                                      root=args.root,
                                      weights=None)
    dl_train = get_dataloader(ds_train,
                              batch_size=args.batch_size,
                              weights=sample_weights)

    ds_drs = get_predefined_dataset(dataset_name=args.dataset,
                                    root=args.root,
                                    weights=None)
    dl_drs = get_dataloader(ds_drs, batch_size=args.batch_size, weights=None)

    if not args.gold:
        show_sorted_score_samples(ds_train,
                                  score=sample_weights,
                                  save_path=save_path,
                                  score_name=args.resample_score,
                                  plot_name=prefix)

    print(args)

    # Start training
    trainer = LogTrainer(
        output_path=save_path,
        netD=netD,
        netG=netG,
        optD=optD,
        optG=optG,
        netG_ckpt_file=str(netG_ckpt_path),
        netD_ckpt_file=str(netD_ckpt_path),
        netD_drs_ckpt_file=str(netD_drs_ckpt_path),
        netD_drs=netD_drs,
        optD_drs=optD_drs,
        dataloader_drs=dl_drs,
        n_dis=args.n_dis,
        num_steps=args.num_steps,
        save_steps=1000,
        lr_decay=args.decay,
        dataloader=dl_train,
        log_dir=output_dir,
        print_steps=10,
        device=device,
        topk=args.topk,
        gold=args.gold,
        gold_step=args.p1_step,
        save_logits=False,
    )
    trainer.train()
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", "-d", default="cifar10", type=str)
    parser.add_argument("--work_dir", default="./exp_results", type=str, help="output dir")
    parser.add_argument("--exp_name", default="mimicry_pretrained-seed1", type=str, help="exp name")
    parser.add_argument("--baseline_exp_name", type=str, help="exp name")
    parser.add_argument('--p1_step', default=40000, type=int)
    parser.add_argument("--model", default="sngan", type=str, help="network model")
    parser.add_argument("--loss_type", default="hinge", type=str, help="loss type")
    parser.add_argument('--gpu', default='0', type=str,
                        help='id(s) for CUDA_VISIBLE_DEVICES')
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--seed', default=1, type=int)
    parser.add_argument("--netG_ckpt_step", type=int)
    parser.add_argument("--netG_train_mode", action='store_true')
    parser.add_argument('--resample_score', type=str)
    parser.add_argument('--gold', action='store_true')
    parser.add_argument('--topk', action='store_true')
    parser.add_argument("--index_num", default=100, type=int, help="number of index to use for FID score")
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    output_dir = f'{args.work_dir}/{args.exp_name}'
    save_path = Path(output_dir)
    save_path.mkdir(parents=True, exist_ok=True)

    baseline_output_dir = f'{args.work_dir}/{args.baseline_exp_name}'
    baseline_save_path = Path(baseline_output_dir)

    set_seed(args.seed)

    if torch.cuda.is_available():
        device = "cuda"
        cudnn.benchmark = True
    else:
        device = "cpu"

    # load model
    assert args.netG_ckpt_step
    print(f'load model from {save_path} step: {args.netG_ckpt_step}')
    netG, _, _, _ = get_gan_model(
        dataset_name=args.dataset,
        model=args.model,
        loss_type=args.loss_type,
        topk=args.topk,
        gold=args.gold,
    )
    netG.to(device)
    if not args.netG_train_mode:
        netG.eval()

    if args.dataset == 'celeba':
        dataset = 'celeba_64'
        window = 5000
    else:
        dataset = args.dataset
        window = 5000

    logit_path = baseline_save_path / 'logits_netD_eval.pkl'
    print(f'Use logit from: {logit_path}')
    logits = pickle.load(open(logit_path, "rb"))
    score_start_step = (args.p1_step - window)
    score_end_step = args.p1_step
    score_dict = calculate_scores(logits, start_epoch=score_start_step, end_epoch=score_end_step)
    sample_weights = score_dict[args.resample_score]
    print(
        f'sample_weights mean: {sample_weights.mean()}, var: {sample_weights.var()}, max: {sample_weights.max()}, min: {sample_weights.min()}')

    print(args)

    sort_index = np.argsort(sample_weights)
    high_index = sort_index[-args.index_num:]
    low_index = sort_index[:args.index_num]

    # Evaluate fid with index of high weight
    evaluate_with_index(
        metric='fid',
        index=high_index,
        log_dir=save_path,
        netG=netG,
        dataset=dataset,
        num_fake_samples=50000,
        evaluate_step=args.netG_ckpt_step,
        num_runs=1,
        device=device,
        stats_file=None,
        name=f'high_{args.resample_score}', )

    # Evaluate fid with index of low weight
    evaluate_with_index(
        metric='fid',
        index=low_index,
        log_dir=save_path,
        netG=netG,
        dataset=dataset,
        num_fake_samples=50000,
        evaluate_step=args.netG_ckpt_step,
        num_runs=1,
        device=device,
        stats_file=None,
        name=f'low_{args.resample_score}', )
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", "-d", default="cifar10", type=str)
    parser.add_argument("--root",
                        "-r",
                        default="./dataset/cifar10",
                        type=str,
                        help="dataset dir")
    parser.add_argument("--work_dir",
                        default="./exp_results",
                        type=str,
                        help="output dir")
    parser.add_argument("--exp_name",
                        default="cifar10",
                        type=str,
                        help="exp name")
    parser.add_argument("--model",
                        default="sngan",
                        type=str,
                        help="network model")
    parser.add_argument("--loss_type",
                        default="hinge",
                        type=str,
                        help="loss type")
    parser.add_argument('--gpu',
                        default='0',
                        type=str,
                        help='id(s) for CUDA_VISIBLE_DEVICES')
    parser.add_argument('--num_pack', default=1, type=int)
    parser.add_argument('--batch_size', default=64, type=int)
    parser.add_argument('--seed', default=1, type=int)
    parser.add_argument('--download_dataset', action='store_true')
    parser.add_argument('--topk', action='store_true')
    parser.add_argument('--num_steps', default=100000, type=int)
    parser.add_argument('--logit_save_steps', default=100, type=int)
    parser.add_argument('--decay', default='linear', type=str)
    parser.add_argument('--n_dis', default=5, type=int)
    parser.add_argument('--imb_factor', default=0.1, type=float)
    parser.add_argument('--celeba_class_attr', default='glass', type=str)
    parser.add_argument('--ckpt_step', type=int)
    parser.add_argument('--no_save_logits', action='store_true')
    parser.add_argument('--save_logit_after', default=30000, type=int)
    parser.add_argument('--stop_save_logit_after', default=60000, type=int)
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    output_dir = f'{args.work_dir}/{args.exp_name}'
    save_path = Path(output_dir)
    save_path.mkdir(parents=True, exist_ok=True)

    set_seed(args.seed)

    if torch.cuda.is_available():
        device = "cuda"
        cudnn.benchmark = True
    else:
        device = "cpu"

    netG, netD, optG, optD = get_gan_model(
        dataset_name=args.dataset,
        model=args.model,
        loss_type=args.loss_type,
        topk=args.topk,
    )

    print_num_params(netG, netD)

    ds_train = get_predefined_dataset(
        dataset_name=args.dataset,
        root=args.root,
    )
    dl_train = get_dataloader(ds_train, batch_size=args.batch_size)

    if args.dataset == 'celeba':
        args.num_steps = 75000
        args.logit_save_steps = 100
        args.save_logit_after = 55000
        args.stop_save_logit_after = 60000

    if args.dataset == 'cifar10':
        args.num_steps = 50000
        args.logit_save_steps = 100
        args.save_logit_after = 35000
        args.stop_save_logit_after = 40000

    print(args)

    if args.ckpt_step:
        netG_ckpt_file = save_path / f'checkpoints/netG/netG_{args.ckpt_step}_steps.pth'
        netD_ckpt_file = save_path / f'checkpoints/netD/netD_{args.ckpt_step}_steps.pth'
    else:
        netG_ckpt_file = None
        netD_ckpt_file = None

    # Start training
    trainer = LogTrainer(
        output_path=save_path,
        logit_save_steps=args.logit_save_steps,
        netG_ckpt_file=netG_ckpt_file,
        netD_ckpt_file=netD_ckpt_file,
        netD=netD,
        netG=netG,
        optD=optD,
        optG=optG,
        n_dis=args.n_dis,
        num_steps=args.num_steps,
        save_steps=1000,
        lr_decay=args.decay,
        dataloader=dl_train,
        log_dir=output_dir,
        print_steps=10,
        device=device,
        topk=args.topk,
        save_logits=not args.no_save_logits,
        save_logit_after=args.save_logit_after,
        stop_save_logit_after=args.stop_save_logit_after,
    )
    trainer.train()
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--work_dir",
                        default="./exp_results",
                        type=str,
                        help="output dir")
    parser.add_argument("--exp_name",
                        default="mimicry_pretrained-seed1",
                        type=str,
                        help="exp name")
    parser.add_argument("--model",
                        default="sngan",
                        type=str,
                        help="network model")
    parser.add_argument("--loss_type",
                        default="hinge",
                        type=str,
                        help="loss type")
    parser.add_argument("--classifier",
                        default="vgg16",
                        type=str,
                        help="calssifier network model")
    parser.add_argument('--gpu',
                        default='0',
                        type=str,
                        help='id(s) for CUDA_VISIBLE_DEVICES')
    parser.add_argument('--batch_size', default=100, type=int)
    parser.add_argument('--seed', default=1, type=int)
    parser.add_argument("--netG_ckpt_step", type=int)
    parser.add_argument("--netG_train_mode", action='store_true')
    parser.add_argument("--use_original_netD", action='store_true')
    parser.add_argument('--attr', default='Bald', type=str)
    parser.add_argument('--drs', action='store_true')
    parser.add_argument('--num_samples', default=50000, type=int)
    args = parser.parse_args()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    set_seed(args.seed)
    print(args)

    save_path = f'{args.work_dir}/{args.exp_name}'

    if torch.cuda.is_available():
        device = "cuda"
        cudnn.benchmark = True
    else:
        device = "cpu"

    # load model
    assert args.netG_ckpt_step
    print(f'load model from {save_path} step: {args.netG_ckpt_step}')
    if args.drs:
        netG, _, netD_drs, _, _, _ = get_gan_model(dataset_name='celeba',
                                                   model=args.model,
                                                   loss_type=args.loss_type,
                                                   drs=True)
    else:
        netG, _, _, _ = get_gan_model(
            dataset_name='celeba',
            model=args.model,
            loss_type=args.loss_type,
        )
    netG.to(device)
    if not args.netG_train_mode:
        netG.eval()
        netG.to(device)
        if args.drs:
            netD_drs.eval()
            netD_drs.to(device)

    gan_ckpt = f'{args.work_dir}/{args.exp_name}/checkpoints/netG/netG_{args.netG_ckpt_step}_steps.pth'
    if args.use_original_netD:
        netD_drs_ckpt = f'{args.work_dir}/{args.exp_name}/checkpoints/netD/netD_{args.netG_ckpt_step}_steps.pth'
    else:
        netD_drs_ckpt = f'{args.work_dir}/{args.exp_name}/checkpoints/netD_drs/netD_drs_{args.netG_ckpt_step}_steps.pth'
    print(gan_ckpt)
    netG.restore_checkpoint(ckpt_file=gan_ckpt)
    if args.drs:
        netD_drs.restore_checkpoint(ckpt_file=netD_drs_ckpt)
        netG = DRS(netG=netG, netD=netD_drs, device=device)

    # load classifier
    print('Load classifier')
    if args.classifier == 'vgg16':
        model = models.vgg16(pretrained=True)
    elif args.classifier == 'resnet18':
        model = models.resnet18(pretrained=True)
    elif args.classifier == 'inception':
        model = models.inception_v3(pretrained=True)
    else:
        raise ValueError('model should be vgg16 or resnet18 or inception')

    # change the number of classes
    in_features = model.classifier[6].in_features
    model.classifier[6] = nn.Linear(in_features, 2, bias=True)

    classifier_path = './convnet_celeba'
    model.load_state_dict(
        torch.load(os.path.join(classifier_path, f'{args.attr}.pth')))
    model.to(device)

    batch_size = min(args.batch_size, args.num_samples)
    num_batches = args.num_samples // batch_size

    attr_num = 0
    not_attr_num = 0
    for i in range(num_batches):
        with torch.no_grad():
            img = netG.generate_images(batch_size, device=device)
            labels = model(img)
            answers = torch.argmax(labels, dim=1)
            attr = torch.count_nonzero(answers).item()
            not_attr = batch_size - attr
            attr_num += attr
            not_attr_num += not_attr

    print(f'attr: {attr_num}')
    print(f'not attr: {not_attr_num}')

    output_dir = os.path.join(save_path, 'evaluate',
                              f'step-{args.netG_ckpt_step}')
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    output_file = os.path.join(output_dir, f'count_attribute.csv')
    if os.path.exists(output_file):
        with open(output_file, 'a', newline='') as f:
            wr = csv.writer(f)
            wr.writerow([args.attr, attr_num, not_attr_num])
    else:
        with open(output_file, 'w', newline='') as f:
            wr = csv.writer(f)
            wr.writerow(['', 'attr', 'not attr'])
            wr.writerow([args.attr, attr_num, not_attr_num])