コード例 #1
0
ファイル: array.py プロジェクト: VislaLabs/webdnn-1
def _convert_split_axis(converter: ChainerConverter,
                        c_op: "chainer.functions.SplitAxis"):
    x = converter.get_variable(c_op.inputs[0])

    VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH = semver(chainer.__version__)
    if VERSION_MAJOR >= 4:
        # Internal data structure changed
        # https://github.com/chainer/chainer/commit/906a8e9b0837cd9a4e5ee6f1dbda26431a1e12d1#diff-9e610d281c820d44c4a0cbf0ca6263fd
        if c_op.indices is None:
            raise NotImplementedError(
                "[ChainerConverter] SplitAxis with sections are not supported."
            )
        indices = c_op.indices
    else:
        if isinstance(c_op.indices_or_sections, int):
            raise NotImplementedError(
                "[ChainerConverter] SplitAxis with sections are not supported."
            )
        indices = c_op.indices_or_sections

    ys = SplitAxis(None, sections=indices, axis=x.order.axes[c_op.axis])(x)
    for i, y in enumerate(ys):
        converter.set_variable(c_op.outputs[i](), y)
コード例 #2
0
def get_variable_data(
        variable: "T_VARIABLE") -> Union[np.ndarray, "chainer.cuda.ndarray"]:
    ...  # return variable's data


def to_variable_node(c_var: "chainer.Variable") -> "T_VARIABLE":
    ...  # convert "chainer.Variable" into variable node (T_VARIABLE instance)


FLAG_CHAINER_INSTALLED = False

try:
    import chainer
    import chainer.computational_graph

    VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH = semver(chainer.__version__)

    if VERSION_MAJOR == 3:
        # v3.x.x

        # In v3, Many functions are represented as instance of `chainer.function_node.FunctionNode`. However some functions are still
        # instance of `chainer.function.Function` (ex. Im2Col).
        T_FUNCTION = (chainer.FunctionNode, chainer.Function)
        T_VARIABLE = chainer.variable.VariableNode

        def get_variable_data(variable: T_VARIABLE):
            # noinspection PyProtectedMember
            return variable._variable(
            ).data if variable.data is None else variable.data

        def to_variable_node(c_var: chainer.Variable):
コード例 #3
0
ファイル: converter.py プロジェクト: zhangaz1/webdnn
from webdnn.graph.order import Order
from webdnn.graph.placeholder import Placeholder
from webdnn.graph.variable import Variable
from webdnn.graph.variables.attributes.input import Input
from webdnn.graph.variables.attributes.output import Output
from webdnn.graph.variables.constant_variable import ConstantVariable
from webdnn.util import console

FLAG_KERAS_INSTALLED = False

try:
    import keras
    import keras.backend as K
    import tensorflow as tf

    VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH = semver(keras.__version__)
    if not (VERSION_MAJOR == 2 and (
        (VERSION_MINOR == 1 and VERSION_PATCH >= 3) or VERSION_MINOR >= 2)):
        raise NotImplementedError(
            f"WebDNN supports Keras v2.*.*. Currently, keras {keras.__version__} is installed."
        )

    FLAG_KERAS_INSTALLED = True

except Exception as e:
    console.warning(traceback.format_exc())


def _to_list(x):
    return x if isinstance(x, (list, tuple)) else [x]
コード例 #4
0
ファイル: converter.py プロジェクト: saibabanadh/webdnn
"""
import tempfile
from os import path

from webdnn.frontend.converter import Converter
from webdnn.frontend.onnx import ONNXConverter
from webdnn.frontend.util import semver
from webdnn.graph.graph import Graph
from webdnn.util import console

FLAG_PYTORCH_INSTALLED = False
try:
    import torch
    import torch.onnx

    VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH = semver(torch.__version__)
    if not ((VERSION_MAJOR == 0) and (VERSION_MINOR >= 0.3)):
        raise NotImplementedError(
            f"WebDNN supports PyTorch >= v0.3 Currently, PyTorch {torch.__version__} is installed."
        )

    FLAG_PYTORCH_INSTALLED = True

except ImportError as e:
    console.debug("PyTorch is not completely installed.")
    pass

FLAG_ONNX_INSTALLED = False
try:
    import onnx
コード例 #5
0
ファイル: converter.py プロジェクト: fossabot/hash2face
from webdnn.graph.order import Order
from webdnn.graph.placeholder import Placeholder
from webdnn.graph.variable import Variable
from webdnn.graph.variables.attributes.input import Input
from webdnn.graph.variables.attributes.output import Output
from webdnn.graph.variables.constant_variable import ConstantVariable
from webdnn.optimizer.sub_rules.constant_folding import ConstantFolding
from webdnn.optimizer.tensorflow_frontend_optimize_rule import TensorFlowFrontendOptimizeRule
from webdnn.util import console

FLAG_TF_INSTALLED = True

try:
    import tensorflow as tf

    VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH = semver(tf.VERSION)
    if not ((VERSION_MAJOR == 1) and (2 <= VERSION_MINOR <= 4)):
        raise NotImplementedError(
            f"WebDNN supports TensorFlow >=v1.2.0,<=v1.4.0 Currently, TensorFlow {tf.VERSION} is installed."
        )

except Exception as e:
    console.warning(traceback.format_exc())


class TensorFlowConverter(Converter["tf.Operation"]):
    """TensorFlowConverter(batch_size=1)

    Converter for `TensorFlow <https://www.tensorflow.org/>`_

    Args: