Ejemplo n.º 1
0
def train(
    exp_name, tpu_name, bucket_name_prefix, model_name='unet2d16b2c112244', dataset='celebahq256',
    optimizer='adam', total_bs=64, grad_clip=1., lr=0.00002, warmup=5000,
    num_diffusion_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', loss_type='noisepred',
    dropout=0.0, randflip=1, block_size=1,
    tfds_data_dir='tensorflow_datasets', log_dir='logs'
):
  region = utils.get_gcp_region()
  tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir)
  log_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, log_dir)
  kwargs = dict(locals())
  ds = datasets.get_dataset(dataset, tfds_data_dir=tfds_data_dir)
  tpu_utils.run_training(
    date_str='9999-99-99',
    exp_name='{exp_name}_{dataset}_{model_name}_{optimizer}_bs{total_bs}_lr{lr}w{warmup}_beta{beta_start}-{beta_end}-{beta_schedule}_t{num_diffusion_timesteps}_{loss_type}_dropout{dropout}_randflip{randflip}_blk{block_size}'.format(
      **kwargs),
    model_constructor=lambda: Model(
      model_name=model_name,
      betas=get_beta_schedule(
        beta_schedule, beta_start=beta_start, beta_end=beta_end, num_diffusion_timesteps=num_diffusion_timesteps
      ),
      loss_type=loss_type,
      num_classes=ds.num_classes,
      dropout=dropout,
      randflip=randflip,
      block_size=block_size
    ),
    optimizer=optimizer, total_bs=total_bs, lr=lr, warmup=warmup, grad_clip=grad_clip,
    train_input_fn=ds.train_input_fn,
    tpu=tpu_name, log_dir=log_dir, dump_kwargs=kwargs
  )
Ejemplo n.º 2
0
def evaluation(
    model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=128,
    tfr_file='tensorflow_datasets/lsun/church-r08.tfrecords', samples_dir=None, num_inception_samples=2048,
):
  #region = utils.get_gcp_region()
  tfr_file = 'gs://{}/{}'.format(bucket_name_prefix, tfr_file)
  kwargs = tpu_utils.load_train_kwargs(model_dir)
  print('loaded kwargs:', kwargs)
  ds = datasets.get_dataset(kwargs['dataset'], tfr_file=tfr_file)
  worker = tpu_utils.EvalWorker(
    tpu_name=tpu_name,
    model_constructor=lambda: Model(
      model_name=kwargs['model_name'],
      betas=get_beta_schedule(
        kwargs['beta_schedule'], beta_start=kwargs['beta_start'], beta_end=kwargs['beta_end'],
        num_diffusion_timesteps=kwargs['num_diffusion_timesteps']
      ),
      loss_type=kwargs['loss_type'],
      num_classes=ds.num_classes,
      dropout=kwargs['dropout'],
      randflip=kwargs['randflip'],
      block_size=kwargs['block_size']
    ),
    total_bs=total_bs, inception_bs=total_bs, num_inception_samples=num_inception_samples,
    dataset=ds,
    limit_dataset_size=30000  # limit size of dataset for computing Inception features, for memory reasons
  )
  worker.run(logdir=model_dir, once=once, skip_non_ema_pass=True, dump_samples_only=dump_samples_only,
             samples_dir=samples_dir)
Ejemplo n.º 3
0
def evaluation(
    model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=128,
    tfds_data_dir='tensorflow_datasets',
):
  region = utils.get_gcp_region()
  tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir)
  kwargs = tpu_utils.load_train_kwargs(model_dir)
  print('loaded kwargs:', kwargs)
  ds = datasets.get_dataset(kwargs['dataset'], tfds_data_dir=tfds_data_dir)
  worker = tpu_utils.EvalWorker(
    tpu_name=tpu_name,
    model_constructor=lambda: Model(
      model_name=kwargs['model_name'],
      betas=get_beta_schedule(
        kwargs['beta_schedule'], beta_start=kwargs['beta_start'], beta_end=kwargs['beta_end'],
        num_diffusion_timesteps=kwargs['num_diffusion_timesteps']
      ),
      loss_type=kwargs['loss_type'],
      num_classes=ds.num_classes,
      dropout=kwargs['dropout'],
      randflip=kwargs['randflip'],
      block_size=kwargs['block_size']
    ),
    total_bs=total_bs, inception_bs=total_bs, num_inception_samples=2048,
    dataset=ds,
  )
  worker.run(logdir=model_dir, once=once, skip_non_ema_pass=True, dump_samples_only=dump_samples_only)
Ejemplo n.º 4
0
def train(exp_name,
          tpu_name,
          bucket_name_prefix,
          model_name='unet2d16b2c112244',
          dataset='lsun',
          optimizer='adam',
          total_bs=64,
          grad_clip=1.,
          lr=2e-5,
          warmup=5000,
          num_diffusion_timesteps=1000,
          beta_start=0.0001,
          beta_end=0.02,
          beta_schedule='linear',
          loss_type='noisepred',
          dropout=0.0,
          randflip=1,
          block_size=1,
          tfr_file='tensorflow_datasets/lsun/church/church-r08.tfrecords',
          log_dir='logs',
          warm_start_model_dir=None):
    region = utils.get_gcp_region()
    tfr_file = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfr_file)
    log_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, log_dir)
    print("tfr_file:", tfr_file)
    print("log_dir:", log_dir)
    kwargs = dict(locals())
    ds = datasets.get_dataset(dataset, tfr_file=tfr_file)
    tpu_utils.run_training(
        date_str='9999-99-99',
        exp_name=
        '{exp_name}_{dataset}_{model_name}_{optimizer}_bs{total_bs}_lr{lr}w{warmup}_beta{beta_start}-{beta_end}-{beta_schedule}_t{num_diffusion_timesteps}_{loss_type}_dropout{dropout}_randflip{randflip}_blk{block_size}'
        .format(**kwargs),
        model_constructor=lambda: Model(
            model_name=model_name,
            betas=get_beta_schedule(beta_schedule,
                                    beta_start=beta_start,
                                    beta_end=beta_end,
                                    num_diffusion_timesteps=
                                    num_diffusion_timesteps),
            loss_type=loss_type,
            num_classes=ds.num_classes,
            dropout=dropout,
            randflip=randflip,
            block_size=block_size),
        optimizer=optimizer,
        total_bs=total_bs,
        lr=lr,
        warmup=warmup,
        grad_clip=grad_clip,
        train_input_fn=ds.train_input_fn,
        tpu=tpu_name,
        log_dir=log_dir,
        dump_kwargs=kwargs,
        warm_start_from=tf.estimator.WarmStartSettings(
            ckpt_to_initialize_from=tf.train.latest_checkpoint(
                warm_start_model_dir),
            vars_to_warm_start=[".*"]) if warm_start_model_dir else None)
Ejemplo n.º 5
0
def _load_model(kwargs, ds):
    return Model(
        model_name=kwargs['model_name'],
        betas=get_beta_schedule(
            kwargs['beta_schedule'],
            beta_start=kwargs['beta_start'],
            beta_end=kwargs['beta_end'],
            num_diffusion_timesteps=kwargs['num_diffusion_timesteps']),
        #model_mean_type='eps',#kwargs['model_mean_type'],
        #model_var_type='learned',#kwargs['model_var_type'],
        loss_type=kwargs['loss_type'],
        num_classes=ds.num_classes,
        dropout=kwargs['dropout'],
        block_size=kwargs['block_size'],
        randflip=kwargs['randflip'])