def testTargetSpec(self, target_name): gin.clear_config() gin.bind_parameter("cloud.path", self.test_cloud) gin.bind_parameter("german.path", self.test_german) target, spec = neutra.GetTargetSpec( target_name, num_dims=5, regression_dataset="german", regression_type="gamma_scales2") lp = self.evaluate(target.log_prob(tf.ones([2, spec.num_dims]))) self.assertAllEqual([2], lp.shape)
def test_evaluate_using_environment_steps(self): gin.bind_parameter('metrics_online.StddevWithinRuns.eval_points', [2001]) metric_instances = [ metrics_online.StddevWithinRuns(), metrics_online.StddevWithinRuns() ] evaluator = eval_metrics.Evaluator( metric_instances, timepoint_variable='Metrics/EnvironmentSteps') results = evaluator.evaluate(self.run_dirs) self.assertEqual(list(results.keys()), ['StddevWithinRuns']) self.assertTrue(np.greater(list(results.values()), 0.).all())
def bind_gin_params(xm_params): """Binding parameters from the given dictionary. Args: xm_params: dict, <key,value> pairs where key is a valid gin parameter. """ tf.logging.info('xm_pararameters:\n') for param_name, param_value in xm_params.items(): # Quote non-numeric values. tf.logging.info('%s=%s\n' % (param_name, param_value)) with gin.unlock_config(): gin.bind_parameter(param_name, param_value)
def test_create_actor_behavioral_cloning_agent(self): gin.bind_parameter('create_agent.policy_network', actor_distribution_network.ActorDistributionNetwork) gin.bind_parameter('BehavioralCloningAgent.optimizer', tf.compat.v1.train.AdamOptimizer()) tf_agent = agent_creators.create_agent( agent_name='actor_behavioral_cloning', time_step_spec=self._time_step_spec, action_spec=self._action_spec) self.assertIsInstance(tf_agent, behavioral_cloning_agent.BehavioralCloningAgent) self.assertIsInstance(tf_agent, actor_behavioral_cloning_agent.ActorBCAgent)
def setUp(self): super(JaxQuantileAgentTest, self).setUp() self.num_actions = 4 self._num_atoms = 5 self._min_replay_history = 32 self._epsilon_decay_period = 90 self.observation_shape = dqn_agent.NATURE_DQN_OBSERVATION_SHAPE self.observation_dtype = dqn_agent.NATURE_DQN_DTYPE self.stack_size = dqn_agent.NATURE_DQN_STACK_SIZE self.zero_state = onp.zeros( (1,) + self.observation_shape + (self.stack_size,)) gin.bind_parameter('OutOfGraphPrioritizedReplayBuffer.replay_capacity', 100) gin.bind_parameter('OutOfGraphPrioritizedReplayBuffer.batch_size', 2)
def _build_env(): """Builds the environment for the Laikago robot. Returns: The OpenAI gym environment. """ gin.parse_config_file(CONFIG_FILE_SIM) gin.bind_parameter("SimulationParameters.enable_rendering", ENABLE_RENDERING) env = env_loader.load() env.seed(ENV_RANDOM_SEED) return env
def main_configure( configs: Sequence[str], extra_options: Tuple[str, ...], verbosity: str, debug: bool = False, checkpoint: Optional[str] = None, catch_exceptions: bool = True, job_type: str = 'training', data: Optional[str] = None, extension: Optional[str] = None, wandb_continue: Optional[str] = None, ) -> Generator[Main, None, None]: if wandb_continue is not None: run = _get_wandb_run(wandb_continue) resume_args = dict(resume=True, id=run.id, name=run.name, config=run.config, notes=run.notes, tags=run.tags) else: resume_args = {} wandb.init(sync_tensorboard=False, job_type=job_type, **resume_args) gin.parse_config_files_and_bindings(configs, extra_options) with gin.unlock_config(): gin.bind_parameter( 'main.base_logdir', str(Path(gin.query_parameter('main.base_logdir')).absolute())) with open(Path(wandb.run.dir) / f'config_{job_type}.gin', 'w') as f: for config in configs: f.write(f'\n# {config}\n') f.write(open(config).read()) f.write('\n# Extra options\n') f.write('\n'.join(extra_options)) checkpoint_path = None if checkpoint is None else Path( checkpoint).absolute() tempdir = None try: if data: # Habitat assumes data is stored in local 'data' directory tempdir = TemporaryDirectory() (Path(tempdir.name) / 'data').symlink_to(Path(data).absolute()) os.chdir(tempdir.name) yield Main(verbosity, debug=debug, catch_exceptions=catch_exceptions, extension=extension, checkpoint=checkpoint_path) finally: if tempdir: tempdir.cleanup()
def predict(self, input_file, output_file, checkpoint_steps=-1, beam_size=1, temperature=1.0, sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH, vocabulary=None): """Predicts targets from the given inputs. Args: input_file: str, path to a text file containing newline-separated input prompts to predict from. output_file: str, path prefix of output file to write predictions to. Note the checkpoint step will be appended to the given filename. checkpoint_steps: int, list of ints, or None. If an int or list of ints, inference will be run on the checkpoint files in `model_dir` whose global steps are closest to the global steps provided. If None, run inference continuously waiting for new checkpoints. If -1, get the latest checkpoint from the model directory. beam_size: int, a number >= 1 specifying the number of beams to use for beam search. temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. sentencepiece_model_path: str, path to the SentencePiece model file to use for decoding. Must match the one used during training. vocabulary: vocabularies.Vocabulary object to use for tokenization, or None to use a SentencePieceVocabulary with the provided sentencepiece_model_path. """ # TODO(sharannarang) : It would be nice to have a function like # load_checkpoint that loads the model once and then call decode_from_file # multiple times without having to restore the checkpoint weights again. # This would be particularly useful in colab demo. if checkpoint_steps == -1: checkpoint_steps = _get_latest_checkpoint_from_dir(self._model_dir) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) gin.bind_parameter("Bitransformer.decode.beam_size", beam_size) gin.bind_parameter("Bitransformer.decode.temperature", temperature) if vocabulary is None: vocabulary = t5.data.SentencePieceVocabulary( sentencepiece_model_path) utils.infer_model(self.estimator(vocabulary), vocabulary, self._sequence_length, self.batch_size, self._model_type, self._model_dir, checkpoint_steps, input_file, output_file)
def latent1d(ctx, rows, cols, plot, filename, **kwargs): """Latent space traversal in 1D.""" add_gin(ctx, "config", ["evaluate/visual/latent1d.gin"]) parse(ctx, set_seed=True) with gin.unlock_config(): gin.bind_parameter("disentangled.visualize.show.output.show_plot", plot) if filename is not None: gin.bind_parameter("disentangled.visualize.show.output.filename", filename) if rows is not None: gin.bind_parameter("disentangled.visualize.traversal1d.dimensions", rows) if cols is not None: gin.bind_parameter("disentangled.visualize.traversal1d.steps", cols) dataset = ctx.obj["dataset"].pipeline() disentangled.visualize.traversal1d( ctx.obj["model"], dataset, dimensions=gin.REQUIRED, offset=gin.REQUIRED, skip_batches=gin.REQUIRED, steps=gin.REQUIRED, )
def export(self, export_dir=None, checkpoint_step=-1, beam_size=1, temperature=1.0, sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH): """Exports a TensorFlow SavedModel. Args: export_dir: str, a directory in which to export SavedModels. Will use `model_dir` if unspecified. checkpoint_step: int, checkpoint to export. If -1 (default), use the latest checkpoint from the pretrained model directory. beam_size: int, a number >= 1 specifying the number of beams to use for beam search. temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. sentencepiece_model_path: str, path to the SentencePiece model file to use for decoding. Must match the one used during training. """ if checkpoint_step == -1: checkpoint_step = _get_latest_checkpoint_from_dir(self._model_dir) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) gin.bind_parameter("Bitransformer.decode.beam_size", beam_size) gin.bind_parameter("Bitransformer.decode.temperature", temperature) gin.bind_parameter("utils.get_variable_dtype.slice_dtype", "float32") gin.bind_parameter("utils.get_variable_dtype.activation_dtype", "float32") vocabulary = t5.data.SentencePieceVocabulary(sentencepiece_model_path) model_ckpt = "model.ckpt-" + str(checkpoint_step) export_dir = export_dir or self._model_dir utils.export_model( self.estimator(vocabulary, disable_tpu=True), export_dir, vocabulary, self._sequence_length, batch_size=self.batch_size, checkpoint_path=os.path.join(self._model_dir, model_ckpt))
def export(self, export_dir=None, checkpoint_step=-1, beam_size=1, temperature=1.0, vocabulary=None, eval_with_score=False): """Exports a TensorFlow SavedModel. Args: export_dir: str, a directory in which to export SavedModels. Will use `model_dir` if unspecified. checkpoint_step: int, checkpoint to export. If -1 (default), use the latest checkpoint from the pretrained model directory. beam_size: int, a number >= 1 specifying the number of beams to use for beam search. temperature: float, a value between 0 and 1 (must be 0 if beam_size > 1) 0.0 means argmax, 1.0 means sample according to predicted distribution. vocabulary: vocabularies.Vocabulary object to use for tokenization, or None to use the default SentencePieceVocabulary. eval_with_score: If True, compute log-likelihood scores of targets. If False, do inference to generate outputs. Returns: The string path to the exported directory. """ if checkpoint_step == -1: checkpoint_step = utils.get_latest_checkpoint_from_dir( self._model_dir) with gin.unlock_config(): gin.parse_config_file(_operative_config_path(self._model_dir)) gin.bind_parameter("Bitransformer.decode.beam_size", beam_size) gin.bind_parameter("Bitransformer.decode.temperature", temperature) if vocabulary is None: vocabulary = utils.get_vocabulary() model_ckpt = "model.ckpt-" + str(checkpoint_step) export_dir = export_dir or self._model_dir estimator = self.estimator(vocabulary, disable_tpu=True, score_in_predict_mode=eval_with_score) return mtf_utils.export_model(estimator, export_dir, vocabulary, self._sequence_length, self._model_type, batch_size=self.batch_size, checkpoint_path=os.path.join( self._model_dir, model_ckpt), eval_with_score=eval_with_score)
def testNeuTraExperiment(self): gin.clear_config() gin.bind_parameter("target_spec.name", "ill_conditioned_gaussian") gin.bind_parameter("chain_stats.compute_stats_over_time", True) exp = neutra.NeuTraExperiment(bijector="affine", log_dir=self.temp_dir) exp.Train(4, batch_size=2) exp.Eval(batch_size=2) exp.Benchmark(test_num_steps=100, test_batch_size=2, batch_size=2) exp.TuneObjective(1, 0.1, batch_size=2, test_num_steps=600, f_name="first_moment_mean")
def fixed(ctx, batch_size, filename, rows, cols, plot, verbose, **kwargs): """View/save images of dataset given a fixed latent factor.""" dataset = ctx.obj['dataset'] add_gin(ctx, 'config', ['evaluate/dataset/{}.gin'.format(dataset)]) parse(ctx) with gin.unlock_config(): gin.bind_parameter('disentangled.visualize.show.output.show_plot', plot) gin.bind_parameter('disentangled.visualize.show.output.filename', filename) if rows is not None: gin.bind_parameter('disentangled.visualize.fixed_factor_data.rows', rows) if cols is not None: gin.bind_parameter('disentangled.visualize.fixed_factor_data.cols', cols) num_values_per_factor = disentangled.dataset.get( dataset).num_values_per_factor dataset = disentangled.dataset.get(dataset).supervised() fixed, _ = disentangled.metric.utils.fixed_factor_dataset( dataset, batch_size, num_values_per_factor) disentangled.visualize.fixed_factor_data(fixed, rows=gin.REQUIRED, cols=gin.REQUIRED, verbose=verbose)
def main(_): logging.set_verbosity(logging.INFO) tf.compat.v1.enable_v2_behavior() tf.compat.v1.enable_resource_variables() gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_bindings, skip_unknown=True) root_dir = FLAGS.root_dir with gin.unlock_config(): gin.bind_parameter('%ROOT_DIR', root_dir) trainer.train(root_dir)
def test_c4_pretrain(self): _t5_gin_config() gin.bind_parameter('c4_bare_preprocess_fn.spm_path', _spm_path()) gin.bind_parameter('batcher.batch_size_per_device', 8) gin.bind_parameter('batcher.eval_batch_size', 8) gin.bind_parameter('batcher.max_eval_length', 50) gin.bind_parameter('batcher.buckets', ([51], [8, 1])) # Just make sure this doesn't throw. _ = tf_inputs.data_streams( 'c4', data_dir=_TESTDATA, input_name='inputs', target_name='targets', bare_preprocess_fn=tf_inputs.c4_bare_preprocess_fn)
def test_inits_rewards_to_actions_serialized(self): precision = 2 gin.bind_parameter('BoxSpaceSerializer.precision', precision) obs_size = 3 n_timesteps = 6 n_controls = 2 rewards_to_actions = ppo.init_rewards_to_actions( vocab_size=4, observation_space=gym.spaces.Box(shape=(obs_size,), low=0, high=1), action_space=gym.spaces.MultiDiscrete(nvec=((2,) * n_controls)), n_timesteps=n_timesteps, ) n_action_symbols = n_timesteps * (obs_size * precision + n_controls) self.assertEqual(rewards_to_actions.shape, (n_timesteps, n_action_symbols))
def _make_space_and_serializer( self, low=-10, high=10, shape=(2, ), # Weird vocab_size to test that it doesn't only work with powers of 2. vocab_size=257, # Enough precision to represent float32s accurately. precision=4, ): gin.bind_parameter("BoxSpaceSerializer.precision", precision) space = gym.spaces.Box(low=low, high=high, shape=shape) serializer = space_serializer.create(space, vocab_size=vocab_size) return (space, serializer)
def visual(ctx, rows, cols, plot, filename, **kwargs): """Qualitative evaluation of output.""" parse(ctx, set_seed=True) with gin.unlock_config(): gin.bind_parameter("disentangled.visualize.show.output.show_plot", plot) if filename is not None: gin.bind_parameter("disentangled.visualize.show.output.filename", filename) if rows is not None: gin.bind_parameter("disentangled.visualize.reconstructed.rows", rows) if cols is not None: gin.bind_parameter("disentangled.visualize.reconstructed.cols", cols) dataset = ctx.obj["dataset"].pipeline() disentangled.visualize.reconstructed(ctx.obj["model"], dataset, rows=gin.REQUIRED, cols=gin.REQUIRED)
def parse_gin_bindings(gin_bindings: Dict = None, ) -> None: """ Parse any extra gin bindings to the config. """ if gin_bindings is None: log.info(f'No additional gin bindings to parse') else: log.info(f'Parsing additional bindings: {pformat(gin_bindings)}') with gin.unlock_config(): for key, value in replace_human_redable_kwargs(gin_bindings): try: gin.bind_parameter(key, value) _message = 'BOUND ' except: _message = 'IGNORED' log.info(f'{_message} - {key} : {value}')
def test_training_loop_onlinetune(self): with self.tmp_dir() as output_dir: gin.bind_parameter("OnlineTuneEnv.model", functools.partial( models.MLP, n_hidden_layers=0, n_output_classes=1)) gin.bind_parameter("OnlineTuneEnv.inputs", functools.partial( trax_inputs.random_inputs, input_shape=(1, 1), input_dtype=np.float32, output_shape=(1, 1), output_dtype=np.float32)) gin.bind_parameter("OnlineTuneEnv.train_steps", 2) gin.bind_parameter("OnlineTuneEnv.eval_steps", 2) gin.bind_parameter( "OnlineTuneEnv.output_dir", os.path.join(output_dir, "envs")) self._run_training_loop("OnlineTuneEnv-v0", output_dir)
def test_finalize(self): gin.bind_parameter('f.x', 'global') gin.finalize() self.assertTrue(gin.config_is_locked()) with GinState() as temp_state: gin.bind_parameter('f.x', 'temp') self.assertEqual(gin.query_parameter('f.x'), 'temp') self.assertFalse(gin.config_is_locked()) self.assertTrue(gin.config_is_locked()) with temp_state: self.assertFalse(gin.config_is_locked()) gin.config.finalize() self.assertTrue(gin.config_is_locked()) with temp_state: self.assertTrue(gin.config_is_locked())
def test_significance_map(self): gin.bind_parameter('BoxSpaceSerializer.precision', 3) significance_map = serialization_utils.significance_map( observation_serializer=space_serializer.create(gym.spaces.Box( low=0, high=1, shape=(2, )), vocab_size=2), action_serializer=space_serializer.create( gym.spaces.MultiDiscrete(nvec=[2, 2]), vocab_size=2), representation_length=20, ) np.testing.assert_array_equal( significance_map, # obs1, act1, obs2, act2, obs3 cut after 4th symbol. [0, 1, 2, 0, 1, 2, 0, 0, 0, 1, 2, 0, 1, 2, 0, 0, 0, 1, 2, 0], )
def setUp(self): super(FullRainbowAgentTest, self).setUp() self._num_actions = 4 self._num_atoms = 5 self._vmax = 7. self.observation_shape = dqn_agent.NATURE_DQN_OBSERVATION_SHAPE self.observation_dtype = dqn_agent.NATURE_DQN_DTYPE self.stack_size = dqn_agent.NATURE_DQN_STACK_SIZE self.zero_state = onp.zeros((1,) + self.observation_shape + (self.stack_size,)) gin.bind_parameter('OutOfGraphPrioritizedReplayBuffer.replay_capacity', 100) gin.bind_parameter('OutOfGraphPrioritizedReplayBuffer.batch_size', 2) gin.bind_parameter('JaxDQNAgent.min_replay_history', 32) gin.bind_parameter('JaxDQNAgent.epsilon_eval', 0.0) gin.bind_parameter('JaxDQNAgent.epsilon_decay_period', 90)
def setUp(self): gin.bind_parameter( 'get_observation_processing_layer_creator.quantile_file_dir', os.path.join(constant.BASE_DIR, 'testdata')) observation_spec = tf.TensorSpec(dtype=tf.int64, shape=(), name='callee_users') self._time_step_spec = time_step.time_step_spec(observation_spec) self._action_spec = tensor_spec.BoundedTensorSpec( dtype=tf.int64, shape=(), minimum=0, maximum=1, name='inlining_decision') super(AgentCreatorsTest, self).setUp()
def sweep(parameters): keys = sorted(list(parameters.keys())) values = [parameters[k] for k in keys] cartesian_product = itertools.product(*values) sweep_as_dict = {} for setting in cartesian_product: for i, v in enumerate(setting): gin.bind_parameter(keys[i], v) sweep_as_dict[keys[i]] = v sweep_as_str = [ f"{k.split('.')[-1]}:{v}" for k, v in sweep_as_dict.items() ] sweep_as_str = '_'.join(sweep_as_str) yield sweep_as_str
def testNeuTraExperiment(self): gin.clear_config() gin.bind_parameter("target_spec.name", "easy_gaussian") gin.bind_parameter("target_spec.num_dims", 2) exp = neutra.NeuTraExperiment(train_batch_size=2, test_chain_batch_size=2, bijector="affine", log_dir=self.temp_dir) with tf.Session() as sess: exp.Initialize(sess) exp.TrainBijector(sess, 1) exp.Eval(sess) exp.Benchmark(sess) exp.Tune(sess, method="random", max_num_trials=1)
def testSingleTrainingStepDiscItersWithEma(self, disc_iters): parameters = { "architecture": c.RESNET_CIFAR, "lambda": 1, "z_dim": 128, "dics_iters": disc_iters, } gin.bind_parameter("ModularGAN.g_use_ema", True) dataset = datasets.get_dataset("cifar10") gan = ModularGAN(dataset=dataset, parameters=parameters, model_dir=self.model_dir) estimator = gan.as_estimator(self.run_config, batch_size=2, use_tpu=False) estimator.train(gan.input_fn, steps=1) # Check for moving average variables in checkpoint. checkpoint_path = tf.train.latest_checkpoint(self.model_dir) ema_vars = sorted([ v[0] for v in tf.train.list_variables(checkpoint_path) if v[0].endswith("ExponentialMovingAverage") ]) tf.logging.info("ema_vars=%s", ema_vars) expected_ema_vars = sorted([ "generator/fc_noise/kernel/ExponentialMovingAverage", "generator/fc_noise/bias/ExponentialMovingAverage", "generator/B1/up_conv_shortcut/kernel/ExponentialMovingAverage", "generator/B1/up_conv_shortcut/bias/ExponentialMovingAverage", "generator/B1/up_conv1/kernel/ExponentialMovingAverage", "generator/B1/up_conv1/bias/ExponentialMovingAverage", "generator/B1/same_conv2/kernel/ExponentialMovingAverage", "generator/B1/same_conv2/bias/ExponentialMovingAverage", "generator/B2/up_conv_shortcut/kernel/ExponentialMovingAverage", "generator/B2/up_conv_shortcut/bias/ExponentialMovingAverage", "generator/B2/up_conv1/kernel/ExponentialMovingAverage", "generator/B2/up_conv1/bias/ExponentialMovingAverage", "generator/B2/same_conv2/kernel/ExponentialMovingAverage", "generator/B2/same_conv2/bias/ExponentialMovingAverage", "generator/B3/up_conv_shortcut/kernel/ExponentialMovingAverage", "generator/B3/up_conv_shortcut/bias/ExponentialMovingAverage", "generator/B3/up_conv1/kernel/ExponentialMovingAverage", "generator/B3/up_conv1/bias/ExponentialMovingAverage", "generator/B3/same_conv2/kernel/ExponentialMovingAverage", "generator/B3/same_conv2/bias/ExponentialMovingAverage", "generator/final_conv/kernel/ExponentialMovingAverage", "generator/final_conv/bias/ExponentialMovingAverage", ]) self.assertAllEqual(ema_vars, expected_ema_vars)
def setUp(self): load_gin_config("backend/ml_model/config/featureless_config.gin") gin.bind_parameter( "FeaturelessMedianPreferenceAverageRegularizationAggregator.epochs", 1000) # creating videos self.videos = [ Video.objects.create(video_id=f"video{i}") for i in tqdm(range(2)) ] # creating users self.djangousers = [ DjangoUser.objects.create_user(username=f"rater{i}", password=f"1234{i}") for i in tqdm(range(2)) ] self.userprefs = [ UserPreferences.objects.create(user=u) for u in self.djangousers ] # making the user verified self.userinfos = [ UserInformation.objects.create(user=u) for u in self.djangousers ] self.verify = [False, True] accepted_domain = create_accepted_domain() self.vemails = [ VerifiableEmail.objects.create(user=ui, email=f"{uuid1()}{accepted_domain}", is_verified=verify) for ui, verify in zip(self.userinfos, self.verify) ] data_rest = {k: 50 for k in VIDEO_FIELDS[1:]} self.f = VIDEO_FIELDS[0] # rater0 likes video0, rater1 likes video1 ExpertRating.objects.create(user=self.userprefs[0], video_1=self.videos[0], video_2=self.videos[1], **data_rest, **{self.f: 0}) ExpertRating.objects.create(user=self.userprefs[1], video_1=self.videos[0], video_2=self.videos[1], **data_rest, **{self.f: 100})
def test_init_from_checkpoint_global_step(self): """Tests that a simple model trains and exported models are valid.""" gin.bind_parameter('tf.estimator.RunConfig.save_checkpoints_steps', 100) gin.bind_parameter('tf.estimator.RunConfig.keep_checkpoint_max', 10) model_dir = self.create_tempdir().full_path mock_t2r_model = mocks.MockT2RModel( preprocessor_cls=noop_preprocessor.NoOpPreprocessor) mock_input_generator_train = mocks.MockInputGenerator( batch_size=_BATCH_SIZE) train_eval.train_eval_model( t2r_model=mock_t2r_model, input_generator_train=mock_input_generator_train, max_train_steps=_MAX_TRAIN_STEPS, model_dir=model_dir, eval_steps=_EVAL_STEPS, eval_throttle_secs=_EVAL_THROTTLE_SECS, create_exporters_fn=train_eval.create_default_exporters) # The model trains for 1000 steps and saves a checkpoint each 100 steps and # keeps 10 -> len == 10. self.assertLen( tf.io.gfile.glob(os.path.join(model_dir, 'model*.meta')), 10) # The continuous training has its own directory. continue_model_dir = self.create_tempdir().full_path init_from_checkpoint_fn = functools.partial( abstract_model.default_init_from_checkpoint_fn, checkpoint=model_dir) continue_mock_t2r_model = mocks.MockT2RModel( preprocessor_cls=noop_preprocessor.NoOpPreprocessor, init_from_checkpoint_fn=init_from_checkpoint_fn) continue_mock_input_generator_train = mocks.MockInputGenerator( batch_size=_BATCH_SIZE) train_eval.train_eval_model( t2r_model=continue_mock_t2r_model, input_generator_train=continue_mock_input_generator_train, model_dir=continue_model_dir, max_train_steps=_MAX_TRAIN_STEPS + 100, eval_steps=_EVAL_STEPS, eval_throttle_secs=_EVAL_THROTTLE_SECS, create_exporters_fn=train_eval.create_default_exporters) # If the model was successful restored including the global step, only 1 # additional checkpoint to the init one should be created -> len == 2. self.assertLen( tf.io.gfile.glob(os.path.join(continue_model_dir, 'model*.meta')), 2)
def _parse_config_item(key, value): if isinstance(value, dict): for k, v in value.items(): _parse_config_item(k, v) return elif isinstance(value, (list, tuple)): for k, v in value: _parse_config_item(k, v) elif value is None: return else: assert (key is not None) # if isinstance(value, six.string_types): # gin.bind_parameter(key, '"{}"'.format(value)) # else: gin.bind_parameter(key, value)