예제 #1
0
def save_metadata(hparams):
    """Saves FLAGS and hparams to output_dir."""
    output_dir = os.path.expanduser(FLAGS.output_dir)
    if not tf.gfile.Exists(output_dir):
        tf.gfile.MakeDirs(output_dir)

    # Save FLAGS in txt file
    if hasattr(FLAGS, "flags_into_string"):
        flags_str = FLAGS.flags_into_string()
        t2t_flags_str = "\n".join([
            "--%s=%s" % (f.name, f.value)
            for f in FLAGS.flags_by_module_dict()["tensor2tensor.utils.flags"]
        ])
    else:
        flags_dict = FLAGS.__dict__["__flags"]
        flags_str = "\n".join(
            ["--%s=%s" % (name, str(f)) for (name, f) in flags_dict.items()])
        t2t_flags_str = None

    flags_txt = os.path.join(output_dir, "flags.txt")
    with tf.gfile.Open(flags_txt, "w") as f:
        f.write(flags_str)

    if t2t_flags_str:
        t2t_flags_txt = os.path.join(output_dir, "flags_t2t.txt")
        with tf.gfile.Open(t2t_flags_txt, "w") as f:
            f.write(t2t_flags_str)

    # Save hparams as hparams.json
    new_hparams = hparams_lib.copy_hparams(hparams)
    # Modality class is not JSON serializable so remove.
    new_hparams.del_hparam("modality")

    hparams_fname = os.path.join(output_dir, "hparams.json")
    with tf.gfile.Open(hparams_fname, "w") as f:
        f.write(new_hparams.to_json(indent=0, sort_keys=True))
예제 #2
0
    def estimator_model_fn(cls,
                           hparams,
                           features,
                           labels,
                           mode,
                           config=None,
                           params=None,
                           decode_hparams=None,
                           use_tpu=False):
        hparams = hparams_lib.copy_hparams(hparams)
        hparams.use_tpu = use_tpu
        # merge decode_hparams into hparams if present
        if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None:
            for k, v in six.iteritems(decode_hparams.values()):
                if hasattr(hparams, k) and getattr(hparams, k) != v:
                    tf.logging.warning(
                        "Overriding hparams.%s with %s from decode_hparams" %
                        (k, v))
                setattr(hparams, k, v)

        # Instantiate model
        data_parallelism = None
        if not use_tpu and config:
            data_parallelism = config.data_parallelism
        model = cls(hparams,
                    mode,
                    data_parallelism=data_parallelism,
                    decode_hparams=decode_hparams)

        global_step = tf.train.get_global_step()

        mesh_shape = mtf.convert_to_shape(hparams.mesh_shape)
        layout_rules = mtf.convert_to_layout_rules(hparams.layout)
        if use_tpu:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape, layout_rules, mesh_devices, ctx.device_assignment)
        else:
            var_placer = None
            if data_parallelism is None or len(
                    data_parallelism.ps_devices) == 1:
                mesh_devices = [""] * mesh_shape.size
            else:
                assert len(data_parallelism.ps_devices) == mesh_shape.size
                mesh_devices = data_parallelism.ps_devices
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)
        # PREDICT mode
        if mode == tf.estimator.ModeKeys.PREDICT:
            return model.estimator_spec_predict(features, mesh, mesh_impl,
                                                use_tpu)

        logits, loss = model.mtf_model_fn(features, mesh)
        if use_tpu and logits is not None:
            logits = mtf.anonymize(logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            var_grads = mtf.gradients(
                [loss], [v.outputs[0] for v in graph.trainable_variables])
            lr = learning_rate.learning_rate_schedule(hparams)
            tf.summary.scalar("learning_rate", lr)
            mtf_lr = mtf.import_tf_tensor(
                mesh, tf.convert_to_tensor(lr, dtype=tf.float32),
                mtf.Shape([]))
            optimizer = mtf.optimize.make_optimizer(hparams, mtf_lr)
            update_ops = optimizer.apply_grads(var_grads,
                                               graph.trainable_variables)

        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        tf_loss = lowering.export_to_tf_tensor(loss)
        tf_loss = tf.to_float(tf_loss)
        if logits and mode != tf.estimator.ModeKeys.TRAIN:
            tf_logits = lowering.export_to_tf_tensor(logits)

        if mode == tf.estimator.ModeKeys.TRAIN:
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            # tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
            train_op = tf.group(tf_update_ops)

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                hparams.model_dir,
                save_steps=1000,
                saver=saver,
                listeners=[saver_listener])

        # EVAL mode
        if mode == tf.estimator.ModeKeys.EVAL:
            tf_logits = lowering.export_to_tf_tensor(logits)
            return model.estimator_spec_eval(features, tf_logits, labels,
                                             tf_loss, restore_hook, use_tpu)

        if use_tpu:
            # TPU host call. Important: need to be called before remove_summaries()
            if hparams.tpu_enable_host_call:
                host_call = t2t_model.create_host_call(hparams.model_dir)
            else:
                host_call = None

            if hparams.warm_start_from:

                def scaffold_fn():
                    t2t_model.initialize_from_ckpt(
                        ckpt_dir=hparams.warm_start_from, hparams=hparams)
                    return tf.train.Scaffold()
            else:
                scaffold_fn = None

            t2t_model.remove_summaries()
            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                host_call=host_call,
                training_hooks=[restore_hook, saver_hook],
                scaffold_fn=scaffold_fn)
        else:
            if hparams.warm_start_from:
                t2t_model.initialize_from_ckpt(
                    ckpt_dir=hparams.warm_start_from, hparams=hparams)
            return tf.estimator.EstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                training_chief_hooks=[restore_hook, saver_hook])
예제 #3
0
    def estimator_model_fn(cls,
                           hparams,
                           features,
                           labels,
                           mode,
                           config=None,
                           params=None,
                           decode_hparams=None,
                           use_tpu=False):

        if mode not in [
                model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL,
                model_fn_lib.ModeKeys.PREDICT
        ]:
            raise ValueError('Mode not recognized: %s' % mode)

        if mode is model_fn_lib.ModeKeys.TRAIN:
            is_training = True
        else:
            is_training = False

        hparams = hparams_lib.copy_hparams(hparams)

        # Instantiate model
        data_parallelism = None
        if not use_tpu and config:
            data_parallelism = config.data_parallelism
        reuse = tf.get_variable_scope().reuse

        # Instantiate model
        self = cls(hparams,
                   mode,
                   data_parallelism=data_parallelism,
                   decode_hparams=decode_hparams,
                   _reuse=reuse)

        generator_inputs = self.sample_noise()
        # rename inputs for clarity
        real_data = features['inputs']
        img_shape = common_layers.shape_list(real_data)[1:4]
        real_data.set_shape([hparams.batch_size] + img_shape)

        # To satify the TFGAN API setting real data to none on predict
        if mode == tf.estimator.ModeKeys.PREDICT:
            real_data = None

        optimizers = Optimizers(
            tf.compat.v1.train.AdamOptimizer(hparams.generator_lr,
                                             hparams.beta1),
            tf.compat.v1.train.AdamOptimizer(hparams.discriminator_lr,
                                             hparams.beta1))

        # Creates tfhub modules for both generator and discriminator
        def make_discriminator_spec():
            input_layer = tf.placeholder(tf.float32, shape=[None] + img_shape)
            disc_output = self.discriminator(input_layer, None, mode)
            hub.add_signature(inputs=input_layer, outputs=disc_output)

        disc_spec = hub.create_module_spec(make_discriminator_spec)

        def make_generator_spec():
            input_layer = tf.placeholder(
                tf.float32,
                shape=[None] + common_layers.shape_list(generator_inputs)[1:])
            gen_output = self.generator(input_layer, mode)
            hub.add_signature(inputs=input_layer, outputs=gen_output)

        gen_spec = hub.create_module_spec(make_generator_spec)

        # Create the modules
        discriminator_module = hub.Module(disc_spec,
                                          name="Discriminator_Module",
                                          trainable=True)
        generator_module = hub.Module(gen_spec,
                                      name="Generator_Module",
                                      trainable=True)

        # Wraps the modules into functions expected by TF-GAN
        def generator(code, mode):
            p = hparams
            out = generator_module(code)
            shape = common_layers.shape_list(out)
            # Applying convolution by PSF convolution
            if p.apply_psf and 'psf' in features:
                out = convolve(out,
                               tf.cast(features['psf'][..., 0], tf.complex64))

            # Adds noise according to the provided power spectrum
            noise = tf.spectral.rfft2d(tf.random_normal(out.get_shape()[:3]))
            thresholded_ps = tf.where(features['ps'] >= 9,
                                      tf.zeros_like(features['ps']),
                                      tf.sqrt(tf.exp(features['ps'])))
            noise = noise * tf.cast(thresholded_ps, tf.complex64)
            out = out + tf.expand_dims(tf.spectral.irfft2d(noise), axis=-1)
            return out

        discriminator = lambda image, conditioning, mode: discriminator_module(
            image)

        # Make GANModel, which encapsulates the GAN model architectures.
        gan_model = get_gan_model(mode,
                                  generator,
                                  discriminator,
                                  real_data,
                                  generator_inputs,
                                  add_summaries=self.summaries)

        # Make GANLoss, which encapsulates the losses.
        if mode in [tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL]:
            gan_loss = tfgan_train.gan_loss(gan_model,
                                            self.generator_loss,
                                            self.discriminator_loss,
                                            add_summaries=True)

        # Make the EstimatorSpec, which incorporates the GANModel, losses, eval
        # metrics, and optimizers (if required).
        if mode == tf.estimator.ModeKeys.TRAIN:
            get_hooks_fn = tfgan_train.get_sequential_train_hooks(
                namedtuples.GANTrainSteps(hparams.gen_steps,
                                          hparams.disc_steps))
            estimator_spec = get_train_estimator_spec(gan_model,
                                                      gan_loss,
                                                      optimizers,
                                                      get_hooks_fn,
                                                      is_chief=True)
        elif mode == tf.estimator.ModeKeys.EVAL:
            estimator_spec = get_eval_estimator_spec(gan_model, gan_loss)
        else:  # tf.estimator.ModeKeys.PREDICT
            # Register hub modules for export
            hub.register_module_for_export(generator_module, "generator")
            hub.register_module_for_export(discriminator_module,
                                           "discriminator")
            estimator_spec = get_predict_estimator_spec(gan_model)
        return estimator_spec