def get_flow_inference_function(
    checkpoint, height,
    width):
  """Restores a raft model from a checkpoint and returns the inference function.

  Args:
    checkpoint: Path to the checkpoint that will be used.
    height: Image height that should be used for inference.
    width: Image width that will be used for inference.

  Returns:
    Inference function of the restored model.
  """
  tf.keras.backend.clear_session()
  gin.parse_config('raft_model_parameters.max_rec_iters = 32')
  smurf = SMURFNet(
      checkpoint, flow_architecture='raft', feature_architecture='raft')
  smurf.restore()
  return functools.partial(
      smurf.infer_no_tf_function,
      input_height=height,
      input_width=width,
      resize_flow_to_img_res=True,
      infer_occlusion=False,
      infer_bw=False)
Ejemplo n.º 2
0
    def test_with_mock_training(self):
        model_dir = self.create_tempdir().full_path
        mock_t2r_model = mocks.MockT2RModel(
            preprocessor_cls=noop_preprocessor.NoOpPreprocessor,
            device_type='tpu',
            use_avg_model_params=True)

        mock_input_generator = mocks.MockInputGenerator(batch_size=_BATCH_SIZE)
        export_dir = os.path.join(model_dir, _EXPORT_DIR)
        hook_builder = async_export_hook_builder.AsyncExportHookBuilder(
            export_dir=export_dir,
            create_export_fn=async_export_hook_builder.default_create_export_fn
        )

        gin.parse_config('tf.contrib.tpu.TPUConfig.iterations_per_loop=1')
        gin.parse_config('tf.estimator.RunConfig.save_checkpoints_steps=1')

        # We optimize our network.
        train_eval.train_eval_model(t2r_model=mock_t2r_model,
                                    input_generator_train=mock_input_generator,
                                    train_hook_builders=[hook_builder],
                                    model_dir=model_dir,
                                    max_train_steps=_MAX_STEPS)
        self.assertNotEmpty(tf.io.gfile.listdir(model_dir))
        self.assertNotEmpty(tf.io.gfile.listdir(export_dir))
        for exported_model_dir in tf.io.gfile.listdir(export_dir):
            self.assertNotEmpty(
                tf.io.gfile.listdir(
                    os.path.join(export_dir, exported_model_dir)))
        predictor = exported_savedmodel_predictor.ExportedSavedModelPredictor(
            export_dir=export_dir)
        self.assertTrue(predictor.restore())
Ejemplo n.º 3
0
def run(agent,
        game,
        num_steps,
        root_dir,
        restore_ckpt,
        use_legacy_checkpoint=False):
    """Main entrypoint for running and generating visualizations.

  Args:
    agent: str, agent type to use.
    game: str, Atari 2600 game to run.
    num_steps: int, number of steps to play game.
    root_dir: str, root directory where files will be stored.
    restore_ckpt: str, path to the checkpoint to reload.
    use_legacy_checkpoint: bool, whether to restore from a legacy (pre-Keras)
      checkpoint.
  """
    tf.compat.v1.reset_default_graph()
    config = """
  atari_lib.create_atari_environment.game_name = '{}'
  WrappedReplayBuffer.replay_capacity = 300
  """.format(game)
    base_dir = os.path.join(root_dir, 'agent_viz', game, agent)
    gin.parse_config(config)
    runner = create_runner(base_dir, restore_ckpt, agent,
                           use_legacy_checkpoint)
    runner.visualize(os.path.join(base_dir, 'images'),
                     num_global_steps=num_steps)
Ejemplo n.º 4
0
    def test_train_mnist(self):
        """Train MNIST model (almost) fully, to compare to other implementations.

    Evals for cross-entropy loss and accuracy are run every 50 steps;
    their values are visible in the test log.
    """
        gin.parse_config([
            'batch_fn.batch_size_per_device = 256',
            'batch_fn.eval_batch_size = 256',
        ])

        mnist_model = tl.Serial(
            tl.Flatten(),
            tl.Dense(512),
            tl.Relu(),
            tl.Dense(512),
            tl.Relu(),
            tl.Dense(10),
            tl.LogSoftmax(),
        )
        task = training.TrainTask(
            itertools.cycle(_mnist_dataset().train_stream(1)),
            tl.CrossEntropyLoss(), adafactor.Adafactor(.02))
        eval_task = training.EvalTask(
            itertools.cycle(_mnist_dataset().eval_stream(1)),
            [tl.CrossEntropyLoss(), tl.AccuracyScalar()],
            names=['CrossEntropyLoss', 'AccuracyScalar'],
            eval_at=lambda step_n: step_n % 50 == 0,
            eval_N=10)

        training_session = training.Loop(mnist_model,
                                         task,
                                         eval_task=eval_task)
        training_session.run(n_steps=1000)
        self.assertEqual(training_session.current_step(), 1000)
    def eval(self,
             mixture_or_task_name,
             checkpoint_steps=None,
             summary_dir=None,
             split="validation"):
        """Evaluate the model on the given Mixture or Task.

    Args:
      mixture_or_task_name: str, the name of the Mixture or Task to evaluate on.
        Must be pre-registered in the global `TaskRegistry` or
        `MixtureRegistry.`
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        evaluation will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run eval
        continuously waiting for new checkpoints. If -1, get the latest
        checkpoint from the model directory.
      summary_dir: str, path to write TensorBoard events file summaries for
        eval. If None, use model_dir/eval_{split}.
      split: str, the mixture/task split to evaluate on.
    """
        if checkpoint_steps == -1:
            checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)
        vocabulary = t5.models.mesh_transformer.get_vocabulary(
            mixture_or_task_name)
        dataset_fn = functools.partial(
            t5.models.mesh_transformer.mesh_eval_dataset_fn,
            mixture_or_task_name=mixture_or_task_name,
        )
        with gin.unlock_config():
            gin.parse_config_file(_operative_config_path(self._model_dir))
            gin.parse_config(self._gin_bindings)
        utils.eval_model(self.estimator(vocabulary), vocabulary,
                         self._sequence_length, self.batch_size, split,
                         self._model_dir, dataset_fn, summary_dir,
                         checkpoint_steps)
Ejemplo n.º 6
0
    def _worker(self, root_dir, parameters, device_queue):
        # sleep for random seconds to avoid crowded launching
        try:
            time.sleep(random.uniform(0, 3))

            device = device_queue.get()
            if self._conf.use_gpu:
                os.environ["CUDA_VISIBLE_DEVICES"] = str(device)
            else:
                os.environ["CUDA_VISIBLE_DEVICES"] = ""  # run on cpu

            from alf.utils.common import set_per_process_memory_growth
            set_per_process_memory_growth()

            logging.set_verbosity(logging.INFO)

            logging.info("parameters %s" % parameters)
            with gin.unlock_config():
                gin.parse_config(
                    ['%s=%s' % (k, v) for k, v in parameters.items()])
            train_eval(root_dir)

            device_queue.put(device)
        except Exception as e:
            logging.info(e)
            raise e
Ejemplo n.º 7
0
 def parse_gin_config(self, ckpt):
     """Parse the model operative config with special streaming parameters."""
     with gin.unlock_config():
         ckpt_dir = os.path.dirname(ckpt)
         operative_config = train_util.get_latest_operative_config(ckpt_dir)
         print(f'Parsing from operative_config {operative_config}')
         gin.parse_config_file(operative_config, skip_unknown=True)
         # Set streaming specific params.
         # Remove reverb processor.
         pg_string = """ProcessorGroup.dag = [
   (@synths.Harmonic(),
     ['amps', 'harmonic_distribution', 'f0_hz']),
   (@synths.FilteredNoise(),
     ['noise_magnitudes']),
   (@processors.Add(),
     ['filtered_noise/signal', 'harmonic/signal']),
   ]"""
         time_steps = gin.query_parameter('F0PowerPreprocessor.time_steps')
         n_samples = gin.query_parameter('Harmonic.n_samples')
         samples_per_frame = int(n_samples / time_steps)
         gin.parse_config([
             'F0PowerPreprocessor.time_steps=1',
             f'Harmonic.n_samples={samples_per_frame}',
             f'FilteredNoise.n_samples={samples_per_frame}',
             pg_string,
         ])
Ejemplo n.º 8
0
def load_model(instrument_model, audio_length):
    # Build checkpoint path
    # Assumes only one checkpoint in the folder, 'model.ckpt-[iter]`.
    model_dir = os.path.join(CKPT_DIR,
                             "solo_%s_ckpt" % instrument_model.lower())
    ckpt_files = [
        f for f in tf.gfile.ListDirectory(model_dir) if "model.ckpt" in f
    ]
    ckpt_name = ".".join(ckpt_files[0].split(".")[:2])
    ckpt = os.path.join(model_dir, ckpt_name)

    # Parse gin config
    with gin.unlock_config():
        gin_file = os.path.join(model_dir, "operative_config-0.gin")
        gin.parse_config_file(gin_file, skip_unknown=True)

    # Ensure dimensions sampling rates are equal
    time_steps_train = gin.query_parameter("DefaultPreprocessor.time_steps")
    n_samples_train = gin.query_parameter("Additive.n_samples")
    hop_size = int(n_samples_train / time_steps_train)

    time_steps = int(audio_length / hop_size)
    n_samples = time_steps * hop_size

    gin_params = [
        "Additive.n_samples = {}".format(n_samples),
        "FilteredNoise.n_samples = {}".format(n_samples),
        "DefaultPreprocessor.time_steps = {}".format(time_steps),
    ]

    with gin.unlock_config():
        gin.parse_config(gin_params)

    return ckpt, time_steps, n_samples
Ejemplo n.º 9
0
  def configure_gin(self, ckpt):
    """Parse the model operative config with special streaming parameters."""
    parse_operative_config(ckpt)

    # Set streaming specific params.
    time_steps = gin.query_parameter('F0PowerPreprocessor.time_steps')
    n_samples = gin.query_parameter('Harmonic.n_samples')
    samples_per_frame = int(n_samples / time_steps)
    config = [
        'F0PowerPreprocessor.time_steps = 1',
        f'Harmonic.n_samples = {samples_per_frame}',
        f'FilteredNoise.n_samples = {samples_per_frame}',
    ]

    # Remove reverb processor.
    processor_group_string = """ProcessorGroup.dag = [
    (@synths.Harmonic(),
      ['amps', 'harmonic_distribution', 'f0_hz']),
    (@synths.FilteredNoise(),
      ['noise_magnitudes']),
    (@processors.Add(),
      ['filtered_noise/signal', 'harmonic/signal']),
    ]"""
    config.append(processor_group_string)

    with gin.unlock_config():
      gin.parse_config(config)
Ejemplo n.º 10
0
  def test_build_layer(self, kwarg_modules):
    """Tests if layer builds properly and produces outputs of correct shape."""
    gin_config = (self.gin_config_kwarg_modules if kwarg_modules else
                  self.gin_config_dag_modules)
    with gin.unlock_config():
      gin.clear_config()
      gin.parse_config(gin_config)

    dag_layer = ConfigurableDAGLayer()
    outputs = dag_layer(self.inputs)
    self.assertIsInstance(outputs, dict)

    z = outputs['bottleneck']['z_bottleneck']
    x_rec = outputs['decoder']['reconstruction']
    x_rec2 = outputs['out']['reconstruction']

    # Confirm that layer generates correctly sized tensors.
    self.assertEqual(outputs['test_data'].shape, self.x.shape)
    self.assertEqual(outputs['inputs']['test_data'].shape, self.x.shape)
    self.assertEqual(x_rec.shape, self.x.shape)
    self.assertEqual(z.shape[-1], self.z_dims)
    self.assertAllClose(x_rec, x_rec2)

    # Confirm that variables are inherited by DAGLayer.
    self.assertLen(dag_layer.trainable_variables, 6)  # 3 weights, 3 biases.
    def test_l1_attack(self, attack_name, random_start):
        num_iter, step_size, epsilon, percentile = 4, 1.0, 2.5, 99
        gin.parse_config([
            f"attacks.l1_config.num_iter = {num_iter}",
            f"attacks.l1_config.step_size = {step_size}",
            f"attacks.l1_config.epsilon = {epsilon}",
            f"attacks.l1_config.percentile = {percentile}",
            "attacks.union_config.restart = 5",
        ])

        x = tf.random.uniform(shape=self.batched_input_shape)
        y = tf.random.categorical(
            tf.zeros([self.batch_size, self.num_classes]), 1)
        attack = attacks.construct_attack(attack_name)
        adv_x = attack.attack(tf.constant(x),
                              tf.constant(y),
                              self.model,
                              self.loss_fn,
                              random_start=random_start)

        diff = tf.reshape(adv_x - x, (self.batch_size, -1)).numpy()
        l1_norm = np.linalg.norm(diff, ord=1, axis=-1)
        self.assertAllLessEqual(l1_norm, epsilon + 1e-5)
        touched = np.count_nonzero(diff, axis=-1)
        self.assertAllLessEqual(touched,
                                diff.shape[1] * num_iter * percentile / 100)
    def test_train(self):
        gin.parse_config([
            "data.preprocess_image.height = 28",
            "data.preprocess_image.width = 28",
            "data.preprocess_image.num_channels = 1",
            "data.get_test_dataset.batch_size = 1",
            "data.get_test_dataset.dataset = 'mnist'",
            "data.get_training_dataset.batch_size = 1",
            "data.get_training_dataset.dataset = 'mnist'",
            "data.get_training_dataset.shuffle_buffer_size = 1",
            "data.get_validation_dataset.batch_size = 1",
            "data.get_validation_dataset.dataset = 'mnist'",
            "data.get_validation_dataset.split = '2'",
            "resnet.build_resnet_v1.input_shape = (28, 28, 1)",
            "resnet.build_resnet_v1.depth = 8",
            "selectors.construct_representation_selector.selection_strategy = 'multiweight'",
            "selectors.construct_representation_selector.sample_freq = 1",
            "selectors.construct_representation_selector.update_freq = 1",
            "trainer.train.epochs = 2",
            "trainer.train.steps_per_epoch = 1",
            "trainer.train.representation_list = [('identity', 'l2'), ('dct', 'l2')]",
        ])
        with tfds.testing.mock_data(num_examples=10):
            trainer.train(self.ckpt_dir.full_path, self.summary_dir.full_path)

        ckpt_path = os.path.join(self.ckpt_dir, "ckpt-2")
        self.assertTrue(tf.io.gfile.exists(ckpt_path + ".index"))
        variables = [
            name for name, shape in tf.train.list_variables(ckpt_path)
        ]
        self.assertTrue(any(name.startswith("model") for name in variables))
        self.assertTrue(any(name.startswith("selector") for name in variables))
Ejemplo n.º 13
0
    def configure_gin(self, ckpt):
        """Parse the model operative config with special streaming parameters."""
        parse_operative_config(ckpt)

        # Set streaming specific params.
        preprocessor_ref = gin.query_parameter('Autoencoder.preprocessor')
        preprocessor_str = preprocessor_ref.scoped_selector
        time_steps = gin.query_parameter(f'{preprocessor_str}.time_steps')
        n_samples = gin.query_parameter('Harmonic.n_samples')
        if not isinstance(n_samples, int):
            n_samples = gin.query_parameter('%n_samples')
        samples_per_frame = int(n_samples / time_steps)

        config = [
            'Autoencoder.preprocessor = @F0PowerPreprocessor()',
            'F0PowerPreprocessor.time_steps = 1',
            f'Harmonic.n_samples = {samples_per_frame}',
            f'FilteredNoise.n_samples = {samples_per_frame}',
        ]

        # Remove reverb and crop processors.
        processor_group_string = """ProcessorGroup.dag = [
    (@synths.Harmonic(),
      ['amps', 'harmonic_distribution', 'f0_hz']),
    (@synths.FilteredNoise(),
      ['noise_magnitudes']),
    (@processors.Add(),
      ['filtered_noise/signal', 'harmonic/signal']),
    ]"""
        config.append(processor_group_string)

        with gin.unlock_config():
            gin.parse_config(config)
Ejemplo n.º 14
0
        def __init__(
            self,
            env_class,
            agent_class,
            network_fn,
            model_class,
            model_network_fn,
            config,
            init_hooks,
        ):
            # Limit number of threads used between independent tf.op-s to 1.
            import tensorflow as tf  # pylint: disable=import-outside-toplevel
            tf.config.threading.set_inter_op_parallelism_threads(1)
            tf.config.threading.set_intra_op_parallelism_threads(1)

            gin.parse_config(config, skip_unknown=True)

            for hook in init_hooks:
                hook()

            self.env = env_class()
            self.agent = (agent_class() if model_class is None else
                          agent_class(model_class=model_class))

            # Metrics cause some problems with Ray, so we switch them off,
            # as we don't train any networks inside the worker.
            if network_fn:
                network_fn = functools.partial(network_fn, metrics=None)
            if model_network_fn:
                model_network_fn = functools.partial(model_network_fn,
                                                     metrics=None)
            self._request_handler = core.RequestHandler(
                network_fn, model_network_fn=model_network_fn)
  def estimator(self, vocabulary, init_checkpoint=None, disable_tpu=False,
                score_in_predict_mode=False):

    if not self._tpu or disable_tpu:
      with gin.unlock_config():
        gin.bind_parameter("utils.get_variable_dtype.slice_dtype", "float32")
        gin.bind_parameter(
            "utils.get_variable_dtype.activation_dtype", "float32")
    with gin.unlock_config():
      gin.parse_config(self._gin_bindings)

    return utils.get_estimator(
        model_type=self._model_type,
        vocabulary=vocabulary,
        layout_rules=self._layout_rules,
        mesh_shape=mtf.Shape([]) if disable_tpu else self._mesh_shape,
        mesh_devices=None if disable_tpu else self._mesh_devices,
        model_dir=self._model_dir,
        batch_size=self.batch_size,
        sequence_length=self._sequence_length,
        autostack=self._autostack,
        learning_rate_schedule=self._learning_rate_schedule,
        keep_checkpoint_max=self._keep_checkpoint_max,
        save_checkpoints_steps=self._save_checkpoints_steps,
        optimizer=self._optimizer,
        predict_fn=self._predict_fn,
        variable_filter=self._variable_filter,
        ensemble_inputs=self._ensemble_inputs,
        use_tpu=None if disable_tpu else self._tpu,
        tpu_job_name=self._tpu_job_name,
        iterations_per_loop=self._iterations_per_loop,
        cluster=self._cluster,
        init_checkpoint=init_checkpoint,
        score_in_predict_mode=score_in_predict_mode)
Ejemplo n.º 16
0
    def test_singletons(self):
        @gin.configurable
        class Champ(object):
            count = 0

            def __init__(self):
                Champ.count += 1

        config = '''
chuck_norris/singleton.constructor = @Champ
f.x = @chuck_norris/singleton()
g.z = @chuck_norris/singleton()
'''
        gin.parse_config(config)
        self.assertEqual(Champ.count, 0)
        f()
        self.assertEqual(Champ.count, 1)
        g()
        self.assertEqual(Champ.count, 1)
        with GinState(copy_state=True):
            f()
            self.assertEqual(Champ.count, 1)
        with GinState():
            gin.parse_config(config)
            f()
            self.assertEqual(Champ.count, 2)
Ejemplo n.º 17
0
    def _worker(self, root_dir, parameters, device_queue):
        # sleep for random seconds to avoid crowded launching
        try:
            time.sleep(random.uniform(0, 3))

            device = device_queue.get()
            if self._conf.use_gpu:
                os.environ["CUDA_VISIBLE_DEVICES"] = str(device)
            else:
                os.environ["CUDA_VISIBLE_DEVICES"] = ""  # run on cpu

            if torch.cuda.is_available():
                alf.set_default_device("cuda")
            logging.set_verbosity(logging.INFO)

            logging.info("Search parameters %s" % parameters)
            with gin.unlock_config():
                gin.parse_config(
                    ['%s=%s' % (k, v) for k, v in parameters.items()])
                gin.parse_config(
                    "TrainerConfig.confirm_checkpoint_upon_crash=False")
            train_eval(FLAGS.ml_type, root_dir)

            device_queue.put(device)
        except Exception as e:
            logging.info(traceback.format_exc())
            raise e
Ejemplo n.º 18
0
def main(argv):
    del argv
    # Import modules BEFORE running Gin.
    if FLAGS.import_module:
        for module_name in FLAGS.import_module:
            __import__(module_name)

    # First, try to parse from a config file.
    if FLAGS.config_file:
        bindings = None
        if bindings is None:
            with tf.io.gfile.GFile(FLAGS.config_file) as f:
                bindings = f.readlines()
        bindings = [six.ensure_str(b) for b in bindings if b.strip()]
        gin.parse_config('\n'.join(bindings))

    if FLAGS.params:
        gin.parse_config(FLAGS.params)

    if FLAGS.run_functions_eagerly:
        tf.config.experimental_run_functions_eagerly(True)

    if not tf.io.gfile.exists(FLAGS.eval_dir):
        tf.io.gfile.makedirs(FLAGS.eval_dir)

    evaluation()
Ejemplo n.º 19
0
def parse_gin(model_dir):
    """Parse gin config from --gin_file, --gin_param, and the model directory."""
    # Add user folders to the gin search path.
    for gin_search_path in [GIN_PATH] + FLAGS.gin_search_path:
        gin.add_config_file_search_path(gin_search_path)

    # Parse gin configs, later calls override earlier ones.
    with gin.unlock_config():
        # Optimization defaults.
        use_tpu = bool(FLAGS.tpu)
        opt_default = 'base.gin' if not use_tpu else 'base_tpu.gin'
        gin.parse_config_file(os.path.join('optimization', opt_default))

        # Load operative_config if it exists (model has already trained).
        operative_config = os.path.join(model_dir, 'operative_config-0.gin')
        if tf.io.gfile.exists(operative_config):
            gin.parse_config_file(operative_config, skip_unknown=True)

        # Only use the custom cumsum for TPUs.
        gin.parse_config('ddsp.core.cumsum.use_tpu={}'.format(use_tpu))

        # User gin config and user hyperparameters from flags.
        gin.parse_config_files_and_bindings(FLAGS.gin_file,
                                            FLAGS.gin_param,
                                            skip_unknown=True)
    def __init__(
        self,
        env_class,
        agent_class,
        network_fn,
        config,
        scope,
        init_hooks,
        compress_episodes,
    ):

        gin.parse_config(config, skip_unknown=True)

        for hook in init_hooks:
            hook()

        import tensorflow as tf
        tf.config.threading.set_inter_op_parallelism_threads(1)
        tf.config.threading.set_intra_op_parallelism_threads(1)

        with gin.config_scope(scope):
            self.env = env_class()
            self.agent = agent_class()
            self._request_handler = core.RequestHandler(network_fn)

        self._compress_episodes = compress_episodes
Ejemplo n.º 21
0
def _test_stability(max_time=5, render=False, test_generator=None):
    """Tests the stability of the controller using speed profiles."""
    locomotion_controller_setup.load_sim_config(render=render)
    gin.parse_config(SCENARIO_SET_CONFIG)
    if FLAGS.add_random_push:
        locomotion_controller_setup.add_random_push_config()

    env = env_loader.load()
    controller = locomotion_controller_setup.setup_controller(env.robot,
                                                              gait=FLAGS.gait)

    for name, speed_profile in test_generator():
        env.reset()
        controller.reset()
        current_time = 0
        while current_time < max_time:
            current_time = env.get_time_since_reset()
            lin_speed, ang_speed = _generate_linear_angular_speed(
                current_time, speed_profile[0], speed_profile[1])
            _update_controller_params(controller, lin_speed, ang_speed)

            # Needed before every call to get_action().
            controller.update()
            hybrid_action = controller.get_action()

            _, _, done, _ = env.step(hybrid_action)
            if done:
                break
        print(
            f"Scene name: flat ground. Random push: {FLAGS.add_random_push}. "
            f"Survival time for {name} = {speed_profile[1]} is {current_time}")
  def test_transformer_steps(self, config, expected_block_count):
    gin.parse_config(config)

    _, params = edge_supervision_models.transformer_steps.init(
        jax.random.PRNGKey(0),
        node_embeddings=jnp.zeros((5, 3), jnp.float32),
        edge_embeddings=jnp.zeros((5, 5, 4), jnp.float32),
        neighbor_mask=jnp.zeros((5, 5), jnp.float32),
        num_real_nodes_per_graph=4)

    # This component should contain the right number of blocks.
    self.assertLen(params, expected_block_count)
    for block in params.values():
      # Each block contains 4 sublayers.
      self.assertLen(block, 4)

    # Gradients should work.
    outs, vjpfun = jax.vjp(
        functools.partial(
            edge_supervision_models.transformer_steps.call,
            node_embeddings=jnp.zeros((5, 3), jnp.float32),
            edge_embeddings=jnp.zeros((5, 5, 4), jnp.float32),
            neighbor_mask=jnp.zeros((5, 5), jnp.float32),
            num_real_nodes_per_graph=4),
        params,
    )
    vjpfun(outs)
Ejemplo n.º 23
0
 def setUp(self):
   super().setUp()
   gin.clear_config()
   gin.parse_config(GIN_CONFIG)
   self.addCleanup(mock.patch.stopall)
   self.mock_load = mock.patch.object(
       loaders.TFDSLoader, 'load', autospec=True).start()
  def test_transformer_steps_masking(self):
    """Transformer should mask out padding even if not masked to neigbors."""
    gin.parse_config(
        textwrap.dedent("""\
            transformer_steps.layers = 1
            transformer_steps.share_weights = False
            transformer_steps.mask_to_neighbors = False
            NodeSelfAttention.heads = 2
            NodeSelfAttention.query_key_dim = 3
            NodeSelfAttention.value_dim = 4
            """))

    with flax.nn.capture_module_outputs() as outputs:
      edge_supervision_models.transformer_steps.init(
          jax.random.PRNGKey(0),
          node_embeddings=jnp.zeros((5, 3), jnp.float32),
          edge_embeddings=jnp.zeros((5, 5, 4), jnp.float32),
          neighbor_mask=jnp.zeros((5, 5), jnp.float32),
          num_real_nodes_per_graph=4)

    attention_weights, = (v[0]
                          for k, v in outputs.as_dict().items()
                          if k.endswith("attend/attention_weights"))
    expected = np.array([[[0.25, 0.25, 0.25, 0.25, 0.0]] * 5] * 2)
    np.testing.assert_allclose(attention_weights, expected)
Ejemplo n.º 25
0
 def testSynchronousTrainCollectEval(self):
   """End-to-end integration test.
   """
   env = grasping_env.KukaGraspingProceduralEnv(downsample_width=64,
                                                downsample_height=64,
                                                continuous=True,
                                                remove_height_hack=True,
                                                render_mode='DIRECT')
   data_dir = 'testdata'
   gin_config = os.path.join(FLAGS.test_srcdir, data_dir, 'random_collect.gin')
   # Collect initial data from random policy without training.
   with open(gin_config, 'r') as f:
     gin.parse_config(f)
   train_collect_eval.train_collect_eval(collect_env=env,
                                         eval_env=None,
                                         test_env=None,
                                         root_dir=self._root_dir,
                                         train_fn=None)
   # Run training (synchronous train, collect, & eval).
   gin_config = os.path.join(FLAGS.test_srcdir, data_dir, 'train_dqn.gin')
   with open(gin_config, 'r') as f:
     gin.parse_config(f)
   train_collect_eval.train_collect_eval(collect_env=env,
                                         eval_env=None,
                                         test_env=None,
                                         root_dir=self._root_dir)
  def test_nri_steps(self):
    gin.parse_config(
        textwrap.dedent("""\
            graph_layers.NRIEdgeLayer.allow_non_adjacent = True
            graph_layers.NRIEdgeLayer.mlp_vtoe_dims = [4, 4]
            nri_steps.mlp_etov_dims = [8, 8]
            nri_steps.with_residual_layer_norm = True
            nri_steps.layers = 3
            """))

    _, params = edge_supervision_models.nri_steps.init(
        jax.random.PRNGKey(0),
        node_embeddings=jnp.zeros((5, 3), jnp.float32),
        edge_embeddings=jnp.zeros((5, 5, 4), jnp.float32),
        num_real_nodes_per_graph=4)

    # This component should contain the right number of blocks.
    self.assertLen(params, 3)
    for block in params.values():
      # Each block contains 5 sublayers:
      # - NRI message pass
      # - Three dense layers (from mlp_etov_dims, then back to embedding space)
      # - Layer norm
      self.assertLen(block, 5)

    # Gradients should work.
    outs, vjpfun = jax.vjp(
        functools.partial(
            edge_supervision_models.nri_steps.call,
            node_embeddings=jnp.zeros((5, 3), jnp.float32),
            edge_embeddings=jnp.zeros((5, 5, 4), jnp.float32),
            num_real_nodes_per_graph=4),
        params,
    )
    vjpfun(outs)
Ejemplo n.º 27
0
    def testGinConfig(self):
        batch_size = 3
        num_state_dims = 5
        action_spec = tensor_spec.BoundedTensorSpec([1], tf.int32, 0, 1)
        num_actions = action_spec.maximum - action_spec.minimum + 1
        self.assertEqual(num_actions, 2)

        observations_spec = tensor_spec.TensorSpec([3, 3, num_state_dims],
                                                   tf.float32)
        observations = tf.random.uniform([batch_size, 3, 3, num_state_dims])
        next_observations = tf.random.uniform(
            [batch_size, 3, 3, num_state_dims])
        time_steps = ts.restart(observations, batch_size)
        next_time_steps = ts.restart(next_observations, batch_size)

        gin.parse_config("""
        CategoricalQNetwork.conv_layer_params = [(16, 2, 1), (15, 2, 1)]
        CategoricalQNetwork.fc_layer_params = [4, 3, 5]
    """)

        q_network = categorical_q_network.CategoricalQNetwork(
            input_tensor_spec=observations_spec, action_spec=action_spec)

        logits, _ = q_network(time_steps.observation)
        next_logits, _ = q_network(next_time_steps.observation)
        self.assertAllEqual(logits.shape.as_list(),
                            [batch_size, num_actions, q_network._num_atoms])
        self.assertAllEqual(next_logits.shape.as_list(),
                            [batch_size, num_actions, q_network._num_atoms])

        # This time there are six layers: two conv layers, three fc layers, and one
        # final logits layer, for 12 trainable_variables in total.
        self.assertLen(q_network.trainable_variables, 12)
  def test_ggtnn_steps(self):
    gin.parse_config(
        textwrap.dedent("""\
            edge_supervision_models.ggnn_steps.iterations = 10
            graph_layers.LinearMessagePassing.message_dim = 5
            """))

    _, params = edge_supervision_models.ggnn_steps.init(
        jax.random.PRNGKey(0),
        node_embeddings=jnp.zeros((5, 3), jnp.float32),
        edge_embeddings=jnp.zeros((5, 5, 4), jnp.float32))

    # This component should only contain one step block, with two sublayers.
    self.assertEqual(set(params.keys()), {"step"})
    self.assertLen(params["step"], 2)

    # Gradients should work.
    outs, vjpfun = jax.vjp(
        functools.partial(
            edge_supervision_models.ggnn_steps.call,
            node_embeddings=jnp.zeros((5, 3), jnp.float32),
            edge_embeddings=jnp.zeros((5, 5, 4), jnp.float32)),
        params,
    )
    vjpfun(outs)
def main(_):
    np.random.seed(FLAGS.task)
    tf.set_random_seed(FLAGS.task)

    if FLAGS.distributed:
        task = FLAGS.task
    else:
        task = 0

    if FLAGS.gin_config:
        if tf.gfile.Exists(FLAGS.gin_config):
            # Parse as a file.
            with tf.gfile.Open(FLAGS.gin_config) as f:
                gin.parse_config(f)
        else:
            gin.parse_config(FLAGS.gin_config)

    gin.finalize()

    if FLAGS.run_mode == 'collect_eval_once':
        train_collect_eval.train_collect_eval(root_dir=FLAGS.root_dir,
                                              train_fn=None,
                                              task=FLAGS.task)
    elif FLAGS.run_mode == 'train_only':
        train_collect_eval.train_collect_eval(root_dir=FLAGS.root_dir,
                                              do_collect_eval=False,
                                              task=task,
                                              master=FLAGS.master,
                                              ps_tasks=FLAGS.ps_tasks)
    elif FLAGS.run_mode == 'collect_eval_loop':
        raise NotImplementedError('collect_eval_loops')
    else:
        # Synchronous train-collect-eval.
        train_collect_eval.train_collect_eval(root_dir=FLAGS.root_dir,
                                              task=task)
Ejemplo n.º 30
0
def parse_args(args=None):
    gin.parse_config('torchexp.config.manual_seed.seed = %seed')
    gin.bind_parameter('%seed', None)

    if args is None:
        args = sys.argv[1:]

    for arg in args:
        try:
            key, value = arg.split('=', maxsplit=1)
        except ValueError:
            raise ValueError(f'The argument `{arg}` is not accepted!'
                             ' All argument should be the form name=value,'
                             ' --yaml=config.yaml or --gin=config.gin')
        if key == '--yaml':
            _read_yaml_macros(value)
        elif key == '--gin':
            gin.parse_config_file(value)
        else:
            if not check_gin_special(value):
                try:
                    value = literal_eval(value)
                except (ValueError, SyntaxError):
                    pass
                value = repr(value)
            gin.parse_config(f'{key} = {value}')

    manual_seed()