def evaluation( # evaluation loop for use during training model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=256, tfds_data_dir='tensorflow_datasets', load_ckpt=None): region = utils.get_gcp_region() tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir) kwargs = tpu_utils.load_train_kwargs(model_dir) print('loaded kwargs:', kwargs) ds = datasets.get_dataset(kwargs['dataset'], tfds_data_dir=tfds_data_dir) worker = tpu_utils.EvalWorker( tpu_name=tpu_name, model_constructor=functools.partial(_load_model, kwargs=kwargs, ds=ds), total_bs=total_bs, inception_bs=total_bs, num_inception_samples=50000, dataset=ds, ) worker.run(logdir=model_dir, once=once, skip_non_ema_pass=True, dump_samples_only=dump_samples_only, load_ckpt=load_ckpt)
def evaluation( model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=128, tfr_file='tensorflow_datasets/lsun/church-r08.tfrecords', samples_dir=None, num_inception_samples=2048, ): #region = utils.get_gcp_region() tfr_file = 'gs://{}/{}'.format(bucket_name_prefix, tfr_file) kwargs = tpu_utils.load_train_kwargs(model_dir) print('loaded kwargs:', kwargs) ds = datasets.get_dataset(kwargs['dataset'], tfr_file=tfr_file) worker = tpu_utils.EvalWorker( tpu_name=tpu_name, model_constructor=lambda: Model( model_name=kwargs['model_name'], betas=get_beta_schedule( kwargs['beta_schedule'], beta_start=kwargs['beta_start'], beta_end=kwargs['beta_end'], num_diffusion_timesteps=kwargs['num_diffusion_timesteps'] ), loss_type=kwargs['loss_type'], num_classes=ds.num_classes, dropout=kwargs['dropout'], randflip=kwargs['randflip'], block_size=kwargs['block_size'] ), total_bs=total_bs, inception_bs=total_bs, num_inception_samples=num_inception_samples, dataset=ds, limit_dataset_size=30000 # limit size of dataset for computing Inception features, for memory reasons ) worker.run(logdir=model_dir, once=once, skip_non_ema_pass=True, dump_samples_only=dump_samples_only, samples_dir=samples_dir)
def evaluation( model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=128, tfds_data_dir='tensorflow_datasets', ): region = utils.get_gcp_region() tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir) kwargs = tpu_utils.load_train_kwargs(model_dir) print('loaded kwargs:', kwargs) ds = datasets.get_dataset(kwargs['dataset'], tfds_data_dir=tfds_data_dir) worker = tpu_utils.EvalWorker( tpu_name=tpu_name, model_constructor=lambda: Model( model_name=kwargs['model_name'], betas=get_beta_schedule( kwargs['beta_schedule'], beta_start=kwargs['beta_start'], beta_end=kwargs['beta_end'], num_diffusion_timesteps=kwargs['num_diffusion_timesteps'] ), loss_type=kwargs['loss_type'], num_classes=ds.num_classes, dropout=kwargs['dropout'], randflip=kwargs['randflip'], block_size=kwargs['block_size'] ), total_bs=total_bs, inception_bs=total_bs, num_inception_samples=2048, dataset=ds, ) worker.run(logdir=model_dir, once=once, skip_non_ema_pass=True, dump_samples_only=dump_samples_only)
def simple_eval(model_dir, tpu_name, bucket_name_prefix, mode, load_ckpt=None, total_bs=16, tfr_file='tensorflow_datasets/lsun/church-r08.tfrecords'): #region = utils.get_gcp_region() tfr_file = 'gs://{}/{}'.format(bucket_name_prefix, tfr_file) kwargs = tpu_utils.load_train_kwargs(model_dir) print('loaded kwargs:', kwargs) ds = datasets.get_dataset(kwargs['dataset'], tfr_file=tfr_file) worker = simple_eval_worker.SimpleEvalWorker( tpu_name=tpu_name, model_constructor=functools.partial(_load_model, kwargs=kwargs, ds=ds), total_bs=total_bs, dataset=ds) worker.run(mode=mode, logdir=model_dir, load_ckpt=load_ckpt)
def simple_eval(model_dir, tpu_name, bucket_name_prefix, mode, load_ckpt=None, total_bs=16, tfr_file='tensorflow_datasets/lsun/church-r08.tfrecords', samples_dir='gs://tensorfork-arfa-euw4/ddpm-samples'): #region = utils.get_gcp_region() tfr_file = 'gs://{}/{}'.format(bucket_name_prefix, tfr_file) kwargs = tpu_utils.load_train_kwargs(model_dir) if 'NUM_DIFFUSION_TIMESTEPS' in os.environ: kwargs['num_diffusion_timesteps'] = int(os.environ['NUM_DIFFUSION_TIMESTEPS']) if 'BETA_SCHEDULE' in os.environ: kwargs['beta_schedule'] = os.environ['BETA_SCHEDULE'] print('loaded kwargs:', kwargs) ds = datasets.get_dataset(kwargs['dataset'], tfr_file=tfr_file) worker = simple_eval_worker.SimpleEvalWorker( tpu_name=tpu_name, model_constructor=functools.partial(_load_model, kwargs=kwargs, ds=ds), total_bs=total_bs, dataset=ds) worker.run(mode=mode, logdir=model_dir, load_ckpt=load_ckpt, samples_dir=samples_dir)
def simple_eval(model_dir, tpu_name, bucket_name_prefix, mode, load_ckpt, total_bs=256, tfds_data_dir='tensorflow_datasets'): region = utils.get_gcp_region() tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir) kwargs = tpu_utils.load_train_kwargs(model_dir) print('loaded kwargs:', kwargs) ds = datasets.get_dataset(kwargs['dataset'], tfds_data_dir=tfds_data_dir) worker = simple_eval_worker.SimpleEvalWorker( tpu_name=tpu_name, model_constructor=functools.partial(_load_model, kwargs=kwargs, ds=ds), total_bs=total_bs, dataset=ds) worker.run(mode=mode, logdir=model_dir, load_ckpt=load_ckpt)