def save_metadata(hparams): """Saves FLAGS and hparams to output_dir.""" output_dir = os.path.expanduser(FLAGS.output_dir) if not tf.gfile.Exists(output_dir): tf.gfile.MakeDirs(output_dir) # Save FLAGS in txt file if hasattr(FLAGS, "flags_into_string"): flags_str = FLAGS.flags_into_string() t2t_flags_str = "\n".join([ "--%s=%s" % (f.name, f.value) for f in FLAGS.flags_by_module_dict()["tensor2tensor.utils.flags"] ]) else: flags_dict = FLAGS.__dict__["__flags"] flags_str = "\n".join( ["--%s=%s" % (name, str(f)) for (name, f) in flags_dict.items()]) t2t_flags_str = None flags_txt = os.path.join(output_dir, "flags.txt") with tf.gfile.Open(flags_txt, "w") as f: f.write(flags_str) if t2t_flags_str: t2t_flags_txt = os.path.join(output_dir, "flags_t2t.txt") with tf.gfile.Open(t2t_flags_txt, "w") as f: f.write(t2t_flags_str) # Save hparams as hparams.json new_hparams = hparams_lib.copy_hparams(hparams) # Modality class is not JSON serializable so remove. new_hparams.del_hparam("modality") hparams_fname = os.path.join(output_dir, "hparams.json") with tf.gfile.Open(hparams_fname, "w") as f: f.write(new_hparams.to_json(indent=0, sort_keys=True))
def estimator_model_fn(cls, hparams, features, labels, mode, config=None, params=None, decode_hparams=None, use_tpu=False): hparams = hparams_lib.copy_hparams(hparams) hparams.use_tpu = use_tpu # merge decode_hparams into hparams if present if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None: for k, v in six.iteritems(decode_hparams.values()): if hasattr(hparams, k) and getattr(hparams, k) != v: tf.logging.warning( "Overriding hparams.%s with %s from decode_hparams" % (k, v)) setattr(hparams, k, v) # Instantiate model data_parallelism = None if not use_tpu and config: data_parallelism = config.data_parallelism model = cls(hparams, mode, data_parallelism=data_parallelism, decode_hparams=decode_hparams) global_step = tf.train.get_global_step() mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) layout_rules = mtf.convert_to_layout_rules(hparams.layout) if use_tpu: ctx = params["context"] num_hosts = ctx.num_hosts host_placement_fn = ctx.tpu_host_placement_function device_list = [ host_placement_fn(host_id=t) for t in range(num_hosts) ] # TODO(ylc): Better estimation of replica cache size? replica_cache_size = 300 * 1000000 # 300M per replica # Worker 0 caches all the TPU binaries. worker0_mem = replica_cache_size * ctx.num_replicas devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1) var_placer = mtf.utils.BalancedVariablePlacer( device_list, devices_memeory_usage) mesh_devices = [""] * mesh_shape.size mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, ctx.device_assignment) else: var_placer = None if data_parallelism is None or len( data_parallelism.ps_devices) == 1: mesh_devices = [""] * mesh_shape.size else: assert len(data_parallelism.ps_devices) == mesh_shape.size mesh_devices = data_parallelism.ps_devices mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh", var_placer) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu) logits, loss = model.mtf_model_fn(features, mesh) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = learning_rate.learning_rate_schedule(hparams) tf.summary.scalar("learning_rate", lr) mtf_lr = mtf.import_tf_tensor( mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([])) optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr) update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) # tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) with mtf.utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( hparams.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) # EVAL mode if mode == tf.estimator.ModeKeys.EVAL: tf_logits = lowering.export_to_tf_tensor(logits) return model.estimator_spec_eval(features, tf_logits, labels, tf_loss, restore_hook, use_tpu) if use_tpu: # TPU host call. Important: need to be called before remove_summaries() if hparams.tpu_enable_host_call: host_call = t2t_model.create_host_call(hparams.model_dir) else: host_call = None if hparams.warm_start_from: def scaffold_fn(): t2t_model.initialize_from_ckpt( ckpt_dir=hparams.warm_start_from, hparams=hparams) return tf.train.Scaffold() else: scaffold_fn = None t2t_model.remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, host_call=host_call, training_hooks=[restore_hook, saver_hook], scaffold_fn=scaffold_fn) else: if hparams.warm_start_from: t2t_model.initialize_from_ckpt( ckpt_dir=hparams.warm_start_from, hparams=hparams) return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook])
def estimator_model_fn(cls, hparams, features, labels, mode, config=None, params=None, decode_hparams=None, use_tpu=False): if mode not in [ model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL, model_fn_lib.ModeKeys.PREDICT ]: raise ValueError('Mode not recognized: %s' % mode) if mode is model_fn_lib.ModeKeys.TRAIN: is_training = True else: is_training = False hparams = hparams_lib.copy_hparams(hparams) # Instantiate model data_parallelism = None if not use_tpu and config: data_parallelism = config.data_parallelism reuse = tf.get_variable_scope().reuse # Instantiate model self = cls(hparams, mode, data_parallelism=data_parallelism, decode_hparams=decode_hparams, _reuse=reuse) generator_inputs = self.sample_noise() # rename inputs for clarity real_data = features['inputs'] img_shape = common_layers.shape_list(real_data)[1:4] real_data.set_shape([hparams.batch_size] + img_shape) # To satify the TFGAN API setting real data to none on predict if mode == tf.estimator.ModeKeys.PREDICT: real_data = None optimizers = Optimizers( tf.compat.v1.train.AdamOptimizer(hparams.generator_lr, hparams.beta1), tf.compat.v1.train.AdamOptimizer(hparams.discriminator_lr, hparams.beta1)) # Creates tfhub modules for both generator and discriminator def make_discriminator_spec(): input_layer = tf.placeholder(tf.float32, shape=[None] + img_shape) disc_output = self.discriminator(input_layer, None, mode) hub.add_signature(inputs=input_layer, outputs=disc_output) disc_spec = hub.create_module_spec(make_discriminator_spec) def make_generator_spec(): input_layer = tf.placeholder( tf.float32, shape=[None] + common_layers.shape_list(generator_inputs)[1:]) gen_output = self.generator(input_layer, mode) hub.add_signature(inputs=input_layer, outputs=gen_output) gen_spec = hub.create_module_spec(make_generator_spec) # Create the modules discriminator_module = hub.Module(disc_spec, name="Discriminator_Module", trainable=True) generator_module = hub.Module(gen_spec, name="Generator_Module", trainable=True) # Wraps the modules into functions expected by TF-GAN def generator(code, mode): p = hparams out = generator_module(code) shape = common_layers.shape_list(out) # Applying convolution by PSF convolution if p.apply_psf and 'psf' in features: out = convolve(out, tf.cast(features['psf'][..., 0], tf.complex64)) # Adds noise according to the provided power spectrum noise = tf.spectral.rfft2d(tf.random_normal(out.get_shape()[:3])) thresholded_ps = tf.where(features['ps'] >= 9, tf.zeros_like(features['ps']), tf.sqrt(tf.exp(features['ps']))) noise = noise * tf.cast(thresholded_ps, tf.complex64) out = out + tf.expand_dims(tf.spectral.irfft2d(noise), axis=-1) return out discriminator = lambda image, conditioning, mode: discriminator_module( image) # Make GANModel, which encapsulates the GAN model architectures. gan_model = get_gan_model(mode, generator, discriminator, real_data, generator_inputs, add_summaries=self.summaries) # Make GANLoss, which encapsulates the losses. if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]: gan_loss = tfgan_train.gan_loss(gan_model, self.generator_loss, self.discriminator_loss, add_summaries=True) # Make the EstimatorSpec, which incorporates the GANModel, losses, eval # metrics, and optimizers (if required). if mode == tf.estimator.ModeKeys.TRAIN: get_hooks_fn = tfgan_train.get_sequential_train_hooks( namedtuples.GANTrainSteps(hparams.gen_steps, hparams.disc_steps)) estimator_spec = get_train_estimator_spec(gan_model, gan_loss, optimizers, get_hooks_fn, is_chief=True) elif mode == tf.estimator.ModeKeys.EVAL: estimator_spec = get_eval_estimator_spec(gan_model, gan_loss) else: # tf.estimator.ModeKeys.PREDICT # Register hub modules for export hub.register_module_for_export(generator_module, "generator") hub.register_module_for_export(discriminator_module, "discriminator") estimator_spec = get_predict_estimator_spec(gan_model) return estimator_spec