def forward_numpyop(ops: List[NumpyOp], data: MutableMapping[str, Any], state: Dict[str, Any], batched: bool = False) -> None: """Call the forward function for list of NumpyOps, and modify the data dictionary in place. Args: ops: A list of NumpyOps to execute. data: The data dictionary. state: Information about the current execution context, ex. {"mode": "train"}. Must contain at least the mode. batched: Whether the `data` is batched or not. """ for op in ops: op_data = get_inputs_by_op(op, data, copy_on_write=op.in_place_edits) try: op_data = op.forward_batch( op_data, state) if batched else op.forward(op_data, state) except ValueError as err: if err.args[0] == 'assignment destination is read-only': # If the numpy error text changes we'll need to make adjustments in the future op.in_place_edits = True op_data = get_inputs_by_op(op, data, copy_on_write=op.in_place_edits) op_data = op.forward_batch( op_data, state) if batched else op.forward(op_data, state) else: raise err if isinstance(op, Delete): for key in op.inputs: del data[key] if op.outputs: write_outputs_by_op(op, data, op_data)
def _forward_batch(batch: MutableMapping[str, Any], state: Dict[str, Any], ops: List[TensorOp]) -> None: """Run a forward pass through the network's Op chain given a `batch` of data. Args: batch: A batch of input data. Predictions from the network will be written back into this dictionary. state: A dictionary holding information about the current execution context. The TF gradient tape, for example will be stored here. ops: Which ops to execute. """ for op in ops: data = get_inputs_by_op(op, batch) data = op.forward(data, state) if op.outputs: write_outputs_by_op(op, batch, data)
def forward_numpyop(ops: List[NumpyOp], data: MutableMapping[str, Any], mode: str) -> None: """Call the forward function for list of NumpyOps, and modify the data dictionary in place. Args: ops: A list of NumpyOps to execute. data: The data dictionary. mode: The current execution mode ("train", "eval", "test", or "infer"). """ for op in ops: op_data = get_inputs_by_op(op, data) op_data = op.forward(op_data, {"mode": mode}) if isinstance(op, Delete): for key in op.inputs: del data[key] if op.outputs: write_outputs_by_op(op, data, op_data)
def forward_numpyop(ops: List[NumpyOp], data: MutableMapping[str, Any], state: Dict[str, Any], batched: Optional[str] = None) -> Optional[FilteredData]: """Call the forward function for list of NumpyOps, and modify the data dictionary in place. Args: ops: A list of NumpyOps to execute. data: The data dictionary. state: Information about the current execution context, ex. {"mode": "train"}. Must contain at least the mode. batched: Whether the `data` is batched or not. If it is batched, provide the string ('tf', 'torch', or 'np') indicating which type of tensors the batch contains. """ if batched: # Cast data to Numpy before performing batch forward for key, val in data.items(): data[key] = to_tensor(val, target_type='np') for op in ops: op_data = get_inputs_by_op(op, data, copy_on_write=op.in_place_edits) try: op_data = op.forward_batch( op_data, state) if batched else op.forward(op_data, state) except ValueError as err: if err.args[0] == 'assignment destination is read-only': # If the numpy error text changes we'll need to make adjustments in the future op.in_place_edits = True op_data = get_inputs_by_op(op, data, copy_on_write=op.in_place_edits) op_data = op.forward_batch( op_data, state) if batched else op.forward(op_data, state) else: raise err if isinstance(op_data, FilteredData): return op_data if isinstance(op, Delete): for key in op.inputs: del data[key] if op.outputs: write_outputs_by_op(op, data, op_data) if batched: # Cast data back to original tensor type after performing batch forward for key, val in data.items(): data[key] = to_tensor(val, target_type=batched, shared_memory=True) return None
def test_multi_key_multi_data(self): batch = {} write_outputs_by_op(op=Op(outputs=["x", "y"]), store=batch, outputs=[1, [1, 2]]) self.assertEqual(batch, {"x": 1, "y": [1, 2]})
def test_single_key_multi_data(self): batch = {} write_outputs_by_op(op=Op(outputs="x"), store=batch, outputs=[1, 2]) self.assertEqual(batch, {"x": [1, 2]})
def test_single_key_single_data(self): batch = {} write_outputs_by_op(op=Op(outputs="x"), store=batch, outputs=1) self.assertEqual(batch, {"x": 1})