Ejemplo n.º 1
0
def parse_flat(op_str, node):
    if type(node) == FunctionNode:
        assert len(node.inedges) == 2, "wrong number of inputs"
    elif type(node) == ModuleNode:
        assert len(node.inedges) == 1, "wrong number of inputs"
    op_str = op_str + enum_to_str(OpType, OpType.FLAT) + "\n"
    return op_str
Ejemplo n.º 2
0
def parse_permute(op_str, node):
    assert len(node.inedges) >= 1
    op_str = op_str + enum_to_str(OpType, OpType.PERMUTE) + ", "
    for dim in node.inedges[1:-1]:
        op_str = op_str + str(dim) + ", "
    op_str = op_str + str(node.inedges[-1]) + "\n"
    return op_str
Ejemplo n.º 3
0
def parse_concat(op_str, node):
    #FIXME assume it is a merge
    assert len(node.inedges[0]) >= 2, "wrong number of inputs"
    op_str = op_str + enum_to_str(OpType, OpType.CONCAT) + ", "
    if len(node.inedges) == 1:
        op_str = op_str + str(1) + "\n"
    else:
        op_str = op_str + str(node.inedges[1]) + "\n"
    return op_str
Ejemplo n.º 4
0
def parse_linear(op_str, node):
    assert len(node.inedges) == 1, "wrong number of inputs"
    op_str = op_str + enum_to_str(OpType, OpType.LINEAR) + ", "
    op_str = op_str + str(node.module.out_features) + ", "
    op_str = op_str + str(enum_to_int(ActiMode, ActiMode.AC_MODE_NONE)) + ", "
    if node.module.bias != None:
        op_str = op_str + "1\n"
    else:
        op_str = op_str + "0\n"
    return op_str
Ejemplo n.º 5
0
def parse_expand(op_str, node):
    assert len(node.inedges) >= 1, "wrong number of inputs"
    op_str = op_str + enum_to_str(OpType, OpType.EXPAND) + ", "
    input_shape = node.inedges[1:]
    for dim in input_shape[:-1]:
        op_str = op_str + (str(dim) if type(dim) is int else
                           (str(dim) + ":")) + ", "
    op_str = op_str + (str(input_shape[-1]) if type(input_shape[-1]) is int
                       else (str(input_shape[-1]) + ":")) + "\n"
    return op_str
Ejemplo n.º 6
0
def parse_adaptivepool2d(op_str, node, pool_type):
    assert len(node.inedges) == 1, "wrong number of inputs"
    #FIXME fix kernel, stride and padding
    op_str = op_str + enum_to_str(OpType, OpType.POOL2D) + ", "
    op_str = op_str + str(3) + ", "
    op_str = op_str + str(1) + ", "
    op_str = op_str + str(0) + ", "
    op_str = op_str + str(enum_to_int(PoolType, pool_type)) + ", "
    op_str = op_str + str(enum_to_int(ActiMode, ActiMode.AC_MODE_NONE)) + "\n"
    return op_str
Ejemplo n.º 7
0
def parse_pool2d(op_str, node, pool_type):
    assert len(node.inedges) == 1, "wrong number of inputs"
    #FIXME MaxPool2d supports ceil_mode
    op_str = op_str + enum_to_str(OpType, OpType.POOL2D) + ", "
    op_str = op_str + str(node.module.kernel_size) + ", "
    op_str = op_str + str(node.module.stride) + ", "
    op_str = op_str + str(node.module.padding) + ", "
    op_str = op_str + str(enum_to_int(PoolType, pool_type)) + ", "
    op_str = op_str + str(enum_to_int(ActiMode, ActiMode.AC_MODE_NONE)) + "\n"
    return op_str
Ejemplo n.º 8
0
def parse_reshape(op_str, node):
    assert len(node.inedges) >= 2
    op_str = op_str + enum_to_str(OpType, OpType.RESHAPE) + ", "
    if len(node.inedges) == 2:
        input_shape = node.inedges[1]
    else:
        input_shape = node.inedges[1:]
    for dim in input_shape[:-1]:
        op_str = op_str + (str(dim) if type(dim) is int else
                           (str(dim) + ":")) + ", "
    op_str = op_str + (str(input_shape[-1]) if type(input_shape[-1]) is int
                       else (str(input_shape[-1]) + ":")) + "\n"
    return op_str
Ejemplo n.º 9
0
def parse_conv2d(op_str, node):
    assert len(node.inedges) == 1, "wrong number of inputs"
    op_str = op_str + enum_to_str(OpType, OpType.CONV2D) + ", "
    op_str = op_str + str(node.module.out_channels) + ", "
    op_str = op_str + str(node.module.kernel_size[0]) + ", "
    op_str = op_str + str(node.module.kernel_size[1]) + ", "
    op_str = op_str + str(node.module.stride[0]) + ", "
    op_str = op_str + str(node.module.stride[1]) + ", "
    op_str = op_str + str(node.module.padding[1]) + ", "
    op_str = op_str + str(node.module.padding[1]) + ", "
    op_str = op_str + str(enum_to_int(ActiMode, ActiMode.AC_MODE_NONE)) + ", "
    op_str = op_str + str(node.module.groups) + ", "
    if node.module.bias != None:
        op_str = op_str + "1\n"
    else:
        op_str = op_str + "0\n"
    return op_str
Ejemplo n.º 10
0
def parse_scalarmul(op_str, node):
    assert len(node.inedges) == 2, "wrong number of inputs"
    op_str = op_str + enum_to_str(OpType, OpType.SCALAR_MULTIPLY) + ", "
    op_str = op_str + str(node.inedges[1]) + "\n"
    return op_str
Ejemplo n.º 11
0
def parse_softmax(op_str, node):
    assert len(node.inedges) == 1, "wrong number of inputs"
    op_str = op_str + enum_to_str(OpType, OpType.SOFTMAX) + "\n"
    return op_str
Ejemplo n.º 12
0
def parse_transpose(op_str, node):
    assert len(node.inedges) == 3, "wrong number of inputs"
    op_str = op_str + enum_to_str(OpType, OpType.TRANSPOSE)
    op_str = op_str + ", " + str(node.inedges[1]) + ", " + str(
        node.inedges[2]) + "\n"
    return op_str
Ejemplo n.º 13
0
def parse_split(op_str, node):
    #FIXME may be 3
    assert len(node.inedges) == 2, "wrong number of inputs"
    op_str = op_str + enum_to_str(OpType, OpType.SPLIT) + ", "
    op_str = op_str + str(node.inedges[1]) + "\n"
    return op_str
Ejemplo n.º 14
0
def parse_parameter(op_str, parameter):
    op_str = op_str + enum_to_str(OpType, OpType.INIT_PARAM) + ", "
    for dim in parameter.shape[:-1]:
        op_str = op_str + str(dim) + ", "
    op_str = op_str + str(parameter.shape[-1]) + "\n"
    return op_str
Ejemplo n.º 15
0
def parse_identity(op_str, node):
    assert len(node.inedges) == 1, "wrong number of inputs"
    op_str = op_str + enum_to_str(OpType, OpType.IDENTITY) + "\n"
    return op_str
Ejemplo n.º 16
0
def parse_dropout(op_str, node):
    assert len(node.inedges) == 1, "wrong number of inputs"
    op_str = op_str + enum_to_str(OpType, OpType.DROPOUT) + ", "
    op_str = op_str + str(node.module.p) + "\n"
    return op_str
Ejemplo n.º 17
0
def parse_batchnorm2d(op_str, node):
    assert len(node.inedges) == 1, "wrong number of inputs"
    # FIXME BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) args are not in FF
    op_str = op_str + enum_to_str(OpType, OpType.BATCH_NORM) + "\n"
    return op_str
Ejemplo n.º 18
0
def parse_scalartruediv(op_str, node):
    assert len(node.inedges) == 2, "wrong number of inputs"
    op_str = op_str + enum_to_str(OpType, OpType.SCALAR_TRUEDIV) + ", "
    op_str = op_str + str(node.inedges[1]) + "\n"
    return op_str
Ejemplo n.º 19
0
def parse_batchmatmul(op_str, node):
    assert len(node.inedges) == 2, "wrong number of inputs"
    op_str = op_str + enum_to_str(OpType, OpType.BATCH_MATMUL) + "\n"
    return op_str
Ejemplo n.º 20
0
def parse_output(op_str, node):
    #FIXME assume there is 1 output
    assert len(node.inedges) == 1, "wrong format"
    op_str = op_str + enum_to_str(OpType, OpType.OUTPUT) + "\n"
    return op_str
Ejemplo n.º 21
0
def parse_layernorm(op_str, node):
    assert len(node.inedges) == 1, "wrong number of inputs"
    op_str = op_str + enum_to_str(OpType, OpType.LAYER_NORM) + "\n"
    return op_str
Ejemplo n.º 22
0
def parse_getattr(op_str, node):
    assert len(node.inedges) == 2, "wrong number of inputs"
    op_str = op_str + enum_to_str(OpType, OpType.GETATTR) + ", "
    op_str = op_str + str(node.inedges[1]) + "\n"
    return op_str
Ejemplo n.º 23
0
def parse_input(op_str, node):
    assert node.inedges == None, "wrong format"
    op_str = op_str + enum_to_str(OpType, OpType.INPUT) + "\n"
    return op_str
Ejemplo n.º 24
0
def parse_sigmoid(op_str, node):
    assert len(node.inedges) == 1, "wrong number of inputs"
    op_str = op_str + enum_to_str(OpType, OpType.SIGMOID) + "\n"
    return op_str
Ejemplo n.º 25
0
def parse_add(op_str, node):
    assert len(node.inedges) == 2, "wrong number of inputs"
    op_str = op_str + enum_to_str(OpType, OpType.ADD) + "\n"
    return op_str
Ejemplo n.º 26
0
def parse_elu(op_str, node):
    assert len(node.inedges) == 1, "wrong number of inputs"
    op_str = op_str + enum_to_str(OpType, OpType.ELU) + "\n"
    return op_str