def _make_indexed_slices_indices_types_match(true_graph, false_graph): """Match dtype of IndexedSlices.indices in outputs of {true|false}_graphs.""" indexed_slice_indices = [] current_index = 0 true_outputs_flat_with_composites = nest.flatten( true_graph.structured_outputs, expand_composites=False) false_outputs_flat_with_composites = nest.flatten( false_graph.structured_outputs, expand_composites=False) # Store indices of IndexedSlices.indices in `indexed_slice_indices`. for idx, (true_out, false_out) in enumerate( zip(true_outputs_flat_with_composites, false_outputs_flat_with_composites)): if isinstance(true_out, ops.IndexedSlices) != isinstance( false_out, ops.IndexedSlices): raise TypeError("Cannot reconcile tf.cond %i-th outputs:\n" " true_fn returned: %s\n" " false_fn returned: %s" % (idx, true_out, false_out)) if isinstance(true_out, ops.IndexedSlices): # indices is the second component of the composite tensor. indexed_slice_indices.append(current_index + 1) if nest.is_sequence_or_composite(true_out): current_index += len(nest.flatten(true_out, expand_composites=True)) else: current_index += 1 if not indexed_slice_indices: return if current_index != len(true_graph.outputs): raise ValueError("Insufficient elements in true_graph.outputs.\n" "Expected: %i\n" "Actual: %i" % (current_index, len(true_graph.outputs))) # Cast indices with mismatching types to int64. for index in indexed_slice_indices: if true_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64): raise TypeError("Type of IndexedSlices.indices must be int32 or int64. " "Found: %s" % str(true_graph.outputs[index].dtype)) if false_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64): raise TypeError("Type of IndexedSlices.indices must be int32 or int64. " "Found: %s" % str(false_graph.outputs[index].dtype)) if true_graph.outputs[index].dtype != false_graph.outputs[index].dtype: if false_graph.outputs[index].dtype == dtypes.int32: with false_graph.as_default(): false_graph.outputs[index] = math_ops.cast(false_graph.outputs[index], dtypes.int64) else: with true_graph.as_default(): true_graph.outputs[index] = math_ops.cast(true_graph.outputs[index], dtypes.int64) true_graph.structured_outputs = func_graph_module.pack_sequence_as( true_graph.structured_outputs, true_graph.outputs) false_graph.structured_outputs = func_graph_module.pack_sequence_as( false_graph.structured_outputs, false_graph.outputs)
def _make_indexed_slices_indices_types_match(op_type, branch_graphs): """Match dtype of IndexedSlices.indices in outputs of branch_graphs.""" assert branch_graphs indexed_slice_indices = [] current_index = 0 branch_outputs_flat_with_composites = [ nest.flatten(branch_graph.structured_outputs, expand_composites=False) for branch_graph in branch_graphs ] outs_per_branch = [ len(outs) for outs in branch_outputs_flat_with_composites ] assert len(set(outs_per_branch)) == 1, outs_per_branch # Store indices of IndexedSlices.indices in `indexed_slice_indices`. for output_idx, branch_outs in enumerate( zip(*branch_outputs_flat_with_composites)): if len(set(isinstance(out, ops.IndexedSlices) for out in branch_outs)) != 1: raise TypeError( "Cannot reconcile tf.{op_name} {output_idx}-th outputs:\n" " branches returned: {outputs}".format( op_name="cond" if op_type == _COND else "switch_case", output_idx=output_idx, outputs=branch_outs)) if isinstance(branch_outs[0], ops.IndexedSlices): # indices is the second component of the composite tensor. indexed_slice_indices.append(current_index + 1) if nest.is_sequence_or_composite(branch_outs[0]): current_index += len( nest.flatten(branch_outs[0], expand_composites=True)) else: current_index += 1 if not indexed_slice_indices: return if current_index != len(branch_graphs[0].outputs): raise ValueError("Insufficient elements in branch_graphs[0].outputs.\n" "Expected: %i\n" "Actual: %i" % (current_index, len(branch_graphs[0].outputs))) # Cast indices with mismatching types to int64. for index in indexed_slice_indices: if any(bg.outputs[index].dtype not in (dtypes.int32, dtypes.int64) for bg in branch_graphs): raise TypeError( "Type of IndexedSlices.indices must be int32 or int64. " "Found: %s" % str([bg.outputs[index].dtype for bg in branch_graphs])) if len(set(bg.outputs[index].dtype for bg in branch_graphs)) != 1: for branch_graph in branch_graphs: if branch_graph.outputs[index].dtype == dtypes.int32: with branch_graph.as_default(): branch_graph.outputs[index] = math_ops.cast( branch_graph.outputs[index], dtypes.int64) for branch_graph in branch_graphs: branch_graph.structured_outputs = func_graph_module.pack_sequence_as( branch_graph.structured_outputs, branch_graph.outputs)
def _pack_sequence_as(structured_outputs, op_outputs): """Packs the outputs of the gradient If/Case op. The branch functions may contain None's in the list of `structured_outputs`. `op_outputs` has those outputs missing. So we need to add those Nones to the list of `op_outputs` and then pack it in the same structure as `structured_outputs`. Args: structured_outputs: structured_outputs from one of the branch functions. op_outputs: List of output tensors of the op. Returns: `op_outputs` packed like `structured_outputs`. """ outputs_with_nones = [] counter = 0 for output in nest.flatten(structured_outputs, expand_composites=True): if output is None: outputs_with_nones.append(None) else: outputs_with_nones.append(op_outputs[counter]) counter += 1 return func_graph_module.pack_sequence_as(structured_outputs, outputs_with_nones)
def _build_case(branch_index, branch_graphs, branch_inputs, name=None): """Creates an `Case` op from `branch_index`, branch graphs and inputs. Note that this modifies `branch_graphs` to make the inputs match, and to output all intermediates values so they're available for the gradient computation. `branch_graphs` need not have the same input types, but they must have the same outpute types. Args: branch_index: integer Tensor branch_graphs: List of FuncGraph branch_inputs: List of lists of Tensors to be passed to corresponding branch_graph as input. name: the name for the Case op. Returns: A list of Tensors which are the outputs of the Case op. Does not include added intermediate outputs. """ _make_indexed_slices_indices_types_match(_CASE, branch_graphs) _check_same_outputs(_CASE, branch_graphs) # Add inputs to branch_graphs to make them match. Note that this modifies the # graphs in `branch_graphs`. case_inputs = _make_inputs_match(branch_graphs, branch_inputs) # Create the Case op. with ops.control_dependencies( sum((list(bg.control_captures) for bg in branch_graphs), [])): tensors = gen_functional_ops.case( branch_index, case_inputs, [t.dtype for t in branch_graphs[0].outputs], [util.create_new_tf_function(g) for g in branch_graphs], output_shapes=_get_output_shapes( *[g.outputs for g in branch_graphs]), name=name) # TODO(b/110167197): this requires Case to have at least 1 output case_op = tensors[0].op util.maybe_set_lowering_attr(case_op) util.maybe_propagate_compile_time_consts_in_xla(case_op) # Return identities for each output of the Case op, rather than the output of # the Case op directly. This makes pruning work if the output of switch_case() # is fetched: the lowering pass converts the Case outputs into IdentityN # outputs, which if fetched will cause all ops in the taken branch to be run # (since it takes all merge ops as input). After lowering, each output # identity op will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = [array_ops.identity(t) for t in tensors] # Prevent fetching since the variant outputs can't be fetched directly. case_op.graph.prevent_fetching(case_op) return func_graph_module.pack_sequence_as( branch_graphs[0].structured_outputs, tensors)
def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs, name=None): """Creates an If op from the specified predicate, branch functions and inputs. Note that this modifies true_graph and false_graph to make the inputs match, and to output all intermediates values so they're available for the gradient computation. true_graph and false_graph need not have the same input types, but they must have the same outpute types. Args: pred: boolean Tensor true_graph: FuncGraph false_graph: FuncGraph true_inputs: a list of Tensors to be passed to true_graph as input. false_inputs: a list of Tensors to be passed to false_graph as input. name: the name for the If op. Returns: A list of Tensors which are the outputs of the If op. Does not include added intermediate outputs. """ _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs) # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes(true_graph.outputs, false_graph.outputs), name=name) # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op util.maybe_set_lowering_attr(if_op) # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, # which if fetched will cause all ops in the taken branch to be run (since # it takes all merge ops as input). After lowering, each output identity op # will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = [array_ops.identity(t) for t in tensors] # Prevent fetching since the variant outputs can't be fetched directly. if_op.graph.prevent_fetching(if_op) return func_graph_module.pack_sequence_as(true_graph.structured_outputs, tensors)
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" if isinstance(pred, bool): raise TypeError("pred must not be a Python bool", pred) if not name: name = "cond" with ops.name_scope(name) as scope: true_name = util.unique_fn_name(scope, "true") false_name = util.unique_fn_name(scope, "false") # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = ops.get_default_graph( )._add_control_dependencies pred = ops.convert_to_tensor(pred) true_graph = func_graph_module.func_graph_from_py_func( true_name, true_fn, [], {}, func_graph=util.CondBranchFuncGraph( true_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies, op_return_value=pred) false_graph = func_graph_module.func_graph_from_py_func( false_name, false_fn, [], {}, func_graph=util.CondBranchFuncGraph( false_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access add_control_dependencies=add_control_dependencies, op_return_value=pred) outputs = _build_cond(pred, true_graph, false_graph, true_graph.external_captures, false_graph.external_captures, name=scope) return func_graph_module.pack_sequence_as( true_graph.structured_outputs, outputs)
def func_wrapper(*args): args = _convert_to_list(args) with ops.name_scope(name) as scope: func_graph, captured_args = _compile_function( func, args, scope, [], allow_external_captures=True) with ops.control_dependencies(list(func_graph.control_captures)): outputs = gen_functional_ops.function( captured_args, to_apply=util.create_new_tf_function(func_graph), Tout=func_graph.output_types, output_shapes=func_graph.output_shapes) # pack_sequence_as requires a list of Tensors, but the gen_ operation # returns an Operation under some circumstances (probably when that # list would be empty) if isinstance(outputs, ops.Operation): outputs = outputs.outputs return func_graph_module.pack_sequence_as(func_graph.structured_outputs, outputs)
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" if isinstance(pred, bool): raise TypeError("pred must not be a Python bool", pred) if not name: name = "cond" with ops.name_scope(name) as scope: true_name = util.unique_fn_name(scope, "true") false_name = util.unique_fn_name(scope, "false") # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = util.in_defun() pred = ops.convert_to_tensor(pred) true_graph = func_graph_module.func_graph_from_py_func( true_name, true_fn, [], {}, func_graph=util.CondBranchFuncGraph( true_name, read_only_collections=False), add_control_dependencies=add_control_dependencies, op_return_value=pred) false_graph = func_graph_module.func_graph_from_py_func( false_name, false_fn, [], {}, func_graph=util.CondBranchFuncGraph( false_name, read_only_collections=False), add_control_dependencies=add_control_dependencies, op_return_value=pred) outputs = _build_cond(pred, true_graph, false_graph, true_graph.external_captures, false_graph.external_captures, name=scope) return func_graph_module.pack_sequence_as(true_graph.structured_outputs, outputs)
def multi_conv_wrapper(*args): inner_options = options if options else {} if not isinstance(inner_options, dict): raise TypeError( "Expected the multi_conv `options` to be a `dict`, but got %s " "instead." % (str(inner_options))) option_proto = option_flag_pb2.PoplarOptionFlags() for key, value in inner_options.items(): flag = option_proto.flags.add() flag.option = key flag.value = value def func_wrapper(*args): with op_util.gradient_override_scope(training=False): return inner_func(*args) args = functional_ops._convert_to_list(args) # pylint: disable=protected-access with ops.name_scope("multi_conv") as scope: func_graph, captured_args = functional_ops._compile_function( # pylint: disable=protected-access func_wrapper, args, scope, [], allow_external_captures=True) with ops.control_dependencies(list(func_graph.control_captures)): outputs = gen_functional_ops.multi_conv( captured_args, to_apply=util.create_new_tf_function(func_graph), Tout=func_graph.output_types, output_shapes=func_graph.output_shapes, option_flags=json_format.MessageToJson(option_proto)) return func_graph_module.pack_sequence_as(func_graph.structured_outputs, outputs)
def _pipeline_stage(func, stage_id, device_id, args, training, infeed_queue=None, outfeed_queue=None, name=None): """Internal function for compiling a pipeline stage. This should not be called directly and doing so will result in undefined behaviour. Creates a pipeline stage. Args: func: function which will be executed as a stage. stage_id: Stage number. device_id: IPU the stage will be mapped to. args: arguments to the function. infeed_queue: optional IPUInfeedQueue, if passed, it is dequeued as part of this function. outfeed_queue: optional IPUOutfeedQueue, if passed, it is enqueued as part of this function. name: name of this pipeline sage. Returns: The values after execting func(args), or the control dependency if outfeed_queue is not None. """ name = name if name else "pipeline_stage" args = functional_ops._convert_to_list(args) # pylint: disable=protected-access func_to_compile = func control_outputs = [] # If we have an infeed, then we wrap the function in another function which # dequeues the infeed. if infeed_queue: def infeed_func_wrapper(*args): args = functional_ops._convert_to_list(args) # pylint: disable=protected-access dequeue_ops = functional_ops._convert_to_list(infeed_queue._dequeue()) # pylint: disable=protected-access # Deal with the dequeue depending on whether it's a list or dict. if len(dequeue_ops) == 1 and isinstance(dequeue_ops[0], dict): kwargs = dequeue_ops[0] return func(*(args), **kwargs) return func(*(args + dequeue_ops)) func_to_compile = infeed_func_wrapper # If we have an outfeed, then we wrap the function in another function which # enqueues the outfeed. if outfeed_queue: func = func_to_compile def outfeed_func_wrapper(*args, **kwargs): outputs = func(*args, **kwargs) # Check if there are output tensors - if there are then enqueue them. if not isinstance(outputs, ops.Operation): if not isinstance(outputs, dict): outputs = functional_ops._convert_to_list(outputs) # pylint: disable=protected-access outputs = outfeed_queue.enqueue(outputs) control_outputs.append(outputs) func_to_compile = outfeed_func_wrapper def gradient_override_wrapper(*args, **kwargs): with op_util.gradient_override_scope(training): return func_to_compile(*args, **kwargs) with ops.name_scope(name) as scope: # pylint: disable=protected-access try: func_graph, captured_args = functional_ops._compile_function( gradient_override_wrapper, args, scope, control_outputs) except functional_ops._InvalidCaptureException as e: raise ValueError( "Trying to capture the tensor %s which is not a resource. This tensor" " needs to be passed as either part of the `input` or `infeed_queue`" " of the pipeline." % (str(e))) # pylint: enable=protected-access # Create the pipeline stage and lower the function into XLA. with ops.control_dependencies(list(func_graph.control_captures)): with scopes.ipu_shard(device_id): outputs = gen_functional_ops.pipeline_stage( captured_args, to_apply=util.create_new_tf_function(func_graph), Tout=func_graph.output_types, output_shapes=func_graph.output_shapes, stage_id=stage_id) if isinstance(outputs, ops.Operation): return outputs return func_graph_module.pack_sequence_as(func_graph.structured_outputs, outputs)
def cond_v2(pred, true_fn, false_fn, name="cond"): """Like tf.cond, except emits a single If op.""" if isinstance(pred, bool): raise TypeError("pred must not be a Python bool", pred) if not name: name = "cond" with ops.name_scope(name) as scope: true_name = util.unique_fn_name(scope, "true") false_name = util.unique_fn_name(scope, "false") # Automatic control dependencies are added in defuns, but not in v1 # graphs. Propagate that behavior here. add_control_dependencies = util.in_defun() pred = ops.convert_to_tensor(pred) true_graph = func_graph_module.func_graph_from_py_func( true_name, true_fn, [], {}, func_graph=util.CondBranchFuncGraph( true_name, read_only_collections=False), add_control_dependencies=add_control_dependencies, op_return_value=pred) false_graph = func_graph_module.func_graph_from_py_func( false_name, false_fn, [], {}, func_graph=util.CondBranchFuncGraph( false_name, read_only_collections=False), add_control_dependencies=add_control_dependencies, op_return_value=pred) _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_graph.external_captures, false_graph.external_captures) # Add all intermediate tensors as function outputs so they're available for # the gradient computation. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) # Make the number/type of new intermediate outputs match. extra_true_outputs, extra_false_outputs = _pad_params( true_graph, false_graph, true_intermediates, false_intermediates) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes(true_graph.outputs, false_graph.outputs), name=scope) # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output util.maybe_set_lowering_attr(tensors[0].op) # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, # which if fetched will cause all ops in the taken branch to be run (since # it takes all merge ops as input). After lowering, each output identity op # will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = tuple(array_ops.identity(t) for t in tensors) return func_graph_module.pack_sequence_as(true_graph.structured_outputs, tensors[:num_cond_outputs])
def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs, building_gradient, name=None): """Creates an If op from the specified predicate, branch functions and inputs. Note that this modifies true_graph and false_graph to make the inputs match, and to output all intermediates values so they're available for the gradient computation. true_graph and false_graph need not have the same input types, but they must have the same outpute types. Args: pred: boolean Tensor true_graph: FuncGraph false_graph: FuncGraph true_inputs: a list of Tensors to be passed to true_graph as input. false_inputs: a list of Tensors to be passed to false_graph as input. building_gradient: Whether this is a gradient If op. name: the name for the If op. Returns: A list of Tensors which are the outputs of the If op. Does not include added intermediate outputs. """ _make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph]) _check_same_outputs(_COND, [true_graph, false_graph]) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match([true_graph, false_graph], [true_inputs, false_inputs]) # Save the original number of outputs to return to the caller. num_cond_outputs = len(true_graph.outputs) # We do not output intermediates of the gradient If op since this is just # for backwards compatibility with existing code. if not building_gradient and util.output_all_intermediates(): # Add all intermediate tensors as function outputs so they're available for # the gradient computation. Since the outputs of the two functions must # match, we wrap all the intermediates in optionals. Each intermediate # output will have a value iff its corresponding branch is taken. true_intermediates = _get_intermediates(true_graph) false_intermediates = _get_intermediates(false_graph) # Wrap intermediates in optionals. wrapped_true_intermediates = _wrap_intermediates( true_graph, true_intermediates) wrapped_false_intermediates = _wrap_intermediates( false_graph, false_intermediates) # Make outputs match by adding none optionals. extra_true_outputs, extra_false_outputs = _make_intermediates_match( # pylint: disable=unbalanced-tuple-unpacking [true_graph, false_graph], [wrapped_true_intermediates, wrapped_false_intermediates]) true_graph.outputs.extend(extra_true_outputs) false_graph.outputs.extend(extra_false_outputs) _check_same_outputs(_COND, [true_graph, false_graph]) # Create the If op. with ops.control_dependencies( list(true_graph.control_captures) + list(false_graph.control_captures)): true_stateful_ops = [ op for op in true_graph.get_operations() if op._is_stateful ] false_stateful_ops = [ op for op in false_graph.get_operations() if op._is_stateful ] if (true_stateful_ops or false_stateful_ops): op_fn = gen_functional_ops._if else: op_fn = gen_functional_ops.stateless_if tensors = op_fn(pred, cond_inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes( true_graph.outputs, false_graph.outputs), name=name) # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op if_op._true_graph = true_graph if_op._false_graph = false_graph util.maybe_set_lowering_attr(if_op) util.maybe_propagate_compile_time_consts_in_xla(if_op) # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, # which if fetched will cause all ops in the taken branch to be run (since # it takes all merge ops as input). After lowering, each output identity op # will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = [array_ops.identity(t) for t in tensors] # Prevent fetching since the variant outputs can't be fetched directly. if_op.graph.prevent_fetching(if_op) return func_graph_module.pack_sequence_as(true_graph.structured_outputs, tensors[:num_cond_outputs])
def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs, name=None): """Creates an If op from the specified predicate, branch functions and inputs. Note that this modifies true_graph and false_graph to make the inputs match, and to output all intermediates values so they're available for the gradient computation. true_graph and false_graph need not have the same input types, but they must have the same outpute types. Args: pred: boolean Tensor true_graph: FuncGraph false_graph: FuncGraph true_inputs: a list of Tensors to be passed to true_graph as input. false_inputs: a list of Tensors to be passed to false_graph as input. name: the name for the If op. Returns: A list of Tensors which are the outputs of the If op. Does not include added intermediate outputs. """ _make_indexed_slices_indices_types_match(_COND, [true_graph, false_graph]) _check_same_outputs(_COND, [true_graph, false_graph]) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match([true_graph, false_graph], [true_inputs, false_inputs]) # Create the If op. with ops.control_dependencies( list(true_graph.control_captures) + list(false_graph.control_captures)): true_stateful_ops = [ op for op in true_graph.get_operations() if op._is_stateful ] false_stateful_ops = [ op for op in false_graph.get_operations() if op._is_stateful ] # TODO(srbs): Remove this after July 22, 2019. This is required to abide by # 3-week forward compat window of new TF python op generating code with # stale runtime binaries. if (true_stateful_ops or false_stateful_ops or not compat.forward_compatible(2019, 7, 22)): op_fn = gen_functional_ops._if else: op_fn = gen_functional_ops.stateless_if tensors = op_fn(pred, cond_inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes( true_graph.outputs, false_graph.outputs), name=name) # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op util.maybe_set_lowering_attr(if_op) util.maybe_propagate_compile_time_consts_in_xla(if_op) # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, # which if fetched will cause all ops in the taken branch to be run (since # it takes all merge ops as input). After lowering, each output identity op # will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = [array_ops.identity(t) for t in tensors] # Prevent fetching since the variant outputs can't be fetched directly. if_op.graph.prevent_fetching(if_op) return func_graph_module.pack_sequence_as(true_graph.structured_outputs, tensors)
def _make_indexed_slices_indices_types_match(true_graph, false_graph): """Match dtype of IndexedSlices.indices in outputs of {true|false}_graphs.""" indexed_slice_indices = [] current_index = 0 true_outputs_flat_with_composites = nest.flatten( true_graph.structured_outputs, expand_composites=False) false_outputs_flat_with_composites = nest.flatten( false_graph.structured_outputs, expand_composites=False) # Store indices of IndexedSlices.indices in `indexed_slice_indices`. for idx, (true_out, false_out) in enumerate( zip(true_outputs_flat_with_composites, false_outputs_flat_with_composites)): if isinstance(true_out, ops.IndexedSlices) != isinstance( false_out, ops.IndexedSlices): raise TypeError("Cannot reconcile tf.cond %i-th outputs:\n" " true_fn returned: %s\n" " false_fn returned: %s" % (idx, true_out, false_out)) if isinstance(true_out, ops.IndexedSlices): # indices is the second component of the composite tensor. indexed_slice_indices.append(current_index + 1) if nest.is_sequence_or_composite(true_out): current_index += len(nest.flatten(true_out, expand_composites=True)) else: current_index += 1 if not indexed_slice_indices: return if current_index != len(true_graph.outputs): raise ValueError("Insufficient elements in true_graph.outputs.\n" "Expected: %i\n" "Actual: %i" % (current_index, len(true_graph.outputs))) # Cast indices with mismatching types to int64. for index in indexed_slice_indices: if true_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64): raise TypeError( "Type of IndexedSlices.indices must be int32 or int64. " "Found: %s" % str(true_graph.outputs[index].dtype)) if false_graph.outputs[index].dtype not in (dtypes.int32, dtypes.int64): raise TypeError( "Type of IndexedSlices.indices must be int32 or int64. " "Found: %s" % str(false_graph.outputs[index].dtype)) if true_graph.outputs[index].dtype != false_graph.outputs[index].dtype: if false_graph.outputs[index].dtype == dtypes.int32: with false_graph.as_default(): false_graph.outputs[index] = math_ops.cast( false_graph.outputs[index], dtypes.int64) else: with true_graph.as_default(): true_graph.outputs[index] = math_ops.cast( true_graph.outputs[index], dtypes.int64) true_graph.structured_outputs = func_graph_module.pack_sequence_as( true_graph.structured_outputs, true_graph.outputs) false_graph.structured_outputs = func_graph_module.pack_sequence_as( false_graph.structured_outputs, false_graph.outputs)
def _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs, name=None): """Creates an If op from the specified predicate, branch functions and inputs. Note that this modifies true_graph and false_graph to make the inputs match, and to output all intermediates values so they're available for the gradient computation. true_graph and false_graph need not have the same input types, but they must have the same outpute types. Args: pred: boolean Tensor true_graph: FuncGraph false_graph: FuncGraph true_inputs: a list of Tensors to be passed to true_graph as input. false_inputs: a list of Tensors to be passed to false_graph as input. name: the name for the If op. Returns: A list of Tensors which are the outputs of the If op. Does not include added intermediate outputs. """ _check_same_outputs(true_graph, false_graph) # Add inputs to true_graph and false_graph to make them match. Note that # this modifies true_graph and false_graph. cond_inputs = _make_inputs_match(true_graph, false_graph, true_inputs, false_inputs) # Create the If op. tensors = gen_functional_ops._if( # pylint: disable=protected-access pred, cond_inputs, [t.dtype for t in true_graph.outputs], util.create_new_tf_function(true_graph), util.create_new_tf_function(false_graph), output_shapes=_get_output_shapes(true_graph.outputs, false_graph.outputs), name=name) # TODO(b/110167197) this approach requires cond_v2 to have at least 1 output if_op = tensors[0].op util.maybe_set_lowering_attr(if_op) # Return identities for each output of the If op, rather than the output of # the If op directly. This makes pruning work if the output of cond() is # fetched: the lowering pass converts the If outputs into IdentityN outputs, # which if fetched will cause all ops in the taken branch to be run (since # it takes all merge ops as input). After lowering, each output identity op # will end up with only the appropriate merge op as input. # TODO(b/79984175): this doesn't have to be a tuple once we covert to the # correct output structure tensors = [array_ops.identity(t) for t in tensors] # Prevent fetching since the variant outputs can't be fetched directly. if_op.graph.prevent_fetching(if_op) return func_graph_module.pack_sequence_as(true_graph.structured_outputs, tensors)