def decode_loop_fn(): if not self.num_batches: infinite_repeat(decode_fn, infeed) else: training_loop.repeat(self.num_batches, decode_fn, infeed_queue=infeed)
def TpuTrain(): loop_result = tpu_training_loop.repeat( self._train_steps_per_loop, TpuTrainStep, inputs=[], name='train_loop') return loop_result
def TpuEval(): loop_result = tpu_training_loop.repeat( self._steps_per_loop, TpuEvalStep, inputs=self._eval_metrics.initial_values, name='eval_loop') # Final metrics are the avg across self._steps_per_loop steps. return self._eval_metrics.FinalizeMetrics(loop_result)
def eval_loop(self): per_replica_eval_batch_size = self.eval_batch_size // self.num_replicas tf.get_variable_scope().reuse_variables() predictions = tf.zeros( [self.eval_steps, per_replica_eval_batch_size, 2]) _, predictions = training_loop.repeat(int(self.eval_steps), self.eval_step, [tf.constant(0), predictions]) with tf.control_dependencies( [tpu_ops.outfeed_enqueue_tuple([predictions])]): return tf.no_op()
def DecodeLoopFn(): return tpu_training_loop.repeat(self._steps_per_loop, _DecodeStep, inputs=[])
def loop(): return training_loop.repeat(5, training_step, infeed_queue=infeed)
def _experimental_run_steps_on_iterator( self, fn, multi_worker_iterator, iterations, initial_loop_values=None): output_shapes = multi_worker_iterator.output_shapes shapes = nest.flatten(output_shapes) if any(not s.is_fully_defined() for s in shapes): raise ValueError( "TPU currently requires fully defined shapes. Either use " "set_shape() on the input tensors or use " "dataset.batch(..., drop_remainder=True).") # Wrap `fn` for repeat. if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = input_lib.MultiStepContext() def run_fn(inputs): """Single step on the TPU device.""" fn_result = fn(ctx, inputs) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: with ops.control_dependencies([fn_result]): return [array_ops.identity(f) for f in flat_last_step_outputs] else: return fn_result # We capture the control_flow_context at this point, before we run `fn` # inside a while_loop and TPU replicate context. This is useful in cases # where we might need to exit these contexts and get back to the outer # context to do some things, for e.g. create an op which should be # evaluated only once at the end of the loop on the host. One such usage # is in creating metrics' value op. self._outer_control_flow_context = ( ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access def rewrite_fn(*args): """The rewritten step fn running on TPU.""" del args per_replica_inputs = multi_worker_iterator.get_next() replicate_inputs = [] for replica_id in range(self._num_replicas_in_sync): select_replica = lambda x: values.select_replica(replica_id, x) # pylint: disable=cell-var-from-loop replicate_inputs.append((nest.map_structure( select_replica, per_replica_inputs),)) replicate_outputs = tpu.replicate(run_fn, replicate_inputs) # If run_fn has tensor outputs, tpu.replicate returns a list of list. We # will flatten it in this case. If run_fn has no tensor outputs, # tpu.replicate returns a list of no_ops, we will keep the output as it # is. if isinstance(replicate_outputs[0], list): replicate_outputs = nest.flatten(replicate_outputs) return replicate_outputs # TODO(sourabhbajaj): The input to while loop should be based on the # output type of the step_fn assert isinstance(initial_loop_values, list) initial_loop_values = initial_loop_values * self._num_replicas_in_sync # Put the while loop op on TPU host 0. with ops.device(self._host_device): if self.steps_per_run == 1: replicate_outputs = rewrite_fn() else: replicate_outputs = training_loop.repeat(iterations, rewrite_fn, initial_loop_values) del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(replicate_outputs) if isinstance(replicate_outputs, list): # Filter out any ops from the outputs, typically this would be the case # when there were no tensor outputs. last_step_tensor_outputs = [ x for x in replicate_outputs if not isinstance(x, ops.Operation) ] # Outputs are currently of the structure (flattened) # [output0_device0, output1_device0, output2_device0, # output0_device1, output1_device1, output2_device1, # ...] # Convert this to the following structure instead: (grouped by output) # [[output0_device0, output0_device1], # [output1_device0, output1_device1], # [output2_device0, output2_device1]] output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync last_step_tensor_outputs = [ last_step_tensor_outputs[i::output_num] for i in range(output_num) ] else: # no tensors returned. last_step_tensor_outputs = [] _set_last_step_outputs(ctx, last_step_tensor_outputs) return ctx
def train_eval_loop(): return training_loop.repeat(self.num_epochs_tensor, train_eval_step, [])
def train_eval_loop(): return training_loop.repeat(self.hparams.max_train_epochs, train_eval_step, [])
def train_loop(): return training_loop.repeat(self.train_steps_tensor, tpu_train_step, [_INITIAL_LOSS])
def train_eval_loop(): return training_loop.repeat(self.max_train_iterations, train_eval_step, [0])
def train_loop(): return training_loop.repeat(self.iterations, tpu_train_step, [_INITIAL_LOSS])
def train_loop(): return training_loop.repeat(self.iterations_per_loop, train_step, tf.constant(0))
def eval_loop(self): tf.get_variable_scope().reuse_variables() return training_loop.repeat(int(self.eval_steps), self.eval_step)
def eval_loop(): return training_loop.repeat(self.eval_steps_tensor, tpu_eval_step, [])
def _experimental_run_steps_on_iterator(self, fn, multi_worker_iterator, iterations, initial_loop_values=None): # Wrap `fn` for repeat. if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = input_lib.MultiStepContext() def run_fn(inputs): """Single step on the TPU device.""" fn_result = fn(ctx, inputs) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: with ops.control_dependencies([fn_result]): return [ array_ops.identity(f) for f in flat_last_step_outputs ] else: return fn_result # We capture the control_flow_context at this point, before we run `fn` # inside a while_loop and TPU replicate context. This is useful in cases # where we might need to exit these contexts and get back to the outer # context to do some things, for e.g. create an op which should be # evaluated only once at the end of the loop on the host. One such usage # is in creating metrics' value op. self._outer_control_flow_context = ( ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access def rewrite_fn(*args): """The rewritten step fn running on TPU.""" del args per_replica_inputs = multi_worker_iterator.get_next() replicate_inputs = [] for replica_id in range(self._num_replicas_in_sync): select_replica = lambda x: values.select_replica(replica_id, x) # pylint: disable=cell-var-from-loop replicate_inputs.append( (nest.map_structure(select_replica, per_replica_inputs), )) replicate_outputs = tpu.replicate(run_fn, replicate_inputs) # If run_fn has tensor outputs, tpu.replicate returns a list of list. We # will flatten it in this case. If run_fn has no tensor outputs, # tpu.replicate returns a list of no_ops, we will keep the output as it # is. if isinstance(replicate_outputs[0], list): replicate_outputs = nest.flatten(replicate_outputs) return replicate_outputs # TODO(sourabhbajaj): The input to while loop should be based on the # output type of the step_fn assert isinstance(initial_loop_values, list) initial_loop_values = initial_loop_values * self._num_replicas_in_sync # Put the while loop op on TPU host 0. with ops.device(self._host_device): if self.steps_per_run == 1: replicate_outputs = rewrite_fn() else: replicate_outputs = training_loop.repeat( iterations, rewrite_fn, initial_loop_values) del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(replicate_outputs) if isinstance(replicate_outputs, list): # Filter out any ops from the outputs, typically this would be the case # when there were no tensor outputs. last_step_tensor_outputs = [ x for x in replicate_outputs if not isinstance(x, ops.Operation) ] # Outputs are currently of the structure (flattened) # [output0_device0, output1_device0, output2_device0, # output0_device1, output1_device1, output2_device1, # ...] # Convert this to the following structure instead: (grouped by output) # [[output0_device0, output0_device1], # [output1_device0, output1_device1], # [output2_device0, output2_device1]] output_num = len( last_step_tensor_outputs) // self._num_replicas_in_sync last_step_tensor_outputs = [ last_step_tensor_outputs[i::output_num] for i in range(output_num) ] else: # no tensors returned. last_step_tensor_outputs = [] _set_last_step_outputs(ctx, last_step_tensor_outputs) return ctx
def eval_loop(): if self.eval_steps > 0: return training_loop.repeat(self.eval_steps, tpu_eval_step, []) else: return tf.no_op()