def to_relay(graph, shape_dict, dtype_dict, params): """Convert an NNVM graph into the corresponding Relay expression. Parameters ---------- graph : Graph The input graph. shape_dict : dict of str to shape The input shape. dtype_dict : dict of str to str/dtype The input shape. params : dict of str to array The parameters. Returns ------- (expr, params) : Tuple[relay.Expr, dict of str to array] The corresponding Relay expression and parameters. """ if isinstance(graph, Symbol): graph = graph_create(graph) param_shapes = dict((k, params[k].shape) for k in params) shape_dict = shape_dict.copy() shape_dict.update(param_shapes) graph = graph_attr.set_shape_inputs(graph, shape_dict) graph = graph_attr.set_dtype_inputs(graph, dtype_dict) graph = graph.apply(["InferShape", "InferType"]) shape = graph.json_attr("shape") dtype = [graph_attr.TCODE_TO_DTYPE[di] for di in graph.json_attr("dtype")] heads = [x[0] for x in json.loads(graph.json())['heads']] gidx = graph.index relay_map = {} fn_params = [] output_ids = [] for nid, node in enumerate(gidx.nodes): children = [] for i in node['inputs']: child = relay_map[i[0]] if isinstance(child, expr.TupleWrapper): children.append(child[i[1]]) else: children.append(child) oshape = shape[gidx.entry_id(nid, 0)] odtype = dtype[gidx.entry_id(nid, 0)] attrs = node.get("attrs", {}) node_name = node["name"] op_name = node["op"] if op_name == "null": v = var(node_name, shape=oshape, dtype=odtype) fn_params.append(v) relay_map[nid] = v else: if nid in heads: output_ids.append(nid) if op_name in NNVM_OP_2_RELAY_OP: str_attrs = StrAttrsDict(attrs) call = NNVM_OP_2_RELAY_OP[op_name](children, str_attrs, odtype) relay_map[nid] = call else: raise Exception( "nnvm.to_relay: unsupported operator: {0}".format(op_name)) outputs = [relay_map[nid] for nid in output_ids] if len(outputs) == 1: body = outputs[0] else: body = expr.Tuple(outputs) func = relay.Function(fn_params, body) return func, params
def test_key_is_not_present(): attrs = StrAttrsDict({"a": 1}) assert not attrs.has_attr("b")