def testCurriculumDataSourceTransitionsCorrectlyWithMixingDataSource(self): sources = [ datasource.WithinBatchMixingDataSource.Params().Set( file_patterns=['file1', 'file2'], weights=[1, 5]), datasource.WithinBatchMixingDataSource.Params().Set( file_patterns=['file3', 'file4'], weights=[2, 3]) ] boundary = 5 ds_params = datasource.CurriculumDataSource.Params().Set( datasource_params=sources, boundaries=[boundary]) ds = ds_params.Instantiate() ret = ds.BuildDataSource(_MockDataSourceFromFilePattern) with tf.Session() as sess: # Advance the global step to the next curriculum stage global_step = py_utils.GetOrCreateGlobalStepVar() tf.global_variables_initializer().run() set_global_step = tf.assign(global_step, boundary, name='advance_step') sess.run(set_global_step) ret.data = sess.run([ret.data]) self.assertCountEqual( sorted(ret.keys()), ['bprop_variable_filters', 'data']) self.assertAllEqual(ret.data, [[b'file3,file4']]) self.assertCountEqual(ret.bprop_variable_filters, [''])
def _MaybeConstructSharedModel(self, train_cfg): """Construct a single shared copy of the model if this is a MultiTaskModel. If the share_model_object parameter is set, for MultiTaskModels, we create a MultiTaskSubModel for each task, but construct the model only once. Args: train_cfg: The params for a SingleTaskModel or MultiTaskModel. Returns: A MultiTaskModel, if train_cfg is a MultiTaskModel params object. """ if not issubclass(train_cfg.cls, base_model.MultiTaskModel): return None if not train_cfg.share_model_object: return None with self._graph.as_default(), tf.container(self._container_id): with self._cluster, tf.device( self._cluster.job_spec.name if not FLAGS. cluster_placer_in_executor else self._cluster.GetPlacer()): with py_utils.VariableRenameScope( self._variable_renaming_rules): _ = py_utils.GetOrCreateGlobalStepVar() shared_model = train_cfg.Instantiate() shared_model.InstantiateVariables() return shared_model
def testSharedEncBiasWeights(self): model_dim = 4 key_value_dim = 2 num_heads = 2 g = tf.Graph() with g.as_default(), self.SetEval(True): _ = py_utils.GetOrCreateGlobalStepVar() # for DeterministicDropout builder = FakeMoEBuilder.Params().Set( num_devices=FLAGS.num_partitions, dropout_rate=0, model_dim=model_dim, attention_key_value_dim=key_value_dim, attention_num_heads=num_heads) builder = builder.Instantiate() p = builder._Seq('model', builder.FakeLayer('layer0'), builder.FakeLayer('layer1')) layer = p.Instantiate() all_vars = tf.trainable_variables() tf.logging.info(all_vars) self.assertEqual(1, len(all_vars)) with tf.Session(graph=g) as sess, self.SetEval(True): x = tf.ones([model_dim]) y = layer.FPropDefaultTheta(x) sess.run(tf.global_variables_initializer()) y_val = sess.run(y) self.assertAllEqual([3.] * model_dim, y_val)
def testCurriculumDataSourceTransitionsCorrectlyWithMixingDataSource(self): sources = [ datasource.SimpleDataSource.Params().Set( file_pattern=['file1', 'file2'], weights=[1, 5]), datasource.SimpleDataSource.Params().Set( file_pattern=['file3', 'file4'], weights=[2, 3]) ] boundary = 5 ds_params = datasource.CurriculumDataSource.Params().Set( sub=sources, boundaries=[boundary]) ds = ds_params.Instantiate() ds.SetInputGenerator(TestInputGenerator.Params().Instantiate()) with self.session(): # Advance the global step to the next curriculum stage global_step = py_utils.GetOrCreateGlobalStepVar() self.evaluate(tf.global_variables_initializer()) set_global_step = tf.assign(global_step, boundary, name='advance_step') self.evaluate(set_global_step) batch = ds.GetNext() ret = ds.GetMeta() ret.data = self.evaluate(batch.data) self.assertCountEqual(sorted(ret.keys()), ['bprop_variable_filters', 'data']) self.assertEqual(ret.data, b'file3,file4') self.assertCountEqual(ret.bprop_variable_filters, [''])
def _buildGraphAndSaver(logdir, keep_latest_n=5, keep_every_n_hours=None, save_async=False): tf.random.set_seed(123) g = tf.Graph() with g.as_default(): p = mnist.LeNet5().Task() p.input = mnist.LeNet5().Train() with cluster_factory.ForTestingWorker(mode='sync', job='controller'): _ = p.Instantiate() gsv = py_utils.GetOrCreateGlobalStepVar() inc = gsv.assign_add(1) variables = tf.all_variables() sanity_checks = [([gsv], saver.InRange(0, 10))] for var in variables: sanity_checks.append(([var], saver.IsFinite())) sav = saver.Saver( logdir, variables, sanity_checks, keep_latest_n=keep_latest_n, keep_every_n_hours=keep_every_n_hours, async_save=save_async) return g, sav, inc
def _testLayerHelper(self, test_case, p, expected=None, not_expected=None, global_step=-1): tf.random.set_seed(398847392) np.random.seed(12345) p.name = 'proj' p.input_dim = 3 p.output_dim = 4 p.params_init = py_utils.WeightInit.Gaussian(0.1) l = p.Instantiate() in_padding = tf.zeros([2, 4, 1], dtype=tf.float32) in_padding = tf.constant([[[0], [0], [1], [0]], [[1], [1], [0], [0]]], dtype=tf.float32) inputs = tf.constant(np.random.normal(0.1, 0.5, [2, 4, 3]), dtype=tf.float32) output = l.FPropDefaultTheta(inputs, in_padding) self.evaluate(tf.global_variables_initializer()) if global_step >= 0: self.evaluate( tf.assign(py_utils.GetOrCreateGlobalStepVar(), global_step)) output = output.eval() print('QuantizableLayerTest output', test_case, ':\n', np.array_repr(output)) if expected is not None: self.assertAllClose(output, expected) if not_expected is not None: self.assertNotAllClose(output, not_expected) return l
def setUp(self): super().setUp() # Ensure the global_step variable is created in the default graph. py_utils.GetOrCreateGlobalStepVar() cluster = cluster_factory.SetRequireSequentialInputOrder(True) cluster.params.in_unit_test = True cluster.__enter__()
def testInitRulesDirectory(self): train_dir = os.path.join(self.get_temp_dir(), 'testInitRulesDirectory') os.mkdir(train_dir) p = base_model.SingleTaskModel.Params(LinearModel.Params()) p.input = base_input_generator.BaseInputGenerator.Params() b1 = 1234 g1 = 10 b2 = 12345 g2 = 100 with self.session(graph=tf.Graph()) as sess: model = p.Instantiate() self.evaluate(tf.global_variables_initializer()) saver = checkpointer.Checkpointer(train_dir, model) self.evaluate(tf.assign(py_utils.GetOrCreateGlobalStepVar(), g1)) self.evaluate(tf.assign(model.GetTask().vars.b, b1)) saver.Save(sess, model.global_step) self.evaluate(tf.assign(py_utils.GetOrCreateGlobalStepVar(), g2)) self.evaluate(tf.assign(model.GetTask().vars.b, b2)) saver.Save(sess, model.global_step) train_dir_2 = os.path.join(self.get_temp_dir(), 'testInitRulesDirectory_2') # Set init_checkpoint_rules to only restore b from a specific ckpt # the first one, not the latest one. rules = [('(.*)', '%s')] spec_dir = os.path.join(train_dir, 'ckpt-00000010') p.train.init_from_checkpoint_rules = {spec_dir: (rules, [])} with self.session(graph=tf.Graph()) as sess: model = p.Instantiate() saver = checkpointer.Checkpointer(train_dir_2, model) saver.RestoreIfNeeded(sess) new_b = self.evaluate(model.GetTask().vars.b) self.assertEqual(b1, new_b) # Set init_checkpoint_rules to restore all from the latest checkpoint # by specifying just the original train directory, not a specific # checkpoint. rules = [('(.*)', '%s')] p.train.init_from_checkpoint_rules = {train_dir: (rules, [])} with self.session(graph=tf.Graph()) as sess: model = p.Instantiate() saver = checkpointer.Checkpointer(train_dir_2, model) saver.RestoreIfNeeded(sess) new_b = self.evaluate(model.GetTask().vars.b) self.assertEqual(b2, new_b)
def BuildDataSource(self, data_source_from_file_pattern_fn): """Read and return input batch. Args: data_source_from_file_pattern_fn: a function to read and return input batch from a string file_pattern Returns: A NestedMap containing: data: a tuple of tf.Tensor or `.NestedMap` of tf.Tensor Raises: ValueError: inconsistent sizes between boundaries and datasource_params, specification of unsupported datasources, or out of order boundaries. """ p = self.params if len(p.datasource_params) != len(p.boundaries) + 1: raise ValueError( 'Expected p.datasource_params to have one more entry than ' 'p.boundaries. Found %d datasource_params, and %d boundaries' % (len(p.datasource_params), len(p.boundaries))) for ds_p in p.datasource_params: if 'bprop_variable_filters' in ds_p: if any(filter for filter in ds_p.bprop_variable_filters): raise ValueError('CurriculumDataSource does not support distinct ' 'bprop_variable_filters per stage.') for idx in range(len(p.boundaries) - 1): if p.boundaries[idx] > p.boundaries[idx + 1]: raise ValueError('Expected p.boundaries to monotonically increase, but ' 'found %d > %d at position %d' % (p.boundaries[idx], p.boundaries[idx + 1], idx)) global_step = py_utils.GetOrCreateGlobalStepVar() datasources = [ds_p.Instantiate() for ds_p in p.datasource_params] def GetDatasourceFn(idx): def DatasourceFn(): datasource = datasources[idx].BuildDataSource( data_source_from_file_pattern_fn) datasource.pop('bprop_variable_filters', None) return datasource return DatasourceFn cases = [] for idx in range(len(p.boundaries)): cases.append( (tf.less(global_step, tf.constant(p.boundaries[idx], dtype=global_step.dtype)), GetDatasourceFn(idx))) ret = tf.case(cases, default=GetDatasourceFn(-1)) ret.bprop_variable_filters = p.bprop_variable_filters return ret
def _create_session(self, *args, **kwargs): sess = super()._create_session(*args, **kwargs) with sess.graph.as_default(): # Ensure the global_step variable is created in every new session. global_step = py_utils.GetOrCreateGlobalStepVar() sess.run( tf.cond(tf.is_variable_initialized(global_step), tf.no_op, lambda: tf.variables_initializer([global_step]))) return sess
def testSingleCheckpoint(self): logdir = tempfile.mkdtemp() g = tf.Graph() with g.as_default(): _ = py_utils.GetOrCreateGlobalStepVar() sav = saver.Saver(logdir, tf.all_variables(), [], keep_latest_n=1) with self.session(graph=g) as sess: sess.run(tf.global_variables_initializer()) _ = sav.Save(sess)
def setUp(self): super().setUp() with contextlib.ExitStack() as stack: stack.enter_context(py_utils.VariableStore()) self.addCleanup(stack.pop_all().close) # Ensure the global_step variable is created in the default graph. py_utils.GetOrCreateGlobalStepVar() cluster = cluster_factory.SetRequireSequentialInputOrder(True) cluster.params.in_unit_test = True cluster.__enter__()
def testSaveRestore(self, use_custom_saver): FLAGS.use_custom_saver = use_custom_saver train_dir = os.path.join(self.get_temp_dir(), 'testSaveRestore') os.mkdir(train_dir) p = base_model.SingleTaskModel.Params(LinearModel.Params()) p.input = base_input_generator.BaseInputGenerator.Params() final_global_step = 10 expected_w = [0.38615, 2.975221, -0.852826] initial_b = 1.418741 final_b = 1234 with self.session(graph=tf.Graph()) as sess: model = p.Instantiate() self.evaluate(tf.global_variables_initializer()) w, b = self.evaluate( [model.GetTask().vars.w, model.GetTask().vars.b]) self.assertAllClose(expected_w, w) self.assertAlmostEqual(initial_b, b, places=5) saver = checkpointer.Checkpointer(train_dir, model) self.evaluate( tf.assign(py_utils.GetOrCreateGlobalStepVar(), final_global_step)) self.evaluate(tf.assign(model.GetTask().vars.b, final_b)) saver.Save(sess, model.global_step) w, b = self.evaluate( [model.GetTask().vars.w, model.GetTask().vars.b]) self.assertAllClose(expected_w, w) self.assertEqual(final_b, b) self.assertTrue( os.path.isfile( os.path.join(train_dir, 'ckpt-%08d.index' % final_global_step))) with self.session(graph=tf.Graph()) as sess: model = p.Instantiate() saver = checkpointer.Checkpointer(train_dir, model) saver.RestoreIfNeeded(sess) w, b, global_step = self.evaluate([ model.GetTask().vars.w, model.GetTask().vars.b, model.global_step ]) self.assertAllClose(expected_w, w) self.assertEqual(final_b, b) self.assertEqual(final_global_step, global_step) # Restore from checkpoint will always work, even though vars are already # initialized. saver.Restore(sess)
def testFakeQuantizationScheduleFromDefun(self): p = quant_utils.FakeQuantizationSchedule.Params() p.clip_start_step = 5 p.clip_end_step = 10 p.quant_start_step = 15 p.start_cap = 6.0 p.end_cap = 1.0 with self.session(): cc_schedule = p.Instantiate() self.evaluate(tf.global_variables_initializer()) # Move to fully quantized part of schedule self.evaluate(tf.assign(py_utils.GetOrCreateGlobalStepVar(), 16)) @tf.Defun(tf.float32, tf.float32) def ExampleFunction8(x, cc_state): return cc_schedule.ApplyClippingWithState(cc_state, x, bits=8) @tf.Defun(tf.float32, tf.float32) def ExampleFunction16(x, cc_state): return cc_schedule.ApplyClippingWithState(cc_state, x, bits=16) a = tf.constant(1.0) b = tf.constant(0.5) # 8bit value. v = ExampleFunction8(a * b, cc_schedule.GetState(cc_schedule.theta)) self.assertAllClose(v.eval(), 0.5) # 16bit value. v = ExampleFunction16(a * b, cc_schedule.GetState(cc_schedule.theta)) self.assertAllClose(v.eval(), 0.5) # An incomplete implementation requires special case gradient logic. # This tests it, specifically in a Defun, which caused issues. # 8bit gradient. g = tf.gradients( ExampleFunction8(a * b, cc_schedule.GetState(cc_schedule.theta)), [a, b]) g = [t.eval() for t in g] print('Gradient8:', g) self.assertAllClose(g, (0.5, 1.0)) # 16bit gradient. g = tf.gradients( ExampleFunction16(a * b, cc_schedule.GetState(cc_schedule.theta)), [a, b]) g = [t.eval() for t in g] print('Gradient16:', g) self.assertAllClose(g, (0.5, 1.0))
def apply_gradients(self, grads_and_vars, global_step=None, name=None): if self._num_micro_batches == 1: return self._opt.apply_gradients(grads_and_vars, global_step) global_step = global_step or py_utils.GetOrCreateGlobalStepVar() with tf.init_scope(): self._create_slots([v for (_, v) in grads_and_vars]) accums = [] variables = [] for g, v in grads_and_vars: accum = self.get_slot(v, 'grad_accum') variables.append(v) # pytype: disable=attribute-error if isinstance(g, tf.IndexedSlices): scaled_grad = tf.IndexedSlices(g.values / self._num_micro_batches, g.indices, dense_shape=g.dense_shape) else: scaled_grad = g / self._num_micro_batches accum_tensor = accum.read_value() accums.append(accum.assign(accum_tensor + scaled_grad)) # pytype: enable=attribute-error def _ApplyAndReset(): normalized_accums = accums if self._apply_crs_to_grad: normalized_accums = [ tf.tpu.cross_replica_sum(accum.read_value()) for accum in accums ] apply_op = self._opt.apply_gradients( list(zip(normalized_accums, variables))) with tf.control_dependencies([apply_op]): zero_op = [ tf.assign(accum, tf.zeros_like(accum)) for accum in accums ] return tf.group(zero_op, tf.assign_add(global_step, 1)) def _Accum(): return tf.no_op() accum_step = tf.cond( tf.equal( tf.math.floormod(self._counter + 1, self._num_micro_batches), 0), _ApplyAndReset, # Apply the accumulated gradients and reset. _Accum) # Accumulate gradients. with tf.control_dependencies([tf.group(accums)]): return tf.group(accum_step, tf.assign_add(self._counter, 1))
def testUniTransformerParallelFPropEntmax(self): length_dim = 4 graph = tf.Graph() params = gshard_builder.UniTransformer.Params().Set( gated_gelu=False, gated_ffn_activation=tf.nn.relu, positional_embedding=False, dtype=tf.float32, name='transformer', parallel_ffn=True, hidden_dim_reshape_segments=2, conv_kernel_size=2, builder=gshard_builder.RecurrentDenseBuilderParallelDecode.Params( ).Set( device_mesh_shape=[1, 1], device_mesh=None, relative_attention_num_buckets=32, relative_attention_type='bias', relative_attention_max_distance=128, dtype=tf.float32, num_devices=1, # we call .Split num_devices on axis 0 (batch) relative_attention_use_universal_1d_position=True, model_dim=32, model_dim_reshape_segments=2, attention_num_memory_heads=1, proj_weight_hdim=2, attention_num_heads=8, ff_dim=128, attention_key_value_dim=8, attention_combine_dims=True), batch_size=32, sequence_length=length_dim, num_transformer_layers=2, aux_loss_coef=0.0, loss_denominator=None, label_smoothing=0, vocab_size=128, max_length=length_dim, use_entmax=True) with graph.as_default(): py_utils.GetOrCreateGlobalStepVar() params.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) tf.random.set_seed(24332) model = params.Instantiate() with tf.Session(graph=graph) as sess: input_batch = self._PreLoadInput() loss = model.FPropDefaultTheta(input_batch)[0]['loss'][0] sess.run(tf.global_variables_initializer()) loss_eval = sess.run(loss) test_utils.CompareToGoldenSingleFloat(self, 5.146667, loss_eval)
def _testUniTransformerFProp(self, use_moe=False): length_dim = 4 graph = tf.Graph() params = gshard_builder.UniTransformer.Params().Set( gated_gelu=False, moe=use_moe, moe_gated_gelu=use_moe, positional_embedding=False, dtype=tf.float32, name='transformer', builder=gshard_builder.DenseBuilder.Params().Set( device_mesh_shape=[1, 1], device_mesh=None, relative_attention_num_buckets=32, relative_attention_type='bias', relative_attention_max_distance=128, dtype=tf.float32, num_devices=1, # we call .Split num_devices on axis 0 (batch) relative_attention_use_universal_1d_position=True, e_dim=2 if use_moe else None, num_groups=1 if use_moe else None, c_dim=2 if use_moe else None, model_dim=32, attention_num_heads=8, moe_hidden_dim=128, ff_dim=128, attention_key_value_dim=8, attention_combine_dims=True), batch_size=32, sequence_length=length_dim, num_transformer_layers=2, aux_loss_coef=0.0, loss_denominator=None, label_smoothing=0, vocab_size=128, max_length=length_dim) with graph.as_default(): py_utils.GetOrCreateGlobalStepVar() params.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) tf.random.set_seed(24332) model = params.Instantiate() with tf.Session(graph=graph) as sess: input_batch = self._PreLoadInput() loss = model.FPropDefaultTheta(input_batch)[0]['loss'][0] sess.run(tf.global_variables_initializer()) loss_eval = sess.run(loss) golden_float = 5.761248 if use_moe else 5.635831 test_utils.CompareToGoldenSingleFloat(self, golden_float, loss_eval)
def __init__(self, params): """Initializes this Model.""" assert issubclass(params.cls, BaseModel) self._global_step_var = py_utils.GetOrCreateGlobalStepVar() self._global_step = tf.identity( self._global_step_var, name='global_step_tensor') super(BaseModel, self).__init__(params) self._ema = None tp = self.params.train tf.logging.info('Training parameters for %s: %s', params.cls, tp) if tp.ema_decay > 0: assert tp.ema_decay < 1.0 self._ema = tf.train.ExponentialMovingAverage( decay=tp.ema_decay, num_updates=self.global_step)
def testRandomChoicePreprocessor(self): p = input_preprocessors.RandomChoicePreprocessor.Params() # Construct 4 preprocessors each producing a different value. base = input_preprocessors.ConstantPreprocessor.Params() c1 = (base.Copy().Set(constants={'value': 1}), schedule.Constant.Params().Set(value=1)) c2 = (base.Copy().Set(constants={'value': 2}), schedule.Constant.Params().Set(value=2)) c3 = (base.Copy().Set(constants={'value': 3}), schedule.Constant.Params().Set(value=3)) c4 = (base.Copy().Set(constants={'value': 4}), schedule.Constant.Params().Set(value=4)) p.subprocessors = [c1, c2, c3, c4] # Create global step because schedules depend on it. _ = py_utils.GetOrCreateGlobalStepVar() preprocessor = p.Instantiate() features = py_utils.NestedMap() shapes = py_utils.NestedMap() dtypes = py_utils.NestedMap() # Verify shape / dtypes. new_shapes = preprocessor.TransformShapes(shapes) new_dtypes = preprocessor.TransformDTypes(dtypes) self.assertEqual(new_shapes.value, tf.TensorShape([])) self.assertEqual(new_dtypes.value, tf.int64) self.evaluate(tf.global_variables_initializer()) new_features = preprocessor.TransformFeatures(features) counts = [0, 0, 0, 0] with self.session() as sess: # Run 10000 times to get probability distribution. for _ in range(10000): new_features_np = sess.run(new_features) counts[new_features_np.value - 1] += 1 # Check distribution roughly matches [0.1, 0.2, 0.3, 0.4] self.assertTrue(counts[0] > 800 and counts[0] < 1200) self.assertTrue(counts[1] > 1800 and counts[1] < 2200) self.assertTrue(counts[2] > 2800 and counts[2] < 3200) self.assertTrue(counts[3] > 3800 and counts[3] < 4200)
def testEmbedding(self): builder = gshard_builder.DenseBuilder.Params().Set( model_dim=4, model_dim_reshape_segments=2).Instantiate() ids = [[1, 2, 3], [3, 2, 1]] graph = tf.Graph() with graph.as_default(): tf.random.set_seed(24332) py_utils.GetOrCreateGlobalStepVar() emb_layer_p = builder.Embedding('emb', vocab_dim=4) emb_layer = emb_layer_p.Instantiate() enc_out = emb_layer.FPropDefaultTheta( tf.convert_to_tensor(ids, dtype=tf.int32)) expected_val = [[[[-0.67452705, -2.6386688], [1.1666715, 0.04592554]], [[-1.0561675, -0.48270327], [0.7765603, 0.6768117]], [[0.8349989, 0.67100984], [-0.15557083, 1.275625]]], [[[0.8349989, 0.67100984], [-0.15557083, 1.275625]], [[-1.0561675, -0.48270327], [0.7765603, 0.6768117]], [[-0.67452705, -2.6386688], [1.1666715, 0.04592554]]]] with self.session(graph=graph) as sess: sess.run(tf.global_variables_initializer()) enc_out_vals = sess.run(enc_out) self.assertAllClose(expected_val, enc_out_vals)
def testBertTransformerFProp(self): graph = tf.Graph() params = self._BertTransformerParams() with graph.as_default(): py_utils.GetOrCreateGlobalStepVar() params.params_init = py_utils.WeightInit.Xavier(scale=1.0, seed=0) tf.random.set_seed(24332) bert_model = params.Instantiate() with tf.Session(graph=graph) as sess: input_batch = self._PreLoadInput() metrics_t = bert_model.FPropDefaultTheta(input_batch)[0] loss = bert_model.loss sess.run(tf.global_variables_initializer()) num_total_tokens = sess.run(metrics_t['num_total_tokens'][0]) test_utils.CompareToGoldenSingleFloat(self, 7, num_total_tokens) num_masked_tokens = sess.run(metrics_t['num_masked_tokens'][0]) test_utils.CompareToGoldenSingleFloat(self, 3, num_masked_tokens) loss_eval = sess.run(loss) golden_float = 5.97803 test_utils.CompareToGoldenSingleFloat(self, golden_float, loss_eval)
def __init__(self, train_cfg, ps_params_dict, model_task_name, logdir, tf_master, **kwargs): """Construct an ExecutorTpu BaseRunner. Args: train_cfg: SingleTaskModelParams or MultiTaskModelParams ps_params_dict: A dict of top-level task name -> ProgramSchedule params, if train_cfg is a SingleTaskModelParams, we expect only one entry. model_task_name: An override for multi-task models, currently unused. logdir: String path to the log directory to output to. tf_master: String path to the master job, e.g. 'local'. **kwargs: keyword args to pass through to BaseRunner. """ super().__init__(train_cfg, model_task_name, logdir, tf_master, **kwargs) self._cluster_def = self._cluster.worker_cluster_def # There is a single Executor task assert self._cluster.num_replicas == 1 data_parallelism = self._cluster.num_splits_per_client assert data_parallelism num_devices_per_split = self._cluster.num_devices_per_split tf.logging.info('data_parallelism: %d, num_devices_per_split: %d', data_parallelism, num_devices_per_split) self.task_scheduler = None self._checkpoint_dir = os.path.join(logdir, 'train') self._variable_renaming_rules = [] self._ml_perf = None # If this is a multi-task model, grab the params for the TaskScheduler. if issubclass(train_cfg.cls, base_model.SingleTaskModel): tf.logging.info('single_task_model') assert len(ps_params_dict) == 1 self._model_task_name = list(ps_params_dict.keys())[0] self._single_task_mode = True elif issubclass(train_cfg.cls, base_model.MultiTaskModel): tf.logging.info('multi_task_model') if issubclass(train_cfg.cls, multitask_model.RegExSharedVariableModel): self._variable_renaming_rules = train_cfg.variable_renaming_rules if train_cfg.task_schedule is None: task_schedule_params = task_scheduler.ConstantScheduler.Params( ) task_schedule_params.task_probs = sorted( list(train_cfg.task_probs.IterParams())) else: task_schedule_params = train_cfg.task_schedule self.task_scheduler = task_schedule_params.Instantiate() self._single_task_mode = False else: tf.logging.fatal( 'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel', train_cfg.cls) tf.logging.info('train_cfg.cls: %s', train_cfg.cls) self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir, 'trainer_params.txt') if self._ml_perf is not None: self._ml_perf_log = True mlp_log.mlperf_print(key='benchmark', value=self._ml_perf.benchmark_name) else: self._ml_perf_log = False # BaseRunner legacy self.enqueue_ops = None @py_utils.RetryOnTransientTfError() def _WaitTillInit(): """Wait until the model is ready.""" try: with self._graph.as_default(), self._GetSession( cluster_def=self._cluster_def, disable_meta_optimizer=FLAGS. disable_meta_optimizer_in_executor) as sess: topology = sess.run( tf.tpu.initialize_system(embedding_config=None, job=None)) device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=py_utils.ComputationShape( num_devices_per_split, topology), num_replicas=data_parallelism) py_utils.SetTpuDeviceAssignment(device_assignment) tf.logging.info('device_assignment.core_assignment: %s', str(device_assignment.core_assignment)) tf.logging.info( 'device_assignment.topology.device_coordinates: %s', str(device_assignment.topology.device_coordinates)) except py_utils.transient_tf_errors as e: tf.logging.info('TPU initialization failed: %s', e) raise if self._ml_perf_log: mlp_log.mlperf_print(key='init_start', value=None) _WaitTillInit() train_cfg = self.params shared_model = self._MaybeConstructSharedModel(train_cfg) self._program_schedule_dict = {} self._programs = [] for task_string, program_schedule_params in ps_params_dict.items(): program_schedule_params.logdir = logdir program_schedule_params.num_splits_per_client = data_parallelism program_schedule_params.task_name = task_string # If the model was created above, we'll inject it here as a shared_model. ps = program_schedule_params.Instantiate(shared_model=shared_model) self._program_schedule_dict[task_string] = ps tf.logging.info('program_schedule_params: %s', program_schedule_params.ToText()) self._programs += ps.Programs() if program_schedule_params.ml_perf.benchmark_name is not None: self._ml_perf = program_schedule_params.ml_perf tf.logging.info('num_programs: %d', len(self._programs)) with self._graph.as_default(), tf.container(self._container_id): with self._cluster, tf.device( self._cluster.job_spec.name if not FLAGS. cluster_placer_in_executor else self._cluster.GetPlacer()): with py_utils.VariableRenameScope( self._variable_renaming_rules): _ = py_utils.GetOrCreateGlobalStepVar() for program in self._programs: program.BuildTpuSubgraph() py_utils.ClearTpuSummaryTensors() for program in self._programs: program.SetStatusMessageFn(self._SetStatusMessage) program.CreateCheckpointer() self._initialize_tables = tf.tables_initializer() self._initialize_local_vars = tf.local_variables_initializer() self.save_only_checkpointer = checkpointer.Checkpointer( self._checkpoint_dir, model=None, train_params=train_cfg.train, save_only=True)
def testAccumulator(self): # testAccumulator compares # - explicit averaging of independently computed var_grads1 and # var_grads2, # - Accumulator(SGD) optimizer effectively doing this over 2 steps. np.random.seed(12345) np_input1 = np.random.normal(0.1, 0.5, [2, 4, 3]) np.random.seed(12346) np_input2 = np.random.normal(0.1, 0.5, [2, 4, 3]) with self.session(use_gpu=True, graph=tf.Graph()) as sess: tf.random.set_seed(123456) params = layers.ProjectionLayer.Params() params.name = 'proj' params.dtype = tf.float64 params.input_dim = 3 params.output_dim = 2 params.params_init = py_utils.WeightInit.Gaussian(0.01, 123456) params.batch_norm = False proj_layer = layers.ProjectionLayer(params) inputs1 = tf.placeholder(shape=[2, 4, 3], dtype=tf.float64) in_padding1 = tf.zeros([2, 4, 1], dtype=tf.float64) inputs2 = tf.placeholder(shape=[2, 4, 3], dtype=tf.float64) in_padding2 = tf.zeros([2, 4, 1], dtype=tf.float64) output1 = proj_layer.FPropDefaultTheta(inputs1, in_padding1) output2 = proj_layer.FPropDefaultTheta(inputs2, in_padding2) loss1 = tf.reduce_sum(output1) loss2 = tf.reduce_sum(output2) var_grads1 = py_utils.ComputeGradients(loss1, proj_layer.vars) var_grads2 = py_utils.ComputeGradients(loss2, proj_layer.vars) op = optimizer.SGD.Params() opt = op.Instantiate() lr = 1e-1 with tf.control_dependencies([loss1, loss2]): var_update_op1 = opt.Apply( lr, py_utils.ApplyGradMultiplier(var_grads1, 1. / 2.)) with tf.control_dependencies([var_update_op1]): var_update_op2 = opt.Apply( lr, py_utils.ApplyGradMultiplier(var_grads2, 1. / 2.)) self.evaluate(tf.global_variables_initializer()) vars1 = self.evaluate(proj_layer.vars.Flatten()) loss1_1, grads1_1, loss1_2, grads1_2 = sess.run( [ loss1, var_grads1.Transform(tuple), loss2, var_grads2.Transform(tuple) ], feed_dict={ inputs1: np_input1, inputs2: np_input2, }, ) sess.run([var_update_op2], feed_dict={ inputs1: np_input1, inputs2: np_input2, }) vars1_1 = self.evaluate(proj_layer.vars.Flatten()) with self.session(use_gpu=True, graph=tf.Graph()) as sess: tf.random.set_seed(123456) params = layers.ProjectionLayer.Params() params.name = 'proj' params.dtype = tf.float64 params.input_dim = 3 params.output_dim = 2 params.params_init = py_utils.WeightInit.Gaussian(0.01, 123456) params.batch_norm = False proj_layer = layers.ProjectionLayer(params) in_padding1 = tf.zeros([2, 4, 1], dtype=tf.float64) inputs1 = tf.placeholder(shape=[2, 4, 3], dtype=tf.float64) output1 = proj_layer.FPropDefaultTheta(inputs1, in_padding1) loss = tf.reduce_sum(output1) var_grads = py_utils.ComputeGradients(loss, proj_layer.vars) op = optimizer.Accumulator.Params().Set( accum_steps=2, dtype=tf.float64, optimizer_tpl=optimizer.SGD.Params()) opt = op.Instantiate() lr = 1e-1 with cluster_factory.ForTestingWorker(add_summary=True): var_update_op = opt.Apply(lr, var_grads) increment_global_step_op = tf.assign_add( py_utils.GetOrCreateGlobalStepVar(), 1) self.evaluate(tf.global_variables_initializer()) vars2 = self.evaluate(proj_layer.vars.Flatten()) loss2_1, grads2_1 = sess.run( [loss, var_grads.Transform(tuple)], feed_dict={ inputs1: np_input1, }) loss2_2, grads2_2 = sess.run( [loss, var_grads.Transform(tuple)], feed_dict={ inputs1: np_input2, }) acc_0 = self.evaluate([ v for v in tf.global_variables() if 'grad_accumulator' in v.name ])[0] sess.run([var_update_op], feed_dict={ inputs1: np_input1, }) acc_1 = self.evaluate([ v for v in tf.global_variables() if 'grad_accumulator' in v.name ])[0] vars2_intermediate = self.evaluate(proj_layer.vars.Flatten()) self.evaluate(increment_global_step_op) sess.run([var_update_op], feed_dict={ inputs1: np_input2, }) acc_2 = self.evaluate([ v for v in tf.global_variables() if 'grad_accumulator' in v.name ])[0] vars2_1 = self.evaluate(proj_layer.vars.Flatten()) summary = tf.Summary.FromString( self.evaluate(tf.summary.merge_all())) tf.logging.info(f'summary: {summary}') self.assertEqual(summary.value[0].tag, 'sgd_lr') self.assertAllClose(vars1, vars2) self.assertAllClose(acc_0, np.zeros_like(acc_0)) self.assertAllClose(acc_1, grads2_1['w'][1]) self.assertAllClose(acc_2, np.zeros_like(acc_0)) self.assertAllClose(loss1_1, loss2_1) self.assertAllClose(loss1_2, loss2_2) self.assertAllClose(grads1_1, grads2_1) self.assertAllClose(grads1_2, grads2_2) self.assertAllClose(vars1, vars2_intermediate) self.assertAllClose(vars2[0], grads2_1['w'][0]) self.assertAllClose(vars2[0], grads2_2['w'][0]) self.assertAllClose( vars1[0] - 0.5 * lr * (grads1_1['w'][1] + grads1_2['w'][1]), vars1_1[0]) self.assertAllClose( vars2[0] - 0.5 * lr * (grads2_1['w'][1] + grads2_2['w'][1]), vars2_1[0]) self.assertAllClose(vars2, vars2_intermediate) self.assertAllClose(vars1_1, vars2_1)
def Export(cls, model_cfg, model_task_name=None, device_options=InferenceDeviceOptions( device='', retain_device_placement=False, var_options=None, gen_init_op=True, dtype_override=None), freeze_checkpoint=None, freeze_defaults=False, export_path=None, subgraph_filter=None, random_seed=None, disable_packed_input=True): """Exports a InferenceGraph proto with piecewise subgraphs. Sets FLAGS.enable_asserts to False unless user explicitly sets it to True. Args: model_cfg: a Params instance as returned by model_registry.GetParams(modelname, 'Test') or model_params.Model(). model_task_name: The task to generate an inference graph for. Should be None for single-task models. device_options: Device options for the accelerator used for serving. freeze_checkpoint: The checkpoint to load. Loads and freezes the model if given. freeze_defaults: Default initializes the graph and freeze. Useful for early testing of downstream tools without having a checkpoint. export_path: If not None, write the inference graph in ASCII to this path. subgraph_filter: A list of subgraph names. If not None or empty, export only this list of inference subgraphs. random_seed: Fixes the random seed in the exported inference graph. disable_packed_input: Disable packed input for inference writing purposes. Returns: InferenceGraph proto. Raises: ValueError: if the model does not support the listed subgraphs. """ assert issubclass(model_cfg.cls, base_model.BaseModel) # Disable assertions unless user explicitly enables it. if FLAGS['enable_asserts'].using_default_value: FLAGS.enable_asserts = False # TODO(laurenzo): Work out how much we need to specify here in terms of # cluster configuration. cls._SetClusterParams(model_cfg.cluster, device_options) # Configure the model. model_cfg.random_seed = random_seed model_cfg.is_inference = True if disable_packed_input: def _DisablePackedInput(task): if (_ParamExists(task, 'encoder') and _ParamExists(task.encoder, 'packed_input')): task.encoder.packed_input = False if (_ParamExists(task, 'decoder') and _ParamExists(task.decoder, 'packed_input')): task.decoder.packed_input = False if issubclass(model_cfg.cls, base_model.MultiTaskModel): for _, task_param in model_cfg.task_params.IterParams(): _DisablePackedInput(task_param) else: _DisablePackedInput(model_cfg.task) tf.logging.info('Model %s params:', model_cfg.name) for line in model_cfg.ToText().split('\n'): tf.logging.info('%s', line) # Instantiate the graph. graph = tf.Graph() with graph.as_default(): tf.random.set_seed(random_seed) cluster = model_cfg.cluster.Instantiate() device = cluster.GetPlacer() tpu_const_scope = _DummyScope() if (IsTpu(device_options) and device_options.var_options == 'AS_CONSTANTS'): # Do not specify devices for variables if we are marking them as # constants. device = '' tpu_const_scope = ConstGuaranteeScope() with cluster, tf.device(device), tpu_const_scope: bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations( device_options) if bfloat16_override: py_utils.UpdateDtype(model_cfg, tf.bfloat16) py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16) # Hard-code TPU-related flags prior to instantiating model. old_enable_asserts = FLAGS.enable_asserts old_xla_device = FLAGS.xla_device if IsTpu(device_options): FLAGS.enable_asserts = False FLAGS.xla_device = 'tpu' # Ensure the global_step variable is created. _ = py_utils.GetOrCreateGlobalStepVar() try: mdl = model_cfg.Instantiate() task = mdl.GetTask(model_task_name) variables_to_restore = ( _MakeVariableDictionary(tf.global_variables()) if not mdl.ema else mdl.ema.variables_to_restore(mdl.variables_for_ema)) if bfloat16_override: saver_var_spec = ( bfloat16_variables .get_saver_spec_for_variables_with_bf16_overrides( variables_to_restore)) else: saver_var_spec = variables_to_restore saver = tf.train.Saver(saver_var_spec) tf.variables_initializer( tf.global_variables(), name='init_all_variables') if IsTpu(device_options) and device_options.gen_init_op: tf.group(tf.tpu.initialize_system(), name='tpu_init_op') inference_graph_proto = inference_graph_pb2.InferenceGraph() subgraphs_proto = task.Inference() if isinstance(subgraphs_proto, dict): subgraphs_proto = ConvertSubgraphDictToProto(subgraphs_proto) for name, subgraph in subgraphs_proto.subgraphs.items(): if not subgraph_filter or name in subgraph_filter: inference_graph_proto.subgraphs[name].CopyFrom(subgraph) # Add a table init op and global variable init op to the graph. # Tables can be declared anywhere in the graph, so this op has to be # added last. tf.tables_initializer(name='init_all_tables') finally: # Reset TPU-related flags after model instantiation. FLAGS.enable_asserts = old_enable_asserts FLAGS.xla_device = old_xla_device tf.logging.info('Graph contains ops: %r', [op.name for op in graph.get_operations()]) inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def()) # Freezing. if freeze_defaults or freeze_checkpoint: output_op_names = GetOutputOpNames( graph, inference_graph_proto, preserve_colocation_nodes=False) if cls._DeviceSupportsFreezing(device_options): raise ValueError('freeze_checkpoint cannot be used with device ' + device_options.device) if freeze_checkpoint: tf.logging.info('Freezing graph from checkpoint: %s', freeze_checkpoint) graph_def = _FreezeGraphFromCheckpoint(graph, saver, freeze_checkpoint, output_op_names) elif freeze_defaults: tf.logging.info('Default initializing graph and freezing.') graph_def = _FreezeDefaults(graph, output_op_names) else: output_op_names = GetOutputOpNames(graph, inference_graph_proto) # Prune the graph to just the parts we need. # To support restoring, we have to not prune out the restore node. output_op_names.append('init_all_tables') output_op_names.append('init_all_variables') output_op_names.append('save/control_dependency') output_op_names.append('save/restore_all') if IsTpu(device_options) and device_options.gen_init_op: output_op_names.append('tpu_init_op') graph_def = graph.as_graph_def() tf.logging.info('Pruning graph to output ops: %r', output_op_names) graph_def = tf.graph_util.extract_sub_graph(graph_def, output_op_names) if not device_options.retain_device_placement: # Clear the device so that the runtime can choose. tf.logging.info('Clearing device placement for: %s', device_options.device) for node in graph_def.node: node.ClearField('device') for function in graph_def.library.function: for node_def in function.node_def: node_def.ClearField('device') inference_graph_proto.graph_def.CopyFrom(graph_def) if export_path: with tf.io.gfile.GFile(export_path, 'w') as f: f.write(text_format.MessageToString(inference_graph_proto)) return inference_graph_proto
def testFakeQuantizationScheduleTraining(self): p = quant_utils.FakeQuantizationSchedule.Params() p.clip_start_step = 5 p.clip_end_step = 10 p.quant_start_step = 15 p.start_cap = 6.0 p.end_cap = 1.0 with self.session() as sess: cc_schedule = p.Instantiate() tf.global_variables_initializer().run() # Step 0: No clipping. self.assertAllClose( self._ClipExample(cc_schedule, 100.0), (-100.0, 100.0)) self.assertAllClose( self._ClipExample(cc_schedule, 0.123456), (-0.123456, 0.123456)) # Not Quantized. # Step 5: Clipping active but not yet quantizing. sess.run(tf.assign(py_utils.GetOrCreateGlobalStepVar(), 5)) self.assertAllClose( self._ClipExample(cc_schedule, 100.0), (-6.0, 5.953125)) # 6 * 127/128 self.assertAllClose( self._ClipExample(cc_schedule, 0.123456), (-0.123456, 0.123456)) # Not Quantized. # Step 7: Middle of clipping range. sess.run(tf.assign(py_utils.GetOrCreateGlobalStepVar(), 7)) self.assertAllClose( self._ClipExample(cc_schedule, 100.0), (-4.0, 3.96875)) # 4 * 127/128 self.assertAllClose( self._ClipExample(cc_schedule, 0.123456), (-0.123456, 0.123456)) # Not Quantized. # Step 10: End of clipping range. sess.run(tf.assign(py_utils.GetOrCreateGlobalStepVar(), 10)) self.assertAllClose( self._ClipExample(cc_schedule, 100.0), (-1.0, 0.9921875)) # 1 * 127/128 self.assertAllClose( self._ClipExample(cc_schedule, 0.123456), (-0.123456, 0.123456)) # Not Quantized. # Step 11: No more clipping but not yet quantizing. sess.run(tf.assign(py_utils.GetOrCreateGlobalStepVar(), 11)) self.assertAllClose( self._ClipExample(cc_schedule, 100.0), (-1.0, 0.9921875)) # 1 * 127/128 self.assertAllClose( self._ClipExample(cc_schedule, 0.123456), (-0.123456, 0.123456)) # Not Quantized. # Step 15-16: Quantizing at full clip. for step in (15, 16): sess.run(tf.assign(py_utils.GetOrCreateGlobalStepVar(), step)) self.assertAllClose( self._ClipExample(cc_schedule, 100.0), (-1.0, 0.9921875)) # 1 * 127/128 self.assertAllClose( self._ClipExample(cc_schedule, 0.123456), (-0.125, 0.125)) # Quantized.
def __init__(self, train_cfg, ps_params_dict, model_task_name, logdir, tf_master, **kwargs): """Construct an ExecutorTpu BaseRunner. Args: train_cfg: SingleTaskModelParams or MultiTaskModelParams ps_params_dict: A dict of top-level task name -> ProgramSchedule params, if train_cfg is a SingleTaskModelParams, we expect only one entry. model_task_name: An override for multi-task models, currently unused. logdir: String path to the log directory to output to. tf_master: String path to the master job, e.g. 'local'. **kwargs: keyword args to pass through to BaseRunner. """ super().__init__(train_cfg, model_task_name, logdir, tf_master, **kwargs) data_parallelism = self._cluster.num_splits_per_client assert data_parallelism num_devices_per_split = self._cluster.num_devices_per_split tf.logging.info('data_parallelism: %d, num_devices_per_split: %d', data_parallelism, num_devices_per_split) self.task_scheduler = None self._checkpoint_dir = os.path.join(logdir, 'train') self._variable_renaming_rules = [] self._ml_perf = None # If this is a multi-task model, grab the params for the TaskScheduler. if issubclass(train_cfg.cls, base_model.SingleTaskModel): tf.logging.info('single_task_model') assert len(ps_params_dict) == 1 self._model_task_name = list(ps_params_dict.keys())[0] self._single_task_mode = True elif issubclass(train_cfg.cls, base_model.MultiTaskModel): tf.logging.info('multi_task_model') if issubclass(train_cfg.cls, multitask_model.RegExSharedVariableModel): self._variable_renaming_rules = train_cfg.variable_renaming_rules if train_cfg.task_schedule is None: task_schedule_params = task_scheduler.ConstantScheduler.Params( ) task_schedule_params.task_probs = sorted( list(train_cfg.task_probs.IterParams())) else: task_schedule_params = train_cfg.task_schedule self.task_scheduler = task_schedule_params.Instantiate() self._single_task_mode = False else: tf.logging.fatal( 'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel', train_cfg.cls) tf.logging.info('train_cfg.cls: %s', train_cfg.cls) self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir, 'trainer_params.txt') if self._ml_perf is not None: self._ml_perf_log = True mlp_log.mlperf_print(key='benchmark', value=self._ml_perf.benchmark_name) else: self._ml_perf_log = False # BaseRunner legacy self.enqueue_ops = None train_cfg = self.params @py_utils.RetryOnTransientTfError() def _WaitTillInit(job=None): """Wait until the model is ready.""" try: # tpu.initialize_system() is called with None as embedding_config, as # embedding_config is not available yet. Later in _Loop, it is called # with the correct embedding_config. Since it cannot be called twice in # the same graph with different embedding_config, we use a dummy_graph # here. dummy_graph = tf.Graph() with dummy_graph.as_default(): tpu_initialize_system_op = tf.tpu.initialize_system( embedding_config=None, job=job) with self._GetSession(graph=dummy_graph) as sess: topology = sess.run(tpu_initialize_system_op) if train_cfg.train.tpu_device_order_mode is None: device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=py_utils.ComputationShape( num_devices_per_split, topology), num_replicas=data_parallelism) else: device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=py_utils.ComputationShape( num_devices_per_split, topology), num_replicas=data_parallelism, device_order_mode=train_cfg.train.tpu_device_order_mode ) py_utils.SetTpuDeviceAssignment(device_assignment, job) tf.logging.info('device_assignment.core_assignment: %s', str(device_assignment.core_assignment)) tf.logging.info( 'device_assignment.topology.device_coordinates: %s', str(device_assignment.topology.device_coordinates)) except py_utils.transient_tf_errors as e: tf.logging.info('TPU initialization failed: %s', e) raise if self._ml_perf_log: mlp_log.mlperf_print(key='init_start', value=None) if len(self._cluster.all_worker_names) > 1: for worker in self._cluster.all_worker_names: _WaitTillInit(worker) else: _WaitTillInit(None) shared_model = self._MaybeConstructSharedModel(train_cfg) self._program_schedule_dict = {} self._programs = [] for task_string, program_schedule_params in ps_params_dict.items(): program_schedule_params.logdir = logdir program_schedule_params.num_splits_per_client = data_parallelism program_schedule_params.task_name = task_string # If the model was created above, we'll inject it here as a shared_model. ps = program_schedule_params.Instantiate(shared_model=shared_model, tf_master=self._tf_master) self._program_schedule_dict[task_string] = ps tf.logging.info('program_schedule_params: %s', program_schedule_params.ToText()) self._programs += ps.Programs() if program_schedule_params.ml_perf.benchmark_name is not None: self._ml_perf = program_schedule_params.ml_perf tf.logging.info('num_programs: %d', len(self._programs)) with self._graph.as_default(), tf.container(self._container_id): with self._cluster, tf.device(self._cluster.GetPlacer()): with py_utils.VariableRenameScope( self._variable_renaming_rules): _ = py_utils.GetOrCreateGlobalStepVar() for program in self._programs: program.BuildTpuSubgraph() py_utils.ClearTpuSummaryTensors() self._initialize_tables = tf.tables_initializer() self._initialize_local_vars = tf.local_variables_initializer() self._initialize_global_vars = tf.global_variables_initializer( ) for program in self._programs: program.SetStatusMessageFn(self._SetStatusMessage) program.CreateCheckpointer( init_op=self._initialize_global_vars) self.save_only_checkpointer = checkpointer.Checkpointer( self._checkpoint_dir, model=None, init_op=self._initialize_global_vars, train_params=train_cfg.train, save_only=True) self._load_ops = tf.get_collection(py_utils.TPU_EMBEDDING_LOAD_OPS) self._retrieve_ops = tf.get_collection( py_utils.TPU_EMBEDDING_RETRIEVE_OPS) tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) self._tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) tf.io.write_graph(self._graph.as_graph_def(), self._checkpoint_dir, 'train.pbtxt')
def _create_session(self, *args, **kwargs): sess = super(TestCase, self)._create_session(*args, **kwargs) with sess.graph.as_default(): # Ensure the global_step variable is created in every new session. py_utils.GetOrCreateGlobalStepVar() return sess
def setUp(self): super(TestCase, self).setUp() # Ensure the global_step variable is created in the default graph. py_utils.GetOrCreateGlobalStepVar()
def _BPropForVariables(self, vmap): """Constructs the backward graph.""" bprop_variable_filters = self.input_generator.GetBpropVariableFilters() # Only compute the mask if the variable filters are not empty. if bprop_variable_filters != [''] * len(bprop_variable_filters): self._ComputeGradientMask(bprop_variable_filters) train_ops = {} # mapping from op name to op. gradient_mask = None if self._per_input_gradient_mask: # TODO(neerajgaur): Change this to use source_selected from input_batch. onehot = self.input_generator.GetInputSourceOneHot() gradient_mask = { k: tf.tensordot(v, onehot, 1) for k, v in six.iteritems(self._per_input_gradient_mask) } all_losses = [] for optimization in self.learners: loss_name = optimization.params.name metric = self._metrics.get(loss_name, None) if metric is None: raise ValueError('Loss %s not found in metrics %s' % (loss_name, list(self._metrics.keys()))) loss = metric[0] all_losses.append(loss) train_ops['train/%s' % loss_name], eval_metrics = optimization.Apply( loss, vmap, gradient_mask=gradient_mask, gradient_adjuster=self.AdjustGradients) for key, (value, weight) in six.iteritems(eval_metrics): self.AddEvalMetric(key + '/' + loss_name, value, weight) relevant_bn_updates, _ = py_utils.FindRelevantBatchNormUpdates( all_losses, tf.get_collection(py_utils.BATCH_NORM_UPDATES)) train_ops['bn_updates'] = relevant_bn_updates # Get the op to update the weight masks and thresholds train_ops['mask_updates'] = self._GetMaskUpdateOp() # Post training step update. train_ops['post_step'] = self.PostTrainingStepUpdate(self.global_step) with tf.control_dependencies(tf.nest.flatten(train_ops)): true_global_step = py_utils.GetOrCreateGlobalStepVar() with tf.colocate_with(true_global_step): increment_global_steps = tf.assign_add(true_global_step, 1) if self._global_step_var != true_global_step: with tf.colocate_with(self._global_step_var): increment_global_steps = tf.group( increment_global_steps, tf.assign_add(self._global_step_var, 1)) train_ops['global_step'] = increment_global_steps # If we are using Tpu Embeddings, generate the monolithic send # gradient op. tpu_embedding_activations = tf.get_collection( py_utils.TPU_EMBEDDING_ACTIVATIONS) if tpu_embedding_activations: tpu_embedding_activations_dict = tpu_embedding_activations[0] tpu_embedding = tf.get_collection(py_utils.TPU_EMBEDDING)[0] tpu_embedding_send_gradient_op = py_utils.ComputeTpuEmbeddingGradients( self.loss, tpu_embedding_activations_dict, tpu_embedding) train_ops['tpu_embedding'] = tpu_embedding_send_gradient_op for op_name, op in six.iteritems(train_ops): assert op is not None, op_name # TODO(rpang): try to structure _train_op as: # tf.cond(skip_step, <only update skip stats>, <all updates>) # so that we skip all other updates when a step is skipped. self._train_op = tf.group(*tf.nest.flatten(train_ops), name='bprop')
def testBatchNormLayer(self): p = base_model.SingleTaskModel.Params() p.task = self.TestParams(layers.BatchNormLayer.Params().Set(dim=1)) p.task.train.ema_decay = 0.9 p.task.train.ema_decay_moving_vars = True model = p.Instantiate() self.assertIsNotNone(model.ema) task = model._task task._train_op = tf.no_op() task.ApplyExponentialMovingAverage(model.ema) layer = task.encoder self.assertLen(layer.vars, 4) for var in layer.vars.Flatten(): self.assertIsNotNone(model.ema.average(var), msg=var.name) beta = layer.vars.beta mean = layer.vars.moving_mean global_step = 100 beta_1 = np.asarray([.2]) mean_1 = np.asarray([.03]) beta_1_ema = beta_1 * .1 mean_1_ema = mean_1 * .1 with self.session() as sess: # Test EMA values. sess.run(tf.global_variables_initializer()) sess.run(tf.assign(py_utils.GetOrCreateGlobalStepVar(), global_step)) sess.run(tf.assign(beta, beta_1)) sess.run(tf.assign(mean, mean_1)) sess.run(task._post_train_ops) self.assertAllClose([beta_1, beta_1_ema, mean_1, mean_1_ema], sess.run([ beta, model.ema.average(beta), mean, model.ema.average(mean) ])) # Test checkpointer. train_dir = os.path.join(self.get_temp_dir(), 'testSaveRestore') os.mkdir(train_dir) saver = checkpointer.Checkpointer(train_dir, model) saver.Save(sess, model.global_step) self.assertTrue( os.path.isfile( os.path.join(train_dir, 'ckpt-%08d.index' % global_step))) # Restore from ckpt in training mode. with self.session(graph=tf.Graph()) as sess: model = p.Instantiate() self.assertIsNotNone(model.ema) task = model._task task._train_op = tf.no_op() task.ApplyExponentialMovingAverage(model.ema) layer = task.encoder for var in layer.vars.Flatten(): self.assertIsNotNone(model.ema.average(var), msg=var.name) beta = layer.vars.beta mean = layer.vars.moving_mean saver = checkpointer.Checkpointer(train_dir, model) saver.RestoreIfNeeded(sess) self.assertAllClose([beta_1, beta_1_ema, mean_1, mean_1_ema], sess.run([ beta, model.ema.average(beta), mean, model.ema.average(mean) ])) # Restore from ckpt in eval mode. with self.session(graph=tf.Graph()) as sess, self.SetEval(True): model = p.Instantiate() self.assertIsNotNone(model.ema) task = model._task # task._train_op = tf.no_op() # task.ApplyExponentialMovingAverage(model.ema) layer = task.encoder # for var in layer.vars.Flatten(): # self.assertIsNotNone(model.ema.average(var), msg=var.name) beta = layer.vars.beta mean = layer.vars.moving_mean saver = checkpointer.Checkpointer(train_dir, model) saver.RestoreIfNeeded(sess) # Both beta and mean should use the EMA value. self.assertAllClose([beta_1_ema, mean_1_ema], sess.run([beta, mean]))