def parse_adaptivepool2d(op_str, node, pool_type): assert len(node.inedges) == 1, "wrong number of inputs" #FIXME fix kernel, stride and padding op_str = op_str + enum_to_str(OpType, OpType.POOL2D) + ", " op_str = op_str + str(3) + ", " op_str = op_str + str(1) + ", " op_str = op_str + str(0) + ", " op_str = op_str + str(enum_to_int(PoolType, pool_type)) + ", " op_str = op_str + str(enum_to_int(ActiMode, ActiMode.AC_MODE_NONE)) + "\n" return op_str
def parse_linear(op_str, node): assert len(node.inedges) == 1, "wrong number of inputs" op_str = op_str + str(enum_to_int(OpType, OpType.LINEAR)) + ", " op_str = op_str + str(node.module.out_features) + ", " op_str = op_str + str(enum_to_int(ActiMode, ActiMode.AC_MODE_NONE)) + ", " if node.module.bias != None: op_str = op_str + "1\n" else: op_str = op_str + "0\n" return op_str
def parse_pool2d(op_str, node, pool_type): assert len(node.inedges) == 1, "wrong number of inputs" #FIXME MaxPool2d supports ceil_mode op_str = op_str + enum_to_str(OpType, OpType.POOL2D) + ", " op_str = op_str + str(node.module.kernel_size) + ", " op_str = op_str + str(node.module.stride) + ", " op_str = op_str + str(node.module.padding) + ", " op_str = op_str + str(enum_to_int(PoolType, pool_type)) + ", " op_str = op_str + str(enum_to_int(ActiMode, ActiMode.AC_MODE_NONE)) + "\n" return op_str
def parse_conv2d(op_str, node): assert len(node.inedges) == 1, "wrong number of inputs" op_str = op_str + str(enum_to_int(OpType, OpType.CONV2D)) + ", " op_str = op_str + str(node.module.out_channels) + ", " op_str = op_str + str(node.module.kernel_size[0]) + ", " op_str = op_str + str(node.module.kernel_size[1]) + ", " op_str = op_str + str(node.module.stride[0]) + ", " op_str = op_str + str(node.module.stride[1]) + ", " op_str = op_str + str(node.module.padding[1]) + ", " op_str = op_str + str(node.module.padding[1]) + ", " op_str = op_str + str(enum_to_int(ActiMode, ActiMode.AC_MODE_NONE)) + ", " if node.module.bias != None: op_str = op_str + "1\n" else: op_str = op_str + "0\n" return op_str
def torch_to_flexflow(model, filename): graph = __symbolic_trace(model) out_file = open(filename, "w") for node in graph: # op name op_str = node.name + ", " # op inedges inedges = node.inedges[0] if type(inedges) == list: pass elif type(inedges) == tuple: pass else: inedges = [inedges] for inedge in inedges: if inedge.name == "x": op_str = op_str + "input" + ":" else: op_str = op_str + inedge.name + ":" op_str = op_str + ", " #op type if type(node) == OutputNode: #FIXME assume there is 1 output assert len(node.inedges) == 1, "wrong format" op_str = op_str + str(enum_to_int(OpType, OpType.OUTPUT)) + "\n" if type(node) == FunctionNode: function_name = str(node.function) if function_name.find('add') >= 0: assert len(node.inedges) == 2, "wrong number of inputs" op_str = op_str + str(enum_to_int(OpType, OpType.ADD)) + "\n" elif function_name.find('cat') >= 0: #FIXME assume it is a merge op_str = op_str + str(enum_to_int(OpType, OpType.CONCAT)) + ", " op_str = op_str + str(node.inedges[1]) + "\n" elif function_name.find('flatten') >= 0: assert len(node.inedges) == 1, "wrong number of inputs" op_str = op_str + str(enum_to_int(OpType, OpType.FLAT)) + "\n" else: # Unrecogonized type assert False, "Unrecogonized built-in function: {}".format( function_name) if type(node) == ModuleNode: assert len(node.inedges) == 1, "wrong format" if type(node.module) == torch.nn.modules.linear.Linear: op_str = op_str + str(enum_to_int(OpType, OpType.LINEAR)) + ", " op_str = op_str + str(node.module.out_features) + ", " op_str = op_str + str( enum_to_int(ActiMode, ActiMode.AC_MODE_NONE)) + ", " if node.module.bias != None: op_str = op_str + "1\n" else: op_str = op_str + "0\n" elif type(node.module) == torch.nn.modules.conv.Conv2d: op_str = op_str + str(enum_to_int(OpType, OpType.CONV2D)) + ", " op_str = op_str + str(node.module.out_channels) + ", " op_str = op_str + str(node.module.kernel_size[0]) + ", " op_str = op_str + str(node.module.kernel_size[1]) + ", " op_str = op_str + str(node.module.stride[0]) + ", " op_str = op_str + str(node.module.stride[1]) + ", " op_str = op_str + str(node.module.padding[1]) + ", " op_str = op_str + str(node.module.padding[1]) + ", " op_str = op_str + str( enum_to_int(ActiMode, ActiMode.AC_MODE_NONE)) + ", " if node.module.bias != None: op_str = op_str + "1\n" else: op_str = op_str + "0\n" elif type(node.module) == torch.nn.modules.pooling.MaxPool2d: #FIXME MaxPool2d supports ceil_mode op_str = op_str + str(enum_to_int(OpType, OpType.POOL2D)) + ", " op_str = op_str + str(node.module.kernel_size) + ", " op_str = op_str + str(node.module.stride) + ", " op_str = op_str + str(node.module.padding) + ", " op_str = op_str + str(enum_to_int(PoolType, PoolType.POOL_MAX)) + ", " op_str = op_str + str( enum_to_int(ActiMode, ActiMode.AC_MODE_NONE)) + "\n" elif type(node.module) == torch.nn.modules.pooling.AvgPool2d: #FIXME AvgPool2d supports ceil_mode op_str = op_str + str(enum_to_int(OpType, OpType.POOL2D)) + ", " op_str = op_str + str(node.module.kernel_size) + ", " op_str = op_str + str(node.module.stride) + ", " op_str = op_str + str(node.module.padding) + ", " op_str = op_str + str(enum_to_int(PoolType, PoolType.POOL_AVG)) + ", " op_str = op_str + str( enum_to_int(ActiMode, ActiMode.AC_MODE_NONE)) + "\n" elif type(node.module) == torch.nn.modules.batchnorm.BatchNorm2d: # FIXME BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) args are not in FF op_str = op_str + str(enum_to_int(OpType, OpType.BATCH_NORM)) + "\n" elif type(node.module) == torch.nn.modules.dropout.Dropout: op_str = op_str + str(enum_to_int(OpType, OpType.DROPOUT)) + ", " op_str = op_str + str(node.module.p) + "\n" elif type(node.module) == torch.nn.modules.flatten.Flatten: op_str = op_str + str(enum_to_int(OpType, OpType.FLAT)) + "\n" elif type(node.module) == torch.nn.modules.activation.ReLU: op_str = op_str + str(enum_to_int(OpType, OpType.RELU)) + "\n" elif type(node.module) == torch.nn.modules.activation.Sigmoid: op_str = op_str + str(enum_to_int(OpType, OpType.SIGMOID)) + "\n" elif type(node.module) == torch.nn.modules.activation.Tanh: op_str = op_str + str(enum_to_int(OpType, OpType.TANH)) + "\n" elif type(node.module) == torch.nn.modules.activation.ELU: op_str = op_str + str(enum_to_int(OpType, OpType.ELU)) + "\n" else: print(node.module) assert 0, "unknown op" print(op_str) out_file.write(op_str) out_file.close()
def parse_concat(op_str, node): #FIXME assume it is a merge op_str = op_str + str(enum_to_int(OpType, OpType.CONCAT)) + ", " op_str = op_str + str(node.inedges[1]) + "\n" return op_str
def parse_add(op_str, node): assert len(node.inedges) == 2, "wrong number of inputs" op_str = op_str + str(enum_to_int(OpType, OpType.ADD)) + "\n" return op_str
def parse_output(op_str, node): #FIXME assume there is 1 output assert len(node.inedges) == 1, "wrong format" op_str = op_str + str(enum_to_int(OpType, OpType.OUTPUT)) + "\n" return op_str
def parse_input(op_str, node): assert node.inedges == None, "wrong format" op_str = op_str + str(enum_to_int(OpType, OpType.INPUT)) + "\n" return op_str
def parse_softmax(op_str, node): assert len(node.inedges) == 1, "wrong number of inputs" op_str = op_str + str(enum_to_int(OpType, OpType.SOFTMAX)) + "\n" return op_str
def parse_tanh(op_str, node): assert len(node.inedges) == 1, "wrong number of inputs" op_str = op_str + str(enum_to_int(OpType, OpType.TANH)) + "\n" return op_str
def parse_sigmoid(op_str, node): assert len(node.inedges) == 1, "wrong number of inputs" op_str = op_str + str(enum_to_int(OpType, OpType.SIGMOID)) + "\n" return op_str
def parse_dropout(op_str, node): assert len(node.inedges) == 1, "wrong number of inputs" op_str = op_str + str(enum_to_int(OpType, OpType.DROPOUT)) + ", " op_str = op_str + str(node.module.p) + "\n" return op_str
def parse_batchnorm2d(op_str, node): assert len(node.inedges) == 1, "wrong number of inputs" # FIXME BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) args are not in FF op_str = op_str + str(enum_to_int(OpType, OpType.BATCH_NORM)) + "\n" return op_str
def parse_flat(op_str, node): #assert len(node.inedges) == 1, "wrong number of inputs" op_str = op_str + str(enum_to_int(OpType, OpType.FLAT)) + "\n" return op_str