Ejemplo n.º 1
0
 def setUp(self):
     temp_dir = self.get_temp_dir()
     if TransformerTaskTest.local_flags is None:
         misc.define_transformer_flags()
         # Loads flags, array cannot be blank.
         flags.FLAGS(['foo'])
         TransformerTaskTest.local_flags = flagsaver.save_flag_values()
     else:
         flagsaver.restore_flag_values(TransformerTaskTest.local_flags)
     FLAGS.model_dir = os.path.join(temp_dir, FIXED_TIMESTAMP)
     FLAGS.param_set = 'tiny'
     FLAGS.use_synthetic_data = True
     FLAGS.steps_between_evals = 1
     FLAGS.train_steps = 2
     FLAGS.validation_steps = 1
     FLAGS.batch_size = 8
     FLAGS.num_gpus = 1
     FLAGS.distribution_strategy = 'off'
     FLAGS.dtype = 'fp32'
     self.model_dir = FLAGS.model_dir
     self.temp_dir = temp_dir
     self.vocab_file = os.path.join(temp_dir, 'vocab')
     self.vocab_size = misc.get_model_params(FLAGS.param_set,
                                             0)['vocab_size']
     self.bleu_source = os.path.join(temp_dir, 'bleu_source')
     self.bleu_ref = os.path.join(temp_dir, 'bleu_ref')
     self.orig_policy = tf.keras.mixed_precision.experimental.global_policy(
     )
Ejemplo n.º 2
0
        return opt


def _ensure_dir(log_dir):
    """Makes log dir if not existed."""
    if not tf.io.gfile.exists(log_dir):
        tf.io.gfile.makedirs(log_dir)


def main(_):
    flags_obj = flags.FLAGS
    with logger.benchmark_context(flags_obj):
        task = TransformerTask(flags_obj)

        if flags_obj.mode == "train":
            task.train()
        elif flags_obj.mode == "predict":
            task.predict()
        elif flags_obj.mode == "eval":
            task.eval()
        else:
            raise ValueError("Invalid mode {}".format(flags_obj.mode))


if __name__ == "__main__":
    tf.compat.v1.enable_v2_behavior()
    logging.set_verbosity(logging.INFO)
    misc.define_transformer_flags()
    app.run(main)