def main(_):
    np.random.seed(FLAGS.task)
    tf.set_random_seed(FLAGS.task)

    if FLAGS.distributed:
        task = FLAGS.task
    else:
        task = 0

    if FLAGS.gin_config:
        if tf.gfile.Exists(FLAGS.gin_config):
            # Parse as a file.
            with tf.gfile.Open(FLAGS.gin_config) as f:
                gin.parse_config(f)
        else:
            gin.parse_config(FLAGS.gin_config)

    gin.finalize()

    if FLAGS.run_mode == 'collect_eval_once':
        train_collect_eval.train_collect_eval(root_dir=FLAGS.root_dir,
                                              train_fn=None,
                                              task=FLAGS.task)
    elif FLAGS.run_mode == 'train_only':
        train_collect_eval.train_collect_eval(root_dir=FLAGS.root_dir,
                                              do_collect_eval=False,
                                              task=task,
                                              master=FLAGS.master,
                                              ps_tasks=FLAGS.ps_tasks)
    elif FLAGS.run_mode == 'collect_eval_loop':
        raise NotImplementedError('collect_eval_loops')
    else:
        # Synchronous train-collect-eval.
        train_collect_eval.train_collect_eval(root_dir=FLAGS.root_dir,
                                              task=task)
Exemplo n.º 2
0
def load_operative_gin_configurations(operative_config_dir):
    """Load operative Gin configurations from the given directory."""
    gin_log_file = operative_config_path(operative_config_dir)
    with gin.unlock_config():
        gin.parse_config_file(gin_log_file)
    gin.finalize()
    logging.info('Operative Gin configurations loaded from %s.', gin_log_file)
Exemplo n.º 3
0
    def _setup(self, config):
        util.tf_init(gpus=None,
                     allow_growth=config.get('allow_growth', True),
                     eager=False)
        logging.set_verbosity(config.get('verbosity', logging.INFO))

        logging.info("calling setup")
        config_files = config['config_files']
        if isinstance(config_files, six.string_types):
            config_files = [config_files]

        with gin.unlock_config():
            config_dir = config.get('config_dir')
            if config_dir is None:
                config_dir = util.get_config_dir()
            util.parse_config(config_dir,
                              config_files,
                              config['bindings'],
                              finalize_config=False)
            _parse_config_item(None, config['mutable_bindings'])
            gin.finalize()

        self._generators = None
        self._problem = None
        self._optimizer = None
        self._model = None
        self._reset_callbacks()

        wp = config.get('initial_weights_path', None)
        if wp is not None:
            self.model.load_weights(wp)

        self._save_operative_config()
Exemplo n.º 4
0
def main(config, eager):
    """Trains an agent."""
    ray.init()
    gin.parse_config_file(config)
    gin.finalize()
    tf.config.run_functions_eagerly(eager)
    train()
Exemplo n.º 5
0
def main(agent_dir, num_episodes, max_episode_steps, save_videos):
    """Visualizes an agent acting in its environment."""
    # noinspection PyUnresolvedReferences
    from interact import train

    gin.parse_config_file(os.path.join(agent_dir, "config.gin"))
    gin.finalize()
    play(agent_dir, num_episodes, max_episode_steps, save_videos)
Exemplo n.º 6
0
def parse_gin_in_request(request: Dict, ) -> None:
    """ Parse any gin related keys in a request dict.

    Args:
        request (Dict): Request dictionary (see zpy.requests).
    """
    zpy.gin.parse_gin_config(gin_config=request.get('gin_config', None))
    zpy.gin.parse_gin_bindings(gin_bindings=request.get('gin_bindings', None))
    gin.finalize()
Exemplo n.º 7
0
def main(_):
    tf.compat.v1.enable_v2_behavior()
    logging.set_verbosity(logging.INFO)
    gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
    gin.finalize()
    experiment = Experiment(api_key="ZIIxUqFtxJ6uSt34ifrIAcZVw",
                            project_name="safemrl",
                            workspace="krishpop")
    train_eval(FLAGS.root_dir, run_eval=FLAGS.run_eval)
Exemplo n.º 8
0
def parse_config(text_name: str = 'config') -> None:
    """ Load gin config for scene """
    _text = bpy.data.texts.get(text_name, None)
    if _text is None:
        log.warning(f'Could not find {text_name} in texts.')
        return
    log.info(f'Loading gin config {text_name}')
    gin.enter_interactive_mode()
    with gin.unlock_config():
        gin.parse_config(_text.as_string())
        gin.finalize()
def main(argv):
    del argv

    # Load gin.config settings stored in model directory. It might take some time
    # for the train script to start up and actually write out a gin config file.
    # Wait 10 minutes (periodically checking for file existence) before giving up.
    gin_config_path = os.path.join(FLAGS.base_dir, 'config.gin')
    if not gfile.exists(gin_config_path):
        raise ValueError('Could not find config.gin in "%s"' % FLAGS.base_dir)

    gin.parse_config_file(gin_config_path, skip_unknown=True)
    gin.finalize()
    all_distractors_eval(FLAGS.base_dir)
Exemplo n.º 10
0
def parse_config(text_name: str = 'config', ) -> None:
    """ Parses the gin config text in Blender.

    Args:
        text_name (str, optional): Name of the config text. Defaults to 'config'.
    """
    _text = bpy.data.texts.get(text_name, None)
    if _text is None:
        log.warning(f'Could not find {text_name} in texts.')
        return
    log.info(f'Loading gin config {text_name}')
    gin.enter_interactive_mode()
    with gin.unlock_config():
        gin.parse_config(_text.as_string())
        gin.finalize()
Exemplo n.º 11
0
 def test_finalize(self):
     gin.bind_parameter('f.x', 'global')
     gin.finalize()
     self.assertTrue(gin.config_is_locked())
     with GinState() as temp_state:
         gin.bind_parameter('f.x', 'temp')
         self.assertEqual(gin.query_parameter('f.x'), 'temp')
         self.assertFalse(gin.config_is_locked())
     self.assertTrue(gin.config_is_locked())
     with temp_state:
         self.assertFalse(gin.config_is_locked())
         gin.config.finalize()
         self.assertTrue(gin.config_is_locked())
     with temp_state:
         self.assertTrue(gin.config_is_locked())
Exemplo n.º 12
0
def main(argv):
    del argv

    # Load gin.config settings stored in model directory. It is possible to run
    # this script concurrently with the train script. In this case, wait for the
    # train script to start up and actually write out a gin config file.
    # Wait 10 minutes (periodically checking for file existence) before giving up.
    gin_config_path = os.path.join(FLAGS.base_dir, 'config.gin')
    num_tries = 0
    while not gfile.exists(gin_config_path):
        num_tries += 1
        if num_tries >= 10:
            raise ValueError('Could not find config.gin in "%s"' %
                             FLAGS.base_dir)
        time.sleep(60)

    gin.parse_config_file(gin_config_path, skip_unknown=True)
    gin.finalize()

    run_eval()
def gin_sacred(config_files, main_fcn, db_name='causal_sparse', base_dir=None):
    """launch a sacred experiment from .gin config files."""
    config_names = load_config_files(config_files)
    gin.finalize()

    name = '_'.join(config_names)
    if base_dir is None:
        base_dir = os.getcwd()

    run_uid = datetime.datetime.utcnow().strftime("%Y_%m_%d_%H_%M_%S")
    run_uid += "__" + str(uuid1())

    base_dir = os.path.join(base_dir, name, run_uid)

    os.makedirs(base_dir, exist_ok=True)

    inner_fcn1 = partial(inner_fcn, main_fcn=main_fcn)
    inner_fcn1.__name__ = main_fcn.__name__

    analysis = tune_gin(inner_fcn1, config_update={'name': name, 'base_dir': base_dir,
                                                   'db_name': db_name, 'sources': config_files},
                        name=name)

    return analysis
Exemplo n.º 14
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError("Too many command-line arguments.")

    # pylint:disable=g-import-not-at-top
    if FLAGS.task == "maze":
        from gfsa.training import train_maze_lib
        train_fn = train_maze_lib.train
    elif FLAGS.task == "edge_supervision":
        from gfsa.training import train_edge_supervision_lib
        train_fn = train_edge_supervision_lib.train
    elif FLAGS.task == "var_misuse":
        from gfsa.training import train_var_misuse_lib
        train_fn = train_var_misuse_lib.train
    else:
        raise ValueError(f"Unrecognized task {FLAGS.task}")
    # pylint:enable=g-import-not-at-top

    print("Setting up Gin configuration")

    for include_dir in FLAGS.gin_include_dirs:
        gin.add_config_file_search_path(include_dir)

    gin.bind_parameter("simple_runner.training_loop.artifacts_dir",
                       FLAGS.train_artifacts_dir)
    gin.bind_parameter("simple_runner.training_loop.log_dir",
                       FLAGS.train_log_dir)

    gin.parse_config_files_and_bindings(FLAGS.gin_files,
                                        FLAGS.gin_bindings,
                                        finalize_config=False,
                                        skip_unknown=False)

    gin.finalize()

    train_fn(runner=simple_runner)
Exemplo n.º 15
0
def main(_):
    if FLAGS.module_import:
        for module in FLAGS.module_import:
            importlib.import_module(module)

    if FLAGS.t5_tfds_data_dir:
        t5.data.set_tfds_data_dir_override(FLAGS.t5_tfds_data_dir)

    # Add search path for gin files stored in package.
    gin.add_config_file_search_path(
        pkg_resources.resource_filename(__name__, "gin"))
    try:
        suffix = 0
        command_dir = os.path.join(FLAGS.model_dir, "commands")
        tf.io.gfile.makedirs(command_dir)
        command_filename = os.path.join(command_dir, "command")
        while tf.io.gfile.exists(command_filename):
            suffix += 1
            command_filename = os.path.join(command_dir,
                                            "command.{}".format(suffix))
        with tf.io.gfile.GFile(command_filename, "w") as f:
            f.write(" ".join(sys.argv))
    except (tf.errors.PermissionDeniedError, tf.errors.InvalidArgumentError):
        logging.info(
            "No write access to model directory. Skipping command logging.")

    utils.parse_gin_defaults_and_flags(
        skip_unknown=(FLAGS.skip_all_gin_unknowns
                      or (mesh_transformer.DEPRECATED_GIN_REFERENCES +
                          tuple(FLAGS.additional_deprecated_gin_references))),
        finalize_config=False)
    # We must overide this binding explicitly since it is set to a deprecated
    # function or class in many existing configs.
    gin.bind_parameter("run.vocabulary", mesh_transformer.get_vocabulary())
    gin.finalize()

    # Set cache dir after loading gin to avoid unintentionally overriding it.
    t5.data.add_global_cache_dirs(FLAGS.additional_task_cache_dirs)

    if FLAGS.use_model_api:
        model = mtf_model.MtfModel(tpu_job_name=FLAGS.tpu_job_name,
                                   tpu=FLAGS.tpu,
                                   gcp_project=FLAGS.gcp_project,
                                   tpu_zone=FLAGS.tpu_zone,
                                   tpu_topology=FLAGS.tpu_topology,
                                   model_parallelism=FLAGS.model_parallelism,
                                   model_dir=FLAGS.model_dir,
                                   batch_size=FLAGS.batch_size,
                                   sequence_length={
                                       "inputs": FLAGS.input_sequence_length,
                                       "targets": FLAGS.target_sequence_length
                                   })

        if FLAGS.checkpoint_mode != "specific" and FLAGS.checkpoint_steps:
            raise ValueError(
                "checkpoint_mode is set to %s and checkpoint_steps is "
                "also set. To use a particular checkpoint, please set "
                "checkpoint_mode to 'specific'. For other modes, please "
                "ensure that checkpoint_steps is not set." %
                FLAGS.checkpoint_mode)

        if FLAGS.checkpoint_mode == "latest":
            checkpoint_steps = -1
        elif FLAGS.checkpoint_mode == "all":
            checkpoint_steps = "all"
        else:
            checkpoint_steps = [int(c) for c in FLAGS.checkpoint_steps]

        if FLAGS.mode == "train":
            model.train(mixture_or_task_name=FLAGS.mixture_or_task,
                        steps=FLAGS.train_steps)
        elif FLAGS.mode == "eval":
            model.eval(mixture_or_task_name=FLAGS.mixture_or_task,
                       checkpoint_steps=checkpoint_steps,
                       summary_dir=FLAGS.eval_summary_dir,
                       split=FLAGS.eval_split)
        elif FLAGS.mode == "finetune":
            if not (FLAGS.checkpoint_mode == "latest" or
                    (FLAGS.checkpoint_mode == "specific"
                     and len(FLAGS.checkpoint_steps) == 1)):
                raise ValueError(
                    "Must specify a single checkpoint for finetuning a model.")

            if isinstance(checkpoint_steps, list):
                checkpoint_steps = checkpoint_steps[0]

            model.finetune(mixture_or_task_name=FLAGS.mixture_or_task,
                           steps=FLAGS.train_steps,
                           pretrained_model_dir=FLAGS.pretrained_model_dir,
                           checkpoint_steps=checkpoint_steps)
        elif FLAGS.mode == "predict":
            model.predict(
                checkpoint_steps=checkpoint_steps,
                input_file=FLAGS.input_file,
                output_file=FLAGS.output_file,
                beam_size=FLAGS.beam_size,
                temperature=FLAGS.temperature,
                keep_top_k=FLAGS.keep_top_k,
            )
        elif FLAGS.mode == "score":
            model.score(FLAGS.input_file,
                        FLAGS.target_file,
                        scores_file=FLAGS.output_file,
                        checkpoint_steps=checkpoint_steps)
        elif FLAGS.mode in ("export_predict", "export_score"):
            if not (FLAGS.checkpoint_mode == "latest" or
                    (FLAGS.checkpoint_mode == "specific"
                     and len(FLAGS.checkpoint_steps) == 1)):
                raise ValueError(
                    "Must specify a single checkpoint for exporting a model.")

            if isinstance(checkpoint_steps, list):
                checkpoint_steps = checkpoint_steps[0]

            model.export(export_dir=FLAGS.export_dir,
                         checkpoint_step=checkpoint_steps,
                         beam_size=FLAGS.beam_size,
                         temperature=FLAGS.temperature,
                         keep_top_k=FLAGS.keep_top_k,
                         eval_with_score=(FLAGS.mode == "export_score"))
        else:
            raise ValueError("--mode flag must be set when using Model API.")
    else:
        if FLAGS.mode:
            raise ValueError(
                "--mode flag should only be set when using Model API.")
        if not FLAGS.tpu:
            with gin.unlock_config():
                gin.bind_parameter("utils.get_variable_dtype.slice_dtype",
                                   "float32")
                gin.bind_parameter("utils.get_variable_dtype.activation_dtype",
                                   "float32")
        utils.run(tpu_job_name=FLAGS.tpu_job_name,
                  tpu=FLAGS.tpu,
                  gcp_project=FLAGS.gcp_project,
                  tpu_zone=FLAGS.tpu_zone,
                  model_dir=FLAGS.model_dir)
Exemplo n.º 16
0
import gin
import runner


def main():
    runner.run()


if __name__ == '__main__':
    gin.parse_config_file('test.gin')
    gin.finalize()
    main()
    with gin.unlock_config():
        gin.parse_config_file('test2.gin')
    main()
Exemplo n.º 17
0
Arquivo: gin.py Projeto: kant/zpy
def parse_gin_in_request(request: Dict) -> None:
    """ Parse any gin related keys in a request dict. """
    zpy.gin.parse_gin_config(gin_config=request.get('gin_config', None))
    zpy.gin.parse_gin_bindings(gin_bindings=request.get('gin_bindings', None))
    gin.finalize()