def _convert_split_axis(converter: ChainerConverter, c_op: "chainer.functions.SplitAxis"): x = converter.get_variable(c_op.inputs[0]) VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH = semver(chainer.__version__) if VERSION_MAJOR >= 4: # Internal data structure changed # https://github.com/chainer/chainer/commit/906a8e9b0837cd9a4e5ee6f1dbda26431a1e12d1#diff-9e610d281c820d44c4a0cbf0ca6263fd if c_op.indices is None: raise NotImplementedError( "[ChainerConverter] SplitAxis with sections are not supported." ) indices = c_op.indices else: if isinstance(c_op.indices_or_sections, int): raise NotImplementedError( "[ChainerConverter] SplitAxis with sections are not supported." ) indices = c_op.indices_or_sections ys = SplitAxis(None, sections=indices, axis=x.order.axes[c_op.axis])(x) for i, y in enumerate(ys): converter.set_variable(c_op.outputs[i](), y)
def get_variable_data( variable: "T_VARIABLE") -> Union[np.ndarray, "chainer.cuda.ndarray"]: ... # return variable's data def to_variable_node(c_var: "chainer.Variable") -> "T_VARIABLE": ... # convert "chainer.Variable" into variable node (T_VARIABLE instance) FLAG_CHAINER_INSTALLED = False try: import chainer import chainer.computational_graph VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH = semver(chainer.__version__) if VERSION_MAJOR == 3: # v3.x.x # In v3, Many functions are represented as instance of `chainer.function_node.FunctionNode`. However some functions are still # instance of `chainer.function.Function` (ex. Im2Col). T_FUNCTION = (chainer.FunctionNode, chainer.Function) T_VARIABLE = chainer.variable.VariableNode def get_variable_data(variable: T_VARIABLE): # noinspection PyProtectedMember return variable._variable( ).data if variable.data is None else variable.data def to_variable_node(c_var: chainer.Variable):
from webdnn.graph.order import Order from webdnn.graph.placeholder import Placeholder from webdnn.graph.variable import Variable from webdnn.graph.variables.attributes.input import Input from webdnn.graph.variables.attributes.output import Output from webdnn.graph.variables.constant_variable import ConstantVariable from webdnn.util import console FLAG_KERAS_INSTALLED = False try: import keras import keras.backend as K import tensorflow as tf VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH = semver(keras.__version__) if not (VERSION_MAJOR == 2 and ( (VERSION_MINOR == 1 and VERSION_PATCH >= 3) or VERSION_MINOR >= 2)): raise NotImplementedError( f"WebDNN supports Keras v2.*.*. Currently, keras {keras.__version__} is installed." ) FLAG_KERAS_INSTALLED = True except Exception as e: console.warning(traceback.format_exc()) def _to_list(x): return x if isinstance(x, (list, tuple)) else [x]
""" import tempfile from os import path from webdnn.frontend.converter import Converter from webdnn.frontend.onnx import ONNXConverter from webdnn.frontend.util import semver from webdnn.graph.graph import Graph from webdnn.util import console FLAG_PYTORCH_INSTALLED = False try: import torch import torch.onnx VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH = semver(torch.__version__) if not ((VERSION_MAJOR == 0) and (VERSION_MINOR >= 0.3)): raise NotImplementedError( f"WebDNN supports PyTorch >= v0.3 Currently, PyTorch {torch.__version__} is installed." ) FLAG_PYTORCH_INSTALLED = True except ImportError as e: console.debug("PyTorch is not completely installed.") pass FLAG_ONNX_INSTALLED = False try: import onnx
from webdnn.graph.order import Order from webdnn.graph.placeholder import Placeholder from webdnn.graph.variable import Variable from webdnn.graph.variables.attributes.input import Input from webdnn.graph.variables.attributes.output import Output from webdnn.graph.variables.constant_variable import ConstantVariable from webdnn.optimizer.sub_rules.constant_folding import ConstantFolding from webdnn.optimizer.tensorflow_frontend_optimize_rule import TensorFlowFrontendOptimizeRule from webdnn.util import console FLAG_TF_INSTALLED = True try: import tensorflow as tf VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH = semver(tf.VERSION) if not ((VERSION_MAJOR == 1) and (2 <= VERSION_MINOR <= 4)): raise NotImplementedError( f"WebDNN supports TensorFlow >=v1.2.0,<=v1.4.0 Currently, TensorFlow {tf.VERSION} is installed." ) except Exception as e: console.warning(traceback.format_exc()) class TensorFlowConverter(Converter["tf.Operation"]): """TensorFlowConverter(batch_size=1) Converter for `TensorFlow <https://www.tensorflow.org/>`_ Args: