Ejemplo n.º 1
0
  def testNoneHparams(self):

    @registry.register_hparams
    def hp():
      pass

    with self.assertRaisesRegexp(TypeError, "is None"):
      registry.hparams("hp")
Ejemplo n.º 2
0
  def __prepare_model(self, train_mode=False):
    """Prepare utilities for decoding."""
    hparams = registry.hparams(self.params.hparams_set)
    hparams.problem = self.problem
    hparams.problem_hparams = self.problem.get_hparams(hparams)
    if self.params.hparams:
      tf.logging.info("Overriding hparams in %s with %s",
                      self.params.hparams_set,
                      self.params.hparams)
      hparams = hparams.parse(self.params.hparams)
    trainer_run_config = g2p_trainer_utils.create_run_config(hparams,
        self.params)
    if train_mode:
      exp_fn = g2p_trainer_utils.create_experiment_fn(self.params, self.problem)
      self.exp = exp_fn(trainer_run_config, hparams)

    decode_hp = decoding.decode_hparams(self.params.decode_hparams)
    estimator = trainer_lib.create_estimator(
        self.params.model_name,
        hparams,
        trainer_run_config,
        decode_hparams=decode_hp,
        use_tpu=False)

    return estimator, decode_hp, hparams
def training_loop_hparams_from_scoped_overrides(scoped_overrides, trial_id):
  """Create HParams suitable for training loop from scoped HParams.

  Args:
    scoped_overrides: HParams, with keys all scoped by one of HP_SCOPES. These
      parameters are overrides for the base HParams created by
      create_loop_hparams.
    trial_id: str, trial identifier. This is used to register unique HParams
      names for the underlying model and ppo HParams.

  Returns:
    HParams suitable for passing to training_loop.
  """
  trial_hp_overrides = scoped_overrides.values()

  # Create loop, model, and ppo base HParams
  loop_hp = create_loop_hparams()
  model_hp_name = trial_hp_overrides.get(
      "loop.generative_model_params", loop_hp.generative_model_params)
  model_hp = registry.hparams(model_hp_name).parse(FLAGS.hparams)
  base_algo_params_name = trial_hp_overrides.get(
      "loop.base_algo_params", loop_hp.base_algo_params)
  algo_hp = registry.hparams(base_algo_params_name)

  # Merge them and then override with the scoped overrides
  combined_hp = merge_unscoped_hparams(
      zip(HP_SCOPES, [loop_hp, model_hp, algo_hp]))
  combined_hp.override_from_dict(trial_hp_overrides)

  # Split out the component hparams
  loop_hp, model_hp, algo_hp = (
      split_scoped_hparams(HP_SCOPES, combined_hp))

  # Dynamic register the model hp and set the new name in loop_hp
  model_hp_name = "model_hp_%s" % str(trial_id)
  dynamic_register_hparams(model_hp_name, model_hp)
  loop_hp.generative_model_params = model_hp_name

  # Dynamic register the algo hp and set the new name in loop_hp
  algo_hp_name = "algo_hp_%s" % str(trial_id)
  dynamic_register_hparams(algo_hp_name, algo_hp)
  loop_hp.base_algo_params = algo_hp_name

  return loop_hp
Ejemplo n.º 4
0
  def testNamedRegistration(self):

    @registry.register_hparams("a")
    def my_hparams_set():
      return 7

    @registry.register_ranged_hparams("a")
    def my_hparams_range(_):
      pass

    self.assertEqual(registry.hparams("a"), my_hparams_set())
    self.assertTrue(registry.ranged_hparams("a") is my_hparams_range)
Ejemplo n.º 5
0
def create_hparams(hparams_set,
                   hparams_overrides_str="",
                   data_dir=None,
                   problem_name=None):
  hparams = registry.hparams(hparams_set)()
  if hparams_overrides_str:
    hparams = hparams.parse(hparams_overrides_str)
  if data_dir:
    hparams.add_hparam("data_dir", data_dir)
  if problem_name:
    add_problem_hparams(hparams, problem_name)
  return hparams
Ejemplo n.º 6
0
  def testHParamSet(self):

    @registry.register_hparams
    def my_hparams_set():
      return 3

    @registry.register_ranged_hparams
    def my_hparams_range(_):
      pass

    self.assertEqual(registry.hparams("my_hparams_set"), my_hparams_set())
    self.assertTrue(
        registry.ranged_hparams("my_hparams_range") is my_hparams_range)
Ejemplo n.º 7
0
 def testExperiment(self):
   exp_fn = lib.create_experiment_fn(
       "transformer",
       "tiny_algo",
       trainer_utils_test.TrainerUtilsTest.data_dir,
       train_steps=1,
       eval_steps=1,
       min_eval_frequency=1,
       use_tpu=False)
   run_config = lib.create_run_config(num_gpus=0, use_tpu=False)
   hparams = registry.hparams("transformer_tiny_tpu")()
   exp = exp_fn(run_config, hparams)
   exp.test()
Ejemplo n.º 8
0
 def testExperiment(self):
   exp_fn = trainer_lib.create_experiment_fn(
       "transformer",
       "tiny_algo",
       self.data_dir,
       train_steps=1,
       eval_steps=1,
       min_eval_frequency=1,
       use_tpu=False)
   run_config = trainer_lib.create_run_config(
       model_dir=self.data_dir, num_gpus=0, use_tpu=False)
   hparams = registry.hparams("transformer_tiny_tpu")
   exp = exp_fn(run_config, hparams)
   exp.test()
Ejemplo n.º 9
0
  def __init__(self,
               hparams,
               mode=tf.estimator.ModeKeys.TRAIN,
               problem_hparams=None,
               data_parallelism=None,
               decode_hparams=None):
    assert hparams.distill_phase in ["train", "distill"]

    if hparams.distill_phase == "train" and hparams.teacher_learning_rate:
      hparams.learning_rate = hparams.teacher_learning_rate
    elif hparams.distill_phase == "distill" and hparams.student_learning_rate:
      hparams.learning_rate = hparams.student_learning_rate

    self.teacher_hparams = registry.hparams(hparams.teacher_hparams)
    self.teacher_model = registry.model(
        hparams.teacher_model)(self.teacher_hparams, mode, problem_hparams,
                               data_parallelism, decode_hparams)
    self.student_hparams = registry.hparams(hparams.student_hparams)
    self.student_model = registry.model(
        hparams.student_model)(self.student_hparams, mode, problem_hparams,
                               data_parallelism, decode_hparams)
    super(Distillation, self).__init__(hparams, mode, problem_hparams,
                                       data_parallelism, decode_hparams)
Ejemplo n.º 10
0
def create_hparams(hparams_set,
                   hparams_overrides_str="",
                   data_dir=None,
                   problem_name=None):
    """Create HParams with data_dir and problem hparams, if kwargs provided."""
    hparams = registry.hparams(hparams_set)
    if data_dir:
        hparams.add_hparam("data_dir", data_dir)
    if problem_name:
        add_problem_hparams(hparams, problem_name)
    if hparams_overrides_str:
        tf.logging.info("Overriding hparams in %s with %s", hparams_set,
                        hparams_overrides_str)
        hparams = hparams.parse(hparams_overrides_str)
    return hparams
Ejemplo n.º 11
0
def create_hparams(hparams_set,
                   hparams_overrides_str="",
                   data_dir=None,
                   problem_name=None):
  """Create HParams with data_dir and problem hparams, if kwargs provided."""
  hparams = registry.hparams(hparams_set)
  if data_dir:
    hparams.add_hparam("data_dir", data_dir)
  if problem_name:
    add_problem_hparams(hparams, problem_name)
  if hparams_overrides_str:
    tf.logging.info("Overriding hparams in %s with %s", hparams_set,
                    hparams_overrides_str)
    hparams = hparams.parse(hparams_overrides_str)
  return hparams
Ejemplo n.º 12
0
 def testExperimentWithClass(self):
   exp_fn = trainer_lib.create_experiment_fn(
       "transformer",
       algorithmic.TinyAlgo(),
       algorithmic.TinyAlgo.data_dir,
       train_steps=1,
       eval_steps=1,
       min_eval_frequency=1,
       use_tpu=False)
   run_config = trainer_lib.create_run_config(
       model_name="transformer",
       model_dir=algorithmic.TinyAlgo.data_dir,
       num_gpus=0,
       use_tpu=False)
   hparams = registry.hparams("transformer_tiny_tpu")
   exp = exp_fn(run_config, hparams)
   exp.test()
Ejemplo n.º 13
0
 def testExperimentWithClass(self):
   exp_fn = trainer_lib.create_experiment_fn(
       "transformer",
       algorithmic.TinyAlgo(),
       algorithmic.TinyAlgo.data_dir,
       train_steps=1,
       eval_steps=1,
       min_eval_frequency=1,
       use_tpu=False)
   run_config = trainer_lib.create_run_config(
       model_name="transformer",
       model_dir=algorithmic.TinyAlgo.data_dir,
       num_gpus=0,
       use_tpu=False)
   hparams = registry.hparams("transformer_tiny_tpu")
   exp = exp_fn(run_config, hparams)
   exp.test()
Ejemplo n.º 14
0
    def _create_hparams(self, src_vocab_size, trg_vocab_size, hparams_set_name,
                        problem_name):
        """Creates hparams object.

        This method corresponds to create_hparams() in tensor2tensor's
        trainer_utils module, but replaces the feature encoders with
        DummyFeatureEncoder's.

        Args:
            src_vocab_size (int): Source vocabulary size.
            trg_vocab_size (int): Target vocabulary size.
            hparams_set_name (string): T2T hparams set name.
            problem_name (string): T2T problem name.

        Returns:
            hparams object.

        Raises:
            LookupError if the problem name is not in the registry or
            uses the old style problem_hparams.
        """
        hparams = registry.hparams(hparams_set_name)()
        problem = registry.problem(problem_name)
        # The following hack is necessary to prevent the problem from creating
        # the default TextEncoders, which would fail due to the lack of a
        # vocabulary file.
        problem._encoders = {
            "inputs": DummyTextEncoder(vocab_size=src_vocab_size),
            "targets": DummyTextEncoder(vocab_size=trg_vocab_size)
        }
        try:
            hparams.add_hparam("max_terminal_id", self.max_terminal_id)
        except:
            if hparams.max_terminal_id != self.max_terminal_id:
                logging.warn("T2T max_terminal_id does not match (%d!=%d)" %
                             (hparams.max_terminal_id, self.max_terminal_id))
        try:
            hparams.add_hparam("closing_bracket_id", self.pop_id)
        except:
            if hparams.closing_bracket_id != self.pop_id:
                logging.warn("T2T closing_bracket_id does not match (%d!=%d)" %
                             (hparams.closing_bracket_id, self.pop_id))
        p_hparams = problem.get_hparams(hparams)
        hparams.problem_instances = [problem]
        hparams.problems = [p_hparams]
        return hparams
Ejemplo n.º 15
0
def create_hparams(hparams_set,
                   hparams_overrides_str="",
                   data_dir=None,
                   problem_name=None):
    ipdb.set_trace()
    # 用命令行接收的超参数代替默认参数
    """Create HParams with data_dir and problem hparams, if kwargs provided."""

    # 这一步,解析model中注册的所有hparams添加到当前hparams中
    # registry._HPARAMS中存的是每个model注册hparams的函数,此步执行函数
    hparams = registry.hparams(hparams_set)()
    if data_dir:
        hparams.add_hparam("data_dir", data_dir)
    if problem_name:
        add_problem_hparams(hparams, problem_name)
    if hparams_overrides_str:
        tf.logging.info("Overriding hparams in %s with %s", hparams_set,
                        hparams_overrides_str)
        hparams = hparams.parse(hparams_overrides_str)
    return hparams
Ejemplo n.º 16
0
def create_hparams(params_id, data_dir):
  """Returns hyperparameters, including any flag value overrides.

  If the hparams FLAG is set, then it will use any values specified in
  hparams to override any individually-set hyperparameter. This logic
  allows tuners to override hyperparameter settings to find optimal values.

  Args:
    params_id: which set of parameters to choose (must be in _PARAMS above).
    data_dir: the directory containing the training data.

  Returns:
    The hyperparameters as a tf.contrib.training.HParams object.
  """
  hparams = registry.hparams(params_id)()
  hparams.add_hparam("data_dir", data_dir)
  # Command line flags override any of the preceding hyperparameter values.
  if FLAGS.hparams:
    hparams = hparams.parse(FLAGS.hparams)

  return add_problem_hparams(hparams, FLAGS.problems)
Ejemplo n.º 17
0
def create_hparams(params_id, data_dir, passed_hparams=None):
  """Returns hyperparameters, including any flag value overrides.

  If the hparams FLAG is set, then it will use any values specified in
  hparams to override any individually-set hyperparameter. This logic
  allows tuners to override hyperparameter settings to find optimal values.

  Args:
    params_id: which set of parameters to choose (must be in _PARAMS above).
    data_dir: the directory containing the training data.
    passed_hparams: command-line overrides for some hparams.

  Returns:
    The hyperparameters as a tf.contrib.training.HParams object.
  """
  hparams = registry.hparams(params_id)()
  hparams.add_hparam("data_dir", data_dir)
  # Command line flags override any of the preceding hyperparameter values.
  if passed_hparams:
    hparams = hparams.parse(passed_hparams)

  return hparams
Ejemplo n.º 18
0
def main(_):
    # gym.logger.set_level(gym.logger.DEBUG)
    hparams = registry.hparams(FLAGS.loop_hparams_set)
    hparams.parse(FLAGS.loop_hparams)
    # Not important for experiments past 2018
    if "wm_policy_param_sharing" not in hparams.values().keys():
        hparams.add_hparam("wm_policy_param_sharing", False)
    directories = player_utils.infer_paths(output_dir=FLAGS.output_dir,
                                           world_model=FLAGS.wm_dir,
                                           policy=FLAGS.policy_dir,
                                           data=FLAGS.episodes_data_dir)
    epoch = FLAGS.epoch if FLAGS.epoch == "last" else int(FLAGS.epoch)

    if FLAGS.simulated_env:
        env = player_utils.load_data_and_make_simulated_env(
            directories["data"],
            directories["world_model"],
            hparams,
            which_epoch_data=epoch)
    else:
        env = player_utils.setup_and_load_epoch(hparams,
                                                data_dir=directories["data"],
                                                which_epoch_data=epoch)
        env = FlatBatchEnv(env)

    env = PlayerEnvWrapper(env)  # pylint: disable=redefined-variable-type

    env = player_utils.wrap_with_monitor(env, FLAGS.video_dir)

    if FLAGS.dry_run:
        for _ in range(5):
            env.reset()
            for i in range(50):
                env.step(i % 3)
            env.step(PlayerEnvWrapper.RESET_ACTION)  # reset
        return

    play.play(env, zoom=FLAGS.zoom, fps=FLAGS.fps)
Ejemplo n.º 19
0
    def __prepare_model(self):
        """Prepare utilities for decoding."""
        hparams = registry.hparams(self.params.hparams_set)
        hparams.problem = self.problem
        hparams.problem_hparams = self.problem.get_hparams(hparams)
        if self.params.hparams:
            tf.logging.info("Overriding hparams in %s with %s",
                            self.params.hparams_set, self.params.hparams)
            hparams = hparams.parse(self.params.hparams)
        trainer_run_config = g2p_trainer_utils.create_run_config(
            hparams, self.params)
        exp_fn = g2p_trainer_utils.create_experiment_fn(
            self.params, self.problem)
        self.exp = exp_fn(trainer_run_config, hparams)

        decode_hp = decoding.decode_hparams(self.params.decode_hparams)
        estimator = trainer_lib.create_estimator(self.params.model_name,
                                                 hparams,
                                                 trainer_run_config,
                                                 decode_hparams=decode_hp,
                                                 use_tpu=False)

        return estimator, decode_hp, hparams
Ejemplo n.º 20
0
def main(_):
    hparams = registry.hparams(FLAGS.loop_hparams_set)
    hparams.parse(FLAGS.loop_hparams)
    output_dir = FLAGS.output_dir

    subdirectories = ["data", "tmp", "world_model", "ppo"]
    using_autoencoder = hparams.autoencoder_train_steps > 0
    if using_autoencoder:
        subdirectories.append("autoencoder")
    directories = setup_directories(output_dir, subdirectories)

    if hparams.game in gym_problems_specs.ATARI_GAMES:
        game_with_mode = hparams.game + "_deterministic-v4"
    else:
        game_with_mode = hparams.game

    if using_autoencoder:
        simulated_problem_name = (
            "gym_simulated_discrete_problem_with_agent_on_%s_autoencoded" %
            game_with_mode)
    else:
        simulated_problem_name = (
            "gym_simulated_discrete_problem_with_agent_on_%s" % game_with_mode)
        if simulated_problem_name not in registry.list_problems():
            tf.logging.info(
                "Game Problem %s not found; dynamically registering",
                simulated_problem_name)
            gym_problems_specs.create_problems_for_game(
                hparams.game, game_mode="Deterministic-v4")

    epoch = hparams.epochs - 1
    epoch_data_dir = os.path.join(directories["data"], str(epoch))
    ppo_model_dir = directories["ppo"]

    world_model_dir = directories["world_model"]

    gym_problem = registry.problem(simulated_problem_name)

    model_hparams = trainer_lib.create_hparams(hparams.generative_model_params)
    environment_spec = copy.copy(gym_problem.environment_spec)
    environment_spec.simulation_random_starts = hparams.simulation_random_starts

    batch_env_hparams = trainer_lib.create_hparams(hparams.ppo_params)
    batch_env_hparams.add_hparam("model_hparams", model_hparams)
    batch_env_hparams.add_hparam("environment_spec", environment_spec)
    batch_env_hparams.num_agents = 1

    with temporary_flags({
            "problem": simulated_problem_name,
            "model": hparams.generative_model,
            "hparams_set": hparams.generative_model_params,
            "output_dir": world_model_dir,
            "data_dir": epoch_data_dir,
    }):
        sess = tf.Session()
        env = DebugBatchEnv(batch_env_hparams, sess)
        sess.run(tf.global_variables_initializer())
        env.initialize()

        env_model_loader = tf.train.Saver(tf.global_variables("next_frame*"))
        trainer_lib.restore_checkpoint(world_model_dir,
                                       env_model_loader,
                                       sess,
                                       must_restore=True)

        model_saver = tf.train.Saver(
            tf.global_variables(".*network_parameters.*"))
        trainer_lib.restore_checkpoint(ppo_model_dir, model_saver, sess)

        key_mapping = gym_problem.env.env.get_keys_to_action()
        # map special codes
        key_mapping[()] = 100
        key_mapping[(ord("r"), )] = 101
        key_mapping[(ord("p"), )] = 102

        play.play(env, zoom=2, fps=10, keys_to_action=key_mapping)
Ejemplo n.º 21
0
 def testUnknownHparams(self):
     with self.assertRaisesRegexp(LookupError, "never registered"):
         registry.hparams("not_registered")
     with self.assertRaisesRegexp(LookupError, "never registered"):
         registry.ranged_hparams("not_registered")
Ejemplo n.º 22
0
    def benchmark(self, ckpt_dir, outer_steps=100, inner_steps=1000):
        """Run repeatedly on dummy data to benchmark inference."""
        # Turn off Grappler optimizations.
        options = {"disable_meta_optimizer": True}
        tf.config.optimizer.set_experimental_options(options)

        # Create the model outside the loop body.
        hparams = registry.hparams(self.hparams_set)
        hparams_lib.add_problem_hparams(hparams, self.problem_name)
        model_cls = registry.model(self.model_name)
        model = model_cls(hparams, tf.estimator.ModeKeys.EVAL)

        # Run only the model body (no data pipeline) on device.
        feature_shape = [
            hparams.batch_size, 3 * self.image_size * self.image_size
        ]
        features = {"targets": tf.zeros(feature_shape, dtype=tf.int32)}

        # Call the model once to initialize the variables. Note that
        # this should never execute.
        with tf.variable_scope(self.model_name) as vso:
            transformed_features = model.bottom(features)
            with tf.variable_scope("body") as vsi:
                body_out = model.body(transformed_features)
            logits = model.top(body_out, features)
            model.loss(logits, features)

        def call_model(features):
            with tf.variable_scope(vso, reuse=tf.AUTO_REUSE):
                transformed_features = model.bottom(features)
                with tf.variable_scope(vsi, reuse=tf.AUTO_REUSE):
                    body_out = model.body(transformed_features)
                logits = model.top(body_out, features)
                return model.loss(logits, features)

        # Run the function body in a loop to amortize session overhead.
        loop_index = tf.zeros([], dtype=tf.int32)
        initial_loss = (tf.zeros([]), tf.zeros([]))

        def loop_cond(idx, _):
            return tf.less(idx, tf.constant(inner_steps, dtype=tf.int32))

        def loop_body(idx, _):
            return idx + 1, call_model(features)

        benchmark_op = tf.while_loop(loop_cond,
                                     loop_body, [loop_index, initial_loss],
                                     parallel_iterations=1,
                                     back_prop=False)

        session_config = tf.ConfigProto(gpu_options=tf.GPUOptions(
            allow_growth=False, per_process_gpu_memory_fraction=0.95))
        run_metadata = tf.RunMetadata()
        with tf.Session(config=session_config) as sess:
            self.restore_model(sess, ckpt_dir)
            tps = []
            for idx in range(outer_steps):
                start_time = time.time()
                sess.run(benchmark_op, run_metadata=run_metadata)
                elapsed_time = time.time() - start_time
                tps.append(inner_steps * hparams.batch_size * (64 * 64 * 3) /
                           elapsed_time)
                logging.error("Iterations %d processed %f TPS.", idx, tps[-1])
            # Skip the first iteration where all the setup and allocation happens.
            tps = np.asarray(tps[1:])
            logging.error("Mean/Std/Max/Min throughput = %f / %f / %f / %f",
                          np.mean(tps), np.std(tps), tps.max(), tps.min())
Ejemplo n.º 23
0
def main(_):
    # gym.logger.set_level(gym.logger.DEBUG)
    hparams = registry.hparams(FLAGS.loop_hparams_set)
    hparams.parse(FLAGS.loop_hparams)
    # Not important for experiments past 2018
    if "wm_policy_param_sharing" not in hparams.values().keys():
        hparams.add_hparam("wm_policy_param_sharing", False)
    directories = player_utils.infer_paths(output_dir=FLAGS.output_dir,
                                           world_model=FLAGS.wm_dir,
                                           policy=FLAGS.policy_dir,
                                           data=FLAGS.episodes_data_dir)
    if FLAGS.game_from_filenames:
        hparams.set_hparam(
            "game",
            player_utils.infer_game_name_from_filenames(directories["data"]))
    action_meanings = gym.make(full_game_name(hparams.game)).\
        unwrapped.get_action_meanings()
    epoch = FLAGS.epoch if FLAGS.epoch == "last" else int(FLAGS.epoch)

    def make_real_env():
        env = player_utils.setup_and_load_epoch(hparams,
                                                data_dir=directories["data"],
                                                which_epoch_data=None)
        env = FlatBatchEnv(env)  # pylint: disable=redefined-variable-type
        return env

    def make_simulated_env(setable_initial_frames, which_epoch_data):
        env = player_utils.load_data_and_make_simulated_env(
            directories["data"],
            directories["world_model"],
            hparams,
            which_epoch_data=which_epoch_data,
            setable_initial_frames=setable_initial_frames)
        return env

    if FLAGS.sim_and_real:
        sim_env = make_simulated_env(which_epoch_data=None,
                                     setable_initial_frames=True)
        real_env = make_real_env()
        env = SimAndRealEnvPlayer(real_env, sim_env, action_meanings)
    else:
        if FLAGS.simulated_env:
            env = make_simulated_env(  # pylint: disable=redefined-variable-type
                which_epoch_data=epoch,
                setable_initial_frames=False)
        else:
            env = make_real_env()
        env = SingleEnvPlayer(env, action_meanings)  # pylint: disable=redefined-variable-type

    env = player_utils.wrap_with_monitor(env, FLAGS.video_dir)

    if FLAGS.dry_run:
        env.unwrapped.get_keys_to_action()
        for _ in range(5):
            env.reset()
            for i in range(50):
                env.step(i % 3)
            env.step(PlayerEnv.RETURN_DONE_ACTION)  # reset
        return

    play.play(env, zoom=FLAGS.zoom, fps=FLAGS.fps)
def main(_):
  hparams = registry.hparams(FLAGS.loop_hparams_set)
  hparams.parse(FLAGS.loop_hparams)
  output_dir = FLAGS.output_dir

  subdirectories = ["data", "tmp", "world_model", "ppo"]
  using_autoencoder = hparams.autoencoder_train_steps > 0
  if using_autoencoder:
    subdirectories.append("autoencoder")
  directories = setup_directories(output_dir, subdirectories)

  if hparams.game in gym_env.ATARI_GAMES:
    game_with_mode = hparams.game + "_deterministic-v4"
  else:
    game_with_mode = hparams.game

  if using_autoencoder:
    simulated_problem_name = (
        "gym_simulated_discrete_problem_with_agent_on_%s_autoencoded"
        % game_with_mode)
  else:
    simulated_problem_name = ("gym_simulated_discrete_problem_with_agent_on_%s"
                              % game_with_mode)
    if simulated_problem_name not in registry.list_problems():
      tf.logging.info("Game Problem %s not found; dynamically registering",
                      simulated_problem_name)
      gym_env.register_game(hparams.game, game_mode="Deterministic-v4")

  epoch = hparams.epochs-1
  epoch_data_dir = os.path.join(directories["data"], str(epoch))
  ppo_model_dir = directories["ppo"]

  world_model_dir = directories["world_model"]

  gym_problem = registry.problem(simulated_problem_name)

  model_hparams = trainer_lib.create_hparams(hparams.generative_model_params)
  environment_spec = copy.copy(gym_problem.environment_spec)
  environment_spec.simulation_random_starts = hparams.simulation_random_starts

  batch_env_hparams = trainer_lib.create_hparams(hparams.ppo_params)
  batch_env_hparams.add_hparam("model_hparams", model_hparams)
  batch_env_hparams.add_hparam("environment_spec", environment_spec)
  batch_env_hparams.num_agents = 1

  with temporary_flags({
      "problem": simulated_problem_name,
      "model": hparams.generative_model,
      "hparams_set": hparams.generative_model_params,
      "output_dir": world_model_dir,
      "data_dir": epoch_data_dir,
  }):
    sess = tf.Session()
    env = DebugBatchEnv(batch_env_hparams, sess)
    sess.run(tf.global_variables_initializer())
    env.initialize()

    env_model_loader = tf.train.Saver(
        tf.global_variables("next_frame*"))
    trainer_lib.restore_checkpoint(world_model_dir, env_model_loader, sess,
                                   must_restore=True)

    model_saver = tf.train.Saver(
        tf.global_variables(".*network_parameters.*"))
    trainer_lib.restore_checkpoint(ppo_model_dir, model_saver, sess)

    key_mapping = gym_problem.env.env.get_keys_to_action()
    # map special codes
    key_mapping[()] = 100
    key_mapping[(ord("r"),)] = 101
    key_mapping[(ord("p"),)] = 102

    play.play(env, zoom=2, fps=10, keys_to_action=key_mapping)
 def test_train_pong(self):
     hparams = registry.hparams("pong_model_free")
     hparams.epochs_num = 2
     hparams.num_agents = 2
     hparams.epoch_length = 3
     rl_trainer_lib.train(hparams)
Ejemplo n.º 26
0
def main():
    args = parser.parse_args()
    enc = encoder.get_encoder(args.model_name)
    hparams = model.default_hparams()
    hparams.res_dropout = args.dropout
    hparams.attn_dropout = args.dropout
    epsilon = -1e10
    if args.dtype == 'float32':
        hparams.dtype = tf.float32
    elif args.dtype == 'float16':
        hparams.dtype = tf.float16
        epsilon = -65500
    elif args.dtype == 'bfloat16':
        hparams.dtype = tf.bfloat16
        epsilon = -65500
    else:
        print('Unknown dtype', args.dtype)
    if args.float16:
        hparams.dtype = tf.bfloat16
        epsilon = -65500

    with open(os.path.join('models', args.model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))
    if args.n_ctx >= 0:
        hparams.n_ctx=args.n_ctx
    if args.n_embd >= 0:
        hparams.n_embd=args.n_embd
    if args.n_head >= 0:
        hparams.n_head=args.n_head
    if args.n_layer >= 0:
        hparams.n_layer=args.n_layer

    if args.sample_length < 0:
        args.sample_length = hparams.n_ctx - 1
    if args.sample_length > hparams.n_ctx:
        raise ValueError(
            "Can't get samples longer than window size: %s" % hparams.n_ctx)
    if args.sample_ctx < 0:
      args.sample_ctx = hparams.n_ctx

    if args.model_name == '345M':
        args.memory_saving_gradients = True
        if args.optimizer == 'adam':
            args.only_train_transformer_layers = True

    config = tf.ConfigProto()
    if args.allow_growth:
        config.gpu_options.allow_growth = True
    if args.disable_layout_optimizer:
        config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tflex.Session(config=config, init_tpu=args.init_tpu) as sess:
        context = tf.placeholder(tf.int32, [args.batch_size, None])
        context_in = randomize(context, hparams, args.noise)
        output = model.model(hparams=hparams, X=context_in)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        if args.val_every > 0:
            val_context = tf.placeholder(tf.int32, [args.val_batch_size, None])
            val_output = model.model(hparams=hparams, X=val_context)
            val_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=val_context[:, 1:], logits=val_output['logits'][:, :-1]))
            val_loss_summary = tf.summary.scalar('val_loss', val_loss)


        tf_sample = sample.sample_sequence(
            hparams=hparams,
            length=args.sample_length,
            context=context,
            batch_size=args.batch_size,
            temperature=1.0,
            top_k=args.top_k,
            top_p=args.top_p,
            epsilon=epsilon)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars

        parameter_count = sum([np.prod(v.shape.as_list()) for v in train_vars])
        print("This model is using %d parameters (%.2fM)" % (parameter_count, parameter_count/(1024.0*1024.0)))

        with tf.variable_scope(tf.get_variable_scope().name, reuse=tf.AUTO_REUSE):
            global_step = tflex.get_variable('global_step') or tf.get_variable('global_step', shape=(), dtype=tf.int32, trainable=False)
            current_step = args.learning_rate_initial_step
            global_step.load(current_step, session=sess)
            if args.learning_rate_cos:
                lr = tflex_sgdr.sgdr_decay_with_warmup(args.learning_rate, global_step,
                    warmup_steps=args.learning_rate_warmup, initial_period_steps=args.learning_rate_period, learning_rate_min=args.learning_rate_min)
            else:
                lr = tflex.get_variable('learn_rate') or tf.get_variable('learn_rate', shape=(), dtype=tf.float32, trainable=False)
                lr.load(args.learning_rate, session=sess)

        def update_lr(rate=None, step=None):
          if not args.learning_rate_cos:
            if step is None:
              step = global_step.eval(session=sess)
            if rate is None:
              rate = args.learning_rate
            if callable(rate):
              rate = rate(step)
            lr.load(rate, session=sess)
          return lr.eval(session=sess)

        @tflex.register_command
        def set_learning_rate():
          print("Current learn rate: %0.8f" % update_lr())
          print("New learn rate?")
          rate = input('')
          if not rate:
            print("Empty input; not changing anything.")
          else:
            try:
              rate = float(rate)
            except:
              print("Invalid input; must be a float")
          print("Setting learn rate to %0.8f" % rate)
          args.learning_rate = rate

        if args.optimizer == 'adam':
            opt = tf.train.AdamOptimizer(learning_rate=lr)
        elif args.optimizer == 'sgd':
            opt = tf.train.GradientDescentOptimizer(learning_rate=lr)
        elif args.optimizer == 'ada':
            import tensor2tensor.utils.optimize
            from tensor2tensor.utils import hparam
            import tensor2tensor.models.research
            from tensor2tensor.utils import registry
            ada_hparams = registry.hparams('afx_mimic_adam')
            ada_hparams.optimizer_adafactor_beta1 = 0.0
            ada_hparams.optimizer_adafactor_factored = True
            opt = tensor2tensor.utils.optimize.adafactor(learning_rate=lr, hparams=ada_hparams)
        else:
            exit('Bad optimizer:', args.optimizer)

        #if tpu_addr:
        #    # https://pulsejet.github.io/blog/posts/tpu-without-estimator/
        #    from tensorflow.contrib.tpu.python.tpu import tpu_function
        #    tpu_function.get_tpu_context().set_number_of_shards(8)
        #    opt = tf.contrib.tpu.CrossShardOptimizer(opt)

        if args.accumulate_gradients > 1:
            if args.memory_saving_gradients:
                exit("Memory saving gradients are not implemented for gradient accumulation yet.")
            opt = AccumulatingOptimizer(
                opt=opt,
                var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            if args.memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_lr = tf.summary.scalar('learning_rate', lr)
        summaries = tf.summary.merge([summary_lr, summary_loss])

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        if args.save_graph:
            summary_log.add_graph(tf.get_default_graph())

        saver = tflex.Saver(
            var_list=all_vars,
            max_to_keep=args.max_to_keep,
            keep_checkpoint_every_n_hours=100000,
            reshape=args.truncate_weights)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tflex.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tflex.latest_checkpoint(
                    os.path.join('models', args.model_name))
        elif args.restore_from == 'fresh':
            ckpt = tflex.latest_checkpoint(
                os.path.join('models', args.model_name))
        else:
            ckpt = tflex.latest_checkpoint(args.restore_from)
        print('Loading snapshot %s...' % ckpt)
        t0 = time.time()
        if not args.fresh_model:
            saver.restore(sess, ckpt)
        t1 = time.time()
        print('Loaded in %f seconds' % (t1 - t0))

        def make_sampler(dataset, enc, seed, combine):
          if os.path.isdir(dataset) or dataset.endswith('.npz'):
            chunks = load_dataset(enc, dataset, combine)
            data_sampler = Sampler(chunks, seed=seed)
            print('dataset has', data_sampler.total_size, 'tokens', len(chunks), 'chunks')
          else:
            data_sampler = TextSampler(dataset, enc, seed=seed)
          return data_sampler

        print('Loading dataset...')
        seed = None if args.seed < 0 else args.seed
        data_sampler = make_sampler(dataset=args.dataset, enc=enc, seed=seed, combine=args.combine)
        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_dataset = args.val_dataset if args.val_dataset else args.dataset
            val_data_sampler = make_sampler(dataset=val_dataset, enc=enc, seed=1, combine=args.combine)
            val_batches = [[val_data_sampler.sample(hparams.n_ctx) for _ in range(args.val_batch_size)]
                           for _ in range(args.val_batch_count)]

        print('Training...')
        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        @tflex.register_command
        def get_tarfile_name(checkpoint_folder):
            """Converts a folder path into a filename for a .tar archive"""
            tarfile_name = checkpoint_folder.replace(os.path.sep, '_') + '.tar'

            return tarfile_name


        def copy_checkpoint_to_gdrive(run_name='run1', copy_folder=False):
            """Copies the checkpoint folder to a mounted Google Drive."""
            #is_mounted()

            checkpoint_folder = os.path.join('checkpoint', run_name)

            if copy_folder:
                shutil.copytree(checkpoint_folder, "/content/drive/My Drive/" + checkpoint_folder)
            else:
                file_path = get_tarfile_name(checkpoint_folder)

                # Reference: https://stackoverflow.com/a/17081026
                with tarfile.open(file_path, 'w') as tar:
                    tar.add(checkpoint_folder)

                shutil.copyfile(file_path, "/content/drive/My Drive/" + file_path)

        @tflex.register_command
        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            t0 = time.time()
            saver.save(
                sess,
                os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                global_step=counter)
            t1 = time.time()
            print('Saved in %f seconds' % (t1 - t0))
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')
            #copy_checkpoint_to_gdrive()

        @tflex.register_command
        def generate_samples():
            print('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: args.batch_size * [context_tokens]})
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    print(text)
                    all_text.append(text)
                    index += 1
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(
                    os.path.join(SAMPLE_DIR, args.run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))

        @tflex.register_command
        def validation():
            if args.val_every <= 0:
              return
            print('Calculating validation loss...')
            losses = []
            for batch in tqdm.tqdm(val_batches):
                losses.append(sess.run(val_loss, feed_dict={val_context: batch}))
            v_val_loss = np.mean(losses)
            v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss})
            summary_log.add_summary(v_summary, counter)
            summary_log.flush()
            print(
                '{stamp} [{counter} | {time:2.4f}] validation loss = {loss:2.4f}'
                .format(
                    stamp=timestamp(),
                    counter=counter,
                    time=time.time() - start_time,
                    loss=v_val_loss))

        start_time = time.time()

        def elapsed():
            return time.time() - start_time

        def say(msg):
            print('{stamp} [{counter} | {time:2.4f}] {msg}'.format(counter=counter, time=elapsed(), msg=msg, stamp=timestamp()))

        def sample_batch():
            #return [data_sampler.sample(args.sample_ctx) for _ in range(args.batch_size)]
            #say('Sampling batch...')
            r = []
            times = []
            for _ in range(args.batch_size):
                start = time.time()
                sample = data_sampler.sample(args.sample_ctx)
                end = time.time()
                elapsed = (end - start)
                r += [sample]
                times += [elapsed]
            total = sum(times)
            avg = total / len(times)
            #say('Sampled %d batches in %.4f seconds (avg per batch: %.4f)' % (args.batch_size, total, avg))
            return r

        prev_time = time.time()
        avg_loss = (0.0, 0.0)

        if args.debug_before_training:
            import pdb
            pdb.set_trace()

        last_saved_time = elapsed()
        while True:
            try:
                now = elapsed()
                if args.save_time > 0 and (((now - last_saved_time) / 60.0) >= args.save_time):
                    save()
                    last_saved_time = now
                elif args.save_every > 0 and (counter % args.save_every == 0):
                    save()
                if counter % args.sample_every == 0:
                    generate_samples()
                if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1):
                    validation()

                v_rate = update_lr()

                if args.accumulate_gradients > 1:
                    #say('Running opt_reset...')
                    sess.run(opt_reset)
                    for _ in range(args.accumulate_gradients):
                        batch = sample_batch()
                        say('Running opt_compute...')
                        sess.run(opt_compute, feed_dict={context: batch})
                    say('Running opt_apply...')
                    (v_loss, v_summary) = sess.run((opt_apply, summaries))
                else:
                    batch = sample_batch()
                    say('Running opt_apply...')
                    (_, v_loss, v_summary) = sess.run(
                        (opt_apply, loss, summaries),
                        feed_dict={context: batch})

                if args.float16:
                    v_loss = tf.to_float(v_loss).eval()

                summary_log.add_summary(v_summary, counter)
                summary_log.flush()

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                now = time.time()
                print('{stamp} [{counter} | {time:2.4f} | {delta:2.2f}s | {ops:2.6f}tokens/s] loss={loss:2.4f} avg={avg:2.4f} rate={rate:0.7f} step={step}'
                    .format(
                        stamp=timestamp(),
                        counter=counter,
                        time=now - start_time,
                        delta=now - prev_time,
                        ops=args.sample_ctx * args.batch_size / (now - prev_time),
                        rate=v_rate,
                        loss=v_loss,
                        avg=avg_loss[0] / avg_loss[1],
                        step=current_step,
                        ))

                counter += 1
                current_step += 1
                global_step.load(current_step, session=sess)

                tflex.check_commands_with_args(
                    session=sess,
                    stamp=timestamp(),
                    counter=counter,
                    time=now - start_time,
                    delta=now - prev_time,
                    ops=args.batch_size / (now - prev_time),
                    rate=v_rate,
                    loss=v_loss,
                    avg=avg_loss[0] / avg_loss[1],
                    avg_loss=avg_loss,
                    step=current_step,
                    train_vars=train_vars,
                    all_vars=all_vars,
                    args=args,
                    data_sampler=data_sampler,
                    ckpt=ckpt,
                    saver=saver,
                    )
                if tflex.should_quit():
                  break

                prev_time = now
                if args.debug_print_all_vars:
                    print('all variables:')
                    print('name/shape/parameter_count')
                    param_count = 0
                    for x in tf.all_variables():
                        shape = x.shape.as_list()
                        count = np.prod(shape)
                        print(x.name, shape, count)
                        param_count += count
                    print('Total parameters:', param_count)
                    args.debug_print_all_vars = False

                if args.debug_print_trainable_vars:
                    print('trainable variables:')
                    print('name/shape/parameter_count')
                    param_count = 0
                    for x in tf.trainable_variables():
                        shape = x.shape.as_list()
                        count = np.prod(shape)
                        print(x.name, shape, count)
                        param_count += count
                    print('Total parameters:', param_count)
                    args.debug_print_trainable_vars = False
            except KeyboardInterrupt:
                print('interrupted')
                if args.save_on_ctrlc:
                    save()
                if args.debug_on_ctrlc:
                    import pdb
                    pdb.set_trace()
                else:
                    break
Ejemplo n.º 27
0
def compute_metrics_v2(problem_name,
                       model_name,
                       hparams_name,
                       ckpt_dir,
                       data_dir="/tmp",
                       eval_batch_size=32,
                       eval_steps=100,
                       extra_hparams=[],
                       mode=Modes.EVAL,
                       num_threshold_bins=100):

  registered_model = registry.model(model_name)

  hparams = registry.hparams(hparams_name)
  hparams.mode = mode

  for extra_hparam in extra_hparams:
    assert len(extra_hparam) == 2
    if extra_hparam[0] == "mode":
      continue
    hparams.setattr(extra_hparam[0], extra_hparam[1])

  problem_instance = registry.problem(problem_name)
  problem_hparams = problem_instance.get_hparams(hparams)

  # Build the eval dataset and get the examples
  eval_dataset = problem_instance.dataset(mode=Modes.EVAL, data_dir=data_dir)

  eval_dataset = eval_dataset.repeat(None).batch(eval_batch_size)
  eval_dataset_iterator = tfe.Iterator(eval_dataset)

  metrics = {}

  def _merge(metrics, metrics_partial):
    for key, value in metrics_partial.items():
      if key not in metrics:
        metrics[key] = value
      else:
        metrics[key] += value
    return metrics

  with tfe.restore_variables_on_create(ckpt_dir):

    model_instance = registered_model(hparams, mode, problem_hparams)

    for i in range(eval_steps):

      try:

        eval_examples = eval_dataset_iterator.next()

        metrics_partial = model_instance.eager_eval(eval_examples)
        metrics = _merge(metrics, metrics_partial)

        if i % 10 == 0:
          msg = "Finished collecting predictions for eval step {}.".format(i)
          tf.logging.info(msg)

      except:
        # Seeing rare CBT deadline exceeded errors and don't know how to modfiy
        # the deadline... More likely to run into error with more iterations,
        # wasn't seeing it with 10 and almost always seeing it with 100.
        # Could conceivably have to do with running out of examples in the eval
        # set... but there should be over 20k and this would only go through
        # 3200.
        msg = "HACK: Squashing inference error."
        tf.logging.info(msg)

  for key, value in metrics.items():
    metrics[key] = value / eval_steps

  return metrics
Ejemplo n.º 28
0
def compute_metrics(problem_name,
                    model_name,
                    hparams_name,
                    ckpt_dir,
                    data_dir="/tmp",
                    eval_batch_size=32,
                    eval_steps=100,
                    extra_hparams=[],
                    mode=Modes.EVAL,
                    num_threshold_bins=100):

  if not isinstance(num_threshold_bins, int) and num_threshold_bins > 0:
    msg = "Num threshold bins should be int > 0, saw {}".format(
        num_threshold_bins)
    raise ValueError(msg)

  registered_model = registry.model(model_name)

  hparams = registry.hparams(hparams_name)
  hparams.mode = mode

  for extra_hparam in extra_hparams:
    assert len(extra_hparam) == 2
    if extra_hparam[0] == "mode":
      continue
    hparams.setattr(extra_hparam[0], extra_hparam[1])

  problem_instance = registry.problem(problem_name)
  problem_hparams = problem_instance.get_hparams(hparams)

  # Build the eval dataset and get the examples
  eval_dataset = problem_instance.dataset(mode=Modes.EVAL, data_dir=data_dir)

  eval_dataset = eval_dataset.repeat(None).batch(eval_batch_size)
  eval_dataset_iterator = tfe.Iterator(eval_dataset)

  with tfe.restore_variables_on_create(ckpt_dir):

    model_instance = registered_model(hparams, mode, problem_hparams)

    predictions = np.array([], dtype=np.float32)
    targets = np.array([], dtype=np.float32)

    for i in range(eval_steps):

      try:

        eval_examples = eval_dataset_iterator.next()

        prediction = model_instance.infer(eval_examples)

        # We've concatenated the two embedding vectors followed
        # by the label so we can obtain the label just by looking
        # at the last value
        prediction = np.array([thing[-1] for thing in prediction],
                              dtype=np.float32)

        target = tf.squeeze(eval_examples["targets"]).numpy().astype(np.float32)

        predictions = np.concatenate([predictions, prediction])
        targets = np.concatenate([targets, target])

        if i % 10 == 0:
          msg = "Finished collecting predictions for eval step {}.".format(i)
          tf.logging.info(msg)

      except Exception as e:
        # Seeing rare CBT deadline exceeded errors and don't know how to modfiy
        # the deadline... More likely to run into error with more iterations,
        # wasn't seeing it with 10 and almost always seeing it with 100.
        # Could conceivably have to do with running out of examples in the eval
        # set... but there should be over 20k and this would only go through
        # 3200.
        msg = "HACK: Squashing inference error."
        tf.logging.info(msg)

  metrics_set = []

  for i in range(num_threshold_bins):
    threshold = i / num_threshold_bins
    metrics = _metrics_given_threshold(predictions,
                                       targets,
                                       threshold=threshold)
    metrics["at_threshold"] = threshold
    metrics["num_threshold_bins"] = num_threshold_bins
    metrics_set.append(metrics)

  midpoint_metrics = _metrics_given_threshold(predictions,
                                              targets,
                                              threshold=0.5)
  midpoint_metrics["auc"] = _auc_for_metrics_set(metrics_set)

  return midpoint_metrics, metrics_set, predictions, targets
 def _test_hparams_set(self, hparams_set):
   hparams = registry.hparams(hparams_set)
   FLAGS.output_dir = tf.test.get_temp_dir()
   trainer_model_free.train(hparams, FLAGS.output_dir,
                            env_problem_name=None)
Ejemplo n.º 30
0
def create_hparams():
    hparams = registry.hparams(FLAGS.hparams_set)()
    if FLAGS.hparams:
        hparams = hparams.parse(FLAGS.hparams)
    return hparams
Ejemplo n.º 31
0
def create_rl_hparams():
    hparams = registry.hparams(FLAGS.rl_hparams_set)()
    hparams.parse(FLAGS.rl_hparams)
    return hparams
Ejemplo n.º 32
0
 def testUnknownHparams(self):
   with self.assertRaisesRegexp(LookupError, "never registered"):
     registry.hparams("not_registered")
   with self.assertRaisesRegexp(LookupError, "never registered"):
     registry.ranged_hparams("not_registered")
# -*- coding: utf-8 -*-
"""
@author: 代码医生工作室 
@公众号:xiangyuejiqiren   (内有更多优秀文章及学习资料)
@来源: <深度学习之TensorFlow工程化项目实战>配套代码 (700+页)
@配套代码技术支持:bbs.aianaconda.com      (有问必答)
"""

#6-19

import tensorflow as tf
from tensor2tensor import models

from tensor2tensor.utils import t2t_model
from tensor2tensor.utils import registry

print(len(registry.list_models()), registry.list_models())
print(registry.model('transformer'))
print(len(registry.list_hparams()), registry.list_hparams('transformer'))
print(registry.hparams('transformer_base_v1'))
Ejemplo n.º 34
0
 def autoencoder_factor(self):
   """By how much to divide sizes when using autoencoders."""
   hparams = registry.hparams(self.ae_hparams_set)
   return 2**hparams.num_hidden_layers
def create_loop_hparams():
  hparams = registry.hparams(FLAGS.loop_hparams_set)
  hparams.parse(FLAGS.loop_hparams)
  return hparams
Ejemplo n.º 36
0
def create_loop_hparams():
    hparams = registry.hparams(FLAGS.loop_hparams_set)
    hparams.parse(FLAGS.loop_hparams)
    return hparams
Ejemplo n.º 37
0
  def _lookup_hparams(self):

    self.hparams = registry.hparams(self.hparams_set)
    self.hparams.data_dir = self.data_dir
    self.p_hparams = self.problem.get_hparams(self.hparams)