def _torch_while(self, data: Dict[str, Tensor], state: Dict[str, Any]) -> Dict[str, Tensor]: """A helper function to invoke a while loop. Args: data: A data dictionary to be used during looping. state: The state variables to be considered during looping. Returns: A reference to the updated data dictionary. """ while self.repeat(*[data[var_name] for var_name in self.repeat_inputs]): BaseNetwork._forward_batch(data, state, self.ops) return data
def __init__(self, pipeline: Pipeline, network: BaseNetwork, epochs: int, train_steps_per_epoch: Optional[int] = None, eval_steps_per_epoch: Optional[int] = None, traces: Union[None, Trace, Scheduler[Trace], Iterable[Union[Trace, Scheduler[Trace]]]] = None, log_steps: Optional[int] = 100, monitor_names: Union[None, str, Iterable[str]] = None): self.traces_in_use = [] self.filepath = os.path.realpath( inspect.stack()[2].filename) # Record this for history tracking assert log_steps is None or log_steps >= 0, \ "log_steps must be None or positive (or 0 to disable only train logging)" self.monitor_names = to_set(monitor_names) | network.get_loss_keys() self.system = System(network=network, pipeline=pipeline, traces=to_list(traces), log_steps=log_steps, total_epochs=epochs, train_steps_per_epoch=train_steps_per_epoch, eval_steps_per_epoch=eval_steps_per_epoch, system_config=self.fe_summary())
def _tf_body( self, cnd: List[Tensor], data: Dict[str, Tensor], state: Dict[str, Any] ) -> Tuple[List[Tensor], Dict[str, Tensor], Dict[str, Any]]: """A helper function to execute the body of a while method. Note that `cnd` is unused here, but required since tf.while_loop needs the cond and body to have the same input argument signatures. Args: cnd: A list of arguments to be passed to the condition function. data: A data dictionary to be used during looping. state: The state variables to be considered during looping. Returns: The updated `cnd` values, along with the modified data and state dictionaries. """ BaseNetwork._forward_batch(data, state, self.ops) return [data[var_name] for var_name in self.repeat_inputs], data, state
def __init__(self, pipeline: Pipeline, network: BaseNetwork, epochs: int, max_train_steps_per_epoch: Optional[int] = None, max_eval_steps_per_epoch: Optional[int] = None, traces: Union[None, Trace, Scheduler[Trace], Iterable[Union[Trace, Scheduler[Trace]]]] = None, log_steps: Optional[int] = 100, monitor_names: Union[None, str, Iterable[str]] = None): self.traces_in_use = [] assert log_steps is None or log_steps >= 0, \ "log_steps must be None or positive (or 0 to disable only train logging)" self.monitor_names = to_set(monitor_names) | network.get_loss_keys() self.system = System(network=network, pipeline=pipeline, traces=to_list(traces), log_steps=log_steps, total_epochs=epochs, max_train_steps_per_epoch=max_train_steps_per_epoch, max_eval_steps_per_epoch=max_eval_steps_per_epoch, system_config=self.fe_summary())
def forward(self, data: List[Tensor], state: Dict[str, Any]) -> List[Tensor]: # Set retain to true since might loop over a gradient aware op self.op.fe_retain_graph(True) data = {key: elem for key, elem in zip(self.inputs, data)} if isinstance(self.repeat, int): for i in range(self.repeat - 1): # Perform n-1 rounds with all ops having retain_graph == True BaseNetwork._forward_batch(data, state, self.ops) # Let retain be whatever it was meant to be for the final sequence self.op.fe_retain_graph(self.retain_graph) # Final round of ops BaseNetwork._forward_batch(data, state, self.ops) else: BaseNetwork._forward_batch(data, state, self.ops) data = self.while_fn(data, state) # TODO - Find some magic way to invoke this at the right moment self.op.fe_retain_graph(self.retain_graph) return [data[key] for key in self.outputs]
def forward(self, data: List[Tensor], state: Dict[str, Any]) -> List[Tensor]: data = {key: elem for key, elem in zip(self.inputs, data)} BaseNetwork._forward_batch(data, state, self.ops) return [data[key] for key in self.outputs]