Пример #1
0
    def SoftmaxCrossEntropyWithLogits(graph: Graph.Graph, curNode: Graph.Node,
                                      dictNodeNameToOutVarStr: dict,
                                      extraNodeInfoDict: dict):
        # Input1 is logits and Input2 is the one-hot encoding true distribution
        # Calculate softmax on input1 and cross-entropy between that (p(x)) and true-distribution (q(x))
        # Cross-entropy = \summation_x{-q(x)*log(p(x))}
        inputsRef = curNode.getInputsRef()
        assert (len(inputsRef) == 2)
        logitsInpt = AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])
        labelsInpt = AST.ID(dictNodeNameToOutVarStr[inputsRef[1]])

        # reduce along column to get row-vector
        # TODO : softmax or implement here ?
        retAST = AST.Let(
            AST.ID('temp_softmax'),
            AST.Func(TFNodesAST.getOperatorsIdx('softmax'), logitsInpt), None)
        retAST.expr = AST.Let(
            AST.ID('temp_1'),
            AST.UOp(
                TFNodesAST.getOperatorsIdx('-'),
                AST.Reduce(
                    AST.BOp(
                        labelsInpt, TFNodesAST.getOperatorsIdx('.*'),
                        AST.Func(TFNodesAST.getOperatorsIdx('log'),
                                 AST.ID('temp_softmax'))), 1,
                    TFNodesAST.getOperatorsIdx('+'))), AST.ID('temp_1'))
        return (None, retAST)
Пример #2
0
 def LogSoftmax(graph: Graph.Graph, curNode: Graph.Node,
                dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
     inputsRef = curNode.getInputsRef()
     assert (len(inputsRef) == 1)
     expAST = AST.Func(TFNodesAST.getOperatorsIdx('exp'),
                       AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]))
     reduceAST = AST.Reduce(expAST, AST.Int(-1),
                            TFNodesAST.getOperatorsIdx('+'))
     return (None,
             AST.BOp(
                 AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
                 TFNodesAST.getOperatorsIdx('+'),
                 AST.UOp(
                     TFNodesAST.getOperatorsIdx('-'),
                     AST.Func(TFNodesAST.getOperatorsIdx('log'),
                              reduceAST))))
Пример #3
0
 def Shape(graph: Graph.Graph, curNode: Graph.Node,
           dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
     inputsRef = curNode.getInputsRef()
     assert (len(inputsRef) == 1)
     return (None,
             AST.Func(TFNodesAST.getOperatorsIdx('shape'),
                      AST.ID(dictNodeNameToOutVarStr[inputsRef[0]])))
Пример #4
0
 def FloorDiv(graph: Graph.Graph, curNode: Graph.Node,
              dictNodeNameToOutVarStr: dict, extraNodeInfoDict: dict):
     inputsRef = curNode.getInputsRef()
     assert (len(inputsRef) == 2)
     realDivAST = AST.BOp(AST.ID(dictNodeNameToOutVarStr[inputsRef[0]]),
                          TFNodesAST.getOperatorsIdx('./'),
                          AST.ID(dictNodeNameToOutVarStr[inputsRef[1]]))
     return (None, AST.Func(TFNodesAST.getOperatorsIdx('floor'),
                            realDivAST))
Пример #5
0
    def ArgMax(node, value_info, node_name_to_out_var_dict,
               innermost_let_ast_node, out_var_count, mtdAST):
        node = OnnxNode(node)
        if (DEBUG):
            print(node)
        inputsRef = node.inputs

        seedot_output_ast = AST.Func(
            SeeDotParser.ARGMAX,
            AST.ID(node_name_to_out_var_dict[inputsRef[0]]))
        output_name = get_new_var_name(out_var_count)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, seedot_output_ast, output_name, mtdAST)
        out_var_count += 1

        node_name_to_out_var_dict[node.outputs[0]] = output_name

        return (innermost_let_ast_node, out_var_count)
Пример #6
0
    def Relu(node, value_info, node_name_to_out_var_dict,
             innermost_let_ast_node, out_var_count, mtdAST):
        node = OnnxNode(node)

        inputsRef = node.inputs
        assert (len(inputsRef) == 1)

        spatial_size = len(value_info[inputsRef[0]][1])

        relu_input_name = node_name_to_out_var_dict[inputsRef[0]]
        if (spatial_size >= 4):
            relu_input_name = get_new_var_name(out_var_count)
            reshaped_input = get_reshaped_input_ast(inputsRef[0], value_info,
                                                    node_name_to_out_var_dict)
            innermost_let_ast_node = update_program_with_new_node(
                innermost_let_ast_node, reshaped_input, relu_input_name,
                mtdAST)
            out_var_count += 1

        seedot_output_ast = AST.Func(SeeDotParser.RELU,
                                     AST.ID(relu_input_name))
        output_name = get_new_var_name(out_var_count)
        innermost_let_ast_node = update_program_with_new_node(
            innermost_let_ast_node, seedot_output_ast, output_name, mtdAST)
        out_var_count += 1

        final_output_name = output_name
        if (spatial_size >= 4):
            final_output_name = get_new_var_name(out_var_count)
            onnx_output_ast = get_reshaped_output_ast(node.outputs[0],
                                                      value_info, output_name)
            innermost_let_ast_node = update_program_with_new_node(
                innermost_let_ast_node, onnx_output_ast, final_output_name,
                mtdAST)
            out_var_count += 1

        node_name_to_out_var_dict[node.outputs[0]] = final_output_name

        if (DEBUG):
            print(node.outputs[0])
            print(onnx_input_shape, '->', seedot_input_shape, '->',
                  onnx_output_shape)

        return (innermost_let_ast_node, out_var_count)
Пример #7
0
 def visitFunc(self, ctx: SeeDotParser.FuncContext):
     op = ctx.specialFunc().getChild(0).symbol.type
     expr = self.visit(ctx.expr())
     return AST.Func(op, expr)