示例#1
0
def forward_numpyop(ops: List[NumpyOp],
                    data: MutableMapping[str, Any],
                    state: Dict[str, Any],
                    batched: bool = False) -> None:
    """Call the forward function for list of NumpyOps, and modify the data dictionary in place.

    Args:
        ops: A list of NumpyOps to execute.
        data: The data dictionary.
        state: Information about the current execution context, ex. {"mode": "train"}. Must contain at least the mode.
        batched: Whether the `data` is batched or not.
    """
    for op in ops:
        op_data = get_inputs_by_op(op, data, copy_on_write=op.in_place_edits)
        try:
            op_data = op.forward_batch(
                op_data, state) if batched else op.forward(op_data, state)
        except ValueError as err:
            if err.args[0] == 'assignment destination is read-only':
                # If the numpy error text changes we'll need to make adjustments in the future
                op.in_place_edits = True
                op_data = get_inputs_by_op(op,
                                           data,
                                           copy_on_write=op.in_place_edits)
                op_data = op.forward_batch(
                    op_data, state) if batched else op.forward(op_data, state)
            else:
                raise err
        if isinstance(op, Delete):
            for key in op.inputs:
                del data[key]
        if op.outputs:
            write_outputs_by_op(op, data, op_data)
示例#2
0
    def _forward_batch(batch: MutableMapping[str, Any], state: Dict[str, Any], ops: List[TensorOp]) -> None:
        """Run a forward pass through the network's Op chain given a `batch` of data.

        Args:
            batch: A batch of input data. Predictions from the network will be written back into this dictionary.
            state: A dictionary holding information about the current execution context. The TF gradient tape, for
                example will be stored here.
            ops: Which ops to execute.
        """
        for op in ops:
            data = get_inputs_by_op(op, batch)
            data = op.forward(data, state)
            if op.outputs:
                write_outputs_by_op(op, batch, data)
示例#3
0
def forward_numpyop(ops: List[NumpyOp], data: MutableMapping[str, Any], mode: str) -> None:
    """Call the forward function for list of NumpyOps, and modify the data dictionary in place.

    Args:
        ops: A list of NumpyOps to execute.
        data: The data dictionary.
        mode: The current execution mode ("train", "eval", "test", or "infer").
    """
    for op in ops:
        op_data = get_inputs_by_op(op, data)
        op_data = op.forward(op_data, {"mode": mode})
        if isinstance(op, Delete):
            for key in op.inputs:
                del data[key]
        if op.outputs:
            write_outputs_by_op(op, data, op_data)
示例#4
0
def forward_numpyop(ops: List[NumpyOp],
                    data: MutableMapping[str, Any],
                    state: Dict[str, Any],
                    batched: Optional[str] = None) -> Optional[FilteredData]:
    """Call the forward function for list of NumpyOps, and modify the data dictionary in place.

    Args:
        ops: A list of NumpyOps to execute.
        data: The data dictionary.
        state: Information about the current execution context, ex. {"mode": "train"}. Must contain at least the mode.
        batched: Whether the `data` is batched or not. If it is batched, provide the string ('tf', 'torch', or 'np')
            indicating which type of tensors the batch contains.
    """
    if batched:
        # Cast data to Numpy before performing batch forward
        for key, val in data.items():
            data[key] = to_tensor(val, target_type='np')
    for op in ops:
        op_data = get_inputs_by_op(op, data, copy_on_write=op.in_place_edits)
        try:
            op_data = op.forward_batch(
                op_data, state) if batched else op.forward(op_data, state)
        except ValueError as err:
            if err.args[0] == 'assignment destination is read-only':
                # If the numpy error text changes we'll need to make adjustments in the future
                op.in_place_edits = True
                op_data = get_inputs_by_op(op,
                                           data,
                                           copy_on_write=op.in_place_edits)
                op_data = op.forward_batch(
                    op_data, state) if batched else op.forward(op_data, state)
            else:
                raise err
        if isinstance(op_data, FilteredData):
            return op_data
        if isinstance(op, Delete):
            for key in op.inputs:
                del data[key]
        if op.outputs:
            write_outputs_by_op(op, data, op_data)
    if batched:
        # Cast data back to original tensor type after performing batch forward
        for key, val in data.items():
            data[key] = to_tensor(val, target_type=batched, shared_memory=True)
    return None
示例#5
0
 def test_multi_key_multi_data(self):
     batch = {}
     write_outputs_by_op(op=Op(outputs=["x", "y"]),
                         store=batch,
                         outputs=[1, [1, 2]])
     self.assertEqual(batch, {"x": 1, "y": [1, 2]})
示例#6
0
 def test_single_key_multi_data(self):
     batch = {}
     write_outputs_by_op(op=Op(outputs="x"), store=batch, outputs=[1, 2])
     self.assertEqual(batch, {"x": [1, 2]})
示例#7
0
 def test_single_key_single_data(self):
     batch = {}
     write_outputs_by_op(op=Op(outputs="x"), store=batch, outputs=1)
     self.assertEqual(batch, {"x": 1})