def testTargetSpec(self, target_name):
    gin.clear_config()
    gin.bind_parameter("cloud.path", self.test_cloud)
    gin.bind_parameter("german.path", self.test_german)

    target, spec = neutra.GetTargetSpec(
        target_name,
        num_dims=5,
        regression_dataset="german",
        regression_type="gamma_scales2")
    lp = self.evaluate(target.log_prob(tf.ones([2, spec.num_dims])))
    self.assertAllEqual([2], lp.shape)
 def test_evaluate_using_environment_steps(self):
     gin.bind_parameter('metrics_online.StddevWithinRuns.eval_points',
                        [2001])
     metric_instances = [
         metrics_online.StddevWithinRuns(),
         metrics_online.StddevWithinRuns()
     ]
     evaluator = eval_metrics.Evaluator(
         metric_instances, timepoint_variable='Metrics/EnvironmentSteps')
     results = evaluator.evaluate(self.run_dirs)
     self.assertEqual(list(results.keys()), ['StddevWithinRuns'])
     self.assertTrue(np.greater(list(results.values()), 0.).all())
Example #3
0
def bind_gin_params(xm_params):
    """Binding parameters from the given dictionary.

  Args:
    xm_params: dict, <key,value> pairs where key is a valid gin parameter.
  """
    tf.logging.info('xm_pararameters:\n')
    for param_name, param_value in xm_params.items():
        # Quote non-numeric values.
        tf.logging.info('%s=%s\n' % (param_name, param_value))
        with gin.unlock_config():
            gin.bind_parameter(param_name, param_value)
Example #4
0
 def test_create_actor_behavioral_cloning_agent(self):
     gin.bind_parameter('create_agent.policy_network',
                        actor_distribution_network.ActorDistributionNetwork)
     gin.bind_parameter('BehavioralCloningAgent.optimizer',
                        tf.compat.v1.train.AdamOptimizer())
     tf_agent = agent_creators.create_agent(
         agent_name='actor_behavioral_cloning',
         time_step_spec=self._time_step_spec,
         action_spec=self._action_spec)
     self.assertIsInstance(tf_agent,
                           behavioral_cloning_agent.BehavioralCloningAgent)
     self.assertIsInstance(tf_agent,
                           actor_behavioral_cloning_agent.ActorBCAgent)
Example #5
0
 def setUp(self):
   super(JaxQuantileAgentTest, self).setUp()
   self.num_actions = 4
   self._num_atoms = 5
   self._min_replay_history = 32
   self._epsilon_decay_period = 90
   self.observation_shape = dqn_agent.NATURE_DQN_OBSERVATION_SHAPE
   self.observation_dtype = dqn_agent.NATURE_DQN_DTYPE
   self.stack_size = dqn_agent.NATURE_DQN_STACK_SIZE
   self.zero_state = onp.zeros(
       (1,) + self.observation_shape + (self.stack_size,))
   gin.bind_parameter('OutOfGraphPrioritizedReplayBuffer.replay_capacity', 100)
   gin.bind_parameter('OutOfGraphPrioritizedReplayBuffer.batch_size', 2)
def _build_env():
    """Builds the environment for the Laikago robot.

  Returns:
    The OpenAI gym environment.
  """
    gin.parse_config_file(CONFIG_FILE_SIM)
    gin.bind_parameter("SimulationParameters.enable_rendering",
                       ENABLE_RENDERING)
    env = env_loader.load()
    env.seed(ENV_RANDOM_SEED)

    return env
Example #7
0
def main_configure(
    configs: Sequence[str],
    extra_options: Tuple[str, ...],
    verbosity: str,
    debug: bool = False,
    checkpoint: Optional[str] = None,
    catch_exceptions: bool = True,
    job_type: str = 'training',
    data: Optional[str] = None,
    extension: Optional[str] = None,
    wandb_continue: Optional[str] = None,
) -> Generator[Main, None, None]:
    if wandb_continue is not None:
        run = _get_wandb_run(wandb_continue)
        resume_args = dict(resume=True,
                           id=run.id,
                           name=run.name,
                           config=run.config,
                           notes=run.notes,
                           tags=run.tags)
    else:
        resume_args = {}
    wandb.init(sync_tensorboard=False, job_type=job_type, **resume_args)
    gin.parse_config_files_and_bindings(configs, extra_options)
    with gin.unlock_config():
        gin.bind_parameter(
            'main.base_logdir',
            str(Path(gin.query_parameter('main.base_logdir')).absolute()))
    with open(Path(wandb.run.dir) / f'config_{job_type}.gin', 'w') as f:
        for config in configs:
            f.write(f'\n# {config}\n')
            f.write(open(config).read())
        f.write('\n# Extra options\n')
        f.write('\n'.join(extra_options))
    checkpoint_path = None if checkpoint is None else Path(
        checkpoint).absolute()
    tempdir = None
    try:
        if data:
            # Habitat assumes data is stored in local 'data' directory
            tempdir = TemporaryDirectory()
            (Path(tempdir.name) / 'data').symlink_to(Path(data).absolute())
            os.chdir(tempdir.name)
        yield Main(verbosity,
                   debug=debug,
                   catch_exceptions=catch_exceptions,
                   extension=extension,
                   checkpoint=checkpoint_path)
    finally:
        if tempdir:
            tempdir.cleanup()
    def predict(self,
                input_file,
                output_file,
                checkpoint_steps=-1,
                beam_size=1,
                temperature=1.0,
                sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH,
                vocabulary=None):
        """Predicts targets from the given inputs.

    Args:
      input_file: str, path to a text file containing newline-separated input
        prompts to predict from.
      output_file: str, path prefix of output file to write predictions to. Note
        the checkpoint step will be appended to the given filename.
      checkpoint_steps: int, list of ints, or None. If an int or list of ints,
        inference will be run on the checkpoint files in `model_dir` whose
        global steps are closest to the global steps provided. If None, run
        inference continuously waiting for new checkpoints. If -1, get the
        latest checkpoint from the model directory.
      beam_size: int, a number >= 1 specifying the number of beams to use for
        beam search.
      temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1)
        0.0 means argmax, 1.0 means sample according to predicted distribution.
      sentencepiece_model_path: str, path to the SentencePiece model file to use
        for decoding. Must match the one used during training.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use a SentencePieceVocabulary with the provided
        sentencepiece_model_path.
    """
        # TODO(sharannarang) : It would be nice to have a function like
        # load_checkpoint that loads the model once and then call decode_from_file
        # multiple times without having to restore the checkpoint weights again.
        # This would be particularly useful in colab demo.

        if checkpoint_steps == -1:
            checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir)

        with gin.unlock_config():
            gin.parse_config_file(_operative_config_path(self._model_dir))
            gin.bind_parameter("Bitransformer.decode.beam_size", beam_size)
            gin.bind_parameter("Bitransformer.decode.temperature", temperature)

        if vocabulary is None:
            vocabulary = t5.data.SentencePieceVocabulary(
                sentencepiece_model_path)
        utils.infer_model(self.estimator(vocabulary), vocabulary,
                          self._sequence_length, self.batch_size,
                          self._model_type, self._model_dir, checkpoint_steps,
                          input_file, output_file)
def latent1d(ctx, rows, cols, plot, filename, **kwargs):
    """Latent space traversal in 1D."""
    add_gin(ctx, "config", ["evaluate/visual/latent1d.gin"])
    parse(ctx, set_seed=True)

    with gin.unlock_config():
        gin.bind_parameter("disentangled.visualize.show.output.show_plot",
                           plot)

        if filename is not None:
            gin.bind_parameter("disentangled.visualize.show.output.filename",
                               filename)

        if rows is not None:
            gin.bind_parameter("disentangled.visualize.traversal1d.dimensions",
                               rows)

        if cols is not None:
            gin.bind_parameter("disentangled.visualize.traversal1d.steps",
                               cols)

    dataset = ctx.obj["dataset"].pipeline()
    disentangled.visualize.traversal1d(
        ctx.obj["model"],
        dataset,
        dimensions=gin.REQUIRED,
        offset=gin.REQUIRED,
        skip_batches=gin.REQUIRED,
        steps=gin.REQUIRED,
    )
Example #10
0
  def export(self, export_dir=None, checkpoint_step=-1, beam_size=1,
             temperature=1.0,
             sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH):
    """Exports a TensorFlow SavedModel.

    Args:
      export_dir: str, a directory in which to export SavedModels. Will use
        `model_dir` if unspecified.
      checkpoint_step: int, checkpoint to export. If -1 (default), use the
        latest checkpoint from the pretrained model directory.
      beam_size: int, a number >= 1 specifying the number of beams to use for
        beam search.
      temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1)
        0.0 means argmax, 1.0 means sample according to predicted distribution.
      sentencepiece_model_path: str, path to the SentencePiece model file to use
        for decoding. Must match the one used during training.
    """
    if checkpoint_step == -1:
      checkpoint_step = _get_latest_checkpoint_from_dir(self._model_dir)
    with gin.unlock_config():
      gin.parse_config_file(_operative_config_path(self._model_dir))
      gin.bind_parameter("Bitransformer.decode.beam_size", beam_size)
      gin.bind_parameter("Bitransformer.decode.temperature", temperature)
      gin.bind_parameter("utils.get_variable_dtype.slice_dtype", "float32")
      gin.bind_parameter("utils.get_variable_dtype.activation_dtype", "float32")

    vocabulary = t5.data.SentencePieceVocabulary(sentencepiece_model_path)
    model_ckpt = "model.ckpt-" + str(checkpoint_step)
    export_dir = export_dir or self._model_dir
    utils.export_model(
        self.estimator(vocabulary, disable_tpu=True), export_dir, vocabulary,
        self._sequence_length, batch_size=self.batch_size,
        checkpoint_path=os.path.join(self._model_dir, model_ckpt))
Example #11
0
    def export(self,
               export_dir=None,
               checkpoint_step=-1,
               beam_size=1,
               temperature=1.0,
               vocabulary=None,
               eval_with_score=False):
        """Exports a TensorFlow SavedModel.

    Args:
      export_dir: str, a directory in which to export SavedModels. Will use
        `model_dir` if unspecified.
      checkpoint_step: int, checkpoint to export. If -1 (default), use the
        latest checkpoint from the pretrained model directory.
      beam_size: int, a number >= 1 specifying the number of beams to use for
        beam search.
      temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1)
        0.0 means argmax, 1.0 means sample according to predicted distribution.
      vocabulary: vocabularies.Vocabulary object to use for tokenization, or
        None to use the default SentencePieceVocabulary.
      eval_with_score: If True, compute log-likelihood scores of targets.
        If False, do inference to generate outputs.

    Returns:
      The string path to the exported directory.
    """
        if checkpoint_step == -1:
            checkpoint_step = utils.get_latest_checkpoint_from_dir(
                self._model_dir)
        with gin.unlock_config():
            gin.parse_config_file(_operative_config_path(self._model_dir))
            gin.bind_parameter("Bitransformer.decode.beam_size", beam_size)
            gin.bind_parameter("Bitransformer.decode.temperature", temperature)

        if vocabulary is None:
            vocabulary = utils.get_vocabulary()
        model_ckpt = "model.ckpt-" + str(checkpoint_step)
        export_dir = export_dir or self._model_dir
        estimator = self.estimator(vocabulary,
                                   disable_tpu=True,
                                   score_in_predict_mode=eval_with_score)
        return mtf_utils.export_model(estimator,
                                      export_dir,
                                      vocabulary,
                                      self._sequence_length,
                                      self._model_type,
                                      batch_size=self.batch_size,
                                      checkpoint_path=os.path.join(
                                          self._model_dir, model_ckpt),
                                      eval_with_score=eval_with_score)
    def testNeuTraExperiment(self):
        gin.clear_config()
        gin.bind_parameter("target_spec.name", "ill_conditioned_gaussian")
        gin.bind_parameter("chain_stats.compute_stats_over_time", True)
        exp = neutra.NeuTraExperiment(bijector="affine", log_dir=self.temp_dir)

        exp.Train(4, batch_size=2)
        exp.Eval(batch_size=2)
        exp.Benchmark(test_num_steps=100, test_batch_size=2, batch_size=2)
        exp.TuneObjective(1,
                          0.1,
                          batch_size=2,
                          test_num_steps=600,
                          f_name="first_moment_mean")
def fixed(ctx, batch_size, filename, rows, cols, plot, verbose, **kwargs):
    """View/save images of dataset given a fixed latent factor."""
    dataset = ctx.obj['dataset']
    add_gin(ctx, 'config', ['evaluate/dataset/{}.gin'.format(dataset)])
    parse(ctx)

    with gin.unlock_config():
        gin.bind_parameter('disentangled.visualize.show.output.show_plot',
                           plot)
        gin.bind_parameter('disentangled.visualize.show.output.filename',
                           filename)
        if rows is not None:
            gin.bind_parameter('disentangled.visualize.fixed_factor_data.rows',
                               rows)
        if cols is not None:
            gin.bind_parameter('disentangled.visualize.fixed_factor_data.cols',
                               cols)

    num_values_per_factor = disentangled.dataset.get(
        dataset).num_values_per_factor
    dataset = disentangled.dataset.get(dataset).supervised()

    fixed, _ = disentangled.metric.utils.fixed_factor_dataset(
        dataset, batch_size, num_values_per_factor)

    disentangled.visualize.fixed_factor_data(fixed,
                                             rows=gin.REQUIRED,
                                             cols=gin.REQUIRED,
                                             verbose=verbose)
Example #14
0
def main(_):
    logging.set_verbosity(logging.INFO)
    tf.compat.v1.enable_v2_behavior()
    tf.compat.v1.enable_resource_variables()

    gin.parse_config_files_and_bindings(FLAGS.gin_file,
                                        FLAGS.gin_bindings,
                                        skip_unknown=True)

    root_dir = FLAGS.root_dir
    with gin.unlock_config():
        gin.bind_parameter('%ROOT_DIR', root_dir)

    trainer.train(root_dir)
Example #15
0
  def test_c4_pretrain(self):
    _t5_gin_config()

    gin.bind_parameter('c4_bare_preprocess_fn.spm_path', _spm_path())

    gin.bind_parameter('batcher.batch_size_per_device', 8)
    gin.bind_parameter('batcher.eval_batch_size', 8)
    gin.bind_parameter('batcher.max_eval_length', 50)
    gin.bind_parameter('batcher.buckets', ([51], [8, 1]))

    # Just make sure this doesn't throw.
    _ = tf_inputs.data_streams(
        'c4', data_dir=_TESTDATA, input_name='inputs', target_name='targets',
        bare_preprocess_fn=tf_inputs.c4_bare_preprocess_fn)
Example #16
0
 def test_inits_rewards_to_actions_serialized(self):
   precision = 2
   gin.bind_parameter('BoxSpaceSerializer.precision', precision)
   obs_size = 3
   n_timesteps = 6
   n_controls = 2
   rewards_to_actions = ppo.init_rewards_to_actions(
       vocab_size=4,
       observation_space=gym.spaces.Box(shape=(obs_size,), low=0, high=1),
       action_space=gym.spaces.MultiDiscrete(nvec=((2,) * n_controls)),
       n_timesteps=n_timesteps,
   )
   n_action_symbols = n_timesteps * (obs_size * precision + n_controls)
   self.assertEqual(rewards_to_actions.shape, (n_timesteps, n_action_symbols))
 def _make_space_and_serializer(
     self,
     low=-10,
     high=10,
     shape=(2, ),
     # Weird vocab_size to test that it doesn't only work with powers of 2.
     vocab_size=257,
     # Enough precision to represent float32s accurately.
     precision=4,
 ):
     gin.bind_parameter("BoxSpaceSerializer.precision", precision)
     space = gym.spaces.Box(low=low, high=high, shape=shape)
     serializer = space_serializer.create(space, vocab_size=vocab_size)
     return (space, serializer)
def visual(ctx, rows, cols, plot, filename, **kwargs):
    """Qualitative evaluation of output."""
    parse(ctx, set_seed=True)

    with gin.unlock_config():
        gin.bind_parameter("disentangled.visualize.show.output.show_plot",
                           plot)

        if filename is not None:
            gin.bind_parameter("disentangled.visualize.show.output.filename",
                               filename)

        if rows is not None:
            gin.bind_parameter("disentangled.visualize.reconstructed.rows",
                               rows)

        if cols is not None:
            gin.bind_parameter("disentangled.visualize.reconstructed.cols",
                               cols)

    dataset = ctx.obj["dataset"].pipeline()
    disentangled.visualize.reconstructed(ctx.obj["model"],
                                         dataset,
                                         rows=gin.REQUIRED,
                                         cols=gin.REQUIRED)
Example #19
0
File: gin.py Project: kant/zpy
def parse_gin_bindings(gin_bindings: Dict = None, ) -> None:
    """ Parse any extra gin bindings to the config. """
    if gin_bindings is None:
        log.info(f'No additional gin bindings to parse')
    else:
        log.info(f'Parsing additional bindings: {pformat(gin_bindings)}')
        with gin.unlock_config():
            for key, value in replace_human_redable_kwargs(gin_bindings):
                try:
                    gin.bind_parameter(key, value)
                    _message = 'BOUND  '
                except:
                    _message = 'IGNORED'
                log.info(f'{_message} - {key} : {value}')
 def test_training_loop_onlinetune(self):
   with self.tmp_dir() as output_dir:
     gin.bind_parameter("OnlineTuneEnv.model", functools.partial(
         models.MLP, n_hidden_layers=0, n_output_classes=1))
     gin.bind_parameter("OnlineTuneEnv.inputs", functools.partial(
         trax_inputs.random_inputs,
         input_shape=(1, 1),
         input_dtype=np.float32,
         output_shape=(1, 1),
         output_dtype=np.float32))
     gin.bind_parameter("OnlineTuneEnv.train_steps", 2)
     gin.bind_parameter("OnlineTuneEnv.eval_steps", 2)
     gin.bind_parameter(
         "OnlineTuneEnv.output_dir", os.path.join(output_dir, "envs"))
     self._run_training_loop("OnlineTuneEnv-v0", output_dir)
Example #21
0
 def test_finalize(self):
     gin.bind_parameter('f.x', 'global')
     gin.finalize()
     self.assertTrue(gin.config_is_locked())
     with GinState() as temp_state:
         gin.bind_parameter('f.x', 'temp')
         self.assertEqual(gin.query_parameter('f.x'), 'temp')
         self.assertFalse(gin.config_is_locked())
     self.assertTrue(gin.config_is_locked())
     with temp_state:
         self.assertFalse(gin.config_is_locked())
         gin.config.finalize()
         self.assertTrue(gin.config_is_locked())
     with temp_state:
         self.assertTrue(gin.config_is_locked())
Example #22
0
 def test_significance_map(self):
     gin.bind_parameter('BoxSpaceSerializer.precision', 3)
     significance_map = serialization_utils.significance_map(
         observation_serializer=space_serializer.create(gym.spaces.Box(
             low=0, high=1, shape=(2, )),
                                                        vocab_size=2),
         action_serializer=space_serializer.create(
             gym.spaces.MultiDiscrete(nvec=[2, 2]), vocab_size=2),
         representation_length=20,
     )
     np.testing.assert_array_equal(
         significance_map,
         # obs1, act1, obs2, act2, obs3 cut after 4th symbol.
         [0, 1, 2, 0, 1, 2, 0, 0, 0, 1, 2, 0, 1, 2, 0, 0, 0, 1, 2, 0],
     )
Example #23
0
 def setUp(self):
   super(FullRainbowAgentTest, self).setUp()
   self._num_actions = 4
   self._num_atoms = 5
   self._vmax = 7.
   self.observation_shape = dqn_agent.NATURE_DQN_OBSERVATION_SHAPE
   self.observation_dtype = dqn_agent.NATURE_DQN_DTYPE
   self.stack_size = dqn_agent.NATURE_DQN_STACK_SIZE
   self.zero_state = onp.zeros((1,) + self.observation_shape +
                               (self.stack_size,))
   gin.bind_parameter('OutOfGraphPrioritizedReplayBuffer.replay_capacity', 100)
   gin.bind_parameter('OutOfGraphPrioritizedReplayBuffer.batch_size', 2)
   gin.bind_parameter('JaxDQNAgent.min_replay_history', 32)
   gin.bind_parameter('JaxDQNAgent.epsilon_eval', 0.0)
   gin.bind_parameter('JaxDQNAgent.epsilon_decay_period', 90)
Example #24
0
 def setUp(self):
     gin.bind_parameter(
         'get_observation_processing_layer_creator.quantile_file_dir',
         os.path.join(constant.BASE_DIR, 'testdata'))
     observation_spec = tf.TensorSpec(dtype=tf.int64,
                                      shape=(),
                                      name='callee_users')
     self._time_step_spec = time_step.time_step_spec(observation_spec)
     self._action_spec = tensor_spec.BoundedTensorSpec(
         dtype=tf.int64,
         shape=(),
         minimum=0,
         maximum=1,
         name='inlining_decision')
     super(AgentCreatorsTest, self).setUp()
def sweep(parameters):
    keys = sorted(list(parameters.keys()))
    values = [parameters[k] for k in keys]
    cartesian_product = itertools.product(*values)
    sweep_as_dict = {}

    for setting in cartesian_product:
        for i, v in enumerate(setting):
            gin.bind_parameter(keys[i], v)
            sweep_as_dict[keys[i]] = v
        sweep_as_str = [
            f"{k.split('.')[-1]}:{v}" for k, v in sweep_as_dict.items()
        ]
        sweep_as_str = '_'.join(sweep_as_str)
        yield sweep_as_str
Example #26
0
    def testNeuTraExperiment(self):
        gin.clear_config()
        gin.bind_parameter("target_spec.name", "easy_gaussian")
        gin.bind_parameter("target_spec.num_dims", 2)
        exp = neutra.NeuTraExperiment(train_batch_size=2,
                                      test_chain_batch_size=2,
                                      bijector="affine",
                                      log_dir=self.temp_dir)

        with tf.Session() as sess:
            exp.Initialize(sess)
            exp.TrainBijector(sess, 1)
            exp.Eval(sess)
            exp.Benchmark(sess)
            exp.Tune(sess, method="random", max_num_trials=1)
Example #27
0
 def testSingleTrainingStepDiscItersWithEma(self, disc_iters):
     parameters = {
         "architecture": c.RESNET_CIFAR,
         "lambda": 1,
         "z_dim": 128,
         "dics_iters": disc_iters,
     }
     gin.bind_parameter("ModularGAN.g_use_ema", True)
     dataset = datasets.get_dataset("cifar10")
     gan = ModularGAN(dataset=dataset,
                      parameters=parameters,
                      model_dir=self.model_dir)
     estimator = gan.as_estimator(self.run_config,
                                  batch_size=2,
                                  use_tpu=False)
     estimator.train(gan.input_fn, steps=1)
     # Check for moving average variables in checkpoint.
     checkpoint_path = tf.train.latest_checkpoint(self.model_dir)
     ema_vars = sorted([
         v[0] for v in tf.train.list_variables(checkpoint_path)
         if v[0].endswith("ExponentialMovingAverage")
     ])
     tf.logging.info("ema_vars=%s", ema_vars)
     expected_ema_vars = sorted([
         "generator/fc_noise/kernel/ExponentialMovingAverage",
         "generator/fc_noise/bias/ExponentialMovingAverage",
         "generator/B1/up_conv_shortcut/kernel/ExponentialMovingAverage",
         "generator/B1/up_conv_shortcut/bias/ExponentialMovingAverage",
         "generator/B1/up_conv1/kernel/ExponentialMovingAverage",
         "generator/B1/up_conv1/bias/ExponentialMovingAverage",
         "generator/B1/same_conv2/kernel/ExponentialMovingAverage",
         "generator/B1/same_conv2/bias/ExponentialMovingAverage",
         "generator/B2/up_conv_shortcut/kernel/ExponentialMovingAverage",
         "generator/B2/up_conv_shortcut/bias/ExponentialMovingAverage",
         "generator/B2/up_conv1/kernel/ExponentialMovingAverage",
         "generator/B2/up_conv1/bias/ExponentialMovingAverage",
         "generator/B2/same_conv2/kernel/ExponentialMovingAverage",
         "generator/B2/same_conv2/bias/ExponentialMovingAverage",
         "generator/B3/up_conv_shortcut/kernel/ExponentialMovingAverage",
         "generator/B3/up_conv_shortcut/bias/ExponentialMovingAverage",
         "generator/B3/up_conv1/kernel/ExponentialMovingAverage",
         "generator/B3/up_conv1/bias/ExponentialMovingAverage",
         "generator/B3/same_conv2/kernel/ExponentialMovingAverage",
         "generator/B3/same_conv2/bias/ExponentialMovingAverage",
         "generator/final_conv/kernel/ExponentialMovingAverage",
         "generator/final_conv/bias/ExponentialMovingAverage",
     ])
     self.assertAllEqual(ema_vars, expected_ema_vars)
Example #28
0
    def setUp(self):
        load_gin_config("backend/ml_model/config/featureless_config.gin")
        gin.bind_parameter(
            "FeaturelessMedianPreferenceAverageRegularizationAggregator.epochs",
            1000)

        # creating videos
        self.videos = [
            Video.objects.create(video_id=f"video{i}") for i in tqdm(range(2))
        ]

        # creating users
        self.djangousers = [
            DjangoUser.objects.create_user(username=f"rater{i}",
                                           password=f"1234{i}")
            for i in tqdm(range(2))
        ]
        self.userprefs = [
            UserPreferences.objects.create(user=u) for u in self.djangousers
        ]

        # making the user verified
        self.userinfos = [
            UserInformation.objects.create(user=u) for u in self.djangousers
        ]
        self.verify = [False, True]
        accepted_domain = create_accepted_domain()
        self.vemails = [
            VerifiableEmail.objects.create(user=ui,
                                           email=f"{uuid1()}{accepted_domain}",
                                           is_verified=verify)
            for ui, verify in zip(self.userinfos, self.verify)
        ]

        data_rest = {k: 50 for k in VIDEO_FIELDS[1:]}
        self.f = VIDEO_FIELDS[0]

        # rater0 likes video0, rater1 likes video1
        ExpertRating.objects.create(user=self.userprefs[0],
                                    video_1=self.videos[0],
                                    video_2=self.videos[1],
                                    **data_rest,
                                    **{self.f: 0})
        ExpertRating.objects.create(user=self.userprefs[1],
                                    video_1=self.videos[0],
                                    video_2=self.videos[1],
                                    **data_rest,
                                    **{self.f: 100})
    def test_init_from_checkpoint_global_step(self):
        """Tests that a simple model trains and exported models are valid."""
        gin.bind_parameter('tf.estimator.RunConfig.save_checkpoints_steps',
                           100)
        gin.bind_parameter('tf.estimator.RunConfig.keep_checkpoint_max', 10)
        model_dir = self.create_tempdir().full_path
        mock_t2r_model = mocks.MockT2RModel(
            preprocessor_cls=noop_preprocessor.NoOpPreprocessor)

        mock_input_generator_train = mocks.MockInputGenerator(
            batch_size=_BATCH_SIZE)

        train_eval.train_eval_model(
            t2r_model=mock_t2r_model,
            input_generator_train=mock_input_generator_train,
            max_train_steps=_MAX_TRAIN_STEPS,
            model_dir=model_dir,
            eval_steps=_EVAL_STEPS,
            eval_throttle_secs=_EVAL_THROTTLE_SECS,
            create_exporters_fn=train_eval.create_default_exporters)
        # The model trains for 1000 steps and saves a checkpoint each 100 steps and
        # keeps 10 -> len == 10.
        self.assertLen(
            tf.io.gfile.glob(os.path.join(model_dir, 'model*.meta')), 10)

        # The continuous training has its own directory.
        continue_model_dir = self.create_tempdir().full_path
        init_from_checkpoint_fn = functools.partial(
            abstract_model.default_init_from_checkpoint_fn,
            checkpoint=model_dir)
        continue_mock_t2r_model = mocks.MockT2RModel(
            preprocessor_cls=noop_preprocessor.NoOpPreprocessor,
            init_from_checkpoint_fn=init_from_checkpoint_fn)
        continue_mock_input_generator_train = mocks.MockInputGenerator(
            batch_size=_BATCH_SIZE)
        train_eval.train_eval_model(
            t2r_model=continue_mock_t2r_model,
            input_generator_train=continue_mock_input_generator_train,
            model_dir=continue_model_dir,
            max_train_steps=_MAX_TRAIN_STEPS + 100,
            eval_steps=_EVAL_STEPS,
            eval_throttle_secs=_EVAL_THROTTLE_SECS,
            create_exporters_fn=train_eval.create_default_exporters)
        # If the model was successful restored including the global step, only 1
        # additional checkpoint to the init one should be created -> len == 2.
        self.assertLen(
            tf.io.gfile.glob(os.path.join(continue_model_dir, 'model*.meta')),
            2)
Example #30
0
def _parse_config_item(key, value):
    if isinstance(value, dict):
        for k, v in value.items():
            _parse_config_item(k, v)
        return
    elif isinstance(value, (list, tuple)):
        for k, v in value:
            _parse_config_item(k, v)
    elif value is None:
        return
    else:
        assert (key is not None)
        # if isinstance(value, six.string_types):
        #     gin.bind_parameter(key, '"{}"'.format(value))
        # else:
        gin.bind_parameter(key, value)