Exemplo n.º 1
0
def test_comparatorListener():
    from antlr4 import InputStream, ParseTreeWalker
    from antlr4 import CommonTokenStream
    from processor.comparison.comparisonantlr.comparatorLexer import comparatorLexer
    from processor.comparison.comparisonantlr.comparatorParser import comparatorParser
    from processor.comparison.comparisonantlr.comparatorListener import comparatorListener
    input_stream = InputStream('exist({1}.location)')
    lexer = comparatorLexer(input_stream)
    stream = CommonTokenStream(lexer)
    parser = comparatorParser(stream)
    tree = parser.expression()
    printer = comparatorListener()
    walker = ParseTreeWalker()
    print(walker.walk(printer, tree))
Exemplo n.º 2
0
def main_comparator(code, otherdata):
    logger.info('#' * 75)
    logger.info('Actual Rule: %s', code)
    inputstream = InputStream(code)
    lexer = comparatorLexer(inputstream)
    stream = CommonTokenStream(lexer)
    parser = comparatorParser(stream)
    tree = parser.expression()
    print('#' * 50)
    print(tree.toStringTree(recog=parser))
    children = []
    for child in tree.getChildren():
        children.append((child.getText()))
    logger.info('*' * 50)
    logger.debug("All the parsed tokens: %s", children)
    r_i = RuleInterpreter(children, **otherdata)
    print(r_i.compare())
Exemplo n.º 3
0
def main(argv):
    # input = FileStream(argv[1])
    from antlr4 import InputStream
    from antlr4 import CommonTokenStream
    from processor.comparison.comparisonantlr.comparatorLexer import comparatorLexer
    from processor.comparison.comparisonantlr.comparatorParser import comparatorParser
    try:
        with open(argv[1]) as f:
            for line in f:
                code = line.rstrip()
                print('#' * 75)
                print('Actual Rule: ', code)
                inputStream = InputStream(code)
                lexer = comparatorLexer(inputStream)
                stream = CommonTokenStream(lexer)
                parser = comparatorParser(stream)
                tree = parser.expression()
                print(tree.toStringTree(recog=parser))
                children = []
                for child in tree.getChildren():
                    children.append((child.getText()))
                print("All the parsed tokens: ", children)
                if len(argv) == 2:
                    otherdata = {
                        'dbname': 'validator',
                        'container': 'container1',
                        'snapshots': {}
                    }
                    r_i = RuleInterpreter(children, **otherdata)
                    lval = r_i.get_value(r_i.lhs_operand)
                    print(
                        "LHS: {}".format(json.dumps(lval) if lval else "None"))
                    rval = r_i.get_value(r_i.rhs_operand)
                    print("RHS: {}".format(rval))
                    result = r_i.compare()
                    print("RESULT: {}".format("PASS" if result else "FAIL"))
                print('*' * 50)
        return True
    except Exception as ex:
        print("Exception: %s" % ex)
        return False
Exemplo n.º 4
0
def main(argv):
    # input = FileStream(argv[1])
    try:
        with open(argv[1]) as f:
            for line in f:
                code = line.rstrip()
                print('#' * 75)
                print('Actual Rule: ', code)
                inputStream = InputStream(code)
                lexer = comparatorLexer(inputStream)
                stream = CommonTokenStream(lexer)
                parser = comparatorParser(stream)
                tree = parser.expression()
                print(tree.toStringTree(recog=parser))
                children = []
                for child in tree.getChildren():
                    children.append((child.getText()))
                print('*' * 50)
                print("All the parsed tokens: ", children)
                r_i = RuleInterpreter(children)
        return True
    except:
        return False
 def validate(self):
     result_val = {"result": "failed"}
     if self.format == TESTCASEV1:
         if self.snapshot_id:
             docs = get_documents(self.collection,
                                  dbname=self.dbname,
                                  sort=[('timestamp', pymongo.DESCENDING)],
                                  query={'snapshotId': self.snapshot_id},
                                  limit=1)
             logger.info('Number of Snapshot Documents: %s', len(docs))
             if docs and len(docs):
                 self.data = docs[0]['json']
                 if self.op in OPERATORS and OPERATORS[self.op]:
                     result = OPERATORS[self.op](self.data, self.loperand,
                                                 self.roperand, self.is_not,
                                                 self.extras)
                     result_val["result"] = "passed" if result else "failed"
                     result_val["snapshots"] = [{
                         'id':
                         docs[0]['snapshotId'],
                         'path':
                         docs[0]['path'],
                         'structure':
                         docs[0]['structure'],
                         'reference':
                         docs[0]['reference'],
                         'source':
                         docs[0]['source']
                     }]
             else:
                 result_val.update({
                     "result":
                     "skipped",
                     "reason":
                     "Missing documents for the snapshot"
                 })
         else:
             result_val.update({
                 "result": "skipped",
                 "reason": "Missing snapshotId for testcase"
             })
     elif self.format == TESTCASEV2:
         if self.type == 'rego':
             result = self.process_rego_test_case()
             result_val["result"] = "passed" if result else "failed"
             result_val['snapshots'] = self.snapshots
         else:
             logger.info('#' * 75)
             logger.info('Actual Rule: %s', self.rule)
             input_stream = InputStream(self.rule)
             lexer = comparatorLexer(input_stream)
             stream = CommonTokenStream(lexer)
             parser = comparatorParser(stream)
             tree = parser.expression()
             children = []
             for child in tree.getChildren():
                 children.append((child.getText()))
             logger.info('*' * 50)
             logger.debug("All the parsed tokens: %s", children)
             otherdata = {
                 'dbname': self.dbname,
                 'snapshots': self.collection_data,
                 'container': self.container
             }
             r_i = RuleInterpreter(children, **otherdata)
             result = r_i.compare()
             result_val["result"] = "passed" if result else "failed"
             result_val['snapshots'] = r_i.get_snapshots()
     else:
         result_val.update({
             "result": "skipped",
             "reason": "Unsupported testcase format"
         })
     return result_val
Exemplo n.º 6
0
def test_comparatorLexer():
    from processor.comparison.comparisonantlr.comparatorLexer import comparatorLexer
    val = comparatorLexer()
    assert val is not None
Exemplo n.º 7
0
def test_comparatorParser():
    from antlr4 import InputStream
    from antlr4 import CommonTokenStream
    from processor.comparison.comparisonantlr.comparatorLexer import comparatorLexer
    from processor.comparison.comparisonantlr.comparatorParser import comparatorParser
    from processor.comparison.interpreter import RuleInterpreter
    vals = [
        'count({1}.firewall.rules[] + {2}.firewall.rules[]) = 13',
        'count({1}.firewall.rules[]) + count({2}.firewall.rules[]) = 13',
        'count({1}.firewall.rules[] + {2}.firewall.rules[]) > 13',
        'count({1}.firewall.rules[] + {2}.firewall.rules[]) < 13',
        'count({1}.firewall.rules[] + {2}.firewall.rules[]) >= 13',
        'count({1}.firewall.rules[] + {2}.firewall.rules[]) <= 13',
        'count({1}.firewall.rules[] + {2}.firewall.rules[]) != 13',
        'count({1}.firewall.rules[]) = count({2}.firewall.rules[])',
        "{2}.properties.cost=2.34",
        "{2}.properties.addressSpace={'addressPrefixes': ['172.18.116.0/23']}",
        "{1}.[0].name=abcd",
        "{1}.['name' = 'abcd'].location = 'eastus2'",
        '{1}.dns.ip = 1.2.4.5',
        '{1}.dns.ip = 1.2.4.5/32',
        '{1}.location = [1,2,4]',
        "{2}.properties.dhcpOptions.dnsServers[]+{3}.properties.dhcpOptions.dnsServers[]=['172.18.96.214', '172.18.96.216', '172.18.96.214', '172.18.96.216']",
        'count(count(count({1}.location.name[0]))+count(count({2}.location.name[0])))= 13',
        "{1}.firewall.rules['name' = 'rule1'].port = {2}.firewall.rules['name' = 'rule1'].port",
        'count({1}.firewall.rules[]) = count({2}.firewall.rules[])',
        'count(count({1}.firewall.rules[]) + count({1}.firewall.rules[])) = 13',
        'exist({1}.location)',
        'exist({1}.location) = TRUE',
        'exist({1}.location) = true',
        'exist({1}.location) = FALSE',
        'exist({1}.location) = false',
        'exist({1}[0].location)',
        'exist({1}.firewall.location)',
        'exist({1}.firewall.rules[])',
        'count({1}.firewall.rules[]) != 13',
        'count({1}.firewall.rules[]) = 13',
        '{1}.firewall.port = 443',
        "{1}.location = 'eastus2'",
        'exist({1}.location) = FAlSE',
        '{1}.firewall.port = 443',
        "{1}.firewall.rules['name' = 'rule1'].port = 443",
        "{1}.firewall.port = {2}.firewall.port",
        '{1}.firewall.rules[0].port = {2}.firewall.port',
        'exist({1}[0].location)',
        "exist({1}['name' = 'abc'])",
        "{1}.firewall.rules['name' = 'abc'].port = {2}.firewall.port",
        "{1}.firewall.rules['name' = 'abc'].ports[2].port = {2}.firewall.port",
        "{1}.firewall.cost = 443.25",
        "{1}[0].location = 'eastus2'",
    ]
    for line in vals:
        code = line.rstrip()
        # print('#' * 75)
        # print('Actual Rule: ', code)
        inputStream = InputStream(code)
        lexer = comparatorLexer(inputStream)
        stream = CommonTokenStream(lexer)
        parser = comparatorParser(stream)
        tree = parser.expression()
        # print(tree.toStringTree(recog=parser))
        children = []
        for child in tree.getChildren():
            children.append((child.getText()))
        assert len(children) > 0
        # print('*' * 50)
        # print("All the parsed tokens: ", children)
        r_i = RuleInterpreter(children)
        assert r_i is not None
Exemplo n.º 8
0
 def validate(self):
     result_val = [{"result": "failed"}]
     if self.format == TESTCASEV1:
         if self.snapshot_id:
             docs = get_documents(self.collection, dbname=self.dbname,
                                  sort=[('timestamp', pymongo.DESCENDING)],
                                  query={'snapshotId': self.snapshot_id},
                                  limit=1)
             logger.info('Number of Snapshot Documents: %s', len(docs))
             if docs and len(docs):
                 self.data = docs[0]['json']
                 if self.op in OPERATORS and OPERATORS[self.op]:
                     result = OPERATORS[self.op](self.data, self.loperand, self.roperand,
                                                 self.is_not, self.extras)
                     result_val[0]["result"] = "passed" if result else "failed"
                     result_val[0]["snapshots"] = [{
                         'id': docs[0]['snapshotId'],
                         'structure': docs[0]['structure'],
                         'reference': docs[0]['reference'],
                         'source': docs[0]['source'],
                         'collection': docs[0]['collection']
                     }]
                     if "paths" in docs[0]:
                         result_val[0]["snapshots"][0]["paths"] = docs[0]["paths"]
                     else:
                         result_val[0]["snapshots"][0]["path"] = docs[0]["path"]
             else:
                 result_val[0].update({
                     "result": "skipped",
                     "message": "Missing documents for the snapshot"
                 })
         else:
             result_val[0].update({
                 "result": "skipped",
                 "message": "Missing snapshotId for testcase"
             })
     elif self.format == TESTCASEV2:
         if self.type == 'rego':
             results = self.process_rego_test_case()
             result_val = []
             connector_data = self.get_connector_data()
             for result in results:
                 result['snapshots'] = self.snapshots
                 result['autoRemediate'] = connector_data.get("autoRemediate", False)
                 result_val.append(result)
         else:
             # logger.info('#' * 75)
             logger.critical('\tTESTID: %s', self.testcase['testId'])
             input_stream = InputStream(self.rule)
             lexer = comparatorLexer(input_stream)
             stream = CommonTokenStream(lexer)
             parser = comparatorParser(stream)
             tree = parser.expression()
             children = []
             for child in tree.getChildren():
                 children.append((child.getText()))
             # logger.info('*' * 50)
             logger.debug("All the parsed tokens: %s", children)
             otherdata = {'dbname': self.dbname, 'snapshots': self.collection_data, 'container': self.container}
             r_i = RuleInterpreter(children, **otherdata)
             # result = r_i.compare()
             l_val, r_val, result = r_i.compare()
             result_val[0]["result"] = "passed" if result else "failed"
             result_val[0]['snapshots'] = r_i.get_snapshots()
             connector_data = self.get_connector_data()
             result_val[0]['autoRemediate'] = connector_data.get("autoRemediate", False)
             if result_val[0]['snapshots']:
                 snapshot = result_val[0]['snapshots'][0]
                 logger.critical('\t\tSNAPSHOTID: %s', snapshot['id'])
                 logger.critical('\t\tPATHS: ')
                 for path in snapshot.get('paths', []):
                     logger.critical('\t\t\t %s', path)
             if not result:
                 logger.critical('\t\tLHS: %s', l_val)
                 logger.critical('\t\tRHS: %s', r_val)
             self.log_result(result_val[0])
     else:
         result_val[0].update({
             "result": "skipped",
             "reason": "Unsupported testcase format"
         })
     
     if 'dirpath' in self.testcase:
         del self.testcase['dirpath']
     return result_val