Esempio n. 1
0
    def _chainerx_apply_fallback_postprocess(
            self, chainerx_in_data, inputs, outputs):

        # TODO(hvy): Take configuration.config.enable_backprop into
        # account?
        chainerx_out_data = backend.to_chainerx(outputs)

        # Insert a ChainerX op-node that calls FunctionNode.backward in
        # backprop. Note that chainerx_out_data may not require gradients.
        chainerx._core._function_node_forward(
            self, chainerx_in_data, chainerx_out_data,
            [] if self._input_indexes_to_retain is None
            else self._input_indexes_to_retain,
            [] if self._output_indexes_to_retain is None
            else self._output_indexes_to_retain)

        self.inputs = tuple([
            None if x is None
            else variable._ChainerxVariableNodeProps(x) for x in inputs])

        ret = tuple([
            _to_variable_with_chainerx_fallback_array(
                chainerx_out_array, out_array)
            for chainerx_out_array, out_array
            in six.moves.zip(chainerx_out_data, outputs)])
        return ret
Esempio n. 2
0
    def _chainerx_apply_fallback_postprocess(
            self, chainerx_in_data, inputs, outputs):

        # TODO(hvy): Take configuration.config.enable_backprop into
        # account?
        chainerx_out_data = backend.to_chainerx(outputs)

        # Insert a ChainerX op-node that calls FunctionNode.backward in
        # backprop. Note that chainerx_out_data may not require gradients.
        chainerx._core._function_node_forward(
            self, chainerx_in_data, chainerx_out_data,
            [] if self._input_indexes_to_retain is None
            else self._input_indexes_to_retain,
            [] if self._output_indexes_to_retain is None
            else self._output_indexes_to_retain)

        self.inputs = tuple(
            [variable._ChainerxVariableNodeProps(x) for x in inputs])

        ret = tuple([
            _to_variable_with_chainerx_fallback_array(
                chainerx_out_array, out_array)
            for chainerx_out_array, out_array
            in six.moves.zip(chainerx_out_data, outputs)])
        return ret