Exemplo n.º 1
0
def test_vector_dummy_data():
    op_map = parse_prototxt(join(PROTO_PATH, "tensor_dummy_data.prototxt"))
    op = op_map.get("A")
    a = np.full((2, 3), 4.)
    with executor(op) as ex:
        res = ex()
    assert (np.array_equal(res, a))
Exemplo n.º 2
0
def test_tensor_const_sum():
    op_map = parse_prototxt(join(PROTO_PATH, "tensor_const_sum.prototxt"))
    op = op_map.get("C")
    with executor(op) as ex:
        res = ex()

    a = np.full((2, 3), 4.)
    b = np.full((2, 3), 3.)
    c = a + b
    assert (np.array_equal(res, c))
Exemplo n.º 3
0
def test_scalar_dummy_data():
    op_map = parse_prototxt(join(PROTO_PATH, "scalar_dummy_data.prototxt"))
    op = op_map.get("A")
    with executor(op) as ex:
        res = ex()
    assert (res == 1.)
Exemplo n.º 4
0
def test_scalar_const_sum():
    op_map = parse_prototxt(join(PROTO_PATH, "scalar_const_sum.prototxt"))
    op = op_map.get("C")
    with executor(op) as ex:
        res = ex()
    assert (res == 4.)
Exemplo n.º 5
0
# ----------------------------------------------------------------------------
# Copyright 2016 Nervana Systems Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ----------------------------------------------------------------------------

from __future__ import print_function
import ngraph.transformers as ngt
from ngraph.frontends.caffe.cf_importer.importer import parse_prototxt

model = "sum.prototxt"
# import graph from the prototxt
op_map = parse_prototxt(model, verbose=True)
# get the op handle for any layer
op = op_map.get("D")
# execute the op handle
res = ngt.make_transformer().computation(op)()
print("Result is:", res)
# EOF