Exemple #1
0
def evaluation(  # evaluation loop for use during training
        model_dir,
        tpu_name,
        bucket_name_prefix,
        once=False,
        dump_samples_only=False,
        total_bs=256,
        tfds_data_dir='tensorflow_datasets',
        load_ckpt=None):
    region = utils.get_gcp_region()
    tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region,
                                           tfds_data_dir)
    kwargs = tpu_utils.load_train_kwargs(model_dir)
    print('loaded kwargs:', kwargs)
    ds = datasets.get_dataset(kwargs['dataset'], tfds_data_dir=tfds_data_dir)
    worker = tpu_utils.EvalWorker(
        tpu_name=tpu_name,
        model_constructor=functools.partial(_load_model, kwargs=kwargs, ds=ds),
        total_bs=total_bs,
        inception_bs=total_bs,
        num_inception_samples=50000,
        dataset=ds,
    )
    worker.run(logdir=model_dir,
               once=once,
               skip_non_ema_pass=True,
               dump_samples_only=dump_samples_only,
               load_ckpt=load_ckpt)
Exemple #2
0
def train(
    exp_name, tpu_name, bucket_name_prefix, model_name='unet2d16b2c112244', dataset='celebahq256',
    optimizer='adam', total_bs=64, grad_clip=1., lr=0.00002, warmup=5000,
    num_diffusion_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', loss_type='noisepred',
    dropout=0.0, randflip=1, block_size=1,
    tfds_data_dir='tensorflow_datasets', log_dir='logs'
):
  region = utils.get_gcp_region()
  tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir)
  log_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, log_dir)
  kwargs = dict(locals())
  ds = datasets.get_dataset(dataset, tfds_data_dir=tfds_data_dir)
  tpu_utils.run_training(
    date_str='9999-99-99',
    exp_name='{exp_name}_{dataset}_{model_name}_{optimizer}_bs{total_bs}_lr{lr}w{warmup}_beta{beta_start}-{beta_end}-{beta_schedule}_t{num_diffusion_timesteps}_{loss_type}_dropout{dropout}_randflip{randflip}_blk{block_size}'.format(
      **kwargs),
    model_constructor=lambda: Model(
      model_name=model_name,
      betas=get_beta_schedule(
        beta_schedule, beta_start=beta_start, beta_end=beta_end, num_diffusion_timesteps=num_diffusion_timesteps
      ),
      loss_type=loss_type,
      num_classes=ds.num_classes,
      dropout=dropout,
      randflip=randflip,
      block_size=block_size
    ),
    optimizer=optimizer, total_bs=total_bs, lr=lr, warmup=warmup, grad_clip=grad_clip,
    train_input_fn=ds.train_input_fn,
    tpu=tpu_name, log_dir=log_dir, dump_kwargs=kwargs
  )
Exemple #3
0
def evaluation(
    model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=128,
    tfds_data_dir='tensorflow_datasets',
):
  region = utils.get_gcp_region()
  tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir)
  kwargs = tpu_utils.load_train_kwargs(model_dir)
  print('loaded kwargs:', kwargs)
  ds = datasets.get_dataset(kwargs['dataset'], tfds_data_dir=tfds_data_dir)
  worker = tpu_utils.EvalWorker(
    tpu_name=tpu_name,
    model_constructor=lambda: Model(
      model_name=kwargs['model_name'],
      betas=get_beta_schedule(
        kwargs['beta_schedule'], beta_start=kwargs['beta_start'], beta_end=kwargs['beta_end'],
        num_diffusion_timesteps=kwargs['num_diffusion_timesteps']
      ),
      loss_type=kwargs['loss_type'],
      num_classes=ds.num_classes,
      dropout=kwargs['dropout'],
      randflip=kwargs['randflip'],
      block_size=kwargs['block_size']
    ),
    total_bs=total_bs, inception_bs=total_bs, num_inception_samples=2048,
    dataset=ds,
  )
  worker.run(logdir=model_dir, once=once, skip_non_ema_pass=True, dump_samples_only=dump_samples_only)
Exemple #4
0
def train(exp_name,
          tpu_name,
          bucket_name_prefix,
          model_name='unet2d16b2c112244',
          dataset='lsun',
          optimizer='adam',
          total_bs=64,
          grad_clip=1.,
          lr=2e-5,
          warmup=5000,
          num_diffusion_timesteps=1000,
          beta_start=0.0001,
          beta_end=0.02,
          beta_schedule='linear',
          loss_type='noisepred',
          dropout=0.0,
          randflip=1,
          block_size=1,
          tfr_file='tensorflow_datasets/lsun/church/church-r08.tfrecords',
          log_dir='logs',
          warm_start_model_dir=None):
    region = utils.get_gcp_region()
    tfr_file = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfr_file)
    log_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, log_dir)
    print("tfr_file:", tfr_file)
    print("log_dir:", log_dir)
    kwargs = dict(locals())
    ds = datasets.get_dataset(dataset, tfr_file=tfr_file)
    tpu_utils.run_training(
        date_str='9999-99-99',
        exp_name=
        '{exp_name}_{dataset}_{model_name}_{optimizer}_bs{total_bs}_lr{lr}w{warmup}_beta{beta_start}-{beta_end}-{beta_schedule}_t{num_diffusion_timesteps}_{loss_type}_dropout{dropout}_randflip{randflip}_blk{block_size}'
        .format(**kwargs),
        model_constructor=lambda: Model(
            model_name=model_name,
            betas=get_beta_schedule(beta_schedule,
                                    beta_start=beta_start,
                                    beta_end=beta_end,
                                    num_diffusion_timesteps=
                                    num_diffusion_timesteps),
            loss_type=loss_type,
            num_classes=ds.num_classes,
            dropout=dropout,
            randflip=randflip,
            block_size=block_size),
        optimizer=optimizer,
        total_bs=total_bs,
        lr=lr,
        warmup=warmup,
        grad_clip=grad_clip,
        train_input_fn=ds.train_input_fn,
        tpu=tpu_name,
        log_dir=log_dir,
        dump_kwargs=kwargs,
        warm_start_from=tf.estimator.WarmStartSettings(
            ckpt_to_initialize_from=tf.train.latest_checkpoint(
                warm_start_model_dir),
            vars_to_warm_start=[".*"]) if warm_start_model_dir else None)
Exemple #5
0
def train(exp_name,
          tpu_name,
          bucket_name_prefix,
          model_name='unet2d16b2',
          dataset='cifar10',
          optimizer='adam',
          total_bs=128,
          grad_clip=1.,
          lr=2e-4,
          warmup=5000,
          num_diffusion_timesteps=1000,
          beta_start=0.0001,
          beta_end=0.02,
          beta_schedule='linear',
          model_mean_type='eps',
          model_var_type='fixedlarge',
          loss_type='mse',
          dropout=0.1,
          randflip=1,
          tfds_data_dir='tensorflow_datasets',
          log_dir='logs',
          keep_checkpoint_max=2):
    region = utils.get_gcp_region()
    tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region,
                                           tfds_data_dir)
    log_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, log_dir)
    kwargs = dict(locals())
    ds = datasets.get_dataset(dataset, tfds_data_dir=tfds_data_dir)
    tpu_utils.run_training(
        date_str='9999-99-99',
        exp_name=
        '{exp_name}_{dataset}_{model_name}_{optimizer}_bs{total_bs}_lr{lr}w{warmup}_beta{beta_start}-{beta_end}-{beta_schedule}_t{num_diffusion_timesteps}_{model_mean_type}-{model_var_type}-{loss_type}_dropout{dropout}_randflip{randflip}'
        .format(**kwargs),
        model_constructor=lambda: Model(
            model_name=model_name,
            betas=get_beta_schedule(beta_schedule,
                                    beta_start=beta_start,
                                    beta_end=beta_end,
                                    num_diffusion_timesteps=
                                    num_diffusion_timesteps),
            model_mean_type=model_mean_type,
            model_var_type=model_var_type,
            loss_type=loss_type,
            num_classes=ds.num_classes,
            dropout=dropout,
            randflip=randflip),
        optimizer=optimizer,
        total_bs=total_bs,
        lr=lr,
        warmup=warmup,
        grad_clip=grad_clip,
        train_input_fn=ds.train_input_fn,
        tpu=tpu_name,
        log_dir=log_dir,
        dump_kwargs=kwargs,
        iterations_per_loop=2000,
        keep_checkpoint_max=keep_checkpoint_max)
Exemple #6
0
def simple_eval(model_dir,
                tpu_name,
                bucket_name_prefix,
                mode,
                load_ckpt,
                total_bs=256,
                tfds_data_dir='tensorflow_datasets'):
    region = utils.get_gcp_region()
    tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region,
                                           tfds_data_dir)
    kwargs = tpu_utils.load_train_kwargs(model_dir)
    print('loaded kwargs:', kwargs)
    ds = datasets.get_dataset(kwargs['dataset'], tfds_data_dir=tfds_data_dir)
    worker = simple_eval_worker.SimpleEvalWorker(
        tpu_name=tpu_name,
        model_constructor=functools.partial(_load_model, kwargs=kwargs, ds=ds),
        total_bs=total_bs,
        dataset=ds)
    worker.run(mode=mode, logdir=model_dir, load_ckpt=load_ckpt)
Exemple #7
0
def evaluation(
    model_dir,
    tpu_name,
    bucket_name_prefix,
    once=False,
    dump_samples_only=False,
    total_bs=128,
    tfr_file='tensorflow_datasets/lsun/church-r08.tfrecords',
    samples_dir=None,
    num_inception_samples=2048,
):
    region = utils.get_gcp_region()
    tfr_file = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfr_file)
    kwargs = tpu_utils.load_train_kwargs(model_dir)
    print('loaded kwargs:', kwargs)
    ds = datasets.get_dataset(kwargs['dataset'], tfr_file=tfr_file)
    worker = tpu_utils.EvalWorker(
        tpu_name=tpu_name,
        model_constructor=lambda: Model(model_name=kwargs['model_name'],
                                        betas=get_beta_schedule(
                                            kwargs['beta_schedule'],
                                            beta_start=kwargs['beta_start'],
                                            beta_end=kwargs['beta_end'],
                                            num_diffusion_timesteps=kwargs[
                                                'num_diffusion_timesteps']),
                                        loss_type=kwargs['loss_type'],
                                        num_classes=ds.num_classes,
                                        dropout=kwargs['dropout'],
                                        randflip=kwargs['randflip'],
                                        block_size=kwargs['block_size']),
        total_bs=total_bs,
        inception_bs=total_bs,
        num_inception_samples=num_inception_samples,
        dataset=ds,
        limit_dataset_size=
        30000  # limit size of dataset for computing Inception features, for memory reasons
    )
    worker.run(logdir=model_dir,
               once=once,
               skip_non_ema_pass=True,
               dump_samples_only=dump_samples_only,
               samples_dir=samples_dir)