def call(self, inputs, mask=None): """Call the model on new inputs. In this case `call` just reapplies all ops in the graph to the new inputs (e.g. build a new computational graph from the provided inputs). Arguments: inputs: A tensor or list of tensors. mask: A mask or list of masks. A mask can be either a tensor or None (no mask). Returns: A tensor if there is a single output, or a list of tensors if there are more than one outputs. """ inputs = nest.flatten(inputs) if mask is None: masks = [None for _ in range(len(inputs))] else: masks = nest.flatten(mask) if context.in_graph_mode(): # Try to retrieve cached outputs if the layer has already been called # on these exact inputs. cache_key = (layers_util.object_list_uid(inputs) + '_' + layers_util.object_list_uid(masks)) if cache_key in self._output_tensor_cache: # Cache hit. return self._output_tensor_cache[cache_key] # Actually apply the network graph to the new inputs. outputs, _ = self._run_internal_graph(inputs, masks) return outputs
def get_updates_for(self, inputs=None): # If the wrapper modifies the inputs, use the modified inputs to # get the updates from the inner layer. inner_inputs = inputs if inputs is not None: uid = tf_layers_util.object_list_uid(inputs) if uid in self._input_map: inner_inputs = self._input_map[uid] updates = self.layer.get_updates_for(inner_inputs) updates += super(Wrapper, self).get_updates_for(inputs) return updates
def call(self, inputs, training=None, mask=None): kwargs = {} if has_arg(self.layer.call, 'training'): kwargs['training'] = training uses_learning_phase = False # pylint: disable=redefined-outer-name input_shape = K.int_shape(inputs) if input_shape[0]: # batch size matters, use rnn-based implementation def step(x, _): global uses_learning_phase # pylint: disable=global-variable-undefined output = self.layer.call(x, **kwargs) if hasattr(output, '_uses_learning_phase'): uses_learning_phase = (output._uses_learning_phase or uses_learning_phase) return output, [] _, outputs, _ = K.rnn( step, inputs, initial_states=[], unroll=False) y = outputs else: # No batch size specified, therefore the layer will be able # to process batches of any size. # We can go with reshape-based implementation for performance. input_length = input_shape[1] if not input_length: input_length = K.shape(inputs)[1] # Shape: (num_samples * timesteps, ...). And track the # transformation in self._input_map. input_uid = tf_layers_util.object_list_uid(inputs) inputs = K.reshape(inputs, (-1,) + input_shape[2:]) self._input_map[input_uid] = inputs # (num_samples * timesteps, ...) y = self.layer.call(inputs, **kwargs) if hasattr(y, '_uses_learning_phase'): uses_learning_phase = y._uses_learning_phase # Shape: (num_samples, timesteps, ...) output_shape = self.compute_output_shape(input_shape).as_list() y = K.reshape(y, (-1, input_length) + tuple(output_shape[2:])) # Apply activity regularizer if any: if (hasattr(self.layer, 'activity_regularizer') and self.layer.activity_regularizer is not None): regularization_loss = self.layer.activity_regularizer(y) self.add_loss(regularization_loss, inputs) if uses_learning_phase: y._uses_learning_phase = True return y
def call(self, inputs, training=None, mask=None): kwargs = {} if has_arg(self.layer.call, 'training'): kwargs['training'] = training uses_learning_phase = False # pylint: disable=redefined-outer-name input_shape = K.int_shape(inputs) if input_shape[0]: # batch size matters, use rnn-based implementation def step(x, _): global uses_learning_phase # pylint: disable=global-variable-undefined output = self.layer.call(x, **kwargs) if hasattr(output, '_uses_learning_phase'): uses_learning_phase = (output._uses_learning_phase or uses_learning_phase) return output, [] _, outputs, _ = K.rnn(step, inputs, initial_states=[], unroll=False) y = outputs else: # No batch size specified, therefore the layer will be able # to process batches of any size. # We can go with reshape-based implementation for performance. input_length = input_shape[1] if not input_length: input_length = array_ops.shape(inputs)[1] # Shape: (num_samples * timesteps, ...). And track the # transformation in self._input_map. input_uid = tf_layers_util.object_list_uid(inputs) inputs = array_ops.reshape(inputs, (-1, ) + input_shape[2:]) self._input_map[input_uid] = inputs # (num_samples * timesteps, ...) y = self.layer.call(inputs, **kwargs) if hasattr(y, '_uses_learning_phase'): uses_learning_phase = y._uses_learning_phase # Shape: (num_samples, timesteps, ...) output_shape = self.compute_output_shape(input_shape).as_list() y = array_ops.reshape(y, (-1, input_length) + tuple(output_shape[2:])) # Apply activity regularizer if any: if (hasattr(self.layer, 'activity_regularizer') and self.layer.activity_regularizer is not None): regularization_loss = self.layer.activity_regularizer(y) self.add_loss(regularization_loss, inputs) if uses_learning_phase: y._uses_learning_phase = True return y
def _run_internal_graph(self, inputs, masks=None): """Computes output tensors for new inputs. # Note: - Expects `inputs` to be a list (potentially with 1 element). - Can be run on non-Keras tensors. Arguments: inputs: List of tensors masks: List of masks (tensors or None). Returns: Three lists: output_tensors, output_masks, output_shapes """ # Note: masking support is relevant mainly for Keras. # It cannot be factored out without having the fully reimplement the network # calling logic on the Keras side. We choose to incorporate it in # GraphNetwork because 1) it may be useful to fully support in tf.layers in # the future and 2) Keras is a major user of GraphNetwork. If you don't # use masking, it does not interfere with regular behavior at all and you # can ignore it. if masks is None: masks = [None for _ in range(len(inputs))] # Dictionary mapping reference tensors to tuples # (computed tensor, compute mask) # we assume a 1:1 mapping from tensor to mask # TODO(fchollet): raise exception when a `.compute_mask()` call # does not return a list the same size as `call` tensor_map = {} for x, y, mask in zip(self.inputs, inputs, masks): tensor_map[str(id(x))] = (y, mask) depth_keys = list(self._nodes_by_depth.keys()) depth_keys.sort(reverse=True) for depth in depth_keys: nodes = self._nodes_by_depth[depth] for node in nodes: # This is always a single layer, never a list. layer = node.outbound_layer reference_input_tensors = node.input_tensors reference_output_tensors = node.output_tensors # If all previous input tensors are available in tensor_map, # then call node.inbound_layer on them. computed_data = [] # List of tuples (input, mask). for x in reference_input_tensors: if str(id(x)) in tensor_map: computed_data.append(tensor_map[str(id(x))]) if len(computed_data) == len(reference_input_tensors): # Call layer (reapplying ops to new inputs). with ops.name_scope(layer.name): if node.arguments: kwargs = node.arguments else: kwargs = {} if len(computed_data) == 1: computed_tensor, computed_mask = computed_data[0] # Ensure mask propagation if applicable. if 'mask' in estimator_util.fn_args(layer.call): if 'mask' not in kwargs: kwargs['mask'] = computed_mask output_tensors = nest.flatten( layer.call(computed_tensor, **kwargs)) if hasattr(layer, 'compute_mask'): output_masks = nest.flatten( layer.compute_mask(computed_tensor, computed_mask)) else: output_masks = [ None for _ in range(len(output_tensors)) ] computed_tensors = [computed_tensor] computed_masks = [computed_mask] else: computed_tensors = [x[0] for x in computed_data] computed_masks = [x[1] for x in computed_data] if 'mask' in estimator_util.fn_args(layer.call): if 'mask' not in kwargs: kwargs['mask'] = computed_masks output_tensors = nest.flatten( layer.call(computed_tensors, **kwargs)) if hasattr(layer, 'compute_mask'): output_masks = nest.flatten( layer.compute_mask(computed_tensors, computed_masks)) else: output_masks = [ None for _ in range(len(output_tensors)) ] # Apply activity regularizer if any: if layer.activity_regularizer is not None: regularization_losses = [ layer.activity_regularizer(x) for x in computed_tensors ] layer.add_loss(regularization_losses, computed_tensors) if context.in_graph_mode(): # Update model updates and losses: # Keep track of updates that depend on the inputs # (e.g. BN updates). self.add_update( layer.get_updates_for(computed_tensors), inputs) # Keep track of unconditional updates (e.g. a counter). self.add_update(layer.get_updates_for(None), None) # Keep track of losses that depend on the inputs # (e.g. activity regularizers). self.add_loss(layer.get_losses_for(computed_tensors), inputs) # Keep track of unconditional losses # (e.g. weight regularizers). self.add_loss(layer.get_losses_for(None), None) # Update tensor_map. for x, y, mask in zip(reference_output_tensors, output_tensors, output_masks): tensor_map[str(id(x))] = (y, mask) output_tensors = [] output_masks = [] output_shapes = [] for x in self.outputs: assert str( id(x)) in tensor_map, 'Could not compute output ' + str(x) tensor, mask = tensor_map[str(id(x))] output_shapes.append(layers_util.static_shape(x)) output_tensors.append(tensor) output_masks.append(mask) if len(output_tensors) == 1: output_tensors = output_tensors[0] if output_shapes is not None: output_shapes = output_shapes[0] if output_masks is not None: output_masks = output_masks[0] if context.in_graph_mode(): # Update cache; # keys are based on ids on input tensors and inputs masks. cache_key = (layers_util.object_list_uid(inputs) + '_' + layers_util.object_list_uid(masks)) self._output_tensor_cache[cache_key] = output_tensors if output_masks is not None: self._output_mask_cache[cache_key] = output_masks if output_shapes is not None: input_shapes = [layers_util.static_shape(x) for x in inputs] cache_key = layers_util.object_list_uid(input_shapes) self._output_shape_cache[cache_key] = output_shapes return output_tensors, output_masks
def compute_output_shape(self, input_shape): if isinstance(input_shape, list): input_shapes = [] for shape in input_shape: if shape is not None: input_shapes.append( tuple(tensor_shape.TensorShape(shape).as_list())) else: input_shapes.append(None) else: if input_shape is not None: input_shapes = [ tuple(tensor_shape.TensorShape(input_shape).as_list()) ] else: input_shapes = [None] if len(input_shapes) != len(self._input_layers): raise ValueError('Invalid input_shape argument ' + str(input_shape) + ': model has ' + str(len(self._input_layers)) + ' tensor inputs.') cache_key = layers_util.object_list_uid(input_shapes) if cache_key not in self._output_shape_cache: # Cache miss. We have to run the network graph manually (recursive calls # to `compute_output_shape`). layers_to_output_shapes = {} for i in range(len(input_shapes)): layer = self._input_layers[i] input_shape = input_shapes[i] # It's an input layer: then `compute_output_shape` is identity, # and there is only one node and one tensor output. shape_key = layer.name + '_0_0' layers_to_output_shapes[shape_key] = input_shape depth_keys = list(self._nodes_by_depth.keys()) depth_keys.sort(reverse=True) # Iterate over nodes, by depth level. if len(depth_keys) > 1: for depth in depth_keys: nodes = self._nodes_by_depth[depth] for node in nodes: # This is always a single layer, never a list. layer = node.outbound_layer if layer in self._input_layers: # We've already covered the input layers # a few lines above. continue # Potentially redundant list, # same size as node.input_tensors. input_shapes = [] for j in range(len(node.inbound_layers)): inbound_layer = node.inbound_layers[j] node_index = node.node_indices[j] tensor_index = node.tensor_indices[j] shape_key = inbound_layer.name + '_%s_%s' % ( node_index, tensor_index) input_shape = layers_to_output_shapes[shape_key] input_shapes.append(input_shape) if len(input_shapes) == 1: output_shape = layer.compute_output_shape( input_shapes[0]) else: output_shape = layer.compute_output_shape( input_shapes) if isinstance(output_shape, list): output_shapes = [ tuple( tensor_shape.TensorShape(shape).as_list()) for shape in output_shape ] else: output_shapes = [ tuple( tensor_shape.TensorShape( output_shape).as_list()) ] node_index = layer._inbound_nodes.index(node) # pylint: disable=protected-access for j in range(len(output_shapes)): shape_key = layer.name + '_%s_%s' % (node_index, j) layers_to_output_shapes[shape_key] = output_shapes[ j] # Read final output shapes from layers_to_output_shapes. output_shapes = [] for i in range(len(self._output_layers)): layer, node_index, tensor_index = self._output_coordinates[ i] shape_key = layer.name + '_%s_%s' % (node_index, tensor_index) output_shapes.append(layers_to_output_shapes[shape_key]) # Store in cache. self._output_shape_cache[cache_key] = output_shapes else: # Cache hit. output_shapes = self._output_shape_cache[cache_key] if isinstance(output_shapes, list): if len(output_shapes) == 1: return tensor_shape.TensorShape(output_shapes[0]) else: return [ tensor_shape.TensorShape(shape) for shape in output_shapes ] else: return tensor_shape.TensorShape(output_shapes)
def _run_internal_graph(self, inputs, masks=None): """Computes output tensors for new inputs. # Note: - Expects `inputs` to be a list (potentially with 1 element). - Can be run on non-Keras tensors. Arguments: inputs: List of tensors masks: List of masks (tensors or None). Returns: Three lists: output_tensors, output_masks, output_shapes """ # Note: masking support is relevant mainly for Keras. # It cannot be factored out without having the fully reimplement the network # calling logic on the Keras side. We choose to incorporate it in # GraphNetwork because 1) it may be useful to fully support in tf.layers in # the future and 2) Keras is a major user of GraphNetwork. If you don't # use masking, it does not interfere with regular behavior at all and you # can ignore it. if masks is None: masks = [None for _ in range(len(inputs))] # Dictionary mapping reference tensors to tuples # (computed tensor, compute mask) # we assume a 1:1 mapping from tensor to mask # TODO(fchollet): raise exception when a `.compute_mask()` call # does not return a list the same size as `call` tensor_map = {} for x, y, mask in zip(self.inputs, inputs, masks): tensor_map[str(id(x))] = (y, mask) depth_keys = list(self._nodes_by_depth.keys()) depth_keys.sort(reverse=True) for depth in depth_keys: nodes = self._nodes_by_depth[depth] for node in nodes: # This is always a single layer, never a list. layer = node.outbound_layer reference_input_tensors = node.input_tensors reference_output_tensors = node.output_tensors # If all previous input tensors are available in tensor_map, # then call node.inbound_layer on them. computed_data = [] # List of tuples (input, mask). for x in reference_input_tensors: if str(id(x)) in tensor_map: computed_data.append(tensor_map[str(id(x))]) if len(computed_data) == len(reference_input_tensors): # Call layer (reapplying ops to new inputs). with ops.name_scope(layer.name): if node.arguments: kwargs = node.arguments else: kwargs = {} if len(computed_data) == 1: computed_tensor, computed_mask = computed_data[0] # Ensure mask propagation if applicable. if 'mask' in estimator_util.fn_args(layer.call): if 'mask' not in kwargs: kwargs['mask'] = computed_mask output_tensors = nest.flatten( layer.call(computed_tensor, **kwargs)) if hasattr(layer, 'compute_mask'): output_masks = nest.flatten( layer.compute_mask(computed_tensor, computed_mask)) else: output_masks = [None for _ in range(len(output_tensors))] computed_tensors = [computed_tensor] computed_masks = [computed_mask] else: computed_tensors = [x[0] for x in computed_data] computed_masks = [x[1] for x in computed_data] if 'mask' in estimator_util.fn_args(layer.call): if 'mask' not in kwargs: kwargs['mask'] = computed_masks output_tensors = nest.flatten( layer.call(computed_tensors, **kwargs)) if hasattr(layer, 'compute_mask'): output_masks = nest.flatten( layer.compute_mask(computed_tensors, computed_masks)) else: output_masks = [None for _ in range(len(output_tensors))] # Apply activity regularizer if any: if layer.activity_regularizer is not None: regularization_losses = [ layer.activity_regularizer(x) for x in computed_tensors ] layer.add_loss(regularization_losses, computed_tensors) if context.in_graph_mode(): # Update model updates and losses: # Keep track of updates that depend on the inputs # (e.g. BN updates). self.add_update(layer.get_updates_for(computed_tensors), inputs) # Keep track of unconditional updates (e.g. a counter). self.add_update(layer.get_updates_for(None), None) # Keep track of losses that depend on the inputs # (e.g. activity regularizers). self.add_loss(layer.get_losses_for(computed_tensors), inputs) # Keep track of unconditional losses # (e.g. weight regularizers). self.add_loss(layer.get_losses_for(None), None) # Update tensor_map. for x, y, mask in zip(reference_output_tensors, output_tensors, output_masks): tensor_map[str(id(x))] = (y, mask) output_tensors = [] output_masks = [] output_shapes = [] for x in self.outputs: assert str(id(x)) in tensor_map, 'Could not compute output ' + str(x) tensor, mask = tensor_map[str(id(x))] output_shapes.append(layers_util.static_shape(x)) output_tensors.append(tensor) output_masks.append(mask) if len(output_tensors) == 1: output_tensors = output_tensors[0] if output_shapes is not None: output_shapes = output_shapes[0] if output_masks is not None: output_masks = output_masks[0] if context.in_graph_mode(): # Update cache; # keys are based on ids on input tensors and inputs masks. cache_key = (layers_util.object_list_uid(inputs) + '_' + layers_util.object_list_uid(masks)) self._output_tensor_cache[cache_key] = output_tensors if output_masks is not None: self._output_mask_cache[cache_key] = output_masks if output_shapes is not None: input_shapes = [layers_util.static_shape(x) for x in inputs] cache_key = layers_util.object_list_uid(input_shapes) self._output_shape_cache[cache_key] = output_shapes return output_tensors, output_masks
def _compute_output_shape(self, input_shape): if isinstance(input_shape, list): input_shapes = [] for shape in input_shape: if shape is not None: input_shapes.append(tuple(tensor_shape.TensorShape(shape).as_list())) else: input_shapes.append(None) else: if input_shape is not None: input_shapes = [tuple(tensor_shape.TensorShape(input_shape).as_list())] else: input_shapes = [None] if len(input_shapes) != len(self._input_layers): raise ValueError('Invalid input_shape argument ' + str(input_shape) + ': model has ' + str(len(self._input_layers)) + ' tensor inputs.') cache_key = layers_util.object_list_uid(input_shapes) if cache_key not in self._output_shape_cache: # Cache miss. We have to run the network graph manually (recursive calls # to `_compute_output_shape`). layers_to_output_shapes = {} for i in range(len(input_shapes)): layer = self._input_layers[i] input_shape = input_shapes[i] # It's an input layer: then `_compute_output_shape` is identity, # and there is only one node and one tensor output. shape_key = layer.name + '_0_0' layers_to_output_shapes[shape_key] = input_shape depth_keys = list(self._nodes_by_depth.keys()) depth_keys.sort(reverse=True) # Iterate over nodes, by depth level. if len(depth_keys) > 1: for depth in depth_keys: nodes = self._nodes_by_depth[depth] for node in nodes: # This is always a single layer, never a list. layer = node.outbound_layer if layer in self._input_layers: # We've already covered the input layers # a few lines above. continue # Potentially redundant list, # same size as node.input_tensors. input_shapes = [] for j in range(len(node.inbound_layers)): inbound_layer = node.inbound_layers[j] node_index = node.node_indices[j] tensor_index = node.tensor_indices[j] shape_key = inbound_layer.name + '_%s_%s' % (node_index, tensor_index) input_shape = layers_to_output_shapes[shape_key] input_shapes.append(input_shape) if len(input_shapes) == 1: output_shape = layer._compute_output_shape(input_shapes[0]) # pylint: disable=protected-access else: output_shape = layer._compute_output_shape(input_shapes) # pylint: disable=protected-access if isinstance(output_shape, list): output_shapes = [ tuple(tensor_shape.TensorShape(shape).as_list()) for shape in output_shape ] else: output_shapes = [ tuple(tensor_shape.TensorShape(output_shape).as_list()) ] node_index = layer._inbound_nodes.index(node) # pylint: disable=protected-access for j in range(len(output_shapes)): shape_key = layer.name + '_%s_%s' % (node_index, j) layers_to_output_shapes[shape_key] = output_shapes[j] # Read final output shapes from layers_to_output_shapes. output_shapes = [] for i in range(len(self._output_layers)): layer, node_index, tensor_index = self._output_coordinates[i] shape_key = layer.name + '_%s_%s' % (node_index, tensor_index) output_shapes.append(layers_to_output_shapes[shape_key]) # Store in cache. self._output_shape_cache[cache_key] = output_shapes else: # Cache hit. output_shapes = self._output_shape_cache[cache_key] if isinstance(output_shapes, list): if len(output_shapes) == 1: return tensor_shape.TensorShape(output_shapes[0]) else: return [tensor_shape.TensorShape(shape) for shape in output_shapes] else: return tensor_shape.TensorShape(output_shapes)