def train( exp_name, tpu_name, bucket_name_prefix, model_name='unet2d16b2c112244', dataset='celebahq256', optimizer='adam', total_bs=64, grad_clip=1., lr=0.00002, warmup=5000, num_diffusion_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', loss_type='noisepred', dropout=0.0, randflip=1, block_size=1, tfds_data_dir='tensorflow_datasets', log_dir='logs' ): region = utils.get_gcp_region() tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir) log_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, log_dir) kwargs = dict(locals()) ds = datasets.get_dataset(dataset, tfds_data_dir=tfds_data_dir) tpu_utils.run_training( date_str='9999-99-99', exp_name='{exp_name}_{dataset}_{model_name}_{optimizer}_bs{total_bs}_lr{lr}w{warmup}_beta{beta_start}-{beta_end}-{beta_schedule}_t{num_diffusion_timesteps}_{loss_type}_dropout{dropout}_randflip{randflip}_blk{block_size}'.format( **kwargs), model_constructor=lambda: Model( model_name=model_name, betas=get_beta_schedule( beta_schedule, beta_start=beta_start, beta_end=beta_end, num_diffusion_timesteps=num_diffusion_timesteps ), loss_type=loss_type, num_classes=ds.num_classes, dropout=dropout, randflip=randflip, block_size=block_size ), optimizer=optimizer, total_bs=total_bs, lr=lr, warmup=warmup, grad_clip=grad_clip, train_input_fn=ds.train_input_fn, tpu=tpu_name, log_dir=log_dir, dump_kwargs=kwargs )
def evaluation( model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=128, tfr_file='tensorflow_datasets/lsun/church-r08.tfrecords', samples_dir=None, num_inception_samples=2048, ): #region = utils.get_gcp_region() tfr_file = 'gs://{}/{}'.format(bucket_name_prefix, tfr_file) kwargs = tpu_utils.load_train_kwargs(model_dir) print('loaded kwargs:', kwargs) ds = datasets.get_dataset(kwargs['dataset'], tfr_file=tfr_file) worker = tpu_utils.EvalWorker( tpu_name=tpu_name, model_constructor=lambda: Model( model_name=kwargs['model_name'], betas=get_beta_schedule( kwargs['beta_schedule'], beta_start=kwargs['beta_start'], beta_end=kwargs['beta_end'], num_diffusion_timesteps=kwargs['num_diffusion_timesteps'] ), loss_type=kwargs['loss_type'], num_classes=ds.num_classes, dropout=kwargs['dropout'], randflip=kwargs['randflip'], block_size=kwargs['block_size'] ), total_bs=total_bs, inception_bs=total_bs, num_inception_samples=num_inception_samples, dataset=ds, limit_dataset_size=30000 # limit size of dataset for computing Inception features, for memory reasons ) worker.run(logdir=model_dir, once=once, skip_non_ema_pass=True, dump_samples_only=dump_samples_only, samples_dir=samples_dir)
def evaluation( model_dir, tpu_name, bucket_name_prefix, once=False, dump_samples_only=False, total_bs=128, tfds_data_dir='tensorflow_datasets', ): region = utils.get_gcp_region() tfds_data_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfds_data_dir) kwargs = tpu_utils.load_train_kwargs(model_dir) print('loaded kwargs:', kwargs) ds = datasets.get_dataset(kwargs['dataset'], tfds_data_dir=tfds_data_dir) worker = tpu_utils.EvalWorker( tpu_name=tpu_name, model_constructor=lambda: Model( model_name=kwargs['model_name'], betas=get_beta_schedule( kwargs['beta_schedule'], beta_start=kwargs['beta_start'], beta_end=kwargs['beta_end'], num_diffusion_timesteps=kwargs['num_diffusion_timesteps'] ), loss_type=kwargs['loss_type'], num_classes=ds.num_classes, dropout=kwargs['dropout'], randflip=kwargs['randflip'], block_size=kwargs['block_size'] ), total_bs=total_bs, inception_bs=total_bs, num_inception_samples=2048, dataset=ds, ) worker.run(logdir=model_dir, once=once, skip_non_ema_pass=True, dump_samples_only=dump_samples_only)
def train(exp_name, tpu_name, bucket_name_prefix, model_name='unet2d16b2c112244', dataset='lsun', optimizer='adam', total_bs=64, grad_clip=1., lr=2e-5, warmup=5000, num_diffusion_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', loss_type='noisepred', dropout=0.0, randflip=1, block_size=1, tfr_file='tensorflow_datasets/lsun/church/church-r08.tfrecords', log_dir='logs', warm_start_model_dir=None): region = utils.get_gcp_region() tfr_file = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, tfr_file) log_dir = 'gs://{}-{}/{}'.format(bucket_name_prefix, region, log_dir) print("tfr_file:", tfr_file) print("log_dir:", log_dir) kwargs = dict(locals()) ds = datasets.get_dataset(dataset, tfr_file=tfr_file) tpu_utils.run_training( date_str='9999-99-99', exp_name= '{exp_name}_{dataset}_{model_name}_{optimizer}_bs{total_bs}_lr{lr}w{warmup}_beta{beta_start}-{beta_end}-{beta_schedule}_t{num_diffusion_timesteps}_{loss_type}_dropout{dropout}_randflip{randflip}_blk{block_size}' .format(**kwargs), model_constructor=lambda: Model( model_name=model_name, betas=get_beta_schedule(beta_schedule, beta_start=beta_start, beta_end=beta_end, num_diffusion_timesteps= num_diffusion_timesteps), loss_type=loss_type, num_classes=ds.num_classes, dropout=dropout, randflip=randflip, block_size=block_size), optimizer=optimizer, total_bs=total_bs, lr=lr, warmup=warmup, grad_clip=grad_clip, train_input_fn=ds.train_input_fn, tpu=tpu_name, log_dir=log_dir, dump_kwargs=kwargs, warm_start_from=tf.estimator.WarmStartSettings( ckpt_to_initialize_from=tf.train.latest_checkpoint( warm_start_model_dir), vars_to_warm_start=[".*"]) if warm_start_model_dir else None)
def _load_model(kwargs, ds): return Model( model_name=kwargs['model_name'], betas=get_beta_schedule( kwargs['beta_schedule'], beta_start=kwargs['beta_start'], beta_end=kwargs['beta_end'], num_diffusion_timesteps=kwargs['num_diffusion_timesteps']), #model_mean_type='eps',#kwargs['model_mean_type'], #model_var_type='learned',#kwargs['model_var_type'], loss_type=kwargs['loss_type'], num_classes=ds.num_classes, dropout=kwargs['dropout'], block_size=kwargs['block_size'], randflip=kwargs['randflip'])