Esempio n. 1
0
def get_vars_to_restore(ckpt=None):
  """Returns list of variables that should be saved/restored.

  Args:
    ckpt: Path to existing checkpoint.  If present, returns only the subset of
        variables that exist in given checkpoint.

  Returns:
    List of all variables that need to be saved/restored.
  """
  model_vars = tf.trainable_variables()
  # Add batchnorm variables.
  bn_vars = [v for v in tf.global_variables()
             if 'moving_mean' in v.op.name or 'moving_variance' in v.op.name]
  model_vars.extend(bn_vars)
  model_vars = sorted(model_vars, key=lambda x: x.op.name)
  if ckpt is not None:
    ckpt_var_names = tf.contrib.framework.list_variables(ckpt)
    ckpt_var_names = [name for (name, unused_shape) in ckpt_var_names]
    for v in model_vars:
      if v.op.name not in ckpt_var_names:
        logging.warn('Missing var %s in checkpoint: %s', v.op.name,
                     os.path.basename(ckpt))
    model_vars = [v for v in model_vars if v.op.name in ckpt_var_names]
  return model_vars
Esempio n. 2
0
def _WriteFile(output_file, file_string):
  try:
    with open(output_file, 'w') as output:
      logging.info('writing file: %s', output_file)
      output.write(file_string)
  except IOError:
    logging.warn('error while writing file: %s', output_file)
    raise
Esempio n. 3
0
  def _TranslatePolicy(self, pol, exp_info):
    self.juniper_policies = []
    current_date = datetime.datetime.utcnow().date()
    exp_info_date = current_date + datetime.timedelta(weeks=exp_info)

    for header, terms in pol.filters:
      if self._PLATFORM not in header.platforms:
        continue

      filter_options = header.FilterOptions(self._PLATFORM)
      filter_name = header.FilterName(self._PLATFORM)

      # Check for the position independent options and remove them from
      # the list.
      interface_specific = 'not-interface-specific' not in filter_options[1:]
      enable_dsmo = 'enable_dsmo' in filter_options[1:]
      noverbose = 'noverbose' in filter_options[1:]

      if not interface_specific:
        filter_options.remove('not-interface-specific')
      if enable_dsmo:
        filter_options.remove('enable_dsmo')

      # default to inet4 filters
      filter_type = 'inet'
      if len(filter_options) > 1:
        filter_type = filter_options[1]

      term_names = set()
      new_terms = []
      for term in terms:
        term.name = self.FixTermLength(term.name)
        if term.name in term_names:
          raise JuniperDuplicateTermError('You have multiple terms named: %s' %
                                          term.name)
        term_names.add(term.name)

        term = self.FixHighPorts(term, af=filter_type)
        if not term:
          continue

        if term.expiration:
          if term.expiration <= exp_info_date:
            logging.info('INFO: Term %s in policy %s expires '
                         'in less than two weeks.', term.name, filter_name)
          if term.expiration <= current_date:
            logging.warn('WARNING: Term %s in policy %s is expired and '
                         'will not be rendered.', term.name, filter_name)
            continue

        new_terms.append(self._TERM(term, filter_type, enable_dsmo, noverbose))

      self.juniper_policies.append((header, filter_name, filter_type,
                                    interface_specific, new_terms))
Esempio n. 4
0
File: util.py Progetto: pcm17/models
def get_vars_to_save_and_restore(ckpt=None):
  """Returns list of variables that should be saved/restored.

  Args:
    ckpt: Path to existing checkpoint.  If present, returns only the subset of
        variables that exist in given checkpoint.

  Returns:
    List of all variables that need to be saved/restored.
  """
  model_vars = tf.trainable_variables()
  # Add batchnorm variables.
  bn_vars = [v for v in tf.global_variables()
             if 'moving_mean' in v.op.name or 'moving_variance' in v.op.name or
             'mu' in v.op.name or 'sigma' in v.op.name or
             'global_scale_var' in v.op.name]
  model_vars.extend(bn_vars)
  model_vars = sorted(model_vars, key=lambda x: x.op.name)
  mapping = {}
  if ckpt is not None:
    ckpt_var = tf.contrib.framework.list_variables(ckpt)
    ckpt_var_names = [name for (name, unused_shape) in ckpt_var]
    ckpt_var_shapes = [shape for (unused_name, shape) in ckpt_var]
    not_loaded = list(ckpt_var_names)
    for v in model_vars:
      if v.op.name not in ckpt_var_names:
        # For backward compatibility, try additional matching.
        v_additional_name = v.op.name.replace('egomotion_prediction/', '')
        if v_additional_name in ckpt_var_names:
          # Check if shapes match.
          ind = ckpt_var_names.index(v_additional_name)
          if ckpt_var_shapes[ind] == v.get_shape():
            mapping[v_additional_name] = v
            not_loaded.remove(v_additional_name)
            continue
          else:
            logging.warn('Shape mismatch, will not restore %s.', v.op.name)
        logging.warn('Did not find var %s in checkpoint: %s', v.op.name,
                     os.path.basename(ckpt))
      else:
        # Check if shapes match.
        ind = ckpt_var_names.index(v.op.name)
        if ckpt_var_shapes[ind] == v.get_shape():
          mapping[v.op.name] = v
          not_loaded.remove(v.op.name)
        else:
          logging.warn('Shape mismatch, will not restore %s.', v.op.name)
    if not_loaded:
      logging.warn('The following variables in the checkpoint were not loaded:')
      for varname_not_loaded in not_loaded:
        logging.info('%s', varname_not_loaded)
  else:  # just get model vars.
    for v in model_vars:
      mapping[v.op.name] = v
  return mapping
Esempio n. 5
0
def DescendRecursively(input_dirname, output_dirname, definitions, depth=1):
  """Recursively descend from input_dirname looking for policy files to render.

  Args:
    input_dirname: the base directory.
    output_dirname: where to place the rendered files.
    definitions: naming.Naming object.
    depth: integer, for outputting '---> rendering prod/corp-backbone.jcl'.

  Returns:
    the files that were found
  """
  # p4 complains if you try to edit a file like ./corp//corp-isp.jcl
  input_dirname = input_dirname.rstrip('/')
  output_dirname = output_dirname.rstrip('/')

  files = []
  # calling all directories
  for curdir in [x for x in dircache.listdir(input_dirname) if
                 os.path.isdir(input_dirname + '/' + x)]:
    # be on the lookout for a policy directory
    if curdir == 'pol':
      for input_file in [x for x in dircache.listdir(input_dirname + '/pol')
                         if x.endswith('.pol')]:
        files.append({'in_file': os.path.join(input_dirname, 'pol', input_file),
                      'out_dir': output_dirname,
                      'defs': definitions})
    else:
      # so we don't have a policy directory, we should check if this new
      # directory has a policy directory
      if curdir in FLAGS.ignore_directories:
        continue
      logging.warn('-' * (2 * depth) + '> %s' % (
          input_dirname + '/' + curdir))
      files_found = DescendRecursively(input_dirname + '/' + curdir,
                                       output_dirname + '/' + curdir,
                                       definitions, depth + 1)
      logging.warn('-' * (2 * depth) + '> %s (%d pol files found)' % (
          input_dirname + '/' + curdir, len(files_found)))
      files.extend(files_found)

  return files
Esempio n. 6
0
  def __str__(self):
    """Render the output of the Nsxv policy."""

    target_header = []
    target = []

    # add the p4 tags
    target.append('<!--')
    target.extend(aclgenerator.AddRepositoryTags('\n'))
    target.append('\n')
    target.append('-->')

    for (_, _, _, terms) in self.nsxv_policies:
      section_name = self._FILTER_OPTIONS_DICT['section_name']
      # check section id value
      section_id = self._FILTER_OPTIONS_DICT['section_id']
      if not section_id or section_id == 0:
        logging.warn('WARNING: Section-id is 0. A new Section is created for%s.'
                     ' If there is any existing section, it will remain '
                     'unreferenced and should be removed manually.',
                     section_name)
        target.append('<section name="%s">' % (section_name.strip(' \t\n\r')))
      else:
        target.append('<section id="%s" name="%s">' %
                      (section_id, section_name.strip(' \t\n\r')))

      # now add the terms
      for term in terms:
        term_str = str(term)
        if term_str:
          target.append(term_str)

      # ensure that the header is always first
      target = target_header + target
      target.append('%s' % (_XML_TABLE.get('sectionEnd')))
      target.append('\n')

      target_as_xml = xml.dom.minidom.parseString(''.join(target))
    return target_as_xml.toprettyxml(indent='  ', encoding='UTF-8')
Esempio n. 7
0
 def __init__(self, term, filter_name, platform='cisco'):
   self.term = term
   self.filter_name = filter_name
   self.platform = platform
   self.options = []
   self.logstring = ''
   self.dscpstring = ''
   # sanity checking for standard acls
   if self.term.protocol:
     raise StandardAclTermError(
         'Standard ACLs cannot specify protocols')
   if self.term.icmp_type:
     raise StandardAclTermError(
         'ICMP Type specifications are not permissible in standard ACLs')
   if (self.term.source_address
       or self.term.source_address_exclude
       or self.term.destination_address
       or self.term.destination_address_exclude):
     raise StandardAclTermError(
         'Standard ACLs cannot use source or destination addresses')
   if self.term.option:
     raise StandardAclTermError(
         'Standard ACLs prohibit use of options')
   if self.term.source_port or self.term.destination_port:
     raise StandardAclTermError(
         'Standard ACLs prohibit use of port numbers')
   if self.term.logging:
     logging.warn(
         'WARNING: Standard ACL logging is set in filter %s, term %s and '
         'may not implemented on all IOS versions', self.filter_name,
         self.term.name)
     self.logstring = ' log'
   if self.term.dscp_match:
     logging.warn(
         'WARNING: dscp-match is set in filter %s, term %s and may not be '
         'implemented on all IOS version', self.filter_name, self.term_name)
     self.dscpstring = ' dscp' + self.term_dscp_match
Esempio n. 8
0
  def __str__(self):
    """Convert term to a rule string.

    Returns:
      A rule as a string.

    Raises:
      NsxvAclTermError: When unknown icmp-types are specified

    """
    # Verify platform specific terms. Skip whole term if platform does not
    # match.
    if self.term.platform:
      if 'nsxv' not in self.term.platform:
        return ''
    if self.term.platform_exclude:
      if 'nsxv' in self.term.platform_exclude:
        return ''

    ret_str = ['']

    # Don't render icmpv6 protocol terms under inet, or icmp under inet6
    if ((self.af == 6 and 'icmp' in self.term.protocol) or
        (self.af == 4 and 'icmpv6' in self.term.protocol)):
      logging.debug(self.NO_AF_LOG_PROTO.substitute(term=self.term.name,
                                                    proto=self.term.protocol,
                                                    af=self.filter_type))
      return ''

    # Term verbatim is not supported
    if self.term.verbatim:
      raise NsxvAclTermError(
          'Verbatim are not implemented in standard ACLs')

    # Term option is not supported
    if self.term.option:
      for opt in [str(single_option) for single_option in self.term.option]:
        if((opt.find('tcp-established') == 0)
           or (opt.find('established') == 0)):
          return ''
        else:
          raise NsxvAclTermError(
              'Option are not implemented in standard ACLs')

    # check for keywords Nsxv does not support
    term_keywords = self.term.__dict__
    unsupported_keywords = []
    for key  in term_keywords:
      if term_keywords[key]:
        # translated is obj attribute not keyword
        if ('translated' not in key) and (key not in _NSXV_SUPPORTED_KEYWORDS):
          unsupported_keywords.append(key)
    if unsupported_keywords:
      logging.warn('WARNING: The keywords %s in Term %s are not supported in '
                   'Nsxv ', unsupported_keywords, self.term.name)

    name = '%s%s%s' % (_XML_TABLE.get('nameStart'), self.term.name,
                       _XML_TABLE.get('nameEnd'))

    notes = ''
    if self.term.comment:
      for comment in self.term.comment:
        notes = '%s%s' %(notes, comment)
      notes = '%s%s%s' % (_XML_TABLE.get('noteStart'), notes,
                          _XML_TABLE.get('noteEnd'))

    # protocol
    protocol = None

    if self.term.protocol:
      protocol = map(self.PROTO_MAP.get, self.term.protocol, self.term.protocol)

      # icmp-types
      icmp_types = ['']
      if self.term.icmp_type:
        icmp_types = self.NormalizeIcmpTypes(self.term.icmp_type,
                                             self.term.protocol,
                                             self.af)

    # for mixed filter type get both IPV4address and IPv6Address
    af_list = []
    if self.filter_type == 'mixed':
      af_list = [4, 6]
    else:
      af_list = [self.af]

    source_address = None
    destination_address = None
    source_addr = []
    destination_addr = []

    source_v4_addr = []
    source_v6_addr = []
    dest_v4_addr = []
    dest_v6_addr = []

    for af in af_list:
      # source address
      if self.term.source_address:
        source_address = self.term.GetAddressOfVersion('source_address', af)
        source_address_exclude = self.term.GetAddressOfVersion(
            'source_address_exclude', af)
        if source_address_exclude:
          source_address = nacaddr.ExcludeAddrs(
              source_address,
              source_address_exclude)

        if source_address:
          if af == 4:
            source_v4_addr = source_address
          else:
            source_v6_addr = source_address
        source_addr = source_v4_addr + source_v6_addr

      # destination address
      if self.term.destination_address:
        destination_address = self.term.GetAddressOfVersion(
            'destination_address', af)
        destination_address_exclude = self.term.GetAddressOfVersion(
            'destination_address_exclude', af)
        if destination_address_exclude:
          destination_address = nacaddr.ExcludeAddrs(
              destination_address,
              destination_address_exclude)

        if destination_address:
          if af == 4:
            dest_v4_addr = destination_address
          else:
            dest_v6_addr = destination_address
        destination_addr = dest_v4_addr + dest_v6_addr

    # Check for mismatch IP for source and destination address for mixed filter
    if self.filter_type == 'mixed':
      if source_addr and destination_addr:
        if source_v4_addr and not dest_v4_addr:
          source_addr = source_v6_addr
        elif source_v6_addr and not dest_v6_addr:
          source_addr = source_v4_addr
        elif dest_v4_addr and not source_v4_addr:
          destination_addr = dest_v6_addr
        elif dest_v6_addr and not source_v6_addr:
          destination_addr = dest_v4_addr

        if not source_addr or not destination_addr:
          logging.warn('Term %s will not be rendered as it has IPv4/IPv6 '
                       'mismatch for source/destination for mixed address '
                       'family.', self.term.name)
          return ''

    # ports
    source_port = None
    destination_port = None
    if self.term.source_port:
      source_port = self.term.source_port
    if self.term.destination_port:
      destination_port = self.term.destination_port

    # logging
    log = 'false'
    if self.term.logging:
      log = 'true'

    sources = ''
    if source_addr:
      sources = '<sources excluded="false">'
      for saddr in source_addr:

        # inet4
        if isinstance(saddr, nacaddr.IPv4):
          if saddr.numhosts > 1:
            saddr = '%s%s%s' % (_XML_TABLE.get('srcIpv4Start'),
                                saddr.with_prefixlen,
                                _XML_TABLE.get('srcIpv4End'),)
          else:
            saddr = '%s%s%s' % (_XML_TABLE.get('srcIpv4Start'),
                                saddr.ip,
                                _XML_TABLE.get('srcIpv4End'))
          sources = '%s%s' %(sources, saddr)
        # inet6
        if isinstance(saddr, nacaddr.IPv6):
          if saddr.numhosts > 1:
            saddr = '%s%s%s' % (_XML_TABLE.get('srcIpv6Start'),
                                saddr.with_prefixlen,
                                _XML_TABLE.get('srcIpv6End'),)
          else:
            saddr = '%s%s%s' % (_XML_TABLE.get('srcIpv6Start'),
                                saddr.ip, _XML_TABLE.get('srcIpv6End'))
          sources = '%s%s' %(sources, saddr)
      sources = '%s%s' %(sources, '</sources>')

    destinations = ''
    if destination_addr:
      destinations = '<destinations excluded="false">'
      for daddr in destination_addr:
        # inet4
        if isinstance(daddr, nacaddr.IPv4):
          if daddr.numhosts > 1:
            daddr = '%s%s%s' % (_XML_TABLE.get('destIpv4Start'),
                                daddr.with_prefixlen,
                                _XML_TABLE.get('destIpv4End'),)
          else:
            daddr = '%s%s%s' % (_XML_TABLE.get('destIpv4Start'),
                                daddr.ip,
                                _XML_TABLE.get('destIpv4End'))
          destinations = '%s%s' %(destinations, daddr)
        # inet6
        if isinstance(daddr, nacaddr.IPv6):
          if daddr.numhosts > 1:
            daddr = '%s%s%s' % (_XML_TABLE.get('destIpv6Start'),
                                daddr.with_prefixlen,
                                _XML_TABLE.get('destIpv6End'),)
          else:
            daddr = '%s%s%s' % (_XML_TABLE.get('destIpv6Start'),
                                daddr.ip,
                                _XML_TABLE.get('destIpv6End'))
          destinations = '%s%s' %(destinations, daddr)
      destinations = '%s%s' %(destinations, '</destinations>')

    services = []
    if protocol:
      services.append('<services>')
      for proto in protocol:
        if proto != 'any':
          services.append(self._ServiceToString(proto,
                                                source_port,
                                                destination_port,
                                                icmp_types))
      services.append('</services>')

    service = ''
    for s in services:
      service = '%s%s' % (service, s)

    # applied_to
    applied_to_list = ''
    if self.applied_to:
      applied_to_list = '<appliedToList>'
      applied_to_element = '%s%s%s' % (_XML_TABLE.get('appliedToStart'),
                                       self.applied_to,
                                       _XML_TABLE.get('appliedToEnd'))
      applied_to_list = '%s%s' %(applied_to_list, applied_to_element)
      applied_to_list = '%s%s' %(applied_to_list, '</appliedToList>')

    # action
    action = '%s%s%s' % (_XML_TABLE.get('actionStart'),
                         _ACTION_TABLE.get(str(self.term.action[0])),
                         _XML_TABLE.get('actionEnd'))

    ret_lines = []
    ret_lines.append('<rule logged="%s">%s%s%s%s%s%s%s</rule>' %
                     (log, name, action, sources, destinations, service,
                      applied_to_list, notes))

    # remove any trailing spaces and replace multiple spaces with singles
    stripped_ret_lines = [re.sub(r'\s+', ' ', x).rstrip() for x in ret_lines]
    ret_str.extend(stripped_ret_lines)
    return ''.join(ret_str)
Esempio n. 9
0
 def _WarnIfCustomTarget(self, target):
   """Emit a warning if a policy's default target is not a built-in chain."""
   if target not in self._GOOD_FILTERS:
     logging.warn('Filter is generating a non-standard chain that will not '
                  'apply to traffic unless linked from INPUT, OUTPUT or '
                  'FORWARD filters. New chain name is: %s', target)
Esempio n. 10
0
  def _TranslatePolicy(self, policy, expiration):
    """Translates policy contents to platform specific data structures.

    Args:
      policy: policy object to be transalted to platform specific data
        structures.
      expiration: integer number of weeks to be warned about term expiration in.

    Raises:
      InvalidTargetOption: if supplied target options are invalid.

    """
    self.tables = collections.defaultdict(list)

    for header, terms in policy.filters:
      if self._PLATFORM not in header.platforms:
        continue

      filter_options = header.FilterOptions(self._PLATFORM)

      if not filter_options:
        raise InvalidTargetOption('Chain name not specified.')

      if len(filter_options) > 4:
        raise InvalidTargetOption('Too many target options.')

      if len(filter_options) == 1:
        raise InvalidTargetOption(
            'Must have at least hook name')

      # Chain name, mandatory
      chain_name = filter_options[0]

      # Hook name, mandatory
      hook_name = filter_options[1].lower()

      if hook_name not in self._VALID_HOOK_NAMES:
        raise InvalidTargetOption(
            'Specified hook name (%s) is not a valid hook name.' % hook_name)

      # chain priority, mandatory
      chain_priority = None
      if len(filter_options) >= 3:
        try:
          chain_priority = str(int(filter_options[2]))
        except ValueError:
          raise InvalidTargetOption(
              'Specified chain priority is not an integer (%s).'
              % filter_options[2])

      # TODO(castagno): fix this. If you dont have hook name it never prints
      # anyways, so its not really optional
      if not hook_name or not chain_priority:
        logging.info('Chain %s is a non-base chain, make sure it is linked.',
                     chain_name)
        raise InvalidTargetOption('A table name is required')

      # Address family, optional, defaults to capirca inet
      af = 'inet'
      if len(filter_options) == 4:
        af = filter_options[3]
        if af not in self._VALID_ADDRESS_FAMILIES:
          raise InvalidTargetOption(
              'Specified address family (%s) is not supported.' % af)

      # Terms
      valid_terms = []
      for term in terms:
        term = self.FixHighPorts(term, af)
        if not term:
          continue

        current_date = datetime.datetime.utcnow().date()
        expiration_date = current_date + datetime.timedelta(weeks=expiration)

        if term.expiration:
          if term.expiration < current_date:
            logging.warn(
                'Term %s in policy %s is expired and will not be rendered.',
                term.name, chain_name)
            continue
          if term.expiration <= expiration_date:
            logging.info('Term %s in policy %s expires in less than %d weeks.',
                         term.name, chain_name, expiration)

        valid_terms.append(self._TERM(term, af))

      # Add to appropriate table
      self.tables[af].append((chain_name, hook_name,
                              chain_priority, valid_terms))
Esempio n. 11
0
  def _TranslatePolicy(self, pol, exp_info):
    # pylint: disable=attribute-defined-outside-init
    """Transform a policy object into a JuniperSRX object.

    Args:
      pol: policy.Policy object
      exp_info: print a info message when a term is set to expire
                in that many weeks

    Raises:
      UnsupportedFilterError: An unsupported filter was specified
      UnsupportedHeader: A header option exists that is not understood/usable
      SRXDuplicateTermError: Two terms were found with same name in same filter
      ConflictingTargetOptions: Two target options are conflicting in the header
      MixedAddrBookTypes: Global and Zone address books in the same policy
      ConflictingApplicationSets: When two duplicate named terms have
                                  conflicting application entries
    """
    self.srx_policies = []
    self.addressbook = collections.OrderedDict()
    self.applications = []
    self.ports = []
    self.from_zone = ''
    self.to_zone = ''
    self.addr_book_type = set()

    current_date = datetime.datetime.utcnow().date()
    exp_info_date = current_date + datetime.timedelta(weeks=exp_info)

    for header, terms in pol.filters:
      if self._PLATFORM not in header.platforms:
        continue

      filter_options = header.FilterOptions(self._PLATFORM)

      if (len(filter_options) < 4 or filter_options[0] != 'from-zone' or
          filter_options[2] != 'to-zone'):
        raise UnsupportedFilterError('SRX filter arguments must specify '
                                     'from-zone and to-zone.')

      # check if to-zone is not a supported target option
      if filter_options[1] in self._SUPPORTED_TARGET_OPTIONS:
        raise UnsupportedFilterError('to-zone %s cannot be the same as any '
                                     'valid SRX target-options' %
                                     (filter_options[1]))
      else:
        self.from_zone = filter_options[1]

      # check if from-zone is not a supported target option
      if filter_options[3] in self._SUPPORTED_TARGET_OPTIONS:
        raise UnsupportedFilterError('from-zone %s cannot be the same as any '
                                     'valid SRX target-options' %
                                     (filter_options[3]))
      else:
        self.to_zone = filter_options[3]

      # variables used to collect target-options and set defaults
      target_options = []
      filter_type = ''

      # parse srx target options
      extra_options = filter_options[4:]
      if self._SUPPORTED_TARGET_OPTIONS.issubset(extra_options):
        raise ConflictingTargetOptions('only one address-book-type can '
                                       'be specified per header "%s"' %
                                       ' '.join(filter_options))
      else:
        address_book_type = self._SUPPORTED_TARGET_OPTIONS.intersection(
            extra_options)
        if len(address_book_type) is 0:
          address_book_type = {self._GLOBAL_ADDR_BOOK}
        self.addr_book_type.update(address_book_type)
        if len(self.addr_book_type) > 1:
          raise MixedAddrBookTypes('Global and Zone address-book-types cannot '
                                   'be used in the same policy')
        if self.from_zone == 'all' and self.to_zone == 'all':
          if self._ZONE_ADDR_BOOK in self.addr_book_type:
            raise UnsupportedFilterError('Zone address books cannot be used '
                                         'with a global policy.')
        elif self.from_zone == 'all' or self.to_zone == 'all':
          raise UnsupportedFilterError('The zone name all is reserved for '
                                       'global policies.')

      for filter_opt in filter_options[4:]:

          # validate address families
        if filter_opt in self._SUPPORTED_AF:
          if not filter_type:
            filter_type = filter_opt
          else:
            raise ConflictingTargetOptions('only one address family can be '
                                           'specified per header "%s"' %
                                           ' '.join(filter_options))

        elif filter_opt in self._SUPPORTED_TARGET_OPTIONS:
          target_options.append(filter_opt)

        else:
          raise UnsupportedHeader('SRX Generator currently does not support '
                                  '%s as a header option "%s"' %
                                  (filter_opt, ' '.join(filter_options)))

      # if address-family and address-book-type have not been set then default
      if not filter_type:
        filter_type = 'mixed'

      term_dup_check = set()
      new_terms = []
      self._FixLargePolices(terms, filter_type)
      for term in terms:
        term.name = self.FixTermLength(term.name)
        if term.name in term_dup_check:
          raise SRXDuplicateTermError('You have a duplicate term: %s'
                                      % term.name)
        term_dup_check.add(term.name)

        if term.expiration:
          if term.expiration <= exp_info_date:
            logging.info('INFO: Term %s in policy %s>%s expires '
                         'in less than two weeks.', term.name, self.from_zone,
                         self.to_zone)
          if term.expiration <= current_date:
            logging.warn('WARNING: Term %s in policy %s>%s is expired.',
                         term.name, self.from_zone, self.to_zone)
            continue

        # SRX address books leverage network token names for IPs.
        # When excluding addresses, we lose those distinct names so we need
        # to create a new unique name based off the term name before excluding.
        if term.source_address_exclude:
          # If we have a naked source_exclude, we need something to exclude from
          if not term.source_address:
            term.source_address = [nacaddr.IP('0.0.0.0/0',
                                              term.name.upper(),
                                              term.name.upper())]
          # Use the term name as the token & parent_token
          new_src_parent_token = term.name.upper() + '_SRC_EXCLUDE'
          new_src_token = new_src_parent_token
          for i in term.source_address_exclude:
            term.source_address = nacaddr.RemoveAddressFromList(
                term.source_address, i)
            for i in term.source_address:
              i.token = new_src_token
              i.parent_token = new_src_parent_token

        if term.destination_address_exclude:
          if not term.destination_address:
            term.destination_address = [nacaddr.IP('0.0.0.0/0',
                                                   term.name.upper(),
                                                   term.name.upper())]
          new_dst_parent_token = term.name.upper() + '_DST_EXCLUDE'
          new_dst_token = new_dst_parent_token
          for i in term.destination_address_exclude:
            term.destination_address = nacaddr.RemoveAddressFromList(
                term.destination_address, i)
            for i in term.destination_address:
              i.token = new_dst_token
              i.parent_token = new_dst_parent_token

        # SRX policies are controlled by addresses that are used within, so
        # policy can be at the same time inet and inet6.
        if self._GLOBAL_ADDR_BOOK in self.addr_book_type:
          for zone in self.addressbook:
            for unused_name, ips in sorted(self.addressbook[zone].iteritems()):
              ips = [i for i in ips]
              if term.source_address == ips:
                term.source_address = ips
              if term.destination_address == ips:
                term.destination_address = ips
        for addr in term.source_address:
          if addr.version in self._AF_MAP[filter_type]:
            self._BuildAddressBook(self.from_zone, addr)
        for addr in term.destination_address:
          if addr.version in self._AF_MAP[filter_type]:
            self._BuildAddressBook(self.to_zone, addr)

        new_term = Term(term, filter_options)
        new_terms.append(new_term)

        # Because SRX terms can contain inet and inet6 addresses. We have to
        # have ability to recover proper AF for ICMP type we need.
        # If protocol is empty or we cannot map to inet or inet6 we insert bogus
        # af_type name which will cause new_term.NormalizeIcmpTypes to fail.
        if not term.protocol:
          icmp_af_type = 'unknown_af_icmp'
        else:
          icmp_af_type = self._AF_ICMP_MAP.get(
              term.protocol[0], 'unknown_af_icmp')
        tmp_icmptype = new_term.NormalizeIcmpTypes(
            term.icmp_type, term.protocol, icmp_af_type)
        # NormalizeIcmpTypes returns [''] for empty, convert to [] for eval
        normalized_icmptype = tmp_icmptype if tmp_icmptype != [''] else []
        # rewrites the protocol icmpv6 to icmp6
        if 'icmpv6' in term.protocol:
          protocol = list(term.protocol)
          protocol[protocol.index('icmpv6')] = 'icmp6'
        else:
          protocol = term.protocol
        new_application_set = {'sport': self._BuildPort(term.source_port),
                               'dport': self._BuildPort(term.destination_port),
                               'protocol': protocol,
                               'icmp-type': normalized_icmptype,
                               'timeout': term.timeout}

        for application_set in self.applications:
          if all(item in application_set.items() for item in
                 new_application_set.items()):
            new_application_set = ''
            term.replacement_application_name = application_set['name']
            break
          if (term.name == application_set['name'] and
              new_application_set != application_set):
            raise ConflictingApplicationSets(
                'Application set %s has a conflicting entry' % term.name)

        if new_application_set:
          new_application_set['name'] = term.name
          self.applications.append(new_application_set)

      self.srx_policies.append((header, new_terms, filter_options))
Esempio n. 12
0
def main(args):
    FLAGS(args)
    if FLAGS.verbose:
        logging.basicConfig(level=logging.INFO)
    if FLAGS.debug:
        logging.basicConfig(level=logging.DEBUG)
    logging.debug(
        'binary: %s\noptimize: %d\nbase_directory: %s\n'
        'policy_file: %s\nrendered_acl_directory: %s', str(sys.argv[0]),
        int(FLAGS.optimize), str(FLAGS.base_directory), str(FLAGS.policy_file),
        str(FLAGS.output_directory))

    definitions = None
    try:
        definitions = naming.Naming(FLAGS.definitions_directory)
    except naming.NoDefinitionsError:
        err_msg = 'bad definitions directory: %s', FLAGS.definitions_directory
        logging.fatal(err_msg)
        sys.exit(1)

    # thead-safe list for storing files to write
    manager = multiprocessing.Manager()
    write_files = manager.list()

    with_errors = False
    if FLAGS.policy_file:
        # render just one file
        logging.info('rendering one file')
        RenderFile(FLAGS.policy_file, FLAGS.output_directory, definitions,
                   FLAGS.exp_info, write_files)
    else:
        # render all files in parallel
        logging.info('finding policies...')
        pols = []
        pols.extend(
            DescendRecursively(FLAGS.base_directory, FLAGS.output_directory,
                               definitions))

        pool = multiprocessing.Pool(processes=FLAGS.max_renderers)
        results = []
        for x in pols:
            results.append(
                pool.apply_async(RenderFile,
                                 args=(x.get('in_file'), x.get('out_dir'),
                                       definitions, FLAGS.exp_info,
                                       write_files)))
        pool.close()
        pool.join()

        for result in results:
            try:
                result.get()
            except (ACLParserError, ACLGeneratorError) as e:
                with_errors = True
                logging.warn(
                    '\n\nerror encountered in rendering process:\n%s\n\n', e)

    # actually write files to disk
    WriteFiles(write_files)

    if with_errors:
        logging.warn('done, with errors.')
        sys.exit(1)
    else:
        logging.info('done.')
Esempio n. 13
0
  def _TranslatePolicy(self, pol, exp_info):
    self.pcap_policies = []
    current_date = datetime.datetime.utcnow().date()
    exp_info_date = current_date + datetime.timedelta(weeks=exp_info)

    good_afs = ['inet', 'inet6', 'mixed']
    good_options = ['in', 'out']
    direction = ''

    for header, terms in pol.filters:
      filter_type = None
      if self._PLATFORM not in header.platforms:
        continue

      filter_options = header.FilterOptions(self._PLATFORM)[1:]
      filter_name = header.FilterName(self._PLATFORM)

      # ensure all options after the filter name are expected
      for opt in filter_options:
        if opt not in good_afs + good_options:
          raise UnsupportedTargetOption('%s %s %s %s' % (
              '\nUnsupported option found in', self._PLATFORM,
              'target definition:', opt))

      if 'in' in filter_options:
        direction = 'in'
      elif 'out' in filter_options:
        direction = 'out'

      # Check for matching af
      for address_family in good_afs:
        if address_family in filter_options:
          # should not specify more than one AF in options
          if filter_type is not None:
            raise aclgenerator.UnsupportedFilterError('%s %s %s %s' % (
                '\nMay only specify one of', good_afs, 'in filter options:',
                filter_options))
          filter_type = address_family
      if filter_type is None:
        filter_type = 'mixed'

      # add the terms
      accept_terms = []
      deny_terms = []
      term_names = set()
      for term in terms:
        if term.name in term_names:
          raise aclgenerator.DuplicateTermError(
              'You have a duplicate term: %s' % term.name)

        if term.expiration:
          if term.expiration <= exp_info_date:
            logging.info('INFO: Term %s in policy %s expires '
                         'in less than two weeks.', term.name, filter_name)
          if term.expiration <= current_date:
            logging.warn('WARNING: Term %s in policy %s is expired and '
                         'will not be rendered.', term.name, filter_name)
            continue

        if not term:
          continue

        if term.action[0] == 'accept':
          accept_terms.append(self._TERM(term, filter_name, filter_type,
                                         direction))
        elif term.action[0] == 'deny' or term.action[0] == 'reject':
          deny_terms.append(self._TERM(term, filter_name, filter_type,
                                       direction))

      self.pcap_policies.append((header, filter_name, filter_type, accept_terms,
                                 deny_terms))
Esempio n. 14
0
 def log_prob(self, x):
     """Redirects to log_prob_elbo with a warning."""
     logging.warn('log_prob is actually a lower bound')
     return self.log_prob_elbo(x)
Esempio n. 15
0
def RenderFile(input_file, output_directory, definitions, exp_info,
               write_files):
    """Render a single file.

  Args:
    input_file: the name of the input policy file.
    output_directory: the directory in which we place the rendered file.
    definitions: the definitions from naming.Naming().
    exp_info: print a info message when a term is set to expire
              in that many weeks.
    write_files: a list of file tuples, (output_file, acl_text), to write
  """
    logging.debug('rendering file: %s into %s', input_file, output_directory)
    pol = None
    jcl = False
    acl = False
    asacl = False
    aacl = False
    bacl = False
    eacl = False
    gcefw = False
    ips = False
    ipt = False
    spd = False
    nsx = False
    pcap_accept = False
    pcap_deny = False
    pf = False
    srx = False
    jsl = False
    nft = False
    win_afw = False
    xacl = False
    paloalto = False

    try:
        conf = open(input_file).read()
        logging.debug('opened and read %s', input_file)
    except IOError as e:
        logging.warn('bad file: \n%s', e)
        raise

    try:
        pol = policy.ParsePolicy(conf,
                                 definitions,
                                 optimize=FLAGS.optimize,
                                 base_dir=FLAGS.base_directory,
                                 shade_check=FLAGS.shade_check)
    except policy.ShadingError as e:
        logging.warn('shading errors for %s:\n%s', input_file, e)
        return
    except (policy.Error, naming.Error):
        raise ACLParserError(
            'Error parsing policy file %s:\n%s%s' %
            (input_file, sys.exc_info()[0], sys.exc_info()[1]))

    platforms = set()
    for header in pol.headers:
        platforms.update(header.platforms)

    if 'juniper' in platforms:
        jcl = copy.deepcopy(pol)
    if 'cisco' in platforms:
        acl = copy.deepcopy(pol)
    if 'ciscoasa' in platforms:
        asacl = copy.deepcopy(pol)
    if 'brocade' in platforms:
        bacl = copy.deepcopy(pol)
    if 'arista' in platforms:
        eacl = copy.deepcopy(pol)
    if 'aruba' in platforms:
        aacl = copy.deepcopy(pol)
    if 'ipset' in platforms:
        ips = copy.deepcopy(pol)
    if 'iptables' in platforms:
        ipt = copy.deepcopy(pol)
    if 'nsxv' in platforms:
        nsx = copy.deepcopy(pol)
    if 'packetfilter' in platforms:
        pf = copy.deepcopy(pol)
    if 'pcap' in platforms:
        pcap_accept = copy.deepcopy(pol)
        pcap_deny = copy.deepcopy(pol)
    if 'speedway' in platforms:
        spd = copy.deepcopy(pol)
    if 'srx' in platforms:
        srx = copy.deepcopy(pol)
    if 'srxlo' in platforms:
        jsl = copy.deepcopy(pol)
    if 'windows_advfirewall' in platforms:
        win_afw = copy.deepcopy(pol)
    if 'ciscoxr' in platforms:
        xacl = copy.deepcopy(pol)
    if 'nftables' in platforms:
        nft = copy.deepcopy(pol)
    if 'gce' in platforms:
        gcefw = copy.deepcopy(pol)
    if 'paloalto' in platforms:
        paloalto = copy.deepcopy(pol)

    if not output_directory.endswith('/'):
        output_directory += '/'

    try:
        if jcl:
            acl_obj = juniper.Juniper(jcl, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if srx:
            acl_obj = junipersrx.JuniperSRX(srx, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if acl:
            acl_obj = cisco.Cisco(acl, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if asacl:
            acl_obj = ciscoasa.CiscoASA(asacl, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if aacl:
            acl_obj = aruba.Aruba(aacl, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if bacl:
            acl_obj = brocade.Brocade(bacl, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if eacl:
            acl_obj = arista.Arista(eacl, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if ips:
            acl_obj = ipset.Ipset(ips, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if ipt:
            acl_obj = iptables.Iptables(ipt, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if nsx:
            acl_obj = nsxv.Nsxv(nsx, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if spd:
            acl_obj = speedway.Speedway(spd, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if pcap_accept:
            acl_obj = pcap.PcapFilter(pcap_accept, exp_info)
            RenderACL(str(acl_obj), '-accept' + acl_obj.SUFFIX,
                      output_directory, input_file, write_files)
        if pcap_deny:
            acl_obj = pcap.PcapFilter(pcap_deny, exp_info, invert=True)
            RenderACL(str(acl_obj), '-deny' + acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if pf:
            acl_obj = packetfilter.PacketFilter(pf, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if win_afw:
            acl_obj = windows_advfirewall.WindowsAdvFirewall(win_afw, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if jsl:
            acl_obj = srxlo.SRXlo(jsl, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if xacl:
            acl_obj = ciscoxr.CiscoXR(xacl, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if nft:
            acl_obj = nftables.Nftables(nft, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if gcefw:
            acl_obj = gce.GCE(gcefw, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
        if paloalto:
            acl_obj = paloaltofw.PaloAltoFW(paloalto, exp_info)
            RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                      input_file, write_files)
    # TODO(robankeny) add additional errors.
    except (juniper.Error, junipersrx.Error, cisco.Error, ipset.Error,
            iptables.Error, speedway.Error, pcap.Error, aclgenerator.Error,
            aruba.Error, nftables.Error, gce.Error) as e:
        raise ACLGeneratorError('Error generating target ACL for %s:\n%s' %
                                (input_file, e))
Esempio n. 16
0
    def __init__(self,
                 root_dir,
                 train_step,
                 agent,
                 experience_dataset_fn=None,
                 after_train_strategy_step_fn=None,
                 triggers=None,
                 checkpoint_interval=100000,
                 summary_interval=1000,
                 max_checkpoints_to_keep=3,
                 use_kwargs_in_agent_train=False,
                 strategy=None,
                 run_optimizer_variable_init=True,
                 use_reverb_v2=False,
                 experience_dataset_options=None,
                 strategy_run_options=None):
        """Initializes a Learner instance.

    Args:
      root_dir: Main directory path where checkpoints, saved_models, and
        summaries will be written to.
      train_step: a scalar tf.int64 `tf.Variable` which will keep track of the
        number of train steps. This is used for artifacts created like
        summaries, or outputs in the root_dir.
      agent: `tf_agent.TFAgent` instance to train with.
      experience_dataset_fn: a function that will create an instance of a
        tf.data.Dataset used to sample experience for training. Required for
        using the Learner as is. Optional for subclass learners which take a new
        iterator each time when `learner.run` is called.
      after_train_strategy_step_fn: (Optional) callable of the form `fn(sample,
        loss)` which can be used for example to update priorities in a replay
        buffer where sample is pulled from the `experience_iterator` and loss is
        a `LossInfo` named tuple returned from the agent. This is called after
        every train step. It runs using `strategy.run(...)`.
      triggers: List of callables of the form `trigger(train_step)`. After every
        `run` call every trigger is called with the current `train_step` value
        as an np scalar.
      checkpoint_interval: Number of train steps in between checkpoints. Note
        these are placed into triggers and so a check to generate a checkpoint
        only occurs after every `run` call. Set to -1 to disable (this is not
        recommended, because it means that if the pipeline gets preempted, all
        previous progress is lost). This only takes care of the checkpointing
        the training process.  Policies must be explicitly exported through
        triggers.
      summary_interval: Number of train steps in between summaries. Note these
        are placed into triggers and so a check to generate a checkpoint only
        occurs after every `run` call.
      max_checkpoints_to_keep: Maximum number of checkpoints to keep around.
        These are used to recover from pre-emptions when training.
      use_kwargs_in_agent_train: If True the experience from the replay buffer
        is passed into the agent as kwargs. This requires samples from the RB to
        be of the form `dict(experience=experience, kwarg1=kwarg1, ...)`. This
        is useful if you have an agent with a custom argspec.
      strategy: (Optional) `tf.distribute.Strategy` to use during training.
      run_optimizer_variable_init: Specifies if the variables of the optimizer
        are initialized before checkpointing. This should be almost always
        `True` (default) to ensure that the state of the optimizer is
        checkpointed properly. The initialization of the optimizer variables
        happens by building the Tensorflow graph. This is done by calling a
        `get_concrete_function` on the agent's `train` method which requires
        passing some input. Since, no real data is available at this point we
        use the batched form of `training_data_spec` to achieve this (standard
        technique). The problem arises when the agent expects some agent
        specific batching of the input. In this case, there is no _general_ way
        at this point in the learner to batch the impacted specs properly. To
        avoid breaking the code in these specific cases, we recommend turning
        off initialization of the optimizer variables by setting the value of
        this field to `False`.
      use_reverb_v2: If True then we expect the dataset samples to return a
        named_tuple with a data and an info field. If False we expect a
        tuple(data, info).
      experience_dataset_options: (Optional) `tf.distribute.InputOptions` passed
        to `strategy.distribute_datasets_from_function`, used to control options
        on how this dataset is distributed.
      strategy_run_options: (Optional) `tf.distribute.RunOptions` passed to
        `strategy.run`. This is passed to every strategy.run invocation by the
        learner.
    """
        if checkpoint_interval < 0:
            logging.warning(
                'Warning: checkpointing the training process is manually disabled.'
                'This means training progress will NOT be automatically restored '
                'if the job gets preempted.')

        self._train_dir = os.path.join(root_dir, TRAIN_DIR)
        self._use_reverb_v2 = use_reverb_v2
        if summary_interval:
            self.train_summary_writer = tf.compat.v2.summary.create_file_writer(
                self._train_dir, flush_millis=10000)
        else:
            self.train_summary_writer = tf.summary.create_noop_writer()

        self.train_step = train_step
        self._agent = agent
        self.use_kwargs_in_agent_train = use_kwargs_in_agent_train
        self.strategy = strategy or tf.distribute.get_strategy()

        dataset = None
        if experience_dataset_fn:
            with self.strategy.scope():
                dataset = self.strategy.distribute_datasets_from_function(
                    lambda _: experience_dataset_fn(),
                    options=experience_dataset_options)
                self._experience_iterator = iter(dataset)

        self.after_train_strategy_step_fn = after_train_strategy_step_fn
        self.triggers = triggers or []

        # Prevent autograph from going into the agent.
        self._agent.train = tf.autograph.experimental.do_not_convert(
            agent.train)

        self._strategy_run_options = strategy_run_options

        checkpoint_dir = os.path.join(self._train_dir, POLICY_CHECKPOINT_DIR)
        with self.strategy.scope():
            agent.initialize()

            if run_optimizer_variable_init:
                # Force a concrete function creation inside of the strategy scope to
                # ensure that all variables, including optimizer slot variables, are
                # created. This has to happen before the checkpointer is created.
                if dataset is not None:
                    if use_reverb_v2:
                        batched_specs = dataset.element_spec.data
                    else:
                        # Assumes (experience, sample_info) = next(iterator)
                        batched_specs, _ = dataset.element_spec
                else:
                    batched_specs = tensor_spec.add_outer_dims_nest(
                        self._agent.training_data_spec,
                        (None, self._agent.train_sequence_length))
                if self.use_kwargs_in_agent_train:
                    batched_specs = dict(experience=batched_specs)

                @common.function
                def _create_variables(specs):
                    # TODO(b/170516529): Each replica has to be in the same graph.
                    # This can be ensured by placing the `strategy.run(...)` call inside
                    # the `tf.function`.
                    if self.use_kwargs_in_agent_train:
                        return self.strategy.run(
                            self._agent.train,
                            kwargs=specs,
                            options=self._strategy_run_options)
                    return self.strategy.run(
                        self._agent.train,
                        args=(specs, ),
                        options=self._strategy_run_options)

                _create_variables.get_concrete_function(batched_specs)
            else:
                # TODO(b/186052656) Update clients.
                logging.warn(
                    'run_optimizer_variable_init = False is Deprecated')

            self._checkpointer = common.Checkpointer(
                checkpoint_dir,
                max_to_keep=max_checkpoints_to_keep,
                agent=self._agent,
                train_step=self.train_step)
            self._checkpointer.initialize_or_restore()  # pytype: disable=attribute-error

        for trigger in self.triggers:
            if hasattr(trigger, 'set_start'):
                trigger.set_start(self.train_step.numpy())

        self.triggers.append(self._get_checkpoint_trigger(checkpoint_interval))
        self.summary_interval = tf.constant(summary_interval, dtype=tf.int64)
Esempio n. 17
0
    def __str__(self):
        output = []

        # Don't render term if not in platforms or in excluded platforms
        if self.term.platform and self._PLATFORM not in self.term.platform:
            return ''
        if (self.term.platform_exclude
                and self._PLATFORM in self.term.platform_exclude):
            return ''

        # Don't render icmpv6 protocol terms under inet, or icmp under inet6
        # Does not currently support mixed family.
        if ((self.af == 'inet6' and 'icmp' in self.term.protocol)
                or (self.af == 'inet' and 'icmpv6' in self.term.protocol)):
            logging.debug(
                self.NO_AF_LOG_PROTO.substitute(term=self.term.name,
                                                proto=self.term.protocol,
                                                af=self.af))
            return ''

        # Term verbatim output - this will skip over most normal term
        # creation code by returning early. Warnings provided in policy.py.
        if self.term.verbatim:
            for verbatim_line in self.term.verbatim:
                platform, contents = verbatim_line.value
                if platform == self._PLATFORM:
                    output.append(str(contents))
            return '\n'.join(output)

        # Source address
        if self.term.source_address or self.term.source_address_exclude:
            src_addrs = self._CalculateAddrs(self.term.source_address,
                                             self.term.source_address_exclude)
            if not src_addrs:
                logging.warn(
                    self.NO_AF_LOG_ADDR.substitute(term=self.term.name,
                                                   direction='source',
                                                   af=self.af))
                return ''
            # TODO(castagno): Add support for ipv6
            output.append('ip saddr %s' % self._FormatMatch(src_addrs))

        # Destination address
        if self.term.destination_address or self.term.source_address_exclude:
            dst_addrs = self._CalculateAddrs(
                self.term.destination_address,
                self.term.destination_address_exclude)
            if not dst_addrs:
                logging.warn(
                    self.NO_AF_LOG_ADDR.substitute(term=self.term.name,
                                                   direction='destination',
                                                   af=self.af))
                return ''
            # TODO(castagno): Add support for ipv6
            output.append('ip daddr %s' % self._FormatMatch(dst_addrs))

        # Protocol
        #
        # nft intepreter shortcuts protocol specification if there are more specific
        # matches. At the moment, these are:
        # * source port
        # * destination port
        # * ICMP type
        if self.term.protocol and not (self.term.source_port
                                       or self.term.destination_port
                                       or self.term.icmp_type):
            output.append('ip protocol %s' %
                          self._FormatMatch(self.term.protocol))

        # Source port
        if self.term.source_port:
            output.append('%s sport %s' % (self._FormatMatch(
                self.term.protocol), self._FormatMatch(self.term.source_port)))

        # Destination port
        if self.term.destination_port:
            output.append('%s dport %s' %
                          (self._FormatMatch(self.term.protocol),
                           self._FormatMatch(self.term.destination_port)))

        # Icmp type
        if self.term.icmp_type:
            icmp_types = self.NormalizeIcmpTypes(self.term.icmp_type,
                                                 self.term.protocol, self.af)
            if icmp_types != ['']:
                # nft intepreter requires ICMP types to be spelled out
                icmp_name_types = self.ICMP_TYPE[self.AF_MAP[self.af]]
                icmp_type_names = dict(
                    (v, k) for k, v in six.iteritems(icmp_name_types))
                output.append('icmp type %s' % self._FormatMatch(
                    [icmp_type_names[icmp_type] for icmp_type in icmp_types]))
        # Counter
        # This does not use the value that was passed in the term.
        if self.term.counter:
            output.append('counter')

        # Log
        # Setup logic so that only one log statement is printed.
        if self.term.logging and not self.term.log_name:
            output.append('log')
        elif (self.term.logging and self.term.log_name) or self.term.log_name:
            # Only supports log prefix's of 128 characters truncate to 126 to support
            # the additional suffix that is being added
            output.append('log prefix "%s: "' % self.term.log_name[:126])

        # Action
        output.append(self._ACTIONS[self.term.action[0]])

        # Owner (implement as comment)
        if self.term.owner:
            self.term.comment.append('Owner: %s' % self.term.owner)

        # Comment
        if self.term.comment:
            comment_data = ' '.join(self.term.comment)
            # Have to truncate MAX_CHARACTERS characters due to NFTables limitation
            if len(comment_data) > self.MAX_CHARACTERS:
                # Have to use the first MAX_CHARACTERS characters
                comment_data = comment_data[:self.MAX_CHARACTERS]
                logging.warn(
                    'Term %s in policy is too long (>%d characters) '
                    'and will be truncated', self.term.name,
                    self.MAX_CHARACTERS)

            output.append('comment "%s"' % comment_data)

        return ' '.join(output)
Esempio n. 18
0
    def _TranslatePolicy(self, policy, expiration):
        """Translates policy contents to platform specific data structures.

    Args:
      policy: policy object to be transalted to platform specific data
        structures.
      expiration: integer number of weeks to be warned about term expiration in.

    Raises:
      InvalidTargetOption: if supplied target options are invalid.

    """
        self.tables = collections.defaultdict(list)

        for header, terms in policy.filters:
            if self._PLATFORM not in header.platforms:
                continue

            filter_options = header.FilterOptions(self._PLATFORM)

            if not filter_options:
                raise InvalidTargetOption('Chain name not specified.')

            if len(filter_options) > 4:
                raise InvalidTargetOption('Too many target options.')

            if len(filter_options) == 1:
                raise InvalidTargetOption('Must have at least hook name')

            # Chain name, mandatory
            chain_name = filter_options[0]

            # Hook name, mandatory
            hook_name = filter_options[1].lower()

            if hook_name not in self._VALID_HOOK_NAMES:
                raise InvalidTargetOption(
                    'Specified hook name (%s) is not a valid hook name.' %
                    hook_name)

            # chain priority, mandatory
            chain_priority = None
            if len(filter_options) >= 3:
                try:
                    chain_priority = str(int(filter_options[2]))
                except ValueError:
                    raise InvalidTargetOption(
                        'Specified chain priority is not an integer (%s).' %
                        filter_options[2])

            # TODO(castagno): fix this. If you dont have hook name it never prints
            # anyways, so its not really optional
            if not hook_name or not chain_priority:
                logging.info(
                    'Chain %s is a non-base chain, make sure it is linked.',
                    chain_name)
                raise InvalidTargetOption('A table name is required')

            # Address family, optional, defaults to capirca inet
            af = 'inet'
            if len(filter_options) == 4:
                af = filter_options[3]
                if af not in self._VALID_ADDRESS_FAMILIES:
                    raise InvalidTargetOption(
                        'Specified address family (%s) is not supported.' % af)

            # Terms
            valid_terms = []
            for term in terms:
                term = self.FixHighPorts(term, af)
                if not term:
                    continue

                current_date = datetime.datetime.utcnow().date()
                expiration_date = current_date + datetime.timedelta(
                    weeks=expiration)

                if term.expiration:
                    if term.expiration < current_date:
                        logging.warn(
                            'Term %s in policy %s is expired and will not be rendered.',
                            term.name, chain_name)
                        continue
                    if term.expiration <= expiration_date:
                        logging.info(
                            'Term %s in policy %s expires in less than %d weeks.',
                            term.name, chain_name, expiration)

                valid_terms.append(self._TERM(term, af))

            # Add to appropriate table
            self.tables[af].append(
                (chain_name, hook_name, chain_priority, valid_terms))
Esempio n. 19
0
def main(_):
    """Runs fine-tuning and inference.

    There are three categories of images.
    1) Images where we have previous and next frame, and that are not filtered
       out by the heuristic. For them, we will use the fine-tuned predictions.
    2) Images where we have previous and next frame, but that were filtered out
       by our heuristic. For them, we will use the ordinary prediction instead.
    3) Images where we have at least one missing adjacent frame. For them, we will
       use the ordinary prediction as indicated by triplet_list_file_remains (if
       provided). They will also not be part of the generated inference list in
       the first place.

    Raises:
       ValueError: Invalid parameters have been passed.
    """

    if FLAGS.handle_motion and FLAGS.joint_encoder:
        raise ValueError(
            'Using a joint encoder is currently not supported when '
            'modeling object motion.')
    if FLAGS.handle_motion and FLAGS.seq_length != 3:
        raise ValueError(
            'The current motion model implementation only supports '
            'using a sequence length of three.')
    if FLAGS.handle_motion and not FLAGS.compute_minimum_loss:
        raise ValueError(
            'Computing the minimum photometric loss is required when '
            'enabling object motion handling.')
    if FLAGS.size_constraint_weight > 0 and not FLAGS.handle_motion:
        raise ValueError('To enforce object size constraints, enable motion '
                         'handling.')
    if FLAGS.icp_weight > 0.0:
        raise ValueError('ICP is currently not supported.')
    if FLAGS.compute_minimum_loss and FLAGS.seq_length % 2 != 1:
        raise ValueError(
            'Compute minimum loss requires using an odd number of '
            'images in a sequence.')
    if FLAGS.compute_minimum_loss and FLAGS.exhaustive_mode:
        raise ValueError(
            'Exhaustive mode has no effect when compute_minimum_loss '
            'is enabled.')
    if FLAGS.img_width % (2**5) != 0 or FLAGS.img_height % (2**5) != 0:
        logging.warn(
            'Image size is not divisible by 2^5. For the architecture '
            'employed, this could cause artefacts caused by resizing in '
            'lower dimensions.')

    if FLAGS.output_dir.endswith('/'):
        FLAGS.output_dir = FLAGS.output_dir[:-1]

    # Create file lists to prepare fine-tuning, save it to unique_file.
    unique_file_name = (str(datetime.datetime.now().date()) + '_' +
                        str(datetime.datetime.now().time()).replace(':', '_'))
    unique_file = os.path.join(FLAGS.data_dir, unique_file_name + '.txt')
    with gfile.FastGFile(FLAGS.triplet_list_file, 'r') as f:
        files_to_process = f.readlines()
        files_to_process = [line.rstrip() for line in files_to_process]
        files_to_process = [line for line in files_to_process if len(line)]
    logging.info('Creating unique file list %s with %s entries.', unique_file,
                 len(files_to_process))
    with gfile.FastGFile(unique_file, 'w') as f_out:
        fetches_network = FLAGS.num_steps * FLAGS.batch_size
        fetches_saves = FLAGS.batch_size * int(
            np.floor(FLAGS.num_steps / SAVE_EVERY))
        repetitions = fetches_network + 3 * fetches_saves
        for i in range(len(files_to_process)):
            for _ in range(repetitions):
                f_out.write(files_to_process[i] + '\n')

    # Read remaining files.
    remaining = []
    if gfile.Exists(FLAGS.triplet_list_file_remains):
        with gfile.FastGFile(FLAGS.triplet_list_file_remains, 'r') as f:
            remaining = f.readlines()
            remaining = [line.rstrip() for line in remaining]
            remaining = [line for line in remaining if len(line)]
    logging.info('Running fine-tuning on %s files, %s files are remaining.',
                 len(files_to_process), len(remaining))

    # Run fine-tuning process and save predictions in id-folders.
    tf.set_random_seed(FIXED_SEED)
    np.random.seed(FIXED_SEED)
    random.seed(FIXED_SEED)
    flipping_mode = reader.FLIP_ALWAYS if FLAGS.flip else reader.FLIP_NONE
    train_model = model.Model(
        data_dir=FLAGS.data_dir,
        file_extension=FLAGS.file_extension,
        is_training=True,
        learning_rate=FLAGS.learning_rate,
        beta1=FLAGS.beta1,
        reconstr_weight=FLAGS.reconstr_weight,
        smooth_weight=FLAGS.smooth_weight,
        ssim_weight=FLAGS.ssim_weight,
        icp_weight=FLAGS.icp_weight,
        batch_size=FLAGS.batch_size,
        img_height=FLAGS.img_height,
        img_width=FLAGS.img_width,
        seq_length=FLAGS.seq_length,
        architecture=FLAGS.architecture,
        imagenet_norm=FLAGS.imagenet_norm,
        weight_reg=FLAGS.weight_reg,
        exhaustive_mode=FLAGS.exhaustive_mode,
        random_scale_crop=FLAGS.random_scale_crop,
        flipping_mode=flipping_mode,
        random_color=False,
        depth_upsampling=FLAGS.depth_upsampling,
        depth_normalization=FLAGS.depth_normalization,
        compute_minimum_loss=FLAGS.compute_minimum_loss,
        use_skip=FLAGS.use_skip,
        joint_encoder=FLAGS.joint_encoder,
        build_sum=False,
        shuffle=False,
        input_file=unique_file_name,
        handle_motion=FLAGS.handle_motion,
        size_constraint_weight=FLAGS.size_constraint_weight,
        train_global_scale_var=False)

    failed_heuristic_ids = finetune_inference(train_model, FLAGS.model_ckpt,
                                              FLAGS.output_dir + '_ft')
    logging.info(
        'Fine-tuning completed, %s files were filtered out by '
        'heuristic.', len(failed_heuristic_ids))
    for failed_id in failed_heuristic_ids:
        failed_entry = files_to_process[failed_id]
        remaining.append(failed_entry)
    logging.info('In total, %s images were fine-tuned, while %s were not.',
                 len(files_to_process) - len(failed_heuristic_ids),
                 len(remaining))

    # Copy all results to have the same structural output as running ordinary
    # inference.
    for i in range(len(files_to_process)):
        if files_to_process[i] not in remaining:  # Use fine-tuned result.
            elements = files_to_process[i].split(' ')
            source_file = os.path.join(
                FLAGS.output_dir + '_ft', FLAGS.ft_name + 'id_' + str(i),
                str(FLAGS.num_steps).zfill(10) +
                ('_flip' if FLAGS.flip else ''))
            if len(elements) == 2:  # No differing mapping defined.
                target_dir = os.path.join(FLAGS.output_dir + '_ft',
                                          elements[0])
                target_file = os.path.join(
                    target_dir, elements[1] + ('_flip' if FLAGS.flip else ''))
            else:  # Other mapping for file defined, copy to this location instead.
                target_dir = os.path.join(FLAGS.output_dir + '_ft',
                                          os.path.dirname(elements[2]))
                target_file = os.path.join(
                    target_dir,
                    os.path.basename(elements[2]) +
                    ('_flip' if FLAGS.flip else ''))
            if not gfile.Exists(target_dir):
                gfile.MakeDirs(target_dir)
            logging.info('Copy refined result %s to %s.', source_file,
                         target_file)
            gfile.Copy(source_file + '.npy',
                       target_file + '.npy',
                       overwrite=True)
            gfile.Copy(source_file + '.txt',
                       target_file + '.txt',
                       overwrite=True)
            gfile.Copy(source_file + '.%s' % FLAGS.file_extension,
                       target_file + '.%s' % FLAGS.file_extension,
                       overwrite=True)
    for j in range(len(remaining)):
        elements = remaining[j].split(' ')
        if len(elements) == 2:  # No differing mapping defined.
            target_dir = os.path.join(FLAGS.output_dir + '_ft', elements[0])
            target_file = os.path.join(
                target_dir, elements[1] + ('_flip' if FLAGS.flip else ''))
        else:  # Other mapping for file defined, copy to this location instead.
            target_dir = os.path.join(FLAGS.output_dir + '_ft',
                                      os.path.dirname(elements[2]))
            target_file = os.path.join(
                target_dir,
                os.path.basename(elements[2]) +
                ('_flip' if FLAGS.flip else ''))
        if not gfile.Exists(target_dir):
            gfile.MakeDirs(target_dir)
        source_file = target_file.replace('_ft', '')
        logging.info('Copy unrefined result %s to %s.', source_file,
                     target_file)
        gfile.Copy(source_file + '.npy', target_file + '.npy', overwrite=True)
        gfile.Copy(source_file + '.%s' % FLAGS.file_extension,
                   target_file + '.%s' % FLAGS.file_extension,
                   overwrite=True)
    logging.info('Done, predictions saved in %s.', FLAGS.output_dir + '_ft')
Esempio n. 20
0
def _warn(msg):
  print(f'Warning: {msg}')
  logging.warn(msg)
Esempio n. 21
0
  def __init__(self, config, task_id, ps_tasks, num_workers, is_chief=True,
               summary_writer=None,
               dtype=tf.float32,
               summary_interval=1,
               run_number=0,
               logging_dir='/tmp', model_v=0):
    self.config = config
    self.data_manager = data.DataManager(
        config, run_number=run_number,
        do_code_simplification=not FLAGS.stop_on_success)
    self.task_id = task_id
    self.ps_tasks = ps_tasks
    self.is_chief = is_chief
    if ps_tasks == 0:
      assert task_id == 0, 'No parameter servers specified. Expecting 1 task.'
      assert num_workers == 1, (
          'No parameter servers specified. Expecting 1 task.')
      worker_device = '/job:localhost/replica:%d/task:0/cpu:0' % task_id
      # worker_device = '/cpu:0'
      # ps_device = '/cpu:0'
    else:
      assert num_workers > 0, 'There must be at least 1 training worker.'
      worker_device = '/job:worker/replica:%d/task:0/cpu:0' % task_id
      # ps_device = '/job:ps/replica:0/task:0/cpu:0'
    logging.info('worker_device: %s', worker_device)

    logging_file = os.path.join(
        logging_dir, 'solutions_%d.txt' % task_id)
    experience_replay_file = os.path.join(
        logging_dir, 'replay_buffer_%d.pickle' % task_id)
    self.topk_file = os.path.join(
        logging_dir, 'topk_buffer_%d.pickle' % task_id)

    tf.get_variable_scope().set_use_resource(True)

    # global model
    with tf.device(tf.train.replica_device_setter(ps_tasks,
                                                  ps_device='/job:ps/replica:0',
                                                  worker_device=worker_device)):
      with tf.variable_scope('global'):
        global_model = agent_lib.LMAgent(config, dtype=dtype, is_local=False)
        global_params_dict = {p.name: p
                              for p in global_model.sync_variables}
        self.global_model = global_model
        self.global_step = make_initialized_variable(
            0, 'global_step', dtype=tf.int64)

        self.global_best_reward = make_initialized_variable(
            -10.0, 'global_best_reward', dtype=tf.float64)
        self.is_best_model = make_initialized_variable(
            False, 'is_best_model', dtype=tf.bool)
        self.reset_is_best_model = self.is_best_model.assign(False)
        self.global_best_reward_placeholder = tf.placeholder(
            tf.float64, [], name='global_best_reward_placeholder')
        self.assign_global_best_reward_op = tf.group(
            self.global_best_reward.assign(
                self.global_best_reward_placeholder),
            self.is_best_model.assign(True))
        def assign_global_best_reward_fn(session, reward):
          reward = round(reward, 10)
          best_reward = round(session.run(self.global_best_reward), 10)
          is_best = reward > best_reward
          if is_best:
            session.run(self.assign_global_best_reward_op,
                        {self.global_best_reward_placeholder: reward})
          return is_best
        self.assign_global_best_reward_fn = assign_global_best_reward_fn

        # Any worker will set to true when it finds a solution.
        self.found_solution_flag = make_initialized_variable(
            False, 'found_solution_flag', dtype=tf.bool)
        self.found_solution_op = self.found_solution_flag.assign(True)

        self.run_number = make_initialized_variable(
            run_number, 'run_number', dtype=tf.int32)

        # Store a solution when found.
        self.code_solution_variable = tf.get_variable(
            'code_solution', [], tf.string,
            initializer=tf.constant_initializer(''))
        self.code_solution_ph = tf.placeholder(
            tf.string, [], name='code_solution_ph')
        self.code_solution_assign_op = self.code_solution_variable.assign(
            self.code_solution_ph)
        def assign_code_solution_fn(session, code_solution_string):
          session.run(self.code_solution_assign_op,
                      {self.code_solution_ph: code_solution_string})
        self.assign_code_solution_fn = assign_code_solution_fn

        # Count all programs sampled from policy. This does not include
        # programs sampled from replay buffer.
        # This equals NPE (number of programs executed). Only programs sampled
        # from the policy need to be executed.
        self.program_count = make_initialized_variable(
            0, 'program_count', dtype=tf.int64)

    # local model
    with tf.device(worker_device):
      with tf.variable_scope('local'):
        self.model = model = agent_lib.LMAgent(
            config,
            task_id=task_id,
            logging_file=logging_file,
            experience_replay_file=experience_replay_file,
            dtype=dtype,
            global_best_reward_fn=self.assign_global_best_reward_fn,
            found_solution_op=self.found_solution_op,
            assign_code_solution_fn=self.assign_code_solution_fn,
            program_count=self.program_count,
            stop_on_success=FLAGS.stop_on_success,
            verbose_level=model_v)
        local_params = model.trainable_variables
        local_params_dict = {p.name: p for p in local_params}

    # Pull global params to local model.
    def _global_to_local_scope(name):
      assert name.startswith('global/')
      return 'local' + name[6:]
    sync_dict = {
        local_params_dict[_global_to_local_scope(p_name)]: p
        for p_name, p in global_params_dict.items()}
    self.sync_op = tf.group(*[v_local.assign(v_global)
                              for v_local, v_global
                              in sync_dict.items()])

    # Pair local gradients with global params.
    grad_var_dict = {
        gradient: sync_dict[local_var]
        for local_var, gradient in model.gradients_dict.items()}

    # local model
    model.make_summary_ops()  # Don't put summaries under 'local' scope.
    with tf.variable_scope('local'):
      self.train_op = model.optimizer.apply_gradients(
          grad_var_dict.items(), global_step=self.global_step)
      self.local_init_op = tf.variables_initializer(
          tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                            tf.get_variable_scope().name))

    self.local_step = 0
    self.last_summary_time = time.time()
    self.summary_interval = summary_interval
    self.summary_writer = summary_writer
    self.cached_global_step = -1
    self.cached_global_npe = -1

    logging.info('summary_interval: %d', self.summary_interval)

    # Load top-k buffer.
    if self.model.top_episodes is not None and tf.gfile.Exists(self.topk_file):
      try:
        with tf.gfile.FastGFile(self.topk_file, 'r') as f:
          self.model.top_episodes = cPickle.loads(f.read())
        logging.info(
            'Loaded top-k buffer from disk with %d items. Location: "%s"',
            len(self.model.top_episodes), self.topk_file)
      except (cPickle.UnpicklingError, EOFError) as e:
        logging.warn(
            'Failed to load existing top-k buffer from disk. Removing bad file.'
            '\nLocation: "%s"\nException: %s', self.topk_file, str(e))
        tf.gfile.Remove(self.topk_file)
Esempio n. 22
0
def trim_and_pack_dataset(
    dataset: tf.data.Dataset,
    feature_lengths: Mapping[str, int],
    use_custom_ops: bool = False
) -> tf.data.Dataset:
  """Creates a 'packed' version of a dataset on-the-fly.

  Modified from the tensor2tensor library.

  This is meant to replace the irritation of having to create a separate
  "packed" version of a dataset to train efficiently on TPU.

  Each example in the output dataset represents several examples in the
  input dataset.

  For each key in the input dataset that also exists in `feature_lengths`, two
  additional keys are created:
    <key>_segment_id: an int32 tensor identifying the parts
       representing the original example.
    <key>_position: an int32 tensor identifying the position within the original
       example.

  Features that are not in `feature_lengths` will be removed.

  Example:
    Two input examples get combined to form an output example.
    The input examples are:
    {"inputs": [8, 7, 1, 0], "targets":[4, 1, 0], "idx": 0}
    {"inputs": [2, 3, 4, 1], "targets":[5, 6, 1], "idx": 1}
    The output example is:
    {
                   "inputs": [8, 7, 1, 2, 3, 4, 1, 0, 0, 0]
      "inputs_segment_id": [1, 1, 1, 2, 2, 2, 2, 0, 0, 0]
          "inputs_position": [0, 1, 2, 0, 1, 2, 3, 0, 0, 0]
                  "targets": [4, 1, 5, 6, 1, 0, 0, 0, 0, 0]
     "targets_segment_id": [1, 1, 2, 2, 2, 0, 0, 0, 0, 0]
         "targets_position": [0, 1, 0, 1, 2, 0, 0, 0, 0, 0]
    }

    0 represents padding in both the inputs and the outputs.

    Sequences in the incoming examples are truncated to length in
    `feature_lengths`, and the sequences in the output examples all have this
    fixed (padded) length. Features not in `features_length` (i.e, "idx") are
    removed.

  Args:
    dataset: a tf.data.Dataset
    feature_lengths: map from feature key to final length. Other features will
      be discarded.
    use_custom_ops: a boolean - custom ops are faster but require a custom-built
      binary, which is not currently possible on cloud-tpu.

  Returns:
    a tf.data.Dataset
  """
  element_spec = dataset.element_spec
  # Make sure that the dataset contains all keys in `feature_lengths`.
  for k in feature_lengths:
    if k not in element_spec:
      raise ValueError(
          f"Feature '{k}' not found in dataset. Available keys are "
          f"{list(element_spec.keys())}")
    if not element_spec[k].shape.is_compatible_with(tf.TensorShape([None])):
      raise ValueError(
          f"Features to be packed must be one-dimensional. '{k}' is not.'")
  # Warn if there are any additional keys that will be removed.
  additional_keys = set(element_spec) - set(feature_lengths)
  if additional_keys:
    logging.warn(
        "Features not in `features_length` will be removed during packing: %s",
        additional_keys)

  ds = dataset.map(
      lambda x: {k: x[k][:l] for k, l in feature_lengths.items()},
      num_parallel_calls=tf.data.experimental.AUTOTUNE)

  # Setting batch_size=length ensures that the concatenated sequences (if they
  # have length >=1) are sufficient to fill at least one packed example.
  batch_size = max(feature_lengths.values())
  ds = ds.padded_batch(
      batch_size, padded_shapes={k: [-1] for k in feature_lengths})

  if use_custom_ops and len(feature_lengths) <= 2:
    ds = _pack_with_custom_ops(ds, feature_lengths)
  else:
    ds = _pack_with_tf_ops(ds, feature_lengths)

  # Set the Tensor shapes correctly since they get lost in the process.
  def _set_shape(x):
    for k, v in x.items():
      v.set_shape([feature_lengths[_strip_packed_feature_key(k)]])
    return x
  return ds.map(_set_shape, num_parallel_calls=tf.data.experimental.AUTOTUNE)
Esempio n. 23
0
def get_paths_to_events(root_dir,
                        restrict_to_architectures,
                        restrict_to_pretrained_source,
                        restrict_to_variants=None):
    """Returns a dict that maps each variant name to its event file.

  The name of the variant is the basename of the directory where it's stored.
  Assumes the following directory organization root_dir contains a sub-directory
  for every variant where event files can be found.

  There may be more than one event file for each variant, e.g. a new one will be
  created upon restarting an experiment that was pre-empted. So later event
  files contain the summaries for larger values of 'step'. We need all of them
  for determining the global 'best'.

  Args:
    root_dir: A str. The root directory of experiments of all models variants.
    restrict_to_architectures: A list of names of architectures to restrict to
      when choosing the best variant.
    restrict_to_pretrained_source: A string. The pretrained_source to restrict
      to when choosing the best variant.
    restrict_to_variants: Optionally, a set of variant names to restrict to.
  """
    params_dir = os.path.join(root_dir, 'params')
    summary_dir = os.path.join(root_dir, 'summaries')

    def get_variant_architecture(name):
        """Get the architecture of the given variant if it's recorded, o/w None."""
        variant_params_dir = os.path.join(params_dir, name)
        variant_params_file = os.path.join(variant_params_dir, 'params.pkl')
        return get_value_from_pkl(variant_params_file,
                                  '_gin.LearnerConfig.embedding_network')

    def get_variant_pretrained_source(name):
        variant_params_dir = os.path.join(params_dir, name)
        variant_params_file = os.path.join(variant_params_dir, 'params.pkl')
        return get_value_from_pkl(variant_params_file,
                                  '_gin.LearnerConfig.pretrained_source')

    def keep_variant(name):
        """Determine if the variant in directory name should be considered."""
        value_error_msg = (
            'Requested to restrict to an architecture or '
            'pretrained_source but the given experiment does not '
            'have its params recorded. Looked in: {}'.format(params_dir))

        if restrict_to_architectures:
            architecture = get_variant_architecture(name)
            if architecture is None:
                raise ValueError(value_error_msg)
        valid_architecture = (not restrict_to_architectures
                              or architecture in restrict_to_architectures)

        if restrict_to_pretrained_source:
            pretrained_source = get_variant_pretrained_source(name)
            if pretrained_source is None:
                raise ValueError(value_error_msg)
        valid_pretrained_source = (not restrict_to_pretrained_source
                                   or pretrained_source
                                   == restrict_to_pretrained_source)

        valid_variant_name = True
        if restrict_to_variants is not None:
            valid_variant_name = name in restrict_to_variants

        return (valid_architecture and valid_pretrained_source
                and valid_variant_name)

    variant_names = [
        fname for fname in tf.io.gfile.listdir(summary_dir)
        if tf.io.gfile.isdir(os.path.join(summary_dir, fname))
    ]

    # Further filter variant names based on the given restrictions.
    variant_names = [name for name in variant_names if keep_variant(name)]

    if not variant_names:
        raise ValueError('Found no subdirectories in {}. Was expecting a '
                         'subdirectory per variant.'.format(summary_dir))
    variant_paths = [
        os.path.join(summary_dir, variant_dir) for variant_dir in variant_names
    ]

    event_paths = {}
    for variant_path, variant_name in zip(variant_paths, variant_names):
        event_filenames = [
            f_name for f_name in tf.io.gfile.listdir(variant_path)
            if f_name.startswith('events.out.tfevents')
        ]

        if len(event_filenames) < 1:
            logging.warn('Skipping empty variant %s.', variant_path)
            logging.info(
                'Was expecting at least one event file '
                'in directory %s. Instead, found %d.', variant_path,
                len(event_filenames))
            continue
        event_paths[variant_name] = [
            os.path.join(variant_path, event_filename)
            for event_filename in event_filenames
        ]

    logging.info('Found event files for variants: %s',
                 list(event_paths.keys()))
    return event_paths
Esempio n. 24
0
  def build_ensemble_spec(self,
                          name,
                          candidate,
                          ensembler,
                          subnetwork_specs,
                          summary,
                          features,
                          mode,
                          iteration_number,
                          labels=None,
                          previous_ensemble_spec=None):
    """Builds an `_EnsembleSpec` with the given `adanet.ensemble.Candidate`.

    Args:
      name: The string name of the ensemble. Typically the name of the builder
        that returned the given `Subnetwork`.
      candidate: The `adanet.ensemble.Candidate` for this spec.
      ensembler: The :class:`adanet.ensemble.Ensembler` to use to ensemble a
        group of subnetworks.
      subnetwork_specs: Iterable of `_SubnetworkSpecs` for this iteration.
      summary: A `_ScopedSummary` instance for recording ensemble summaries.
      features: Input `dict` of `Tensor` objects.
      mode: Estimator `ModeKeys` indicating training, evaluation, or inference.
      iteration_number: Integer current iteration number.
      labels: Labels `Tensor` or a dictionary of string label name to `Tensor`
        (for multi-head).
      previous_ensemble_spec: Link the rest of the `_EnsembleSpec` from
        iteration t-1. Used for creating the subnetwork train_op.

    Returns:
      An `_EnsembleSpec` instance.
    """

    with tf_compat.v1.variable_scope("ensemble_{}".format(name)):
      step = tf_compat.v1.get_variable(
          "step",
          shape=[],
          initializer=tf_compat.v1.zeros_initializer(),
          trainable=False,
          dtype=tf.int64)
      # Convert to tensor so that users cannot mutate it.
      step_tensor = tf.convert_to_tensor(value=step)
      with summary.current_scope():
        summary.scalar("iteration_step/adanet/iteration_step", step_tensor)
      architecture = _Architecture(candidate.name, ensembler.name)
      previous_subnetworks = []
      subnetwork_builders = []
      previous_ensemble = None
      if previous_ensemble_spec:
        previous_ensemble = previous_ensemble_spec.ensemble
        previous_architecture = previous_ensemble_spec.architecture
        keep_indices = range(len(previous_ensemble.subnetworks))
        if len(candidate.subnetwork_builders) == 1 and previous_ensemble:
          # Prune previous ensemble according to the subnetwork.Builder for
          # backwards compatibility.
          subnetwork_builder = candidate.subnetwork_builders[0]
          prune_previous_ensemble = getattr(subnetwork_builder,
                                            "prune_previous_ensemble", None)
          if callable(prune_previous_ensemble):
            logging.warn(
                "Using an `adanet.subnetwork.Builder#prune_previous_ensemble` "
                "is deprecated. Please use a custom `adanet.ensemble.Strategy` "
                "instead.")
            keep_indices = prune_previous_ensemble(previous_ensemble)
        for i, builder in enumerate(previous_ensemble_spec.subnetwork_builders):
          if i not in keep_indices:
            continue
          if builder not in candidate.previous_ensemble_subnetwork_builders:
            continue
          previous_subnetworks.append(previous_ensemble.subnetworks[i])
          subnetwork_builders.append(builder)
          architecture.add_subnetwork(*previous_architecture.subnetworks[i])
      for builder in candidate.subnetwork_builders:
        architecture.add_subnetwork(iteration_number, builder.name)
        subnetwork_builders.append(builder)
      subnetwork_map = {s.builder.name: s.subnetwork for s in subnetwork_specs}
      subnetworks = [
          subnetwork_map[s.name] for s in candidate.subnetwork_builders
      ]
      ensemble_scope = tf_compat.v1.get_variable_scope()
      before_var_list = tf_compat.v1.trainable_variables()
      with summary.current_scope(), _monkey_patch_context(
          iteration_step_scope=ensemble_scope,
          scoped_summary=summary,
          trainable_vars=[]):
        ensemble = ensembler.build_ensemble(
            subnetworks,
            previous_ensemble_subnetworks=previous_subnetworks,
            features=features,
            labels=labels,
            logits_dimension=self._head.logits_dimension,
            training=mode == tf.estimator.ModeKeys.TRAIN,
            iteration_step=step_tensor,
            summary=summary,
            previous_ensemble=previous_ensemble)
      ensemble_var_list = _new_trainable_variables(before_var_list)

      estimator_spec = _create_estimator_spec(self._head, features, labels,
                                              mode, ensemble.logits,
                                              self._use_tpu)

      ensemble_loss = estimator_spec.loss
      adanet_loss = None
      if mode != tf.estimator.ModeKeys.PREDICT:
        # TODO: Support any kind of Ensemble. Use a moving average of
        # their train loss for the 'adanet_loss'.
        if not isinstance(ensemble, ComplexityRegularized):
          raise ValueError(
              "Only ComplexityRegularized ensembles are supported.")
        adanet_loss = estimator_spec.loss + ensemble.complexity_regularization

      ensemble_metrics = _EnsembleMetrics(use_tpu=self._use_tpu)
      if mode == tf.estimator.ModeKeys.EVAL:
        ensemble_metrics.create_eval_metrics(
            features=features,
            labels=labels,
            estimator_spec=estimator_spec,
            metric_fn=self._metric_fn,
            architecture=architecture)

      if mode == tf.estimator.ModeKeys.TRAIN:
        with summary.current_scope():
          summary.scalar("loss", estimator_spec.loss)

      # Create train ops for training subnetworks and ensembles.
      train_op = None
      if mode == tf.estimator.ModeKeys.TRAIN:
        # Note that these mixture weights are on top of the last_layer of the
        # subnetwork constructed in TRAIN mode, which means that dropout is
        # still applied when the mixture weights are being trained.
        ensemble_scope = tf_compat.v1.get_variable_scope()
        with tf_compat.v1.variable_scope("train_mixture_weights"):
          with summary.current_scope(), _monkey_patch_context(
              iteration_step_scope=ensemble_scope,
              scoped_summary=summary,
              trainable_vars=ensemble_var_list):
            # For backwards compatibility.
            subnetwork_builder = candidate.subnetwork_builders[0]
            old_train_op_fn = getattr(subnetwork_builder,
                                      "build_mixture_weights_train_op", None)
            if callable(old_train_op_fn):
              logging.warn(
                  "The `build_mixture_weights_train_op` method is deprecated. "
                  "Please use the `Ensembler#build_train_op` instead.")
              train_op = _to_train_op_spec(
                  subnetwork_builder.build_mixture_weights_train_op(
                      loss=adanet_loss,
                      var_list=ensemble_var_list,
                      logits=ensemble.logits,
                      labels=labels,
                      iteration_step=step_tensor,
                      summary=summary))
            else:
              train_op = _to_train_op_spec(
                  ensembler.build_train_op(
                      ensemble=ensemble,
                      loss=adanet_loss,
                      var_list=ensemble_var_list,
                      labels=labels,
                      iteration_step=step_tensor,
                      summary=summary,
                      previous_ensemble=previous_ensemble))
    return _EnsembleSpec(
        name=name,
        architecture=architecture,
        subnetwork_builders=subnetwork_builders,
        ensemble=ensemble,
        predictions=estimator_spec.predictions,
        step=step,
        loss=ensemble_loss,
        adanet_loss=adanet_loss,
        train_op=train_op,
        eval_metrics=ensemble_metrics.eval_metrics_tuple(),
        export_outputs=estimator_spec.export_outputs)
Esempio n. 25
0
  def _TranslatePolicy(self, pol, exp_info):
    self.cisco_policies = []
    current_date = datetime.datetime.utcnow().date()
    exp_info_date = current_date + datetime.timedelta(weeks=exp_info)

    # a mixed filter outputs both ipv4 and ipv6 acls in the same output file
    good_filters = ['extended', 'standard', 'object-group', 'inet6',
                    'mixed', 'enable_dsmo']

    for header, terms in pol.filters:
      if self._PLATFORM not in header.platforms:
        continue

      obj_target = ObjectGroup()

      filter_options = header.FilterOptions(self._PLATFORM)
      filter_name = header.FilterName(self._PLATFORM)

      # extended is the most common filter type.
      filter_type = 'extended'
      if len(filter_options) > 1:
        filter_type = filter_options[1]

      # check if filter type is renderable
      if filter_type not in good_filters:
        raise UnsupportedCiscoAccessListError(
            'access list type %s not supported by %s (good types: %s)' % (
                filter_type, self._PLATFORM, str(good_filters)))

      filter_list = [filter_type]
      if filter_type == 'mixed':
        # Loop through filter and generate output for inet and inet6 in sequence
        filter_list = ['extended', 'inet6']

      for next_filter in filter_list:
        # Numeric access lists can be extended or standard, but have specific
        # known ranges.
        if next_filter == 'extended' and filter_name.isdigit():
          if int(filter_name) in range(1, 100) + range(1300, 2000):
            raise UnsupportedCiscoAccessListError(
                'Access lists between 1-99 and 1300-1999 are reserved for '
                'standard ACLs')
        if next_filter == 'standard' and filter_name.isdigit():
          if int(filter_name) not in range(1, 100) + range(1300, 2000):
            raise UnsupportedCiscoAccessListError(
                'Standard access lists must be numeric in the range of 1-99'
                ' or 1300-1999.')

        term_dup_check = set()
        new_terms = []
        for term in terms:
          if term.name in term_dup_check:
            raise CiscoDuplicateTermError('You have a duplicate term: %s' %
                                          term.name)
          term_dup_check.add(term.name)

          term.name = self.FixTermLength(term.name)
          af = 'inet'
          if next_filter == 'inet6':
            af = 'inet6'
          term = self.FixHighPorts(term, af=af)
          if not term:
            continue

          if term.expiration:
            if term.expiration <= exp_info_date:
              logging.info('INFO: Term %s in policy %s expires '
                           'in less than two weeks.', term.name, filter_name)
            if term.expiration <= current_date:
              logging.warn('WARNING: Term %s in policy %s is expired and '
                           'will not be rendered.', term.name, filter_name)
              continue

          # render terms based on filter type
          if next_filter == 'standard':
            # keep track of sequence numbers across terms
            new_terms.append(TermStandard(term, filter_name, self._PLATFORM))
          elif next_filter == 'extended':
            enable_dsmo = (len(filter_options) > 2 and
                           filter_options[2] == 'enable_dsmo')
            new_terms.append(
                Term(term, proto_int=self._PROTO_INT, enable_dsmo=enable_dsmo,
                     term_remark=self._TERM_REMARK, platform=self._PLATFORM))
          elif next_filter == 'object-group':
            obj_target.AddTerm(term)
            self._SetObjectGroupProtos(ObjectGroupTerm)
            obj_group_term = ObjectGroupTerm(term, filter_name)
            new_terms.append(obj_group_term)
          elif next_filter == 'inet6':
            new_terms.append(
                Term(
                    term, 6, proto_int=self._PROTO_INT,
                    platform=self._PLATFORM))

        # cisco requires different name for the v4 and v6 acls
        if filter_type == 'mixed' and next_filter == 'inet6':
          filter_name = 'ipv6-%s' % filter_name
        self.cisco_policies.append((header, filter_name, [next_filter],
                                    new_terms, obj_target))
Esempio n. 26
0
    def _TranslatePolicy(self, pol, exp_info):
        # pylint: disable=attribute-defined-outside-init
        """Transform a policy object into a JuniperSRX object.

    Args:
      pol: policy.Policy object
      exp_info: print a info message when a term is set to expire
                in that many weeks

    Raises:
      UnsupportedFilterError: An unsupported filter was specified
      UnsupportedHeader: A header option exists that is not understood/usable
      SRXDuplicateTermError: Two terms were found with same name in same filter
      ConflictingTargetOptions: Two target options are conflicting in the header
      MixedAddrBookTypes: Global and Zone address books in the same policy
      ConflictingApplicationSets: When two duplicate named terms have
                                  conflicting application entries
    """
        self.srx_policies = []
        self.addressbook = collections.OrderedDict()
        self.applications = []
        self.ports = []
        self.from_zone = ''
        self.to_zone = ''
        self.addr_book_type = set()

        current_date = datetime.datetime.utcnow().date()
        exp_info_date = current_date + datetime.timedelta(weeks=exp_info)

        for header, terms in pol.filters:
            if self._PLATFORM not in header.platforms:
                continue

            filter_options = header.FilterOptions(self._PLATFORM)

            if (len(filter_options) < 4 or filter_options[0] != 'from-zone'
                    or filter_options[2] != 'to-zone'):
                raise UnsupportedFilterError(
                    'SRX filter arguments must specify '
                    'from-zone and to-zone.')

            # check if to-zone is not a supported target option
            if filter_options[1] in self._SUPPORTED_TARGET_OPTIONS:
                raise UnsupportedFilterError(
                    'to-zone %s cannot be the same as any '
                    'valid SRX target-options' % (filter_options[1]))
            else:
                self.from_zone = filter_options[1]

            # check if from-zone is not a supported target option
            if filter_options[3] in self._SUPPORTED_TARGET_OPTIONS:
                raise UnsupportedFilterError(
                    'from-zone %s cannot be the same as any '
                    'valid SRX target-options' % (filter_options[3]))
            else:
                self.to_zone = filter_options[3]

            # variables used to collect target-options and set defaults
            target_options = []
            filter_type = ''

            # parse srx target options
            extra_options = filter_options[4:]
            if self._SUPPORTED_TARGET_OPTIONS.issubset(extra_options):
                raise ConflictingTargetOptions(
                    'only one address-book-type can '
                    'be specified per header "%s"' % ' '.join(filter_options))
            else:
                address_book_type = self._SUPPORTED_TARGET_OPTIONS.intersection(
                    extra_options)
                if len(address_book_type) is 0:
                    address_book_type = {self._GLOBAL_ADDR_BOOK}
                self.addr_book_type.update(address_book_type)
                if len(self.addr_book_type) > 1:
                    raise MixedAddrBookTypes(
                        'Global and Zone address-book-types cannot '
                        'be used in the same policy')
                if self.from_zone == 'all' and self.to_zone == 'all':
                    if self._ZONE_ADDR_BOOK in self.addr_book_type:
                        raise UnsupportedFilterError(
                            'Zone address books cannot be used '
                            'with a global policy.')
                elif self.from_zone == 'all' or self.to_zone == 'all':
                    raise UnsupportedFilterError(
                        'The zone name all is reserved for '
                        'global policies.')

            for filter_opt in filter_options[4:]:

                # validate address families
                if filter_opt in self._SUPPORTED_AF:
                    if not filter_type:
                        filter_type = filter_opt
                    else:
                        raise ConflictingTargetOptions(
                            'only one address family can be '
                            'specified per header "%s"' %
                            ' '.join(filter_options))

                elif filter_opt in self._SUPPORTED_TARGET_OPTIONS:
                    target_options.append(filter_opt)

                else:
                    raise UnsupportedHeader(
                        'SRX Generator currently does not support '
                        '%s as a header option "%s"' %
                        (filter_opt, ' '.join(filter_options)))

            # if address-family and address-book-type have not been set then default
            if not filter_type:
                filter_type = 'mixed'

            term_dup_check = set()
            new_terms = []
            self._FixLargePolices(terms, filter_type)
            for term in terms:
                term.name = self.FixTermLength(term.name)
                if term.name in term_dup_check:
                    raise SRXDuplicateTermError(
                        'You have a duplicate term: %s' % term.name)
                term_dup_check.add(term.name)

                if term.expiration:
                    if term.expiration <= exp_info_date:
                        logging.info(
                            'INFO: Term %s in policy %s>%s expires '
                            'in less than two weeks.', term.name,
                            self.from_zone, self.to_zone)
                    if term.expiration <= current_date:
                        logging.warn(
                            'WARNING: Term %s in policy %s>%s is expired.',
                            term.name, self.from_zone, self.to_zone)
                        continue

                # SRX address books leverage network token names for IPs.
                # When excluding addresses, we lose those distinct names so we need
                # to create a new unique name based off the term name before excluding.
                if term.source_address_exclude:
                    # If we have a naked source_exclude, we need something to exclude from
                    if not term.source_address:
                        term.source_address = [
                            nacaddr.IP('0.0.0.0/0', term.name.upper(),
                                       term.name.upper())
                        ]
                    # Use the term name as the token & parent_token
                    new_src_parent_token = term.name.upper() + '_SRC_EXCLUDE'
                    new_src_token = new_src_parent_token
                    for i in term.source_address_exclude:
                        term.source_address = nacaddr.RemoveAddressFromList(
                            term.source_address, i)
                        for i in term.source_address:
                            i.token = new_src_token
                            i.parent_token = new_src_parent_token

                if term.destination_address_exclude:
                    if not term.destination_address:
                        term.destination_address = [
                            nacaddr.IP('0.0.0.0/0', term.name.upper(),
                                       term.name.upper())
                        ]
                    new_dst_parent_token = term.name.upper() + '_DST_EXCLUDE'
                    new_dst_token = new_dst_parent_token
                    for i in term.destination_address_exclude:
                        term.destination_address = nacaddr.RemoveAddressFromList(
                            term.destination_address, i)
                        for i in term.destination_address:
                            i.token = new_dst_token
                            i.parent_token = new_dst_parent_token

                # SRX policies are controlled by addresses that are used within, so
                # policy can be at the same time inet and inet6.
                if self._GLOBAL_ADDR_BOOK in self.addr_book_type:
                    for zone in self.addressbook:
                        for unused_name, ips in sorted(
                                self.addressbook[zone].iteritems()):
                            ips = [i for i in ips]
                            if term.source_address == ips:
                                term.source_address = ips
                            if term.destination_address == ips:
                                term.destination_address = ips
                for addr in term.source_address:
                    if addr.version in self._AF_MAP[filter_type]:
                        self._BuildAddressBook(self.from_zone, addr)
                for addr in term.destination_address:
                    if addr.version in self._AF_MAP[filter_type]:
                        self._BuildAddressBook(self.to_zone, addr)

                new_term = Term(term, filter_options)
                new_terms.append(new_term)

                # Because SRX terms can contain inet and inet6 addresses. We have to
                # have ability to recover proper AF for ICMP type we need.
                # If protocol is empty or we cannot map to inet or inet6 we insert bogus
                # af_type name which will cause new_term.NormalizeIcmpTypes to fail.
                if not term.protocol:
                    icmp_af_type = 'unknown_af_icmp'
                else:
                    icmp_af_type = self._AF_ICMP_MAP.get(
                        term.protocol[0], 'unknown_af_icmp')
                tmp_icmptype = new_term.NormalizeIcmpTypes(
                    term.icmp_type, term.protocol, icmp_af_type)
                # NormalizeIcmpTypes returns [''] for empty, convert to [] for eval
                normalized_icmptype = tmp_icmptype if tmp_icmptype != [
                    ''
                ] else []
                # rewrites the protocol icmpv6 to icmp6
                if 'icmpv6' in term.protocol:
                    protocol = list(term.protocol)
                    protocol[protocol.index('icmpv6')] = 'icmp6'
                else:
                    protocol = term.protocol
                new_application_set = {
                    'sport': self._BuildPort(term.source_port),
                    'dport': self._BuildPort(term.destination_port),
                    'protocol': protocol,
                    'icmp-type': normalized_icmptype,
                    'timeout': term.timeout
                }

                for application_set in self.applications:
                    if all(item in application_set.items()
                           for item in new_application_set.items()):
                        new_application_set = ''
                        term.replacement_application_name = application_set[
                            'name']
                        break
                    if (term.name == application_set['name']
                            and new_application_set != application_set):
                        raise ConflictingApplicationSets(
                            'Application set %s has a conflicting entry' %
                            term.name)

                if new_application_set:
                    new_application_set['name'] = term.name
                    self.applications.append(new_application_set)

            self.srx_policies.append((header, new_terms, filter_options))
Esempio n. 27
0
  def _FixLargePolices(self, terms, address_family):
    """Loops over all terms finding terms exceeding SRXs policy limit.

    Args:
      terms: List of terms from a policy.
      address_family: Tuple containing address family versions.

    See the following URL for more information
    http://www.juniper.net/techpubs/en_US/junos12.1x44/topics/reference/
    general/address-address-sets-limitations.html
    """

    def Chunks(l):
      """Splits a list of IP addresses into smaller lists based on byte size."""
      return_list = [[]]
      counter = 0
      index = 0
      for i in l:
        # Size is split in half due to the max size being a sum of src and dst.
        if counter > (self._ADDRESS_LENGTH_LIMIT/2):
          counter = 0
          index += 1
          return_list.append([])
        if i.version == 6:
          counter += self._IPV6_SIZE
        else:
          counter += 1
        return_list[index].append(i)
      return return_list

    expanded_terms = []
    for term in terms:
      if (term.AddressesByteLength(
          self._AF_MAP[address_family]) > self._ADDRESS_LENGTH_LIMIT):
        logging.warn('LARGE TERM ENCOUNTERED')
        src_chunks = Chunks(term.source_address)
        counter = 0
        for chunk in src_chunks:
          for ip in chunk:
            ip.parent_token = 'src_' + term.name + str(counter)
          counter += 1
        dst_chunks = Chunks(term.destination_address)
        counter = 0
        for chunk in dst_chunks:
          for ip in chunk:
            ip.parent_token = 'dst_' + term.name + str(counter)
          counter += 1

        src_dst_products = itertools.product(src_chunks, dst_chunks)
        counter = 0
        for src_dst_list in src_dst_products:
          new_term = copy.copy(term)
          new_term.source_address = src_dst_list[0]
          new_term.destination_address = src_dst_list[1]
          new_term.name = new_term.name + '_' + str(counter)
          expanded_terms.append(new_term)
          counter += 1
      else:
        expanded_terms.append(term)
    if expanded_terms:
      del terms[:]
      terms.extend(expanded_terms)
Esempio n. 28
0
    def _FixLargePolices(self, terms, address_family):
        """Loops over all terms finding terms exceeding SRXs policy limit.

    Args:
      terms: List of terms from a policy.
      address_family: Tuple containing address family versions.

    See the following URL for more information
    http://www.juniper.net/techpubs/en_US/junos12.1x44/topics/reference/
    general/address-address-sets-limitations.html
    """
        def Chunks(l):
            """Splits a list of IP addresses into smaller lists based on byte size."""
            return_list = [[]]
            counter = 0
            index = 0
            for i in l:
                # Size is split in half due to the max size being a sum of src and dst.
                if counter > (self._ADDRESS_LENGTH_LIMIT / 2):
                    counter = 0
                    index += 1
                    return_list.append([])
                if i.version == 6:
                    counter += self._IPV6_SIZE
                else:
                    counter += 1
                return_list[index].append(i)
            return return_list

        expanded_terms = []
        for term in terms:
            if (term.AddressesByteLength(self._AF_MAP[address_family]) >
                    self._ADDRESS_LENGTH_LIMIT):
                logging.warn('LARGE TERM ENCOUNTERED')
                src_chunks = Chunks(term.source_address)
                counter = 0
                for chunk in src_chunks:
                    for ip in chunk:
                        ip.parent_token = 'src_' + term.name + str(counter)
                    counter += 1
                dst_chunks = Chunks(term.destination_address)
                counter = 0
                for chunk in dst_chunks:
                    for ip in chunk:
                        ip.parent_token = 'dst_' + term.name + str(counter)
                    counter += 1

                src_dst_products = itertools.product(src_chunks, dst_chunks)
                counter = 0
                for src_dst_list in src_dst_products:
                    new_term = copy.copy(term)
                    new_term.source_address = src_dst_list[0]
                    new_term.destination_address = src_dst_list[1]
                    new_term.name = new_term.name + '_' + str(counter)
                    expanded_terms.append(new_term)
                    counter += 1
            else:
                expanded_terms.append(term)
        if expanded_terms:
            del terms[:]
            terms.extend(expanded_terms)
Esempio n. 29
0
  def __str__(self):
    output = []

    # Don't render term if not in platforms or in excluded platforms
    if self.term.platform and self._PLATFORM not in self.term.platform:
      return ''
    if (self.term.platform_exclude and
        self._PLATFORM in self.term.platform_exclude):
      return ''

    # Don't render icmpv6 protocol terms under inet, or icmp under inet6
    # Does not currently support mixed family.
    if ((self.af == 'inet6' and 'icmp' in self.term.protocol) or
        (self.af == 'inet' and 'icmpv6' in self.term.protocol)):
      logging.debug(self.NO_AF_LOG_PROTO.substitute(term=self.term.name,
                                                    proto=self.term.protocol,
                                                    af=self.af))
      return ''

    # Term verbatim output - this will skip over most normal term
    # creation code by returning early. Warnings provided in policy.py.
    if self.term.verbatim:
      for verbatim_line in self.term.verbatim:
        platform, contents = verbatim_line.value
        if platform == self._PLATFORM:
          output.append(str(contents))
      return '\n'.join(output)

    # Source address
    if self.term.source_address or self.term.source_address_exclude:
      src_addrs = self._CalculateAddrs(self.term.source_address,
                                       self.term.source_address_exclude)
      if not src_addrs:
        logging.warn(self.NO_AF_LOG_ADDR.substitute(term=self.term.name,
                                                    direction='source',
                                                    af=self.af))
        return ''
      # TODO(castagno): Add support for ipv6
      output.append('ip saddr %s' % self._FormatMatch(src_addrs))

    # Destination address
    if self.term.destination_address or self.term.source_address_exclude:
      dst_addrs = self._CalculateAddrs(self.term.destination_address,
                                       self.term.destination_address_exclude)
      if not dst_addrs:
        logging.warn(self.NO_AF_LOG_ADDR.substitute(term=self.term.name,
                                                    direction='destination',
                                                    af=self.af))
        return ''
      # TODO(castagno): Add support for ipv6
      output.append('ip daddr %s' % self._FormatMatch(dst_addrs))

    # Protocol
    #
    # nft intepreter shortcuts protocol specification if there are more specific
    # matches. At the moment, these are:
    # * source port
    # * destination port
    # * ICMP type
    if self.term.protocol and not (self.term.source_port or
                                   self.term.destination_port or
                                   self.term.icmp_type):
      output.append('ip protocol %s' % self._FormatMatch(self.term.protocol))

    # Source port
    if self.term.source_port:
      output.append('%s sport %s' %
                    (self._FormatMatch(self.term.protocol),
                     self._FormatMatch(self.term.source_port)))

    # Destination port
    if self.term.destination_port:
      output.append('%s dport %s' %
                    (self._FormatMatch(self.term.protocol),
                     self._FormatMatch(self.term.destination_port)))

    # Icmp type
    if self.term.icmp_type:
      icmp_types = self.NormalizeIcmpTypes(self.term.icmp_type,
                                           self.term.protocol,
                                           self.af)
      if icmp_types != ['']:
        # nft intepreter requires ICMP types to be spelled out
        icmp_name_types = self.ICMP_TYPE[self.AF_MAP[self.af]]
        icmp_type_names = dict((v, k) for k, v in icmp_name_types.iteritems())
        output.append('icmp type %s' %
                      self._FormatMatch([icmp_type_names[icmp_type] for
                                         icmp_type in icmp_types]))
    # Counter
    # This does not use the value that was passed in the term.
    if self.term.counter:
      output.append('counter')

    # Log
    # Setup logic so that only one log statement is printed.
    if self.term.logging and not self.term.log_name:
      output.append('log')
    elif (self.term.logging and self.term.log_name) or self.term.log_name:
      # Only supports log prefix's of 128 characters truncate to 126 to support
      # the additional suffix that is being added
      output.append('log prefix "%s: "' % self.term.log_name[:126])

    # Action
    output.append(self._ACTIONS[self.term.action[0]])

    # Owner (implement as comment)
    if self.term.owner:
      self.term.comment.append('Owner: %s' % self.term.owner)

    # Comment
    if self.term.comment:
      comment_data = ' '.join(self.term.comment)
      # Have to truncate MAX_CHARACTERS characters due to NFTables limitation
      if len(comment_data) > self.MAX_CHARACTERS:
        # Have to use the first MAX_CHARACTERS characters
        comment_data = comment_data[:self.MAX_CHARACTERS]
        logging.warn(
            'Term %s in policy is too long (>%d characters) '
            'and will be truncated', self.term.name, self.MAX_CHARACTERS)

      output.append('comment "%s"' % comment_data)

    return ' '.join(output)
Esempio n. 30
0
def start_aip_training(input_dict: Dict[Text, List[types.Artifact]],
                       output_dict: Dict[Text, List[types.Artifact]],
                       exec_properties: Dict[Text,
                                             Any], executor_class_path: Text,
                       training_inputs: Dict[Text,
                                             Any], job_id: Optional[Text]):
  """Start a trainer job on AI Platform (AIP).

  This is done by forwarding the inputs/outputs/exec_properties to the
  tfx.scripts.run_executor module on a AI Platform training job interpreter.

  Args:
    input_dict: Passthrough input dict for tfx.components.Trainer.executor.
    output_dict: Passthrough input dict for tfx.components.Trainer.executor.
    exec_properties: Passthrough input dict for tfx.components.Trainer.executor.
    executor_class_path: class path for TFX core default trainer.
    training_inputs: Training input argument for AI Platform training job.
      'pythonModule', 'pythonVersion' and 'runtimeVersion' will be inferred. For
      the full set of parameters, refer to
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#TrainingInput
    job_id: Job ID for AI Platform Training job. If not supplied,
      system-determined unique ID is given. Refer to
    https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs#resource-job

  Returns:
    None
  """
  training_inputs = training_inputs.copy()

  json_inputs = artifact_utils.jsonify_artifact_dict(input_dict)
  logging.info('json_inputs=\'%s\'.', json_inputs)
  json_outputs = artifact_utils.jsonify_artifact_dict(output_dict)
  logging.info('json_outputs=\'%s\'.', json_outputs)
  json_exec_properties = json.dumps(exec_properties, sort_keys=True)
  logging.info('json_exec_properties=\'%s\'.', json_exec_properties)

  # We use custom containers to launch training on AI Platform, which invokes
  # the specified image using the container's entrypoint. The default
  # entrypoint for TFX containers is to call scripts/run_executor.py. The
  # arguments below are passed to this run_executor entry to run the executor
  # specified in `executor_class_path`.
  container_command = _CONTAINER_COMMAND + [
      '--executor_class_path',
      executor_class_path,
      '--inputs',
      json_inputs,
      '--outputs',
      json_outputs,
      '--exec-properties',
      json_exec_properties,
  ]

  if not training_inputs.get('masterConfig'):
    training_inputs['masterConfig'] = {
        'imageUri': _TFX_IMAGE,
    }

  # Always use our own entrypoint instead of relying on container default.
  if 'containerCommand' in training_inputs['masterConfig']:
    logging.warn('Overriding custom value of containerCommand')
  training_inputs['masterConfig']['containerCommand'] = container_command

  # Pop project_id so AIP doesn't complain about an unexpected parameter.
  # It's been a stowaway in aip_args and has finally reached its destination.
  project = training_inputs.pop('project')
  with telemetry_utils.scoped_labels(
      {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}):
    job_labels = telemetry_utils.get_labels_dict()

  # 'tfx_YYYYmmddHHMMSS' is the default job ID if not explicitly specified.
  job_id = job_id or 'tfx_{}'.format(
      datetime.datetime.now().strftime('%Y%m%d%H%M%S'))

  _launch_aip_training(
      job_id=job_id,
      project=project,
      training_input=training_inputs,
      job_labels=job_labels)
Esempio n. 31
0
  def _TranslatePolicy(self, pol, exp_info):
    """Translate a policy from objects into strings."""
    self.iptables_policies = []
    current_date = datetime.datetime.utcnow().date()
    exp_info_date = current_date + datetime.timedelta(weeks=exp_info)

    default_action = None
    good_default_actions = ['ACCEPT', 'DROP']
    good_afs = ['inet', 'inet6']
    good_options = ['nostate', 'abbreviateterms', 'truncateterms', 'noverbose']
    all_protocols_stateful = True
    self.verbose = True

    for header, terms in pol.filters:
      filter_type = None
      if self._PLATFORM not in header.platforms:
        continue

      filter_options = header.FilterOptions(self._PLATFORM)[1:]
      filter_name = header.FilterName(self._PLATFORM)

      self._WarnIfCustomTarget(filter_name)

      # ensure all options after the filter name are expected
      for opt in filter_options:
        if opt not in good_default_actions + good_afs + good_options:
          raise UnsupportedTargetOption('%s %s %s %s' % (
              '\nUnsupported option found in', self._PLATFORM,
              'target definition:', opt))

      # disable stateful?
      if 'nostate' in filter_options:
        all_protocols_stateful = False
      if 'noverbose' in filter_options:
        self.verbose = False

      # Check for matching af
      for address_family in good_afs:
        if address_family in filter_options:
          # should not specify more than one AF in options
          if filter_type is not None:
            raise UnsupportedFilterError('%s %s %s %s' % (
                '\nMay only specify one of', good_afs, 'in filter options:',
                filter_options))
          filter_type = address_family
      if filter_type is None:
        filter_type = 'inet'

      if self._PLATFORM == 'iptables' and filter_name == 'FORWARD':
        default_action = 'DROP'

      # does this policy override the default filter actions?
      for next_target in header.target:
        if next_target.platform == self._PLATFORM:
          if len(next_target.options) > 1:
            for arg in next_target.options:
              if arg in good_default_actions:
                default_action = arg
      if default_action and default_action not in good_default_actions:
        raise UnsupportedDefaultAction('%s %s %s %s %s' % (
            '\nOnly', ', '.join(good_default_actions),
            'default filter action allowed;', default_action, 'used.'))

      # add the terms
      new_terms = []
      term_names = set()
      for term in terms:
        term.name = self.FixTermLength(term.name,
                                       'abbreviateterms' in filter_options,
                                       'truncateterms' in filter_options)
        if term.name in term_names:
          raise aclgenerator.DuplicateTermError(
              'You have a duplicate term: %s' % term.name)
        term_names.add(term.name)

        term = self.FixHighPorts(term, af=filter_type,
                                 all_protocols_stateful=all_protocols_stateful)
        if not term:
          continue

        if term.expiration:
          if term.expiration <= exp_info_date:
            logging.info('INFO: Term %s in policy %s expires '
                         'in less than two weeks.', term.name, filter_name)
          if term.expiration <= current_date:
            logging.warn('WARNING: Term %s in policy %s is expired and '
                         'will not be rendered.', term.name, filter_name)
            continue

        new_terms.append(self._TERM(term, filter_name, all_protocols_stateful,
                                    default_action, filter_type, self.verbose))

      self.iptables_policies.append((header, filter_name, filter_type,
                                     default_action, new_terms))
Esempio n. 32
0
def main(_):
  """Runs fine-tuning and inference.

  There are three categories of images.
  1) Images where we have previous and next frame, and that are not filtered
     out by the heuristic. For them, we will use the fine-tuned predictions.
  2) Images where we have previous and next frame, but that were filtered out
     by our heuristic. For them, we will use the ordinary prediction instead.
  3) Images where we have at least one missing adjacent frame. For them, we will
     use the ordinary prediction as indicated by triplet_list_file_remains (if
     provided). They will also not be part of the generated inference list in
     the first place.

  Raises:
     ValueError: Invalid parameters have been passed.
  """

  if FLAGS.handle_motion and FLAGS.joint_encoder:
    raise ValueError('Using a joint encoder is currently not supported when '
                     'modeling object motion.')
  if FLAGS.handle_motion and FLAGS.seq_length != 3:
    raise ValueError('The current motion model implementation only supports '
                     'using a sequence length of three.')
  if FLAGS.handle_motion and not FLAGS.compute_minimum_loss:
    raise ValueError('Computing the minimum photometric loss is required when '
                     'enabling object motion handling.')
  if FLAGS.size_constraint_weight > 0 and not FLAGS.handle_motion:
    raise ValueError('To enforce object size constraints, enable motion '
                     'handling.')
  if FLAGS.icp_weight > 0.0:
    raise ValueError('ICP is currently not supported.')
  if FLAGS.compute_minimum_loss and FLAGS.seq_length % 2 != 1:
    raise ValueError('Compute minimum loss requires using an odd number of '
                     'images in a sequence.')
  if FLAGS.compute_minimum_loss and FLAGS.exhaustive_mode:
    raise ValueError('Exhaustive mode has no effect when compute_minimum_loss '
                     'is enabled.')
  if FLAGS.img_width % (2 ** 5) != 0 or FLAGS.img_height % (2 ** 5) != 0:
    logging.warn('Image size is not divisible by 2^5. For the architecture '
                 'employed, this could cause artefacts caused by resizing in '
                 'lower dimensions.')

  if FLAGS.output_dir.endswith('/'):
    FLAGS.output_dir = FLAGS.output_dir[:-1]

  # Create file lists to prepare fine-tuning, save it to unique_file.
  unique_file_name = (str(datetime.datetime.now().date()) + '_' +
                      str(datetime.datetime.now().time()).replace(':', '_'))
  unique_file = os.path.join(FLAGS.data_dir, unique_file_name + '.txt')
  with gfile.FastGFile(FLAGS.triplet_list_file, 'r') as f:
    files_to_process = f.readlines()
    files_to_process = [line.rstrip() for line in files_to_process]
    files_to_process = [line for line in files_to_process if len(line)]
  logging.info('Creating unique file list %s with %s entries.', unique_file,
               len(files_to_process))
  with gfile.FastGFile(unique_file, 'w') as f_out:
    fetches_network = FLAGS.num_steps * FLAGS.batch_size
    fetches_saves = FLAGS.batch_size * int(np.floor(FLAGS.num_steps/SAVE_EVERY))
    repetitions = fetches_network + 3 * fetches_saves
    for i in range(len(files_to_process)):
      for _ in range(repetitions):
        f_out.write(files_to_process[i] + '\n')

  # Read remaining files.
  remaining = []
  if gfile.Exists(FLAGS.triplet_list_file_remains):
    with gfile.FastGFile(FLAGS.triplet_list_file_remains, 'r') as f:
      remaining = f.readlines()
      remaining = [line.rstrip() for line in remaining]
      remaining = [line for line in remaining if len(line)]
  logging.info('Running fine-tuning on %s files, %s files are remaining.',
               len(files_to_process), len(remaining))

  # Run fine-tuning process and save predictions in id-folders.
  tf.set_random_seed(FIXED_SEED)
  np.random.seed(FIXED_SEED)
  random.seed(FIXED_SEED)
  flipping_mode = reader.FLIP_ALWAYS if FLAGS.flip else reader.FLIP_NONE
  train_model = model.Model(data_dir=FLAGS.data_dir,
                            file_extension=FLAGS.file_extension,
                            is_training=True,
                            learning_rate=FLAGS.learning_rate,
                            beta1=FLAGS.beta1,
                            reconstr_weight=FLAGS.reconstr_weight,
                            smooth_weight=FLAGS.smooth_weight,
                            ssim_weight=FLAGS.ssim_weight,
                            icp_weight=FLAGS.icp_weight,
                            batch_size=FLAGS.batch_size,
                            img_height=FLAGS.img_height,
                            img_width=FLAGS.img_width,
                            seq_length=FLAGS.seq_length,
                            architecture=FLAGS.architecture,
                            imagenet_norm=FLAGS.imagenet_norm,
                            weight_reg=FLAGS.weight_reg,
                            exhaustive_mode=FLAGS.exhaustive_mode,
                            random_scale_crop=FLAGS.random_scale_crop,
                            flipping_mode=flipping_mode,
                            random_color=False,
                            depth_upsampling=FLAGS.depth_upsampling,
                            depth_normalization=FLAGS.depth_normalization,
                            compute_minimum_loss=FLAGS.compute_minimum_loss,
                            use_skip=FLAGS.use_skip,
                            joint_encoder=FLAGS.joint_encoder,
                            build_sum=False,
                            shuffle=False,
                            input_file=unique_file_name,
                            handle_motion=FLAGS.handle_motion,
                            size_constraint_weight=FLAGS.size_constraint_weight,
                            train_global_scale_var=False)

  failed_heuristic_ids = finetune_inference(train_model, FLAGS.model_ckpt,
                                            FLAGS.output_dir + '_ft')
  logging.info('Fine-tuning completed, %s files were filtered out by '
               'heuristic.', len(failed_heuristic_ids))
  for failed_id in failed_heuristic_ids:
    failed_entry = files_to_process[failed_id]
    remaining.append(failed_entry)
  logging.info('In total, %s images were fine-tuned, while %s were not.',
               len(files_to_process)-len(failed_heuristic_ids), len(remaining))

  # Copy all results to have the same structural output as running ordinary
  # inference.
  for i in range(len(files_to_process)):
    if files_to_process[i] not in remaining:  # Use fine-tuned result.
      elements = files_to_process[i].split(' ')
      source_file = os.path.join(FLAGS.output_dir + '_ft', FLAGS.ft_name +
                                 'id_' + str(i),
                                 str(FLAGS.num_steps).zfill(10) +
                                 ('_flip' if FLAGS.flip else ''))
      if len(elements) == 2:  # No differing mapping defined.
        target_dir = os.path.join(FLAGS.output_dir + '_ft', elements[0])
        target_file = os.path.join(
            target_dir, elements[1] + ('_flip' if FLAGS.flip else ''))
      else:  # Other mapping for file defined, copy to this location instead.
        target_dir = os.path.join(
            FLAGS.output_dir + '_ft', os.path.dirname(elements[2]))
        target_file = os.path.join(
            target_dir,
            os.path.basename(elements[2]) + ('_flip' if FLAGS.flip else ''))
      if not gfile.Exists(target_dir):
        gfile.MakeDirs(target_dir)
      logging.info('Copy refined result %s to %s.', source_file, target_file)
      gfile.Copy(source_file + '.npy', target_file + '.npy', overwrite=True)
      gfile.Copy(source_file + '.txt', target_file + '.txt', overwrite=True)
      gfile.Copy(source_file + '.%s' % FLAGS.file_extension,
                 target_file + '.%s' % FLAGS.file_extension, overwrite=True)
  for j in range(len(remaining)):
    elements = remaining[j].split(' ')
    if len(elements) == 2:  # No differing mapping defined.
      target_dir = os.path.join(FLAGS.output_dir + '_ft', elements[0])
      target_file = os.path.join(
          target_dir, elements[1] + ('_flip' if FLAGS.flip else ''))
    else:  # Other mapping for file defined, copy to this location instead.
      target_dir = os.path.join(
          FLAGS.output_dir + '_ft', os.path.dirname(elements[2]))
      target_file = os.path.join(
          target_dir,
          os.path.basename(elements[2]) + ('_flip' if FLAGS.flip else ''))
    if not gfile.Exists(target_dir):
      gfile.MakeDirs(target_dir)
    source_file = target_file.replace('_ft', '')
    logging.info('Copy unrefined result %s to %s.', source_file, target_file)
    gfile.Copy(source_file + '.npy', target_file + '.npy', overwrite=True)
    gfile.Copy(source_file + '.%s' % FLAGS.file_extension,
               target_file + '.%s' % FLAGS.file_extension, overwrite=True)
  logging.info('Done, predictions saved in %s.', FLAGS.output_dir + '_ft')
Esempio n. 33
0
  def _TranslatePolicy(self, pol, exp_info):
    self.nsxv_policies = []
    current_date = datetime.datetime.utcnow().date()
    exp_info_date = current_date + datetime.timedelta(weeks=exp_info)

    for header, terms in pol.filters:
      if self._PLATFORM not in header.platforms:
        continue

      filter_options = header.FilterOptions(self._PLATFORM)
      if len(filter_options) >= 2:
        filter_name = filter_options[1]

      # get filter type, section id and applied To
      self._ParseFilterOptions(filter_options)

      filter_type = self._FILTER_OPTIONS_DICT['filter_type']
      applied_to = self._FILTER_OPTIONS_DICT['applied_to']

      term_names = set()
      new_terms = []
      for term in terms:
        # Check for duplicate terms
        if term.name in term_names:
          raise NsxvDuplicateTermError('There are multiple terms named: %s' %
                                       term.name)
        term_names.add(term.name)

        if term.expiration:
          if term.expiration <= exp_info_date:
            logging.info('INFO: Term %s in policy %s expires '
                         'in less than two weeks.', term.name, filter_name)
          if term.expiration <= current_date:
            logging.warn('WARNING: Term %s in policy %s is expired and '
                         'will not be rendered.', term.name, filter_name)
            continue
        # Get the mapped action value
        # If there is no mapped action value term is not rendered
        mapped_action = _ACTION_TABLE.get(str(term.action[0]))
        if not mapped_action:
          logging.warn('WARNING: Action %s in Term %s is not valid and '
                       'will not be rendered.', term.action, term.name)
          continue

        term.name = self.FixTermLength(term.name)

        if filter_type == 'inet':
          af = 'inet'
          term = self.FixHighPorts(term, af=af)
          if not term:
            continue
          new_terms.append(Term(term, filter_type, applied_to, 4))

        if filter_type == 'inet6':
          af = 'inet6'
          term = self.FixHighPorts(term, af=af)
          if not term:
            continue
          new_terms.append(Term(term, filter_type, applied_to, 6))

        if filter_type == 'mixed':
          if 'icmpv6' not in term.protocol:
            inet_term = self.FixHighPorts(term, 'inet')
            if not inet_term:
              continue
            new_terms.append(Term(inet_term, filter_type, applied_to, 4))
          else:
            inet6_term = self.FixHighPorts(term, 'inet6')
            if not inet6_term:
              continue
            new_terms.append(Term(inet6_term, filter_type, applied_to, 6))

      self.nsxv_policies.append((header, filter_name, [filter_type],
                                 new_terms))
Esempio n. 34
0
    def _TranslatePolicy(self, pol, exp_info):
        self.cisco_policies = []
        current_date = datetime.datetime.utcnow().date()
        exp_info_date = current_date + datetime.timedelta(weeks=exp_info)

        # a mixed filter outputs both ipv4 and ipv6 acls in the same output file
        good_filters = [
            'extended', 'standard', 'object-group', 'inet6', 'mixed',
            'enable_dsmo'
        ]

        for header, terms in pol.filters:
            if self._PLATFORM not in header.platforms:
                continue
            obj_target = ObjectGroup()

            filter_options = header.FilterOptions(self._PLATFORM)
            filter_name = header.FilterName(self._PLATFORM)

            self.verbose = True
            if 'noverbose' in filter_options:
                filter_options.remove('noverbose')
                self.verbose = False

            # extended is the most common filter type.
            filter_type = 'extended'
            if len(filter_options) > 1:
                filter_type = filter_options[1]

            # check if filter type is renderable
            if filter_type not in good_filters:
                raise UnsupportedCiscoAccessListError(
                    'access list type %s not supported by %s (good types: %s)'
                    % (filter_type, self._PLATFORM, str(good_filters)))

            filter_list = [filter_type]
            if filter_type == 'mixed':
                # Loop through filter and generate output for inet and inet6 in sequence
                filter_list = ['extended', 'inet6']

            for next_filter in filter_list:
                # Numeric access lists can be extended or standard, but have specific
                # known ranges.
                if next_filter == 'extended' and filter_name.isdigit():
                    if int(filter_name) in list(range(1, 100)) + list(
                            range(1300, 2000)):
                        raise UnsupportedCiscoAccessListError(
                            'Access lists between 1-99 and 1300-1999 are reserved for '
                            'standard ACLs')
                if next_filter == 'standard' and filter_name.isdigit():
                    if (int(filter_name) not in list(range(1, 100)) +
                            list(range(1300, 2000))):
                        raise UnsupportedCiscoAccessListError(
                            'Standard access lists must be numeric in the range of 1-99'
                            ' or 1300-1999.')

                term_dup_check = set()
                new_terms = []
                for term in terms:
                    if term.name in term_dup_check:
                        raise CiscoDuplicateTermError(
                            'You have a duplicate term: %s' % term.name)
                    term_dup_check.add(term.name)

                    term.name = self.FixTermLength(term.name)
                    af = 'inet'
                    if next_filter == 'inet6':
                        af = 'inet6'
                    term = self.FixHighPorts(term, af=af)
                    if not term:
                        continue

                    if term.expiration:
                        if term.expiration <= exp_info_date:
                            logging.info(
                                'INFO: Term %s in policy %s expires '
                                'in less than two weeks.', term.name,
                                filter_name)
                        if term.expiration <= current_date:
                            logging.warn(
                                'WARNING: Term %s in policy %s is expired and '
                                'will not be rendered.', term.name,
                                filter_name)
                            continue

                    # render terms based on filter type
                    if next_filter == 'standard':
                        # keep track of sequence numbers across terms
                        new_terms.append(
                            TermStandard(term, filter_name, self._PLATFORM,
                                         self.verbose))
                    elif next_filter == 'extended':
                        enable_dsmo = (len(filter_options) > 2
                                       and filter_options[2] == 'enable_dsmo')
                        new_terms.append(
                            Term(term,
                                 proto_int=self._PROTO_INT,
                                 enable_dsmo=enable_dsmo,
                                 term_remark=self._TERM_REMARK,
                                 platform=self._PLATFORM,
                                 verbose=self.verbose))
                    elif next_filter == 'object-group':
                        obj_target.AddTerm(term)
                        new_terms.append(
                            self._GetObjectGroupTerm(term,
                                                     filter_name,
                                                     verbose=self.verbose))
                    elif next_filter == 'inet6':
                        new_terms.append(
                            Term(term,
                                 6,
                                 proto_int=self._PROTO_INT,
                                 platform=self._PLATFORM,
                                 verbose=self.verbose))

                # cisco requires different name for the v4 and v6 acls
                if filter_type == 'mixed' and next_filter == 'inet6':
                    filter_name = 'ipv6-%s' % filter_name
                self.cisco_policies.append((header, filter_name, [next_filter],
                                            new_terms, obj_target))
Esempio n. 35
0
def deploy_model_for_aip_prediction(api: discovery.Resource,
                                    serving_path: Text,
                                    model_version: Text,
                                    ai_platform_serving_args: Dict[Text, Any],
                                    job_labels: Dict[Text, Text],
                                    skip_model_creation: bool = False,
                                    set_default_version: bool = True) -> None:
  """Deploys a model for serving with AI Platform.

  Args:
    api: Google API client resource.
    serving_path: The path to the model. Must be a GCS URI.
    model_version: Version of the model being deployed. Must be different from
      what is currently being served.
    ai_platform_serving_args: Dictionary containing arguments for pushing to AI
      Platform. The full set of parameters supported can be found at
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.models.versions#Version.
      Most keys are forwarded as-is, but following keys are handled specially:
        - name: this must be empty (and will be filled by pusher).
        - deployment_uri: this must be empty (and will be filled by pusher).
        - python_version: when left empty, this will be filled by python version
            of the environment being used.
        - runtime_version: when left empty, this will be filled by TensorFlow
            version from the environment.
        - labels: a list of job labels will be merged with user's input.
    job_labels: The dict of labels that will be attached to this job. They are
      merged with optional labels from `ai_platform_serving_args`.
    skip_model_creation: If true, the method assuem model already exist in
      AI platform, therefore skipping model creation.
    set_default_version: Whether set the newly deployed model version as the
      default version.

  Raises:
    RuntimeError: if an error is encountered when trying to push.
  """
  logging.info(
      'Deploying to model with version %s to AI Platform for serving: %s',
      model_version, ai_platform_serving_args)

  model_name = ai_platform_serving_args['model_name']
  project_id = ai_platform_serving_args['project_id']
  default_runtime_version = _get_tf_runtime_version(tf.__version__)
  runtime_version = ai_platform_serving_args.get('runtime_version',
                                                 default_runtime_version)
  python_version = _get_caip_python_version(runtime_version)

  if not skip_model_creation:
    create_model_for_aip_prediction_if_not_exist(api, job_labels,
                                                 ai_platform_serving_args)
  version_body = dict(ai_platform_serving_args)
  for model_only_key in ['model_name', 'project_id', 'regions']:
    version_body.pop(model_only_key, None)
  version_body['name'] = model_version
  version_body['deployment_uri'] = serving_path
  version_body['runtime_version'] = version_body.get('runtime_version',
                                                     runtime_version)
  version_body['python_version'] = version_body.get('python_version',
                                                    python_version)
  version_body['labels'] = {**version_body.get('labels', {}), **job_labels}
  logging.info(
      'Creating new version of model_name %s in project %s, request body: %s',
      model_name, project_id, version_body)

  # Push to AIP, and record the operation name so we can poll for its state.
  model_name = 'projects/{}/models/{}'.format(project_id, model_name)
  try:
    operation = api.projects().models().versions().create(
        body=version_body, parent=model_name).execute()
    _wait_for_operation(api, operation, 'projects.models.versions.create')
  except errors.HttpError as e:
    # If the error is to create an already existing model version, it's ok to
    # ignore.
    if e.resp.status == 409:
      logging.warn('Model version %s already exists', model_version)
    else:
      raise RuntimeError('Creating model verseion to AI Platform failed: {}'
                         .format(e))

  if set_default_version:
    # Set the new version as default.
    # By API specification, if Long-Running-Operation is done and there is
    # no error, 'response' is guaranteed to exist.
    api.projects().models().versions().setDefault(name='{}/versions/{}'.format(
        model_name, model_version)).execute()

  logging.info(
      'Successfully deployed model %s with version %s, serving from %s',
      model_name, model_version, serving_path)
Esempio n. 36
0
  def utt_to_samples(self, args):
    '''
    Utt to samples of (feat_chunk, spk_id, sample_key).
    Will be run in a process pool so restrictions apply.
    '''
    result_queue, utt_info = args
    logging.debug(utt_info)

    # TODO: wrap into a function or something
    utt_key, utt_meta = utt_info

    # Load features and select voiced frames.
    feat_scp = utt_meta['feat']
    feat_mat = kaldiio.load_mat(feat_scp)
    num_frames_feat, feat_dim = feat_mat.shape
    vad_scp = utt_meta['vad']

    if vad_scp:
      vad_mat = kaldiio.load_mat(vad_scp)
      num_frames_vad = vad_mat.shape[0]
      logging.debug('feat_mat: %s, vad_mat: %s' %
                    (str(feat_mat.shape), str(vad_mat.shape)))
      if num_frames_feat != num_frames_vad:
        logging.debug('num_frames_feat != num_frames_vad: %d vs %d' %
                      (num_frames_feat, num_frames_vad))
        return None
      voiced_frames_index = np.where(vad_mat == 1)[0]
      logging.debug('voiced_frames_index: %s' %
                    (str(voiced_frames_index.shape)))
      feat_mat_voiced = feat_mat[voiced_frames_index, :]
    else:
      # If no VAD info was found, the entire utt will be used.
      feat_mat_voiced = feat_mat
    num_frames_voiced = feat_mat_voiced.shape[0]
    logging.debug('feat_mat_voiced: %s' % (str(feat_mat_voiced.shape)))

    spk_id = utt_meta['spkid']

    logging.debug('Chunk size: %d' % (self.chunk_size))

    results = []
    chunk_idx = 0
    if self.add_random_offset:
      random_offset = np.random.randint(0, self.chunk_size)
    else:
      random_offset = 0
    for offset in range(random_offset, num_frames_voiced, self.chunk_size):
      if self.single_chunk:
        available = num_frames_voiced - self.chunk_size
        if available < 0:
          # No padding.
          logging.warn('Single chunk mode: available < 0.')
          break
        offset = random.randint(0, available)
      logging.debug('offset = %d' % (offset))
      feat_chunk = feat_mat_voiced[offset:offset + self.chunk_size, :]
      unpadded_frames = feat_chunk.shape[0]
      if self.pad_chunks and unpadded_frames < self.chunk_size:
        rel_chunk_len = float(unpadded_frames) / self.chunk_size
        if rel_chunk_len < self.drop_short_chunks:
          continue
        logging.debug('Padding chunk of frames %d ...' % (unpadded_frames))
        padded = np.zeros((self.chunk_size, feat_dim), dtype=feat_chunk.dtype)
        padded[:unpadded_frames, :] = feat_chunk
        feat_chunk = padded
      feat_chunk = np.expand_dims(feat_chunk, axis=2)  # TODO: not here
      sample_key = '%s_chunk%02d' % (utt_key, chunk_idx)
      sample = (feat_chunk, spk_id, sample_key)
      chunk_idx += 1
      results.append(sample)
      if self.single_chunk:
        break
    if result_queue:
      # queue mode
      result_queue.put(results)
      return None
    # imap mode
    return results
Esempio n. 37
0
def main(_):
  # Fixed seed for repeatability
  seed = 8964
  tf.set_random_seed(seed)
  np.random.seed(seed)
  random.seed(seed)

  if FLAGS.handle_motion and FLAGS.joint_encoder:
    raise ValueError('Using a joint encoder is currently not supported when '
                     'modeling object motion.')
  if FLAGS.handle_motion and FLAGS.seq_length != 3:
    raise ValueError('The current motion model implementation only supports '
                     'using a sequence length of three.')
  if FLAGS.handle_motion and not FLAGS.compute_minimum_loss:
    raise ValueError('Computing the minimum photometric loss is required when '
                     'enabling object motion handling.')
  if FLAGS.size_constraint_weight > 0 and not FLAGS.handle_motion:
    raise ValueError('To enforce object size constraints, enable motion '
                     'handling.')
  if FLAGS.imagenet_ckpt and not FLAGS.imagenet_norm:
    logging.warn('When initializing with an ImageNet-pretrained model, it is '
                 'recommended to normalize the image inputs accordingly using '
                 'imagenet_norm.')
  if FLAGS.compute_minimum_loss and FLAGS.seq_length % 2 != 1:
    raise ValueError('Compute minimum loss requires using an odd number of '
                     'images in a sequence.')
  if FLAGS.architecture != nets.RESNET and FLAGS.imagenet_ckpt:
    raise ValueError('Can only load weights from pre-trained ImageNet model '
                     'when using ResNet-architecture.')
  if FLAGS.compute_minimum_loss and FLAGS.exhaustive_mode:
    raise ValueError('Exhaustive mode has no effect when compute_minimum_loss '
                     'is enabled.')
  if FLAGS.img_width % (2 ** 5) != 0 or FLAGS.img_height % (2 ** 5) != 0:
    logging.warn('Image size is not divisible by 2^5. For the architecture '
                 'employed, this could cause artefacts caused by resizing in '
                 'lower dimensions.')
  if FLAGS.icp_weight > 0.0:
    # TODO(casser): Change ICP interface to take matrix instead of vector.
    raise ValueError('ICP is currently not supported.')

  if not gfile.Exists(FLAGS.checkpoint_dir):
    gfile.MakeDirs(FLAGS.checkpoint_dir)

  train_model = model.Model(data_dir=FLAGS.data_dir,
                            file_extension=FLAGS.file_extension,
                            is_training=True,
                            learning_rate=FLAGS.learning_rate,
                            beta1=FLAGS.beta1,
                            reconstr_weight=FLAGS.reconstr_weight,
                            smooth_weight=FLAGS.smooth_weight,
                            ssim_weight=FLAGS.ssim_weight,
                            icp_weight=FLAGS.icp_weight,
                            batch_size=FLAGS.batch_size,
                            img_height=FLAGS.img_height,
                            img_width=FLAGS.img_width,
                            seq_length=FLAGS.seq_length,
                            architecture=FLAGS.architecture,
                            imagenet_norm=FLAGS.imagenet_norm,
                            weight_reg=FLAGS.weight_reg,
                            exhaustive_mode=FLAGS.exhaustive_mode,
                            random_scale_crop=FLAGS.random_scale_crop,
                            flipping_mode=FLAGS.flipping_mode,
                            depth_upsampling=FLAGS.depth_upsampling,
                            depth_normalization=FLAGS.depth_normalization,
                            compute_minimum_loss=FLAGS.compute_minimum_loss,
                            use_skip=FLAGS.use_skip,
                            joint_encoder=FLAGS.joint_encoder,
                            handle_motion=FLAGS.handle_motion,
                            equal_weighting=FLAGS.equal_weighting,
                            size_constraint_weight=FLAGS.size_constraint_weight)

  train(train_model, FLAGS.pretrained_ckpt, FLAGS.imagenet_ckpt,
        FLAGS.checkpoint_dir, FLAGS.train_steps, FLAGS.summary_freq)
Esempio n. 38
0
def RenderFile(input_file, output_directory, definitions,
               exp_info, write_files):
  """Render a single file.

  Args:
    input_file: the name of the input policy file.
    output_directory: the directory in which we place the rendered file.
    definitions: the definitions from naming.Naming().
    exp_info: print a info message when a term is set to expire
              in that many weeks.
    write_files: a list of file tuples, (output_file, acl_text), to write
  """
  logging.debug('rendering file: %s into %s', input_file,
                output_directory)
  pol = None
  jcl = False
  acl = False
  asacl = False
  aacl = False
  bacl = False
  eacl = False
  gcefw = False
  ips = False
  ipt = False
  spd = False
  nsx = False
  pcap_accept = False
  pcap_deny = False
  pf = False
  srx = False
  jsl = False
  nft = False
  win_afw = False
  xacl = False
  paloalto = False

  try:
    conf = open(input_file).read()
    logging.debug('opened and read %s', input_file)
  except IOError as e:
    logging.warn('bad file: \n%s', e)
    raise

  try:
    pol = policy.ParsePolicy(
        conf, definitions, optimize=FLAGS.optimize,
        base_dir=FLAGS.base_directory, shade_check=FLAGS.shade_check)
  except policy.ShadingError as e:
    logging.warn('shading errors for %s:\n%s', input_file, e)
    return
  except (policy.Error, naming.Error):
    raise ACLParserError('Error parsing policy file %s:\n%s%s' % (
        input_file, sys.exc_info()[0], sys.exc_info()[1]))

  platforms = set()
  for header in pol.headers:
    platforms.update(header.platforms)

  if 'juniper' in platforms:
    jcl = copy.deepcopy(pol)
  if 'cisco' in platforms:
    acl = copy.deepcopy(pol)
  if 'ciscoasa' in platforms:
    asacl = copy.deepcopy(pol)
  if 'brocade' in platforms:
    bacl = copy.deepcopy(pol)
  if 'arista' in platforms:
    eacl = copy.deepcopy(pol)
  if 'aruba' in platforms:
    aacl = copy.deepcopy(pol)
  if 'ipset' in platforms:
    ips = copy.deepcopy(pol)
  if 'iptables' in platforms:
    ipt = copy.deepcopy(pol)
  if 'nsxv' in platforms:
    nsx = copy.deepcopy(pol)
  if 'packetfilter' in platforms:
    pf = copy.deepcopy(pol)
  if 'pcap' in platforms:
    pcap_accept = copy.deepcopy(pol)
    pcap_deny = copy.deepcopy(pol)
  if 'speedway' in platforms:
    spd = copy.deepcopy(pol)
  if 'srx' in platforms:
    srx = copy.deepcopy(pol)
  if 'srxlo' in platforms:
    jsl = copy.deepcopy(pol)
  if 'windows_advfirewall' in platforms:
    win_afw = copy.deepcopy(pol)
  if 'ciscoxr' in platforms:
    xacl = copy.deepcopy(pol)
  if 'nftables' in platforms:
    nft = copy.deepcopy(pol)
  if 'gce' in platforms:
    gcefw = copy.deepcopy(pol)
  if 'paloalto' in platforms:
    paloalto = copy.deepcopy(pol)

  if not output_directory.endswith('/'):
    output_directory += '/'

  try:
    if jcl:
      acl_obj = juniper.Juniper(jcl, exp_info)
      RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                input_file, write_files)
    if srx:
      acl_obj = junipersrx.JuniperSRX(srx, exp_info)
      RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                input_file, write_files)
    if acl:
      acl_obj = cisco.Cisco(acl, exp_info)
      RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                input_file, write_files)
    if asacl:
      acl_obj = ciscoasa.CiscoASA(asacl, exp_info)
      RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                input_file, write_files)
    if aacl:
      acl_obj = aruba.Aruba(aacl, exp_info)
      RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                input_file, write_files)
    if bacl:
      acl_obj = brocade.Brocade(bacl, exp_info)
      RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                input_file, write_files)
    if eacl:
      acl_obj = arista.Arista(eacl, exp_info)
      RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                input_file, write_files)
    if ips:
      acl_obj = ipset.Ipset(ips, exp_info)
      RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                input_file, write_files)
    if ipt:
      acl_obj = iptables.Iptables(ipt, exp_info)
      RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                input_file, write_files)
    if nsx:
      acl_obj = nsxv.Nsxv(nsx, exp_info)
      RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                input_file, write_files)
    if spd:
      acl_obj = speedway.Speedway(spd, exp_info)
      RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                input_file, write_files)
    if pcap_accept:
      acl_obj = pcap.PcapFilter(pcap_accept, exp_info)
      RenderACL(str(acl_obj), '-accept' + acl_obj.SUFFIX, output_directory,
                input_file, write_files)
    if pcap_deny:
      acl_obj = pcap.PcapFilter(pcap_deny, exp_info, invert=True)
      RenderACL(str(acl_obj), '-deny' + acl_obj.SUFFIX, output_directory,
                input_file, write_files)
    if pf:
      acl_obj = packetfilter.PacketFilter(pf, exp_info)
      RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                input_file, write_files)
    if win_afw:
      acl_obj = windows_advfirewall.WindowsAdvFirewall(win_afw, exp_info)
      RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                input_file, write_files)
    if jsl:
      acl_obj = srxlo.SRXlo(jsl, exp_info)
      RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                input_file, write_files)
    if xacl:
      acl_obj = ciscoxr.CiscoXR(xacl, exp_info)
      RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                input_file, write_files)
    if nft:
      acl_obj = nftables.Nftables(nft, exp_info)
      RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                input_file, write_files)
    if gcefw:
      acl_obj = gce.GCE(gcefw, exp_info)
      RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                input_file, write_files)
    if paloalto:
      acl_obj = paloaltofw.PaloAltoFW(paloalto, exp_info)
      RenderACL(str(acl_obj), acl_obj.SUFFIX, output_directory,
                input_file, write_files)
  # TODO(robankeny) add additional errors.
  except (juniper.Error, junipersrx.Error, cisco.Error, ipset.Error,
          iptables.Error, speedway.Error, pcap.Error,
          aclgenerator.Error, aruba.Error, nftables.Error, gce.Error) as e:
    raise ACLGeneratorError(
        'Error generating target ACL for %s:\n%s' % (input_file, e))
Esempio n. 39
0
  def _TranslatePolicy(self, pol, exp_info):
    self.pf_policies = []
    self.address_book = {}
    self.def_short_to_long = {}
    current_date = datetime.datetime.utcnow().date()
    exp_info_date = current_date + datetime.timedelta(weeks=exp_info)

    good_afs = ['inet', 'inet6', 'mixed']
    good_options = ['in', 'out', 'nostate']
    all_protocols_stateful = True

    for header, terms in pol.filters:
      filter_type = None
      if self._PLATFORM not in header.platforms:
        continue

      filter_options = header.FilterOptions(self._PLATFORM)[1:]
      filter_name = header.FilterName(self._PLATFORM)
      direction = ''

      # ensure all options after the filter name are expected
      for opt in filter_options:
        if opt not in good_afs + good_options:
          raise aclgenerator.UnsupportedTargetOption('%s %s %s %s' % (
              '\nUnsupported option found in', self._PLATFORM,
              'target definition:', opt))

      # pf will automatically add 'keep state flags S/SA' to all TCP connections
      # by default.
      if 'nostate' in filter_options:
        all_protocols_stateful = False

      if 'in' in filter_options:
        direction = 'in'
      elif 'out' in filter_options:
        direction = 'out'

      # Check for matching af
      for address_family in good_afs:
        if address_family in filter_options:
          # should not specify more than one AF in options
          if filter_type is not None:
            raise aclgenerator.UnsupportedFilterError('%s %s %s %s' % (
                '\nMay only specify one of', good_afs, 'in filter options:',
                filter_options))
          filter_type = address_family
      if filter_type is None:
        filter_type = 'inet'

      # add the terms
      new_terms = []
      term_names = set()

      for term in terms:
        term.name = self.FixTermLength(term.name)
        if term.name in term_names:
          raise DuplicateTermError(
              'You have a duplicate term: %s' % term.name)
        term_names.add(term.name)

        for source_addr in term.source_address:
          src_token = source_addr.parent_token[:self._DEF_MAX_LENGTH]

          if (src_token in self.def_short_to_long and
              self.def_short_to_long[src_token] != source_addr.parent_token):
            raise DuplicateShortenedTableName(
                'There is a shortened name conflict between names %s and %s '
                '(different named objects would conflict when shortened to %s)'
                % (self.def_short_to_long[src_token],
                   source_addr.parent_token,
                   src_token))
          else:
            self.def_short_to_long[src_token] = source_addr.parent_token

          if src_token not in self.address_book:
            self.address_book[src_token] = set([source_addr])
          else:
            self.address_book[src_token].add(source_addr)

        for dest_addr in term.destination_address:
          dst_token = dest_addr.parent_token[:self._DEF_MAX_LENGTH]

          if (dst_token in self.def_short_to_long and
              self.def_short_to_long[dst_token] != dest_addr.parent_token):
            raise DuplicateShortenedTableName(
                'There is a shortened name conflict between names %s and %s '
                '(different named objects would conflict when shortened to %s)'
                %(self.def_short_to_long[dst_token],
                  dest_addr.parent_token,
                  dst_token))
          else:
            self.def_short_to_long[dst_token] = dest_addr.parent_token

          if dst_token not in self.address_book:
            self.address_book[dst_token] = set([dest_addr])
          else:
            self.address_book[dst_token].add(dest_addr)

        if not term:
          continue

        if term.expiration:
          if term.expiration <= exp_info_date:
            logging.info('INFO: Term %s in policy %s expires '
                         'in less than two weeks.', term.name, filter_name)
          if term.expiration <= current_date:
            logging.warn('WARNING: Term %s in policy %s is expired and '
                         'will not be rendered.', term.name, filter_name)
            continue

        new_terms.append(self._TERM(term, filter_name, all_protocols_stateful,
                                    filter_type, direction))

      self.pf_policies.append((header, filter_name, filter_type, new_terms))
Esempio n. 40
0
  def __str__(self):
    # Verify platform specific terms. Skip whole term if platform does not
    # match.
    if self.term.platform:
      if self._PLATFORM not in self.term.platform:
        return ''
    if self.term.platform_exclude:
      if self._PLATFORM in self.term.platform_exclude:
        return ''

    config = Config(indent=self._DEFAULT_INDENT)
    from_str = []
    # Don't render icmpv6 protocol terms under inet, or icmp under inet6
    if ((self.term_type == 'inet6' and 'icmp' in self.term.protocol) or
        (self.term_type == 'inet' and 'icmpv6' in self.term.protocol)):
      logging.debug(self.NO_AF_LOG_PROTO.substitute(term=self.term.name,
                                                    proto=self.term.protocol,
                                                    af=self.term_type))
      return ''

    # comment
    # this deals just fine with multi line comments, but we could probably
    # output them a little cleaner; do things like make sure the
    # len(output) < 80, etc. Note, if 'noverbose' is set for the filter, skip
    # all comment processing.
    if self.term.owner and not self.noverbose:
      self.term.comment.append('Owner: %s' % self.term.owner)
    if self.term.comment and not self.noverbose:
      config.Append('/*')
      for comment in self.term.comment:
        for line in comment.split('\n'):
          config.Append('** ' + line)
      config.Append('*/')

    # Term verbatim output - this will skip over normal term creation
    # code.  Warning generated from policy.py if appropriate.
    if self.term.verbatim:
      for next_term in self.term.verbatim:
        if next_term.value[0] == self._PLATFORM:
          config.Append(str(next_term.value[1]), verbatim=True)
      return str(config)

    # Helper for per-address-family keywords.
    family_keywords = self._TERM_TYPE.get(self.term_type)

    # option
    # this is going to be a little ugly b/c there are a few little messed
    # up options we can deal with.
    if self.term.option:
      for opt in [str(x) for x in self.term.option]:
        # there should be a better way to search the array of protocols
        if opt.startswith('sample'):
          self.extra_actions.append('sample')

        # only append tcp-established for option established when
        # tcp is the only protocol, otherwise other protos break on juniper
        elif opt.startswith('established'):
          if self.term.protocol == ['tcp']:
            if 'tcp-established;' not in from_str:
              from_str.append(family_keywords['tcp-est'] + ';')

        # if tcp-established specified, but more than just tcp is included
        # in the protocols, raise an error
        elif opt.startswith('tcp-established'):
          flag = family_keywords['tcp-est'] + ';'
          if self.term.protocol == ['tcp']:
            if flag not in from_str:
              from_str.append(flag)
          else:
            raise TcpEstablishedWithNonTcp(
                'tcp-established can only be used with tcp protocol in term %s'
                % self.term.name)
        elif opt.startswith('rst'):
          from_str.append('tcp-flags "rst";')
        elif opt.startswith('initial') and 'tcp' in self.term.protocol:
          from_str.append('tcp-initial;')
        elif opt.startswith('first-fragment'):
          from_str.append('first-fragment;')

        # we don't have a special way of dealing with this, so we output it and
        # hope the user knows what they're doing.
        else:
          from_str.append('%s;' % opt)

    # term name
    config.Append('term %s {' % self.term.name)

    # a default action term doesn't have any from { clause
    has_match_criteria = (self.term.address or
                          self.term.dscp_except or
                          self.term.dscp_match or
                          self.term.destination_address or
                          self.term.destination_port or
                          self.term.destination_prefix or
                          self.term.destination_prefix_except or
                          self.term.ether_type or
                          self.term.flexible_match_range or
                          self.term.forwarding_class or
                          self.term.forwarding_class_except or
                          self.term.fragment_offset or
                          self.term.hop_limit or
                          self.term.next_ip or
                          self.term.port or
                          self.term.precedence or
                          self.term.protocol or
                          self.term.protocol_except or
                          self.term.source_address or
                          self.term.source_port or
                          self.term.source_prefix or
                          self.term.source_prefix_except or
                          self.term.traffic_type or
                          self.term.ttl)

    if has_match_criteria:
      config.Append('from {')

      term_af = self.AF_MAP.get(self.term_type)

      # address
      address = self.term.GetAddressOfVersion('address', term_af)
      if self.enable_dsmo:
        address = summarizer.Summarize(address)

      if address:
        config.Append('%s {' % family_keywords['addr'])
        for addr in address:
          if self.enable_dsmo:
            config.Append('%s/%s;' % summarizer.ToDottedQuad(addr,
                                                             nondsm=True))
          else:
            for comment in self._Comment(addr):
              config.Append('%s' % comment)
            config.Append('%s;' % addr)
        config.Append('}')
      elif self.term.address:
        logging.debug(self.NO_AF_LOG_ADDR.substitute(term=self.term.name,
                                                     af=self.term_type))
        return ''

      # source address
      src_addr = self.term.GetAddressOfVersion('source_address', term_af)
      src_addr_ex = self.term.GetAddressOfVersion('source_address_exclude',
                                                  term_af)
      if self.enable_dsmo:
        src_addr = summarizer.Summarize(src_addr)
        src_addr_ex = summarizer.Summarize(src_addr_ex)
      else:
        src_addr, src_addr_ex = self._MinimizePrefixes(src_addr, src_addr_ex)

      if src_addr:
        config.Append('%s {' % family_keywords['saddr'])
        for addr in src_addr:
          if self.enable_dsmo:
            config.Append('%s/%s;' % summarizer.ToDottedQuad(addr,
                                                             nondsm=True))
          else:
            for comment in self._Comment(addr):
              config.Append('%s' % comment)
            config.Append('%s;' % addr)
        for addr in src_addr_ex:
          if self.enable_dsmo:
            config.Append('%s/%s except;' %
                          summarizer.ToDottedQuad(addr, nondsm=True))
          else:
            for comment in self._Comment(addr, exclude=True):
              config.Append('%s' % comment)
            config.Append('%s except;' % addr)
        config.Append('}')
      elif self.term.source_address:
        logging.debug(self.NO_AF_LOG_ADDR.substitute(term=self.term.name,
                                                     direction='source',
                                                     af=self.term_type))
        return ''

      # destination address
      dst_addr = self.term.GetAddressOfVersion('destination_address', term_af)
      dst_addr_ex = self.term.GetAddressOfVersion('destination_address_exclude',
                                                  term_af)
      if self.enable_dsmo:
        dst_addr = summarizer.Summarize(dst_addr)
        dst_addr_ex = summarizer.Summarize(dst_addr_ex)
      else:
        dst_addr, dst_addr_ex = self._MinimizePrefixes(dst_addr, dst_addr_ex)

      if dst_addr:
        config.Append('%s {' % family_keywords['daddr'])
        for addr in dst_addr:
          if self.enable_dsmo:
            config.Append('%s/%s;' % summarizer.ToDottedQuad(addr,
                                                             nondsm=True))
          else:
            for comment in self._Comment(addr):
              config.Append('%s' % comment)
            config.Append('%s;' % addr)
        for addr in dst_addr_ex:
          if self.enable_dsmo:
            config.Append('%s/%s except;' %
                          summarizer.ToDottedQuad(addr, nondsm=True))
          else:
            for comment in self._Comment(addr, exclude=True):
              config.Append('%s' % comment)
            config.Append('%s except;' % addr)
        config.Append('}')
      elif self.term.destination_address:
        logging.debug(self.NO_AF_LOG_ADDR.substitute(term=self.term.name,
                                                     direction='destination',
                                                     af=self.term_type))
        return ''

      # forwarding-class
      if self.term.forwarding_class:
        config.Append('forwarding-class %s' % self._Group(
            self.term.forwarding_class, lc=False))

      # forwarding-class-except
      if self.term.forwarding_class_except:
        config.Append('forwarding-class-except %s' % self._Group(
            self.term.forwarding_class_except, lc=False))

      # source prefix <except> list
      if self.term.source_prefix or self.term.source_prefix_except:
        config.Append('source-prefix-list {')
        for pfx in self.term.source_prefix:
          config.Append(pfx + ';')
        for epfx in self.term.source_prefix_except:
          config.Append(epfx + ' except;')
        config.Append('}')

      # destination prefix <except> list
      if self.term.destination_prefix or self.term.destination_prefix_except:
        config.Append('destination-prefix-list {')
        for pfx in self.term.destination_prefix:
          config.Append(pfx + ';')
        for epfx in self.term.destination_prefix_except:
          config.Append(epfx + ' except;')
        config.Append('}')

      if self.term.ttl:
        config.Append('ttl %s;' % self.term.ttl)

      # protocol
      if self.term.protocol:
        # both are supported on JunOS, but only icmp6 is supported
        # on SRX loopback stateless filter
        config.Append(family_keywords['protocol'] +
                      ' ' + self._Group(self.term.protocol))

      # protocol
      if self.term.protocol_except:
        # same as above
        config.Append(family_keywords['protocol-except'] + ' '
                      + self._Group(self.term.protocol_except))

      # port
      if self.term.port:
        config.Append('port %s' % self._Group(self.term.port))

      # source port
      if self.term.source_port:
        config.Append('source-port %s' % self._Group(self.term.source_port))

      # destination port
      if self.term.destination_port:
        config.Append('destination-port %s' %
                      self._Group(self.term.destination_port))

      # append any options beloging in the from {} section
      for next_str in from_str:
        config.Append(next_str)

      # packet length
      if self.term.packet_length:
        config.Append('packet-length %s;' % self.term.packet_length)

      # fragment offset
      if self.term.fragment_offset:
        config.Append('fragment-offset %s;' % self.term.fragment_offset)

      # icmp-types
      icmp_types = ['']
      if self.term.icmp_type:
        icmp_types = self.NormalizeIcmpTypes(self.term.icmp_type,
                                             self.term.protocol, self.term_type)
      if icmp_types != ['']:
        config.Append('icmp-type %s' % self._Group(icmp_types))
      if self.term.icmp_code:
        config.Append('icmp-code %s' % self._Group(self.term.icmp_code))
      if self.term.ether_type:
        config.Append('ether-type %s' %
                      self._Group(self.term.ether_type))

      if self.term.traffic_type:
        config.Append('traffic-type %s' %
                      self._Group(self.term.traffic_type))

      if self.term.precedence:
        # precedence may be a single integer, or a space separated list
        policy_precedences = set()
        # precedence values may only be 0 through 7
        for precedence in self.term.precedence:
          if int(precedence) in range(0, 8):
            policy_precedences.add(precedence)
          else:
            raise PrecedenceError('Precedence value %s is out of bounds in %s' %
                                  (precedence, self.term.name))
        config.Append('precedence %s' % self._Group(sorted(policy_precedences)))

      # DSCP Match
      if self.term.dscp_match:
        if self.term_type == 'inet6':
          config.Append('traffic-class [ %s ];' % (
              ' '.join(self.term.dscp_match)))
        else:
          config.Append('dscp [ %s ];' % ' '.join(self.term.dscp_match))

      # DSCP Except
      if self.term.dscp_except:
        if self.term_type == 'inet6':
          config.Append('traffic-class-except [ %s ];' % (
              ' '.join(self.term.dscp_except)))
        else:
          config.Append('dscp-except [ %s ];' % ' '.join(self.term.dscp_except))

      if self.term.hop_limit:
        # Only generate a hop-limit if inet6, inet4 has not hop-limit.
        if self.term_type == 'inet6':
          config.Append('hop-limit %s;' % (self.term.hop_limit))

      # flexible-match
      if self.term.flexible_match_range:
        config.Append('flexible-match-range {')
        for fm_opt in self.term.flexible_match_range:
          config.Append('%s %s;' % (fm_opt[0], fm_opt[1]))
        config.Append('}')

      config.Append('}')  # end from { ... }

    ####
    # ACTIONS go below here
    ####

    # If the action is only one line, include it in the same line as "then "
    # statement.
    # For example, if the action is only accept, it should be:
    # "then accept;" rather than:
    # "then {
    #     accept;
    # }"
    #
    unique_actions = set(self.extra_actions)
    if not self.term.routing_instance:
      unique_actions.update(self.term.action)
    if len(unique_actions) <= 1:
      for action in [self.term.logging, self.term.routing_instance,
                     self.term.counter, self.term.policer, self.term.qos,
                     self.term.loss_priority, self.term.dscp_set,
                     self.term.next_ip, self.term.traffic_class_count]:
        if action:
          try:
            unique_actions.update(action)
          except TypeError:
            unique_actions.add(action)
          if len(unique_actions) > 1:
            break

    if len(unique_actions) == 1:
      # b/21795531: Juniper device treats a set of IPv4 actions differently
      # than any other actions.
      # For example, if the term is in IPv4 and the action is only discard,
      # it should be:
      # "then {
      #     discard;
      # }" rather than:
      # "then discard;"
      current_action = self.ACTIONS.get(unique_actions.pop(), 'next_ip')
      if (self.term_type == 'inet' and
          current_action in ['discard', 'reject', 'reject tcp-reset']
         ) or (self.term_type == 'inet6' and current_action in
               ['reject', 'reject tcp-reset']):
        config.Append('then {')
        config.Append('%s;' % current_action)
        config.Append('}')
      elif current_action == 'next_ip':
        self.NextIpCheck(self.term.next_ip, self.term.name)
        config.Append('then {')
        if self.term.next_ip[0].version == 4:
          config.Append('next-ip %s;' % str(self.term.next_ip[0]))
        else:
          config.Append('next-ip6 %s;' % str(self.term.next_ip[0]))
        config.Append('}')
      else:
        config.Append('then %s;' % current_action)
    elif len(unique_actions) > 1:
      config.Append('then {')
      # logging
      if self.term.logging:
        for log_target in self.term.logging:
          if str(log_target) == 'local':
            config.Append('log;')
          else:
            config.Append('syslog;')

      if self.term.routing_instance:
        config.Append('routing-instance %s;' % self.term.routing_instance)

      if self.term.counter:
        config.Append('count %s;' % self.term.counter)

      if self.term.traffic_class_count:
        config.Append('traffic-class-count %s;' % self.term.traffic_class_count)

      oid_length = 128
      if self.term.policer:
        config.Append('policer %s;' % self.term.policer)
        if len(self.term.policer) > oid_length:
          logging.warn('WARNING: %s is longer than %d bytes. Due to limitation'
                       ' in JUNOS, OIDs longer than %dB can cause SNMP '
                       'timeout issues.',
                       self.term.policer, oid_length, oid_length)

      if self.term.qos:
        config.Append('forwarding-class %s;' % self.term.qos)

      if self.term.loss_priority:
        config.Append('loss-priority %s;' % self.term.loss_priority)
      if self.term.next_ip:
        self.NextIpCheck(self.term.next_ip, self.term.name)
        if self.term.next_ip[0].version == 4:
          config.Append('next-ip %s;' % str(self.term.next_ip[0]))
        else:
          config.Append('next-ip6 %s;' % str(self.term.next_ip[0]))
      for action in self.extra_actions:
        config.Append(action + ';')

      # If there is a routing-instance defined, skip reject/accept/etc actions.
      if not self.term.routing_instance:
        for action in self.term.action:
          config.Append(self.ACTIONS.get(action) + ';')

      # DSCP SET
      if self.term.dscp_set:
        if self.term_type == 'inet6':
          config.Append('traffic-class %s;' % self.term.dscp_set)
        else:
          config.Append('dscp %s;' % self.term.dscp_set)

      config.Append('}')  # end then{...}

    config.Append('}')  # end term accept-foo-to-bar { ... }

    return str(config)
Esempio n. 41
0
File: parser.py Progetto: vruge/docs
def _generate_signature(func, reverse_index):
  """Given a function, returns a list of strings representing its args.

  This function produces a list of strings representing the arguments to a
  python function. It uses tf_inspect.getfullargspec, which
  does not generalize well to Python 3.x, which is more flexible in how *args
  and **kwargs are handled. This is not a problem in TF, since we have to remain
  compatible to Python 2.7 anyway.

  This function uses `__name__` for callables if it is available. This can lead
  to poor results for functools.partial and other callable objects.

  The returned string is Python code, so if it is included in a Markdown
  document, it should be typeset as code (using backticks), or escaped.

  Args:
    func: A function, method, or functools.partial to extract the signature for.
    reverse_index: A map from object ids to canonical full names to use.

  Returns:
    A list of strings representing the argument signature of `func` as python
    code.
  """

  args_list = []

  argspec = tf_inspect.getfullargspec(func)
  first_arg_with_default = (
      len(argspec.args or []) - len(argspec.defaults or []))

  # Python documentation skips `self` when printing method signatures.
  # Note we cannot test for ismethod here since unbound methods do not register
  # as methods (in Python 3).
  first_arg = 1 if 'self' in argspec.args[:1] else 0

  # Add all args without defaults.
  for arg in argspec.args[first_arg:first_arg_with_default]:
    args_list.append(arg)

  # Add all args with defaults.
  if argspec.defaults:
    try:
      source = _remove_first_line_indent(tf_inspect.getsource(func))
      func_ast = ast.parse(source)
      ast_defaults = func_ast.body[0].args.defaults
    except IOError:  # If this is a builtin, getsource fails with IOError
      # If we cannot get the source, assume the AST would be equal to the repr
      # of the defaults.
      ast_defaults = [None] * len(argspec.defaults)

    for arg, default, ast_default in zip(
        argspec.args[first_arg_with_default:], argspec.defaults, ast_defaults):
      if id(default) in reverse_index:
        default_text = reverse_index[id(default)]
      elif ast_default is not None:
        default_text = (
            astor.to_source(ast_default).rstrip('\n').replace('\t', '\\t')
            .replace('\n', '\\n').replace('"""', "'"))
        default_text = PAREN_NUMBER_RE.sub('\\1', default_text)

        if default_text != repr(default):
          # This may be an internal name. If so, handle the ones we know about.
          # TODO(wicke): This should be replaced with a lookup in the index.
          # TODO(wicke): (replace first ident with tf., check if in index)
          internal_names = {
              'ops.GraphKeys': 'tf.GraphKeys',
              '_ops.GraphKeys': 'tf.GraphKeys',
              'init_ops.zeros_initializer': 'tf.zeros_initializer',
              'init_ops.ones_initializer': 'tf.ones_initializer',
              'saver_pb2.SaverDef': 'tf.train.SaverDef',
          }
          full_name_re = '^%s(.%s)+' % (IDENTIFIER_RE, IDENTIFIER_RE)
          match = re.match(full_name_re, default_text)
          if match:
            lookup_text = default_text
            for internal_name, public_name in six.iteritems(internal_names):
              if match.group(0).startswith(internal_name):
                lookup_text = public_name + default_text[len(internal_name):]
                break
            if default_text is lookup_text:
              logging.warn(
                  'WARNING: Using default arg, failed lookup: %s, repr: %r',
                  default_text, default)
            else:
              default_text = lookup_text
      else:
        default_text = repr(default)

      args_list.append('%s=%s' % (arg, default_text))

  # Add *args and *kwargs.
  if argspec.varargs:
    args_list.append('*' + argspec.varargs)
  if argspec.varkw:
    args_list.append('**' + argspec.varkw)

  return args_list
Esempio n. 42
0
    def _TranslatePolicy(self, pol, exp_info):
        """Translate a policy from objects into strings."""
        current_date = datetime.datetime.utcnow().date()
        exp_info_date = current_date + datetime.timedelta(weeks=exp_info)

        default_action = None
        good_default_actions = ['ACCEPT', 'DROP']
        good_afs = ['inet', 'inet6']
        all_protocols_stateful = True
        self.verbose = True

        for header, terms in pol.filters:
            filter_type = None
            if self._PLATFORM not in header.platforms:
                continue

            self.filter_options = header.FilterOptions(self._PLATFORM)[1:]
            filter_name = header.FilterName(self._PLATFORM)

            self._WarnIfCustomTarget(filter_name)

            # ensure all options after the filter name are expected
            for opt in self.filter_options:
                if opt not in good_default_actions + good_afs + self._GOOD_OPTIONS:
                    raise UnsupportedTargetOptionError(
                        '%s %s %s %s' %
                        ('\nUnsupported option found in', self._PLATFORM,
                         'target definition:', opt))

            # disable stateful?
            if 'nostate' in self.filter_options:
                all_protocols_stateful = False
            if 'noverbose' in self.filter_options:
                self.verbose = False

            # Check for matching af
            for address_family in good_afs:
                if address_family in self.filter_options:
                    # should not specify more than one AF in options
                    if filter_type is not None:
                        raise UnsupportedFilterError(
                            '%s %s %s %s' %
                            ('\nMay only specify one of', good_afs,
                             'in filter options:', self.filter_options))
                    filter_type = address_family
            if filter_type is None:
                filter_type = 'inet'

            if self._PLATFORM == 'iptables' and filter_name == 'FORWARD':
                default_action = 'DROP'

            # does this policy override the default filter actions?
            for next_target in header.target:
                if next_target.platform == self._PLATFORM:
                    if len(next_target.options) > 1:
                        for arg in next_target.options:
                            if arg in good_default_actions:
                                default_action = arg
            if default_action and default_action not in good_default_actions:
                raise UnsupportedDefaultActionError(
                    '%s %s %s %s %s' %
                    ('\nOnly', ', '.join(good_default_actions),
                     'default filter action allowed;', default_action,
                     'used.'))

            # add the terms
            new_terms = []
            term_names = set()
            for term in terms:
                term.name = self.FixTermLength(
                    term.name, 'abbreviateterms' in self.filter_options,
                    'truncateterms' in self.filter_options)
                if term.name in term_names:
                    raise aclgenerator.DuplicateTermError(
                        'You have a duplicate term: %s' % term.name)
                term_names.add(term.name)
                if not term.logging and term.log_limit:
                    raise LimitButNoLogError(
                        'Term %s: Cannoy use log-limit without logging' %
                        term.name)

                term = self.FixHighPorts(
                    term,
                    af=filter_type,
                    all_protocols_stateful=all_protocols_stateful)
                if not term:
                    continue

                if term.expiration:
                    if term.expiration <= exp_info_date:
                        logging.info(
                            'INFO: Term %s in policy %s expires '
                            'in less than two weeks.', term.name, filter_name)
                    if term.expiration <= current_date:
                        logging.warn(
                            'WARNING: Term %s in policy %s is expired and '
                            'will not be rendered.', term.name, filter_name)
                        continue

                new_terms.append(
                    self._TERM(term, filter_name, all_protocols_stateful,
                               default_action, filter_type, self.verbose))

            self.iptables_policies.append(
                (header, filter_name, filter_type, default_action, new_terms))
Esempio n. 43
0
  def _TranslatePolicy(self, pol, exp_info):
    """Translate a policy from objects into strings."""
    self.windows_policies = []
    current_date = datetime.datetime.utcnow().date()
    exp_info_date = current_date + datetime.timedelta(weeks=exp_info)

    default_action = None
    good_default_actions = ['permit', 'block']
    good_options = []

    for header, terms in pol.filters:
      filter_type = None
      if self._PLATFORM not in header.platforms:
        continue

      filter_options = header.FilterOptions(self._PLATFORM)[1:]
      filter_name = header.FilterName(self._PLATFORM)

      # ensure all options after the filter name are expected
      for opt in filter_options:
        if opt not in good_default_actions + self._GOOD_AFS + good_options:
          raise aclgenerator.UnsupportedTargetOption('%s %s %s %s' % (
              '\nUnsupported option found in', self._PLATFORM,
              'target definition:', opt))

      # Check for matching af
      for address_family in self._GOOD_AFS:
        if address_family in filter_options:
          # should not specify more than one AF in options
          if filter_type is not None:
            raise aclgenerator.UnsupportedFilterError('%s %s %s %s' % (
                '\nMay only specify one of', self._GOOD_AFS,
                'in filter options:', filter_options))
          filter_type = address_family
      if filter_type is None:
        filter_type = 'inet'

      # does this policy override the default filter actions?
      for next_target in header.target:
        if next_target.platform == self._PLATFORM:
          if len(next_target.options) > 1:
            for arg in next_target.options:
              if arg in good_default_actions:
                default_action = arg
      if default_action and default_action not in good_default_actions:
        raise aclgenerator.UnsupportedDefaultAction('%s %s %s %s %s' % (
            '\nOnly', ', '.join(good_default_actions),
            'default filter action allowed;', default_action, 'used.'))

      # add the terms
      new_terms = []
      term_names = set()
      for term in terms:
        if term.name in term_names:
          raise aclgenerator.DuplicateTermError(
              'You have a duplicate term: %s' % term.name)
        term_names.add(term.name)

        if term.expiration:
          if term.expiration <= exp_info_date:
            logging.info('INFO: Term %s in policy %s expires '
                         'in less than two weeks.', term.name, filter_name)
          if term.expiration <= current_date:
            logging.warn('WARNING: Term %s in policy %s is expired and '
                         'will not be rendered.', term.name, filter_name)
            continue

        new_terms.append(self._TERM(term, filter_name, default_action,
                                    filter_type))

      self.windows_policies.append((header, filter_name, filter_type,
                                    default_action, new_terms))
Esempio n. 44
0
File: cell.py Progetto: yyht/lamb
    def make_rhn_column():
        init_params = collections.OrderedDict([
            ('B_c', {
                'initializer': tf.constant_initializer(forget_bias)
            }),
        ])

        if overlay_rank > 0:
            assert sparsity_ratio < 0
            # TODO(melisgl): Specify initializers for the shared matrices.
            tiled_linear_class = tiled_linear.OverlayedTiledLinear
            init_params.update(
                collections.OrderedDict([
                    ('W_x_h', {
                        'overlay_sharing_key': 'W_x_any',
                        'overlay_rank': overlay_rank
                    }),
                    ('W_x_c', {
                        'overlay_sharing_key': 'W_x_any',
                        'overlay_rank': overlay_rank
                    }),
                    ('W_x_t', {
                        'overlay_sharing_key': 'W_x_any',
                        'overlay_rank': overlay_rank
                    }),
                    ('W_s_h', {
                        'overlay_sharing_key': 'W_s_any',
                        'overlay_rank': overlay_rank
                    }),
                    ('W_s_c', {
                        'overlay_sharing_key': 'W_s_any',
                        'overlay_rank': overlay_rank
                    }),
                    ('W_s_t', {
                        'overlay_sharing_key': 'W_s_any',
                        'overlay_rank': overlay_rank
                    }),
                ]))
        elif sparsity_ratio >= 0.0:
            assert overlay_rank == -1
            tiled_linear_class = tiled_linear.SparseTiledLinear
            sparse_initializer = tf.truncated_normal_initializer(
                stddev=math.sqrt(
                    cell_init_factor / sparsity_ratio /
                    # TODO(melisgl): This is off if the input
                    # embedding size is different from the hidden
                    # size.
                    hidden_size))
            init_params.update(
                collections.OrderedDict([
                    ('W_x_.*', {
                        'sparse_indices_sharing_key': 'W_x'
                    }),
                    ('W_s_.*', {
                        'sparse_indices_sharing_key': 'W_s'
                    }),
                    ('W_x', {
                        'sparsity_ratio': sparsity_ratio,
                        'initializer': sparse_initializer
                    }),
                    ('W_s', {
                        'sparsity_ratio': sparsity_ratio,
                        'initializer': sparse_initializer
                    }),
                ]))
        else:
            tiled_linear_class = tiled_linear.TiledLinear
            init_params.update(
                collections.OrderedDict([
                    ('W_.*', {
                        'initializer': cell_initializer
                    }),
                ]))
        logging.info('Creating RHN of depth %s', num_layers)
        if layer_norm:
            logging.warn('RHN does not support layer normalization.')
        cell = TiledRHNCell(hidden_size,
                            depth=num_layers,
                            tie_gates=tie_forget_and_input_gates,
                            input_transform=input_transforms[layer_index],
                            state_transform=state_transforms[layer_index],
                            update_transform=update_transforms[layer_index],
                            tiled_linear_class=tiled_linear_class,
                            tiled_linear_var_init_params=init_params,
                            cell_clip=cell_clip if cell_clip > 0 else None,
                            activation=eval(activation_fn))  # pylint: disable=eval-used
        return cell
Esempio n. 45
0
def main(args):
  FLAGS(args)
  if FLAGS.verbose:
    logging.basicConfig(level=logging.INFO)
  if FLAGS.debug:
    logging.basicConfig(level=logging.DEBUG)
  logging.debug('binary: %s\noptimize: %d\nbase_directory: %s\n'
                'policy_file: %s\nrendered_acl_directory: %s',
                str(sys.argv[0]),
                int(FLAGS.optimize),
                str(FLAGS.base_directory),
                str(FLAGS.policy_file),
                str(FLAGS.output_directory))

  definitions = None
  try:
    definitions = naming.Naming(FLAGS.definitions_directory)
  except naming.NoDefinitionsError:
    err_msg = 'bad definitions directory: %s', FLAGS.definitions_directory
    logging.fatal(err_msg)
    sys.exit(1)

  # thead-safe list for storing files to write
  manager = multiprocessing.Manager()
  write_files = manager.list()

  with_errors = False
  if FLAGS.policy_file:
    # render just one file
    logging.info('rendering one file')
    RenderFile(FLAGS.policy_file, FLAGS.output_directory, definitions,
               FLAGS.exp_info, write_files)
  else:
    # render all files in parallel
    logging.info('finding policies...')
    pols = []
    pols.extend(DescendRecursively(FLAGS.base_directory, FLAGS.output_directory,
                                   definitions))

    pool = multiprocessing.Pool(processes=FLAGS.max_renderers)
    results = []
    for x in pols:
      results.append(pool.apply_async(RenderFile,
                                      args=(x.get('in_file'),
                                            x.get('out_dir'),
                                            definitions,
                                            FLAGS.exp_info,
                                            write_files)))
    pool.close()
    pool.join()

    for result in results:
      try:
        result.get()
      except (ACLParserError, ACLGeneratorError) as e:
        with_errors = True
        logging.warn('\n\nerror encountered in rendering process:\n%s\n\n', e)

  # actually write files to disk
  WriteFiles(write_files)

  if with_errors:
    logging.warn('done, with errors.')
    sys.exit(1)
  else:
    logging.info('done.')
Esempio n. 46
0
    def __str__(self):
        # Verify platform specific terms. Skip whole term if platform does not
        # match.
        if self.term.platform:
            if self._PLATFORM not in self.term.platform:
                return ''
        if self.term.platform_exclude:
            if self._PLATFORM in self.term.platform_exclude:
                return ''

        ret_str = []

        # Don't render icmpv6 protocol terms under inet, or icmp under inet6
        if ((self.af == 'inet6' and 'icmp' in self.term.protocol)
                or (self.af == 'inet' and 'icmpv6' in self.term.protocol)):
            logging.debug(
                self.NO_AF_LOG_PROTO.substitute(term=self.term.name,
                                                proto=', '.join(
                                                    self.term.protocol),
                                                af=self.af))
            return ''

        # Term verbatim output - this will skip over most normal term
        # creation code by returning early. Warnings provided in policy.py
        if self.term.verbatim:
            for next_verbatim in self.term.verbatim:
                if next_verbatim[0] == self._PLATFORM:
                    ret_str.append(str(next_verbatim[1]))
            return '\n'.join(ret_str)

        # Create a new term
        self._SetDefaultAction()
        if self._TERM_FORMAT:
            ret_str.append(self._TERM_FORMAT.substitute(term=self.term_name))

        if self._PREJUMP_FORMAT:
            ret_str.append(
                self._PREJUMP_FORMAT.substitute(filter=self.filter,
                                                term=self.term_name))

        if self.verbose:
            if self.term.owner:
                self.term.comment.append('Owner: %s' % self.term.owner)
            # reformat long comments, if needed
            #
            # iptables allows individual comments up to 256 chars.
            # But our generator will limit a single comment line to < 120, using:
            # max = 119 - 27 (static chars in comment command) - [length of term name]
            comment_max_width = 92 - len(self.term_name)
            if comment_max_width < 40:
                comment_max_width = 40
            comments = aclgenerator.WrapWords(self.term.comment,
                                              comment_max_width)
            # append comments to output
            if comments and comments[0]:
                for line in comments:
                    if not line:
                        continue  # iptables-restore does not like 0-length comments.
                    # term comments
                    # Strip out quotes as iptables cant have nested quotes
                    ret_str.append(
                        self._COMMENT_FORMAT.substitute(
                            filter=self.filter,
                            term=self.term_name,
                            comment=str(line).replace('\"', '')))

        # Unsupported configuration; in the case of 'accept' or 'next', we
        # skip the rule.  In other cases, we blow up (raise an exception)
        # to ensure that this is not considered valid configuration.
        if self.term.source_prefix or self.term.destination_prefix:
            if str(self.term.action[0]) not in set(['accept', 'next']):
                raise UnsupportedFilterError(
                    '%s %s %s %s %s %s %s %s' %
                    ('\nTerm', self.term.name, 'has action',
                     str(self.term.action[0]),
                     'with source_prefix or destination_prefix,',
                     ' which is unsupported in', self._PLATFORM,
                     'iptables output.'))
            return ('# skipped %s due to source or destination prefix rule' %
                    self.term.name)

        # protocol
        if self.term.protocol:
            protocol = self.term.protocol
        else:
            protocol = ['all']
        if 'hopopt' in protocol and self.af == 'inet':
            logging.warn('Term %s is using hopopt in IPv4 context.',
                         self.term_name)
            return ''

        (term_saddr, exclude_saddr, term_daddr,
         exclude_daddr) = self._CalculateAddresses(
             self.term.source_address, self.term.source_address_exclude,
             self.term.destination_address,
             self.term.destination_address_exclude)
        if not term_saddr:
            logging.debug(
                self.NO_AF_LOG_ADDR.substitute(term=self.term.name,
                                               direction='source',
                                               af=self.af))
            return ''
        if not term_daddr:
            logging.debug(
                self.NO_AF_LOG_ADDR.substitute(term=self.term.name,
                                               direction='destination',
                                               af=self.af))
            return ''

        # ports
        source_port = []
        destination_port = []
        if self.term.source_port:
            source_port = self.term.source_port
        if self.term.destination_port:
            destination_port = self.term.destination_port

        # icmp-types
        icmp_types = ['']
        if self.term.icmp_type:
            icmp_types = self.NormalizeIcmpTypes(self.term.icmp_type, protocol,
                                                 self.af)

        source_interface = ''
        if self.term.source_interface:
            source_interface = self.term.source_interface

        destination_interface = ''
        if self.term.destination_interface:
            destination_interface = self.term.destination_interface

        log_hits = False
        if self.term.logging:
            # Iptables sends logs to hosts configured syslog
            log_hits = True

        # options
        tcp_flags = []
        tcp_track_options = []
        for next_opt in [str(x) for x in self.term.option]:
            #
            # Sanity checking and high-ports are added as appropriate in
            # pre-processing that is done in __str__ within class Iptables.
            # Option established will add destination port high-ports if protocol
            # contains only tcp, udp or both.  This is done earlier in class Iptables.
            #
            if ((next_opt.find('established') == 0
                 or next_opt.find('tcp-established') == 0)
                    and 'ESTABLISHED' not in [x.strip()
                                              for x in self.options]):
                if next_opt.find(
                        'tcp-established') == 0 and protocol != ['tcp']:
                    raise TcpEstablishedError('%s %s %s' % (
                        '\noption tcp-established can only be applied for proto tcp.',
                        '\nError in term:', self.term.name))

                if self.trackstate:
                    # Use nf_conntrack to track state -- works with any proto
                    self.options.append('-m state --state ESTABLISHED,RELATED')
                elif protocol == ['tcp']:
                    # Simple established-only rule for TCP: Must have ACK field
                    # (SYN/ACK or subsequent ACK), or RST and no other flags.
                    tcp_track_options = [(['ACK'], ['ACK']),
                                         (['SYN', 'FIN', 'ACK',
                                           'RST'], ['RST'])]

            # Iterate through flags table, and create list of tcp-flags to append
            for next_flag in self._TCP_FLAGS_TABLE:
                if next_opt.find(next_flag) == 0:
                    tcp_flags.append(self._TCP_FLAGS_TABLE.get(next_flag))
            if next_opt in self._KNOWN_OPTIONS_MATCHERS:
                self.options.append(self._KNOWN_OPTIONS_MATCHERS[next_opt])
        if self.term.packet_length:
            # Policy format is "#-#", but iptables format is "#:#"
            self.options.append('-m length --length %s' %
                                self.term.packet_length.replace('-', ':'))
        if self.term.fragment_offset:
            self.options.append('-m u32 --u32 4&0x1FFF=%s' %
                                self.term.fragment_offset.replace('-', ':'))
        icmp_code = ['']
        if self.term.icmp_code:
            icmp_code = self.term.icmp_code

        for saddr in exclude_saddr:
            ret_str.extend(
                self._FormatPart('', saddr, '', '', '', '', '', '', '', '', '',
                                 '', '', self._action_table.get('next')))
        for daddr in exclude_daddr:
            ret_str.extend(
                self._FormatPart('', '', '', daddr, '', '', '', '', '', '', '',
                                 '', '', self._action_table.get('next')))

        for saddr in term_saddr:
            for daddr in term_daddr:
                for icmp in icmp_types:
                    for code in icmp_code:
                        for proto in protocol:
                            for tcp_matcher in tcp_track_options or (([],
                                                                      []), ):
                                ret_str.extend(
                                    self._FormatPart(
                                        str(proto), saddr, source_port, daddr,
                                        destination_port, self.options,
                                        tcp_flags, icmp, code, tcp_matcher,
                                        source_interface,
                                        destination_interface, log_hits,
                                        self._action_table.get(
                                            str(self.term.action[0]))))

        if self._POSTJUMP_FORMAT:
            ret_str.append(
                self._POSTJUMP_FORMAT.substitute(filter=self.filter,
                                                 term=self.term_name))

        return '\n'.join(str(v) for v in ret_str if v is not '')