Ejemplo n.º 1
0
def test_infer_type():
    x = sym.Variable('x')
    y = sym.add(x, x, name='add1')
    y = sym.cast(y, dtype=1, name="cast1")
    g = graph.create(y)
    g = g.apply('InferType')
    jgraph = json.loads(g.apply('SaveJSON').json_attr('json'))
    jnodes = jgraph['nodes']
    jnode_row_ptr = jgraph['node_row_ptr']
    nindex = {n['name']: i for i, n in enumerate(jnodes)}
    assert g.json_attr('dtype')[jnode_row_ptr[nindex["cast1"]]] == 1
    assert g.json_attr('dtype')[jnode_row_ptr[nindex["add1"]]] == 0
Ejemplo n.º 2
0
def test_infer_type():
    x = sym.Variable('x', dtype=0)
    y = sym.add(x, x, name='add1')
    y = sym.cast(y, dtype=1, name="cast1")
    g = graph.create(y)
    g._set_json_attr("dtype_attr_key", "dtype")
    g = g.apply('InferType')
    jgraph = json.loads(g.apply('SaveJSON').json_attr('json'))
    jnodes = jgraph['nodes']
    jnode_row_ptr = jgraph['node_row_ptr']
    nindex = {n['name']: i for i, n in enumerate(jnodes)}
    assert g.json_attr('dtype')[jnode_row_ptr[nindex["cast1"]]] == 1
    assert g.json_attr('dtype')[jnode_row_ptr[nindex["add1"]]] == 0
Ejemplo n.º 3
0
def test_place_device():
    x = sym.Variable('x', device_group="stage1")
    y = sym.add(x, x, name='add1')
    y = sym.cast(y, dtype=1, name="cast1")
    z = sym.add(y, y, device_group="stage2", name="add2")
    z = sym.add(z, sym.exp(y, device_group="stage2"),  name="add3")
    g = graph.create(z)
    g._set_json_attr("device_group_attr_key", "device_group")
    g._set_json_attr("device_assign_map", {"stage1": 0, "stage2" : 1}, "dict_str_int")
    g._set_json_attr("device_copy_op", "cross_device_copy")
    g = g.apply("PlaceDevice")
    jgraph = json.loads(g.apply('SaveJSON').json_attr('json'))
    jnodes = jgraph['nodes']
    jnode_row_ptr = jgraph['node_row_ptr']
    nindex = {n['name']: i for i, n in enumerate(jnodes)}
    assert g.json_attr('device')[jnode_row_ptr[nindex["add2"]]] == 1
    assert g.json_attr('device')[jnode_row_ptr[nindex["add3"]]] == 1
    assert g.json_attr('device')[jnode_row_ptr[nindex["cast1"]]] == 0
Ejemplo n.º 4
0
def test_place_device():
    x = sym.Variable('x', device_group="stage1")
    y = sym.add(x, x, name='add1')
    y = sym.cast(y, dtype=1, name="cast1")
    z = sym.add(y, y, device_group="stage2", name="add2")
    z = sym.add(z, sym.exp(y, device_group="stage2"),  name="add3")
    g = graph.create(z)
    g._set_json_attr("device_group_attr_key", "device_group")
    g._set_json_attr("device_assign_map", {"stage1": 0, "stage2" : 1}, "dict_str_int")
    g._set_json_attr("device_copy_op", "cross_device_copy")
    g = g.apply("PlaceDevice")
    jgraph = json.loads(g.apply('SaveJSON').json_attr('json'))
    jnodes = jgraph['nodes']
    jnode_row_ptr = jgraph['node_row_ptr']
    nindex = {n['name']: i for i, n in enumerate(jnodes)}
    assert g.json_attr('device')[jnode_row_ptr[nindex["add2"]]] == 1
    assert g.json_attr('device')[jnode_row_ptr[nindex["add3"]]] == 1
    assert g.json_attr('device')[jnode_row_ptr[nindex["cast1"]]] == 0