def BuildTpuSubgraph(self): tf.logging.info('DecodeProgram BuildTpuSubGraph') py_utils.ResetStepSeed() def _DecodeFn(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): with cluster_factory.SetEval(True): self._model = self._task_params.Instantiate() self._model_task = self._model.GetTask() if py_utils.use_tpu(): input_batch = self._model_task.input_generator.CreateTpuFeeds( ) else: input_batch = self._model_task.input_generator.SplitInputBatch( self.cluster.num_splits_per_client) metrics_dict = self._model_task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) return self.metrics_nm.Flatten() self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( _DecodeFn, num_shards=self.data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self.metrics = py_utils.NestedMap(self.metrics_nm) self.metrics = self.metrics.Pack(batch_parallel_res) return None
def BuildTpuSubgraph(self): tf.logging.info('DecodeProgram BuildTpuSubGraph') py_utils.ResetStepSeed() # Instantiate input generator first. self._input = self._task_params.input.Instantiate() self._input.CreateTpuEnqueueOps() self.SkipCreateChild(self._task_params) def _DecodeFn(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): with cluster_factory.SetEval(True): self._model = self._task_params.Instantiate() self._task = self._model.GetTask() self._task.AddChild('input', self._input) input_batch = self._task.input.TpuDequeueBatch() metrics_dict = self._task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) return self.metrics_nm.Flatten() self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( _DecodeFn, num_shards=self.data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self.metrics = py_utils.NestedMap(self.metrics_nm) self.metrics = self.metrics.Pack(batch_parallel_res) return None
def ConstructFPropBPropGraph(self): py_utils.ResetStepSeed() self._task.FPropDefaultTheta() self._task.BProp() if self.ema: tf.logging.info('ApplyExponentialMovingAverage on %s', self._task) self._task.ApplyExponentialMovingAverage(self.ema)
def FProp(self, theta, input_batch): """Forward propagation. This default `FProp` implementation here supports batch splitting in synchronous and asynchronous training when sub-classes implement `FPropTower`. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: The input batch. A `NestedMap` of tensors. Or, if input batch spiltting is used, a list of `NestedMap`, one for each split. Returns: Two dicts: A dict containing str keys and (metric, weight) pairs as values, where one of the keys is expected to be 'loss'. A dict containing arbitrary tensors describing something about each training example, where the first dimension of each tensor is the batch index. """ p = self.params with tf.name_scope('fprop'), tf.name_scope(p.name): # Always reset step seed at the start of a new global_step. py_utils.ResetStepSeed() if py_utils.use_tpu(): metrics, per_example = self._FPropTpu(theta, input_batch) else: metrics, per_example = self._FPropSplitInputBatch(theta, input_batch) self._FPropResult(metrics, per_example) return metrics, per_example
def Bak(inputs, outputs, d_outputs): """Backward step.""" del inputs # unused output_acts, step_seeds = outputs d_outputs = d_outputs[0] d_layer_thetas = [] for layer_idx in reversed(range(num_layers)): f_seed, g_seed = step_seeds[layer_idx] layer = self.sub_layers[layer_idx] layer_theta = theta.sub_layers[layer_idx] input_acts, d_inputs, d_theta = layer.ReverseAndGrad( layer_theta, output_acts, d_outputs, f_seed, g_seed, *extra_inputs) d_layer_thetas.append(d_theta) # Passes reconstructed inputs to the previous layer. output_acts = input_acts d_outputs = d_inputs py_utils.ResetStepSeed(final_step_seed) d_theta = py_utils.NestedMap( global_step=tf.zeros_like(initial_step_seed)) d_theta.sub_layers = list(reversed(d_layer_thetas)) extra_grads = [tf.zeros_like(t) for t in extra_inputs] return [ tf.zeros_like(initial_step_seed), d_theta, d_inputs, extra_grads ]
def testTransformerAttentionLayerDeterministicDropout(self): with self.session(use_gpu=True) as sess: # Needed to generate a seed pair. py_utils.ResetStepSeed() py_utils.GetOrCreateGlobalStep() depth = 4 p = layers_with_attention.TransformerAttentionLayer.Params() p.name = 'transformer_atten' p.source_dim = depth p.is_masked = False p.num_attention_heads = 2 p.residual_dropout_tpl = layers.DeterministicDropoutLayer.Params() p.residual_dropout_prob = 0.1 transformer_atten = layers_with_attention.TransformerAttentionLayer(p) (source_vecs, source_padding, _, _) = self._testTransformerAttentionLayerInputs(depth=depth) ctx, probs = transformer_atten.FProp(transformer_atten.theta, source_vecs, source_padding) tf.global_variables_initializer().run() actual_ctx, actual_probs = sess.run([ctx, probs]) # pylint: disable=bad-whitespace # pyformat: disable print(np.array_repr(actual_ctx)) expected_ctx = np.array([ [[-1.45762944, 1.5337404 , 0.34037334, -0.97208667], [-1.35992002, -1.06530988, 1.53705895, 2.79370689]], [[ 0.00657134, 1.12030125, -1.32564592, -1.73569465], [-0.80793667, -0.10877949, -0.80295694, 2.25494242]], [[ 1.76956046, -0.50777751, -1.19745886, -1.46751583], [-1.79178905, -0.77374339, 1.31586027, 2.98173356]], [[-0.85498607, -0.37413225, 1.25707364, -0.50043333], [ 1.62276983, 0.50820369, -1.52967572, -2.02076197]], [[-0.66754031, -0.68657839, -0.51643699, 1.96581018], [-1.4816376 , 0.89419198, -0.57226259, 1.90177512]] ], dtype=np.float32) print(np.array_repr(actual_probs)) expected_probs = np.array([ [[ 0.21387868, 0.22080734, 0. , 0. , 0.56531399], [ 0. , 0.30584112, 0.24723588, 0.44692296, 0. ]], [[ 0.25358215, 0.50932312, 0. , 0. , 0.23709476], [ 0. , 0.56834149, 0.2632803 , 0.16837817, 0. ]], [[ 0.38519409, 0.55454361, 0. , 0. , 0.06026226], [ 0. , 0.33708778, 0.21976741, 0.4431448 , 0. ]], [[ 0.27139962, 0.12790371, 0. , 0. , 0.60069668], [ 0. , 0.31849149, 0.28174096, 0.39976761, 0. ]], [[ 0.16272782, 0.15781289, 0. , 0. , 0.67945927], [ 0. , 0.55003977, 0.26049581, 0.18946445, 0. ]] ], dtype=np.float32) # pyformat: enable # pylint: enable=bad-whitespace self.assertAllClose(expected_ctx, actual_ctx, rtol=1e-05, atol=1e-05) self.assertAllClose(expected_probs, actual_probs, rtol=1e-05, atol=1e-05)
def FProp(self, theta, inputs, *extra_inputs): initial_step_seed = py_utils.GetStepSeed() final_step_seed = py_utils.GenerateSeedFromName( tf.no_op(name='new_step_seed').name) num_layers = len(self.sub_layers) def Bak(inputs, outputs, d_outputs): """Backward step.""" del inputs # unused output_acts, step_seeds = outputs d_outputs = d_outputs[0] d_layer_thetas = [] for layer_idx in reversed(range(num_layers)): f_seed, g_seed = step_seeds[layer_idx] layer = self.sub_layers[layer_idx] layer_theta = theta.sub_layers[layer_idx] input_acts, d_inputs, d_theta = layer.ReverseAndGrad( layer_theta, output_acts, d_outputs, f_seed, g_seed, *extra_inputs) d_layer_thetas.append(d_theta) # Passes reconstructed inputs to the previous layer. output_acts = input_acts d_outputs = d_inputs py_utils.ResetStepSeed(final_step_seed) d_theta = py_utils.NestedMap() d_theta.sub_layers = list(reversed(d_layer_thetas)) extra_grads = [tf.zeros_like(t) for t in extra_inputs] return [ tf.zeros_like(initial_step_seed), d_theta, d_inputs, extra_grads ] def Fwd(xs): """Forward pass.""" initial_step_seed, theta, acts, extra_inputs = xs py_utils.ResetStepSeed(initial_step_seed) layer_step_seeds = [] for layer_theta, layer in zip(theta.sub_layers, self.sub_layers): acts, f_seed, g_seed = layer.FProp(layer_theta, acts, *extra_inputs) layer_step_seeds += [(f_seed, g_seed)] return [acts, layer_step_seeds] if self.params.custom_gradient: acts, _ = py_utils.CallDefun( Fwd, [initial_step_seed, theta, inputs, extra_inputs], Bak) py_utils.ResetStepSeed(final_step_seed) return acts else: acts = inputs for layer_theta, layer in zip(theta.sub_layers, self.sub_layers): acts, _, _ = layer.FProp(layer_theta, acts, *extra_inputs) return acts
def BuildTpuSubgraph(self): tf.logging.info('DecodeProgram BuildTpuSubGraph') py_utils.ResetStepSeed() self.spmd = self._task_params.input.use_partitioned_infeed_queue with cluster_factory.SetEval(True): self._CompileDecodeLoop() return
def Fwd(xs): """Forward pass.""" initial_step_seed, theta, acts, extra_inputs = xs py_utils.ResetStepSeed(initial_step_seed) layer_step_seeds = [] for layer_theta, layer in zip(theta.sub_layers, self.sub_layers): acts, f_seed, g_seed = layer.FProp(layer_theta, acts, *extra_inputs) layer_step_seeds += [(f_seed, g_seed)] return [acts, layer_step_seeds]
def BuildTpuSubgraph(self): tf.logging.info('DecodeProgram BuildTpuSubGraph') py_utils.ResetStepSeed() device_assignment = py_utils.GetTpuDeviceAssignment() self.spmd = self._task_params.input.use_partitioned_infeed_queue with cluster_factory.SetEval(True): with cluster_factory.SetImmediatelyInstantiateVariables(False): self._model = self._task_params.Instantiate() self._task = self._model.GetTask() self._task.input.InstantiateVariables() self._task.input.CreateTpuEnqueueOps() def _DecodeStep(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): self._model.InstantiateVariables() input_batch = self._task.input.TpuDequeueBatch() metrics_dict = self._task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) device = tpu.core(0) if self.spmd else '' with tf.device(device): outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple( self.metrics_nm.Flatten()) return [outfeed_enqueue] @tpu_function.on_device_training_loop def DecodeLoopFn(): return tpu_training_loop.repeat(self._steps_per_loop, _DecodeStep, inputs=[]) self._compile_op, self.decode_loop = tpu.split_compile_and_shard( DecodeLoopFn, num_shards=self.data_parallelism, device_assignment=device_assignment) # Get a list of outfeed ops. self.metrics = self._OutfeedDequeue() # Pack the list of outfeed ops with structure in self.metrics_nm. self.metrics = tf.nest.pack_sequence_as(self.metrics_nm, self.metrics) return
def BuildTpuSubgraph(self): tf.logging.info('DecodeProgram BuildTpuSubGraph') py_utils.ResetStepSeed() def _DecodeFn(): with py_utils.OpportunisticVariableReuseScope(True): with cluster_factory.SetEval(True): self._model = self._task_params.Instantiate() self._model_task = self._model.GetTask() input_batch = self._model_task.GetInputBatch() metrics_dict = self._model_task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) return self.metrics_nm.Flatten() batch_parallel_res = tf.tpu.batch_parallel( _DecodeFn, num_shards=self.data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self.metrics = py_utils.NestedMap(self.metrics_nm) self.metrics = self.metrics.Pack(batch_parallel_res) return None
def BuildTpuSubgraph(self): py_utils.ResetStepSeed() def _DecodeFn(): with py_utils.OpportunisticVariableReuseScope(True): self._model = self._task_params.Instantiate() self._model_task = self._model.GetTask() input_batch = self._model_task.GetInputBatch() metrics_dict = self._model_task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) return self.metrics_nm.Flatten() batch_parallel_res = tf.tpu.batch_parallel( _DecodeFn, num_shards=self.data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self._checkpointer = checkpointer.Checkpointer(self._checkpoint_dir, self._model) self.metrics = py_utils.NestedMap(self.metrics_nm) self.metrics = self.metrics.Pack(batch_parallel_res) return None
def BuildTpuSubgraph(self): if self._ml_perf_log: mlp_log.mlperf_print('global_batch_size', self._ml_perf.global_batch_size) mlp_log.mlperf_print('max_sequence_length', self._ml_perf.max_sequence_length) mlp_log.mlperf_print('opt_name', self._ml_perf.optimizer_name) mlp_log.mlperf_print('opt_base_learning_rate', self._ml_perf.base_learning_rate) mlp_log.mlperf_print('opt_learning_rate_warmup_steps', self._ml_perf.warmup_steps) with py_utils.OpportunisticVariableReuseScope(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism def TpuTrainStep(): """Train a shard of a batch on a single TPU core. Do not calculate loss metrics. Returns: [train_op]. """ self._train_model = self._train_task_params.Instantiate() self._model = self._train_model self._train_model.ConstructFPropBPropGraph() return [self._train_model.GetTask().train_op] def TpuTrain(): loop_result = tpu_training_loop.repeat( self._train_steps_per_loop, TpuTrainStep, inputs=[], name='train_loop') return loop_result py_utils.ResetStepSeed() def _DecodeFn(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): with cluster_factory.SetEval(True): self._decode_model = self._decode_task_params.Instantiate() self._decode_model_task = self._decode_model.GetTask() if py_utils.use_tpu(): input_batch = self._decode_model_task.input_generator.CreateTpuFeeds( ) else: input_batch = self._decode_model_task.input_generator.SplitInputBatch( self.cluster.num_splits_per_client) metrics_dict = self._decode_model_task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) return self.metrics_nm.Flatten() @tpu_function.on_device_training_loop def TrainAndDecode(): with tf.control_dependencies([TpuTrain()]): return _DecodeFn() self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TrainAndDecode, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self.metrics = py_utils.NestedMap(self.metrics_nm) self.metrics = self.metrics.Pack(batch_parallel_res) return None
def BuildTpuSubgraph(self): tf.logging.info('DecodeProgram BuildTpuSubGraph') py_utils.ResetStepSeed() with cluster_factory.SetEval(True): self._CompileDecodeFn() return None
def ReverseAndGrad(self, theta, outputs, d_outputs, f_seed, g_seed, *extra_inputs): """Implements Algorithm 1 in the revnet paper. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. outputs: A NestedMap: .split1 and .split2 corresponding to y1 and y2. d_outputs: A NestedMap: .split1 and .split2 corresponding to dy1 and dy2, the total derivatives. f_seed: Scalar tensor. The step seed used in forward for the f block. g_seed: Scalar tensor. The step seed used in forward for the g block. The step seeds are needed for deterministic randomness, e.g. to ensure dropout generate the same random mask in forward and reverse_grad. *extra_inputs: additional inputs that will be passed to both f and g. No gradient will be computed for these inputs. Returns: A tuple of NestedMaps - inputs: .split1 and .split2 corresponding to x1 and x2. - d_inputs: .split1 and .split2 corresponding to dx1 and dx2, the total derivatives with respect to inputs. - d_theta: has the same structure as theta. The total derivatives with respect to weights. """ # Stop gradient on the outputs to avoid circular symbolic dependency. y1 = tf.stop_gradient(outputs.split1) y2 = tf.stop_gradient(outputs.split2) dy1 = d_outputs.split1 dy2 = d_outputs.split2 # Computes the reverse. z1 = y1 py_utils.ResetStepSeed(g_seed) gz1 = self.g_block.FProp(theta.g_block, z1, *extra_inputs) x2 = y2 - gz1 py_utils.ResetStepSeed(f_seed) fx2 = self.f_block.FProp(theta.f_block, x2, *extra_inputs) x1 = z1 - fx2 # Computes the gradients. dz1 = dy1 + tf.gradients(gz1, z1, dy2)[0] dx2 = dy2 + tf.gradients(fx2, x2, dz1)[0] dgw = tf.gradients(gz1, theta.g_block.Flatten(), dy2, unconnected_gradients=tf.UnconnectedGradients.ZERO) dgw = theta.g_block.Pack(dgw) dfw = tf.gradients(fx2, theta.f_block.Flatten(), dz1, unconnected_gradients=tf.UnconnectedGradients.ZERO) dfw = theta.f_block.Pack(dfw) return (py_utils.NestedMap(split1=x1, split2=x2), py_utils.NestedMap(split1=dz1, split2=dx2), py_utils.NestedMap(f_block=dfw, g_block=dgw, global_step=tf.zeros_like( theta.global_step)))