예제 #1
0
def evaluation(  # evaluation loop for use during training
        model_dir,
        tpu_name,
        bucket_name_prefix,
        once=False,
        dump_samples_only=False,
        total_bs=256,
        tfds_data_dir='tensorflow_datasets',
        load_ckpt=None):
    region = utils.get_gcp_region()
    tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region,
                                           tfds_data_dir)
    kwargs = tpu_utils.load_train_kwargs(model_dir)
    print('loaded kwargs:', kwargs)
    ds = datasets.get_dataset(kwargs['dataset'], tfds_data_dir=tfds_data_dir)
    worker = tpu_utils.EvalWorker(
        tpu_name=tpu_name,
        model_constructor=functools.partial(_load_model, kwargs=kwargs, ds=ds),
        total_bs=total_bs,
        inception_bs=total_bs,
        num_inception_samples=50000,
        dataset=ds,
    )
    worker.run(logdir=model_dir,
               once=once,
               skip_non_ema_pass=True,
               dump_samples_only=dump_samples_only,
               load_ckpt=load_ckpt)
예제 #2
0
def evaluation(
    model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=128,
    tfr_file='tensorflow_datasets/lsun/church-r08.tfrecords', samples_dir=None, num_inception_samples=2048,
):
  #region = utils.get_gcp_region()
  tfr_file = 'gs://{}/{}'.format(bucket_name_prefix, tfr_file)
  kwargs = tpu_utils.load_train_kwargs(model_dir)
  print('loaded kwargs:', kwargs)
  ds = datasets.get_dataset(kwargs['dataset'], tfr_file=tfr_file)
  worker = tpu_utils.EvalWorker(
    tpu_name=tpu_name,
    model_constructor=lambda: Model(
      model_name=kwargs['model_name'],
      betas=get_beta_schedule(
        kwargs['beta_schedule'], beta_start=kwargs['beta_start'], beta_end=kwargs['beta_end'],
        num_diffusion_timesteps=kwargs['num_diffusion_timesteps']
      ),
      loss_type=kwargs['loss_type'],
      num_classes=ds.num_classes,
      dropout=kwargs['dropout'],
      randflip=kwargs['randflip'],
      block_size=kwargs['block_size']
    ),
    total_bs=total_bs, inception_bs=total_bs, num_inception_samples=num_inception_samples,
    dataset=ds,
    limit_dataset_size=30000  # limit size of dataset for computing Inception features, for memory reasons
  )
  worker.run(logdir=model_dir, once=once, skip_non_ema_pass=True, dump_samples_only=dump_samples_only,
             samples_dir=samples_dir)
예제 #3
0
def evaluation(
    model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=128,
    tfds_data_dir='tensorflow_datasets',
):
  region = utils.get_gcp_region()
  tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir)
  kwargs = tpu_utils.load_train_kwargs(model_dir)
  print('loaded kwargs:', kwargs)
  ds = datasets.get_dataset(kwargs['dataset'], tfds_data_dir=tfds_data_dir)
  worker = tpu_utils.EvalWorker(
    tpu_name=tpu_name,
    model_constructor=lambda: Model(
      model_name=kwargs['model_name'],
      betas=get_beta_schedule(
        kwargs['beta_schedule'], beta_start=kwargs['beta_start'], beta_end=kwargs['beta_end'],
        num_diffusion_timesteps=kwargs['num_diffusion_timesteps']
      ),
      loss_type=kwargs['loss_type'],
      num_classes=ds.num_classes,
      dropout=kwargs['dropout'],
      randflip=kwargs['randflip'],
      block_size=kwargs['block_size']
    ),
    total_bs=total_bs, inception_bs=total_bs, num_inception_samples=2048,
    dataset=ds,
  )
  worker.run(logdir=model_dir, once=once, skip_non_ema_pass=True, dump_samples_only=dump_samples_only)
예제 #4
0
def simple_eval(model_dir, tpu_name, bucket_name_prefix, mode, load_ckpt=None, total_bs=16, tfr_file='tensorflow_datasets/lsun/church-r08.tfrecords'):
  #region = utils.get_gcp_region()
  tfr_file = 'gs://{}/{}'.format(bucket_name_prefix, tfr_file)
  kwargs = tpu_utils.load_train_kwargs(model_dir)
  print('loaded kwargs:', kwargs)
  ds = datasets.get_dataset(kwargs['dataset'], tfr_file=tfr_file)
  worker = simple_eval_worker.SimpleEvalWorker(
    tpu_name=tpu_name, model_constructor=functools.partial(_load_model, kwargs=kwargs, ds=ds),
    total_bs=total_bs, dataset=ds)
  worker.run(mode=mode, logdir=model_dir, load_ckpt=load_ckpt)
예제 #5
0
def simple_eval(model_dir, tpu_name, bucket_name_prefix, mode, load_ckpt=None, total_bs=16, tfr_file='tensorflow_datasets/lsun/church-r08.tfrecords', samples_dir='gs://tensorfork-arfa-euw4/ddpm-samples'):
  #region = utils.get_gcp_region()
  tfr_file = 'gs://{}/{}'.format(bucket_name_prefix, tfr_file)
  kwargs = tpu_utils.load_train_kwargs(model_dir)
  if 'NUM_DIFFUSION_TIMESTEPS' in os.environ:
      kwargs['num_diffusion_timesteps'] = int(os.environ['NUM_DIFFUSION_TIMESTEPS'])
  if 'BETA_SCHEDULE' in os.environ:
      kwargs['beta_schedule'] = os.environ['BETA_SCHEDULE']

  print('loaded kwargs:', kwargs)
  ds = datasets.get_dataset(kwargs['dataset'], tfr_file=tfr_file)
  worker = simple_eval_worker.SimpleEvalWorker(
    tpu_name=tpu_name, model_constructor=functools.partial(_load_model, kwargs=kwargs, ds=ds),
    total_bs=total_bs, dataset=ds)
  worker.run(mode=mode, logdir=model_dir, load_ckpt=load_ckpt, samples_dir=samples_dir)
예제 #6
0
def simple_eval(model_dir,
                tpu_name,
                bucket_name_prefix,
                mode,
                load_ckpt,
                total_bs=256,
                tfds_data_dir='tensorflow_datasets'):
    region = utils.get_gcp_region()
    tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region,
                                           tfds_data_dir)
    kwargs = tpu_utils.load_train_kwargs(model_dir)
    print('loaded kwargs:', kwargs)
    ds = datasets.get_dataset(kwargs['dataset'], tfds_data_dir=tfds_data_dir)
    worker = simple_eval_worker.SimpleEvalWorker(
        tpu_name=tpu_name,
        model_constructor=functools.partial(_load_model, kwargs=kwargs, ds=ds),
        total_bs=total_bs,
        dataset=ds)
    worker.run(mode=mode, logdir=model_dir, load_ckpt=load_ckpt)