def evaluation( model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=128, tfr_file='tensorflow_datasets/lsun/church-r08.tfrecords', samples_dir=None, num_inception_samples=2048, ): #region = utils.get_gcp_region() tfr_file = 'gs://{}/{}'.format(bucket_name_prefix, tfr_file) kwargs = tpu_utils.load_train_kwargs(model_dir) print('loaded kwargs:', kwargs) ds = datasets.get_dataset(kwargs['dataset'], tfr_file=tfr_file) worker = tpu_utils.EvalWorker( tpu_name=tpu_name, model_constructor=lambda: Model( model_name=kwargs['model_name'], betas=get_beta_schedule( kwargs['beta_schedule'], beta_start=kwargs['beta_start'], beta_end=kwargs['beta_end'], num_diffusion_timesteps=kwargs['num_diffusion_timesteps'] ), loss_type=kwargs['loss_type'], num_classes=ds.num_classes, dropout=kwargs['dropout'], randflip=kwargs['randflip'], block_size=kwargs['block_size'] ), total_bs=total_bs, inception_bs=total_bs, num_inception_samples=num_inception_samples, dataset=ds, limit_dataset_size=30000 # limit size of dataset for computing Inception features, for memory reasons ) worker.run(logdir=model_dir, once=once, skip_non_ema_pass=True, dump_samples_only=dump_samples_only, samples_dir=samples_dir)
def train(exp_name, tpu_name, bucket_name_prefix, model_name='unet2d16b2', dataset='cifar10', optimizer='adam', total_bs=128, grad_clip=1., lr=2e-4, warmup=5000, num_diffusion_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', model_mean_type='eps', model_var_type='fixedlarge', loss_type='mse', dropout=0.1, randflip=1, tfds_data_dir='tensorflow_datasets', log_dir='logs', keep_checkpoint_max=2): region = utils.get_gcp_region() tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir) log_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, log_dir) kwargs = dict(locals()) ds = datasets.get_dataset(dataset, tfds_data_dir=tfds_data_dir) tpu_utils.run_training( date_str='9999-99-99', exp_name= '{exp_name}_{dataset}_{model_name}_{optimizer}_bs{total_bs}_lr{lr}w{warmup}_beta{beta_start}-{beta_end}-{beta_schedule}_t{num_diffusion_timesteps}_{model_mean_type}-{model_var_type}-{loss_type}_dropout{dropout}_randflip{randflip}' .format(**kwargs), model_constructor=lambda: Model( model_name=model_name, betas=get_beta_schedule(beta_schedule, beta_start=beta_start, beta_end=beta_end, num_diffusion_timesteps= num_diffusion_timesteps), model_mean_type=model_mean_type, model_var_type=model_var_type, loss_type=loss_type, num_classes=ds.num_classes, dropout=dropout, randflip=randflip), optimizer=optimizer, total_bs=total_bs, lr=lr, warmup=warmup, grad_clip=grad_clip, train_input_fn=ds.train_input_fn, tpu=tpu_name, log_dir=log_dir, dump_kwargs=kwargs, iterations_per_loop=2000, keep_checkpoint_max=keep_checkpoint_max)
def _load_model(kwargs, ds): return Model( model_name=kwargs['model_name'], betas=get_beta_schedule( kwargs['beta_schedule'], beta_start=kwargs['beta_start'], beta_end=kwargs['beta_end'], num_diffusion_timesteps=kwargs['num_diffusion_timesteps']), model_mean_type=kwargs['model_mean_type'], model_var_type=kwargs['model_var_type'], loss_type=kwargs['loss_type'], num_classes=ds.num_classes, dropout=kwargs['dropout'], randflip=kwargs['randflip'])
def train( exp_name, tpu_name, bucket_name_prefix, model_name='unet2d16b2c112244', dataset='lsun', optimizer='adam', total_bs=64, grad_clip=1., lr=2e-5, warmup=5000, num_diffusion_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', loss_type='noisepred', dropout=0.0, randflip=1, block_size=1, tfr_file='tensorflow_datasets/lsun/church/church-r08.tfrecords', log_dir='logs', warm_start_model_dir=None ): #region = utils.get_gcp_region() tfr_file = 'gs://{}/{}'.format(bucket_name_prefix, tfr_file) log_dir = 'gs://{}/{}'.format(bucket_name_prefix, log_dir) print("tfr_file:", tfr_file) print("log_dir:", log_dir) kwargs = dict(locals()) ds = datasets.get_dataset(dataset, tfr_file=tfr_file) tpu_utils.run_training( date_str='9999-99-99', exp_name='{exp_name}_{dataset}_{model_name}_{optimizer}_bs{total_bs}_lr{lr}w{warmup}_beta{beta_start}-{beta_end}-{beta_schedule}_t{num_diffusion_timesteps}_{loss_type}_dropout{dropout}_randflip{randflip}_blk{block_size}'.format( **kwargs), model_constructor=lambda: Model( model_name=model_name, betas=get_beta_schedule( beta_schedule, beta_start=beta_start, beta_end=beta_end, num_diffusion_timesteps=num_diffusion_timesteps ), loss_type=loss_type, num_classes=ds.num_classes, dropout=dropout, randflip=randflip, block_size=block_size ), optimizer=optimizer, total_bs=total_bs, lr=lr, warmup=warmup, grad_clip=grad_clip, train_input_fn=ds.train_input_fn, tpu=tpu_name, log_dir=log_dir, dump_kwargs=kwargs, warm_start_from=tf.estimator.WarmStartSettings( ckpt_to_initialize_from=tf.train.latest_checkpoint(warm_start_model_dir), vars_to_warm_start=[".*"] ) if warm_start_model_dir else None )
def sample_tf(bs=1, nb=1, which=None): import tqdm # add diffusion/scripts to PYTHONPATH, too from diffusion_tf.tpu_utils.tpu_utils import make_ema import diffusion_tf.utils as utils from scripts.run_cifar import Model as cifar10_model from scripts.run_lsun import Model as lsun_model from diffusion_tf.diffusion_utils_2 import get_beta_schedule import PIL.Image ckpts = { "cifar10": "diffusion_models_release/diffusion_cifar10_model/model.ckpt-790000", "lsun_bedroom": "diffusion_models_release/diffusion_lsun_bedroom_model/model.ckpt-2388000", "lsun_cat": "diffusion_models_release/diffusion_lsun_cat_model/model.ckpt-1761000", "lsun_church": "diffusion_models_release/diffusion_lsun_church_model/model.ckpt-4432000", } models = { "cifar10": cifar10_model, "lsun_bedroom": lsun_model, "lsun_cat": lsun_model, "lsun_church": lsun_model, } betas = get_beta_schedule(beta_schedule="linear", beta_start=0.0001, beta_end=0.02, num_diffusion_timesteps=1000) cifar10_config = { "model_name": "unet2d16b2", "model_mean_type": "eps", "model_var_type": "fixedlarge", "betas": betas, "loss_type": "noisepred", "num_classes": 1, "dropout": 0, "randflip": 0, } lsun_config = { "model_name": "unet2d16b2c112244", "betas": betas, "loss_type": "noisepred", "num_classes": 1, "dropout": 0, "randflip": 0, "block_size": 1, } model_configs = { "cifar10": cifar10_config, "lsun_bedroom": lsun_config, "lsun_cat": lsun_config, "lsun_church": lsun_config, } img_shapes = { "cifar10": (32, 32, 3), "lsun_bedroom": (256, 256, 3), "lsun_cat": (256, 256, 3), "lsun_church": (256, 256, 3), } which = which if which is not None else [ "cifar10", "lsun_bedroom", "lsun_cat", "lsun_church" ] if type(which) == str: which = [which] for name in which: os.makedirs("results/tf_{}".format(name), exist_ok=True) print("Writing tf samples in {}".format("results/tf_{}".format(name))) ema = name.startswith("ema_") basename = name[len("ema_"):] if ema else name with tf.Session() as sess: print("Loading {} model".format(name)) model = models[basename](**model_configs[basename]) # build graph x_ = tf.fill([bs, *img_shapes[basename]], value=np.nan) y = tf.fill([ bs, ], value=0) sample = model.samples_fn(x_, y) global_step = tf.train.get_or_create_global_step() if ema: ema_, _ = make_ema( global_step=global_step, ema_decay=1e-10, trainable_variables=tf.trainable_variables()) with utils.ema_scope(ema_): print('===== EMA SAMPLES =====') sample = model.progressive_samples_fn(x_, y) # load ckpt ckpt = ckpts[basename] print('restoring') saver = tf.train.Saver() saver.restore(sess, ckpt) global_step_val = sess.run(global_step) print('restored global step: {}'.format(global_step_val)) for ib in tqdm.tqdm(range(nb), desc="Batch"): # test sampling result = sess.run(sample) samples = result["samples"] for i in range(samples.shape[0]): np_sample = ((samples[i] + 1.0) * 127.5).astype(np.uint8) PIL.Image.fromarray(np_sample).save( "results/tf_{}/{:06}.png".format(name, ib * bs + i)) tf.reset_default_graph()
def sample_tf(bs=1): # add diffusion/scripts to PYTHONPATH, too from scripts.run_cifar import Model as cifar10_model from scripts.run_lsun import Model as lsun_model from diffusion_tf.diffusion_utils_2 import get_beta_schedule import PIL.Image ckpts = { "cifar10": "diffusion_models_release/diffusion_cifar10_model/model.ckpt-790000", "lsun_bedroom": "diffusion_models_release/diffusion_lsun_bedroom_model/model.ckpt-2388000", "lsun_cat": "diffusion_models_release/diffusion_lsun_cat_model/model.ckpt-1761000", "lsun_church": "diffusion_models_release/diffusion_lsun_church_model/model.ckpt-4432000", } models = { "cifar10": cifar10_model, "lsun_bedroom": lsun_model, "lsun_cat": lsun_model, "lsun_church": lsun_model, } betas = get_beta_schedule(beta_schedule="linear", beta_start=0.0001, beta_end=0.02, num_diffusion_timesteps=1000) cifar10_config = { "model_name": "unet2d16b2", "model_mean_type": "eps", "model_var_type": "fixedlarge", "betas": betas, "loss_type": "noisepred", "num_classes": 1, "dropout": 0, "randflip": 0, } lsun_config = { "model_name": "unet2d16b2c112244", "betas": betas, "loss_type": "noisepred", "num_classes": 1, "dropout": 0, "randflip": 0, "block_size": 1, } model_configs = { "cifar10": cifar10_config, "lsun_bedroom": lsun_config, "lsun_cat": lsun_config, "lsun_church": lsun_config, } img_shapes = { "cifar10": (32,32,3), "lsun_bedroom": (256,256,3), "lsun_cat": (256,256,3), "lsun_church": (256,256,3), } for name in ["cifar10", "lsun_bedroom", "lsun_cat", "lsun_church"]: os.makedirs("results/tf_{}".format(name), exist_ok=True) with tf.Session() as sess: print("Loading {} model".format(name)) model = models[name](**model_configs[name]) # build graph x_ = tf.fill([bs, *img_shapes[name]], value=np.nan) y = tf.fill([bs,], value=0) sample = model.samples_fn(x_, y) global_step = tf.train.get_or_create_global_step() # load ckpt ckpt = ckpts[name] #print('initializing global variables') #sess.run(tf.global_variables_initializer()) print('restoring') saver = tf.train.Saver() saver.restore(sess, ckpt) global_step_val = sess.run(global_step) print('restored global step: {}'.format(global_step_val)) # test sampling result = sess.run(sample) samples = result["samples"] for i in range(samples.shape[0]): sample = ((samples[i]+1.0)*127.5).astype(np.uint8) PIL.Image.fromarray(sample).save("results/tf_{}/{:06}.png".format(name, i)) tf.reset_default_graph()