示例#1
0
文件: fx.py 项目: lswzjuer/FlexFlow
def parse_adaptivepool2d(op_str, node, pool_type):
  assert len(node.inedges) == 1, "wrong number of inputs"
  #FIXME fix kernel, stride and padding
  op_str = op_str + enum_to_str(OpType, OpType.POOL2D) + ", "
  op_str = op_str + str(3) + ", "
  op_str = op_str + str(1) + ", "
  op_str = op_str + str(0) + ", "
  op_str = op_str + str(enum_to_int(PoolType, pool_type)) + ", "
  op_str = op_str + str(enum_to_int(ActiMode, ActiMode.AC_MODE_NONE)) + "\n"
  return op_str
示例#2
0
def parse_linear(op_str, node):
    assert len(node.inedges) == 1, "wrong number of inputs"
    op_str = op_str + str(enum_to_int(OpType, OpType.LINEAR)) + ", "
    op_str = op_str + str(node.module.out_features) + ", "
    op_str = op_str + str(enum_to_int(ActiMode, ActiMode.AC_MODE_NONE)) + ", "
    if node.module.bias != None:
        op_str = op_str + "1\n"
    else:
        op_str = op_str + "0\n"
    return op_str
示例#3
0
文件: fx.py 项目: lswzjuer/FlexFlow
def parse_pool2d(op_str, node, pool_type):
  assert len(node.inedges) == 1, "wrong number of inputs"
  #FIXME MaxPool2d supports ceil_mode
  op_str = op_str + enum_to_str(OpType, OpType.POOL2D) + ", "
  op_str = op_str + str(node.module.kernel_size) + ", "
  op_str = op_str + str(node.module.stride) + ", "
  op_str = op_str + str(node.module.padding) + ", "
  op_str = op_str + str(enum_to_int(PoolType, pool_type)) + ", "
  op_str = op_str + str(enum_to_int(ActiMode, ActiMode.AC_MODE_NONE)) + "\n"
  return op_str
示例#4
0
def parse_conv2d(op_str, node):
    assert len(node.inedges) == 1, "wrong number of inputs"
    op_str = op_str + str(enum_to_int(OpType, OpType.CONV2D)) + ", "
    op_str = op_str + str(node.module.out_channels) + ", "
    op_str = op_str + str(node.module.kernel_size[0]) + ", "
    op_str = op_str + str(node.module.kernel_size[1]) + ", "
    op_str = op_str + str(node.module.stride[0]) + ", "
    op_str = op_str + str(node.module.stride[1]) + ", "
    op_str = op_str + str(node.module.padding[1]) + ", "
    op_str = op_str + str(node.module.padding[1]) + ", "
    op_str = op_str + str(enum_to_int(ActiMode, ActiMode.AC_MODE_NONE)) + ", "
    if node.module.bias != None:
        op_str = op_str + "1\n"
    else:
        op_str = op_str + "0\n"
    return op_str
示例#5
0
文件: fx.py 项目: msbaines/FlexFlow
def torch_to_flexflow(model, filename):
    graph = __symbolic_trace(model)
    out_file = open(filename, "w")

    for node in graph:
        # op name
        op_str = node.name + ", "

        # op inedges
        inedges = node.inedges[0]
        if type(inedges) == list:
            pass
        elif type(inedges) == tuple:
            pass
        else:
            inedges = [inedges]
        for inedge in inedges:
            if inedge.name == "x":
                op_str = op_str + "input" + ":"
            else:
                op_str = op_str + inedge.name + ":"
        op_str = op_str + ", "

        #op type
        if type(node) == OutputNode:
            #FIXME assume there is 1 output
            assert len(node.inedges) == 1, "wrong format"
            op_str = op_str + str(enum_to_int(OpType, OpType.OUTPUT)) + "\n"

        if type(node) == FunctionNode:
            function_name = str(node.function)
            if function_name.find('add') >= 0:
                assert len(node.inedges) == 2, "wrong number of inputs"
                op_str = op_str + str(enum_to_int(OpType, OpType.ADD)) + "\n"
            elif function_name.find('cat') >= 0:
                #FIXME assume it is a merge
                op_str = op_str + str(enum_to_int(OpType,
                                                  OpType.CONCAT)) + ", "
                op_str = op_str + str(node.inedges[1]) + "\n"
            elif function_name.find('flatten') >= 0:
                assert len(node.inedges) == 1, "wrong number of inputs"
                op_str = op_str + str(enum_to_int(OpType, OpType.FLAT)) + "\n"
            else:
                # Unrecogonized type
                assert False, "Unrecogonized built-in function: {}".format(
                    function_name)

        if type(node) == ModuleNode:
            assert len(node.inedges) == 1, "wrong format"

            if type(node.module) == torch.nn.modules.linear.Linear:
                op_str = op_str + str(enum_to_int(OpType,
                                                  OpType.LINEAR)) + ", "
                op_str = op_str + str(node.module.out_features) + ", "
                op_str = op_str + str(
                    enum_to_int(ActiMode, ActiMode.AC_MODE_NONE)) + ", "
                if node.module.bias != None:
                    op_str = op_str + "1\n"
                else:
                    op_str = op_str + "0\n"

            elif type(node.module) == torch.nn.modules.conv.Conv2d:
                op_str = op_str + str(enum_to_int(OpType,
                                                  OpType.CONV2D)) + ", "
                op_str = op_str + str(node.module.out_channels) + ", "
                op_str = op_str + str(node.module.kernel_size[0]) + ", "
                op_str = op_str + str(node.module.kernel_size[1]) + ", "
                op_str = op_str + str(node.module.stride[0]) + ", "
                op_str = op_str + str(node.module.stride[1]) + ", "
                op_str = op_str + str(node.module.padding[1]) + ", "
                op_str = op_str + str(node.module.padding[1]) + ", "
                op_str = op_str + str(
                    enum_to_int(ActiMode, ActiMode.AC_MODE_NONE)) + ", "
                if node.module.bias != None:
                    op_str = op_str + "1\n"
                else:
                    op_str = op_str + "0\n"

            elif type(node.module) == torch.nn.modules.pooling.MaxPool2d:
                #FIXME MaxPool2d supports ceil_mode
                op_str = op_str + str(enum_to_int(OpType,
                                                  OpType.POOL2D)) + ", "
                op_str = op_str + str(node.module.kernel_size) + ", "
                op_str = op_str + str(node.module.stride) + ", "
                op_str = op_str + str(node.module.padding) + ", "
                op_str = op_str + str(enum_to_int(PoolType,
                                                  PoolType.POOL_MAX)) + ", "
                op_str = op_str + str(
                    enum_to_int(ActiMode, ActiMode.AC_MODE_NONE)) + "\n"

            elif type(node.module) == torch.nn.modules.pooling.AvgPool2d:
                #FIXME AvgPool2d supports ceil_mode
                op_str = op_str + str(enum_to_int(OpType,
                                                  OpType.POOL2D)) + ", "
                op_str = op_str + str(node.module.kernel_size) + ", "
                op_str = op_str + str(node.module.stride) + ", "
                op_str = op_str + str(node.module.padding) + ", "
                op_str = op_str + str(enum_to_int(PoolType,
                                                  PoolType.POOL_AVG)) + ", "
                op_str = op_str + str(
                    enum_to_int(ActiMode, ActiMode.AC_MODE_NONE)) + "\n"

            elif type(node.module) == torch.nn.modules.batchnorm.BatchNorm2d:
                # FIXME BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) args are not in FF
                op_str = op_str + str(enum_to_int(OpType,
                                                  OpType.BATCH_NORM)) + "\n"

            elif type(node.module) == torch.nn.modules.dropout.Dropout:
                op_str = op_str + str(enum_to_int(OpType,
                                                  OpType.DROPOUT)) + ", "
                op_str = op_str + str(node.module.p) + "\n"

            elif type(node.module) == torch.nn.modules.flatten.Flatten:
                op_str = op_str + str(enum_to_int(OpType, OpType.FLAT)) + "\n"

            elif type(node.module) == torch.nn.modules.activation.ReLU:
                op_str = op_str + str(enum_to_int(OpType, OpType.RELU)) + "\n"

            elif type(node.module) == torch.nn.modules.activation.Sigmoid:
                op_str = op_str + str(enum_to_int(OpType,
                                                  OpType.SIGMOID)) + "\n"

            elif type(node.module) == torch.nn.modules.activation.Tanh:
                op_str = op_str + str(enum_to_int(OpType, OpType.TANH)) + "\n"

            elif type(node.module) == torch.nn.modules.activation.ELU:
                op_str = op_str + str(enum_to_int(OpType, OpType.ELU)) + "\n"

            else:
                print(node.module)
                assert 0, "unknown op"

        print(op_str)
        out_file.write(op_str)

    out_file.close()
示例#6
0
def parse_concat(op_str, node):
    #FIXME assume it is a merge
    op_str = op_str + str(enum_to_int(OpType, OpType.CONCAT)) + ", "
    op_str = op_str + str(node.inedges[1]) + "\n"
    return op_str
示例#7
0
def parse_add(op_str, node):
    assert len(node.inedges) == 2, "wrong number of inputs"
    op_str = op_str + str(enum_to_int(OpType, OpType.ADD)) + "\n"
    return op_str
示例#8
0
def parse_output(op_str, node):
    #FIXME assume there is 1 output
    assert len(node.inedges) == 1, "wrong format"
    op_str = op_str + str(enum_to_int(OpType, OpType.OUTPUT)) + "\n"
    return op_str
示例#9
0
def parse_input(op_str, node):
    assert node.inedges == None, "wrong format"
    op_str = op_str + str(enum_to_int(OpType, OpType.INPUT)) + "\n"
    return op_str
示例#10
0
def parse_softmax(op_str, node):
    assert len(node.inedges) == 1, "wrong number of inputs"
    op_str = op_str + str(enum_to_int(OpType, OpType.SOFTMAX)) + "\n"
    return op_str
示例#11
0
def parse_tanh(op_str, node):
    assert len(node.inedges) == 1, "wrong number of inputs"
    op_str = op_str + str(enum_to_int(OpType, OpType.TANH)) + "\n"
    return op_str
示例#12
0
def parse_sigmoid(op_str, node):
    assert len(node.inedges) == 1, "wrong number of inputs"
    op_str = op_str + str(enum_to_int(OpType, OpType.SIGMOID)) + "\n"
    return op_str
示例#13
0
def parse_dropout(op_str, node):
    assert len(node.inedges) == 1, "wrong number of inputs"
    op_str = op_str + str(enum_to_int(OpType, OpType.DROPOUT)) + ", "
    op_str = op_str + str(node.module.p) + "\n"
    return op_str
示例#14
0
def parse_batchnorm2d(op_str, node):
    assert len(node.inedges) == 1, "wrong number of inputs"
    # FIXME BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) args are not in FF
    op_str = op_str + str(enum_to_int(OpType, OpType.BATCH_NORM)) + "\n"
    return op_str
示例#15
0
def parse_flat(op_str, node):
    #assert len(node.inedges) == 1, "wrong number of inputs"
    op_str = op_str + str(enum_to_int(OpType, OpType.FLAT)) + "\n"
    return op_str