def testTransformerScheduleWithStartStep(self): ref_params = schedule.TransformerSchedule.Params().Set( warmup_steps=3000, model_dim=512) ref_lrs = ref_params.Instantiate() start_step = 1000 params = ref_params.Copy().Set(start_step=start_step) lrs = params.Instantiate() with self.session() as sess: sess.run(tf.global_variables_initializer()) # Warmup respects real global_step no matter start_step. warmup_values = [] for step in range(0, 4001, 1000): with py_utils.GlobalStepContext(step): warmup_values.append(lrs.Value()) warmup_values = sess.run(warmup_values) expected_values = [0, 0.000269, 0.000538, 0.000699, 0.000625] print(warmup_values) self.assertAllClose(warmup_values, expected_values) # After warmup, ref_lrs and lrs has same function with x-axis translation. ref_values = [] values = [] for step in range(3000, 8000, 1000): with py_utils.GlobalStepContext(step): values.append(lrs.Value()) with py_utils.GlobalStepContext(step + start_step): ref_values.append(ref_lrs.Value()) ref_values, values = sess.run([ref_values, values]) print(ref_values) self.assertAllClose(ref_values, values)
def testTransformerMLPerfSchedule(self): params = schedule.TransformerMLPerfSchedule.Params().Set( warmup_steps=4000, warmup_init_fraction=.3, model_dim=512) lrs = params.Instantiate() base_params = schedule.TransformerSchedule.Params().Set( warmup_steps=4000, model_dim=512) base_lrs = base_params.Instantiate() with self.session(): # Linear warmup starting from 0.3 * peak_lr. peak_lr = 0.000698684 for step in (0, 1000, 2000, 3000, 4000): with py_utils.GlobalStepContext(step): self.assertAllClose( .3 * peak_lr + .7 * base_lrs.Value().eval(), lrs.Value().eval()) # Test that the schedule is identical with transformer-lr after 4k steps for step in (4000, 4010, 5000): with py_utils.GlobalStepContext(step): self.assertAllClose(base_lrs.Value().eval(), lrs.Value().eval()) self.assertAllClose(base_lrs.Value().eval(), lrs.Value().eval()) self.assertAllClose(base_lrs.Value().eval(), lrs.Value().eval())
def testTransformerScheduleNoWarmUp(self): params = schedule.TransformerScheduleNoWarmUp.Params().Set( decay_start=4000, model_dim=512) lrs = params.Instantiate() base_params = schedule.TransformerSchedule.Params().Set( warmup_steps=4000, model_dim=512) base_lrs = base_params.Instantiate() with self.session(): # Tests that the schedule is flat up until 4000 steps. for step in (0, 1000, 2000, 3000, 4000): with py_utils.GlobalStepContext(step): self.assertAllClose(lrs.Value().eval(), 0.000698684) with py_utils.GlobalStepContext(4500): self.assertAllClose(lrs.Value().eval(), 0.000658735) with py_utils.GlobalStepContext(5000): self.assertAllClose(lrs.Value().eval(), 0.000624937) # Test that the schedule is identical with transformer-lr after 4k steps for step in (4000, 4010, 5000): with py_utils.GlobalStepContext(step): self.assertAllClose(base_lrs.Value().eval(), lrs.Value().eval()) self.assertAllClose(base_lrs.Value().eval(), lrs.Value().eval()) self.assertAllClose(base_lrs.Value().eval(), lrs.Value().eval())
def testCosineWithLinearRamp(self): p = self._testParams() lr_util.SetCosineLR( p.train, p.input, warmup_epoch=1, total_epoch=10, warmup_init=0.) schedule_layer = p.train.lr_schedule.Instantiate() with self.session(): # Linear ramp up. with py_utils.GlobalStepContext(8): self.assertLess(self.evaluate(schedule_layer.Value()), 1.) # Cosine ramp down. with py_utils.GlobalStepContext(48): self.assertLess(self.evaluate(schedule_layer.Value()), 1.)
def testExponentialWithoutLinearRamp(self): p = self._testParams() lr_util.SetExponentialLR( p.train, p.input, exp_start_epoch=0, total_epoch=10) schedule_layer = p.train.lr_schedule.Instantiate() with self.session(): # Peak learning rate at 0. with py_utils.GlobalStepContext(0): self.assertEqual(self.evaluate(schedule_layer.Value()), 1.) # Exponential ramp down within first epoch. with py_utils.GlobalStepContext(4): self.assertLess(self.evaluate(schedule_layer.Value()), 1.)
def test_linear_rampup_exp_decay_schedule_nowarmup(self): p = schedules.LinearRampupExponentialDecay.Params().Set( warmup=0, decay_start=0, decay_end=100, max=1.0, min_ratio=0.01) lr_schedule = p.Instantiate() jit_value = jax.jit(lr_schedule.value) xs = [0, 50, 100, 150, 200] expected_values = [1., 0.1, 0.01, 0.01, 0.01] with self.subTest(name='reference_values'): for count, expected_value in zip(xs, expected_values): self.assertAllClose(jit_value(jnp.array(count)), expected_value) tf_p = tf_schedule.LinearRampupExponentialDecay.Params().Set( warmup=0, decay_start=0, decay_end=100, max=1.0, min=0.01) tf_lr_schedule = tf_p.Instantiate() with self.subTest(name='lingvo_values'): for count in xs: if count == 50: # Lingvo implementation does not support no warm-up. It just adds a # warm-up consisting of a single step. Hence, no comparison. continue with tf_py_utils.GlobalStepContext(count): self.assertAllClose( jit_value(jnp.array(count)), tf_lr_schedule.Value().numpy())
def CellFn(theta, state0, inputs): """A cell fn is exectued inside of StackedRecurrent.""" del state0 frop_inputs = [] for input_idx in range(len(state_shapes[i])): name = 's{}'.format(input_idx) if state_shapes[i][input_idx] is not None: inputs[name].set_shape(state_shapes[i][input_idx]) frop_inputs.append(inputs[name]) else: frop_inputs.append(None) with CellFnFropOpReplacementWrapper(): tf.logging.info('cell {} input {}'.format(i, frop_inputs)) mb_tensor = inputs[_MICRO_BATCH_STATE_NAME] gs_tensor = theta.global_step * p.num_micro_batches + tf.cast( mb_tensor, theta.global_step.dtype) _, cell = self._cells[i] with py_utils.GlobalStepContext(gs_tensor): outputs = cell.FProp(theta, *frop_inputs) state1 = py_utils.NestedMap() state1[_MICRO_BATCH_STATE_NAME] = mb_tensor outputs = _ToTuple(outputs) assert len(outputs) == len(state_shapes[i + 1]) for output_idx in range(len(outputs)): if outputs[output_idx] is not None: name = 's{}'.format(output_idx) state1[name] = outputs[output_idx] return state1, py_utils.NestedMap()
def testLayerWithPassiveAsymQDomain(self): # pyformat: disable expected = [[[0., -0.03921568, -0.02352941, -0.00784314], [0.0862745, 0.13333333, -0.03137255, 0.06274509], [0., 0., 0., 0.], [-0.02352941, -0.17254901, -0.05490196, 0.02352941]], [[0., 0., 0., 0.], [0., 0., 0., 0.], [-0.02352941, -0.10196078, -0.00784314, 0.07058823], [0.02352941, -0.1490196, -0.09411764, 0.01568627]]] # pyformat: enable with self.session(): p = quant_test_lib.SampleQuantizedProjectionLayer.Params() p.qdomain.default = quant_utils.PassiveAsymQDomain.Params() l = self._testLayerHelper('testLayerWithPassiveAsymQDomain', p, expected=expected) init_minmax_vars = l.qdomain_default._qvars.Transform( lambda x: x.eval()) print('Initial Minmax vars:', init_minmax_vars) # Record. with py_utils.GlobalStepContext(16): self.evaluate([l.PostTrainingStepUpdate()]) minmax_vars = l.qdomain_default._qvars.Transform( lambda x: x.eval()) print('Minmax vars:', minmax_vars) # Make sure that the vars have moved from their defaults. for k in minmax_vars: self.assertNotEqual(init_minmax_vars[k], minmax_vars[k])
def test_linear_rampup_exp_decay_schedule(self): p = schedules.LinearRampupExponentialDecay.Params().Set( warmup=100, decay_start=200, decay_end=300, max=1.0, min_ratio=0.01) lr_schedule = p.Instantiate() jit_value = jax.jit(lr_schedule.value) xs = [0, 10, 20, 100, 120, 150, 200, 250, 300, 350] expected_values = [0.0, 0.1, 0.2, 1.0, 1.0, 1.0, 1.0, 0.1, 0.01, 0.01] with self.subTest(name='reference_values'): for count, expected_value in zip(xs, expected_values): self.assertAllClose(jit_value(jnp.array(count)), expected_value) tf_p = tf_schedule.LinearRampupExponentialDecay.Params().Set( warmup=100, decay_start=200, decay_end=300, max=1.0, min=0.01) tf_lr_schedule = tf_p.Instantiate() with self.subTest(name='lingvo_values'): for count in xs: with tf_py_utils.GlobalStepContext(count): self.assertAllClose(jit_value(jnp.array(count)), tf_lr_schedule.Value().numpy())
def test_transformer_schedule_with_decay_end_fixed(self): p = schedules.Transformer.Params().Set(warmup_steps=4000, model_dim=512, decay_end=5000) lr_schedule = p.Instantiate() jit_value = jax.jit(lr_schedule.value) # Tests that the schedule is fixed after decay end steps. v_decay_end = lr_schedule.value(jnp.array(p.decay_end)) with self.subTest(name='reference_values'): self.assertGreater(jit_value(jnp.array(p.decay_end - 1)), v_decay_end) self.assertAllClose(jit_value(jnp.array(p.decay_end + 1)), v_decay_end) self.assertAllClose(jit_value(jnp.array(p.decay_end + 1000)), v_decay_end) tf_p = tf_schedule.TransformerSchedule.Params().Set(warmup_steps=4000, model_dim=512, decay_end=5000) tf_lr_schedule = tf_p.Instantiate() with self.subTest(name='lingvo_values'): for count in range(p.decay_end - 1, p.decay_end + 20, 2): with tf_py_utils.GlobalStepContext(count): self.assertAllClose(jit_value(jnp.array(count)), tf_lr_schedule.Value().numpy())
def testCombinedLRSchedule(self): p = schedule.CombinedMinimumSchedule.Params().Set(schedules=[ schedule.LinearSchedule.Params().Set(start=(0., 1.), limit=(2000000, 8.)), schedule.LinearSchedule.Params().Set(start=(2000000., 8.), limit=(4000000, 8.)), schedule.ExponentialSchedule.Params().Set(start=(4000000., 8.), limit=(8000000, 0.5)) ]) with self.session(): lrs = p.Instantiate() pts = [] for step in range(0, 10000000, 1000000): with py_utils.GlobalStepContext(step): pts.append([step, lrs.Value().eval()]) self.assertAllClose( pts, [ # Linear increasing. [0, 1.0], [1000000, 4.5], # Constant [2000000, 8.0], [3000000, 8.0], # Exponentially decreasing. [4000000, 8.0], [5000000, 4.0], [6000000, 2.0], [7000000, 1.0], [8000000, 0.5], [9000000, 0.5] ])
def testStepwiseExponentialSchedule(self): p = schedule.StepwiseExponentialSchedule.Params() p.decay = 0.5 p.num_steps_per_decay = 1000 decay = p.Instantiate() with self.session(): with py_utils.GlobalStepContext(0): self.assertAllClose(decay.Value().eval(), 1.0) with py_utils.GlobalStepContext(999): self.assertAllClose(decay.Value().eval(), 1.0) with py_utils.GlobalStepContext(1000): self.assertAllClose(decay.Value().eval(), 0.5) with py_utils.GlobalStepContext(1999): self.assertAllClose(decay.Value().eval(), 0.5) with py_utils.GlobalStepContext(2000): self.assertAllClose(decay.Value().eval(), 0.25)
def testLinearRampupExponentialDecayScaledByNumSplitScheduleNoWarmUp(self): p = schedule.LinearRampupExponentialDecayScaledByNumSplitSchedule.Params( ).Set(warmup=0, decay_start=32000000, decay_end=64000000, min=0.5) with self.session(), cluster_factory.ForTestingWorker( mode='sync', job='trainer_client', gpus=8): lrs = p.Instantiate() pts = [] for step in range(0, 10000000, 1000000): with py_utils.GlobalStepContext(step): pts.append([step, lrs.Value().eval()]) self.assertAllClose( pts, [ # Constant [0, 8.0], [1000000, 8.0], [2000000, 8.0], [3000000, 8.0], # Exponentially decreasing. [4000000, 8.0], [5000000, 4.0], [6000000, 2.0], [7000000, 1.0], [8000000, 0.5], [9000000, 0.5] ])
def testConstantOne(self): with self.session(use_gpu=False): p = schedule.ConstantOne.Params() lrs = p.Instantiate() for x in [0, 10, 100, 1000000]: with py_utils.GlobalStepContext(x): self.assertAllClose(lrs.Value().eval(), 1.0)
def FProp(self, theta, input_batch): """Forward propagation. This default `FProp` implementation here supports batch splitting in synchronous and asynchronous training when sub-classes implement `FPropTower`. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: The input batch. A `NestedMap` of tensors. Or, if input batch spiltting is used, a list of `NestedMap`, one for each split. Returns: (dict, dict): - A dict containing str keys and (metric, weight) pairs as values, where one of the keys is expected to be 'loss'. - A dict containing arbitrary tensors describing something about each training example, where the first dimension of each tensor is the batch index. """ p = self.params with tf.name_scope('fprop'), tf.name_scope(p.name): with py_utils.GlobalStepContext(self._global_step_var): # Always reset step seed at the start of a new global_step. py_utils.ResetStepSeed() metrics, per_example = self._FPropSplitInputBatch(theta, input_batch) self._FPropResult(metrics, per_example) return metrics, per_example
def testLinearRampupSqrtDecayByBatchSizeAndReplicasSchedule(self): p = schedule.LinearRampupSqrtDecayByBatchSizeAndReplicas.Params().Set( warmup_examples=100000, batch_size=100) with self.session(), cluster_factory.ForTestingWorker( mode='sync', job='trainer_client', gpus=10): lrs = p.Instantiate() with py_utils.GlobalStepContext(-1): self.assertAllClose(lrs.Value().eval(), 0.0) with py_utils.GlobalStepContext(49): self.assertAllClose(lrs.Value().eval(), 0.05) with py_utils.GlobalStepContext(99): self.assertAllClose(lrs.Value().eval(), 0.1) with py_utils.GlobalStepContext(399): self.assertAllClose(lrs.Value().eval(), 0.05) with py_utils.GlobalStepContext(1599): self.assertAllClose(lrs.Value().eval(), 0.025)
def testLinearRampupCosineSchedule(self): p = schedule.LinearRampupCosineSchedule.Params().Set( warmup_steps=200, initial_value=3.0, final_value=1.0, total_steps=400000, num_splits=1) with self.session(): lrs = p.Instantiate() pts = [] for step in [0, 100, 200, 100000, 200000, 300000, 400000]: with py_utils.GlobalStepContext(step): pts.append([step, lrs.Value().eval()]) self.assertAllClose( pts, [ [0, 0.0], [100, 1.5], [200, 3.0], [100000, math.cos(math.pi / 4) + 2.], # angle=pi/4 [200000, 2.0], # angle=pi/2, half-way [300000, math.cos(math.pi * 3 / 4) + 2.], # angle=pi*3/4 [400000, 1.0], ])
def testPiecewiseSchedule(self): # Linear ramp-up in 20000 steps, cosine decay in 40000 steps. p0 = schedule.LinearSchedule.Params().Set(start=(0, 0.), limit=(20000, 2.)) p1 = schedule.CosineSchedule.Params().Set(initial_value=2.0, total_steps=40000) p = schedule.PiecewiseSchedule.Params().Set(boundaries=[20000], schedules=[p0, p1]) with self.session(): lrs = p.Instantiate() pts = [] for step in range(0, 70000, 10000): with py_utils.GlobalStepContext(step): pts.append([step, lrs.Value().eval()]) self.assertAllClose( pts, [ [0, 0.0], [10000, 1.0], # half-way in linear ramp-up. [20000, 2.0], # completed linear ramp-up. [30000, math.cos(math.pi / 4) + 1.], # pi/4. [40000, 1.0], # pi/2. [50000, math.cos(math.pi * 3 / 4) + 1.], # pi*3/4. [60000, 0.0], # pi. ])
def testContinuousSchedule_CanOverrideStart(self): p = schedule.ContinuousSchedule.Params() p.initial_value = 2.0 p.start_step = 1000 p.half_life_steps = 100 decay = p.Instantiate() with self.session(): with py_utils.GlobalStepContext(0): self.assertAllClose(decay.Value().eval(), 2.0) with py_utils.GlobalStepContext(1000): self.assertAllClose(decay.Value().eval(), 2.0) with py_utils.GlobalStepContext(1100): self.assertAllClose(decay.Value().eval(), 1.0) with py_utils.GlobalStepContext(1200): self.assertAllClose(decay.Value().eval(), 0.5) with py_utils.GlobalStepContext(1300): self.assertAllClose(decay.Value().eval(), 0.25)
def testPolynomialLRSchedule(self): p = schedule.PolynomialSchedule.Params().Set( power=2, start=(0, 0.), limit=(20000, 2.)) with self.session(): lrs = p.Instantiate() pts = [] for step in (0, 10000, 20000): with py_utils.GlobalStepContext(step): pts.append([step, lrs.Value().eval()]) self.assertAllClose( pts, [ [0, 0.0], [10000, 0.5], # 2 * (0.5 ** 2) [20000, 2.0], ]) with py_utils.GlobalStepContext(42): self.assertEqual(len(lrs.Value().shape), 0)
def PostTrainingLoop(self, outfeed=None): """Construct the post training loop op. Args: outfeed: a dict of tensors dequeued from TPU outfeed queue. """ with py_utils.GlobalStepContext(self._global_step_var): self._post_training_loop_op = tf.group( *[opt.ApplyPostTrainingLoop() for opt in self.learners])
def test_sqrt_decay_schedule_values(self, count): p = schedules.SqrtDecay.Params().Set(warmup_steps=4000) lr_schedule = p.Instantiate() jit_value = jax.jit(lr_schedule.value) tf_p = tf_schedule.SqrtDecay.Params().Set(warmup_steps=4000) tf_lr_schedule = tf_p.Instantiate() with tf_py_utils.GlobalStepContext(count): self.assertAllClose(jit_value(jnp.array(count)), tf_lr_schedule.Value().numpy())
def _ParseProcessor(processor): """Parses python callable `processor` into a TF concrete function.""" output_tmpl = py_utils.NestedMap() @tf.function(autograph=False) def _FlatOutputProcessor(source_id, record): """Returns a flattened list of 'processor(inputs)'.""" processor_spec = tf_inspect.getargspec(processor) tf.logging.debug('GenericInput.processor.argspec=%s', processor_spec) processor_args = set(processor_spec.args) - set(['self']) if len(processor_args) == 1: output, bucketing_key = processor(record) elif processor_args == set(['source_id', 'record']): output, bucketing_key = processor(source_id=source_id, record=record) else: raise ValueError( 'GenericInput: processor should take either a single arg ' 'or two args named as "source_id" and "record". ' 'Actual: %s' % processor_args) if isinstance(output, list): assert output assert all(isinstance(x, tf.Tensor) for x in output), '{}'.format(output) else: assert isinstance(output, py_utils.NestedMap), '{}'.format(output) assert output assert all(isinstance(x, tf.Tensor) for x in output.Flatten()), '{}'.format( output.DebugString()) bucketing_key = tf.cast(bucketing_key, tf.int32) tf.logging.debug('Processor outputs=%s bucketing_key=%s', output, bucketing_key) output_tmpl.out_values = output flat_output_tmpl = output_tmpl.Flatten() tf.logging.debug('Processor flat outputs=%s', flat_output_tmpl) tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s', py_utils.GetExtraInputs(), py_utils.GetExtraArgs(), py_utils.GetExtraVars()) assert not py_utils.GetExtraArgs(), ( 'fns {} is not pure: extra_args={}'.format( processor, py_utils.GetExtraArgs())) return flat_output_tmpl + [bucketing_key] with py_utils.GlobalStepContext(None): # Hide global_step tensor from being captured by _FlatOutputProcessor. proc_fn = _FlatOutputProcessor.get_concrete_function( tf.TensorSpec([], tf.int32), tf.TensorSpec([], tf.string)) out_types = [ tf.DType(a.type) for a in proc_fn.function_def.signature.output_arg ] assert out_types[-1] == tf.int32, ('%s is not expected.' % out_types[-1]) return proc_fn, out_types, output_tmpl
def testDevBasedSchedule(self): logdir = tf.test.get_temp_dir() tf.io.gfile.mkdir(os.path.join(logdir, 'eval_dev')) p = schedule.DevBasedSchedule.Params() p.tolerance = 1.0 p.window = 2 p.decay = 0.5 p.min_factor = 0.20 early_stop.MetricHistory.SetLogdirInMetricHistories(p, logdir) lrs = p.Instantiate() self.assertEqual(lrs.theta.cur_factor.name, 'LRSched/cur_factor/var:0') self.assertEqual(lrs.theta.ref_step.name, 'LRSched/ref_step/var:0') mh = lrs._metric_history mh.params.local_filesystem = True with self.session(): self.evaluate(tf.global_variables_initializer()) with py_utils.GlobalStepContext(0): mh.ConditionalAppend(mh.params.jobname, mh.params.metric, 1, 10.0) # best = 1 self.assertAllClose(lrs.Value().eval(), 1.0) mh.ConditionalAppend(mh.params.jobname, mh.params.metric, 2, 5.0) # best = 2 self.assertAllClose(lrs.Value().eval(), 1.0) mh.ConditionalAppend(mh.params.jobname, mh.params.metric, 5, 4.0) # best = 2, out of window self.assertAllClose(lrs.Value().eval(), 0.5) mh.ConditionalAppend(mh.params.jobname, mh.params.metric, 6, 4.0) # best = 2, ref = 5, in window self.assertAllClose(lrs.Value().eval(), 0.5) mh.ConditionalAppend(mh.params.jobname, mh.params.metric, 9, 4.0) # best = 2, ref = 5, out of window self.assertAllClose(lrs.Value().eval(), 0.25) mh.ConditionalAppend(mh.params.jobname, mh.params.metric, 10, 3.9) # best = 10 self.assertAllClose(lrs.Value().eval(), 0.25) mh.ConditionalAppend(mh.params.jobname, mh.params.metric, 13, 3.0) # best = 10, out of window, min factor self.assertAllClose(lrs.Value().eval(), 0.20)
def testPiecewiseConstant(self): cls = schedule.PiecewiseConstantSchedule with self.session(use_gpu=False): bs = [300000, 400000, 500000] vs = [1.0, 0.1, 0.01, 0.001] x_ins = [tf.constant(x) for x in [299999, 399999, 499999, 599999]] outs = [] for x in x_ins: with py_utils.GlobalStepContext(x): lrs = cls.Params().Set(boundaries=bs, values=vs).Instantiate() outs.append(lrs.Value().eval()) self.assertAllClose([1.0, 0.1, 0.01, 0.001], outs)
def testInverseSigmoid(self): p = schedule.InverseSigmoid.Params().Set(k=10000) with self.session(): lrs = p.Instantiate() pts = [] for step in range(0, 200000, 25000): with py_utils.GlobalStepContext(step): pts.append([step, lrs.Value().eval()]) self.assertAllClose( [[0, 0.999900], [25000, 0.998783], [50000, 0.985376], [75000, 0.846880], [100000, 0.312242], [125000, 0.035928], [150000, 0.003050], [175000, 0.000251]], pts)
def testLRDecay(self): with self.session(use_gpu=False, graph=tf.Graph()): p = self._testParams() tp = p.train tp.lr_schedule.boundaries = [300000, 400000, 500000] tp.lr_schedule.values = [1.0, 0.1, 0.01, 0.001] lrs = tp.lr_schedule.Instantiate() fetches = [] for step in [299999, 300001, 399999, 400001, 499999, 500001]: with py_utils.GlobalStepContext(step): fetches.append(lrs.Value()) values = self.evaluate(fetches) self.assertAllClose([1.0, 0.1, 0.1, 0.01, 0.01, 0.001], values)
def testExponentialWithLinearRamp(self): p = self._testParams() lr_util.SetExponentialLR(p.train, p.input, warmup_epoch=1, exp_start_epoch=2, total_epoch=10, warmup_init=0.) schedule_layer = p.train.lr_schedule.Instantiate() with self.session(): # Linear ramp up. with py_utils.GlobalStepContext(8): self.assertLess(self.evaluate(schedule_layer.Value()), 1.) # Peak learning rate. with py_utils.GlobalStepContext(16): self.assertEqual(self.evaluate(schedule_layer.Value()), 1.) # Still at peak learning rate. with py_utils.GlobalStepContext(24): self.assertEqual(self.evaluate(schedule_layer.Value()), 1.) # Exponential ramp down. with py_utils.GlobalStepContext(48): self.assertLess(self.evaluate(schedule_layer.Value()), 1.)
def _InitIterator(self): if self.host_id in self._dataset: return with py_utils.GlobalStepContext(None): # Hide global_step tensor from being captured by dataset function. ds = self.GetDataset() ds.options().experimental_deterministic = False self._dataset[self.host_id] = ds if tf.executing_eagerly(): it = iter(ds) else: it = tf.data.make_initializable_iterator(ds) self._iterator[self.host_id] = it
def Value(self): p = self.params current_step = tf.cast(py_utils.GetGlobalStep(), tf.int64) interval_starts = [0] + p.boundaries values = [] for interval_start, schedule in zip(interval_starts, self.schedules): relative_step = tf.maximum( tf.cast(0, current_step.dtype), current_step - tf.cast(interval_start, current_step.dtype)) with py_utils.GlobalStepContext(relative_step): values.append(schedule.Value()) return py_utils.PiecewiseConstant(current_step, p.boundaries, values, values[0].dtype)