Ejemplo n.º 1
0
def evaluation(
    model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=128,
    tfr_file='tensorflow_datasets/lsun/church-r08.tfrecords', samples_dir=None, num_inception_samples=2048,
):
  #region = utils.get_gcp_region()
  tfr_file = 'gs://{}/{}'.format(bucket_name_prefix, tfr_file)
  kwargs = tpu_utils.load_train_kwargs(model_dir)
  print('loaded kwargs:', kwargs)
  ds = datasets.get_dataset(kwargs['dataset'], tfr_file=tfr_file)
  worker = tpu_utils.EvalWorker(
    tpu_name=tpu_name,
    model_constructor=lambda: Model(
      model_name=kwargs['model_name'],
      betas=get_beta_schedule(
        kwargs['beta_schedule'], beta_start=kwargs['beta_start'], beta_end=kwargs['beta_end'],
        num_diffusion_timesteps=kwargs['num_diffusion_timesteps']
      ),
      loss_type=kwargs['loss_type'],
      num_classes=ds.num_classes,
      dropout=kwargs['dropout'],
      randflip=kwargs['randflip'],
      block_size=kwargs['block_size']
    ),
    total_bs=total_bs, inception_bs=total_bs, num_inception_samples=num_inception_samples,
    dataset=ds,
    limit_dataset_size=30000  # limit size of dataset for computing Inception features, for memory reasons
  )
  worker.run(logdir=model_dir, once=once, skip_non_ema_pass=True, dump_samples_only=dump_samples_only,
             samples_dir=samples_dir)
Ejemplo n.º 2
0
def train(exp_name,
          tpu_name,
          bucket_name_prefix,
          model_name='unet2d16b2',
          dataset='cifar10',
          optimizer='adam',
          total_bs=128,
          grad_clip=1.,
          lr=2e-4,
          warmup=5000,
          num_diffusion_timesteps=1000,
          beta_start=0.0001,
          beta_end=0.02,
          beta_schedule='linear',
          model_mean_type='eps',
          model_var_type='fixedlarge',
          loss_type='mse',
          dropout=0.1,
          randflip=1,
          tfds_data_dir='tensorflow_datasets',
          log_dir='logs',
          keep_checkpoint_max=2):
    region = utils.get_gcp_region()
    tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region,
                                           tfds_data_dir)
    log_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, log_dir)
    kwargs = dict(locals())
    ds = datasets.get_dataset(dataset, tfds_data_dir=tfds_data_dir)
    tpu_utils.run_training(
        date_str='9999-99-99',
        exp_name=
        '{exp_name}_{dataset}_{model_name}_{optimizer}_bs{total_bs}_lr{lr}w{warmup}_beta{beta_start}-{beta_end}-{beta_schedule}_t{num_diffusion_timesteps}_{model_mean_type}-{model_var_type}-{loss_type}_dropout{dropout}_randflip{randflip}'
        .format(**kwargs),
        model_constructor=lambda: Model(
            model_name=model_name,
            betas=get_beta_schedule(beta_schedule,
                                    beta_start=beta_start,
                                    beta_end=beta_end,
                                    num_diffusion_timesteps=
                                    num_diffusion_timesteps),
            model_mean_type=model_mean_type,
            model_var_type=model_var_type,
            loss_type=loss_type,
            num_classes=ds.num_classes,
            dropout=dropout,
            randflip=randflip),
        optimizer=optimizer,
        total_bs=total_bs,
        lr=lr,
        warmup=warmup,
        grad_clip=grad_clip,
        train_input_fn=ds.train_input_fn,
        tpu=tpu_name,
        log_dir=log_dir,
        dump_kwargs=kwargs,
        iterations_per_loop=2000,
        keep_checkpoint_max=keep_checkpoint_max)
Ejemplo n.º 3
0
def _load_model(kwargs, ds):
    return Model(
        model_name=kwargs['model_name'],
        betas=get_beta_schedule(
            kwargs['beta_schedule'],
            beta_start=kwargs['beta_start'],
            beta_end=kwargs['beta_end'],
            num_diffusion_timesteps=kwargs['num_diffusion_timesteps']),
        model_mean_type=kwargs['model_mean_type'],
        model_var_type=kwargs['model_var_type'],
        loss_type=kwargs['loss_type'],
        num_classes=ds.num_classes,
        dropout=kwargs['dropout'],
        randflip=kwargs['randflip'])
Ejemplo n.º 4
0
def train(
    exp_name, tpu_name, bucket_name_prefix, model_name='unet2d16b2c112244', dataset='lsun',
    optimizer='adam', total_bs=64, grad_clip=1., lr=2e-5, warmup=5000,
    num_diffusion_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', loss_type='noisepred',
    dropout=0.0, randflip=1, block_size=1,
    tfr_file='tensorflow_datasets/lsun/church/church-r08.tfrecords', log_dir='logs',
    warm_start_model_dir=None
):
  #region = utils.get_gcp_region()
  tfr_file = 'gs://{}/{}'.format(bucket_name_prefix, tfr_file)
  log_dir = 'gs://{}/{}'.format(bucket_name_prefix, log_dir)
  print("tfr_file:", tfr_file)
  print("log_dir:", log_dir)
  kwargs = dict(locals())
  ds = datasets.get_dataset(dataset, tfr_file=tfr_file)
  tpu_utils.run_training(
    date_str='9999-99-99',
    exp_name='{exp_name}_{dataset}_{model_name}_{optimizer}_bs{total_bs}_lr{lr}w{warmup}_beta{beta_start}-{beta_end}-{beta_schedule}_t{num_diffusion_timesteps}_{loss_type}_dropout{dropout}_randflip{randflip}_blk{block_size}'.format(
      **kwargs),
    model_constructor=lambda: Model(
      model_name=model_name,
      betas=get_beta_schedule(
        beta_schedule, beta_start=beta_start, beta_end=beta_end, num_diffusion_timesteps=num_diffusion_timesteps
      ),
      loss_type=loss_type,
      num_classes=ds.num_classes,
      dropout=dropout,
      randflip=randflip,
      block_size=block_size
    ),
    optimizer=optimizer, total_bs=total_bs, lr=lr, warmup=warmup, grad_clip=grad_clip,
    train_input_fn=ds.train_input_fn,
    tpu=tpu_name, log_dir=log_dir, dump_kwargs=kwargs,
    warm_start_from=tf.estimator.WarmStartSettings(
      ckpt_to_initialize_from=tf.train.latest_checkpoint(warm_start_model_dir),
      vars_to_warm_start=[".*"]
    ) if warm_start_model_dir else None
  )
Ejemplo n.º 5
0
def sample_tf(bs=1, nb=1, which=None):
    import tqdm
    # add diffusion/scripts to PYTHONPATH, too
    from diffusion_tf.tpu_utils.tpu_utils import make_ema
    import diffusion_tf.utils as utils
    from scripts.run_cifar import Model as cifar10_model
    from scripts.run_lsun import Model as lsun_model
    from diffusion_tf.diffusion_utils_2 import get_beta_schedule
    import PIL.Image

    ckpts = {
        "cifar10":
        "diffusion_models_release/diffusion_cifar10_model/model.ckpt-790000",
        "lsun_bedroom":
        "diffusion_models_release/diffusion_lsun_bedroom_model/model.ckpt-2388000",
        "lsun_cat":
        "diffusion_models_release/diffusion_lsun_cat_model/model.ckpt-1761000",
        "lsun_church":
        "diffusion_models_release/diffusion_lsun_church_model/model.ckpt-4432000",
    }
    models = {
        "cifar10": cifar10_model,
        "lsun_bedroom": lsun_model,
        "lsun_cat": lsun_model,
        "lsun_church": lsun_model,
    }
    betas = get_beta_schedule(beta_schedule="linear",
                              beta_start=0.0001,
                              beta_end=0.02,
                              num_diffusion_timesteps=1000)
    cifar10_config = {
        "model_name": "unet2d16b2",
        "model_mean_type": "eps",
        "model_var_type": "fixedlarge",
        "betas": betas,
        "loss_type": "noisepred",
        "num_classes": 1,
        "dropout": 0,
        "randflip": 0,
    }
    lsun_config = {
        "model_name": "unet2d16b2c112244",
        "betas": betas,
        "loss_type": "noisepred",
        "num_classes": 1,
        "dropout": 0,
        "randflip": 0,
        "block_size": 1,
    }
    model_configs = {
        "cifar10": cifar10_config,
        "lsun_bedroom": lsun_config,
        "lsun_cat": lsun_config,
        "lsun_church": lsun_config,
    }
    img_shapes = {
        "cifar10": (32, 32, 3),
        "lsun_bedroom": (256, 256, 3),
        "lsun_cat": (256, 256, 3),
        "lsun_church": (256, 256, 3),
    }

    which = which if which is not None else [
        "cifar10", "lsun_bedroom", "lsun_cat", "lsun_church"
    ]
    if type(which) == str:
        which = [which]

    for name in which:
        os.makedirs("results/tf_{}".format(name), exist_ok=True)
        print("Writing tf samples in {}".format("results/tf_{}".format(name)))
        ema = name.startswith("ema_")
        basename = name[len("ema_"):] if ema else name
        with tf.Session() as sess:
            print("Loading {} model".format(name))
            model = models[basename](**model_configs[basename])
            # build graph
            x_ = tf.fill([bs, *img_shapes[basename]], value=np.nan)
            y = tf.fill([
                bs,
            ], value=0)

            sample = model.samples_fn(x_, y)

            global_step = tf.train.get_or_create_global_step()
            if ema:
                ema_, _ = make_ema(
                    global_step=global_step,
                    ema_decay=1e-10,
                    trainable_variables=tf.trainable_variables())
                with utils.ema_scope(ema_):
                    print('===== EMA SAMPLES =====')
                    sample = model.progressive_samples_fn(x_, y)

            # load ckpt
            ckpt = ckpts[basename]
            print('restoring')
            saver = tf.train.Saver()
            saver.restore(sess, ckpt)
            global_step_val = sess.run(global_step)
            print('restored global step: {}'.format(global_step_val))
            for ib in tqdm.tqdm(range(nb), desc="Batch"):
                # test sampling
                result = sess.run(sample)
                samples = result["samples"]
                for i in range(samples.shape[0]):
                    np_sample = ((samples[i] + 1.0) * 127.5).astype(np.uint8)
                    PIL.Image.fromarray(np_sample).save(
                        "results/tf_{}/{:06}.png".format(name, ib * bs + i))
        tf.reset_default_graph()
Ejemplo n.º 6
0
def sample_tf(bs=1):
    # add diffusion/scripts to PYTHONPATH, too
    from scripts.run_cifar import Model as cifar10_model
    from scripts.run_lsun import Model as lsun_model
    from diffusion_tf.diffusion_utils_2 import get_beta_schedule
    import PIL.Image

    ckpts = {
        "cifar10": "diffusion_models_release/diffusion_cifar10_model/model.ckpt-790000",
        "lsun_bedroom": "diffusion_models_release/diffusion_lsun_bedroom_model/model.ckpt-2388000",
        "lsun_cat": "diffusion_models_release/diffusion_lsun_cat_model/model.ckpt-1761000",
        "lsun_church": "diffusion_models_release/diffusion_lsun_church_model/model.ckpt-4432000",
    }
    models = {
        "cifar10": cifar10_model,
        "lsun_bedroom": lsun_model,
        "lsun_cat": lsun_model,
        "lsun_church": lsun_model,
    }
    betas = get_beta_schedule(beta_schedule="linear", beta_start=0.0001,
                              beta_end=0.02, num_diffusion_timesteps=1000)
    cifar10_config = {
        "model_name": "unet2d16b2",
        "model_mean_type": "eps",
        "model_var_type": "fixedlarge",
        "betas": betas,
        "loss_type": "noisepred",
        "num_classes": 1,
        "dropout": 0,
        "randflip": 0,
    }
    lsun_config = {
        "model_name": "unet2d16b2c112244",
        "betas": betas,
        "loss_type": "noisepred",
        "num_classes": 1,
        "dropout": 0,
        "randflip": 0,
        "block_size": 1,
    }
    model_configs = {
        "cifar10": cifar10_config,
        "lsun_bedroom": lsun_config,
        "lsun_cat": lsun_config,
        "lsun_church": lsun_config,
    }
    img_shapes = {
        "cifar10": (32,32,3),
        "lsun_bedroom": (256,256,3),
        "lsun_cat": (256,256,3),
        "lsun_church": (256,256,3),
    }

    for name in ["cifar10", "lsun_bedroom", "lsun_cat", "lsun_church"]:
        os.makedirs("results/tf_{}".format(name), exist_ok=True)
        with tf.Session() as sess:
            print("Loading {} model".format(name))
            model = models[name](**model_configs[name])
            # build graph
            x_ = tf.fill([bs, *img_shapes[name]], value=np.nan)
            y = tf.fill([bs,], value=0)
            sample = model.samples_fn(x_, y)
            global_step = tf.train.get_or_create_global_step()
            # load ckpt
            ckpt = ckpts[name]
            #print('initializing global variables')
            #sess.run(tf.global_variables_initializer())
            print('restoring')
            saver = tf.train.Saver()
            saver.restore(sess, ckpt)
            global_step_val = sess.run(global_step)
            print('restored global step: {}'.format(global_step_val))
            # test sampling
            result = sess.run(sample)
            samples = result["samples"]
            for i in range(samples.shape[0]):
                sample = ((samples[i]+1.0)*127.5).astype(np.uint8)
                PIL.Image.fromarray(sample).save("results/tf_{}/{:06}.png".format(name, i))
        tf.reset_default_graph()