def __init__(
            self,
            config_name,  # name of experiment's config file
            model_path="",  # path to the model. empty string infers the most recent checkpoint
            clusterer_path="",  # path to the clusterer, ignored if gan type doesn't require a clusterer
            pretrained={},  # urls to the pretrained models
            rootdir='./',
            device='cuda:0'):
        self.config = load_config(os.path.join(rootdir, config_name),
                                  'configs/default.yaml')
        self.model_path = model_path
        self.clusterer_path = clusterer_path
        self.rootdir = rootdir
        self.nlabels = self.config['generator']['nlabels']
        self.device = device
        self.pretrained = pretrained

        self.generator = self.get_generator()
        self.generator.eval()
        self.yz_dist = self.get_yz_dist()
Exemple #2
0
# DATA = 'CELEBA'
seed_torch(999)
DATA_FIX = 'CELEBA'
Num_epoch = 500 * 10000

DATA = 'Flowers'
NNN = 8000
image_path = './data/102flowers/'  # your image path
image_test = './data/102flowers/'  # your image path for calculating FID

main_path = './code_GAN_Memory/'
load_dir = './pretrained_model/'
out_path = main_path + '/results/'

config_path = main_path + '/configs/' + 'Flowers_celeba.yaml'
config = load_config(config_path, 'configs/default.yaml')
config['data']['train_dir'] = image_path
config['data']['test_dir'] = image_test

config['training']['out_dir'] = out_path
if not os.path.isdir(config['training']['out_dir']):
    os.makedirs(config['training']['out_dir'])

if 1:
    # Short hands
    batch_size = config['training']['batch_size']
    d_steps = config['training']['d_steps']
    restart_every = config['training']['restart_every']
    inception_every = config['training']['inception_every']
    save_every = config['training']['save_every']
    backup_every = config['training']['backup_every']
Exemple #3
0
from gan_training.distributions import get_ydist, get_zdist
from gan_training.eval import Evaluator
from gan_training.config import (
    load_config, build_models, build_optimizers, build_lr_scheduler,
)

# Arguments
parser = argparse.ArgumentParser(
    description='Train a GAN with different regularization strategies.'
)
parser.add_argument('config', type=str, help='Path to config file.')
parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.')

args = parser.parse_args()

config = load_config(args.config)
is_cuda = (torch.cuda.is_available() and not args.no_cuda)

# Short hands
batch_size = config['training']['batch_size']
d_steps = config['training']['d_steps']
restart_every = config['training']['restart_every']
inception_every = config['training']['inception_every']
save_every = config['training']['save_every']
backup_every = config['training']['backup_every']
sample_nlabels = config['training']['sample_nlabels']

out_dir = config['training']['out_dir']
checkpoint_dir = path.join(out_dir, 'chkpts')

# Create missing directories
from utils.visualization import draw_box_batch

if __name__ == '__main__':
    # Arguments
    parser = argparse.ArgumentParser(
        description=
        'Test a trained 3D controllable GAN and create visualizations.')
    parser.add_argument('config', type=str, help='Path to config file.')
    parser.add_argument('--eval-attr',
                        type=str,
                        default='fid,rot,trans,cam',
                        help='Attributes to evaluate.')

    args = parser.parse_args()

    config = load_config(args.config, 'configs/default.yaml')
    eval_attr = args.eval_attr.split(',')
    is_cuda = torch.cuda.is_available()
    assert is_cuda, 'No GPU device detected!'

    # Shorthands
    nlabels = config['data']['nlabels']
    batch_size = config['test']['batch_size']
    sample_size = config['test']['sample_size']
    sample_nrow = config['test']['sample_nrow']

    out_dir = get_out_dir(config)
    checkpoint_dir = path.join(out_dir, 'chkpts')
    out_dir = path.join(out_dir, 'test')

    # Creat missing directories
            x_real = sampler.sample(BS)[0].detach().cpu()
            x_real = [x.detach().cpu() for x in x_real]
            samples.extend(x_real)
        samples = torch.stack(samples[:N], dim=0)
        return pt_to_np(samples)

root = './'

while len(dirs) > 0:
    path = dirs.pop()
    if os.path.isdir(path):     # search down tree for config files
        for d1 in os.listdir(path):
            dirs.append(os.path.join(path, d1))
    else:
        if path.endswith('.yaml'):
            config = load_config(path, default_path='configs/default.yaml')
            outdir = config['training']['out_dir']

            if not os.path.exists(outdir) and config['pretrained'] == {}:
                print('Skipping', path, 'outdir', outdir)
                continue

            results_dir = os.path.join(outdir, 'results')
            checkpoint_dir = os.path.join(outdir, 'chkpts')
            os.makedirs(results_dir, exist_ok=True)

            fid_results, is_results, kl_results, nmodes_results, fsd_results, cluster_results = load_results(results_dir)

            checkpoint_files = os.listdir(checkpoint_dir) if os.path.exists(checkpoint_dir) else []
            if config['pretrained'] != {}: checkpoint_files = checkpoint_files + ['pretrained']
            
Exemple #6
0
parser = argparse.ArgumentParser(
    description='Train a GAN with different regularization strategies.'
)
parser.add_argument('config', type=str, help='Path to config file.')
parser.add_argument('lr', type=float, help='learning rate')
parser.add_argument('tau', type=float, help='timescale separation')
parser.add_argument('alpha', type=float, help='RMSProp parameter')
parser.add_argument('type', type=str, help='dataset type')
parser.add_argument('out_dir', type=str, help='output directory')
parser.add_argument('num_iter', type=int, help='number of iterations')
parser.add_argument('reg_param', type=float, help='regularization parameter')
parser.add_argument('random_seed', type=int, help='random seed')
parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda.')
args = parser.parse_args()

config = load_config(args.config, None)

seed = args.random_seed
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

is_cuda = (torch.cuda.is_available() and not args.no_cuda)

config['training']['lr_g'] = args.lr
config['training']['lr_d'] = args.lr*args.tau
config['training']['alpha'] = args.alpha
config['data']['type'] = args.type
config['training']['out_dir'] = args.out_dir
config['training']['reg_param'] = args.reg_param
Exemple #7
0
# total_inception = dict({})

# model_list.reverse()
all_results = dict({})
all_models = glob.glob("./output/Plot*")
print(len(all_models))
all_models.reverse()
for epoch_id in range(80):
    for model in all_models:
        model_name = "/home/kunxu/Workspace/GAN_PID/{}".format(model)
        key_name = model_name
        if key_name not in all_results:
            all_results[key_name] = []

        config = load_config(os.path.join(model_name, "config.yaml"),
                             'configs/default.yaml')
        generator, discriminator = build_models(config)
        generator = torch.nn.DataParallel(generator)
        zdist = get_zdist(config['z_dist']['type'],
                          config['z_dist']['dim'],
                          device=device)
        ydist = get_ydist(1, device=device)
        checkpoint_io = CheckpointIO(checkpoint_dir="./tmp")
        checkpoint_io.register_modules(generator_test=generator)
        evaluator = Evaluator(generator,
                              zdist,
                              ydist,
                              batch_size=100,
                              device=device)

        ckptpath = os.path.join(
Exemple #8
0
def perform_evaluation(run_name, image_type):

    out_dir = os.path.join(os.getcwd(), '..', 'output', run_name)
    checkpoint_dir = os.path.join(out_dir, 'chkpts')
    checkpoints = sorted(glob.glob(os.path.join(checkpoint_dir, '*')))
    evaluation_dict = {}

    for point in checkpoints:
        if not int(
                point.split('/')[-1].split('_')[1].split('.')[0]) % 10000 == 0:
            continue

        iter_num = int(point.split('/')[-1].split('_')[1].split('.')[0])
        model_file = point.split('/')[-1]

        config = load_config('../configs/fr_default.yaml', None)
        is_cuda = (torch.cuda.is_available())
        checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir)
        device = torch.device("cuda:0" if is_cuda else "cpu")

        generator, discriminator = build_models(config)

        # Put models on gpu if needed
        generator = generator.to(device)
        discriminator = discriminator.to(device)

        # Use multiple GPUs if possible
        generator = nn.DataParallel(generator)
        discriminator = nn.DataParallel(discriminator)

        generator_test_9 = copy.deepcopy(generator)
        generator_test_99 = copy.deepcopy(generator)
        generator_test_999 = copy.deepcopy(generator)
        generator_test_9999 = copy.deepcopy(generator)

        # Register modules to checkpoint
        checkpoint_io.register_modules(
            generator=generator,
            generator_test_9=generator_test_9,
            generator_test_99=generator_test_99,
            generator_test_999=generator_test_999,
            generator_test_9999=generator_test_9999,
            discriminator=discriminator,
        )

        # Load checkpoint
        load_dict = checkpoint_io.load(model_file)

        # Distributions
        ydist = get_ydist(config['data']['nlabels'], device=device)
        zdist = get_zdist(config['z_dist']['type'],
                          config['z_dist']['dim'],
                          device=device)
        z_sample = torch.Tensor(np.load('z_data.npy')).to(device)

        #for name, model in zip(['0_', '09_', '099_', '0999_', '09999_'], [generator, generator_test_9, generator_test_99, generator_test_999, generator_test_9999]):
        for name, model in zip(
            ['099_', '0999_', '09999_'],
            [generator_test_99, generator_test_999, generator_test_9999]):

            # Evaluator
            evaluator = Evaluator(model, zdist, ydist, device=device)

            x_sample = []

            for i in range(10):
                x = evaluator.create_samples(z_sample[i * 1000:(i + 1) * 1000])
                x_sample.append(x)

            x_sample = torch.cat(x_sample)
            x_sample = x_sample / 2 + 0.5

            if not os.path.exists('fake_data'):
                os.makedirs('fake_data')

            for i in range(10000):
                torchvision.utils.save_image(x_sample[i, :, :, :],
                                             'fake_data/{}.png'.format(i))

            fid_score = calculate_fid_given_paths(
                ['fake_data', image_type + '_real'], 50, True, 2048)
            print(iter_num, name, fid_score)

            os.system("rm -rf " + "fake_data")

            evaluation_dict[(iter_num, name[:-1])] = {'FID': fid_score}

            if not os.path.exists('evaluation_data/' + run_name):
                os.makedirs('evaluation_data/' + run_name)

            pickle.dump(
                evaluation_dict,
                open('evaluation_data/' + run_name + '/eval_fid.p', 'wb'))
Exemple #9
0
                    default=-1,
                    type=int,
                    help='Max training iteration')
parser.add_argument('--no-cuda', action='store_true', help='Do not use cuda')
parser.add_argument('--seed', default=1, type=int, help='Random Seed')

args = parser.parse_args()

seed = args.seed
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

config_path = os.path.join(args.config_dir, args.config)
config = load_config(config_path)
is_cuda = (torch.cuda.is_available() and not args.no_cuda)

# = = = = = Customized Configurations = = = = = #
out_dir = os.path.join(args.output_dir, args.name)
config['training']['out_dir'] = out_dir
if args.nf > 0:
    config['generator']['kwargs']['nfilter'] = args.nf
    config['discriminator']['kwargs']['nfilter'] = args.nf
if args.bs > 0:
    config['training']['batch_size'] = args.bs
if args.reg_param > 0:
    config['training']['reg_param'] = args.reg_param
if args.w_info > 0:
    config['training']['w_info'] = args.w_info
if args.mi > 0: