Exemplo n.º 1
0
def parse_gin(restore_dir):
    """Parse gin config from --gin_file, --gin_param, and the model directory."""
    # Add user folders to the gin search path.
    for gin_search_path in [GIN_PATH] + FLAGS.gin_search_path:
        gin.add_config_file_search_path(gin_search_path)

    # Parse gin configs, later calls override earlier ones.
    with gin.unlock_config():
        # Optimization defaults.
        use_tpu = bool(FLAGS.tpu)
        opt_default = 'base.gin' if not use_tpu else 'base_tpu.gin'
        gin.parse_config_file(os.path.join('optimization', opt_default))
        eval_default = 'eval/basic.gin'
        gin.parse_config_file(eval_default)

        # Load operative_config if it exists (model has already trained).
        operative_config = train_util.get_latest_operative_config(restore_dir)
        if tf.io.gfile.exists(operative_config):
            logging.info('Using operative config: %s', operative_config)
            operative_config = cloud.make_file_paths_local(
                operative_config, GIN_PATH)
            gin.parse_config_file(operative_config, skip_unknown=True)

        # User gin config and user hyperparameters from flags.
        gin_file = cloud.make_file_paths_local(FLAGS.gin_file, GIN_PATH)
        gin.parse_config_files_and_bindings(gin_file,
                                            FLAGS.gin_param,
                                            skip_unknown=True)
Exemplo n.º 2
0
def parse_gin(restore_dir):
    """Parse gin config from --gin_file, --gin_param, and the model directory."""
    # Add user folders to the gin search path.
    for gin_search_path in [GIN_PATH] + FLAGS.gin_search_path:
        gin.add_config_file_search_path(gin_search_path)

    # Parse gin configs, later calls override earlier ones.
    with gin.unlock_config():
        # Optimization defaults.
        use_tpu = bool(FLAGS.tpu)
        opt_default = 'base.gin' if not use_tpu else 'base_tpu.gin'
        gin.parse_config_file(os.path.join('optimization', opt_default))

        # Load operative_config if it exists (model has already trained).
        # operative_config = train_util.get_latest_operative_config(restore_dir)
        # if tf.io.gfile.exists(operative_config):
        #   # Copy the config file from gstorage
        #   helper_functions.copy_config_file_from_gstorage(operative_config, LAST_OPERATIVE_CONFIG_PATH)
        #   logging.info('Using operative config: %s', operative_config)
        #   gin.parse_config_file(LAST_OPERATIVE_CONFIG_PATH, skip_unknown=True)
        # gin.parse_config_file(operative_config, skip_unknown=True)

        # User gin config and user hyperparameters from flags.
        gin.parse_config_files_and_bindings(FLAGS.gin_file,
                                            FLAGS.gin_param,
                                            skip_unknown=True)
Exemplo n.º 3
0
  def test_autoregressive_sample_reformer2_lsh_attn_quality(self):
    gin.add_config_file_search_path(_CONFIG_DIR)
    max_len = 32  # 32 is the max length we trained the checkpoint for.
    test_lengths = [8, 16, 32]
    vocab_size = 13
    # The checkpoint is correct on ~90% sequences, set random seed to deflake.
    np.random.seed(0)
    for test_len in test_lengths:
      gin.clear_config()
      gin.parse_config_file('reformer2_copy.gin')
      gin.bind_parameter('LSHSelfAttention.predict_mem_len', 2 * max_len)
      gin.bind_parameter('LSHSelfAttention.predict_drop_len', 2 * max_len)

      pred_model = models.Reformer2(mode='predict')

      shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
      shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)

      model_path = os.path.join(_TESTDATA, 'reformer2_copy_lsh_attn.pkl.gz')
      pred_model.init_from_file(model_path, weights_only=True,
                                input_signature=(shape1l, shape11))
      initial_state = pred_model.state

      for _ in range(2):  # Set low to make the test run reasonably fast.
        # Pick a length in [1, test_len] at random.
        inp_len = np.random.randint(low=1, high=test_len + 1)
        inputs = np.random.randint(low=1, high=vocab_size-1, size=(1, inp_len))
        inputs = np.pad(inputs, [(0, 0), (0, max_len - inp_len)],
                        mode='constant', constant_values=0)
        s = decoding.autoregressive_sample(
            pred_model, inputs=inputs, eos_id=-1, max_length=inp_len,
            temperature=0.0)
        np.testing.assert_equal(s[0], inputs[0, :inp_len])
        pred_model.state = initial_state
    gin.clear_config()  # Make sure to not affect other tests.
Exemplo n.º 4
0
def main(_):
  name = FLAGS.name
  if FLAGS.seed is not None and name:
    name = '-'.join([name, str(FLAGS.seed)])
  run = wandb.init(name=name, sync_tensorboard=True, entity='krshna', project='sqrl-neurips',
                   config=FLAGS, monitor_gym=FLAGS.monitor, config_exclude_keys=EXCLUDE_KEYS,
                   notes=FLAGS.notes, resume=FLAGS.resume_id)

  logging.set_verbosity(logging.INFO)
  if FLAGS.debug:
    logging.set_verbosity(logging.DEBUG)

  if os.environ.get('CONFIG_DIR'):
    gin.add_config_file_search_path(os.environ.get('CONFIG_DIR'))

  config = wandb.config

  # Only update root_path if not resuming a run
  if not wandb.run.resumed:
    update_root(config)

  if FLAGS.load_run:
    load_prev_run(config)

  gin_bindings = FLAGS.gin_param or []

  for gin_file in config.gin_files:
    if gin_file == 'sac_safe_online.gin':
      gin_file = 'sqrl.gin'
    gin.parse_config_file(gin_file, [])

  gin_bindings = gin_bindings_from_config(config) + gin_bindings
  gin.parse_config_files_and_bindings([], gin_bindings)

  if FLAGS.num_threads:
    tf.config.threading.set_inter_op_parallelism_threads(FLAGS.num_threads)

  trainer.train_eval(config.root_dir, load_root_dir=FLAGS.load_dir,
                     batch_size=config.batch_size,
                     seed=FLAGS.seed, train_metrics_callback=wandb.log,
                     eager_debug=FLAGS.eager_debug,
                     monitor=FLAGS.monitor,
                     debug_summaries=FLAGS.debug_summaries,
                     pretraining=(not FLAGS.finetune),
                     finetune_sc=FLAGS.finetune_sc, wandb=True)

  if config.train_finetune and not config.finetune:
    with gin.unlock_config():
      finetune_bindings = finetune_gin_bindings(config)
      gin.parse_config_files_and_bindings([], finetune_bindings)
    trainer.train_eval(config.root_dir, load_root_dir=FLAGS.load_dir,
                       batch_size=config.batch_size,
                       seed=FLAGS.seed, train_metrics_callback=wandb.log,
                       eager_debug=FLAGS.eager_debug,
                       monitor=FLAGS.monitor,
                       debug_summaries=FLAGS.debug_summaries,
                       pretraining=False,
                       finetune_sc=FLAGS.finetune_sc, wandb=True)
Exemplo n.º 5
0
def main(_):
    flags.mark_flags_as_required(['base_dir'])
    if FLAGS.custom_base_dir_from_hparams is not None:
        FLAGS.base_dir = os.path.join(FLAGS.base_dir,
                                      FLAGS.custom_base_dir_from_hparams)
    else:
        # Add Work unit to base directory path, if it exists.
        if 'xm_wid' in FLAGS and FLAGS.xm_wid > 0:
            FLAGS.base_dir = os.path.join(FLAGS.base_dir, str(FLAGS.xm_wid))
    xm_parameters = (None
                     if 'xm_parameters' not in FLAGS else FLAGS.xm_parameters)
    if xm_parameters:
        xm_params = json.loads(xm_parameters)
        if 'env_name' in xm_params:
            FLAGS.env_name = xm_params['env_name']
    if FLAGS.env_name is None:
        base_dir = os.path.join(
            FLAGS.base_dir, '{}_{}'.format(FLAGS.num_states,
                                           FLAGS.num_actions))
    else:
        base_dir = os.path.join(FLAGS.base_dir, FLAGS.env_name)
    base_dir = os.path.join(base_dir, 'PVF', FLAGS.estimator)
    if not tf.io.gfile.exists(base_dir):
        tf.io.gfile.makedirs(base_dir)
    if FLAGS.env_name is not None:
        gin.add_config_file_search_path(_ENV_CONFIG_PATH)
        gin.parse_config_files_and_bindings(
            config_files=[f'{FLAGS.env_name}.gin'],
            bindings=FLAGS.gin_bindings,
            skip_unknown=False)
        env_id = mon_minigrid.register_environment()
        env = gym.make(env_id)
        env = RGBImgObsWrapper(env)  # Get pixel observations
        # Get tabular observation and drop the 'mission' field:
        env = mdp_wrapper.MDPWrapper(env, get_rgb=False)
        env = coloring_wrapper.ColoringWrapper(env)
    if FLAGS.env_name is None:
        env = random_mdp.RandomMDP(FLAGS.num_states, FLAGS.num_actions)
        # We add the discount factor to the environment.

    env.gamma = FLAGS.gamma

    logging.set_verbosity(logging.INFO)
    gin_files = []
    gin_bindings = FLAGS.gin_bindings

    runner = TrainRunner(base_dir, env, FLAGS.epochs, FLAGS.lr,
                         FLAGS.estimator, FLAGS.alpha, FLAGS.optimizer,
                         FLAGS.use_l2_reg, FLAGS.reg_coeff,
                         FLAGS.use_penalty, FLAGS.j, FLAGS.num_rows,
                         jax.random.PRNGKey(0), FLAGS.epochs - 1,
                         FLAGS.epochs - 1)
    runner.train()
Exemplo n.º 6
0
def main(_):
    logging.set_verbosity(logging.INFO)
    if FLAGS.debug:
        logging.set_verbosity(logging.DEBUG)
    logging.debug('Executing eagerly: %s', tf.executing_eagerly())

    if os.environ.get('CONFIG_DIR'):
        gin.add_config_file_search_path(os.environ.get('CONFIG_DIR'))

    root_dir = FLAGS.root_dir or FLAGS.load_dir
    if os.environ.get('EXP_DIR') and not os.path.exists(root_dir):
        root_dir = os.path.join(os.environ.get('EXP_DIR'), root_dir)

    gin_files = FLAGS.gin_file or []
    if FLAGS.train_sc:
        op_config = osp.join(root_dir, 'train/operative_config-0.gin')
        if osp.exists(op_config):
            gin_files.append(op_config)
    logging.debug('parsing config files: %s', gin_files)

    gin_bindings = FLAGS.gin_param or []
    if FLAGS.num_steps:
        gin_bindings.append('NUM_STEPS = {}'.format(FLAGS.num_steps))
    if FLAGS.lr:
        gin_bindings.append('SC_LEARNING_RATE = {}'.format(FLAGS.lr))
    logging.debug('parsing gin bindings: %s', gin_bindings)
    gin.parse_config_files_and_bindings(gin_files,
                                        gin_bindings,
                                        skip_unknown=True)

    if FLAGS.seed:
        random.seed(FLAGS.seed)
        np.random.seed(FLAGS.seed)
        tf.compat.v1.set_random_seed(FLAGS.seed)
        logging.debug('Set seed: %d', FLAGS.seed)

    if FLAGS.train_sc:
        train_sc.train_eval(root_dir,
                            safety_critic_bias_init_val=FLAGS.sc_bias_init_val,
                            safety_critic_kernel_scale=FLAGS.sc_kernel_scale,
                            fail_weight=FLAGS.fail_weight,
                            seed=FLAGS.seed,
                            monitor=FLAGS.monitor,
                            debug_summaries=FLAGS.debug_summaries)
    else:
        trainer.train_eval(root_dir,
                           load_root_dir=FLAGS.load_dir,
                           pretraining=(not FLAGS.finetune),
                           monitor=FLAGS.monitor,
                           eager_debug=FLAGS.eager_debug,
                           seed=FLAGS.seed,
                           debug_summaries=FLAGS.debug_summaries)
Exemplo n.º 7
0
    def test_autoregressive_sample_terraformer_pure_lsh_attn_quality(self):
        gin.add_config_file_search_path(_CONFIG_DIR)
        max_len = 32  # 32 is the max length we trained the checkpoint for.
        test_lengths = [8, 16, 32]
        vocab_size = 13
        # The checkpoint is correct on ~90% sequences, set random seed to deflake.
        np.random.seed(0)
        for test_len in test_lengths:
            gin.clear_config()
            gin.parse_config_file('terraformer_purelsh_copy.gin')
            gin.bind_parameter('PureLSHSelfAttention.predict_mem_len',
                               2 * max_len)
            gin.bind_parameter('PureLSHSelfAttention.predict_drop_len',
                               2 * max_len)
            gin.bind_parameter('PureLSHSelfAttentionWrapper.bias', False)
            gin.bind_parameter('PureLSHSelfAttentionWrapper.num_weights', 2)

            pred_model = models.ConfigurableTerraformer(mode='predict')

            shape11 = shapes.ShapeDtype((1, 1), dtype=np.int32)
            shape1l = shapes.ShapeDtype((1, max_len), dtype=np.int32)

            model_path = os.path.join(_TESTDATA,
                                      'terraformer_purelsh_copy.pkl.gz')
            pred_model.init_from_file(model_path,
                                      weights_only=True,
                                      input_signature=(shape1l, shape11))
            initial_state = pred_model.state

            for _ in range(2):  # Set low to make the test run reasonably fast.
                # Pick a length in [1, test_len] at random.
                inp_len = np.random.randint(low=1, high=test_len + 1)
                inputs = np.random.randint(low=1,
                                           high=vocab_size - 1,
                                           size=(1, max_len))
                # TODO(jaszczur): properly fix padding in terraformer predict mode,
                # and add a test here.
                s = decoding.autoregressive_sample(pred_model,
                                                   inputs=inputs,
                                                   eos_id=-1,
                                                   max_length=inp_len,
                                                   temperature=0.0)

                np.testing.assert_equal(s[0], inputs[0, :inp_len])
                pred_model.state = initial_state
        gin.clear_config()  # Make sure to not affect other tests.
def run():
    """Run the beam pipeline to create synthetic dataset."""
    pipeline_options = beam.options.pipeline_options.PipelineOptions(
        FLAGS.pipeline_options)
    with beam.Pipeline(options=pipeline_options) as pipeline:
        for gin_search_path in [GIN_PATH] + FLAGS.gin_search_path:
            gin.add_config_file_search_path(gin_search_path)
        gin.parse_config_files_and_bindings(FLAGS.gin_file,
                                            FLAGS.gin_param,
                                            skip_unknown=True)

        np.random.seed(FLAGS.random_seed)
        _ = (pipeline
             | beam.Create(np.random.randint(2**32, size=FLAGS.num_examples))
             | beam.ParDo(GenerateExampleFn(gin.config_str()))
             | beam.Reshuffle()
             | beam.Map(_float_dict_to_tfexample)
             | beam.io.tfrecordio.WriteToTFRecord(FLAGS.output_tfrecord_path,
                                                  num_shards=FLAGS.num_shards,
                                                  coder=beam.coders.ProtoCoder(
                                                      tf.train.Example)))
Exemplo n.º 9
0
def main(_):
  logging.set_verbosity(logging.INFO)
  if FLAGS.debug:
    logging.set_verbosity(logging.DEBUG)
  logging.debug('Executing eagerly: %s', tf.executing_eagerly())
  if os.environ.get('CONFIG_DIR'):
    gin.add_config_file_search_path(os.environ.get('CONFIG_DIR'))
  root_dir = FLAGS.root_dir
  if os.environ.get('EXP_DIR'):
    root_dir = os.path.join(os.environ.get('EXP_DIR'), root_dir)

  logging.debug('parsing config files: %s', FLAGS.gin_file)
  if FLAGS.seed:
    # bindings.append(('trainer.train_eval.seed', FLAGS.seed))
    random.seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    tf.compat.v1.set_random_seed(FLAGS.seed)
    logging.debug('Set seed: %d', FLAGS.seed)
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param, skip_unknown=True)

  trainer.train_eval(root_dir, eager_debug=FLAGS.eager_debug, seed=FLAGS.seed)
Exemplo n.º 10
0
def parse_gin(model_dir):
    """Parse gin config from --gin_file, --gin_param, and the model directory."""
    # Add user folders to the gin search path.
    for gin_search_path in [GIN_PATH] + FLAGS.gin_search_path:
        gin.add_config_file_search_path(gin_search_path)

    # Parse gin configs, later calls override earlier ones.
    with gin.unlock_config():
        # Optimization defaults.
        use_tpu = bool(FLAGS.tpu)
        opt_default = 'base.gin' if not use_tpu else 'base_tpu.gin'
        gin.parse_config_file(os.path.join('optimization', opt_default))

        # Load operative_config if it exists (model has already trained).
        operative_config = os.path.join(model_dir, 'operative_config-0.gin')
        if tf.io.gfile.exists(operative_config):
            gin.parse_config_file(operative_config, skip_unknown=True)

        # User gin config and user hyperparameters from flags.
        gin.parse_config_files_and_bindings(FLAGS.gin_file,
                                            FLAGS.gin_param,
                                            skip_unknown=True)
Exemplo n.º 11
0
def main(_):
  logging.set_verbosity(logging.INFO)
  if FLAGS.debug:
    logging.set_verbosity(logging.DEBUG)
  if os.environ.get('CONFIG_DIR'):
    gin.add_config_file_search_path(os.environ.get('CONFIG_DIR'))
  config = wandb.config
  if not wandb.run.resumed:  # do not make changes
    root_path = []
    if os.environ.get('EXP_DIR'):
      root_path.append(os.environ.get('EXP_DIR'))
    root_path.append(config.root_dir)
    root_path.append(str(os.environ.get('WANDB_RUN_ID', 0)))
    config.update(dict(root_dir=osp.join(*root_path)), allow_val_change=True)
  else:
    config.update(dict(num_steps=FLAGS.num_steps), allow_val_change=True)
  gin_files = config.gin_files
  gin_bindings = gin_bindings_from_config(config)
  gin.parse_config_files_and_bindings(gin_files, gin_bindings)
  # tf.config.threading.set_inter_op_parallelism_threads(12)
  trainer.train_eval(config.root_dir, batch_size=config.batch_size, seed=FLAGS.seed,
                     train_metrics_callback=wandb.log, eager_debug=FLAGS.eager_debug,
                     monitor=FLAGS.monitor)
Exemplo n.º 12
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")

    # pylint:disable=g-import-not-at-top
    if FLAGS.task == "maze":
        from gfsa.training import train_maze_lib
        train_fn = train_maze_lib.train
    elif FLAGS.task == "edge_supervision":
        from gfsa.training import train_edge_supervision_lib
        train_fn = train_edge_supervision_lib.train
    elif FLAGS.task == "var_misuse":
        from gfsa.training import train_var_misuse_lib
        train_fn = train_var_misuse_lib.train
    else:
        raise ValueError(f"Unrecognized task {FLAGS.task}")
    # pylint:enable=g-import-not-at-top

    print("Setting up Gin configuration")

    for include_dir in FLAGS.gin_include_dirs:
        gin.add_config_file_search_path(include_dir)

    gin.bind_parameter("simple_runner.training_loop.artifacts_dir",
                       FLAGS.train_artifacts_dir)
    gin.bind_parameter("simple_runner.training_loop.log_dir",
                       FLAGS.train_log_dir)

    gin.parse_config_files_and_bindings(FLAGS.gin_files,
                                        FLAGS.gin_bindings,
                                        finalize_config=False,
                                        skip_unknown=False)

    gin.finalize()

    train_fn(runner=simple_runner)
Exemplo n.º 13
0
def main(_):
    if FLAGS.module_import:
        for module in FLAGS.module_import:
            importlib.import_module(module)

    if FLAGS.t5_tfds_data_dir:
        t5.data.set_tfds_data_dir_override(FLAGS.t5_tfds_data_dir)

    # Add search path for gin files stored in package.
    gin.add_config_file_search_path(
        pkg_resources.resource_filename(__name__, "gin"))
    try:
        suffix = 0
        command_dir = os.path.join(FLAGS.model_dir, "commands")
        tf.io.gfile.makedirs(command_dir)
        command_filename = os.path.join(command_dir, "command")
        while tf.io.gfile.exists(command_filename):
            suffix += 1
            command_filename = os.path.join(command_dir,
                                            "command.{}".format(suffix))
        with tf.io.gfile.GFile(command_filename, "w") as f:
            f.write(" ".join(sys.argv))
    except (tf.errors.PermissionDeniedError, tf.errors.InvalidArgumentError):
        logging.info(
            "No write access to model directory. Skipping command logging.")

    utils.parse_gin_defaults_and_flags(
        skip_unknown=(FLAGS.skip_all_gin_unknowns
                      or (mesh_transformer.DEPRECATED_GIN_REFERENCES +
                          tuple(FLAGS.additional_deprecated_gin_references))),
        finalize_config=False)
    # We must overide this binding explicitly since it is set to a deprecated
    # function or class in many existing configs.
    gin.bind_parameter("run.vocabulary", mesh_transformer.get_vocabulary())
    gin.finalize()

    # Set cache dir after loading gin to avoid unintentionally overriding it.
    t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs)

    if FLAGS.use_model_api:
        model = mtf_model.MtfModel(tpu_job_name=FLAGS.tpu_job_name,
                                   tpu=FLAGS.tpu,
                                   gcp_project=FLAGS.gcp_project,
                                   tpu_zone=FLAGS.tpu_zone,
                                   tpu_topology=FLAGS.tpu_topology,
                                   model_parallelism=FLAGS.model_parallelism,
                                   model_dir=FLAGS.model_dir,
                                   batch_size=FLAGS.batch_size,
                                   sequence_length={
                                       "inputs": FLAGS.input_sequence_length,
                                       "targets": FLAGS.target_sequence_length
                                   })

        if FLAGS.checkpoint_mode != "specific" and FLAGS.checkpoint_steps:
            raise ValueError(
                "checkpoint_mode is set to %s and checkpoint_steps is "
                "also set. To use a particular checkpoint, please set "
                "checkpoint_mode to 'specific'. For other modes, please "
                "ensure that checkpoint_steps is not set." %
                FLAGS.checkpoint_mode)

        if FLAGS.checkpoint_mode == "latest":
            checkpoint_steps = -1
        elif FLAGS.checkpoint_mode == "all":
            checkpoint_steps = "all"
        else:
            checkpoint_steps = [int(c) for c in FLAGS.checkpoint_steps]

        if FLAGS.mode == "train":
            model.train(mixture_or_task_name=FLAGS.mixture_or_task,
                        steps=FLAGS.train_steps)
        elif FLAGS.mode == "eval":
            model.eval(mixture_or_task_name=FLAGS.mixture_or_task,
                       checkpoint_steps=checkpoint_steps,
                       summary_dir=FLAGS.eval_summary_dir,
                       split=FLAGS.eval_split)
        elif FLAGS.mode == "finetune":
            if not (FLAGS.checkpoint_mode == "latest" or
                    (FLAGS.checkpoint_mode == "specific"
                     and len(FLAGS.checkpoint_steps) == 1)):
                raise ValueError(
                    "Must specify a single checkpoint for finetuning a model.")

            if isinstance(checkpoint_steps, list):
                checkpoint_steps = checkpoint_steps[0]

            model.finetune(mixture_or_task_name=FLAGS.mixture_or_task,
                           steps=FLAGS.train_steps,
                           pretrained_model_dir=FLAGS.pretrained_model_dir,
                           checkpoint_steps=checkpoint_steps)
        elif FLAGS.mode == "predict":
            model.predict(
                checkpoint_steps=checkpoint_steps,
                input_file=FLAGS.input_file,
                output_file=FLAGS.output_file,
                beam_size=FLAGS.beam_size,
                temperature=FLAGS.temperature,
                keep_top_k=FLAGS.keep_top_k,
            )
        elif FLAGS.mode == "score":
            model.score(FLAGS.input_file,
                        FLAGS.target_file,
                        scores_file=FLAGS.output_file,
                        checkpoint_steps=checkpoint_steps)
        elif FLAGS.mode in ("export_predict", "export_score"):
            if not (FLAGS.checkpoint_mode == "latest" or
                    (FLAGS.checkpoint_mode == "specific"
                     and len(FLAGS.checkpoint_steps) == 1)):
                raise ValueError(
                    "Must specify a single checkpoint for exporting a model.")

            if isinstance(checkpoint_steps, list):
                checkpoint_steps = checkpoint_steps[0]

            model.export(export_dir=FLAGS.export_dir,
                         checkpoint_step=checkpoint_steps,
                         beam_size=FLAGS.beam_size,
                         temperature=FLAGS.temperature,
                         keep_top_k=FLAGS.keep_top_k,
                         eval_with_score=(FLAGS.mode == "export_score"))
        else:
            raise ValueError("--mode flag must be set when using Model API.")
    else:
        if FLAGS.mode:
            raise ValueError(
                "--mode flag should only be set when using Model API.")
        if not FLAGS.tpu:
            with gin.unlock_config():
                gin.bind_parameter("utils.get_variable_dtype.slice_dtype",
                                   "float32")
                gin.bind_parameter("utils.get_variable_dtype.activation_dtype",
                                   "float32")
        utils.run(tpu_job_name=FLAGS.tpu_job_name,
                  tpu=FLAGS.tpu,
                  gcp_project=FLAGS.gcp_project,
                  tpu_zone=FLAGS.tpu_zone,
                  model_dir=FLAGS.model_dir)
Exemplo n.º 14
0
# Lint as: python3
"""Utility functions."""
import collections
import os
from os import path
from absl import flags
import dataclasses
import flax
import gin
import jax
import jax.numpy as jnp
import numpy as np
from PIL import Image

gin.add_config_file_search_path('../')

gin.config.external_configurable(flax.nn.relu, module='flax.nn')
gin.config.external_configurable(flax.nn.sigmoid, module='flax.nn')
gin.config.external_configurable(flax.nn.softplus, module='flax.nn')


@flax.struct.dataclass
class TrainState:
    optimizer: flax.optim.Optimizer


@flax.struct.dataclass
class Stats:
    loss: float
    losses: float
Exemplo n.º 15
0
def main(_):
    if FLAGS.module_import:
        for module in FLAGS.module_import:
            importlib.import_module(module)

    # Add search path for gin files stored in package.
    gin.add_config_file_search_path(
        pkg_resources.resource_filename(__name__, "gin"))

    models_dir_name = FLAGS.model_dir_name
    if FLAGS.model_dir_counter >= 0:
        models_dir_name += "_%s" % str(FLAGS.model_dir_counter)
    models_dir = os.path.join(FLAGS.base_dir, models_dir_name)

    model_dir = os.path.join(models_dir, FLAGS.model_size)
    try:
        tf.io.gfile.makedirs(model_dir)
        suffix = 0
        command_filename = os.path.join(model_dir, "command")
        while tf.io.gfile.exists(command_filename):
            suffix += 1
            command_filename = os.path.join(model_dir,
                                            "command.{}".format(suffix))
        with tf.io.gfile.GFile(command_filename, "w") as f:
            f.write(" ".join(sys.argv))
    except tf.errors.PermissionDeniedError:
        logging.info(
            "No write access to model directory. Skipping command logging.")

    utils.parse_gin_defaults_and_flags()

    # Load and print a few examples.
    st_task = TaskRegistry_ll.get("processed_cctk")
    sequence_length = {"inputs": 64, "targets": 64}
    sequence_length[
        "attribute"] = 64  # Or "attribute": 1 but packing not efficient...
    sequence_length["codeprefixedtargets"] = 64
    sequence_length["controlcode"] = 64

    with gin.config_scope('caet5'):
        ds = st_task.get_dataset(split="validation",
                                 sequence_length=sequence_length)

    print("A few preprocessed validation examples...")
    for ex in tfds.as_numpy(ds.take(5)):
        print(ex)
    """
    print("unitests")

    mixture_or_task_name = "processed_cctk"
    from caet5.models.mesh_transformer import mesh_train_dataset_fn_ll
    from caet5.data.utils import get_mixture_or_task_ll, MixtureRegistry_ll

    from mesh_tensorflow_caet5.dataset import pack_or_pad_ll

    mixture_or_task = get_mixture_or_task_ll("mixture_processed_cctk")

    with gin.config_scope('caet5'):
        dsbis = mixture_or_task.get_dataset(split="train", sequence_length=sequence_length)

    
    #ds2 = pack_or_pad_ll(dsbis, sequence_length, pack=False,
    #                     feature_keys=tuple(mixture_or_task.output_features), ensure_eos=True)
    

    def filter_attribute_1_fn(x):
        return tf.equal(x["attribute"][0], 1)

    def filter_attribute_2_fn(x):
        return tf.equal(x["attribute"][0], 2)

    ds_attribute_1 = dsbis.filter(filter_attribute_1_fn)
    ds_attribute_2 = dsbis.filter(filter_attribute_2_fn)

    ds2_attribute_1 = pack_or_pad_ll(
        ds_attribute_1, sequence_length, pack=False,
        feature_keys=tuple(mixture_or_task.output_features),
        ensure_eos=True)  # (not straightforward) Adapt packing so that pack=True
    ds2_attribute_2 = pack_or_pad_ll(
        ds_attribute_2, sequence_length, pack=False,
        feature_keys=tuple(mixture_or_task.output_features),
        ensure_eos=True)  # (not straightforward) Adapt packing so that pack=True

    ds3_attribute_1 = ds2_attribute_1
    ds3_attribute_2 = ds2_attribute_2

    def f1():
        return ds3_attribute_1

    def f2():
        return ds3_attribute_2

    def interleave_map_fn(x):
        return tf.cond(tf.equal(x, 0), f1, f2)

    ds3 = tf.data.Dataset.range(2).interleave(
        interleave_map_fn, cycle_length=2,
        block_length=4,
        num_parallel_calls=tf.data.experimental.AUTOTUNE)

    print("A few preprocessed validation examples...")
    for ex in tfds.as_numpy(ds3.take(80)):
        print(ex)
    """

    if FLAGS.use_model_api:
        # Modifying original T5 in CAE-T5
        transformer.make_bitransformer = make_bitransformer_ll
        utils.tpu_estimator_model_fn = tpu_estimator_model_fn_ll

        model_parallelism, train_batch_size, keep_checkpoint_max = {
            "small": (1, 256, 16),
            "base": (2, 128, 8),
            "large": (8, 64, 4),
            "3B": (8, 16, 1),
            "11B": (8, 16, 1)
        }[FLAGS.model_size]

        model = MtfModel_ll(
            tpu_job_name=FLAGS.tpu_job_name,
            tpu=FLAGS.tpu,
            gcp_project=FLAGS.gcp_project,
            tpu_zone=FLAGS.tpu_zone,
            model_dir=model_dir,
            model_parallelism=model_parallelism,
            batch_size=train_batch_size,
            learning_rate_schedule=0.003,
            save_checkpoints_steps=2000,
            keep_checkpoint_max=keep_checkpoint_max,  # if ON_CLOUD else None,
            iterations_per_loop=100,
            model_type="bitransformer",
            unsupervised_attribute_transfer_metrics=True)

        if FLAGS.checkpoint_mode != "specific" and FLAGS.checkpoint_steps:
            raise ValueError(
                "checkpoint_mode is set to %s and checkpoint_steps is "
                "also set. To use a particular checkpoint, please set "
                "checkpoint_mode to 'specific'. For other modes, please "
                "ensure that checkpoint_steps is not set." %
                FLAGS.checkpoint_mode)

        if FLAGS.checkpoint_mode == "latest":
            checkpoint_steps = -1
        elif FLAGS.checkpoint_mode == "all":
            checkpoint_steps = "all"
        else:
            checkpoint_steps = [int(c) for c in FLAGS.checkpoint_steps]

        if FLAGS.mode == "finetune":
            pretrained_dir = os.path.join(FLAGS.base_pretrained_model_dir,
                                          FLAGS.model_size)

            model.finetune(mixture_or_task_name=FLAGS.mixture_or_task,
                           pretrained_model_dir=pretrained_dir,
                           finetune_steps=FLAGS.train_steps)

        elif FLAGS.mode == "eval":
            model.batch_size = train_batch_size * 4
            model.eval(mixture_or_task_name=FLAGS.mixture_or_task,
                       checkpoint_steps=checkpoint_steps,
                       summary_dir=FLAGS.eval_summary_dir,
                       split=FLAGS.eval_split)

            # print_random_predictions("yelp", sequence_length, model_dir, n=10)

        elif FLAGS.mode == "predict":
            if FLAGS.predict_batch_size > 0:
                model.batch_size = FLAGS.predict_batch_size
            model.predict(checkpoint_steps=checkpoint_steps,
                          input_file=FLAGS.input_file,
                          output_file=FLAGS.output_file,
                          temperature=0)
        else:
            raise ValueError("--mode flag must be set when using Model API.")

    else:
        raise NotImplementedError()
Exemplo n.º 16
0
def main(_):
    if FLAGS.module_import:
        for module in FLAGS.module_import:
            importlib.import_module(module)

    if FLAGS.t5_tfds_data_dir:
        t5.data.set_tfds_data_dir_override(FLAGS.t5_tfds_data_dir)
    t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs)

    # Add search path for gin files stored in package.
    gin.add_config_file_search_path(
        pkg_resources.resource_filename(__name__, "gin"))
    try:
        tf.io.gfile.makedirs(FLAGS.model_dir)
        suffix = 0
        command_filename = os.path.join(FLAGS.model_dir, "command")
        while tf.io.gfile.exists(command_filename):
            suffix += 1
            command_filename = os.path.join(FLAGS.model_dir,
                                            "command.{}".format(suffix))
        with tf.io.gfile.GFile(command_filename, "w") as f:
            f.write(" ".join(sys.argv))
    except tf.errors.PermissionDeniedError:
        logging.info(
            "No write access to model directory. Skipping command logging.")

    utils.parse_gin_defaults_and_flags()

    if FLAGS.use_model_api:
        model = mtf_model.MtfModel(tpu_job_name=FLAGS.tpu_job_name,
                                   tpu=FLAGS.tpu,
                                   gcp_project=FLAGS.gcp_project,
                                   tpu_zone=FLAGS.tpu_zone,
                                   model_dir=FLAGS.model_dir)

        if FLAGS.checkpoint_mode != "specific" and FLAGS.checkpoint_steps:
            raise ValueError(
                "checkpoint_mode is set to %s and checkpoint_steps is "
                "also set. To use a particular checkpoint, please set "
                "checkpoint_mode to 'specific'. For other modes, please "
                "ensure that checkpoint_steps is not set." %
                FLAGS.checkpoint_mode)

        if FLAGS.checkpoint_mode == "latest":
            checkpoint_steps = -1
        elif FLAGS.checkpoint_mode == "all":
            checkpoint_steps = "all"
        else:
            checkpoint_steps = [int(c) for c in FLAGS.checkpoint_steps]

        if FLAGS.mode == "train":
            model.train(mixture_or_task_name=FLAGS.mixture_or_task,
                        steps=FLAGS.train_steps)
        elif FLAGS.mode == "eval":
            model.eval(mixture_or_task_name=FLAGS.mixture_or_task,
                       checkpoint_steps=checkpoint_steps,
                       summary_dir=FLAGS.eval_summary_dir,
                       split=FLAGS.eval_split)
        elif FLAGS.mode == "finetune":
            if not (FLAGS.checkpoint_mode == "latest" or
                    (FLAGS.checkpoint_mode == "specific"
                     and len(FLAGS.checkpoint_steps) == 1)):
                raise ValueError(
                    "Must specify a single checkpoint for finetuning a model.")

            if isinstance(checkpoint_steps, list):
                checkpoint_steps = checkpoint_steps[0]

            model.finetune(mixture_or_task_name=FLAGS.mixture_or_task,
                           steps=FLAGS.train_steps,
                           pretrained_model_dir=FLAGS.pretrained_model_dir,
                           checkpoint_steps=checkpoint_steps)
        elif FLAGS.mode == "predict":
            model.predict(checkpoint_steps=checkpoint_steps,
                          input_file=FLAGS.input_file,
                          output_file=FLAGS.output_file)
        elif FLAGS.mode == "score":
            model.score(FLAGS.input_file,
                        FLAGS.target_file,
                        scores_file=FLAGS.output_file,
                        checkpoint_steps=checkpoint_steps)
        elif FLAGS.mode == "export":
            if not (FLAGS.checkpoint_mode == "latest" or
                    (FLAGS.checkpoint_mode == "specific"
                     and len(FLAGS.checkpoint_steps) == 1)):
                raise ValueError(
                    "Must specify a single checkpoint for exporting a model.")

            if isinstance(checkpoint_steps, list):
                checkpoint_steps = checkpoint_steps[0]

            model.batch_size = FLAGS.export_batch_size
            model.export(export_dir=FLAGS.export_dir,
                         checkpoint_step=checkpoint_steps)
        else:
            raise ValueError("--mode flag must be set when using Model API.")
    else:
        if FLAGS.mode:
            raise ValueError(
                "--mode flag should only be set when using Model API.")
        if not FLAGS.tpu:
            with gin.unlock_config():
                gin.bind_parameter("utils.get_variable_dtype.slice_dtype",
                                   "float32")
                gin.bind_parameter("utils.get_variable_dtype.activation_dtype",
                                   "float32")
        utils.run(tpu_job_name=FLAGS.tpu_job_name,
                  tpu=FLAGS.tpu,
                  gcp_project=FLAGS.gcp_project,
                  tpu_zone=FLAGS.tpu_zone,
                  model_dir=FLAGS.model_dir)
def main(_):
    flags.mark_flags_as_required(["task"])

    if FLAGS.module_import:
        import_modules(FLAGS.module_import)

    # Load gin parameters if they've been defined.
    try:
        for gin_file_path in FLAGS.gin_location_prefix:
            gin.add_config_file_search_path(gin_file_path)
        gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
    except AttributeError:
        # Otherwise, use default settings.
        gin.parse_config_files_and_bindings(None, None)

    total_examples = 0
    if FLAGS.task is not None:
        task_or_mixture = seqio.TaskRegistry.get(FLAGS.task)
    elif FLAGS.mixture is not None:
        task_or_mixture = seqio.MixtureRegistry.get(FLAGS.mixture)

    ds = task_or_mixture.get_dataset(sequence_length=sequence_length(),
                                     split=FLAGS.split,
                                     use_cached=False,
                                     shuffle=FLAGS.shuffle)

    keys = re.findall(r"{([\w+]+)}", FLAGS.format_string)

    def _example_to_string(ex):
        key_to_string = {}
        for k in keys:
            if k not in ex:
                key_to_string[k] = ""
                continue
            value = ex[k]
            if FLAGS.detokenize:
                try:
                    value = task_or_mixture.output_features[
                        k].vocabulary.decode_tf(tf.abs(value))
                except RuntimeError as err:
                    value = f"Error {err} while decoding {value}"
                if (FLAGS.apply_postprocess_fn and k == "targets"
                        and hasattr(task_or_mixture, "postprocess_fn")):
                    value = task_or_mixture.postprocess_fn(value)
            if tf.rank(value) == 0:
                value = [value]
            if tf.is_numeric_tensor(value):
                value = tf.strings.format("{}",
                                          tf.squeeze(value),
                                          summarize=-1)
            else:
                value = tf.strings.join(value, separator="\n\n")
            key_to_string[k] = pretty(value.numpy().decode("utf-8"))
        return FLAGS.format_string.format(**key_to_string)

    for ex in ds:
        print(_example_to_string(ex))
        total_examples += 1
        if total_examples == FLAGS.max_examples:
            break
    return
Exemplo n.º 18
0
def main(_):
    if FLAGS.module_import:
        for module in FLAGS.module_import:
            importlib.import_module(module)

    if FLAGS.t5_tfds_data_dir:
        t5.data.set_tfds_data_dir_override(FLAGS.tfds_data_dir)
    t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs)

    # Add search path for gin files stored in package.
    gin.add_config_file_search_path(
        pkg_resources.resource_filename(__name__, "gin"))

    tf.io.gfile.makedirs(FLAGS.model_dir)
    suffix = 0
    command_filename = os.path.join(FLAGS.model_dir, "command")
    while tf.io.gfile.exists(command_filename):
        suffix += 1
        command_filename = os.path.join(FLAGS.model_dir,
                                        "command.{}".format(suffix))
    with tf.io.gfile.GFile(command_filename, "w") as f:
        f.write(" ".join(sys.argv))

    utils.parse_gin_defaults_and_flags()

    if FLAGS.use_model_api:
        model = mtf_model.MtfModel(tpu_job_name=FLAGS.tpu_job_name,
                                   tpu=FLAGS.tpu,
                                   gcp_project=FLAGS.gcp_project,
                                   tpu_zone=FLAGS.tpu_zone,
                                   model_dir=FLAGS.model_dir)

        if FLAGS.checkpoint_mode == "latest":
            ckpts = tf.io.gfile.glob(FLAGS.model_dir + "model.*index")
            ckpts = [re.sub(".*ckpt-", "", c) for c in ckpts]
            ckpts = sorted([int(c.replace(".index", "")) for c in ckpts])
            checkpoint_step = ckpts[-1]
        elif FLAGS.checkpoint_mode == "all":
            checkpoint_step = "all"
        else:
            checkpoint_step = [int(c) for c in FLAGS.checkpoint_steps]

        if FLAGS.mode == "train":
            model.train(mixture_or_task_name=FLAGS.mixture_or_task,
                        steps=FLAGS.train_steps)
        elif FLAGS.mode == "eval":
            model.eval(mixture_or_task_name=FLAGS.mixture_or_task,
                       checkpoint_step=checkpoint_step,
                       summary_dir=FLAGS.eval_summary_dir,
                       split=FLAGS.eval_split)
        else:
            model.predict(checkpoint_step=checkpoint_step,
                          input_file=FLAGS.input_file,
                          output_file=FLAGS.output_file)

    else:
        utils.run(tpu_job_name=FLAGS.tpu_job_name,
                  tpu=FLAGS.tpu,
                  gcp_project=FLAGS.gcp_project,
                  tpu_zone=FLAGS.tpu_zone,
                  model_dir=FLAGS.model_dir)
Exemplo n.º 19
0
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Tests for ddsp.training.models.autoencoder."""

from absl.testing import parameterized
from ddsp.core import tf_float32
from ddsp.training import models
import gin
import numpy as np
import pkg_resources
import tensorflow as tf

GIN_PATH = pkg_resources.resource_filename(__name__, '../gin')
gin.add_config_file_search_path(GIN_PATH)


class AutoencoderTest(parameterized.TestCase, tf.test.TestCase):
    def setUp(self):
        """Create some dummy input data for the chain."""
        super().setUp()
        # Create inputs.
        self.n_batch = 4
        self.n_frames = 1001
        self.n_samples = 64000
        inputs = {
            'loudness_db': np.zeros([self.n_batch, self.n_frames]),
            'f0_hz': np.zeros([self.n_batch, self.n_frames]),
            'audio': np.random.randn(self.n_batch, self.n_samples),
        }
Exemplo n.º 20
0
def main(_):
  flags.mark_flags_as_required(['base_dir'])
  if FLAGS.custom_base_dir_from_hparams is not None:
    FLAGS.base_dir = os.path.join(FLAGS.base_dir,
                                  FLAGS.custom_base_dir_from_hparams)
  else:
    # Add Work unit to base directory path, if it exists.
    if 'xm_wid' in FLAGS and FLAGS.xm_wid > 0:
      FLAGS.base_dir = os.path.join(FLAGS.base_dir, str(FLAGS.xm_wid))
  xm_parameters = (None
                   if 'xm_parameters' not in FLAGS else FLAGS.xm_parameters)
  if xm_parameters:
    xm_params = json.loads(xm_parameters)
    if 'env_name' in xm_params:
      FLAGS.env_name = xm_params['env_name']
  if FLAGS.env_name is None:
    base_dir = os.path.join(
        FLAGS.base_dir, '{}_{}'.format(FLAGS.num_states, FLAGS.num_actions))
  else:
    base_dir = os.path.join(FLAGS.base_dir, FLAGS.env_name)
  base_dir = os.path.join(base_dir, 'PVF', FLAGS.estimator, f'lr_{FLAGS.lr}')
  if not tf.io.gfile.exists(base_dir):
    tf.io.gfile.makedirs(base_dir)
  if FLAGS.env_name is not None:
    gin.add_config_file_search_path(_ENV_CONFIG_PATH)
    gin.parse_config_files_and_bindings(
        config_files=[f'{FLAGS.env_name}.gin'],
        bindings=FLAGS.gin_bindings,
        skip_unknown=False)
    env_id = mon_minigrid.register_environment()
    env = gym.make(env_id)
    env = RGBImgObsWrapper(env)  # Get pixel observations
    # Get tabular observation and drop the 'mission' field:
    env = mdp_wrapper.MDPWrapper(env, get_rgb=False)
    env = coloring_wrapper.ColoringWrapper(env)
  if FLAGS.env_name is None:
    env = random_mdp.RandomMDP(FLAGS.num_states, FLAGS.num_actions)
    # We add the discount factor to the environment.
  env.gamma = FLAGS.gamma
  P = utils.transition_matrix(env, rl_basics.policy_random(env))  # pylint: disable=invalid-name
  S = P.shape[0]  # pylint: disable=invalid-name
  Psi = jnp.linalg.solve(jnp.eye(S) - env.gamma * P, jnp.eye(S))  # pylint: disable=invalid-name
  # Normalize tasks so that they have maximum value 1.
  max_task_value = np.max(Psi, axis=0)
  Psi /= max_task_value  # pylint: disable=invalid-name

  left_vectors, _, _ = jnp.linalg.svd(Psi)  # pylint: disable=invalid-names
  approx_error = utils.approx_error(left_vectors, FLAGS.d, Psi)

  #   Initialization of Phi
  representation_init = jax.random.normal(  # pylint: disable=invalid-names
      jax.random.PRNGKey(0),
      (S, FLAGS.d),  # pylint: disable=invalid-name
      dtype=jnp.float64)
  representations, grads = train(representation_init,
                                 Psi, FLAGS.epochs, FLAGS.lr,
                                 jax.random.PRNGKey(0), FLAGS.estimator,
                                 FLAGS.alpha, FLAGS.optimizer, FLAGS.use_l2_reg,
                                 FLAGS.reg_coeff, FLAGS.use_penalty, FLAGS.j,
                                 FLAGS.num_rows, FLAGS.skipsize_train)

  gm_distances = calc_gm_distances(representations, left_vectors[:, :FLAGS.d],
                                   FLAGS.skipsize)
  x_len = len(gm_distances)
  frob_norms = calc_frob_norms(representations, Psi, FLAGS.skipsize)
  if FLAGS.d == 1:
    dot_products = calc_dot_products(representations, left_vectors[:, :FLAGS.d],
                                     FLAGS.skipsize)
  else:
    dot_products = np.zeros((x_len,))
  grad_norms = calc_grad_norms(grads, FLAGS.skipsize)
  phi_norms = calc_Phi_norm(representations, FLAGS.skipsize)
  phi_ranks = calc_sranks(representations, FLAGS.skipsize)

  prefix = f'alpha{FLAGS.alpha}_j{FLAGS.j}_d{FLAGS.d}_regcoeff{FLAGS.reg_coeff}'

  with tf.io.gfile.GFile(osp.join(base_dir, f'{prefix}.npy'), 'wb') as f:
    np.save(
        f, {
            'gm_distances': gm_distances,
            'dot_products': dot_products,
            'frob_norms': frob_norms,
            'approx_error': approx_error,
            'grad_norms': grad_norms,
            'representations': representations,
            'phi_norms': phi_norms,
            'phi_ranks': phi_ranks
        },
        allow_pickle=True)
Exemplo n.º 21
0
 def setUp(self):
     super().setUp()
     gin.clear_config()
     gin.add_config_file_search_path(_CONFIG_DIR)
     test_utils.ensure_flag('test_tmpdir')
Exemplo n.º 22
0
def main(_):
    # https://github.com/google-research/text-to-text-transfer-transformer/blob/c0ea75dbe9e35a629ae2e3c964ef32adc0e997f3/t5/models/mesh_transformer_main.py#L149
    # Add search path for gin files stored in package.
    gin.add_config_file_search_path(
        pkg_resources.resource_filename(__name__, "gin"))
    gin.parse_config_files_and_bindings(FLAGS.gin_file,
                                        FLAGS.gin_param,
                                        finalize_config=True)
    pl.seed_everything(1234)

    task_functions_maps = get_tasks_functions_maps(
        partial(get_default_preprocessing_functions,
                str_replace_newlines=FLAGS.str_replace_newline))

    train_keynames = get_all_keynames_from_dir(FLAGS.train_basedir)
    val_keynames = get_all_keynames_from_dir(FLAGS.val_basedir)

    train_datasets = get_datasets_dict_from_task_functions_map(
        keynames=train_keynames,
        tasks_functions_maps=task_functions_maps,
        t5_prefix=FLAGS.t5_tokenizer_prefix,
        max_source_length=FLAGS.max_source_length)

    val_datasets = get_datasets_dict_from_task_functions_map(
        keynames=val_keynames,
        tasks_functions_maps=task_functions_maps,
        t5_prefix=FLAGS.t5_tokenizer_prefix,
        max_source_length=FLAGS.max_source_length)

    # Initializing model
    model = T5OCRBaseline(t5_model_prefix=FLAGS.t5_model_prefix,
                          t5_tokenizer_prefix=FLAGS.t5_tokenizer_prefix,
                          optimizer=FLAGS.optimizer,
                          learning_rate=FLAGS.learning_rate,
                          generate_max_length=FLAGS.generate_max_length)

    # Trainer
    if FLAGS.debug:
        logger = False
        trainer_callbacks = []
    else:
        logger = pl.loggers.NeptuneLogger(
            close_after_fit=False,
            api_key=os.environ["NEPTUNE_API_TOKEN"],
            # project_name is set via gin file
            # params=None,
            tags=[FLAGS.t5_model_prefix, FLAGS.task_train, 't5_ocr_baseline'])
        checkpoint_callback = pl.callbacks.ModelCheckpoint(
            prefix=
            f"experiment_id={logger.experiment.id}-task={task_train}-t5_model_prefix={model.t5_model_prefix.replace('-', '_')}",
            dirpath=os.path.join(FLAGS.checkpoint_basedir, "t5_ocr_baseline"),
            filename=
            "{step}-{epoch}-{val_precision:.6f}-{val_recall:.6f}-{val_f1:.6f}-{val_exact_match:.6f}",
            monitor="val_f1",
            mode="max",
            save_top_k=1,
            verbose=True)
        # Patience comes from gin
        early_stop_callback = pl.callbacks.EarlyStopping(
            monitor='val_f1', mode='max', patience=FLAGS.patience)
        trainer_callbacks = [checkpoint_callback, early_stop_callback]

    trainer = pl.Trainer(checkpoint_callback=not (FLAGS.debug),
                         log_gpu_memory=True,
                         profiler=FLAGS.debug,
                         logger=logger,
                         callbacks=trainer_callbacks,
                         progress_bar_refresh_rate=1,
                         log_every_n_steps=1)

    # Dataloaders
    train_loader_kwargs = {
        'num_workers': mp.cpu_count(),
        'shuffle': True if (trainer.overfit_batches == 0) else False,
        'pin_memory': True
    }

    if trainer.overfit_batches != 0:
        with gin.unlock_config():
            gin.bind_parameter(
                'get_dataloaders_dict_from_datasets_dict.batch_size', 1)

    eval_loader_kwargs = {**train_loader_kwargs, **{'shuffle': False}}

    train_dataloaders = get_dataloaders_dict_from_datasets_dict(
        datasets_dict=train_datasets, dataloader_kwargs=train_loader_kwargs)
    val_dataloaders = get_dataloaders_dict_from_datasets_dict(
        datasets_dict=val_datasets, dataloader_kwargs=eval_loader_kwargs)

    print(f'gin total config: {gin.config_str()}')
    print(f'gin operative config: {gin.operative_config_str()}')
    print(f"flags used: {FLAGS.flags_into_string()}")

    trainer.fit(model,
                train_dataloader=train_dataloaders[task_train],
                val_dataloaders=val_dataloaders[task_train])

    # Logging best metrics and saving best checkpoint on Neptune experiment
    if logger:
        trainer.logger.experiment.log_text(
            log_name='best_model_path',
            x=trainer.checkpoint_callback.best_model_path)
        trainer.logger.experiment.log_metric(
            'best_model_val_f1',
            trainer.checkpoint_callback.best_model_score.item())
        if FLAGS.upload_best_checkpoint:
            trainer.logger.experiment.log_artifact(
                trainer.checkpoint_callback.best_model_path)

        trainer.logger.experiment.stop()
Exemplo n.º 23
0
def main(_):
    if FLAGS.module_import:
        for module in FLAGS.module_import:
            importlib.import_module(module)

    if FLAGS.t5_tfds_data_dir:
        t5.data.set_tfds_data_dir_override(FLAGS.t5_tfds_data_dir)
    t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs)

    # Add search path for gin files stored in package.
    gin.add_config_file_search_path(
        pkg_resources.resource_filename(__name__, "gin"))
    try:
        tf.io.gfile.makedirs(FLAGS.model_dir)
        suffix = 0
        command_filename = os.path.join(FLAGS.model_dir, "command")
        while tf.io.gfile.exists(command_filename):
            suffix += 1
            command_filename = os.path.join(FLAGS.model_dir,
                                            "command.{}".format(suffix))
        with tf.io.gfile.GFile(command_filename, "w") as f:
            f.write(" ".join(sys.argv))
    except tf.errors.PermissionDeniedError:
        logging.info(
            "No write access to model directory. Skipping command logging.")

    utils.parse_gin_defaults_and_flags()

    if FLAGS.use_model_api:
        model = mtf_model.MtfModel(tpu_job_name=FLAGS.tpu_job_name,
                                   tpu=FLAGS.tpu,
                                   gcp_project=FLAGS.gcp_project,
                                   tpu_zone=FLAGS.tpu_zone,
                                   model_dir=FLAGS.model_dir)

        if FLAGS.checkpoint_mode == "latest":
            checkpoint_steps = -1
        elif FLAGS.checkpoint_mode == "all":
            checkpoint_steps = "all"
        else:
            checkpoint_steps = [int(c) for c in FLAGS.checkpoint_steps]

        if FLAGS.mode == "train":
            model.train(mixture_or_task_name=FLAGS.mixture_or_task,
                        steps=FLAGS.train_steps)
        elif FLAGS.mode == "eval":
            model.eval(mixture_or_task_name=FLAGS.mixture_or_task,
                       checkpoint_steps=checkpoint_steps,
                       summary_dir=FLAGS.eval_summary_dir,
                       split=FLAGS.eval_split)
        elif FLAGS.mode == "finetune":

            if not (FLAGS.checkpoint_mode == "latest" or
                    (FLAGS.checkpoint_mode == "specific"
                     and len(FLAGS.checkpoint_steps) == 1)):
                raise ValueError(
                    "Must specify a single checkpoint for finetuning a model.")

            if isinstance(checkpoint_steps, list):
                checkpoint_steps = checkpoint_steps[0]

            model.finetune(mixture_or_task_name=FLAGS.mixture_or_task,
                           steps=FLAGS.train_steps,
                           pretrained_model_dir=FLAGS.pretrained_model_dir,
                           checkpoint_steps=checkpoint_steps)
        elif FLAGS.mode == "predict":
            model.predict(checkpoint_steps=checkpoint_steps,
                          input_file=FLAGS.input_file,
                          output_file=FLAGS.output_file)
        elif FLAGS.mode == "export":
            if not (FLAGS.checkpoint_mode == "latest" or
                    (FLAGS.checkpoint_mode == "specific"
                     and len(FLAGS.checkpoint_steps) == 1)):
                raise ValueError(
                    "Must specify a single checkpoint for exporting a model.")

            if isinstance(checkpoint_steps, list):
                checkpoint_steps = checkpoint_steps[0]

            model.export(export_dir=FLAGS.export_dir,
                         checkpoint_step=checkpoint_steps)
    else:
        utils.run(tpu_job_name=FLAGS.tpu_job_name,
                  tpu=FLAGS.tpu,
                  gcp_project=FLAGS.gcp_project,
                  tpu_zone=FLAGS.tpu_zone,
                  model_dir=FLAGS.model_dir)
Exemplo n.º 24
0
Arquivo: config.py Projeto: jackd/grax
"""
Can be imported from `gin` to add directory and `grax/projects` to gin search path.

Example config file:

```gin
import grax.config
include "grax_config/single/fit.gin"
include "gat/configs/pubmed.gin"
```

"""
import os

import gin

base_dir = os.path.dirname(__file__)
for path in base_dir, os.path.join(base_dir, "projects"):
    gin.add_config_file_search_path(path)
def main(_):
    # https://github.com/google-research/text-to-text-transfer-transformer/blob/c0ea75dbe9e35a629ae2e3c964ef32adc0e997f3/t5/models/mesh_transformer_main.py#L149
    # Add search path for gin files stored in package.
    gin.add_config_file_search_path(
        pkg_resources.resource_filename(__name__, "gin"))
    gin.parse_config_files_and_bindings(FLAGS.gin_file,
                                        FLAGS.gin_param,
                                        finalize_config=True)
    pl.seed_everything(1234)
    with gin.config_scope('sroie_t5_baseline'):
        task_functions_maps = get_tasks_functions_maps()

    # Datasets
    with gin.config_scope('train_sroie'):
        train_keynames = get_all_keynames_from_dir()

    with gin.config_scope('validation_sroie'):
        val_keynames = get_all_keynames_from_dir()

    train_datasets = get_datasets_dict_from_task_functions_map(
        keynames=train_keynames, tasks_functions_maps=task_functions_maps)
    val_datasets = get_datasets_dict_from_task_functions_map(
        keynames=val_keynames, tasks_functions_maps=task_functions_maps)

    with gin.config_scope('task_train'):
        task_train = operative_macro()

    # Initializing model
    model = T5OCRBaseline()

    # Trainer
    if FLAGS.debug:
        logger = False
        trainer_callbacks = []
    else:
        logger = NeptuneLogger(
            close_after_fit=False,
            api_key=os.environ["NEPTUNE_API_TOKEN"],
            # project_name is set via gin file
            # params=None,
            tags=[model.t5_model_prefix, task_train, 't5_ocr_baseline'])
        with gin.config_scope('sroie_t5_baseline'):
            checkpoint_callback = config_model_checkpoint(
                monitor=None if FLAGS.best_model_run_mode else "val_f1",
                dirpath=("/home/marcospiau/final_project_ia376j/checkpoints/"
                         f"{logger.project_name.replace('/', '_')}/"
                         "t5_ocr_baseline/"),
                prefix=(
                    f"experiment_id={logger.experiment.id}-task={task_train}-"
                    "t5_model_prefix="
                    f"{model.t5_model_prefix.replace('-', '_')}"),
                filename=("{step}-{epoch}-{val_precision:.6f}-{val_recall:.6f}"
                          "-{val_f1:.6f}-{val_exact_match:.6f}"),
                mode="max",
                save_top_k=None if FLAGS.best_model_run_mode else 1,
                verbose=True)
        early_stop_callback = config_early_stopping_callback()
        trainer_callbacks = [checkpoint_callback, early_stop_callback]

    trainer = Trainer(
        checkpoint_callback=not (FLAGS.debug),
        log_gpu_memory=True,
        # profiler=FLAGS.debug,
        logger=logger,
        callbacks=trainer_callbacks,
        progress_bar_refresh_rate=1,
        log_every_n_steps=1)
    # Dataloaders
    train_loader_kwargs = {
        'num_workers': mp.cpu_count(),
        'shuffle': True if (trainer.overfit_batches == 0) else False,
        'pin_memory': True
    }

    if trainer.overfit_batches != 0:
        with gin.unlock_config():
            gin.bind_parameter(
                'get_dataloaders_dict_from_datasets_dict.batch_size', 1)

    eval_loader_kwargs = {**train_loader_kwargs, **{'shuffle': False}}

    train_dataloaders = get_dataloaders_dict_from_datasets_dict(
        datasets_dict=train_datasets, dataloader_kwargs=train_loader_kwargs)
    val_dataloaders = get_dataloaders_dict_from_datasets_dict(
        datasets_dict=val_datasets, dataloader_kwargs=eval_loader_kwargs)

    # Logging important artifacts and params
    if logger:
        to_upload = {
            'gin_operative_config.gin': gin.operative_config_str(),
            'gin_complete_config.gin': gin.config_str(),
            'abseil_flags.txt': FLAGS.flags_into_string()
        }
        for destination, content in to_upload.items():
            buffer = StringIO(initial_value=content)
            buffer.seek(0)
            logger.log_artifact(buffer, destination=destination)
        params_to_log = dict()
        params_to_log['str_replace_newlines'] = gin.query_parameter(
            'sroie_t5_baseline/get_default_preprocessing_functions.'
            'str_replace_newlines')
        params_to_log['task_train'] = task_train
        params_to_log['patience'] = early_stop_callback.patience
        params_to_log['max_epochs'] = trainer.max_epochs
        params_to_log['min_epochs'] = trainer.min_epochs
        params_to_log[
            'accumulate_grad_batches'] = trainer.accumulate_grad_batches
        params_to_log['batch_size'] = train_dataloaders[task_train].batch_size

        for k, v in params_to_log.items():
            logger.experiment.set_property(k, v)

    trainer.fit(model,
                train_dataloader=train_dataloaders[task_train],
                val_dataloaders=val_dataloaders[task_train])

    # Logging best metrics and saving best checkpoint on Neptune experiment
    if logger:
        trainer.logger.experiment.log_text(
            log_name='best_model_path',
            x=trainer.checkpoint_callback.best_model_path)
        if not (FLAGS.best_model_run_mode):
            trainer.logger.experiment.log_metric(
                'best_model_val_f1',
                trainer.checkpoint_callback.best_model_score.item())
        if FLAGS.upload_best_checkpoint:
            trainer.logger.experiment.log_artifact(
                trainer.checkpoint_callback.best_model_path)

        trainer.logger.experiment.stop()
Exemplo n.º 26
0
 def setUp(self):
   super().setUp()
   gin.clear_config()
   gin.add_config_file_search_path(_CONFIG_DIR)