def _train_on_tpu_system(model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" config = model_fn_wrapper.config.tpu_config iterations_per_loop = config.iterations_per_loop num_shards = config.num_shards single_tpu_train_step = model_fn_wrapper.convert_to_single_tpu_train_step( dequeue_fn) multi_tpu_train_steps_on_single_shard = (lambda: training_loop.repeat( # pylint: disable=g-long-lambda iterations_per_loop, single_tpu_train_step, [_INITIAL_LOSS], name='loop')) (loss,) = tpu.shard(multi_tpu_train_steps_on_single_shard, inputs=[], num_shards=num_shards, outputs_from_all_shards=False) return loss
def iterate_on_tpu(): return training_loop.repeat(iterations, run_fn, initial_loop_values)
def loop(): return training_loop.repeat(5, training_step, infeed_queue=infeed)
def _experimental_run_steps_on_iterator( self, fn, multi_worker_iterator, iterations, initial_loop_values=None): output_shapes = multi_worker_iterator.output_shapes shapes = nest.flatten(output_shapes) if any(not s.is_fully_defined() for s in shapes): raise ValueError( "TPU currently requires fully defined shapes. Either use " "set_shape() on the input tensors or use " "dataset.batch(..., drop_remainder=True).") # Wrap `fn` for repeat. if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = input_lib.MultiStepContext() def run_fn(inputs): """Single step on the TPU device.""" fn_result = fn(ctx, inputs) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: with ops.control_dependencies([fn_result]): return [array_ops.identity(f) for f in flat_last_step_outputs] else: return fn_result # We capture the control_flow_context at this point, before we run `fn` # inside a while_loop and TPU replicate context. This is useful in cases # where we might need to exit these contexts and get back to the outer # context to do some things, for e.g. create an op which should be # evaluated only once at the end of the loop on the host. One such usage # is in creating metrics' value op. self._outer_control_flow_context = ( ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access def rewrite_fn(*args): """The rewritten step fn running on TPU.""" del args per_replica_inputs = multi_worker_iterator.get_next() replicate_inputs = [] for replica_id in range(self._num_replicas_in_sync): select_replica = lambda x: values.select_replica(replica_id, x) # pylint: disable=cell-var-from-loop replicate_inputs.append((nest.map_structure( select_replica, per_replica_inputs),)) replicate_outputs = tpu.replicate(run_fn, replicate_inputs) # If run_fn has tensor outputs, tpu.replicate returns a list of list. We # will flatten it in this case. If run_fn has no tensor outputs, # tpu.replicate returns a list of no_ops, we will keep the output as it # is. if isinstance(replicate_outputs[0], list): replicate_outputs = nest.flatten(replicate_outputs) return replicate_outputs # TODO(sourabhbajaj): The input to while loop should be based on the # output type of the step_fn assert isinstance(initial_loop_values, list) initial_loop_values = initial_loop_values * self._num_replicas_in_sync # Put the while loop op on host 0. with ops.device(self.get_host_cpu_device(0)): replicate_outputs = training_loop.repeat(iterations, rewrite_fn, initial_loop_values) del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(replicate_outputs) if isinstance(replicate_outputs, list): # Filter out any ops from the outputs, typically this would be the case # when there were no tensor outputs. last_step_tensor_outputs = [ x for x in replicate_outputs if not isinstance(x, ops.Operation) ] # Outputs are currently of the structure (flattened) # [output0_device0, output1_device0, output2_device0, # output0_device1, output1_device1, output2_device1, # ...] # Convert this to the following structure instead: (grouped by output) # [[output0_device0, output0_device1], # [output1_device0, output1_device1], # [output2_device0, output2_device1]] output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync last_step_tensor_outputs = [ last_step_tensor_outputs[i::output_num] for i in range(output_num) ] else: # no tensors returned. last_step_tensor_outputs = [] _set_last_step_outputs(ctx, last_step_tensor_outputs) return ctx
def _experimental_run_steps_on_iterator(self, fn, multi_worker_iterator, iterations, initial_loop_values=None): output_shapes = multi_worker_iterator.output_shapes shapes = nest.flatten(output_shapes) if any(not s.is_fully_defined() for s in shapes): raise ValueError( "TPU currently requires fully defined shapes. Either use " "set_shape() on the input tensors or use " "dataset.batch(..., drop_remainder=True).") types = nest.flatten(multi_worker_iterator.output_types) enqueue_ops = [ self._get_enqueue_op_per_host(host_id, multi_worker_iterator, shapes, iterations) for host_id in range(self.num_hosts) ] def dequeue_fn(): dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes) return nest.pack_sequence_as(output_shapes, dequeued) # Wrap `fn` for repeat. if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = input_lib.MultiStepContext() def run_fn(*args, **kwargs): """Single step on the TPU device.""" del args, kwargs fn_result = fn(ctx, dequeue_fn()) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: with ops.control_dependencies([fn_result]): return [ array_ops.identity(f) for f in flat_last_step_outputs ] else: return fn_result def iterate_on_tpu(): return training_loop.repeat(iterations, run_fn, initial_loop_values) # We capture the control_flow_context at this point, before we run `fn` # inside a while_loop and TPU replicate context. This is useful in cases # where we might need to exit these contexts and get back to the outer # context to do some things, for e.g. create an op which should be # evaluated only once at the end of the loop on the host. One such usage # is in creating metrics' value op. self._outer_control_flow_context = ( ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access # pylint: disable=protected-access if self._container_strategy()._disable_training_loop_on_host: replicate_inputs = [[]] * self._num_replicas_in_sync replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) else: def rewrite_fn(*args): """The rewritten step fn running on TPU.""" del args replicate_inputs = [[]] * self._num_replicas_in_sync replicate_outputs = tpu.replicate(run_fn, replicate_inputs) # If run_fn has tensor outputs, tpu.replicate returns a list of list. We # will flatten it in this case. If run_fn has no tensor outputs, # tpu.replicate returns a list of no_ops, we will keep the output as it # is. if isinstance(replicate_outputs[0], list): replicate_outputs = nest.flatten(replicate_outputs) return replicate_outputs # TODO(sourabhbajaj): The input to while loop should be based on the # output type of the step_fn assert isinstance(initial_loop_values, list) initial_loop_values = initial_loop_values * self._num_replicas_in_sync # Put the while loop op on host 0. with ops.device(self.get_host_cpu_device(0)): replicate_outputs = training_loop.repeat( iterations, rewrite_fn, initial_loop_values) del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) if self._container_strategy()._disable_training_loop_on_host: # Filter out any ops from the outputs, typically this would be the case # when there were no tensor outputs. last_step_tensor_outputs = [ x for x in replicate_outputs if not isinstance(x, ops.Operation) ] # Outputs are currently of the structure (grouped by device) # [[output0_device0, output1_device0, output2_device0], # [output0_device1, output1_device1, output2_device1]] # Convert this to the following structure instead: (grouped by output) # [[output0_device0, output0_device1], # [output1_device0, output1_device1], # [output2_device0, output2_device1]] last_step_tensor_outputs = [ list(x) for x in zip(*last_step_tensor_outputs) ] else: if isinstance(replicate_outputs, list): # Filter out any ops from the outputs, typically this would be the case # when there were no tensor outputs. last_step_tensor_outputs = [ x for x in replicate_outputs if not isinstance(x, ops.Operation) ] # Outputs are currently of the structure (flattened) # [output0_device0, output1_device0, output2_device0, # output0_device1, output1_device1, output2_device1, # ...] # Convert this to the following structure instead: (grouped by output) # [[output0_device0, output0_device1], # [output1_device0, output1_device1], # [output2_device0, output2_device1]] output_num = len( last_step_tensor_outputs) // self._num_replicas_in_sync last_step_tensor_outputs = [ last_step_tensor_outputs[i::output_num] for i in range(output_num) ] else: # no tensors returned. last_step_tensor_outputs = [] # Convert replicate_outputs to the original dict structure of # last_step_outputs. last_step_tensor_outputs_dict = nest.pack_sequence_as( ctx.last_step_outputs, last_step_tensor_outputs) for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access output = last_step_tensor_outputs_dict[name] # For outputs that have already been reduced, take the first value # from the list as each value should be the same. Else return the full # list of values. # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica # value. if reduce_op is not None: # TODO(priyag): Should this return the element or a list with 1 element last_step_tensor_outputs_dict[name] = output[0] ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access return ctx
def train_shard(): return training_loop.repeat( run_config.tpu_config.iterations_per_loop, train_step, [1e7], # initial_loss name='loop')
def _experimental_run_steps_on_iterator( self, fn, multi_worker_iterator, iterations, initial_loop_values=None): output_shapes = multi_worker_iterator.output_shapes shapes = nest.flatten(output_shapes) if any(not s.is_fully_defined() for s in shapes): raise ValueError( "TPU currently requires fully defined shapes. Either use " "set_shape() on the input tensors or use " "dataset.batch(..., drop_remainder=True).") types = nest.flatten(multi_worker_iterator.output_types) enqueue_ops = [ self._get_enqueue_op_per_host(host_id, multi_worker_iterator, shapes, iterations) for host_id in range(self.num_hosts)] def dequeue_fn(): dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes) return nest.pack_sequence_as(output_shapes, dequeued) # Wrap `fn` for repeat. if initial_loop_values is None: initial_loop_values = {} initial_loop_values = nest.flatten(initial_loop_values) ctx = input_lib.MultiStepContext() def run_fn(*args, **kwargs): """Single step on the TPU device.""" del args, kwargs fn_result = fn(ctx, dequeue_fn()) flat_last_step_outputs = nest.flatten(ctx.last_step_outputs) if flat_last_step_outputs: with ops.control_dependencies([fn_result]): return [array_ops.identity(f) for f in flat_last_step_outputs] else: return fn_result def iterate_on_tpu(): return training_loop.repeat(iterations, run_fn, initial_loop_values) # We capture the control_flow_context at this point, before we run `fn` # inside a while_loop and TPU replicate context. This is useful in cases # where we might need to exit these contexts and get back to the outer # context to do some things, for e.g. create an op which should be # evaluated only once at the end of the loop on the host. One such usage # is in creating metrics' value op. self._outer_control_flow_context = ( ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access # pylint: disable=protected-access if self._container_strategy()._disable_training_loop_on_host: replicate_inputs = [[]] * self._num_replicas_in_sync replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs) else: def rewrite_fn(*args): """The rewritten step fn running on TPU.""" del args replicate_inputs = [[]] * self._num_replicas_in_sync replicate_outputs = tpu.replicate(run_fn, replicate_inputs) # If run_fn has tensor outputs, tpu.replicate returns a list of list. We # will flatten it in this case. If run_fn has no tensor outputs, # tpu.replicate returns a list of no_ops, we will keep the output as it # is. if isinstance(replicate_outputs[0], list): replicate_outputs = nest.flatten(replicate_outputs) return replicate_outputs # TODO(sourabhbajaj): The input to while loop should be based on the # output type of the step_fn assert isinstance(initial_loop_values, list) initial_loop_values = initial_loop_values * self._num_replicas_in_sync # Put the while loop op on host 0. with ops.device(self.get_host_cpu_device(0)): replicate_outputs = training_loop.repeat(iterations, rewrite_fn, initial_loop_values) del self._outer_control_flow_context ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops) if self._container_strategy()._disable_training_loop_on_host: # Filter out any ops from the outputs, typically this would be the case # when there were no tensor outputs. last_step_tensor_outputs = [x for x in replicate_outputs if not isinstance(x, ops.Operation)] # Outputs are currently of the structure (grouped by device) # [[output0_device0, output1_device0, output2_device0], # [output0_device1, output1_device1, output2_device1]] # Convert this to the following structure instead: (grouped by output) # [[output0_device0, output0_device1], # [output1_device0, output1_device1], # [output2_device0, output2_device1]] last_step_tensor_outputs = [list(x) for x in zip(*last_step_tensor_outputs)] else: if isinstance(replicate_outputs, list): # Filter out any ops from the outputs, typically this would be the case # when there were no tensor outputs. last_step_tensor_outputs = [ x for x in replicate_outputs if not isinstance(x, ops.Operation) ] # Outputs are currently of the structure (flattened) # [output0_device0, output1_device0, output2_device0, # output0_device1, output1_device1, output2_device1, # ...] # Convert this to the following structure instead: (grouped by output) # [[output0_device0, output0_device1], # [output1_device0, output1_device1], # [output2_device0, output2_device1]] output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync last_step_tensor_outputs = [ last_step_tensor_outputs[i::output_num] for i in range(output_num) ] else: # no tensors returned. last_step_tensor_outputs = [] # Convert replicate_outputs to the original dict structure of # last_step_outputs. last_step_tensor_outputs_dict = nest.pack_sequence_as( ctx.last_step_outputs, last_step_tensor_outputs) for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access output = last_step_tensor_outputs_dict[name] # For outputs that have already been reduced, take the first value # from the list as each value should be the same. Else return the full # list of values. # TODO(josh11b): If reduce_op is NONE, we should return a PerReplica # value. if reduce_op is not None: # TODO(priyag): Should this return the element or a list with 1 element last_step_tensor_outputs_dict[name] = output[0] ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access return ctx
def train_shard(): return training_loop.repeat(run_config.tpu_config.iterations_per_loop, train_step, [1e7], # initial_loss name='loop')