示例#1
0
    def test_reformer2_copy(self):
        trax.fastmath.disable_jit()

        batch_size_per_device = 1  # Ignored, but needs to be set.
        steps = 1
        n_layers = 2
        d_ff = 32

        gin.parse_config_file('reformer2_copy.gin')

        gin.bind_parameter('batcher.batch_size_per_device',
                           batch_size_per_device)
        gin.bind_parameter('batcher.buckets', ([64], [1, 1]))  # batch size 1.
        gin.bind_parameter('train.steps', steps)
        gin.bind_parameter('Reformer2.n_encoder_layers', n_layers)
        gin.bind_parameter('Reformer2.n_decoder_layers', n_layers)
        gin.bind_parameter('Reformer2.d_ff', d_ff)

        output_dir = self.create_tempdir().full_path
        _ = trainer_lib.train(output_dir=output_dir)
示例#2
0
  def test_run_pose_env_collect(self, demo_policy_cls):
    urdf_root = pose_env.get_pybullet_urdf_root()

    config_dir = 'research/pose_env/configs'
    gin_config = os.path.join(
        FLAGS.test_srcdir, config_dir, 'run_random_collect.gin')
    gin.parse_config_file(gin_config)
    tmp_dir = absltest.get_default_test_tmpdir()
    root_dir = os.path.join(tmp_dir, str(demo_policy_cls))
    gin.bind_parameter('PoseToyEnv.urdf_root', urdf_root)
    gin.bind_parameter(
        'collect_eval_loop.root_dir', root_dir)
    gin.bind_parameter('run_meta_env.num_tasks', 2)
    gin.bind_parameter('run_meta_env.num_episodes_per_adaptation', 1)
    gin.bind_parameter(
        'collect_eval_loop.policy_class', demo_policy_cls)
    continuous_collect_eval.collect_eval_loop()
    output_files = tf.io.gfile.glob(os.path.join(
        root_dir, 'policy_collect', '*.tfrecord'))
    self.assertLen(output_files, 2)
示例#3
0
    def test_reformer_wmt_ende(self):
        trax.fastmath.disable_jit()

        batch_size_per_device = 2
        steps = 1
        n_layers = 2
        d_ff = 32

        gin.parse_config_file('reformer_wmt_ende.gin')

        gin.bind_parameter('data_streams.data_dir', _TESTDATA)
        gin.bind_parameter('batcher.batch_size_per_device',
                           batch_size_per_device)
        gin.bind_parameter('train.steps', steps)
        gin.bind_parameter('Reformer.n_encoder_layers', n_layers)
        gin.bind_parameter('Reformer.n_decoder_layers', n_layers)
        gin.bind_parameter('Reformer.d_ff', d_ff)

        output_dir = self.create_tempdir().full_path
        _ = trainer_lib.train(output_dir=output_dir)
示例#4
0
def main(argv):
    del argv

    # Load gin.config settings stored in model directory. It is possible to run
    # this script concurrently with the train script. In this case, wait for the
    # train script to start up and actually write out a gin config file.
    # Wait 10 minutes (periodically checking for file existence) before giving up.
    gin_config_path = os.path.join(FLAGS.base_dir, 'config.gin')
    num_tries = 0
    while not gfile.exists(gin_config_path):
        num_tries += 1
        if num_tries >= 10:
            raise ValueError('Could not find config.gin in "%s"' %
                             FLAGS.base_dir)
        time.sleep(60)

    gin.parse_config_file(gin_config_path, skip_unknown=True)
    gin.finalize()

    run_eval()
示例#5
0
def main(config):
    """Start the worker."""

    gin.parse_config_file(config.config)
    logger = misc.utility.create_logger(name='es_worker{}'.format(
        config.worker_id),
                                        log_dir=config.log_dir)

    if config.master_address is not None:
        logger.info('master_address: {}'.format(config.master_address))
        channel = grpc.insecure_channel(
            config.master_address,
            [("grpc.max_receive_message_length", _MAX_MSG_LEN)])
        stub = protobuf.roll_out_service_pb2_grpc.ParameterSyncServiceStub(
            channel)
        worker = misc.utility.get_es_worker(logger=logger, master=stub)
    else:
        worker = misc.utility.get_es_worker(logger=logger)

    if config.run_on_gke:
        port = config.port
    else:
        port = config.port + config.worker_id
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=1),
                         options=[
                             ("grpc.max_send_message_length", _MAX_MSG_LEN),
                             ("grpc.max_receive_message_length", _MAX_MSG_LEN)
                         ])

    # Start the RPC server.
    protobuf.roll_out_service_pb2_grpc.add_RollOutServiceServicer_to_server(
        worker, server)
    server.add_insecure_port('[::]:{}'.format(port))
    server.start()
    logger.info('Listening to port {} ...'.format(port))

    try:
        while True:
            time.sleep(_ONE_DAY_IN_SECONDS)
    except KeyboardInterrupt:
        logger.info('Worker quit.')
  def score(self,
            inputs,
            targets,
            scores_file=None,
            checkpoint_steps=-1,
            vocabulary=None):
    """Computes log-likelihood of target per example in targets.

    Args:
      inputs: optional - a string (filename), or a list of strings (inputs)
      targets: a string (filename), or a list of strings (targets)
      scores_file: str, path to write example scores to, one per line.
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        inference will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run
        inference continuously waiting for new checkpoints. If -1, get the
        latest checkpoint from the model directory.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use the default SentencePieceVocabulary.
    """
    if checkpoint_steps == -1:
      checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)

    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(self._model_dir))

    if vocabulary is None:
      vocabulary = t5.data.get_default_vocabulary()

    if isinstance(targets, str):
      tf.logging.info("scoring targets from file %s" % targets)
      utils.score_from_files(self.estimator(vocabulary), vocabulary,
                             self._model_type, self.batch_size,
                             self._sequence_length, self._model_dir,
                             checkpoint_steps, inputs, targets, scores_file)
    else:
      tf.logging.info("scoring targets from list of strings")
      utils.score_from_strings(self.estimator(vocabulary), vocabulary,
                               self._model_type, self.batch_size,
                               self._sequence_length, self._model_dir,
                               checkpoint_steps, inputs, targets, scores_file)
示例#7
0
    def test_autoregressive_sample_reformer2_lsh_attn_quality(self):
        gin.add_config_file_search_path(_CONFIG_DIR)
        # 32 is the max length we trained the checkpoint for.
        test_lengths = [8, 16, 32]
        vocab_size = 13
        np.random.seed(0)
        for max_len in test_lengths:
            gin.clear_config()
            gin.parse_config_file('reformer2_copy.gin')
            gin.bind_parameter('LSHSelfAttention.predict_mem_len', 2 * max_len)
            gin.bind_parameter('LSHSelfAttention.predict_drop_len',
                               max_len // 4)

            pred_model = models.Reformer2(mode='predict')

            shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
            shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)

            model_path = os.path.join(_TESTDATA,
                                      'reformer2_copy_lsh_attn.pkl.gz')
            pred_model.init_from_file(model_path,
                                      weights_only=True,
                                      input_signature=(shape1l, shape11))
            initial_state = pred_model.state

            for _ in range(3):
                # pick a length in [1, max_len]
                inp_len = np.random.randint(low=1, high=max_len + 1)
                inputs = np.random.randint(low=1,
                                           high=vocab_size - 1,
                                           size=(1, inp_len))
                inputs = np.pad(inputs, [(0, 0), (0, max_len - inp_len)],
                                mode='constant',
                                constant_values=0)
                s = decoding.autoregressive_sample(pred_model,
                                                   inputs=inputs,
                                                   eos_id=-1,
                                                   max_length=inp_len,
                                                   temperature=0.0)
                np.testing.assert_equal(s[0], inputs[0, :inp_len])
                pred_model.state = initial_state
示例#8
0
    def export(self,
               export_dir=None,
               checkpoint_step=-1,
               beam_size=1,
               temperature=1.0,
               vocabulary=None):
        """Exports a TensorFlow SavedModel.

    Args:
      export_dir: str, a directory in which to export SavedModels. Will use
        `model_dir` if unspecified.
      checkpoint_step: int, checkpoint to export. If -1 (default), use the
        latest checkpoint from the pretrained model directory.
      beam_size: int, a number >= 1 specifying the number of beams to use for
        beam search.
      temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1)
        0.0 means argmax, 1.0 means sample according to predicted distribution.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use the default SentencePieceVocabulary.

    Returns:
      The string path to the exported directory.
    """
        if checkpoint_step == -1:
            checkpoint_step = _get_latest_checkpoint_from_dir(self._model_dir)
        with gin.unlock_config():
            gin.parse_config_file(_operative_config_path(self._model_dir))
            gin.bind_parameter("Bitransformer.decode.beam_size", beam_size)
            gin.bind_parameter("Bitransformer.decode.temperature", temperature)

        if vocabulary is None:
            vocabulary = t5.data.get_default_vocabulary()
        model_ckpt = "model.ckpt-" + str(checkpoint_step)
        export_dir = export_dir or self._model_dir
        return utils.export_model(self.estimator(vocabulary, disable_tpu=True),
                                  export_dir,
                                  vocabulary,
                                  self._sequence_length,
                                  batch_size=self.batch_size,
                                  checkpoint_path=os.path.join(
                                      self._model_dir, model_ckpt))
  def predict(self, input_file, output_file, checkpoint_steps=-1,
              beam_size=1, temperature=1.0, vocabulary=None):
    """Predicts targets from the given inputs.

    Args:
      input_file: str, path to a text file containing newline-separated input
        prompts to predict from.
      output_file: str, path prefix of output file to write predictions to. Note
        the checkpoint step will be appended to the given filename.
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        inference will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run
        inference continuously waiting for new checkpoints. If -1, get the
        latest checkpoint from the model directory.
      beam_size: int, a number >= 1 specifying the number of beams to use for
        beam search.
      temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1)
        0.0 means argmax, 1.0 means sample according to predicted distribution.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use the default SentencePieceVocabulary.
    """
    # TODO(sharannarang) : It would be nice to have a function like
    # load_checkpoint that loads the model once and then call decode_from_file
    # multiple times without having to restore the checkpoint weights again.
    # This would be particularly useful in colab demo.

    if checkpoint_steps == -1:
      checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)

    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(self._model_dir))
      gin.bind_parameter("Bitransformer.decode.beam_size", beam_size)
      gin.bind_parameter("Bitransformer.decode.temperature", temperature)
      gin.bind_parameter("tpu_estimator_model_fn.score_in_predict_mode", False)

    if vocabulary is None:
      vocabulary = t5.data.get_default_vocabulary()
    utils.infer_model(
        self.estimator(vocabulary), vocabulary, self._sequence_length,
        self.batch_size, self._model_type, self._model_dir, checkpoint_steps,
        input_file, output_file)
示例#10
0
    def test_reformer_wmt_ende(self):
        trax.math.disable_jit()

        batch_size_per_device = 2
        steps = 1
        n_layers = 2
        d_ff = 32

        gin.parse_config_file(
            os.path.join(_CONFIG_DIR, 'reformer_wmt_ende.gin'))

        gin.bind_parameter('data_streams.data_dir', _TESTDATA)
        gin.bind_parameter('batcher.batch_size_per_device',
                           batch_size_per_device)
        gin.bind_parameter('train.steps', steps)
        gin.bind_parameter('Reformer.n_encoder_layers', n_layers)
        gin.bind_parameter('Reformer.n_decoder_layers', n_layers)
        gin.bind_parameter('Reformer.d_ff', d_ff)

        with self.tmp_dir() as output_dir:
            _ = trainer_lib.train(output_dir=output_dir)
示例#11
0
def parallel_exps(var_par, varying_par_to_change, gin_path, func):
    """Main function to parallelize which loads the parameters for the real experiments
    and run both training and testing routines

    Parameters
    ----------
    var_par: list
        List of varying parameters for the single parallelized experiment

    Param: dict
        The dictionary containing the parameters
    """

    gin.parse_config_file(gin_path, skip_unknown=True)

    for i in range(len(var_par)):
        gin.bind_parameter(varying_par_to_change[i], var_par[i])
        # pdb.set_trace()

    model_runner = func()
    model_runner.run()
示例#12
0
def parse_gin(restore_dir):
    """Parse gin config from --gin_file, --gin_param, and the model directory."""
    # Add user folders to the gin search path.
    for gin_search_path in [GIN_PATH] + FLAGS.gin_search_path:
        gin.add_config_file_search_path(gin_search_path)

    # Parse gin configs, later calls override earlier ones.
    with gin.unlock_config():
        # Optimization defaults.
        use_tpu = bool(FLAGS.tpu)
        opt_default = 'base.gin' if not use_tpu else 'base_tpu.gin'
        gin.parse_config_file(os.path.join('optimization', opt_default))
        eval_default = 'eval/basic.gin'
        gin.parse_config_file(eval_default)

        # Load operative_config if it exists (model has already trained).
        operative_config = train_util.get_latest_operative_config(restore_dir)
        if tf.io.gfile.exists(operative_config):
            logging.info('Using operative config: %s', operative_config)
            operative_config = cloud.make_file_paths_local(
                operative_config, GIN_PATH)
            gin.parse_config_file(operative_config, skip_unknown=True)

        # User gin config and user hyperparameters from flags.
        gin_file = cloud.make_file_paths_local(FLAGS.gin_file, GIN_PATH)
        gin.parse_config_files_and_bindings(gin_file,
                                            FLAGS.gin_param,
                                            skip_unknown=True)
示例#13
0
 def eval(self,
          mixture_or_task_name,
          checkpoint_steps=None,
          summary_dir=None,
          split="validation"):
     """Evaluate the model on the given Mixture or Task.
     Args:
       mixture_or_task_name: str, the name of the Mixture or Task to evaluate on.
         Must be pre-registered in the global `TaskRegistry` or
         `MixtureRegistry.`
       checkpoint_steps: int, list of ints, or None. If an int or list of ints,
         evaluation will be run on the checkpoint files in `model_dir` whose
         global steps are closest to the global steps provided. If None, run eval
         continuously waiting for new checkpoints. If -1, get the latest
         checkpoint from the model directory.
       summary_dir: str, path to write TensorBoard events file summaries for
         eval. If None, use model_dir/eval_{split}.
       split: str, the split to evaluate on.
     """
     if checkpoint_steps == -1:
         checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)
     vocabulary = get_mixture_or_task_ll(
         mixture_or_task_name).get_vocabulary()
     dataset_fn = functools.partial(
         mesh_eval_dataset_fn_ll, mixture_or_task_name=mixture_or_task_name)
     with gin.unlock_config():
         gin.parse_config_file(_operative_config_path(self._model_dir))
     eval_model_ll(self.estimator(vocabulary),
                   vocabulary,
                   self._sequence_length,
                   self.batch_size,
                   split,
                   self._model_dir,
                   dataset_fn,
                   summary_dir,
                   checkpoint_steps,
                   attribute_bit=self.attribute_bit,
                   unsupervised_attribute_transfer_metrics=self.
                   unsupervised_attribute_transfer_metrics,
                   control_code_bool=self.control_code_bool)
示例#14
0
文件: train.py 项目: yaelandau22/rigl
def main(unused_argv):
  tf.random.set_seed(FLAGS.seed)
  init_timer = timer.Timer()
  init_timer.Start()

  if FLAGS.mode == 'hessian':
    # Load default values from the original experiment.
    FLAGS.preload_gin_config = os.path.join(FLAGS.logdir,
                                            'operative_config.gin')

  # Maybe preload a gin config.
  if FLAGS.preload_gin_config:
    config_path = FLAGS.preload_gin_config
    gin.parse_config_file(config_path)
    logging.info('Gin configuration pre-loaded from: %s', config_path)

  gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings)
  ds_train, ds_test, info = utils.get_dataset()
  input_shape = info.features['image'].shape
  num_classes = info.features['label'].num_classes
  logging.info('Input Shape: %s', input_shape)
  logging.info('train samples: %s', info.splits['train'].num_examples)
  logging.info('test samples: %s', info.splits['test'].num_examples)

  pruning_params = utils.get_pruning_params()
  model = utils.get_network(pruning_params, input_shape, num_classes)
  model.summary(print_fn=logging.info)
  if FLAGS.mode == 'train_eval':
    train_model(model, ds_train, ds_test, FLAGS.logdir)
  elif FLAGS.mode == 'hessian':
    test_model(model, ds_test)
    hessian(model, ds_train, FLAGS.logdir)
  logging.info('Total runtime: %.3f s', init_timer.GetDuration())

  logconfigfile_path = os.path.join(
      FLAGS.logdir,
      'hessian_' if FLAGS.mode == 'hessian' else '' + 'operative_config.gin')
  with tf.io.gfile.GFile(logconfigfile_path, 'w') as f:
    f.write('# Gin-Config:\n %s' % gin.config.operative_config_str())
  def eval(self, mixture_or_task_name, checkpoint_steps=None, summary_dir=None,
           split="validation", eval_with_score=False, beam_size=1, temperature=1.0):
    """Evaluate the model on the given Mixture or Task.

    Args:
      mixture_or_task_name: str, the name of the Mixture or Task to evaluate on.
        Must be pre-registered in the global `TaskRegistry` or
        `MixtureRegistry.`
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        evaluation will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run eval
        continuously waiting for new checkpoints. If -1, get the latest
        checkpoint from the model directory.
      summary_dir: str, path to write TensorBoard events file summaries for
        eval. If None, use model_dir/eval_{split}.
      split: str, the mixture/task split to evaluate on.
      eval_with_score: bool, whether to evaluate using log likelihood scores of
        targets instead of decoded predictions.
    """
    if checkpoint_steps == -1:
      checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)
    vocabulary = _get_vocabulary(mixture_or_task_name)
    dataset_fn = functools.partial(
        t5.models.mesh_transformer.mesh_eval_dataset_fn,
        mixture_or_task_name=mixture_or_task_name,
    )
    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(self._model_dir))
      gin.bind_parameter("Bitransformer.decode.beam_size", beam_size)
      gin.bind_parameter("Bitransformer.decode.temperature", temperature)
    
    estimator = self.estimator(
        vocabulary, score_in_predict_mode=eval_with_score)
    utils.eval_model(
        estimator=estimator, vocabulary=vocabulary,
        sequence_length=self._sequence_length, batch_size=self.batch_size,
        dataset_split=split, model_dir=self._model_dir,
        eval_dataset_fn=dataset_fn, eval_summary_dir=summary_dir,
        eval_checkpoint_step=checkpoint_steps, eval_with_score=eval_with_score)
示例#16
0
def parse_gin(model_dir):
    """Parse gin config from --gin_file, --gin_param, and the model directory."""
    # Add user folders to the gin search path.
    for gin_search_path in [GIN_PATH] + FLAGS.gin_search_path:
        gin.add_config_file_search_path(gin_search_path)

    # Parse gin configs, later calls override earlier ones.
    with gin.unlock_config():
        # Optimization defaults.
        use_tpu = bool(FLAGS.tpu)
        opt_default = 'base.gin' if not use_tpu else 'base_tpu.gin'
        gin.parse_config_file(os.path.join('optimization', opt_default))

        # Load operative_config if it exists (model has already trained).
        operative_config = os.path.join(model_dir, 'operative_config-0.gin')
        if tf.io.gfile.exists(operative_config):
            gin.parse_config_file(operative_config, skip_unknown=True)

        # User gin config and user hyperparameters from flags.
        gin.parse_config_files_and_bindings(FLAGS.gin_file,
                                            FLAGS.gin_param,
                                            skip_unknown=True)
示例#17
0
  def test_autoregressive_sample_reformer2_pure_lsh_attn_quality(self):
    gin.add_config_file_search_path(_CONFIG_DIR)
    max_len = 32  # 32 is the max length we trained the checkpoint for.
    test_lengths = [8, 16, 32]
    vocab_size = 13
    # The checkpoint is correct on ~90% sequences, set random seed to deflake.
    np.random.seed(0)
    for test_len in test_lengths:
      gin.clear_config()
      gin.parse_config_file('reformer2_purelsh_copy.gin')
      gin.bind_parameter('PureLSHSelfAttention.predict_mem_len', 2 * max_len)
      gin.bind_parameter('PureLSHSelfAttention.predict_drop_len', 2 * max_len)
      gin.bind_parameter('PureLSHSelfAttentionWrapper.bias', False)
      gin.bind_parameter('PureLSHSelfAttentionWrapper.num_weights', 2)

      pred_model = models.Reformer2(mode='predict')

      shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
      shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)

      model_path = os.path.join(_TESTDATA, 'reformer2_purelsh_copy.pkl.gz')
      pred_model.init_from_file(model_path, weights_only=True,
                                input_signature=(shape1l, shape11))
      initial_state = pred_model.state

      for _ in range(2):  # Set low to make the test run reasonably fast.
        # Pick a length in [1, test_len] at random.
        inp_len = np.random.randint(low=1, high=test_len + 1)
        inputs = np.random.randint(low=1, high=vocab_size-1, size=(1, inp_len))
        inputs = np.pad(inputs, [(0, 0), (0, max_len - inp_len)],
                        mode='constant', constant_values=0)
        s = decoding.autoregressive_sample(
            pred_model, inputs=inputs, eos_id=-1, max_length=inp_len,
            temperature=0.0)

        np.testing.assert_equal(s[0], inputs[0, :inp_len])
        pred_model.state = initial_state
    gin.clear_config()  # Make sure to not affect other tests.
示例#18
0
    def benchmark_halfcheetah_medium_v0(self):
        """Benchmarks MuJoCo HalfCheetah to 1M steps."""
        self.setUp()
        output_dir = self._get_test_output_dir('halfcheetah_medium_v0_02_eval')
        dataset_path = self.root_data_dir
        start_time_sec = time.time()
        gin.parse_config_file(
            'tf_agents/examples/cql_sac/kumar20/configs/mujoco_medium.gin')
        cql_sac_train_eval.train_eval(
            dataset_path=dataset_path,
            root_dir=output_dir,
            env_name='halfcheetah-medium-v0',
            num_gradient_updates=500000,  # Number of iterations.
            learner_iterations_per_call=500,
            data_shuffle_buffer_size=10000,
            data_num_shards=50,
            data_parallel_reads=500,
            data_prefetch=1000000,
            eval_interval=10000)
        wall_time_sec = time.time() - start_time_sec
        event_file = utils.find_event_log(os.path.join(output_dir, 'eval'))
        values, _ = utils.extract_event_log_values(event_file,
                                                   'Metrics/AverageReturn',
                                                   start_step=10000)

        # Min/Max ranges are very large to only hard fail if very broken. The system
        # monitoring the results owns looking for anomalies. These numbers are based
        # on the results that we were getting in MLCompass as of 04-NOV-2021.
        # Results at 500k steps and 1M steps are similar enough to not make it worth
        # running 1M.
        metric_500k = self.build_metric('average_return_at_env_step500000',
                                        values[500000],
                                        min_value=4400,
                                        max_value=5400)

        self.report_benchmark(wall_time=wall_time_sec,
                              metrics=[metric_500k],
                              extras={})
示例#19
0
    def test_reformer_noencdecattn_wmt_ende(self):
        trax.fastmath.disable_jit()

        batch_size_per_device = 1  # Ignored, but needs to be set.
        steps = 1
        n_layers = 2
        d_ff = 32

        gin.parse_config_file('reformer_noencdecattn_wmt_ende.gin')

        gin.bind_parameter('data_streams.data_dir', _TESTDATA)
        gin.bind_parameter('batcher.batch_size_per_device',
                           batch_size_per_device)
        gin.bind_parameter('batcher.buckets', ([512], [1, 1]))  # batch size 1.
        gin.bind_parameter('train.steps', steps)
        gin.bind_parameter('ReformerNoEncDecAttention.n_encoder_layers',
                           n_layers)
        gin.bind_parameter('ReformerNoEncDecAttention.n_decoder_layers',
                           n_layers)
        gin.bind_parameter('ReformerNoEncDecAttention.d_ff', d_ff)

        output_dir = self.create_tempdir().full_path
        _ = trainer_lib.train(output_dir=output_dir)
示例#20
0
    def test_terraformer_wmt_ende(self):
        batch_size_per_device = 2
        steps = 1
        n_layers = 2
        d_ff = 32

        gin.parse_config_file('terraformer_wmt_ende.gin')

        gin.bind_parameter('data_streams.data_dir', _TESTDATA)
        gin.bind_parameter('batcher.batch_size_per_device',
                           batch_size_per_device)
        gin.bind_parameter(
            'batcher.buckets',
            ([512], [batch_size_per_device, batch_size_per_device]))
        gin.bind_parameter('train.steps', steps)
        gin.bind_parameter('ConfigurableTerraformer.n_encoder_layers',
                           n_layers)
        gin.bind_parameter('ConfigurableTerraformer.n_decoder_layers',
                           n_layers)
        gin.bind_parameter('ConfigurableTerraformer.d_ff', d_ff)

        output_dir = self.create_tempdir().full_path
        _ = trainer_lib.train(output_dir=output_dir)
示例#21
0
    def export(self,
               export_dir=None,
               checkpoint_step=-1,
               beam_size=1,
               temperature=1.0,
               sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH):
        """Exports a TensorFlow SavedModel.

    Args:
      export_dir: str, a directory in which to export SavedModels. Will use
        `model_dir` if unspecified.
      checkpoint_step: int, checkpoint to export. If -1 (default), use the
        latest checkpoint from the pretrained model directory.
      beam_size: int, a number >= 1 specifying the number of beams to use for
        beam search.
      temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1)
        0.0 means argmax, 1.0 means sample according to predicted distribution.
      sentencepiece_model_path: str, path to the SentencePiece model file to use
        for decoding. Must match the one used during training.
    """
        if checkpoint_step == -1:
            checkpoint_step = _get_latest_checkpoint_from_dir(self._model_dir)
        with gin.unlock_config():
            gin.parse_config_file(_operative_config_path(self._model_dir))
            gin.bind_parameter("Bitransformer.decode.beam_size", beam_size)
            gin.bind_parameter("Bitransformer.decode.temperature", temperature)
            gin.bind_parameter("utils.get_variable_dtype.slice_dtype",
                               "float32")
            gin.bind_parameter("utils.get_variable_dtype.activation_dtype",
                               "float32")

        vocabulary = t5.data.SentencePieceVocabulary(sentencepiece_model_path)
        model_ckpt = "model.ckpt-" + str(checkpoint_step)
        export_dir = export_dir or self._model_dir
        utils.export_model(self.estimator(vocabulary), export_dir, vocabulary,
                           self._sequence_length,
                           os.path.join(self._model_dir, model_ckpt))
  def finetune(self, mixture_or_task_name, finetune_steps, pretrained_model_dir,
               pretrained_checkpoint_step=-1):
    """Finetunes a model from an existing checkpoint.

    Args:
      mixture_or_task_name: str, the name of the Mixture or Task to evaluate on.
        Must be pre-registered in the global `TaskRegistry` or
        `MixtureRegistry.`
      finetune_steps: int, the number of additional steps to train for.
      pretrained_model_dir: str, directory with pretrained model checkpoints and
        operative config.
      pretrained_checkpoint_step: int, checkpoint to initialize weights from. If
        -1, use the latest checkpoint from the pretrained model directory.
    """
    if pretrained_checkpoint_step == -1:
      checkpoint_step = _get_latest_checkpoint_from_dir(pretrained_model_dir)
    else:
      checkpoint_step = pretrained_checkpoint_step
    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(pretrained_model_dir))

    model_ckpt = "model.ckpt-" + str(checkpoint_step)
    self.train(mixture_or_task_name, checkpoint_step + finetune_steps,
               init_checkpoint=os.path.join(pretrained_model_dir, model_ckpt))
示例#23
0
    def test_reformer_noencdecattn_wmt_ende(self):
        trax.math.disable_jit()

        batch_size_per_device = 1  # Ignored, but needs to be set.
        steps = 1
        n_layers = 2
        d_ff = 32

        gin.parse_config_file(
            os.path.join(_CONFIG_DIR, 'reformer_noencdecattn_wmt_ende.gin'))

        gin.bind_parameter('data_streams.data_dir', _TESTDATA)
        gin.bind_parameter('batcher.batch_size_per_device',
                           batch_size_per_device)
        gin.bind_parameter('batcher.buckets', ([513], [1, 1]))  # batch size 1.
        gin.bind_parameter('train.steps', steps)
        gin.bind_parameter('ReformerNoEncDecAttention.n_encoder_layers',
                           n_layers)
        gin.bind_parameter('ReformerNoEncDecAttention.n_decoder_layers',
                           n_layers)
        gin.bind_parameter('ReformerNoEncDecAttention.d_ff', d_ff)

        with self.tmp_dir() as output_dir:
            _ = trainer_lib.train(output_dir=output_dir)
示例#24
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments for train.')

  if not FLAGS.gin_config:
    # Run the experiments locally.
    gin.parse_config_file(FLAGS.gin_config_file)
  else:
    # Run the experiments on a server.
    gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings)

  # create the `checkpoing` and `summary` directory used during training
  # to save/restore the model and write TF summaries.
  checkpoint_dir_str = gin.query_parameter('checkpoint_dir/macro.value')
  summary_dir_str = gin.query_parameter('summary_dir/macro.value')
  mcts_checkpoint_dir_str = os.path.join(checkpoint_dir_str, 'mcts_data')
  app_directories = [
      checkpoint_dir_str, summary_dir_str, mcts_checkpoint_dir_str
  ]

  for d in app_directories:
    directory_handling.ensure_dir_exists(d)

  ppo_train()
示例#25
0
文件: inference.py 项目: zeeps31/ddsp
 def parse_gin_config(self, ckpt):
     """Parse the model operative config with new length parameters."""
     with gin.unlock_config():
         ckpt_dir = os.path.dirname(ckpt)
         operative_config = train_util.get_latest_operative_config(ckpt_dir)
         print(f'Parsing from operative_config {operative_config}')
         gin.parse_config_file(operative_config, skip_unknown=True)
         # Set gin params to new length.
         # Remove reverb processor.
         pg_string = """ProcessorGroup.dag = [
   (@synths.Harmonic(),
     ['amps', 'harmonic_distribution', 'f0_hz']),
   (@synths.FilteredNoise(),
     ['noise_magnitudes']),
   (@processors.Add(),
     ['filtered_noise/signal', 'harmonic/signal']),
   ]"""
         gin.parse_config([
             'Harmonic.n_samples=%d' % self.n_samples,
             'FilteredNoise.n_samples=%d' % self.n_samples,
             'F0LoudnessPreprocessor.time_steps=%d' % self.time_steps,
             'oscillator_bank.use_angular_cumsum=True',
             pg_string,
         ])
示例#26
0
def parse_gin(restore_dir):
  """Parse gin config from --gin_file, --gin_param, and the model directory."""
  # Add user folders to the gin search path.
  for gin_search_path in [GIN_PATH] + FLAGS.gin_search_path:
    gin.config.add_config_file_search_path(gin_search_path)

  # Parse gin configs, later calls override earlier ones.
  with gin.unlock_config():
    # Optimization defaults.
    use_tpu = bool(FLAGS.tpu)
    opt_default = 'base.gin' if not use_tpu else 'base_tpu.gin'
    # USES main library gin--'optimization' folder ONLY ('models' and 'dataset' folders use drive gin)
    # gin.parse_config_file(os.path.join('ddsp/training/gin/optimization', opt_default))
    gin.parse_config_file(os.path.join('ddsp/training/gin/optimization', opt_default))

    # Load operative_config if it exists (model has already trained).
    operative_config = train_util.get_latest_operative_config(restore_dir)
    if tf.io.gfile.exists(operative_config):
      logging.info('Using operative config: %s', operative_config)
      gin.parse_config_file(operative_config, skip_unknown=True)

    # User gin config and user hyperparameters from flags.
    gin.config.parse_config_files_and_bindings(
        FLAGS.gin_file, FLAGS.gin_param, skip_unknown=True)
示例#27
0
def train(args):
    # Parse config file
    if args.config_file:
        if os.path.isfile(os.path.join(args.directory, args.config_file)):
            gin.parse_config_file(
                os.path.join(args.directory, args.config_file))
        elif os.path.isfile(args.config_file):
            gin.parse_config_file(args.config_file)
        else:
            raise FileNotFoundError('Failed to find config file {}'.format(
                args.config_file))

    kwargs = {'directory': args.directory}
    # If env id or env args are provided, use them for the training env
    # (otherwise, the configured defaults are used)
    make_args = {k: v for k, v in args.env_args} if args.env_args else {}
    if args.env_id:
        make_args['env'] = args.env_id
    if make_args:
        kwargs['env'] = lambda **kwa: stackrl.envs.make(**make_args, **kwa)

    # Run training
    training = stackrl.Training(**kwargs)
    training.run()
示例#28
0
    def __init__(self, data_path, mode, train_set, validation_set, test_set, max_way_train, max_way_test, max_support_train, max_support_test):

        self.data_path = data_path
        self.train_dataset_next_task = None
        self.validation_set_dict = {}
        self.test_set_dict = {}
        gin.parse_config_file('./meta_dataset_config.gin')

        if mode == 'train' or mode == 'train_test':
            train_episode_description = self._get_train_episode_description(max_way_train, max_support_train)
            self.train_dataset_next_task = self._init_multi_source_dataset(train_set, learning_spec.Split.TRAIN,
                                                                           train_episode_description)

            test_episode_description = self._get_test_episode_description(max_way_test, max_support_test)
            for item in validation_set:
                next_task = self.validation_dataset = self._init_single_source_dataset(item, learning_spec.Split.VALID,
                                                                                       test_episode_description)
                self.validation_set_dict[item] = next_task

        if mode == 'test' or mode == 'train_test' or mode == 'attack':
            test_episode_description = self._get_test_episode_description(max_way_test, max_support_test)
            for item in test_set:
                next_task = self._init_single_source_dataset(item, learning_spec.Split.TEST, test_episode_description)
                self.test_set_dict[item] = next_task
    if eval_flag:

        latest = tf.train.latest_checkpoint(checkpoint_dir)
        model.load_weights(latest)
        eval_model(eval_dir, model, system, model_name)


if __name__ == '__main__':

    root_dir = sys.argv[1]
    model_name = sys.argv[2]
    system_name = sys.argv[3]
    observations = sys.argv[4]
    num_train_traj = int(sys.argv[5])
    num_train_steps = int(sys.argv[6])
    seed = int(sys.argv[7])

    if not os.path.isdir(root_dir):
        os.makedirs(root_dir)

    #NOTE: Configs may overwrite arguments in prev. files
    gin.parse_config_file('./config/base.gin')
    gin.parse_config_file('./config/{}.gin'.format(observations))
    if model_name == 'VIN_SO2' and observations == 'pixels':
        # Overwrite step-size for SO2 VIN
        gin.parse_config_file('./config/pixels_SO2.gin')
    gin.parse_config_file('./config/{}.gin'.format(system_name))

    main(root_dir, model_name, system_name, observations, num_train_traj,
         num_train_steps, seed)
示例#30
0
                if self.results.get(ds_name, 0) < f1:
                    self.results[ds_name] = f1

    def finish(self):
        with self.summary_writer.as_default():
            for ds_name, f1 in self.results.items():
                tf.summary.scalar("test/" + ds_name + "/scene/best_f1", f1, step=0)


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Train TransNet")
    parser.add_argument("config", help="path to config")
    args = parser.parse_args()

    gin.parse_config_file(args.config)
    options = get_options_dict()

    trn_ds = input_processing.train_pipeline(options["trn_files"]) if len(options["trn_files"]) > 0 else None
    if options["transition_only_trn_files"] is not None:
        trn_ds_ = input_processing.train_transition_pipeline(options["transition_only_trn_files"])
        if trn_ds is not None:
            frac = options["transition_only_data_fraction"]
            trn_ds = tf.data.experimental.sample_from_datasets([trn_ds, trn_ds_], weights=[1 - frac, frac])
        else:
            trn_ds = trn_ds_

    tst_ds = [(name, input_processing.test_pipeline(files))
              for name, files in options["tst_files"].items()]

    if options["original_transnet"]: