def test_reformer2_copy(self): trax.fastmath.disable_jit() batch_size_per_device = 1 # Ignored, but needs to be set. steps = 1 n_layers = 2 d_ff = 32 gin.parse_config_file('reformer2_copy.gin') gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device) gin.bind_parameter('batcher.buckets', ([64], [1, 1])) # batch size 1. gin.bind_parameter('train.steps', steps) gin.bind_parameter('Reformer2.n_encoder_layers', n_layers) gin.bind_parameter('Reformer2.n_decoder_layers', n_layers) gin.bind_parameter('Reformer2.d_ff', d_ff) output_dir = self.create_tempdir().full_path _ = trainer_lib.train(output_dir=output_dir)
def test_run_pose_env_collect(self, demo_policy_cls): urdf_root = pose_env.get_pybullet_urdf_root() config_dir = 'research/pose_env/configs' gin_config = os.path.join( FLAGS.test_srcdir, config_dir, 'run_random_collect.gin') gin.parse_config_file(gin_config) tmp_dir = absltest.get_default_test_tmpdir() root_dir = os.path.join(tmp_dir, str(demo_policy_cls)) gin.bind_parameter('PoseToyEnv.urdf_root', urdf_root) gin.bind_parameter( 'collect_eval_loop.root_dir', root_dir) gin.bind_parameter('run_meta_env.num_tasks', 2) gin.bind_parameter('run_meta_env.num_episodes_per_adaptation', 1) gin.bind_parameter( 'collect_eval_loop.policy_class', demo_policy_cls) continuous_collect_eval.collect_eval_loop() output_files = tf.io.gfile.glob(os.path.join( root_dir, 'policy_collect', '*.tfrecord')) self.assertLen(output_files, 2)
def test_reformer_wmt_ende(self): trax.fastmath.disable_jit() batch_size_per_device = 2 steps = 1 n_layers = 2 d_ff = 32 gin.parse_config_file('reformer_wmt_ende.gin') gin.bind_parameter('data_streams.data_dir', _TESTDATA) gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device) gin.bind_parameter('train.steps', steps) gin.bind_parameter('Reformer.n_encoder_layers', n_layers) gin.bind_parameter('Reformer.n_decoder_layers', n_layers) gin.bind_parameter('Reformer.d_ff', d_ff) output_dir = self.create_tempdir().full_path _ = trainer_lib.train(output_dir=output_dir)
def main(argv): del argv # Load gin.config settings stored in model directory. It is possible to run # this script concurrently with the train script. In this case, wait for the # train script to start up and actually write out a gin config file. # Wait 10 minutes (periodically checking for file existence) before giving up. gin_config_path = os.path.join(FLAGS.base_dir, 'config.gin') num_tries = 0 while not gfile.exists(gin_config_path): num_tries += 1 if num_tries >= 10: raise ValueError('Could not find config.gin in "%s"' % FLAGS.base_dir) time.sleep(60) gin.parse_config_file(gin_config_path, skip_unknown=True) gin.finalize() run_eval()
def main(config): """Start the worker.""" gin.parse_config_file(config.config) logger = misc.utility.create_logger(name='es_worker{}'.format( config.worker_id), log_dir=config.log_dir) if config.master_address is not None: logger.info('master_address: {}'.format(config.master_address)) channel = grpc.insecure_channel( config.master_address, [("grpc.max_receive_message_length", _MAX_MSG_LEN)]) stub = protobuf.roll_out_service_pb2_grpc.ParameterSyncServiceStub( channel) worker = misc.utility.get_es_worker(logger=logger, master=stub) else: worker = misc.utility.get_es_worker(logger=logger) if config.run_on_gke: port = config.port else: port = config.port + config.worker_id server = grpc.server(futures.ThreadPoolExecutor(max_workers=1), options=[ ("grpc.max_send_message_length", _MAX_MSG_LEN), ("grpc.max_receive_message_length", _MAX_MSG_LEN) ]) # Start the RPC server. protobuf.roll_out_service_pb2_grpc.add_RollOutServiceServicer_to_server( worker, server) server.add_insecure_port('[::]:{}'.format(port)) server.start() logger.info('Listening to port {} ...'.format(port)) try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: logger.info('Worker quit.')
def score(self, inputs, targets, scores_file=None, checkpoint_steps=-1, vocabulary=None): """Computes log-likelihood of target per example in targets. Args: inputs: optional - a string (filename), or a list of strings (inputs) targets: a string (filename), or a list of strings (targets) scores_file: str, path to write example scores to, one per line. checkpoint_steps: int, list of ints, or None. If an int or list of ints, inference will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None, run inference continuously waiting for new checkpoints. If -1, get the latest checkpoint from the model directory. vocabulary: vocabularies.Vocabulary object to use for tokenization, or None to use the default SentencePieceVocabulary. """ if checkpoint_steps == -1: checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) if vocabulary is None: vocabulary = t5.data.get_default_vocabulary() if isinstance(targets, str): tf.logging.info("scoring targets from file %s" % targets) utils.score_from_files(self.estimator(vocabulary), vocabulary, self._model_type, self.batch_size, self._sequence_length, self._model_dir, checkpoint_steps, inputs, targets, scores_file) else: tf.logging.info("scoring targets from list of strings") utils.score_from_strings(self.estimator(vocabulary), vocabulary, self._model_type, self.batch_size, self._sequence_length, self._model_dir, checkpoint_steps, inputs, targets, scores_file)
def test_autoregressive_sample_reformer2_lsh_attn_quality(self): gin.add_config_file_search_path(_CONFIG_DIR) # 32 is the max length we trained the checkpoint for. test_lengths = [8, 16, 32] vocab_size = 13 np.random.seed(0) for max_len in test_lengths: gin.clear_config() gin.parse_config_file('reformer2_copy.gin') gin.bind_parameter('LSHSelfAttention.predict_mem_len', 2 * max_len) gin.bind_parameter('LSHSelfAttention.predict_drop_len', max_len // 4) pred_model = models.Reformer2(mode='predict') shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) model_path = os.path.join(_TESTDATA, 'reformer2_copy_lsh_attn.pkl.gz') pred_model.init_from_file(model_path, weights_only=True, input_signature=(shape1l, shape11)) initial_state = pred_model.state for _ in range(3): # pick a length in [1, max_len] inp_len = np.random.randint(low=1, high=max_len + 1) inputs = np.random.randint(low=1, high=vocab_size - 1, size=(1, inp_len)) inputs = np.pad(inputs, [(0, 0), (0, max_len - inp_len)], mode='constant', constant_values=0) s = decoding.autoregressive_sample(pred_model, inputs=inputs, eos_id=-1, max_length=inp_len, temperature=0.0) np.testing.assert_equal(s[0], inputs[0, :inp_len]) pred_model.state = initial_state
def export(self, export_dir=None, checkpoint_step=-1, beam_size=1, temperature=1.0, vocabulary=None): """Exports a TensorFlow SavedModel. Args: export_dir: str, a directory in which to export SavedModels. Will use `model_dir` if unspecified. checkpoint_step: int, checkpoint to export. If -1 (default), use the latest checkpoint from the pretrained model directory. beam_size: int, a number >= 1 specifying the number of beams to use for beam search. temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. vocabulary: vocabularies.Vocabulary object to use for tokenization, or None to use the default SentencePieceVocabulary. Returns: The string path to the exported directory. """ if checkpoint_step == -1: checkpoint_step = _get_latest_checkpoint_from_dir(self._model_dir) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) gin.bind_parameter("Bitransformer.decode.beam_size", beam_size) gin.bind_parameter("Bitransformer.decode.temperature", temperature) if vocabulary is None: vocabulary = t5.data.get_default_vocabulary() model_ckpt = "model.ckpt-" + str(checkpoint_step) export_dir = export_dir or self._model_dir return utils.export_model(self.estimator(vocabulary, disable_tpu=True), export_dir, vocabulary, self._sequence_length, batch_size=self.batch_size, checkpoint_path=os.path.join( self._model_dir, model_ckpt))
def predict(self, input_file, output_file, checkpoint_steps=-1, beam_size=1, temperature=1.0, vocabulary=None): """Predicts targets from the given inputs. Args: input_file: str, path to a text file containing newline-separated input prompts to predict from. output_file: str, path prefix of output file to write predictions to. Note the checkpoint step will be appended to the given filename. checkpoint_steps: int, list of ints, or None. If an int or list of ints, inference will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None, run inference continuously waiting for new checkpoints. If -1, get the latest checkpoint from the model directory. beam_size: int, a number >= 1 specifying the number of beams to use for beam search. temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. vocabulary: vocabularies.Vocabulary object to use for tokenization, or None to use the default SentencePieceVocabulary. """ # TODO(sharannarang) : It would be nice to have a function like # load_checkpoint that loads the model once and then call decode_from_file # multiple times without having to restore the checkpoint weights again. # This would be particularly useful in colab demo. if checkpoint_steps == -1: checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) gin.bind_parameter("Bitransformer.decode.beam_size", beam_size) gin.bind_parameter("Bitransformer.decode.temperature", temperature) gin.bind_parameter("tpu_estimator_model_fn.score_in_predict_mode", False) if vocabulary is None: vocabulary = t5.data.get_default_vocabulary() utils.infer_model( self.estimator(vocabulary), vocabulary, self._sequence_length, self.batch_size, self._model_type, self._model_dir, checkpoint_steps, input_file, output_file)
def test_reformer_wmt_ende(self): trax.math.disable_jit() batch_size_per_device = 2 steps = 1 n_layers = 2 d_ff = 32 gin.parse_config_file( os.path.join(_CONFIG_DIR, 'reformer_wmt_ende.gin')) gin.bind_parameter('data_streams.data_dir', _TESTDATA) gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device) gin.bind_parameter('train.steps', steps) gin.bind_parameter('Reformer.n_encoder_layers', n_layers) gin.bind_parameter('Reformer.n_decoder_layers', n_layers) gin.bind_parameter('Reformer.d_ff', d_ff) with self.tmp_dir() as output_dir: _ = trainer_lib.train(output_dir=output_dir)
def parallel_exps(var_par, varying_par_to_change, gin_path, func): """Main function to parallelize which loads the parameters for the real experiments and run both training and testing routines Parameters ---------- var_par: list List of varying parameters for the single parallelized experiment Param: dict The dictionary containing the parameters """ gin.parse_config_file(gin_path, skip_unknown=True) for i in range(len(var_par)): gin.bind_parameter(varying_par_to_change[i], var_par[i]) # pdb.set_trace() model_runner = func() model_runner.run()
def parse_gin(restore_dir): """Parse gin config from --gin_file, --gin_param, and the model directory.""" # Add user folders to the gin search path. for gin_search_path in [GIN_PATH] + FLAGS.gin_search_path: gin.add_config_file_search_path(gin_search_path) # Parse gin configs, later calls override earlier ones. with gin.unlock_config(): # Optimization defaults. use_tpu = bool(FLAGS.tpu) opt_default = 'base.gin' if not use_tpu else 'base_tpu.gin' gin.parse_config_file(os.path.join('optimization', opt_default)) eval_default = 'eval/basic.gin' gin.parse_config_file(eval_default) # Load operative_config if it exists (model has already trained). operative_config = train_util.get_latest_operative_config(restore_dir) if tf.io.gfile.exists(operative_config): logging.info('Using operative config: %s', operative_config) operative_config = cloud.make_file_paths_local( operative_config, GIN_PATH) gin.parse_config_file(operative_config, skip_unknown=True) # User gin config and user hyperparameters from flags. gin_file = cloud.make_file_paths_local(FLAGS.gin_file, GIN_PATH) gin.parse_config_files_and_bindings(gin_file, FLAGS.gin_param, skip_unknown=True)
def eval(self, mixture_or_task_name, checkpoint_steps=None, summary_dir=None, split="validation"): """Evaluate the model on the given Mixture or Task. Args: mixture_or_task_name: str, the name of the Mixture or Task to evaluate on. Must be pre-registered in the global `TaskRegistry` or `MixtureRegistry.` checkpoint_steps: int, list of ints, or None. If an int or list of ints, evaluation will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None, run eval continuously waiting for new checkpoints. If -1, get the latest checkpoint from the model directory. summary_dir: str, path to write TensorBoard events file summaries for eval. If None, use model_dir/eval_{split}. split: str, the split to evaluate on. """ if checkpoint_steps == -1: checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir) vocabulary = get_mixture_or_task_ll( mixture_or_task_name).get_vocabulary() dataset_fn = functools.partial( mesh_eval_dataset_fn_ll, mixture_or_task_name=mixture_or_task_name) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) eval_model_ll(self.estimator(vocabulary), vocabulary, self._sequence_length, self.batch_size, split, self._model_dir, dataset_fn, summary_dir, checkpoint_steps, attribute_bit=self.attribute_bit, unsupervised_attribute_transfer_metrics=self. unsupervised_attribute_transfer_metrics, control_code_bool=self.control_code_bool)
def main(unused_argv): tf.random.set_seed(FLAGS.seed) init_timer = timer.Timer() init_timer.Start() if FLAGS.mode == 'hessian': # Load default values from the original experiment. FLAGS.preload_gin_config = os.path.join(FLAGS.logdir, 'operative_config.gin') # Maybe preload a gin config. if FLAGS.preload_gin_config: config_path = FLAGS.preload_gin_config gin.parse_config_file(config_path) logging.info('Gin configuration pre-loaded from: %s', config_path) gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings) ds_train, ds_test, info = utils.get_dataset() input_shape = info.features['image'].shape num_classes = info.features['label'].num_classes logging.info('Input Shape: %s', input_shape) logging.info('train samples: %s', info.splits['train'].num_examples) logging.info('test samples: %s', info.splits['test'].num_examples) pruning_params = utils.get_pruning_params() model = utils.get_network(pruning_params, input_shape, num_classes) model.summary(print_fn=logging.info) if FLAGS.mode == 'train_eval': train_model(model, ds_train, ds_test, FLAGS.logdir) elif FLAGS.mode == 'hessian': test_model(model, ds_test) hessian(model, ds_train, FLAGS.logdir) logging.info('Total runtime: %.3f s', init_timer.GetDuration()) logconfigfile_path = os.path.join( FLAGS.logdir, 'hessian_' if FLAGS.mode == 'hessian' else '' + 'operative_config.gin') with tf.io.gfile.GFile(logconfigfile_path, 'w') as f: f.write('# Gin-Config:\n %s' % gin.config.operative_config_str())
def eval(self, mixture_or_task_name, checkpoint_steps=None, summary_dir=None, split="validation", eval_with_score=False, beam_size=1, temperature=1.0): """Evaluate the model on the given Mixture or Task. Args: mixture_or_task_name: str, the name of the Mixture or Task to evaluate on. Must be pre-registered in the global `TaskRegistry` or `MixtureRegistry.` checkpoint_steps: int, list of ints, or None. If an int or list of ints, evaluation will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None, run eval continuously waiting for new checkpoints. If -1, get the latest checkpoint from the model directory. summary_dir: str, path to write TensorBoard events file summaries for eval. If None, use model_dir/eval_{split}. split: str, the mixture/task split to evaluate on. eval_with_score: bool, whether to evaluate using log likelihood scores of targets instead of decoded predictions. """ if checkpoint_steps == -1: checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir) vocabulary = _get_vocabulary(mixture_or_task_name) dataset_fn = functools.partial( t5.models.mesh_transformer.mesh_eval_dataset_fn, mixture_or_task_name=mixture_or_task_name, ) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) gin.bind_parameter("Bitransformer.decode.beam_size", beam_size) gin.bind_parameter("Bitransformer.decode.temperature", temperature) estimator = self.estimator( vocabulary, score_in_predict_mode=eval_with_score) utils.eval_model( estimator=estimator, vocabulary=vocabulary, sequence_length=self._sequence_length, batch_size=self.batch_size, dataset_split=split, model_dir=self._model_dir, eval_dataset_fn=dataset_fn, eval_summary_dir=summary_dir, eval_checkpoint_step=checkpoint_steps, eval_with_score=eval_with_score)
def parse_gin(model_dir): """Parse gin config from --gin_file, --gin_param, and the model directory.""" # Add user folders to the gin search path. for gin_search_path in [GIN_PATH] + FLAGS.gin_search_path: gin.add_config_file_search_path(gin_search_path) # Parse gin configs, later calls override earlier ones. with gin.unlock_config(): # Optimization defaults. use_tpu = bool(FLAGS.tpu) opt_default = 'base.gin' if not use_tpu else 'base_tpu.gin' gin.parse_config_file(os.path.join('optimization', opt_default)) # Load operative_config if it exists (model has already trained). operative_config = os.path.join(model_dir, 'operative_config-0.gin') if tf.io.gfile.exists(operative_config): gin.parse_config_file(operative_config, skip_unknown=True) # User gin config and user hyperparameters from flags. gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param, skip_unknown=True)
def test_autoregressive_sample_reformer2_pure_lsh_attn_quality(self): gin.add_config_file_search_path(_CONFIG_DIR) max_len = 32 # 32 is the max length we trained the checkpoint for. test_lengths = [8, 16, 32] vocab_size = 13 # The checkpoint is correct on ~90% sequences, set random seed to deflake. np.random.seed(0) for test_len in test_lengths: gin.clear_config() gin.parse_config_file('reformer2_purelsh_copy.gin') gin.bind_parameter('PureLSHSelfAttention.predict_mem_len', 2 * max_len) gin.bind_parameter('PureLSHSelfAttention.predict_drop_len', 2 * max_len) gin.bind_parameter('PureLSHSelfAttentionWrapper.bias', False) gin.bind_parameter('PureLSHSelfAttentionWrapper.num_weights', 2) pred_model = models.Reformer2(mode='predict') shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32) shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32) model_path = os.path.join(_TESTDATA, 'reformer2_purelsh_copy.pkl.gz') pred_model.init_from_file(model_path, weights_only=True, input_signature=(shape1l, shape11)) initial_state = pred_model.state for _ in range(2): # Set low to make the test run reasonably fast. # Pick a length in [1, test_len] at random. inp_len = np.random.randint(low=1, high=test_len + 1) inputs = np.random.randint(low=1, high=vocab_size-1, size=(1, inp_len)) inputs = np.pad(inputs, [(0, 0), (0, max_len - inp_len)], mode='constant', constant_values=0) s = decoding.autoregressive_sample( pred_model, inputs=inputs, eos_id=-1, max_length=inp_len, temperature=0.0) np.testing.assert_equal(s[0], inputs[0, :inp_len]) pred_model.state = initial_state gin.clear_config() # Make sure to not affect other tests.
def benchmark_halfcheetah_medium_v0(self): """Benchmarks MuJoCo HalfCheetah to 1M steps.""" self.setUp() output_dir = self._get_test_output_dir('halfcheetah_medium_v0_02_eval') dataset_path = self.root_data_dir start_time_sec = time.time() gin.parse_config_file( 'tf_agents/examples/cql_sac/kumar20/configs/mujoco_medium.gin') cql_sac_train_eval.train_eval( dataset_path=dataset_path, root_dir=output_dir, env_name='halfcheetah-medium-v0', num_gradient_updates=500000, # Number of iterations. learner_iterations_per_call=500, data_shuffle_buffer_size=10000, data_num_shards=50, data_parallel_reads=500, data_prefetch=1000000, eval_interval=10000) wall_time_sec = time.time() - start_time_sec event_file = utils.find_event_log(os.path.join(output_dir, 'eval')) values, _ = utils.extract_event_log_values(event_file, 'Metrics/AverageReturn', start_step=10000) # Min/Max ranges are very large to only hard fail if very broken. The system # monitoring the results owns looking for anomalies. These numbers are based # on the results that we were getting in MLCompass as of 04-NOV-2021. # Results at 500k steps and 1M steps are similar enough to not make it worth # running 1M. metric_500k = self.build_metric('average_return_at_env_step500000', values[500000], min_value=4400, max_value=5400) self.report_benchmark(wall_time=wall_time_sec, metrics=[metric_500k], extras={})
def test_reformer_noencdecattn_wmt_ende(self): trax.fastmath.disable_jit() batch_size_per_device = 1 # Ignored, but needs to be set. steps = 1 n_layers = 2 d_ff = 32 gin.parse_config_file('reformer_noencdecattn_wmt_ende.gin') gin.bind_parameter('data_streams.data_dir', _TESTDATA) gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device) gin.bind_parameter('batcher.buckets', ([512], [1, 1])) # batch size 1. gin.bind_parameter('train.steps', steps) gin.bind_parameter('ReformerNoEncDecAttention.n_encoder_layers', n_layers) gin.bind_parameter('ReformerNoEncDecAttention.n_decoder_layers', n_layers) gin.bind_parameter('ReformerNoEncDecAttention.d_ff', d_ff) output_dir = self.create_tempdir().full_path _ = trainer_lib.train(output_dir=output_dir)
def test_terraformer_wmt_ende(self): batch_size_per_device = 2 steps = 1 n_layers = 2 d_ff = 32 gin.parse_config_file('terraformer_wmt_ende.gin') gin.bind_parameter('data_streams.data_dir', _TESTDATA) gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device) gin.bind_parameter( 'batcher.buckets', ([512], [batch_size_per_device, batch_size_per_device])) gin.bind_parameter('train.steps', steps) gin.bind_parameter('ConfigurableTerraformer.n_encoder_layers', n_layers) gin.bind_parameter('ConfigurableTerraformer.n_decoder_layers', n_layers) gin.bind_parameter('ConfigurableTerraformer.d_ff', d_ff) output_dir = self.create_tempdir().full_path _ = trainer_lib.train(output_dir=output_dir)
def export(self, export_dir=None, checkpoint_step=-1, beam_size=1, temperature=1.0, sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH): """Exports a TensorFlow SavedModel. Args: export_dir: str, a directory in which to export SavedModels. Will use `model_dir` if unspecified. checkpoint_step: int, checkpoint to export. If -1 (default), use the latest checkpoint from the pretrained model directory. beam_size: int, a number >= 1 specifying the number of beams to use for beam search. temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. sentencepiece_model_path: str, path to the SentencePiece model file to use for decoding. Must match the one used during training. """ if checkpoint_step == -1: checkpoint_step = _get_latest_checkpoint_from_dir(self._model_dir) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) gin.bind_parameter("Bitransformer.decode.beam_size", beam_size) gin.bind_parameter("Bitransformer.decode.temperature", temperature) gin.bind_parameter("utils.get_variable_dtype.slice_dtype", "float32") gin.bind_parameter("utils.get_variable_dtype.activation_dtype", "float32") vocabulary = t5.data.SentencePieceVocabulary(sentencepiece_model_path) model_ckpt = "model.ckpt-" + str(checkpoint_step) export_dir = export_dir or self._model_dir utils.export_model(self.estimator(vocabulary), export_dir, vocabulary, self._sequence_length, os.path.join(self._model_dir, model_ckpt))
def finetune(self, mixture_or_task_name, finetune_steps, pretrained_model_dir, pretrained_checkpoint_step=-1): """Finetunes a model from an existing checkpoint. Args: mixture_or_task_name: str, the name of the Mixture or Task to evaluate on. Must be pre-registered in the global `TaskRegistry` or `MixtureRegistry.` finetune_steps: int, the number of additional steps to train for. pretrained_model_dir: str, directory with pretrained model checkpoints and operative config. pretrained_checkpoint_step: int, checkpoint to initialize weights from. If -1, use the latest checkpoint from the pretrained model directory. """ if pretrained_checkpoint_step == -1: checkpoint_step = _get_latest_checkpoint_from_dir(pretrained_model_dir) else: checkpoint_step = pretrained_checkpoint_step with gin.unlock_config(): gin.parse_config_file(_operative_config_path(pretrained_model_dir)) model_ckpt = "model.ckpt-" + str(checkpoint_step) self.train(mixture_or_task_name, checkpoint_step + finetune_steps, init_checkpoint=os.path.join(pretrained_model_dir, model_ckpt))
def test_reformer_noencdecattn_wmt_ende(self): trax.math.disable_jit() batch_size_per_device = 1 # Ignored, but needs to be set. steps = 1 n_layers = 2 d_ff = 32 gin.parse_config_file( os.path.join(_CONFIG_DIR, 'reformer_noencdecattn_wmt_ende.gin')) gin.bind_parameter('data_streams.data_dir', _TESTDATA) gin.bind_parameter('batcher.batch_size_per_device', batch_size_per_device) gin.bind_parameter('batcher.buckets', ([513], [1, 1])) # batch size 1. gin.bind_parameter('train.steps', steps) gin.bind_parameter('ReformerNoEncDecAttention.n_encoder_layers', n_layers) gin.bind_parameter('ReformerNoEncDecAttention.n_decoder_layers', n_layers) gin.bind_parameter('ReformerNoEncDecAttention.d_ff', d_ff) with self.tmp_dir() as output_dir: _ = trainer_lib.train(output_dir=output_dir)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments for train.') if not FLAGS.gin_config: # Run the experiments locally. gin.parse_config_file(FLAGS.gin_config_file) else: # Run the experiments on a server. gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings) # create the `checkpoing` and `summary` directory used during training # to save/restore the model and write TF summaries. checkpoint_dir_str = gin.query_parameter('checkpoint_dir/macro.value') summary_dir_str = gin.query_parameter('summary_dir/macro.value') mcts_checkpoint_dir_str = os.path.join(checkpoint_dir_str, 'mcts_data') app_directories = [ checkpoint_dir_str, summary_dir_str, mcts_checkpoint_dir_str ] for d in app_directories: directory_handling.ensure_dir_exists(d) ppo_train()
def parse_gin_config(self, ckpt): """Parse the model operative config with new length parameters.""" with gin.unlock_config(): ckpt_dir = os.path.dirname(ckpt) operative_config = train_util.get_latest_operative_config(ckpt_dir) print(f'Parsing from operative_config {operative_config}') gin.parse_config_file(operative_config, skip_unknown=True) # Set gin params to new length. # Remove reverb processor. pg_string = """ProcessorGroup.dag = [ (@synths.Harmonic(), ['amps', 'harmonic_distribution', 'f0_hz']), (@synths.FilteredNoise(), ['noise_magnitudes']), (@processors.Add(), ['filtered_noise/signal', 'harmonic/signal']), ]""" gin.parse_config([ 'Harmonic.n_samples=%d' % self.n_samples, 'FilteredNoise.n_samples=%d' % self.n_samples, 'F0LoudnessPreprocessor.time_steps=%d' % self.time_steps, 'oscillator_bank.use_angular_cumsum=True', pg_string, ])
def parse_gin(restore_dir): """Parse gin config from --gin_file, --gin_param, and the model directory.""" # Add user folders to the gin search path. for gin_search_path in [GIN_PATH] + FLAGS.gin_search_path: gin.config.add_config_file_search_path(gin_search_path) # Parse gin configs, later calls override earlier ones. with gin.unlock_config(): # Optimization defaults. use_tpu = bool(FLAGS.tpu) opt_default = 'base.gin' if not use_tpu else 'base_tpu.gin' # USES main library gin--'optimization' folder ONLY ('models' and 'dataset' folders use drive gin) # gin.parse_config_file(os.path.join('ddsp/training/gin/optimization', opt_default)) gin.parse_config_file(os.path.join('ddsp/training/gin/optimization', opt_default)) # Load operative_config if it exists (model has already trained). operative_config = train_util.get_latest_operative_config(restore_dir) if tf.io.gfile.exists(operative_config): logging.info('Using operative config: %s', operative_config) gin.parse_config_file(operative_config, skip_unknown=True) # User gin config and user hyperparameters from flags. gin.config.parse_config_files_and_bindings( FLAGS.gin_file, FLAGS.gin_param, skip_unknown=True)
def train(args): # Parse config file if args.config_file: if os.path.isfile(os.path.join(args.directory, args.config_file)): gin.parse_config_file( os.path.join(args.directory, args.config_file)) elif os.path.isfile(args.config_file): gin.parse_config_file(args.config_file) else: raise FileNotFoundError('Failed to find config file {}'.format( args.config_file)) kwargs = {'directory': args.directory} # If env id or env args are provided, use them for the training env # (otherwise, the configured defaults are used) make_args = {k: v for k, v in args.env_args} if args.env_args else {} if args.env_id: make_args['env'] = args.env_id if make_args: kwargs['env'] = lambda **kwa: stackrl.envs.make(**make_args, **kwa) # Run training training = stackrl.Training(**kwargs) training.run()
def __init__(self, data_path, mode, train_set, validation_set, test_set, max_way_train, max_way_test, max_support_train, max_support_test): self.data_path = data_path self.train_dataset_next_task = None self.validation_set_dict = {} self.test_set_dict = {} gin.parse_config_file('./meta_dataset_config.gin') if mode == 'train' or mode == 'train_test': train_episode_description = self._get_train_episode_description(max_way_train, max_support_train) self.train_dataset_next_task = self._init_multi_source_dataset(train_set, learning_spec.Split.TRAIN, train_episode_description) test_episode_description = self._get_test_episode_description(max_way_test, max_support_test) for item in validation_set: next_task = self.validation_dataset = self._init_single_source_dataset(item, learning_spec.Split.VALID, test_episode_description) self.validation_set_dict[item] = next_task if mode == 'test' or mode == 'train_test' or mode == 'attack': test_episode_description = self._get_test_episode_description(max_way_test, max_support_test) for item in test_set: next_task = self._init_single_source_dataset(item, learning_spec.Split.TEST, test_episode_description) self.test_set_dict[item] = next_task
if eval_flag: latest = tf.train.latest_checkpoint(checkpoint_dir) model.load_weights(latest) eval_model(eval_dir, model, system, model_name) if __name__ == '__main__': root_dir = sys.argv[1] model_name = sys.argv[2] system_name = sys.argv[3] observations = sys.argv[4] num_train_traj = int(sys.argv[5]) num_train_steps = int(sys.argv[6]) seed = int(sys.argv[7]) if not os.path.isdir(root_dir): os.makedirs(root_dir) #NOTE: Configs may overwrite arguments in prev. files gin.parse_config_file('./config/base.gin') gin.parse_config_file('./config/{}.gin'.format(observations)) if model_name == 'VIN_SO2' and observations == 'pixels': # Overwrite step-size for SO2 VIN gin.parse_config_file('./config/pixels_SO2.gin') gin.parse_config_file('./config/{}.gin'.format(system_name)) main(root_dir, model_name, system_name, observations, num_train_traj, num_train_steps, seed)
if self.results.get(ds_name, 0) < f1: self.results[ds_name] = f1 def finish(self): with self.summary_writer.as_default(): for ds_name, f1 in self.results.items(): tf.summary.scalar("test/" + ds_name + "/scene/best_f1", f1, step=0) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train TransNet") parser.add_argument("config", help="path to config") args = parser.parse_args() gin.parse_config_file(args.config) options = get_options_dict() trn_ds = input_processing.train_pipeline(options["trn_files"]) if len(options["trn_files"]) > 0 else None if options["transition_only_trn_files"] is not None: trn_ds_ = input_processing.train_transition_pipeline(options["transition_only_trn_files"]) if trn_ds is not None: frac = options["transition_only_data_fraction"] trn_ds = tf.data.experimental.sample_from_datasets([trn_ds, trn_ds_], weights=[1 - frac, frac]) else: trn_ds = trn_ds_ tst_ds = [(name, input_processing.test_pipeline(files)) for name, files in options["tst_files"].items()] if options["original_transnet"]: