Exemple #1
0
    def _torch_while(self, data: Dict[str, Tensor],
                     state: Dict[str, Any]) -> Dict[str, Tensor]:
        """A helper function to invoke a while loop.

        Args:
            data: A data dictionary to be used during looping.
            state: The state variables to be considered during looping.

        Returns:
            A reference to the updated data dictionary.
        """
        while self.repeat(*[data[var_name]
                            for var_name in self.repeat_inputs]):
            BaseNetwork._forward_batch(data, state, self.ops)
        return data
Exemple #2
0
 def __init__(self,
              pipeline: Pipeline,
              network: BaseNetwork,
              epochs: int,
              train_steps_per_epoch: Optional[int] = None,
              eval_steps_per_epoch: Optional[int] = None,
              traces: Union[None, Trace, Scheduler[Trace],
                            Iterable[Union[Trace,
                                           Scheduler[Trace]]]] = None,
              log_steps: Optional[int] = 100,
              monitor_names: Union[None, str, Iterable[str]] = None):
     self.traces_in_use = []
     self.filepath = os.path.realpath(
         inspect.stack()[2].filename)  # Record this for history tracking
     assert log_steps is None or log_steps >= 0, \
         "log_steps must be None or positive (or 0 to disable only train logging)"
     self.monitor_names = to_set(monitor_names) | network.get_loss_keys()
     self.system = System(network=network,
                          pipeline=pipeline,
                          traces=to_list(traces),
                          log_steps=log_steps,
                          total_epochs=epochs,
                          train_steps_per_epoch=train_steps_per_epoch,
                          eval_steps_per_epoch=eval_steps_per_epoch,
                          system_config=self.fe_summary())
Exemple #3
0
    def _tf_body(
        self, cnd: List[Tensor], data: Dict[str, Tensor], state: Dict[str, Any]
    ) -> Tuple[List[Tensor], Dict[str, Tensor], Dict[str, Any]]:
        """A helper function to execute the body of a while method.

        Note that `cnd` is unused here, but required since tf.while_loop needs the cond and body to have the same input
        argument signatures.

        Args:
            cnd: A list of arguments to be passed to the condition function.
            data: A data dictionary to be used during looping.
            state: The state variables to be considered during looping.

        Returns:
            The updated `cnd` values, along with the modified data and state dictionaries.
        """
        BaseNetwork._forward_batch(data, state, self.ops)
        return [data[var_name] for var_name in self.repeat_inputs], data, state
Exemple #4
0
 def __init__(self,
              pipeline: Pipeline,
              network: BaseNetwork,
              epochs: int,
              max_train_steps_per_epoch: Optional[int] = None,
              max_eval_steps_per_epoch: Optional[int] = None,
              traces: Union[None, Trace, Scheduler[Trace], Iterable[Union[Trace, Scheduler[Trace]]]] = None,
              log_steps: Optional[int] = 100,
              monitor_names: Union[None, str, Iterable[str]] = None):
     self.traces_in_use = []
     assert log_steps is None or log_steps >= 0, \
         "log_steps must be None or positive (or 0 to disable only train logging)"
     self.monitor_names = to_set(monitor_names) | network.get_loss_keys()
     self.system = System(network=network,
                          pipeline=pipeline,
                          traces=to_list(traces),
                          log_steps=log_steps,
                          total_epochs=epochs,
                          max_train_steps_per_epoch=max_train_steps_per_epoch,
                          max_eval_steps_per_epoch=max_eval_steps_per_epoch,
                          system_config=self.fe_summary())
Exemple #5
0
    def forward(self, data: List[Tensor], state: Dict[str,
                                                      Any]) -> List[Tensor]:
        # Set retain to true since might loop over a gradient aware op
        self.op.fe_retain_graph(True)

        data = {key: elem for key, elem in zip(self.inputs, data)}
        if isinstance(self.repeat, int):
            for i in range(self.repeat - 1):
                # Perform n-1 rounds with all ops having retain_graph == True
                BaseNetwork._forward_batch(data, state, self.ops)
            # Let retain be whatever it was meant to be for the final sequence
            self.op.fe_retain_graph(self.retain_graph)
            # Final round of ops
            BaseNetwork._forward_batch(data, state, self.ops)
        else:
            BaseNetwork._forward_batch(data, state, self.ops)
            data = self.while_fn(data, state)
            # TODO - Find some magic way to invoke this at the right moment
            self.op.fe_retain_graph(self.retain_graph)
        return [data[key] for key in self.outputs]
Exemple #6
0
 def forward(self, data: List[Tensor], state: Dict[str,
                                                   Any]) -> List[Tensor]:
     data = {key: elem for key, elem in zip(self.inputs, data)}
     BaseNetwork._forward_batch(data, state, self.ops)
     return [data[key] for key in self.outputs]