Example #1
0
 def testGoodHeader(self):
   nftables.Nftables(policy.ParsePolicy(GOOD_HEADER_1 + GOOD_TERM_1,
                                        self.mock_naming), EXP_INFO)
   nft = str(nftables.Nftables(policy.ParsePolicy(GOOD_HEADER_1 + GOOD_TERM_1 +
                                                  GOOD_HEADER_2 + IPV6_TERM_2,
                                                  self.mock_naming),
                               EXP_INFO))
   self.assertIn('flush table ip table_filter', nft)
   self.assertIn(
       'table ip table_filter {\n\tchain chain_name {\n\t\ttype filter '
       'hook input priority 0;\n\t\taccept\n\t}\n}', nft)
   self.assertIn('flush table ip6 table_filter', nft)
   self.assertIn(
       'table ip6 table_filter {\n\tchain chain_name {\n\t\ttype filter '
       'hook input priority 0;\n\t\tdrop\n\t}\n}', nft)
Example #2
0
 def testExpired(self, mock_logging_warn):
     nftables.Nftables(
         policy.ParsePolicy(GOOD_HEADER_1 + EXPIRED_TERM, self.mock_naming),
         EXP_INFO)
     mock_logging_warn.assert_called_once_with(
         'Term %s in policy %s is expired '
         'and will not be rendered.', 'is_expired', 'chain_name')
Example #3
0
 def testGoodHeader(self, mock_logging_info):
   nftables.Nftables(policy.ParsePolicy(GOOD_HEADER_2 + GOOD_TERM_1,
                                        self.mock_naming), EXP_INFO)
   mock_logging_info.assert_called_once_with('Chain %s is a non-base '
                                             'chain, make sure it is linked.',
                                             'chain_name')
   nft = str(nftables.Nftables(policy.ParsePolicy(GOOD_HEADER_1 + GOOD_TERM_1 +
                                                  GOOD_HEADER_3 + IPV6_TERM_2,
                                                  self.mock_naming),
                               EXP_INFO))
   self.assertIn('flush table ip filter', nft)
   self.assertIn('table ip filter {\n\tchain chain_name {\n\t\ttype filter '
                 'hook input priority 0;\n\t\taccept\n\t}\n}', nft)
   self.assertIn('flush table ip6 filter', nft)
   self.assertIn('table ip6 filter {\n\tchain chain_name {\n\t\ttype filter '
                 'hook input priority 0;\n\t\tdrop\n\t}\n}', nft)
Example #4
0
 def testVerbatimTerm(self):
   nft = str(nftables.Nftables(policy.ParsePolicy(GOOD_HEADER_1 + GOOD_TERM_10,
                                                  self.mock_naming), EXP_INFO))
   self.assertIn('mary had a little lamb', nft)
   # check if another platforms verbatim shows up
   self.assertNotIn('mary had a second lamb', nft)
   self.assertNotIn('mary had a third lamb', nft)
Example #5
0
 def testBuildWarningTokens(self):
     pol1 = nftables.Nftables(
         policy.ParsePolicy(GOOD_HEADER_1 + GOOD_TERM_12, self.mock_naming),
         EXP_INFO)
     st, sst = pol1._BuildTokens()
     self.assertEquals(st, SUPPORTED_TOKENS)
     self.assertEquals(sst, SUPPORTED_SUB_TOKENS)
Example #6
0
 def testSourceDestExcludeFromAllIps(self):
   source_network = []
   source_exclude_network = [nacaddr.IPv4('192.168.0.0/27')]
   destination_network = []
   destination_exclude_network = [nacaddr.IPv4('10.0.0.0/27')]
   self.mock_naming.GetNetAddr.side_effect = [source_network,
                                              source_exclude_network,
                                              destination_network,
                                              destination_exclude_network]
   nft = str(nftables.Nftables(policy.ParsePolicy(GOOD_HEADER_1 + GOOD_TERM_11,
                                                  self.mock_naming), EXP_INFO))
   self.assertIn('ip saddr { 0.0.0.0/1, 128.0.0.0/2, 192.0.0.0/9, '
                 '192.128.0.0/11, 192.160.0.0/13, 192.168.0.32/27, '
                 '192.168.0.64/26, 192.168.0.128/25, 192.168.1.0/24, '
                 '192.168.2.0/23, 192.168.4.0/22, 192.168.8.0/21, '
                 '192.168.16.0/20, 192.168.32.0/19, 192.168.64.0/18, '
                 '192.168.128.0/17, 192.169.0.0/16, 192.170.0.0/15, '
                 '192.172.0.0/14, 192.176.0.0/12, 192.192.0.0/10, '
                 '193.0.0.0/8, 194.0.0.0/7, 196.0.0.0/6, 200.0.0.0/5, '
                 '208.0.0.0/4, 224.0.0.0/3}', nft)
   self.assertIn('ip daddr { 0.0.0.0/5, 8.0.0.0/7, 10.0.0.32/27, '
                 '10.0.0.64/26, 10.0.0.128/25, 10.0.1.0/24, 10.0.2.0/23, '
                 '10.0.4.0/22, 10.0.8.0/21, 10.0.16.0/20, 10.0.32.0/19, '
                 '10.0.64.0/18, 10.0.128.0/17, 10.1.0.0/16, 10.2.0.0/15, '
                 '10.4.0.0/14, 10.8.0.0/13, 10.16.0.0/12, 10.32.0.0/11, '
                 '10.64.0.0/10, 10.128.0.0/9, 11.0.0.0/8, 12.0.0.0/6, '
                 '16.0.0.0/4, 32.0.0.0/3, 64.0.0.0/2, 128.0.0.0/1}', nft)
Example #7
0
 def testCommentOwner(self):
     nft = str(
         nftables.Nftables(
             policy.ParsePolicy(GOOD_HEADER_1 + GOOD_TERM_4,
                                self.mock_naming), EXP_INFO))
     self.assertIn(
         'comment "comment first line comment second line '
         'Owner: [email protected]"', nft)
Example #8
0
 def testMultiSport(self):
     source_ports = ['25', '80', '6610', '6611', '6612']
     self.mock_naming.GetServiceByProto.return_value = source_ports
     nft = str(
         nftables.Nftables(
             policy.ParsePolicy(GOOD_HEADER_1 + GOOD_TERM_8,
                                self.mock_naming), EXP_INFO))
     self.assertIn('sport { 25, 80, 6610-6612}', nft)
Example #9
0
 def testCommentTruncate(self, mock_logging_warn):
   nft = str(nftables.Nftables(policy.ParsePolicy(GOOD_HEADER_1 + GOOD_TERM_17,
                                                  self.mock_naming), EXP_INFO))
   mock_logging_warn.assert_called_once_with(
       'Term %s in policy is too long (>%d characters) and will be'
       ' truncated', 'good-term-17', nftables.Term.MAX_CHARACTERS)
   # Ensure that the truncate did happen and stripped off the ':'
   self.assertIn('comment "%(long_line)s' % {'long_line': 'A' *127}, nft)
Example #10
0
 def testIcmpv6InetMismatch(self, mock_logging_debug):
   str(nftables.Nftables(policy.ParsePolicy(GOOD_HEADER_1 + IPV6_TERM_1,
                                            self.mock_naming), EXP_INFO))
   mock_logging_debug.assert_called_once_with('Term inet6-icmp will not be '
                                              'rendered, as it has '
                                              '[\'icmpv6\'] match specified '
                                              'but the ACL is of inet address '
                                              'family.')
Example #11
0
 def testSingleDport(self):
     destination_ports = ['25']
     self.mock_naming.GetServiceByProto.return_value = destination_ports
     nft = str(
         nftables.Nftables(
             policy.ParsePolicy(GOOD_HEADER_1 + GOOD_TERM_7,
                                self.mock_naming), EXP_INFO))
     self.assertIn('dport 25', nft)
Example #12
0
 def testSingleSourceDestIp(self):
   source_network = [nacaddr.IPv4('172.16.0.0/24')]
   destination_network = [nacaddr.IPv4('10.0.0.0/24')]
   self.mock_naming.GetNetAddr.side_effect = [source_network,
                                              destination_network]
   nft = str(nftables.Nftables(policy.ParsePolicy(GOOD_HEADER_1 + GOOD_TERM_2,
                                                  self.mock_naming), EXP_INFO))
   self.assertIn('ip saddr 172.16.0.0/24 ip daddr 10.0.0.0/24', nft)
Example #13
0
 def testSingleSport(self):
     source_ports = ['25']
     self.mock_naming.GetServiceByProto.return_value = source_ports
     nft = str(
         nftables.Nftables(
             policy.ParsePolicy(GOOD_HEADER_1 + GOOD_TERM_8,
                                self.mock_naming), EXP_INFO))
     self.assertIn('sport 25', nft)
Example #14
0
 def testExpiring(self, mock_logging_info):
   exp_date = datetime.date.today() + datetime.timedelta(weeks=EXP_INFO)
   nftables.Nftables(policy.ParsePolicy(GOOD_HEADER_1 + EXPIRING_TERM %
                                        exp_date.strftime('%Y-%m-%d'),
                                        self.mock_naming), EXP_INFO)
   mock_logging_info.assert_called_once_with('Term %s in policy %s '
                                             'expires in less than %d weeks.',
                                             'is_expiring', 'chain_name',
                                             EXP_INFO)
Example #15
0
 def testSourceDestExclude(self):
   source_network = [nacaddr.IPv4('192.168.0.0/24')]
   source_exclude_network = [nacaddr.IPv4('192.168.0.0/27')]
   destination_network = [nacaddr.IPv4('10.0.0.0/24')]
   destination_exclude_network = [nacaddr.IPv4('10.0.0.0/27')]
   self.mock_naming.GetNetAddr.side_effect = [source_network,
                                              source_exclude_network,
                                              destination_network,
                                              destination_exclude_network]
   nft = str(nftables.Nftables(policy.ParsePolicy(GOOD_HEADER_1 + GOOD_TERM_11,
                                                  self.mock_naming), EXP_INFO))
   self.assertIn('ip saddr { 192.168.0.32/27, 192.168.0.64/26, '
                 '192.168.0.128/25}', nft)
   self.assertIn('ip daddr { 10.0.0.32/27, 10.0.0.64/26, '
                 '10.0.0.128/25}', nft)
Example #16
0
def RenderFile(input_file, output_directory, definitions, exp_info,
               write_files):
    """Render a single file.

  Args:
    input_file: the name of the input policy file.
    output_directory: the directory in which we place the rendered file.
    definitions: the definitions from naming.Naming().
    exp_info: print a info message when a term is set to expire
              in that many weeks.
    write_files: a list of file tuples, (output_file, acl_text), to write
  """
    logging.debug('rendering file: %s into %s', input_file, output_directory)
    pol = None
    jcl = False
    acl = False
    asacl = False
    aacl = False
    bacl = False
    eacl = False
    gcefw = False
    ips = False
    ipt = False
    spd = False
    nsx = False
    pcap_accept = False
    pcap_deny = False
    pf = False
    srx = False
    jsl = False
    nft = False
    win_afw = False
    xacl = False
    paloalto = False

    try:
        conf = open(input_file).read()
        logging.debug('opened and read %s', input_file)
    except IOError as e:
        logging.warn('bad file: \n%s', e)
        raise

    try:
        pol = policy.ParsePolicy(conf,
                                 definitions,
                                 optimize=FLAGS.optimize,
                                 base_dir=FLAGS.base_directory,
                                 shade_check=FLAGS.shade_check)
    except policy.ShadingError as e:
        logging.warn('shading errors for %s:\n%s', input_file, e)
        return
    except (policy.Error, naming.Error):
        raise ACLParserError(
            'Error parsing policy file %s:\n%s%s' %
            (input_file, sys.exc_info()[0], sys.exc_info()[1]))

    platforms = set()
    for header in pol.headers:
        platforms.update(header.platforms)

    if 'juniper' in platforms:
        jcl = copy.deepcopy(pol)
    if 'cisco' in platforms:
        acl = copy.deepcopy(pol)
    if 'ciscoasa' in platforms:
        asacl = copy.deepcopy(pol)
    if 'brocade' in platforms:
        bacl = copy.deepcopy(pol)
    if 'arista' in platforms:
        eacl = copy.deepcopy(pol)
    if 'aruba' in platforms:
        aacl = copy.deepcopy(pol)
    if 'ipset' in platforms:
        ips = copy.deepcopy(pol)
    if 'iptables' in platforms:
        ipt = copy.deepcopy(pol)
    if 'nsxv' in platforms:
        nsx = copy.deepcopy(pol)
    if 'packetfilter' in platforms:
        pf = copy.deepcopy(pol)
    if 'pcap' in platforms:
        pcap_accept = copy.deepcopy(pol)
        pcap_deny = copy.deepcopy(pol)
    if 'speedway' in platforms:
        spd = copy.deepcopy(pol)
    if 'srx' in platforms:
        srx = copy.deepcopy(pol)
    if 'srxlo' in platforms:
        jsl = copy.deepcopy(pol)
    if 'windows_advfirewall' in platforms:
        win_afw = copy.deepcopy(pol)
    if 'ciscoxr' in platforms:
        xacl = copy.deepcopy(pol)
    if 'nftables' in platforms:
        nft = copy.deepcopy(pol)
    if 'gce' in platforms:
        gcefw = copy.deepcopy(pol)
    if 'paloalto' in platforms:
        paloalto = copy.deepcopy(pol)

    if not output_directory.endswith('/'):
        output_directory += '/'

    try:
        if jcl:
            acl_obj = juniper.Juniper(jcl, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if srx:
            acl_obj = junipersrx.JuniperSRX(srx, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if acl:
            acl_obj = cisco.Cisco(acl, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if asacl:
            acl_obj = ciscoasa.CiscoASA(acl, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if aacl:
            acl_obj = aruba.Aruba(aacl, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if bacl:
            acl_obj = brocade.Brocade(bacl, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if eacl:
            acl_obj = arista.Arista(eacl, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if ips:
            acl_obj = ipset.Ipset(ips, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if ipt:
            acl_obj = iptables.Iptables(ipt, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if nsx:
            acl_obj = nsxv.Nsxv(nsx, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if spd:
            acl_obj = speedway.Speedway(spd, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if pcap_accept:
            acl_obj = pcap.PcapFilter(pcap_accept, exp_info)
            RenderACL(str(acl_obj), '-accept' + acl_obj.SUFFIX,
                      output_directory, input_file, write_files)
        if pcap_deny:
            acl_obj = pcap.PcapFilter(pcap_deny, exp_info, invert=True)
            RenderACL(str(acl_obj), '-deny' + acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if pf:
            acl_obj = packetfilter.PacketFilter(pf, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if win_afw:
            acl_obj = windows_advfirewall.WindowsAdvFirewall(win_afw, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if jsl:
            acl_obj = srxlo.SRXlo(jsl, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if xacl:
            acl_obj = ciscoxr.CiscoXR(xacl, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if nft:
            acl_obj = nftables.Nftables(nft, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if gcefw:
            acl_obj = gce.GCE(gcefw, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if paloalto:
            acl_obj = paloaltofw.PaloAltoFW(paloalto, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
    # TODO(robankeny) add additional errors.
    except (juniper.Error, junipersrx.Error, cisco.Error, ipset.Error,
            iptables.Error, speedway.Error, pcap.Error, aclgenerator.Error,
            aruba.Error, nftables.Error, gce.Error):
        raise ACLGeneratorError(
            'Error generating target ACL for %s:\n%s%s' %
            (input_file, sys.exc_info()[0], sys.exc_info()[1]))
Example #17
0
 def testLogAndLogNameTerm(self):
     nft = str(
         nftables.Nftables(
             policy.ParsePolicy(GOOD_HEADER_1 + GOOD_TERM_15,
                                self.mock_naming), EXP_INFO))
     self.assertIn('log prefix "my log prefix: " ', nft)
Example #18
0
 def testCounterTerm(self):
     nft = str(
         nftables.Nftables(
             policy.ParsePolicy(GOOD_HEADER_1 + GOOD_TERM_16,
                                self.mock_naming), EXP_INFO))
     self.assertIn(' icmp counter accept', nft)
Example #19
0
 def testIcmpType(self):
   nft = str(nftables.Nftables(policy.ParsePolicy(GOOD_HEADER_1 + GOOD_TERM_9,
                                                  self.mock_naming), EXP_INFO))
   self.assertIn('icmp type { echo-reply, echo-request}', nft)
Example #20
0
 def testMultiProtocol(self):
   nft = str(nftables.Nftables(policy.ParsePolicy(GOOD_HEADER_1 + GOOD_TERM_6,
                                                  self.mock_naming), EXP_INFO))
   self.assertIn('ip protocol ah', nft)
Example #21
0
 def testSingleProtocol(self):
   nft = str(nftables.Nftables(policy.ParsePolicy(GOOD_HEADER_1 + GOOD_TERM_5,
                                                  self.mock_naming), EXP_INFO))
   self.assertIn('ip protocol { ah, esp}', nft)