def testNoneHparams(self): @registry.register_hparams def hp(): pass with self.assertRaisesRegexp(TypeError, "is None"): registry.hparams("hp")
def __prepare_model(self, train_mode=False): """Prepare utilities for decoding.""" hparams = registry.hparams(self.params.hparams_set) hparams.problem = self.problem hparams.problem_hparams = self.problem.get_hparams(hparams) if self.params.hparams: tf.logging.info("Overriding hparams in %s with %s", self.params.hparams_set, self.params.hparams) hparams = hparams.parse(self.params.hparams) trainer_run_config = g2p_trainer_utils.create_run_config(hparams, self.params) if train_mode: exp_fn = g2p_trainer_utils.create_experiment_fn(self.params, self.problem) self.exp = exp_fn(trainer_run_config, hparams) decode_hp = decoding.decode_hparams(self.params.decode_hparams) estimator = trainer_lib.create_estimator( self.params.model_name, hparams, trainer_run_config, decode_hparams=decode_hp, use_tpu=False) return estimator, decode_hp, hparams
def training_loop_hparams_from_scoped_overrides(scoped_overrides, trial_id): """Create HParams suitable for training loop from scoped HParams. Args: scoped_overrides: HParams, with keys all scoped by one of HP_SCOPES. These parameters are overrides for the base HParams created by create_loop_hparams. trial_id: str, trial identifier. This is used to register unique HParams names for the underlying model and ppo HParams. Returns: HParams suitable for passing to training_loop. """ trial_hp_overrides = scoped_overrides.values() # Create loop, model, and ppo base HParams loop_hp = create_loop_hparams() model_hp_name = trial_hp_overrides.get( "loop.generative_model_params", loop_hp.generative_model_params) model_hp = registry.hparams(model_hp_name).parse(FLAGS.hparams) base_algo_params_name = trial_hp_overrides.get( "loop.base_algo_params", loop_hp.base_algo_params) algo_hp = registry.hparams(base_algo_params_name) # Merge them and then override with the scoped overrides combined_hp = merge_unscoped_hparams( zip(HP_SCOPES, [loop_hp, model_hp, algo_hp])) combined_hp.override_from_dict(trial_hp_overrides) # Split out the component hparams loop_hp, model_hp, algo_hp = ( split_scoped_hparams(HP_SCOPES, combined_hp)) # Dynamic register the model hp and set the new name in loop_hp model_hp_name = "model_hp_%s" % str(trial_id) dynamic_register_hparams(model_hp_name, model_hp) loop_hp.generative_model_params = model_hp_name # Dynamic register the algo hp and set the new name in loop_hp algo_hp_name = "algo_hp_%s" % str(trial_id) dynamic_register_hparams(algo_hp_name, algo_hp) loop_hp.base_algo_params = algo_hp_name return loop_hp
def testNamedRegistration(self): @registry.register_hparams("a") def my_hparams_set(): return 7 @registry.register_ranged_hparams("a") def my_hparams_range(_): pass self.assertEqual(registry.hparams("a"), my_hparams_set()) self.assertTrue(registry.ranged_hparams("a") is my_hparams_range)
def create_hparams(hparams_set, hparams_overrides_str="", data_dir=None, problem_name=None): hparams = registry.hparams(hparams_set)() if hparams_overrides_str: hparams = hparams.parse(hparams_overrides_str) if data_dir: hparams.add_hparam("data_dir", data_dir) if problem_name: add_problem_hparams(hparams, problem_name) return hparams
def testHParamSet(self): @registry.register_hparams def my_hparams_set(): return 3 @registry.register_ranged_hparams def my_hparams_range(_): pass self.assertEqual(registry.hparams("my_hparams_set"), my_hparams_set()) self.assertTrue( registry.ranged_hparams("my_hparams_range") is my_hparams_range)
def testExperiment(self): exp_fn = lib.create_experiment_fn( "transformer", "tiny_algo", trainer_utils_test.TrainerUtilsTest.data_dir, train_steps=1, eval_steps=1, min_eval_frequency=1, use_tpu=False) run_config = lib.create_run_config(num_gpus=0, use_tpu=False) hparams = registry.hparams("transformer_tiny_tpu")() exp = exp_fn(run_config, hparams) exp.test()
def testExperiment(self): exp_fn = trainer_lib.create_experiment_fn( "transformer", "tiny_algo", self.data_dir, train_steps=1, eval_steps=1, min_eval_frequency=1, use_tpu=False) run_config = trainer_lib.create_run_config( model_dir=self.data_dir, num_gpus=0, use_tpu=False) hparams = registry.hparams("transformer_tiny_tpu") exp = exp_fn(run_config, hparams) exp.test()
def __init__(self, hparams, mode=tf.estimator.ModeKeys.TRAIN, problem_hparams=None, data_parallelism=None, decode_hparams=None): assert hparams.distill_phase in ["train", "distill"] if hparams.distill_phase == "train" and hparams.teacher_learning_rate: hparams.learning_rate = hparams.teacher_learning_rate elif hparams.distill_phase == "distill" and hparams.student_learning_rate: hparams.learning_rate = hparams.student_learning_rate self.teacher_hparams = registry.hparams(hparams.teacher_hparams) self.teacher_model = registry.model( hparams.teacher_model)(self.teacher_hparams, mode, problem_hparams, data_parallelism, decode_hparams) self.student_hparams = registry.hparams(hparams.student_hparams) self.student_model = registry.model( hparams.student_model)(self.student_hparams, mode, problem_hparams, data_parallelism, decode_hparams) super(Distillation, self).__init__(hparams, mode, problem_hparams, data_parallelism, decode_hparams)
def create_hparams(hparams_set, hparams_overrides_str="", data_dir=None, problem_name=None): """Create HParams with data_dir and problem hparams, if kwargs provided.""" hparams = registry.hparams(hparams_set) if data_dir: hparams.add_hparam("data_dir", data_dir) if problem_name: add_problem_hparams(hparams, problem_name) if hparams_overrides_str: tf.logging.info("Overriding hparams in %s with %s", hparams_set, hparams_overrides_str) hparams = hparams.parse(hparams_overrides_str) return hparams
def testExperimentWithClass(self): exp_fn = trainer_lib.create_experiment_fn( "transformer", algorithmic.TinyAlgo(), algorithmic.TinyAlgo.data_dir, train_steps=1, eval_steps=1, min_eval_frequency=1, use_tpu=False) run_config = trainer_lib.create_run_config( model_name="transformer", model_dir=algorithmic.TinyAlgo.data_dir, num_gpus=0, use_tpu=False) hparams = registry.hparams("transformer_tiny_tpu") exp = exp_fn(run_config, hparams) exp.test()
def _create_hparams(self, src_vocab_size, trg_vocab_size, hparams_set_name, problem_name): """Creates hparams object. This method corresponds to create_hparams() in tensor2tensor's trainer_utils module, but replaces the feature encoders with DummyFeatureEncoder's. Args: src_vocab_size (int): Source vocabulary size. trg_vocab_size (int): Target vocabulary size. hparams_set_name (string): T2T hparams set name. problem_name (string): T2T problem name. Returns: hparams object. Raises: LookupError if the problem name is not in the registry or uses the old style problem_hparams. """ hparams = registry.hparams(hparams_set_name)() problem = registry.problem(problem_name) # The following hack is necessary to prevent the problem from creating # the default TextEncoders, which would fail due to the lack of a # vocabulary file. problem._encoders = { "inputs": DummyTextEncoder(vocab_size=src_vocab_size), "targets": DummyTextEncoder(vocab_size=trg_vocab_size) } try: hparams.add_hparam("max_terminal_id", self.max_terminal_id) except: if hparams.max_terminal_id != self.max_terminal_id: logging.warn("T2T max_terminal_id does not match (%d!=%d)" % (hparams.max_terminal_id, self.max_terminal_id)) try: hparams.add_hparam("closing_bracket_id", self.pop_id) except: if hparams.closing_bracket_id != self.pop_id: logging.warn("T2T closing_bracket_id does not match (%d!=%d)" % (hparams.closing_bracket_id, self.pop_id)) p_hparams = problem.get_hparams(hparams) hparams.problem_instances = [problem] hparams.problems = [p_hparams] return hparams
def create_hparams(hparams_set, hparams_overrides_str="", data_dir=None, problem_name=None): ipdb.set_trace() # 用命令行接收的超参数代替默认参数 """Create HParams with data_dir and problem hparams, if kwargs provided.""" # 这一步,解析model中注册的所有hparams添加到当前hparams中 # registry._HPARAMS中存的是每个model注册hparams的函数,此步执行函数 hparams = registry.hparams(hparams_set)() if data_dir: hparams.add_hparam("data_dir", data_dir) if problem_name: add_problem_hparams(hparams, problem_name) if hparams_overrides_str: tf.logging.info("Overriding hparams in %s with %s", hparams_set, hparams_overrides_str) hparams = hparams.parse(hparams_overrides_str) return hparams
def create_hparams(params_id, data_dir): """Returns hyperparameters, including any flag value overrides. If the hparams FLAG is set, then it will use any values specified in hparams to override any individually-set hyperparameter. This logic allows tuners to override hyperparameter settings to find optimal values. Args: params_id: which set of parameters to choose (must be in _PARAMS above). data_dir: the directory containing the training data. Returns: The hyperparameters as a tf.contrib.training.HParams object. """ hparams = registry.hparams(params_id)() hparams.add_hparam("data_dir", data_dir) # Command line flags override any of the preceding hyperparameter values. if FLAGS.hparams: hparams = hparams.parse(FLAGS.hparams) return add_problem_hparams(hparams, FLAGS.problems)
def create_hparams(params_id, data_dir, passed_hparams=None): """Returns hyperparameters, including any flag value overrides. If the hparams FLAG is set, then it will use any values specified in hparams to override any individually-set hyperparameter. This logic allows tuners to override hyperparameter settings to find optimal values. Args: params_id: which set of parameters to choose (must be in _PARAMS above). data_dir: the directory containing the training data. passed_hparams: command-line overrides for some hparams. Returns: The hyperparameters as a tf.contrib.training.HParams object. """ hparams = registry.hparams(params_id)() hparams.add_hparam("data_dir", data_dir) # Command line flags override any of the preceding hyperparameter values. if passed_hparams: hparams = hparams.parse(passed_hparams) return hparams
def main(_): # gym.logger.set_level(gym.logger.DEBUG) hparams = registry.hparams(FLAGS.loop_hparams_set) hparams.parse(FLAGS.loop_hparams) # Not important for experiments past 2018 if "wm_policy_param_sharing" not in hparams.values().keys(): hparams.add_hparam("wm_policy_param_sharing", False) directories = player_utils.infer_paths(output_dir=FLAGS.output_dir, world_model=FLAGS.wm_dir, policy=FLAGS.policy_dir, data=FLAGS.episodes_data_dir) epoch = FLAGS.epoch if FLAGS.epoch == "last" else int(FLAGS.epoch) if FLAGS.simulated_env: env = player_utils.load_data_and_make_simulated_env( directories["data"], directories["world_model"], hparams, which_epoch_data=epoch) else: env = player_utils.setup_and_load_epoch(hparams, data_dir=directories["data"], which_epoch_data=epoch) env = FlatBatchEnv(env) env = PlayerEnvWrapper(env) # pylint: disable=redefined-variable-type env = player_utils.wrap_with_monitor(env, FLAGS.video_dir) if FLAGS.dry_run: for _ in range(5): env.reset() for i in range(50): env.step(i % 3) env.step(PlayerEnvWrapper.RESET_ACTION) # reset return play.play(env, zoom=FLAGS.zoom, fps=FLAGS.fps)
def __prepare_model(self): """Prepare utilities for decoding.""" hparams = registry.hparams(self.params.hparams_set) hparams.problem = self.problem hparams.problem_hparams = self.problem.get_hparams(hparams) if self.params.hparams: tf.logging.info("Overriding hparams in %s with %s", self.params.hparams_set, self.params.hparams) hparams = hparams.parse(self.params.hparams) trainer_run_config = g2p_trainer_utils.create_run_config( hparams, self.params) exp_fn = g2p_trainer_utils.create_experiment_fn( self.params, self.problem) self.exp = exp_fn(trainer_run_config, hparams) decode_hp = decoding.decode_hparams(self.params.decode_hparams) estimator = trainer_lib.create_estimator(self.params.model_name, hparams, trainer_run_config, decode_hparams=decode_hp, use_tpu=False) return estimator, decode_hp, hparams
def main(_): hparams = registry.hparams(FLAGS.loop_hparams_set) hparams.parse(FLAGS.loop_hparams) output_dir = FLAGS.output_dir subdirectories = ["data", "tmp", "world_model", "ppo"] using_autoencoder = hparams.autoencoder_train_steps > 0 if using_autoencoder: subdirectories.append("autoencoder") directories = setup_directories(output_dir, subdirectories) if hparams.game in gym_problems_specs.ATARI_GAMES: game_with_mode = hparams.game + "_deterministic-v4" else: game_with_mode = hparams.game if using_autoencoder: simulated_problem_name = ( "gym_simulated_discrete_problem_with_agent_on_%s_autoencoded" % game_with_mode) else: simulated_problem_name = ( "gym_simulated_discrete_problem_with_agent_on_%s" % game_with_mode) if simulated_problem_name not in registry.list_problems(): tf.logging.info( "Game Problem %s not found; dynamically registering", simulated_problem_name) gym_problems_specs.create_problems_for_game( hparams.game, game_mode="Deterministic-v4") epoch = hparams.epochs - 1 epoch_data_dir = os.path.join(directories["data"], str(epoch)) ppo_model_dir = directories["ppo"] world_model_dir = directories["world_model"] gym_problem = registry.problem(simulated_problem_name) model_hparams = trainer_lib.create_hparams(hparams.generative_model_params) environment_spec = copy.copy(gym_problem.environment_spec) environment_spec.simulation_random_starts = hparams.simulation_random_starts batch_env_hparams = trainer_lib.create_hparams(hparams.ppo_params) batch_env_hparams.add_hparam("model_hparams", model_hparams) batch_env_hparams.add_hparam("environment_spec", environment_spec) batch_env_hparams.num_agents = 1 with temporary_flags({ "problem": simulated_problem_name, "model": hparams.generative_model, "hparams_set": hparams.generative_model_params, "output_dir": world_model_dir, "data_dir": epoch_data_dir, }): sess = tf.Session() env = DebugBatchEnv(batch_env_hparams, sess) sess.run(tf.global_variables_initializer()) env.initialize() env_model_loader = tf.train.Saver(tf.global_variables("next_frame*")) trainer_lib.restore_checkpoint(world_model_dir, env_model_loader, sess, must_restore=True) model_saver = tf.train.Saver( tf.global_variables(".*network_parameters.*")) trainer_lib.restore_checkpoint(ppo_model_dir, model_saver, sess) key_mapping = gym_problem.env.env.get_keys_to_action() # map special codes key_mapping[()] = 100 key_mapping[(ord("r"), )] = 101 key_mapping[(ord("p"), )] = 102 play.play(env, zoom=2, fps=10, keys_to_action=key_mapping)
def testUnknownHparams(self): with self.assertRaisesRegexp(LookupError, "never registered"): registry.hparams("not_registered") with self.assertRaisesRegexp(LookupError, "never registered"): registry.ranged_hparams("not_registered")
def benchmark(self, ckpt_dir, outer_steps=100, inner_steps=1000): """Run repeatedly on dummy data to benchmark inference.""" # Turn off Grappler optimizations. options = {"disable_meta_optimizer": True} tf.config.optimizer.set_experimental_options(options) # Create the model outside the loop body. hparams = registry.hparams(self.hparams_set) hparams_lib.add_problem_hparams(hparams, self.problem_name) model_cls = registry.model(self.model_name) model = model_cls(hparams, tf.estimator.ModeKeys.EVAL) # Run only the model body (no data pipeline) on device. feature_shape = [ hparams.batch_size, 3 * self.image_size * self.image_size ] features = {"targets": tf.zeros(feature_shape, dtype=tf.int32)} # Call the model once to initialize the variables. Note that # this should never execute. with tf.variable_scope(self.model_name) as vso: transformed_features = model.bottom(features) with tf.variable_scope("body") as vsi: body_out = model.body(transformed_features) logits = model.top(body_out, features) model.loss(logits, features) def call_model(features): with tf.variable_scope(vso, reuse=tf.AUTO_REUSE): transformed_features = model.bottom(features) with tf.variable_scope(vsi, reuse=tf.AUTO_REUSE): body_out = model.body(transformed_features) logits = model.top(body_out, features) return model.loss(logits, features) # Run the function body in a loop to amortize session overhead. loop_index = tf.zeros([], dtype=tf.int32) initial_loss = (tf.zeros([]), tf.zeros([])) def loop_cond(idx, _): return tf.less(idx, tf.constant(inner_steps, dtype=tf.int32)) def loop_body(idx, _): return idx + 1, call_model(features) benchmark_op = tf.while_loop(loop_cond, loop_body, [loop_index, initial_loss], parallel_iterations=1, back_prop=False) session_config = tf.ConfigProto(gpu_options=tf.GPUOptions( allow_growth=False, per_process_gpu_memory_fraction=0.95)) run_metadata = tf.RunMetadata() with tf.Session(config=session_config) as sess: self.restore_model(sess, ckpt_dir) tps = [] for idx in range(outer_steps): start_time = time.time() sess.run(benchmark_op, run_metadata=run_metadata) elapsed_time = time.time() - start_time tps.append(inner_steps * hparams.batch_size * (64 * 64 * 3) / elapsed_time) logging.error("Iterations %d processed %f TPS.", idx, tps[-1]) # Skip the first iteration where all the setup and allocation happens. tps = np.asarray(tps[1:]) logging.error("Mean/Std/Max/Min throughput = %f / %f / %f / %f", np.mean(tps), np.std(tps), tps.max(), tps.min())
def main(_): # gym.logger.set_level(gym.logger.DEBUG) hparams = registry.hparams(FLAGS.loop_hparams_set) hparams.parse(FLAGS.loop_hparams) # Not important for experiments past 2018 if "wm_policy_param_sharing" not in hparams.values().keys(): hparams.add_hparam("wm_policy_param_sharing", False) directories = player_utils.infer_paths(output_dir=FLAGS.output_dir, world_model=FLAGS.wm_dir, policy=FLAGS.policy_dir, data=FLAGS.episodes_data_dir) if FLAGS.game_from_filenames: hparams.set_hparam( "game", player_utils.infer_game_name_from_filenames(directories["data"])) action_meanings = gym.make(full_game_name(hparams.game)).\ unwrapped.get_action_meanings() epoch = FLAGS.epoch if FLAGS.epoch == "last" else int(FLAGS.epoch) def make_real_env(): env = player_utils.setup_and_load_epoch(hparams, data_dir=directories["data"], which_epoch_data=None) env = FlatBatchEnv(env) # pylint: disable=redefined-variable-type return env def make_simulated_env(setable_initial_frames, which_epoch_data): env = player_utils.load_data_and_make_simulated_env( directories["data"], directories["world_model"], hparams, which_epoch_data=which_epoch_data, setable_initial_frames=setable_initial_frames) return env if FLAGS.sim_and_real: sim_env = make_simulated_env(which_epoch_data=None, setable_initial_frames=True) real_env = make_real_env() env = SimAndRealEnvPlayer(real_env, sim_env, action_meanings) else: if FLAGS.simulated_env: env = make_simulated_env( # pylint: disable=redefined-variable-type which_epoch_data=epoch, setable_initial_frames=False) else: env = make_real_env() env = SingleEnvPlayer(env, action_meanings) # pylint: disable=redefined-variable-type env = player_utils.wrap_with_monitor(env, FLAGS.video_dir) if FLAGS.dry_run: env.unwrapped.get_keys_to_action() for _ in range(5): env.reset() for i in range(50): env.step(i % 3) env.step(PlayerEnv.RETURN_DONE_ACTION) # reset return play.play(env, zoom=FLAGS.zoom, fps=FLAGS.fps)
def main(_): hparams = registry.hparams(FLAGS.loop_hparams_set) hparams.parse(FLAGS.loop_hparams) output_dir = FLAGS.output_dir subdirectories = ["data", "tmp", "world_model", "ppo"] using_autoencoder = hparams.autoencoder_train_steps > 0 if using_autoencoder: subdirectories.append("autoencoder") directories = setup_directories(output_dir, subdirectories) if hparams.game in gym_env.ATARI_GAMES: game_with_mode = hparams.game + "_deterministic-v4" else: game_with_mode = hparams.game if using_autoencoder: simulated_problem_name = ( "gym_simulated_discrete_problem_with_agent_on_%s_autoencoded" % game_with_mode) else: simulated_problem_name = ("gym_simulated_discrete_problem_with_agent_on_%s" % game_with_mode) if simulated_problem_name not in registry.list_problems(): tf.logging.info("Game Problem %s not found; dynamically registering", simulated_problem_name) gym_env.register_game(hparams.game, game_mode="Deterministic-v4") epoch = hparams.epochs-1 epoch_data_dir = os.path.join(directories["data"], str(epoch)) ppo_model_dir = directories["ppo"] world_model_dir = directories["world_model"] gym_problem = registry.problem(simulated_problem_name) model_hparams = trainer_lib.create_hparams(hparams.generative_model_params) environment_spec = copy.copy(gym_problem.environment_spec) environment_spec.simulation_random_starts = hparams.simulation_random_starts batch_env_hparams = trainer_lib.create_hparams(hparams.ppo_params) batch_env_hparams.add_hparam("model_hparams", model_hparams) batch_env_hparams.add_hparam("environment_spec", environment_spec) batch_env_hparams.num_agents = 1 with temporary_flags({ "problem": simulated_problem_name, "model": hparams.generative_model, "hparams_set": hparams.generative_model_params, "output_dir": world_model_dir, "data_dir": epoch_data_dir, }): sess = tf.Session() env = DebugBatchEnv(batch_env_hparams, sess) sess.run(tf.global_variables_initializer()) env.initialize() env_model_loader = tf.train.Saver( tf.global_variables("next_frame*")) trainer_lib.restore_checkpoint(world_model_dir, env_model_loader, sess, must_restore=True) model_saver = tf.train.Saver( tf.global_variables(".*network_parameters.*")) trainer_lib.restore_checkpoint(ppo_model_dir, model_saver, sess) key_mapping = gym_problem.env.env.get_keys_to_action() # map special codes key_mapping[()] = 100 key_mapping[(ord("r"),)] = 101 key_mapping[(ord("p"),)] = 102 play.play(env, zoom=2, fps=10, keys_to_action=key_mapping)
def test_train_pong(self): hparams = registry.hparams("pong_model_free") hparams.epochs_num = 2 hparams.num_agents = 2 hparams.epoch_length = 3 rl_trainer_lib.train(hparams)
def main(): args = parser.parse_args() enc = encoder.get_encoder(args.model_name) hparams = model.default_hparams() hparams.res_dropout = args.dropout hparams.attn_dropout = args.dropout epsilon = -1e10 if args.dtype == 'float32': hparams.dtype = tf.float32 elif args.dtype == 'float16': hparams.dtype = tf.float16 epsilon = -65500 elif args.dtype == 'bfloat16': hparams.dtype = tf.bfloat16 epsilon = -65500 else: print('Unknown dtype', args.dtype) if args.float16: hparams.dtype = tf.bfloat16 epsilon = -65500 with open(os.path.join('models', args.model_name, 'hparams.json')) as f: hparams.override_from_dict(json.load(f)) if args.n_ctx >= 0: hparams.n_ctx=args.n_ctx if args.n_embd >= 0: hparams.n_embd=args.n_embd if args.n_head >= 0: hparams.n_head=args.n_head if args.n_layer >= 0: hparams.n_layer=args.n_layer if args.sample_length < 0: args.sample_length = hparams.n_ctx - 1 if args.sample_length > hparams.n_ctx: raise ValueError( "Can't get samples longer than window size: %s" % hparams.n_ctx) if args.sample_ctx < 0: args.sample_ctx = hparams.n_ctx if args.model_name == '345M': args.memory_saving_gradients = True if args.optimizer == 'adam': args.only_train_transformer_layers = True config = tf.ConfigProto() if args.allow_growth: config.gpu_options.allow_growth = True if args.disable_layout_optimizer: config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF with tflex.Session(config=config, init_tpu=args.init_tpu) as sess: context = tf.placeholder(tf.int32, [args.batch_size, None]) context_in = randomize(context, hparams, args.noise) output = model.model(hparams=hparams, X=context_in) loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=context[:, 1:], logits=output['logits'][:, :-1])) if args.val_every > 0: val_context = tf.placeholder(tf.int32, [args.val_batch_size, None]) val_output = model.model(hparams=hparams, X=val_context) val_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])) val_loss_summary = tf.summary.scalar('val_loss', val_loss) tf_sample = sample.sample_sequence( hparams=hparams, length=args.sample_length, context=context, batch_size=args.batch_size, temperature=1.0, top_k=args.top_k, top_p=args.top_p, epsilon=epsilon) all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars parameter_count = sum([np.prod(v.shape.as_list()) for v in train_vars]) print("This model is using %d parameters (%.2fM)" % (parameter_count, parameter_count/(1024.0*1024.0))) with tf.variable_scope(tf.get_variable_scope().name, reuse=tf.AUTO_REUSE): global_step = tflex.get_variable('global_step') or tf.get_variable('global_step', shape=(), dtype=tf.int32, trainable=False) current_step = args.learning_rate_initial_step global_step.load(current_step, session=sess) if args.learning_rate_cos: lr = tflex_sgdr.sgdr_decay_with_warmup(args.learning_rate, global_step, warmup_steps=args.learning_rate_warmup, initial_period_steps=args.learning_rate_period, learning_rate_min=args.learning_rate_min) else: lr = tflex.get_variable('learn_rate') or tf.get_variable('learn_rate', shape=(), dtype=tf.float32, trainable=False) lr.load(args.learning_rate, session=sess) def update_lr(rate=None, step=None): if not args.learning_rate_cos: if step is None: step = global_step.eval(session=sess) if rate is None: rate = args.learning_rate if callable(rate): rate = rate(step) lr.load(rate, session=sess) return lr.eval(session=sess) @tflex.register_command def set_learning_rate(): print("Current learn rate: %0.8f" % update_lr()) print("New learn rate?") rate = input('') if not rate: print("Empty input; not changing anything.") else: try: rate = float(rate) except: print("Invalid input; must be a float") print("Setting learn rate to %0.8f" % rate) args.learning_rate = rate if args.optimizer == 'adam': opt = tf.train.AdamOptimizer(learning_rate=lr) elif args.optimizer == 'sgd': opt = tf.train.GradientDescentOptimizer(learning_rate=lr) elif args.optimizer == 'ada': import tensor2tensor.utils.optimize from tensor2tensor.utils import hparam import tensor2tensor.models.research from tensor2tensor.utils import registry ada_hparams = registry.hparams('afx_mimic_adam') ada_hparams.optimizer_adafactor_beta1 = 0.0 ada_hparams.optimizer_adafactor_factored = True opt = tensor2tensor.utils.optimize.adafactor(learning_rate=lr, hparams=ada_hparams) else: exit('Bad optimizer:', args.optimizer) #if tpu_addr: # # https://pulsejet.github.io/blog/posts/tpu-without-estimator/ # from tensorflow.contrib.tpu.python.tpu import tpu_function # tpu_function.get_tpu_context().set_number_of_shards(8) # opt = tf.contrib.tpu.CrossShardOptimizer(opt) if args.accumulate_gradients > 1: if args.memory_saving_gradients: exit("Memory saving gradients are not implemented for gradient accumulation yet.") opt = AccumulatingOptimizer( opt=opt, var_list=train_vars) opt_reset = opt.reset() opt_compute = opt.compute_gradients(loss) opt_apply = opt.apply_gradients() summary_loss = tf.summary.scalar('loss', opt_apply) else: if args.memory_saving_gradients: opt_grads = memory_saving_gradients.gradients(loss, train_vars) else: opt_grads = tf.gradients(loss, train_vars) opt_grads = list(zip(opt_grads, train_vars)) opt_apply = opt.apply_gradients(opt_grads) summary_loss = tf.summary.scalar('loss', loss) summary_lr = tf.summary.scalar('learning_rate', lr) summaries = tf.summary.merge([summary_lr, summary_loss]) summary_log = tf.summary.FileWriter( os.path.join(CHECKPOINT_DIR, args.run_name)) if args.save_graph: summary_log.add_graph(tf.get_default_graph()) saver = tflex.Saver( var_list=all_vars, max_to_keep=args.max_to_keep, keep_checkpoint_every_n_hours=100000, reshape=args.truncate_weights) sess.run(tf.global_variables_initializer()) if args.restore_from == 'latest': ckpt = tflex.latest_checkpoint( os.path.join(CHECKPOINT_DIR, args.run_name)) if ckpt is None: # Get fresh GPT weights if new run. ckpt = tflex.latest_checkpoint( os.path.join('models', args.model_name)) elif args.restore_from == 'fresh': ckpt = tflex.latest_checkpoint( os.path.join('models', args.model_name)) else: ckpt = tflex.latest_checkpoint(args.restore_from) print('Loading snapshot %s...' % ckpt) t0 = time.time() if not args.fresh_model: saver.restore(sess, ckpt) t1 = time.time() print('Loaded in %f seconds' % (t1 - t0)) def make_sampler(dataset, enc, seed, combine): if os.path.isdir(dataset) or dataset.endswith('.npz'): chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks, seed=seed) print('dataset has', data_sampler.total_size, 'tokens', len(chunks), 'chunks') else: data_sampler = TextSampler(dataset, enc, seed=seed) return data_sampler print('Loading dataset...') seed = None if args.seed < 0 else args.seed data_sampler = make_sampler(dataset=args.dataset, enc=enc, seed=seed, combine=args.combine) if args.val_every > 0: # Sample from validation set once with fixed seed to make # it deterministic during training as well as across runs. val_dataset = args.val_dataset if args.val_dataset else args.dataset val_data_sampler = make_sampler(dataset=val_dataset, enc=enc, seed=1, combine=args.combine) val_batches = [[val_data_sampler.sample(hparams.n_ctx) for _ in range(args.val_batch_size)] for _ in range(args.val_batch_count)] print('Training...') counter = 1 counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter') if os.path.exists(counter_path): # Load the step number if we're resuming a run # Add 1 so we don't immediately try to save again with open(counter_path, 'r') as fp: counter = int(fp.read()) + 1 @tflex.register_command def get_tarfile_name(checkpoint_folder): """Converts a folder path into a filename for a .tar archive""" tarfile_name = checkpoint_folder.replace(os.path.sep, '_') + '.tar' return tarfile_name def copy_checkpoint_to_gdrive(run_name='run1', copy_folder=False): """Copies the checkpoint folder to a mounted Google Drive.""" #is_mounted() checkpoint_folder = os.path.join('checkpoint', run_name) if copy_folder: shutil.copytree(checkpoint_folder, "/content/drive/My Drive/" + checkpoint_folder) else: file_path = get_tarfile_name(checkpoint_folder) # Reference: https://stackoverflow.com/a/17081026 with tarfile.open(file_path, 'w') as tar: tar.add(checkpoint_folder) shutil.copyfile(file_path, "/content/drive/My Drive/" + file_path) @tflex.register_command def save(): maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) print( 'Saving', os.path.join(CHECKPOINT_DIR, args.run_name, 'model-{}').format(counter)) t0 = time.time() saver.save( sess, os.path.join(CHECKPOINT_DIR, args.run_name, 'model'), global_step=counter) t1 = time.time() print('Saved in %f seconds' % (t1 - t0)) with open(counter_path, 'w') as fp: fp.write(str(counter) + '\n') #copy_checkpoint_to_gdrive() @tflex.register_command def generate_samples(): print('Generating samples...') context_tokens = data_sampler.sample(1) all_text = [] index = 0 while index < args.sample_num: out = sess.run( tf_sample, feed_dict={context: args.batch_size * [context_tokens]}) for i in range(min(args.sample_num - index, args.batch_size)): text = enc.decode(out[i]) text = '======== SAMPLE {} ========\n{}\n'.format( index + 1, text) print(text) all_text.append(text) index += 1 maketree(os.path.join(SAMPLE_DIR, args.run_name)) with open( os.path.join(SAMPLE_DIR, args.run_name, 'samples-{}').format(counter), 'w') as fp: fp.write('\n'.join(all_text)) @tflex.register_command def validation(): if args.val_every <= 0: return print('Calculating validation loss...') losses = [] for batch in tqdm.tqdm(val_batches): losses.append(sess.run(val_loss, feed_dict={val_context: batch})) v_val_loss = np.mean(losses) v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss}) summary_log.add_summary(v_summary, counter) summary_log.flush() print( '{stamp} [{counter} | {time:2.4f}] validation loss = {loss:2.4f}' .format( stamp=timestamp(), counter=counter, time=time.time() - start_time, loss=v_val_loss)) start_time = time.time() def elapsed(): return time.time() - start_time def say(msg): print('{stamp} [{counter} | {time:2.4f}] {msg}'.format(counter=counter, time=elapsed(), msg=msg, stamp=timestamp())) def sample_batch(): #return [data_sampler.sample(args.sample_ctx) for _ in range(args.batch_size)] #say('Sampling batch...') r = [] times = [] for _ in range(args.batch_size): start = time.time() sample = data_sampler.sample(args.sample_ctx) end = time.time() elapsed = (end - start) r += [sample] times += [elapsed] total = sum(times) avg = total / len(times) #say('Sampled %d batches in %.4f seconds (avg per batch: %.4f)' % (args.batch_size, total, avg)) return r prev_time = time.time() avg_loss = (0.0, 0.0) if args.debug_before_training: import pdb pdb.set_trace() last_saved_time = elapsed() while True: try: now = elapsed() if args.save_time > 0 and (((now - last_saved_time) / 60.0) >= args.save_time): save() last_saved_time = now elif args.save_every > 0 and (counter % args.save_every == 0): save() if counter % args.sample_every == 0: generate_samples() if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1): validation() v_rate = update_lr() if args.accumulate_gradients > 1: #say('Running opt_reset...') sess.run(opt_reset) for _ in range(args.accumulate_gradients): batch = sample_batch() say('Running opt_compute...') sess.run(opt_compute, feed_dict={context: batch}) say('Running opt_apply...') (v_loss, v_summary) = sess.run((opt_apply, summaries)) else: batch = sample_batch() say('Running opt_apply...') (_, v_loss, v_summary) = sess.run( (opt_apply, loss, summaries), feed_dict={context: batch}) if args.float16: v_loss = tf.to_float(v_loss).eval() summary_log.add_summary(v_summary, counter) summary_log.flush() avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0) now = time.time() print('{stamp} [{counter} | {time:2.4f} | {delta:2.2f}s | {ops:2.6f}tokens/s] loss={loss:2.4f} avg={avg:2.4f} rate={rate:0.7f} step={step}' .format( stamp=timestamp(), counter=counter, time=now - start_time, delta=now - prev_time, ops=args.sample_ctx * args.batch_size / (now - prev_time), rate=v_rate, loss=v_loss, avg=avg_loss[0] / avg_loss[1], step=current_step, )) counter += 1 current_step += 1 global_step.load(current_step, session=sess) tflex.check_commands_with_args( session=sess, stamp=timestamp(), counter=counter, time=now - start_time, delta=now - prev_time, ops=args.batch_size / (now - prev_time), rate=v_rate, loss=v_loss, avg=avg_loss[0] / avg_loss[1], avg_loss=avg_loss, step=current_step, train_vars=train_vars, all_vars=all_vars, args=args, data_sampler=data_sampler, ckpt=ckpt, saver=saver, ) if tflex.should_quit(): break prev_time = now if args.debug_print_all_vars: print('all variables:') print('name/shape/parameter_count') param_count = 0 for x in tf.all_variables(): shape = x.shape.as_list() count = np.prod(shape) print(x.name, shape, count) param_count += count print('Total parameters:', param_count) args.debug_print_all_vars = False if args.debug_print_trainable_vars: print('trainable variables:') print('name/shape/parameter_count') param_count = 0 for x in tf.trainable_variables(): shape = x.shape.as_list() count = np.prod(shape) print(x.name, shape, count) param_count += count print('Total parameters:', param_count) args.debug_print_trainable_vars = False except KeyboardInterrupt: print('interrupted') if args.save_on_ctrlc: save() if args.debug_on_ctrlc: import pdb pdb.set_trace() else: break
def compute_metrics_v2(problem_name, model_name, hparams_name, ckpt_dir, data_dir="/tmp", eval_batch_size=32, eval_steps=100, extra_hparams=[], mode=Modes.EVAL, num_threshold_bins=100): registered_model = registry.model(model_name) hparams = registry.hparams(hparams_name) hparams.mode = mode for extra_hparam in extra_hparams: assert len(extra_hparam) == 2 if extra_hparam[0] == "mode": continue hparams.setattr(extra_hparam[0], extra_hparam[1]) problem_instance = registry.problem(problem_name) problem_hparams = problem_instance.get_hparams(hparams) # Build the eval dataset and get the examples eval_dataset = problem_instance.dataset(mode=Modes.EVAL, data_dir=data_dir) eval_dataset = eval_dataset.repeat(None).batch(eval_batch_size) eval_dataset_iterator = tfe.Iterator(eval_dataset) metrics = {} def _merge(metrics, metrics_partial): for key, value in metrics_partial.items(): if key not in metrics: metrics[key] = value else: metrics[key] += value return metrics with tfe.restore_variables_on_create(ckpt_dir): model_instance = registered_model(hparams, mode, problem_hparams) for i in range(eval_steps): try: eval_examples = eval_dataset_iterator.next() metrics_partial = model_instance.eager_eval(eval_examples) metrics = _merge(metrics, metrics_partial) if i % 10 == 0: msg = "Finished collecting predictions for eval step {}.".format(i) tf.logging.info(msg) except: # Seeing rare CBT deadline exceeded errors and don't know how to modfiy # the deadline... More likely to run into error with more iterations, # wasn't seeing it with 10 and almost always seeing it with 100. # Could conceivably have to do with running out of examples in the eval # set... but there should be over 20k and this would only go through # 3200. msg = "HACK: Squashing inference error." tf.logging.info(msg) for key, value in metrics.items(): metrics[key] = value / eval_steps return metrics
def compute_metrics(problem_name, model_name, hparams_name, ckpt_dir, data_dir="/tmp", eval_batch_size=32, eval_steps=100, extra_hparams=[], mode=Modes.EVAL, num_threshold_bins=100): if not isinstance(num_threshold_bins, int) and num_threshold_bins > 0: msg = "Num threshold bins should be int > 0, saw {}".format( num_threshold_bins) raise ValueError(msg) registered_model = registry.model(model_name) hparams = registry.hparams(hparams_name) hparams.mode = mode for extra_hparam in extra_hparams: assert len(extra_hparam) == 2 if extra_hparam[0] == "mode": continue hparams.setattr(extra_hparam[0], extra_hparam[1]) problem_instance = registry.problem(problem_name) problem_hparams = problem_instance.get_hparams(hparams) # Build the eval dataset and get the examples eval_dataset = problem_instance.dataset(mode=Modes.EVAL, data_dir=data_dir) eval_dataset = eval_dataset.repeat(None).batch(eval_batch_size) eval_dataset_iterator = tfe.Iterator(eval_dataset) with tfe.restore_variables_on_create(ckpt_dir): model_instance = registered_model(hparams, mode, problem_hparams) predictions = np.array([], dtype=np.float32) targets = np.array([], dtype=np.float32) for i in range(eval_steps): try: eval_examples = eval_dataset_iterator.next() prediction = model_instance.infer(eval_examples) # We've concatenated the two embedding vectors followed # by the label so we can obtain the label just by looking # at the last value prediction = np.array([thing[-1] for thing in prediction], dtype=np.float32) target = tf.squeeze(eval_examples["targets"]).numpy().astype(np.float32) predictions = np.concatenate([predictions, prediction]) targets = np.concatenate([targets, target]) if i % 10 == 0: msg = "Finished collecting predictions for eval step {}.".format(i) tf.logging.info(msg) except Exception as e: # Seeing rare CBT deadline exceeded errors and don't know how to modfiy # the deadline... More likely to run into error with more iterations, # wasn't seeing it with 10 and almost always seeing it with 100. # Could conceivably have to do with running out of examples in the eval # set... but there should be over 20k and this would only go through # 3200. msg = "HACK: Squashing inference error." tf.logging.info(msg) metrics_set = [] for i in range(num_threshold_bins): threshold = i / num_threshold_bins metrics = _metrics_given_threshold(predictions, targets, threshold=threshold) metrics["at_threshold"] = threshold metrics["num_threshold_bins"] = num_threshold_bins metrics_set.append(metrics) midpoint_metrics = _metrics_given_threshold(predictions, targets, threshold=0.5) midpoint_metrics["auc"] = _auc_for_metrics_set(metrics_set) return midpoint_metrics, metrics_set, predictions, targets
def _test_hparams_set(self, hparams_set): hparams = registry.hparams(hparams_set) FLAGS.output_dir = tf.test.get_temp_dir() trainer_model_free.train(hparams, FLAGS.output_dir, env_problem_name=None)
def create_hparams(): hparams = registry.hparams(FLAGS.hparams_set)() if FLAGS.hparams: hparams = hparams.parse(FLAGS.hparams) return hparams
def create_rl_hparams(): hparams = registry.hparams(FLAGS.rl_hparams_set)() hparams.parse(FLAGS.rl_hparams) return hparams
# -*- coding: utf-8 -*- """ @author: 代码医生工作室 @公众号:xiangyuejiqiren (内有更多优秀文章及学习资料) @来源: <深度学习之TensorFlow工程化项目实战>配套代码 (700+页) @配套代码技术支持:bbs.aianaconda.com (有问必答) """ #6-19 import tensorflow as tf from tensor2tensor import models from tensor2tensor.utils import t2t_model from tensor2tensor.utils import registry print(len(registry.list_models()), registry.list_models()) print(registry.model('transformer')) print(len(registry.list_hparams()), registry.list_hparams('transformer')) print(registry.hparams('transformer_base_v1'))
def autoencoder_factor(self): """By how much to divide sizes when using autoencoders.""" hparams = registry.hparams(self.ae_hparams_set) return 2**hparams.num_hidden_layers
def create_loop_hparams(): hparams = registry.hparams(FLAGS.loop_hparams_set) hparams.parse(FLAGS.loop_hparams) return hparams
def _lookup_hparams(self): self.hparams = registry.hparams(self.hparams_set) self.hparams.data_dir = self.data_dir self.p_hparams = self.problem.get_hparams(self.hparams)