def evaluation( # evaluation loop for use during training model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=256, tfds_data_dir='tensorflow_datasets', load_ckpt=None): region = utils.get_gcp_region() tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir) kwargs = tpu_utils.load_train_kwargs(model_dir) print('loaded kwargs:', kwargs) ds = datasets.get_dataset(kwargs['dataset'], tfds_data_dir=tfds_data_dir) worker = tpu_utils.EvalWorker( tpu_name=tpu_name, model_constructor=functools.partial(_load_model, kwargs=kwargs, ds=ds), total_bs=total_bs, inception_bs=total_bs, num_inception_samples=50000, dataset=ds, ) worker.run(logdir=model_dir, once=once, skip_non_ema_pass=True, dump_samples_only=dump_samples_only, load_ckpt=load_ckpt)
def train( exp_name, tpu_name, bucket_name_prefix, model_name='unet2d16b2c112244', dataset='celebahq256', optimizer='adam', total_bs=64, grad_clip=1., lr=0.00002, warmup=5000, num_diffusion_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', loss_type='noisepred', dropout=0.0, randflip=1, block_size=1, tfds_data_dir='tensorflow_datasets', log_dir='logs' ): region = utils.get_gcp_region() tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir) log_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, log_dir) kwargs = dict(locals()) ds = datasets.get_dataset(dataset, tfds_data_dir=tfds_data_dir) tpu_utils.run_training( date_str='9999-99-99', exp_name='{exp_name}_{dataset}_{model_name}_{optimizer}_bs{total_bs}_lr{lr}w{warmup}_beta{beta_start}-{beta_end}-{beta_schedule}_t{num_diffusion_timesteps}_{loss_type}_dropout{dropout}_randflip{randflip}_blk{block_size}'.format( **kwargs), model_constructor=lambda: Model( model_name=model_name, betas=get_beta_schedule( beta_schedule, beta_start=beta_start, beta_end=beta_end, num_diffusion_timesteps=num_diffusion_timesteps ), loss_type=loss_type, num_classes=ds.num_classes, dropout=dropout, randflip=randflip, block_size=block_size ), optimizer=optimizer, total_bs=total_bs, lr=lr, warmup=warmup, grad_clip=grad_clip, train_input_fn=ds.train_input_fn, tpu=tpu_name, log_dir=log_dir, dump_kwargs=kwargs )
def evaluation( model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=128, tfds_data_dir='tensorflow_datasets', ): region = utils.get_gcp_region() tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir) kwargs = tpu_utils.load_train_kwargs(model_dir) print('loaded kwargs:', kwargs) ds = datasets.get_dataset(kwargs['dataset'], tfds_data_dir=tfds_data_dir) worker = tpu_utils.EvalWorker( tpu_name=tpu_name, model_constructor=lambda: Model( model_name=kwargs['model_name'], betas=get_beta_schedule( kwargs['beta_schedule'], beta_start=kwargs['beta_start'], beta_end=kwargs['beta_end'], num_diffusion_timesteps=kwargs['num_diffusion_timesteps'] ), loss_type=kwargs['loss_type'], num_classes=ds.num_classes, dropout=kwargs['dropout'], randflip=kwargs['randflip'], block_size=kwargs['block_size'] ), total_bs=total_bs, inception_bs=total_bs, num_inception_samples=2048, dataset=ds, ) worker.run(logdir=model_dir, once=once, skip_non_ema_pass=True, dump_samples_only=dump_samples_only)
def train(exp_name, tpu_name, bucket_name_prefix, model_name='unet2d16b2c112244', dataset='lsun', optimizer='adam', total_bs=64, grad_clip=1., lr=2e-5, warmup=5000, num_diffusion_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', loss_type='noisepred', dropout=0.0, randflip=1, block_size=1, tfr_file='tensorflow_datasets/lsun/church/church-r08.tfrecords', log_dir='logs', warm_start_model_dir=None): region = utils.get_gcp_region() tfr_file = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfr_file) log_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, log_dir) print("tfr_file:", tfr_file) print("log_dir:", log_dir) kwargs = dict(locals()) ds = datasets.get_dataset(dataset, tfr_file=tfr_file) tpu_utils.run_training( date_str='9999-99-99', exp_name= '{exp_name}_{dataset}_{model_name}_{optimizer}_bs{total_bs}_lr{lr}w{warmup}_beta{beta_start}-{beta_end}-{beta_schedule}_t{num_diffusion_timesteps}_{loss_type}_dropout{dropout}_randflip{randflip}_blk{block_size}' .format(**kwargs), model_constructor=lambda: Model( model_name=model_name, betas=get_beta_schedule(beta_schedule, beta_start=beta_start, beta_end=beta_end, num_diffusion_timesteps= num_diffusion_timesteps), loss_type=loss_type, num_classes=ds.num_classes, dropout=dropout, randflip=randflip, block_size=block_size), optimizer=optimizer, total_bs=total_bs, lr=lr, warmup=warmup, grad_clip=grad_clip, train_input_fn=ds.train_input_fn, tpu=tpu_name, log_dir=log_dir, dump_kwargs=kwargs, warm_start_from=tf.estimator.WarmStartSettings( ckpt_to_initialize_from=tf.train.latest_checkpoint( warm_start_model_dir), vars_to_warm_start=[".*"]) if warm_start_model_dir else None)
def train(exp_name, tpu_name, bucket_name_prefix, model_name='unet2d16b2', dataset='cifar10', optimizer='adam', total_bs=128, grad_clip=1., lr=2e-4, warmup=5000, num_diffusion_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', model_mean_type='eps', model_var_type='fixedlarge', loss_type='mse', dropout=0.1, randflip=1, tfds_data_dir='tensorflow_datasets', log_dir='logs', keep_checkpoint_max=2): region = utils.get_gcp_region() tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir) log_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, log_dir) kwargs = dict(locals()) ds = datasets.get_dataset(dataset, tfds_data_dir=tfds_data_dir) tpu_utils.run_training( date_str='9999-99-99', exp_name= '{exp_name}_{dataset}_{model_name}_{optimizer}_bs{total_bs}_lr{lr}w{warmup}_beta{beta_start}-{beta_end}-{beta_schedule}_t{num_diffusion_timesteps}_{model_mean_type}-{model_var_type}-{loss_type}_dropout{dropout}_randflip{randflip}' .format(**kwargs), model_constructor=lambda: Model( model_name=model_name, betas=get_beta_schedule(beta_schedule, beta_start=beta_start, beta_end=beta_end, num_diffusion_timesteps= num_diffusion_timesteps), model_mean_type=model_mean_type, model_var_type=model_var_type, loss_type=loss_type, num_classes=ds.num_classes, dropout=dropout, randflip=randflip), optimizer=optimizer, total_bs=total_bs, lr=lr, warmup=warmup, grad_clip=grad_clip, train_input_fn=ds.train_input_fn, tpu=tpu_name, log_dir=log_dir, dump_kwargs=kwargs, iterations_per_loop=2000, keep_checkpoint_max=keep_checkpoint_max)
def simple_eval(model_dir, tpu_name, bucket_name_prefix, mode, load_ckpt, total_bs=256, tfds_data_dir='tensorflow_datasets'): region = utils.get_gcp_region() tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir) kwargs = tpu_utils.load_train_kwargs(model_dir) print('loaded kwargs:', kwargs) ds = datasets.get_dataset(kwargs['dataset'], tfds_data_dir=tfds_data_dir) worker = simple_eval_worker.SimpleEvalWorker( tpu_name=tpu_name, model_constructor=functools.partial(_load_model, kwargs=kwargs, ds=ds), total_bs=total_bs, dataset=ds) worker.run(mode=mode, logdir=model_dir, load_ckpt=load_ckpt)
def evaluation( model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=128, tfr_file='tensorflow_datasets/lsun/church-r08.tfrecords', samples_dir=None, num_inception_samples=2048, ): region = utils.get_gcp_region() tfr_file = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfr_file) kwargs = tpu_utils.load_train_kwargs(model_dir) print('loaded kwargs:', kwargs) ds = datasets.get_dataset(kwargs['dataset'], tfr_file=tfr_file) worker = tpu_utils.EvalWorker( tpu_name=tpu_name, model_constructor=lambda: Model(model_name=kwargs['model_name'], betas=get_beta_schedule( kwargs['beta_schedule'], beta_start=kwargs['beta_start'], beta_end=kwargs['beta_end'], num_diffusion_timesteps=kwargs[ 'num_diffusion_timesteps']), loss_type=kwargs['loss_type'], num_classes=ds.num_classes, dropout=kwargs['dropout'], randflip=kwargs['randflip'], block_size=kwargs['block_size']), total_bs=total_bs, inception_bs=total_bs, num_inception_samples=num_inception_samples, dataset=ds, limit_dataset_size= 30000 # limit size of dataset for computing Inception features, for memory reasons ) worker.run(logdir=model_dir, once=once, skip_non_ema_pass=True, dump_samples_only=dump_samples_only, samples_dir=samples_dir)