Exemple #1
0
    def __call__(self, *args, return_losses=False, **kwargs):
        """Reset the losses dict on each call.

    Args:
      *args: Arguments passed on to call().
      return_losses: Return a dictionary of losses in addition to the call()
        function returns.
      **kwargs: Keyword arguments passed on to call().

    Returns:
      outputs: A dictionary of model outputs generated in call().
        {output_name: output_tensor or dict}.
      losses: If return_losses=True, also returns a dictionary of losses,
        {loss_name: loss_value}.
    """
        # Copy mutable dicts if in graph mode to prevent side-effects (pure func).
        args = [
            copy_if_tf_function(a) if isinstance(a, dict) else a for a in args
        ]

        # Run model.
        self._losses_dict = {}
        outputs = super().__call__(*args, **kwargs)

        # Get total loss.
        if not return_losses:
            return outputs
        else:
            self._losses_dict['total_loss'] = self.sum_losses(
                self._losses_dict)
            return outputs, self._losses_dict
Exemple #2
0
  def call(self, conditioning):
    """Updates conditioning with dictionary of decoder outputs."""
    conditioning = core.copy_if_tf_function(conditioning)
    x = self.decode(conditioning)
    outputs = nn.split_to_dict(x, self.output_splits)

    if isinstance(outputs, dict):
      conditioning.update(outputs)
    else:
      raise ValueError('Decoder must output a dictionary of signals.')
    return conditioning
Exemple #3
0
    def run_dag(self,
                inputs: TensorDict,
                verbose: bool = True,
                **kwargs) -> TensorDict:
        """Connects and runs submodules of dag.

    Args:
      inputs: A dictionary of input tensors fed to the dag.
      verbose: Print out dag routing when running.
      **kwargs: Other kwargs to pass to submodules, such as keras kwargs.

    Returns:
      A nested dictionary of all the output tensors.
    """
        inputs = core.copy_if_tf_function(inputs)
        # Initialize the outputs with inputs to the dag.
        outputs = {'inputs': inputs}
        # TODO(jesseengel): Remove this cluttering of the base namespace. Only there
        # for backwards compatability.
        outputs.update(inputs)

        # Run through the DAG nodes in sequential order.
        for node in self.dag:
            # The first element of the node can be either a module or module_key.
            module_key, input_keys = node[0], node[1]
            module = getattr(self, module_key)
            # Optionally specify output keys if module does not return dict.
            output_keys = node[2] if len(node) > 2 else None

            # Get the inputs to the node.
            inputs = [core.nested_lookup(key, outputs) for key in input_keys]

            if verbose:
                shape = lambda d: tf.nest.map_structure(
                    lambda x: list(x.shape), d)
                logging.info('Input to Module: %s\nKeys: %s\nIn: %s\n',
                             module_key, input_keys, shape(inputs))

            if is_processor(module):
                # Processor modules.
                module_outputs = module(*inputs,
                                        return_outputs_dict=True,
                                        **kwargs)
            else:
                # Network modules.
                module_outputs = module(*inputs, **kwargs)

            if not isinstance(module_outputs, dict):
                module_outputs = core.to_dict(module_outputs, output_keys)

            if verbose:
                logging.info('Output from Module: %s\nOut: %s\n', module_key,
                             shape(module_outputs))

            # Add module outputs to the dictionary.
            outputs[module_key] = module_outputs

        # Alias final module output as dag output.
        # 'out' is a reserved key for final dag output.
        outputs['out'] = module_outputs

        return outputs
Exemple #4
0
 def call(self, dag_inputs: TensorDict) -> tf.Tensor:
     """Like Processor, but specific to having an input dictionary."""
     dag_inputs = core.copy_if_tf_function(dag_inputs)
     dag_outputs = self.get_controls(dag_inputs)
     signal = self.get_signal(dag_outputs)
     return signal