Пример #1
0
    def to_csv(self, file_name, fieldnames=None, batch_size=200, filters=None, iterator=None):
        """Write all records to a CSV file.

        Args:
            file_name: the path to the CSV file to create.
            fieldnames: the list of fields to save. If not set, it will get
                them from the first record and sort them alphabetically.
            batch_size: the size of the batch of records to download.
            filters: optional filters not to ask the whole resource.
            iterator: a wrapper around the iterator on records, so that you can
                modify the records or just keep track of progress.
        """
        records = self.records(batch_size=batch_size, filters=filters)
        need_utf8_encode = six.PY2
        if not fieldnames:
            first = records.peek_first()
            keys = set(_strip_bom(k) for k in first.keys())
            fieldnames = sorted(keys - set(['_id']))
        if need_utf8_encode:
            fieldnames = [six.ensure_str(f, encoding='utf-8') for f in fieldnames]
        with open(file_name, 'wt') as csvfile:
            csv_writer = csv.DictWriter(
                csvfile, fieldnames, extrasaction='ignore')
            csv_writer.writeheader()
            for record in iterator(records) if iterator else records:
                record = {_strip_bom(k): v for k, v in record.items()}
                if need_utf8_encode:
                    record = {
                        six.ensure_str(k, encoding='utf-8'):
                        six.ensure_str(six.text_type(v), encoding='utf-8')
                        for k, v in record.items()}
                csv_writer.writerow(record)
Пример #2
0
 def test_ensure_str(self):
     converted_unicode = six.ensure_str(self.UNICODE_EMOJI, encoding='utf-8', errors='strict')
     converted_binary = six.ensure_str(self.BINARY_EMOJI, encoding="utf-8", errors='strict')
     if six.PY2:
         # PY2: unicode -> str
         assert converted_unicode == self.BINARY_EMOJI and isinstance(converted_unicode, str)
         # PY2: str -> str
         assert converted_binary == self.BINARY_EMOJI and isinstance(converted_binary, str)
     else:
         # PY3: str -> str
         assert converted_unicode == self.UNICODE_EMOJI and isinstance(converted_unicode, str)
         # PY3: bytes -> str
         assert converted_binary == self.UNICODE_EMOJI and isinstance(converted_unicode, str)
Пример #3
0
  def _MatchesExcludedPattern(self, blr):
    """Checks bucket listing reference against patterns to exclude.

    Args:
      blr: BucketListingRef to check.

    Returns:
      True if reference matches a pattern and should be excluded.
    """
    if self.exclude_patterns:
      tomatch = six.ensure_str(blr.url_string)
      for pattern in self.exclude_patterns:
        if fnmatch.fnmatch(tomatch, six.ensure_str(pattern)):
          return True
    return False
Пример #4
0
def _ProcessUnknownEnums(message, encoded_message):
    """Add unknown enum values from encoded_message as unknown fields.

    ProtoRPC diverges from the usual protocol buffer behavior here and
    doesn't allow unknown fields. Throwing on unknown fields makes it
    impossible to let servers add new enum values and stay compatible
    with older clients, which isn't reasonable for us. We simply store
    unrecognized enum values as unknown fields, and all is well.

    Args:
      message: Proto message we've decoded thus far.
      encoded_message: JSON string we're decoding.

    Returns:
      message, with any unknown enums stored as unrecognized fields.
    """
    if not encoded_message:
        return message
    decoded_message = json.loads(six.ensure_str(encoded_message))
    for field in message.all_fields():
        if (isinstance(field, messages.EnumField) and
                field.name in decoded_message and
                message.get_assigned_value(field.name) is None):
            message.set_unrecognized_field(
                field.name, decoded_message[field.name], messages.Variant.ENUM)
    return message
Пример #5
0
def _ProcessUnknownMessages(message, encoded_message):
    """Store any remaining unknown fields as strings.

    ProtoRPC currently ignores unknown values for which no type can be
    determined (and logs a "No variant found" message). For the purposes
    of reserializing, this is quite harmful (since it throws away
    information). Here we simply add those as unknown fields of type
    string (so that they can easily be reserialized).

    Args:
      message: Proto message we've decoded thus far.
      encoded_message: JSON string we're decoding.

    Returns:
      message, with any remaining unrecognized fields saved.
    """
    if not encoded_message:
        return message
    decoded_message = json.loads(six.ensure_str(encoded_message))
    message_fields = [x.name for x in message.all_fields()] + list(
        message.all_unrecognized_fields())
    missing_fields = [x for x in decoded_message.keys()
                      if x not in message_fields]
    for field_name in missing_fields:
        message.set_unrecognized_field(field_name, decoded_message[field_name],
                                       messages.Variant.STRING)
    return message
Пример #6
0
    def decode_message(self, message_type, encoded_message):
        """Merge JSON structure to Message instance.

        Args:
          message_type: Message to decode data to.
          encoded_message: JSON encoded version of message.

        Returns:
          Decoded instance of message_type.

        Raises:
          ValueError: If encoded_message is not valid JSON.
          messages.ValidationError if merged message is not initialized.
        """
        encoded_message = six.ensure_str(encoded_message)
        if not encoded_message.strip():
            return message_type()

        dictionary = json.loads(encoded_message)
        message = self.__decode_dictionary(message_type, dictionary)
        message.check_initialized()
        return message
Пример #7
0
def _ReadJSONKeystore(ks_contents, passwd=None):
  """Read the client email and private key from a JSON keystore.

  Assumes this keystore was downloaded from the Cloud Platform Console.
  By default, JSON keystore private keys from the Cloud Platform Console
  aren't encrypted so the passwd is optional as load_privatekey will
  prompt for the PEM passphrase if the key is encrypted.

  Arguments:
    ks_contents: JSON formatted string representing the keystore contents. Must
                 be a valid JSON string and contain the fields 'private_key'
                 and 'client_email'.
    passwd: Passphrase for encrypted private keys.

  Returns:
    key: Parsed private key from the keystore.
    client_email: The email address for the service account.

  Raises:
    ValueError: If unable to parse ks_contents or keystore is missing
                required fields.
  """
  # ensuring that json.loads receives unicode in Python 3 and bytes in Python 2
  # Previous to Python 3.6, there was no automatic conversion and str was req.
  ks = json.loads(six.ensure_str(ks_contents))

  if 'client_email' not in ks or 'private_key' not in ks:
    raise ValueError('JSON keystore doesn\'t contain required fields')

  client_email = ks['client_email']
  if passwd:
    key = load_privatekey(FILETYPE_PEM, ks['private_key'], passwd)
  else:
    key = load_privatekey(FILETYPE_PEM, ks['private_key'])

  return key, client_email
Пример #8
0
def get_all_recommended_series_from_cache(indexers):
    """
    Retrieve all recommended show objects from the dogpile cache for a specific indexer or a number of indexers.

    For example: `get_all_recommended_series_from_cache(['imdb', 'anidb'])` will return all recommended show objects, for the
    indexers imdb and anidb.

    :param indexers: indexer or list of indexers. Indexers need to be passed as a string. For example: 'imdb', 'anidb' or 'trakt'.
    :return: List of recommended show objects.
    """
    indexers = ensure_list(indexers)
    all_series = []
    for indexer in indexers:
        index = recommended_series_cache.get(ensure_str(indexer))
        if not index:
            continue

        for index_item in index:
            key = '{indexer}_{series_id}'.format(indexer=indexer, series_id=index_item)
            series = recommended_series_cache.get(key)
            if series:
                all_series.append(series)

    return all_series
Пример #9
0
def _convert_tf1_model(flags):
  """Calls function to convert the TensorFlow 1.X model into a TFLite model.

  Args:
    flags: argparse.Namespace object.

  Raises:
    ValueError: Invalid flags.
  """
  # Create converter.
  converter = _get_toco_converter(flags)
  if flags.inference_type:
    converter.inference_type = _parse_inference_type(flags.inference_type,
                                                     "inference_type")
  if flags.inference_input_type:
    converter.inference_input_type = _parse_inference_type(
        flags.inference_input_type, "inference_input_type")
  if flags.output_format:
    converter.output_format = _toco_flags_pb2.FileFormat.Value(
        flags.output_format)

  if flags.mean_values and flags.std_dev_values:
    input_arrays = converter.get_input_arrays()
    std_dev_values = _parse_array(flags.std_dev_values, type_fn=float)

    # In quantized inference, mean_value has to be integer so that the real
    # value 0.0 is exactly representable.
    if converter.inference_type == lite_constants.QUANTIZED_UINT8:
      mean_values = _parse_array(flags.mean_values, type_fn=int)
    else:
      mean_values = _parse_array(flags.mean_values, type_fn=float)
    quant_stats = list(zip(mean_values, std_dev_values))
    if ((not flags.input_arrays and len(input_arrays) > 1) or
        (len(input_arrays) != len(quant_stats))):
      raise ValueError("Mismatching --input_arrays, --std_dev_values, and "
                       "--mean_values. The flags must have the same number of "
                       "items. The current input arrays are '{0}'. "
                       "--input_arrays must be present when specifying "
                       "--std_dev_values and --mean_values with multiple input "
                       "tensors in order to map between names and "
                       "values.".format(",".join(input_arrays)))
    converter.quantized_input_stats = dict(list(zip(input_arrays, quant_stats)))
  if (flags.default_ranges_min is not None) and (flags.default_ranges_max is
                                                 not None):
    converter.default_ranges_stats = (flags.default_ranges_min,
                                      flags.default_ranges_max)

  if flags.drop_control_dependency:
    converter.drop_control_dependency = flags.drop_control_dependency
  if flags.reorder_across_fake_quant:
    converter.reorder_across_fake_quant = flags.reorder_across_fake_quant
  if flags.change_concat_input_ranges:
    converter.change_concat_input_ranges = (
        flags.change_concat_input_ranges == "TRUE")

  if flags.allow_custom_ops:
    converter.allow_custom_ops = flags.allow_custom_ops
  if flags.custom_opdefs:
    converter._custom_opdefs = _parse_array(flags.custom_opdefs)  # pylint: disable=protected-access
  if flags.target_ops:
    ops_set_options = lite.OpsSet.get_options()
    converter.target_spec.supported_ops = set()
    for option in six.ensure_str(flags.target_ops).split(","):
      if option not in ops_set_options:
        raise ValueError("Invalid value for --target_ops. Options: "
                         "{0}".format(",".join(ops_set_options)))
      converter.target_spec.supported_ops.add(lite.OpsSet(option))

  if flags.post_training_quantize:
    converter.optimizations = [lite.Optimize.DEFAULT]
    if converter.inference_type == lite_constants.QUANTIZED_UINT8:
      print("--post_training_quantize quantizes a graph of inference_type "
            "FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT.")
      converter.inference_type = lite_constants.FLOAT

  if flags.quantize_to_float16:
    converter.target_spec.supported_types = [lite.constants.FLOAT16]
    if not flags.post_training_quantize:
      print("--quantize_to_float16 will only take effect with the "
            "--post_training_quantize flag enabled.")

  if flags.dump_graphviz_dir:
    converter.dump_graphviz_dir = flags.dump_graphviz_dir
  if flags.dump_graphviz_video:
    converter.dump_graphviz_vode = flags.dump_graphviz_video
  if flags.conversion_summary_dir:
    converter.conversion_summary_dir = flags.conversion_summary_dir

  # TODO(b/145312675): Enable the new converter by default. It requires to
  # add a new command line argument like `experimental_legacy_converter`.
  converter.experimental_new_converter = flags.experimental_new_converter

  # Convert model.
  output_data = converter.convert()
  with open(flags.output_file, "wb") as f:
    f.write(six.ensure_binary(output_data))
Пример #10
0
    def parseIsoggTable(self):
        'parses ISOGG table'

        # input reader
        utils.checkFileExistence(self.config.isoggFN, 'Isogg')
        isoggInFile = open(self.config.isoggFN, 'r')
        isoggReader = csv.reader(isoggInFile, delimiter='\t')
        next(isoggReader)  # ignore header

        # output file handles
        if self.config.suppressOutputAndLog:
            isoggOutFile = None
            isoggDropOutFile = None
        else:
            isoggOutFile = open(self.config.cleanedIsoggFN, 'w')
            isoggDropOutFile = open(self.config.droppedIsoggFN, 'w')

        droppedMarkerList = list()

        for lineList in isoggReader:
            self.isoggCountsDict['read'] += 1

            # clean up data row and extract values
            lineList = [element.strip() for element in lineList]
            if lineList[
                    1] == '':  # when present, remove extra tab after snp name
                del lineList[1]
            if len(lineList) != 6:
                self.isoggCountsDict['badLines'] += 1
                continue
            name, haplogroup, _, _, position, mutation = lineList

            # apply corrections
            if name in self.isoggCorrectionDict:
                haplogroup, position, mutation = self.isoggCorrectionDict[name]
                self.numSNPsCorrected += 1

            # identify markers to drop
            recordIsBad, markerIsOkToRepresentNode = (self.checkIsoggRecord(
                name, haplogroup, position, mutation))
            if recordIsBad:
                self.isoggCountsDict['dropped'] += 1
                if isoggDropOutFile:
                    isogg_drop_output = six.ensure_str(
                        '%-10s %-25s %8s %s\n' %
                        (six.ensure_text(name), haplogroup, position,
                         mutation))
                    isoggDropOutFile.write(isogg_drop_output)
                if markerIsOkToRepresentNode:
                    droppedMarkerList.append(DroppedMarker(name, haplogroup))
                continue

            # process retained SNPs
            self.isoggCountsDict['retained'] += 1
            position = int(position)
            if isoggOutFile:
                isoggOutFile.write('%-10s %-25s %8d %s\n' %
                                   (name, haplogroup, position, mutation))
            self.constructSNP(name, haplogroup, position, mutation)

        self.addDroppedMarkersToNodes(droppedMarkerList)
        utils.closeFiles([isoggInFile, isoggOutFile, isoggDropOutFile])
Пример #11
0
def current_branch_name():
    # type: () -> str
    branch_name = subprocess.check_output(
        ["git", "rev-parse", "--abbrev-ref", "HEAD"])
    return six.ensure_str(branch_name).split("\n", 1)[0]
Пример #12
0
 def getA(name, default=None):
     name = six.ensure_binary(name)
     ret = req.args.get(name)
     return [six.ensure_str(x) for x in ret] if ret else default
Пример #13
0
def get_url(address, user, password):
    request = Request(address)
    base64string = base64.encodebytes(ensure_binary(
        '%s:%s' % (user, password))).replace(b'\n', b'')
    request.add_header("Authorization", "Basic %s" % ensure_str(base64string))
    return urlopen(request)
Пример #14
0
def filterName(name, encode=True):
	if name is not None:
		name = six.ensure_str(removeBadChars(six.ensure_binary(name)))
		if encode is True:
			return html_escape(name, quote=True)
	return name
Пример #15
0
def CreateTestProcesses(parallel_tests,
                        test_index,
                        process_list,
                        process_done,
                        max_parallel_tests,
                        root_coverage_file=None):
  """Creates test processes to run tests in parallel.

  Args:
    parallel_tests: List of all parallel tests.
    test_index: List index of last created test before this function call.
    process_list: List of running subprocesses. Created processes are appended
                  to this list.
    process_done: List of booleans indicating process completion. One 'False'
                  will be added per process created.
    max_parallel_tests: Maximum number of tests to run in parallel.
    root_coverage_file: The root .coverage filename if coverage is requested.

  Returns:
    Index of last created test.
  """
  orig_test_index = test_index
  # checking to see if test was invoked from a par file (bundled archive)
  # if not, add python executable path to ensure correct version of python
  # is used for testing
  executable_prefix = [sys.executable] if not InvokedFromParFile() else []
  s3_argument = ['-s'] if tests.util.RUN_S3_TESTS else []
  multiregional_buckets = ['-b'] if tests.util.USE_MULTIREGIONAL_BUCKETS else []
  project_id_arg = []
  try:
    project_id_arg = [
        '-o', 'GSUtil:default_project_id=%s' % PopulateProjectId()
    ]
  except ProjectIdException:
    # If we don't have a project ID, unit tests should still be able to pass.
    pass

  process_create_start_time = time.time()
  last_log_time = process_create_start_time
  while (CountFalseInList(process_done) < max_parallel_tests and
         test_index < len(parallel_tests)):
    env = os.environ.copy()
    if root_coverage_file:
      env['GSUTIL_COVERAGE_OUTPUT_FILE'] = root_coverage_file
    envstr = dict()
    # constructing command list and ensuring each part is str
    cmd = [
        six.ensure_str(part) for part in list(
            executable_prefix +
            [gslib.GSUTIL_PATH] +
            project_id_arg +
            ['test'] +
            s3_argument +
            multiregional_buckets +
            ['--' + _SEQUENTIAL_ISOLATION_FLAG] +
            [parallel_tests[test_index][len('gslib.tests.test_'):]]
        )
    ]  # yapf: disable
    for k, v in six.iteritems(env):
      envstr[six.ensure_str(k)] = six.ensure_str(v)
    process_list.append(
        subprocess.Popen(cmd,
                         stdout=subprocess.PIPE,
                         stderr=subprocess.PIPE,
                         env=envstr))
    test_index += 1
    process_done.append(False)
    if time.time() - last_log_time > 5:
      print(('Created %d new processes (total %d/%d created)' %
             (test_index - orig_test_index, len(process_list),
              len(parallel_tests))))
      last_log_time = time.time()
  if test_index == len(parallel_tests):
    print(('Test process creation finished (%d/%d created)' %
           (len(process_list), len(parallel_tests))))
  return test_index
Пример #16
0
 def _get_variable_name(self, param_name):
     """Get the variable name from the tensor name."""
     m = re.match("^(.*):\\d+$", six.ensure_str(param_name))
     if m is not None:
         param_name = m.group(1)
     return param_name
Пример #17
0
class IniConfig(object):
    VIRTUALENV_CONFIG_FILE_ENV_VAR = six.ensure_str("VIRTUALENV_CONFIG_FILE")
    STATE = {None: "failed to parse", True: "active", False: "missing"}

    section = "virtualenv"

    def __init__(self):
        config_file = os.environ.get(self.VIRTUALENV_CONFIG_FILE_ENV_VAR, None)
        self.is_env_var = config_file is not None
        self.config_file = Path(config_file) if config_file is not None else (
            default_config_dir() / "virtualenv.ini")
        self._cache = {}

        exception = None
        self.has_config_file = None
        try:
            self.has_config_file = self.config_file.exists()
        except OSError as exc:
            exception = exc
        else:
            if self.has_config_file:
                self.config_file = self.config_file.resolve()
                self.config_parser = ConfigParser.ConfigParser()
                try:
                    self._load()
                    self.has_virtualenv_section = self.config_parser.has_section(
                        self.section)
                except Exception as exc:
                    exception = exc
        if exception is not None:
            logging.error("failed to read config file %s because %r",
                          config_file, exception)

    def _load(self):
        with self.config_file.open("rt") as file_handler:
            reader = getattr(self.config_parser,
                             "read_file" if PY3 else "readfp")
            reader(file_handler)

    def get(self, key, as_type):
        cache_key = key, as_type
        if cache_key in self._cache:
            return self._cache[cache_key]
        # noinspection PyBroadException
        try:
            source = "file"
            raw_value = self.config_parser.get(self.section, key.lower())
            value = convert(raw_value, as_type, source)
            result = value, source
        except Exception:
            result = None
        self._cache[cache_key] = result
        return result

    def __bool__(self):
        return bool(self.has_config_file) and bool(self.has_virtualenv_section)

    @property
    def epilog(self):
        msg = "{}config file {} {} (change{} via env var {})"
        return msg.format(
            os.linesep,
            self.config_file,
            self.STATE[self.has_config_file],
            "d" if self.is_env_var else "",
            self.VIRTUALENV_CONFIG_FILE_ENV_VAR,
        )
Пример #18
0
    def _output_analysed_ruleset(self,
                                 all_rulesets,
                                 rulespec,
                                 svc_desc_or_item,
                                 svc_desc,
                                 known_settings=None):
        if known_settings is None:
            known_settings = self._PARAMETERS_UNKNOWN

        def rule_url(rule):
            return watolib.folder_preserving_link([
                ('mode', 'edit_rule'),
                ('varname', varname),
                ('rule_folder', rule.folder.path()),
                ('rulenr', rule.index()),
                ('host', self._hostname),
                ('item', ensure_str(watolib.mk_repr(svc_desc_or_item)) if svc_desc_or_item else ''),
                ('service', ensure_str(watolib.mk_repr(svc_desc)) if svc_desc else ''),
            ])

        varname = rulespec.name
        valuespec = rulespec.valuespec

        url = watolib.folder_preserving_link([
            ('mode', 'edit_ruleset'),
            ('varname', varname),
            ('host', self._hostname),
            ('item', ensure_str(watolib.mk_repr(svc_desc_or_item))),
            ('service', ensure_str(watolib.mk_repr(svc_desc))),
        ])

        forms.section(html.render_a(rulespec.title, url))

        ruleset = all_rulesets.get(varname)
        setting, rules = ruleset.analyse_ruleset(self._hostname, svc_desc_or_item, svc_desc)

        html.open_table(class_="setting")
        html.open_tr()
        html.open_td(class_="reason")

        # Show reason for the determined value
        if len(rules) == 1:
            rule_folder, rule_index, rule = rules[0]
            url = rule_url(rule)
            html.a(_("Rule %d in %s") % (rule_index + 1, rule_folder.title()), href=rule_url(rule))

        elif len(rules) > 1:
            html.a("%d %s" % (len(rules), _("Rules")), href=url)

        else:
            html.i(_("Default Value"))
        html.close_td()

        # Show the resulting value or factory setting
        html.open_td(class_=["settingvalue", "used" if len(rules) > 0 else "unused"])

        if isinstance(known_settings, dict) and "tp_computed_params" in known_settings:
            computed_at = known_settings["tp_computed_params"]["computed_at"]
            html.write_text(
                _("Timespecific parameters computed at %s") %
                cmk.utils.render.date_and_time(computed_at))
            html.br()
            known_settings = known_settings["tp_computed_params"]["params"]

        # In some cases we now the settings from a check_mk automation
        if known_settings is self._PARAMETERS_OMIT:
            return

        # Special handling for logwatch: The check parameter is always None. The actual
        # patterns are configured in logwatch_rules. We do not have access to the actual
        # patterns here but just to the useless "None". In order not to complicate things
        # we simply display nothing here.
        if varname == "logwatch_rules":
            pass

        elif known_settings is not self._PARAMETERS_UNKNOWN:
            try:
                html.write(valuespec.value_to_text(known_settings))
            except Exception as e:
                if config.debug:
                    raise
                html.write_text(_("Invalid parameter %r: %s") % (known_settings, e))

        else:
            # For match type "dict" it can be the case the rule define some of the keys
            # while other keys are taken from the factory defaults. We need to show the
            # complete outcoming value here.
            if rules and ruleset.match_type() == "dict":
                if rulespec.factory_default is not watolib.Rulespec.NO_FACTORY_DEFAULT \
                    and rulespec.factory_default is not watolib.Rulespec.FACTORY_DEFAULT_UNUSED:
                    fd = rulespec.factory_default.copy()
                    fd.update(setting)
                    setting = fd

            if valuespec and not rules:  # show the default value
                if rulespec.factory_default is watolib.Rulespec.FACTORY_DEFAULT_UNUSED:
                    # Some rulesets are ineffective if they are empty
                    html.write_text(_("(unused)"))

                elif rulespec.factory_default is not watolib.Rulespec.NO_FACTORY_DEFAULT:
                    # If there is a factory default then show that one
                    setting = rulespec.factory_default
                    html.write(valuespec.value_to_text(setting))

                elif ruleset.match_type() in ("all", "list"):
                    # Rulesets that build lists are empty if no rule matches
                    html.write_text(_("(no entry)"))

                else:
                    # Else we use the default value of the valuespec
                    html.write(valuespec.value_to_text(valuespec.default_value()))

            # We have a setting
            elif valuespec:
                if ruleset.match_type() == "all":
                    html.write(", ".join(valuespec.value_to_text(s) for s in setting))
                else:
                    html.write(valuespec.value_to_text(setting))

            # Binary rule, no valuespec, outcome is True or False
            else:
                icon_name = "rule_%s%s" % ("yes" if setting else "no", "_off" if not rules else '')
                html.icon(title=_("yes") if setting else _("no"), icon=icon_name)
        html.close_td()
        html.close_tr()
        html.close_table()
Пример #19
0
 def test_ensure_binary_raise_type_error(self):
     with py.test.raises(TypeError):
         six.ensure_str(8)
Пример #20
0
def test_get_str_input_mandatory_non_ascii():
    assert html.request.get_str_input_mandatory("abc") == six.ensure_str(
        u"äbc")
  def RunGsUtil(self,
                cmd,
                return_status=False,
                return_stdout=False,
                return_stderr=False,
                expected_status=0,
                stdin=None,
                env_vars=None):
    """Runs the gsutil command.

    Args:
      cmd: The command to run, as a list, e.g. ['cp', 'foo', 'bar']
      return_status: If True, the exit status code is returned.
      return_stdout: If True, the standard output of the command is returned.
      return_stderr: If True, the standard error of the command is returned.
      expected_status: The expected return code. If not specified, defaults to
                       0. If the return code is a different value, an exception
                       is raised.
      stdin: A string of data to pipe to the process as standard input.
      env_vars: A dictionary of variables to extend the subprocess's os.environ
                with.

    Returns:
      If multiple return_* values were specified, this method returns a tuple
      containing the desired return values specified by the return_* arguments
      (in the order those parameters are specified in the method definition).
      If only one return_* value was specified, that value is returned directly
      rather than being returned within a 1-tuple.
    """
    cmd = [
        gslib.GSUTIL_PATH, '--testexceptiontraces', '-o',
        'GSUtil:default_project_id=' + PopulateProjectId()
    ] + cmd
    if stdin is not None:
      if six.PY3:
        if not isinstance(stdin, bytes):
          stdin = stdin.encode(UTF8)
      else:
        stdin = stdin.encode(UTF8)
    # checking to see if test was invoked from a par file (bundled archive)
    # if not, add python executable path to ensure correct version of python
    # is used for testing
    cmd = [str(sys.executable)] + cmd if not InvokedFromParFile() else cmd
    env = os.environ.copy()
    if env_vars:
      env.update(env_vars)
    # Ensuring correct text types
    envstr = dict()
    for k, v in six.iteritems(env):
      envstr[six.ensure_str(k)] = six.ensure_str(v)
    cmd = [six.ensure_str(part) for part in cmd]
    # executing command
    p = subprocess.Popen(cmd,
                         stdout=subprocess.PIPE,
                         stderr=subprocess.PIPE,
                         stdin=subprocess.PIPE,
                         env=envstr)
    c_out = p.communicate(stdin)
    try:
      c_out = [six.ensure_text(output) for output in c_out]
    except UnicodeDecodeError:
      c_out = [
          six.ensure_text(output, locale.getpreferredencoding(False))
          for output in c_out
      ]
    stdout = c_out[0].replace(os.linesep, '\n')
    stderr = c_out[1].replace(os.linesep, '\n')
    status = p.returncode

    if expected_status is not None:
      cmd = map(six.ensure_text, cmd)
      self.assertEqual(
          int(status),
          int(expected_status),
          msg='Expected status {}, got {}.\nCommand:\n{}\n\nstderr:\n{}'.format(
              expected_status, status, ' '.join(cmd), stderr))

    toreturn = []
    if return_status:
      toreturn.append(status)
    if return_stdout:
      toreturn.append(stdout)
    if return_stderr:
      toreturn.append(stderr)

    if len(toreturn) == 1:
      return toreturn[0]
    elif toreturn:
      return tuple(toreturn)
Пример #22
0
 def get_rp_uuid_from_obj(self, obj):
     return str(uuid.uuid3(uuid.NAMESPACE_DNS, six.ensure_str(obj.name)))
Пример #23
0
def git_repository_root():
    return six.ensure_str(
        subprocess.check_output(['git', 'rev-parse',
                                 '--show-toplevel']).strip())
Пример #24
0
def check_logwatch_generic(
    *,
    item: str,
    patterns,
    loglines,
    found: bool,
    max_filesize: int,
) -> CheckResult:
    logmsg_dir = pathlib.Path(cmk.utils.paths.var_dir, "logwatch", host_name())

    logmsg_dir.mkdir(parents=True, exist_ok=True)

    logmsg_file_path = logmsg_dir / item.replace("/", "\\")

    # Logfile (=item) section not found and no local file found. This usually
    # means, that the corresponding logfile also vanished on the target host.
    if not found and not logmsg_file_path.exists():
        yield Result(state=state.UNKNOWN, summary="log not present anymore")
        return

    block_collector = LogwatchBlockCollector()

    logmsg_file_exists = logmsg_file_path.exists()
    logmsg_file_handle = logmsg_file_path.open(
        "r+" if logmsg_file_exists else "w", encoding="utf-8")

    # TODO: repr() of a dict may change.
    pattern_hash = hashlib.sha256(repr(patterns).encode()).hexdigest()
    if not logmsg_file_exists:
        output_size = 0
        reclassify = True
    else:  # parse cached log lines
        reclassify = _patterns_changed(logmsg_file_handle, pattern_hash)

        if not reclassify and _truncate_way_too_large_result(
                logmsg_file_path, max_filesize):
            yield _dropped_msg_result(max_filesize)
            return

        block_collector.extend(
            _extract_blocks(logmsg_file_handle, patterns, reclassify))

        if reclassify:
            output_size = block_collector.size
        else:
            output_size = logmsg_file_handle.tell()
            # when skipping reclassification, output lines contain only headers anyway
            block_collector.clear_lines()

    header = time.strftime("<<<%Y-%m-%d %H:%M:%S UNKNOWN>>>\n")
    output_size += len(header)
    header = six.ensure_str(header)

    # process new input lines - but only when there is some room left in the file
    block_collector.extend(
        _extract_blocks([header] + loglines,
                        patterns,
                        False,
                        limit=max_filesize - output_size))

    # when reclassifying, rewrite the whole file, otherwise append
    if reclassify and block_collector.get_lines():
        logmsg_file_handle.seek(0)
        logmsg_file_handle.truncate()
        logmsg_file_handle.write("[[[%s]]]\n" % pattern_hash)

    for line in block_collector.get_lines():
        logmsg_file_handle.write(line)
    # correct output size
    logmsg_file_handle.close()

    if not block_collector.saw_lines:
        logmsg_file_path.unlink(missing_ok=True)

    # if logfile has reached maximum size, abort with critical state
    if logmsg_file_path.exists(
    ) and logmsg_file_path.stat().st_size > max_filesize:
        yield _dropped_msg_result(max_filesize)
        return

    #
    # Render output
    #

    if block_collector.worst <= 0:
        yield Result(state=state.OK, summary="No error messages")
        return

    info = block_collector.get_count_info()
    if LOGWATCH_SERVICE_OUTPUT == "default":
        info += ' (Last worst: "%s")' % block_collector.last_worst_line

    summary, details = info, None
    if "\n" in info.strip():
        summary, details = info.split("\n", 1)

    yield Result(
        state=state(block_collector.worst),
        summary=summary,
        details=details,
    )
Пример #25
0
    def imdb_person_list(self, url):
        try:
            result = client.request(url)
            items = client.parseDOM(result, 'div', attrs={'class': '.+?etail'})
        except:
            return

        try:
            result = result.replace(r'"class=".*?ister-page-nex',
                                    '" class="lister-page-nex')
            next = client.parseDOM(result,
                                   'a',
                                   ret='href',
                                   attrs={'class': r'.*?ister-page-nex.*?'})

            if len(next) == 0:
                next = client.parseDOM(result,
                                       'div',
                                       attrs={'class': u'pagination'})[0]
                next = zip(client.parseDOM(next, 'a', ret='href'),
                           client.parseDOM(next, 'a'))
                next = [i[0] for i in next if 'Next' in i[1]]

            next = url.replace(
                urllib_parse.urlparse(url).query,
                urllib_parse.urlparse(next[0]).query)
            next = client.replaceHTMLCodes(next)
            next = six.ensure_str(next, errors='ignore')
        except:
            next = ''

        for item in items:
            try:
                name = client.parseDOM(item, 'img', ret='alt')[0]
                name = six.ensure_str(name, errors='ignore')

                id = client.parseDOM(item, 'a', ret='href')[0]
                id = re.findall(r'(nm\d*)', id, re.I)[0]
                id = client.replaceHTMLCodes(id)
                id = six.ensure_str(id, errors='replace')

                try:
                    image = client.parseDOM(item, 'img', ret='src')[0]
                    image = re.sub(
                        r'(?:_SX|_SY|_UX|_UY|_CR|_AL)(?:\d+|_).+?\.',
                        '_SX500.', image)
                    image = client.replaceHTMLCodes(image)
                    image = six.ensure_str(image, errors='replace')
                    if '/sash/' in image or '/nopicture/' in image:
                        raise Exception()
                except:
                    image = 'person.png'

                try:
                    info = client.parseDOM(item, 'p')
                    info = '[I]%s[/I][CR]%s' % (info[0].split('<')[0].strip(),
                                                info[1])
                    info = client.replaceHTMLCodes(info)
                    info = six.ensure_str(info, errors='ignore')
                    info = re.sub(r'<.*?>', '', info)
                except:
                    info = ''

                self.list.append({
                    'name': name,
                    'id': id,
                    'image': image,
                    'plot': info,
                    'next': next
                })
            except:
                pass

        return self.list
Пример #26
0
    :param validate_func: SwaggerFormat.validate function
    :return: wrapped callable
    """
    @functools.wraps(validate_func)
    def wrapper(validatable_primitive):
        # type: (typing.Callable[[typing.Any], typing.Any]) -> bool
        validate_func(validatable_primitive)
        return True

    return wrapper


BASE64_BYTE_FORMAT = SwaggerFormat(
    format='byte',
    # Note: In Python 3, this requires a bytes-like object as input
    to_wire=lambda b: six.ensure_str(base64.b64encode(b),
                                     encoding=str('ascii')),
    to_python=lambda s: base64.b64decode(
        six.ensure_binary(s, encoding=str('ascii'))),
    validate=NO_OP,  # jsonschema validates string
    description='Converts [wire]string:byte <=> python bytes',
)

DEFAULT_FORMATS = {
    'byte':
    SwaggerFormat(
        format='byte',
        to_wire=lambda b: b if isinstance(b, str) else str(b),
        to_python=lambda s: s if isinstance(s, str) else str(s),
        validate=NO_OP,  # jsonschema validates string
        description='Converts [wire]string:byte <=> python byte',
    ),
def test_compute_log_manager(mock_create_blob_client, mock_generate_blob_sas,
                             storage_account, container, credential):
    mock_generate_blob_sas.return_value = "fake-url"
    fake_client = FakeBlobServiceClient(storage_account)
    mock_create_blob_client.return_value = fake_client

    @pipeline
    def simple():
        @solid
        def easy(context):
            context.log.info("easy")
            print(HELLO_WORLD)
            return "easy"

        easy()

    with seven.TemporaryDirectory() as temp_dir:
        run_store = SqliteRunStorage.from_local(temp_dir)
        event_store = SqliteEventLogStorage(temp_dir)
        manager = AzureBlobComputeLogManager(
            storage_account=storage_account,
            container=container,
            prefix="my_prefix",
            local_dir=temp_dir,
            secret_key=credential,
        )
        instance = DagsterInstance(
            instance_type=InstanceType.PERSISTENT,
            local_artifact_storage=LocalArtifactStorage(temp_dir),
            run_storage=run_store,
            event_storage=event_store,
            compute_log_manager=manager,
            run_launcher=SyncInMemoryRunLauncher(),
        )
        result = execute_pipeline(simple, instance=instance)
        compute_steps = [
            event.step_key for event in result.step_event_list
            if event.event_type == DagsterEventType.STEP_START
        ]
        assert len(compute_steps) == 1
        step_key = compute_steps[0]

        stdout = manager.read_logs_file(result.run_id, step_key,
                                        ComputeIOType.STDOUT)
        assert stdout.data == HELLO_WORLD + SEPARATOR

        stderr = manager.read_logs_file(result.run_id, step_key,
                                        ComputeIOType.STDERR)
        for expected in EXPECTED_LOGS:
            assert expected in stderr.data

        # Check ADLS2 directly
        adls2_object = fake_client.get_blob_client(
            container=container,
            blob="{prefix}/storage/{run_id}/compute_logs/easy.compute.err".
            format(prefix="my_prefix", run_id=result.run_id),
        )
        adls2_stderr = six.ensure_str(adls2_object.download_blob().readall())
        for expected in EXPECTED_LOGS:
            assert expected in adls2_stderr

        # Check download behavior by deleting locally cached logs
        compute_logs_dir = os.path.join(temp_dir, result.run_id,
                                        "compute_logs")
        for filename in os.listdir(compute_logs_dir):
            os.unlink(os.path.join(compute_logs_dir, filename))

        stdout = manager.read_logs_file(result.run_id, step_key,
                                        ComputeIOType.STDOUT)
        assert stdout.data == HELLO_WORLD + SEPARATOR

        stderr = manager.read_logs_file(result.run_id, step_key,
                                        ComputeIOType.STDERR)
        for expected in EXPECTED_LOGS:
            assert expected in stderr.data
Пример #28
0
def _get_user_edit_page(app):
    user = factories.User()
    env = {"REMOTE_USER": six.ensure_str(user["name"])}
    response = app.get(url=url_for("user.edit"), extra_environ=env)
    return env, response, user
Пример #29
0
 def get_user(self, handler):
     user_cookie = handler.get_secure_cookie("user")
     if user_cookie:
         #return json.loads(user_cookie)
         return ensure_str(user_cookie)
     return None
Пример #30
0
def check_mk_local_automation_serialized(
    *,
    command: str,
    args: Optional[Sequence[str]] = None,
    indata: Any = "",
    stdin_data: Optional[str] = None,
    timeout: Optional[int] = None,
) -> Tuple[Sequence[str], SerializedResult]:
    if args is None:
        args = []
    new_args = [ensure_str(a) for a in args]

    if stdin_data is None:
        stdin_data = repr(indata)

    if timeout:
        new_args = ["--timeout", "%d" % timeout] + new_args

    cmd = ["check_mk"]

    if auto_logger.isEnabledFor(logging.DEBUG):
        cmd.append("-vv")
    elif auto_logger.isEnabledFor(VERBOSE):
        cmd.append("-v")

    cmd += ["--automation", command] + new_args

    if command in ["restart", "reload"]:
        call_hook_pre_activate_changes()

    cmd = [ensure_str(a) for a in cmd]
    try:
        # This debug output makes problems when doing bulk inventory, because
        # it garbles the non-HTML response output
        # if config.debug:
        #     html.write_text("<div class=message>Running <tt>%s</tt></div>\n" % subprocess.list2cmdline(cmd))
        auto_logger.info("RUN: %s" % subprocess.list2cmdline(cmd))
        p = subprocess.Popen(
            cmd,
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            close_fds=True,
            encoding="utf-8",
        )
    except Exception as e:
        raise local_automation_failure(command=command, cmdline=cmd, exc=e)

    assert p.stdin is not None
    assert p.stdout is not None
    assert p.stderr is not None

    auto_logger.info("STDIN: %r" % stdin_data)
    p.stdin.write(stdin_data)
    p.stdin.close()

    outdata = p.stdout.read()
    exitcode = p.wait()
    auto_logger.info("FINISHED: %d" % exitcode)
    auto_logger.debug("OUTPUT: %r" % outdata)
    errdata = p.stderr.read()
    if errdata:
        auto_logger.warning("'%s' returned '%s'" % (" ".join(cmd), errdata))
    if exitcode != 0:
        auto_logger.error("Error running %r (exit code %d)" %
                          (subprocess.list2cmdline(cmd), exitcode))
        raise local_automation_failure(command=command,
                                       cmdline=cmd,
                                       code=exitcode,
                                       out=outdata,
                                       err=errdata)

    # On successful "restart" command execute the activate changes hook
    if command in ["restart", "reload"]:
        call_hook_activate_changes()

    return cmd, SerializedResult(outdata)
Пример #31
0
 def _get_message_unparsed(flag, orig_flag, new_flag):
   if six.ensure_str(flag).startswith(orig_flag):
     return "\n  Use {0} instead of {1}".format(new_flag, orig_flag)
   return ""
Пример #32
0
 def crash_dir(self, ident_text: Optional[str] = None) -> Path:
     """Returns the path to the crash directory of the current or given crash report"""
     if ident_text is None:
         ident_text = self.ident_to_text()
     return cmk.utils.paths.crash_dir / ensure_str(self.type()) / ensure_str(ident_text)
Пример #33
0
def _parse_array(values, type_fn=str):
  if values is not None:
    return [type_fn(val) for val in six.ensure_str(values).split(",") if val]
  return None
Пример #34
0
def api_str_type(s):
    if not is_gui_py3():
        return six.ensure_binary(s)
    return six.ensure_str(s)
Пример #35
0
def _parse_set(values):
  if values is not None:
    return set([item for item in six.ensure_str(values).split(",") if item])
  return None
Пример #36
0
 def get(name, default=None):
     name = six.ensure_binary(name)
     ret = req.args.get(name)
     return six.ensure_str(ret[0]) if ret else default
Пример #37
0
def format_local_vars(local_vars):
    return ensure_str(base64.b64decode(local_vars))
Пример #38
0
    def render(self, req):
        for key, value in six.iteritems(req.args):
            key = six.ensure_str(key)
            if value:
                value = value[0]
                value = six.ensure_str(value)
            if key == "autopoll":
                config.plugins.autotimer.autopoll.value = True if value == "true" else False
            elif key == "interval":
                config.plugins.autotimer.interval.value = int(value)
            elif key == "refresh":
                config.plugins.autotimer.refresh.value = value
            elif key == "try_guessing":
                config.plugins.autotimer.try_guessing.value = True if value == "true" else False
            elif key == "editor":
                config.plugins.autotimer.editor.value = value
            elif key == "disabled_on_conflict":
                config.plugins.autotimer.disabled_on_conflict.value = True if value == "true" else False
            elif key == "addsimilar_on_conflict":
                config.plugins.autotimer.addsimilar_on_conflict.value = True if value == "true" else False
            elif key == "show_in_extensionsmenu":
                config.plugins.autotimer.show_in_extensionsmenu.value = True if value == "true" else False
            elif key == "show_in_furtheroptionsmenu":
                config.plugins.autotimer.show_in_furtheroptionsmenu.value = True if value == "true" else False
            elif key == "fastscan":
                config.plugins.autotimer.fastscan.value = True if value == "true" else False
            elif key == "notifconflict":
                config.plugins.autotimer.notifconflict.value = True if value == "true" else False
            elif key == "notifsimilar":
                config.plugins.autotimer.notifsimilar.value = True if value == "true" else False
            elif key == "notiftimers":
                config.plugins.autotimer.notiftimers.value = True if value == "true" else False
            elif key == "maxdaysinfuture":
                config.plugins.autotimer.maxdaysinfuture.value = int(value)
            elif key == "add_autotimer_to_tags":
                config.plugins.autotimer.add_autotimer_to_tags.value = True if value == "true" else False
            elif key == "add_name_to_tags":
                config.plugins.autotimer.add_name_to_tags.value = True if value == "true" else False
            elif key == "timeout":
                config.plugins.autotimer.timeout.value = int(value)
            elif key == "delay":
                config.plugins.autotimer.delay.value = int(value)
            elif key == "editdelay":
                config.plugins.autotimer.editdelay.value = int(value)
            elif key == "skip_during_records":
                config.plugins.autotimer.skip_during_records.value = True if value == "true" else False
            elif key == "skip_during_epgrefresh":
                config.plugins.autotimer.skip_during_epgrefresh.value = True if value == "true" else False
            elif key == "check_eit_and_removeh":
                config.plugins.autotimer.check_eit_and_remove.value = True if value == "true" else False
            elif key == "onlyinstandby":
                config.plugins.autotimer.onlyinstandby.value = True if value == "true" else False
            elif key == "add_to_channelselection":
                config.plugins.autotimer.add_to_channelselection.value = True if value == "true" else False
            elif key == "add_to_epgselection":
                config.plugins.autotimer.add_to_epgselection.value = True if value == "true" else False
            elif key == "add_to_multiepgselection":
                config.plugins.autotimer.add_to_multiepgselection.value = True if value == "true" else False
            elif key == "log_write":
                config.plugins.autotimer.log_write.value = True if value == "true" else False
            elif key == "log_shell":
                config.plugins.autotimer.log_shell.value = True if value == "true" else False

        if config.plugins.autotimer.autopoll.value:
            if plugin.autopoller is None:
                from .AutoPoller import AutoPoller
                plugin.autopoller = AutoPoller()
            plugin.autopoller.start(initial=False)
        else:
            if plugin.autopoller is not None:
                plugin.autopoller.stop()
                plugin.autopoller = None

        return self.returnResult(req, True, _("config changed."))
def resolve(regex):
    try:
        vanilla = re.compile('(<regex>.+)',
                             re.MULTILINE | re.DOTALL).findall(regex)[0]
        cddata = re.compile('<\!\[CDATA\[(.+?)\]\]>',
                            re.MULTILINE | re.DOTALL).findall(regex)
        for i in cddata:
            regex = regex.replace('<![CDATA[' + i + ']]>',
                                  urllib_parse.quote_plus(i))

        regexs = re.compile('(<regex>.+)',
                            re.MULTILINE | re.DOTALL).findall(regex)[0]
        regexs = re.compile('<regex>(.+?)</regex>',
                            re.MULTILINE | re.DOTALL).findall(regexs)
        regexs = [
            re.compile('<(.+?)>(.*?)</.+?>',
                       re.MULTILINE | re.DOTALL).findall(i) for i in regexs
        ]

        regexs = [
            dict([(client.replaceHTMLCodes(x[0]),
                   client.replaceHTMLCodes(urllib_parse.unquote_plus(x[1])))
                  for x in i]) for i in regexs
        ]
        regexs = [(i['name'], i) for i in regexs]
        regexs = dict(regexs)

        url = regex.split('<regex>', 1)[0].strip()
        url = client.replaceHTMLCodes(url)
        url = six.ensure_str(url)

        r = getRegexParsed(regexs, url)

        try:
            ln = ''
            ret = r[1]
            listrepeat = r[2]['listrepeat']
            regexname = r[2]['name']

            for obj in ret:
                try:
                    item = listrepeat
                    for i in list(range(len(obj) + 1)):
                        item = item.replace(
                            '[%s.param%s]' % (regexname, str(i)), obj[i - 1])

                    item2 = vanilla
                    for i in list(range(len(obj) + 1)):
                        item2 = item2.replace(
                            '[%s.param%s]' % (regexname, str(i)), obj[i - 1])

                    item2 = re.compile('(<regex>.+?</regex>)',
                                       re.MULTILINE | re.DOTALL).findall(item2)
                    item2 = [
                        x for x in item2
                        if not '<name>%s</name>' % regexname in x
                    ]
                    item2 = ''.join(item2)

                    ln += '\n<item>%s\n%s</item>\n' % (item, item2)
                except:
                    pass

            return ln
        except:
            pass

        if r[1] == True:
            return r[0]
    except:
        return
Пример #40
0
  def CreateTempFile(self,
                     tmpdir=None,
                     contents=None,
                     file_name=None,
                     mtime=None,
                     mode=NA_MODE,
                     uid=NA_ID,
                     gid=NA_ID):
    """Creates a temporary file on disk.

    Note: if mode, uid, or gid are present, they must be validated by
    ValidateFilePermissionAccess and ValidatePOSIXMode before calling this
    function.

    Args:
      tmpdir: The temporary directory to place the file in. If not specified, a
              new temporary directory is created.
      contents: The contents to write to the file. If not specified, a test
                string is constructed and written to the file. Since the file
                is opened 'wb', the contents must be bytes.
      file_name: The name to use for the file. If not specified, a temporary
                 test file name is constructed. This can also be a tuple, where
                 ('dir', 'foo') means to create a file named 'foo' inside a
                 subdirectory named 'dir'.
      mtime: The modification time of the file in POSIX time (seconds since
             UTC 1970-01-01). If not specified, this defaults to the current
             system time.
      mode: The POSIX mode for the file. Must be a base-8 3-digit integer
            represented as a string.
      uid: A POSIX user ID.
      gid: A POSIX group ID.

    Returns:
      The path to the new temporary file.
    """

    tmpdir = six.ensure_str(tmpdir or self.CreateTempDir())
    file_name = file_name or self.MakeTempName(str('file'))
    if isinstance(file_name, (six.text_type, six.binary_type)):
      fpath = os.path.join(tmpdir, six.ensure_str(file_name))
    else:
      file_name = map(six.ensure_str, file_name)
      fpath = os.path.join(tmpdir, *file_name)
    if not os.path.isdir(os.path.dirname(fpath)):
      os.makedirs(os.path.dirname(fpath))
    if isinstance(fpath, six.binary_type):
      fpath = fpath.decode(UTF8)

    with open(fpath, 'wb') as f:
      contents = (contents if contents is not None else self.MakeTempName(
          str('contents')))
      if isinstance(contents, bytearray):
        contents = bytes(contents)
      else:
        contents = six.ensure_binary(contents)
      f.write(contents)
    if mtime is not None:
      # Set the atime and mtime to be the same.
      os.utime(fpath, (mtime, mtime))
    if uid != NA_ID or int(gid) != NA_ID:
      os.chown(fpath, uid, int(gid))
    if int(mode) != NA_MODE:
      os.chmod(fpath, int(mode, 8))
    return fpath
Пример #41
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mrpc": MrpcProcessor,
        "xnli": XnliProcessor,
        "lcqmc_pair": LCQMCPairClassificationProcessor
    }

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)

    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
        raise ValueError(
            "At least one of `do_train`, `do_eval` or `do_predict' must be True."
        )

    albert_config = modeling.AlbertConfig.from_json_file(
        FLAGS.albert_config_file)

    if FLAGS.max_seq_length > albert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the ALBERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, albert_config.max_position_embeddings))

    tf.gfile.MakeDirs(FLAGS.output_dir)

    task_name = FLAGS.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()

    label_list = processor.get_labels()

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case,
                                           spm_model_file=FLAGS.spm_model_file)

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            num_shards=FLAGS.num_tpu_cores,
            per_host_input_for_training=is_per_host))

    train_examples = None
    num_train_steps = None
    num_warmup_steps = None
    if FLAGS.do_train:
        train_examples = processor.get_train_examples(FLAGS.data_dir)
        num_train_steps = int(
            len(train_examples) / FLAGS.train_batch_size *
            FLAGS.num_train_epochs)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    model_fn = model_fn_builder(albert_config=albert_config,
                                num_labels=len(label_list),
                                init_checkpoint=FLAGS.init_checkpoint,
                                learning_rate=FLAGS.learning_rate,
                                num_train_steps=num_train_steps,
                                num_warmup_steps=num_warmup_steps,
                                use_tpu=FLAGS.use_tpu,
                                use_one_hot_embeddings=FLAGS.use_tpu)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.contrib.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        predict_batch_size=FLAGS.predict_batch_size)

    if FLAGS.do_train:
        train_file = os.path.join(FLAGS.output_dir, "train.tf_record")
        file_based_convert_examples_to_features(train_examples, label_list,
                                                FLAGS.max_seq_length,
                                                tokenizer, train_file)
        tf.logging.info("***** Running training *****")
        tf.logging.info("  Num examples = %d", len(train_examples))
        tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        tf.logging.info("  Num steps = %d", num_train_steps)
        train_input_fn = file_based_input_fn_builder(
            input_file=train_file,
            seq_length=FLAGS.max_seq_length,
            is_training=True,
            drop_remainder=True)
        estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

    if FLAGS.do_eval:
        eval_examples = processor.get_dev_examples(FLAGS.data_dir)
        num_actual_eval_examples = len(eval_examples)
        if FLAGS.use_tpu:
            # TPU requires a fixed batch size for all batches, therefore the number
            # of examples must be a multiple of the batch size, or else examples
            # will get dropped. So we pad with fake examples which are ignored
            # later on. These do NOT count towards the metric (all tf.metrics
            # support a per-instance weight, and these get a weight of 0.0).
            while len(eval_examples) % FLAGS.eval_batch_size != 0:
                eval_examples.append(PaddingInputExample())

        eval_file = os.path.join(FLAGS.output_dir, "eval.tf_record")
        file_based_convert_examples_to_features(eval_examples, label_list,
                                                FLAGS.max_seq_length,
                                                tokenizer, eval_file)

        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                        len(eval_examples), num_actual_eval_examples,
                        len(eval_examples) - num_actual_eval_examples)
        tf.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

        # This tells the estimator to run through the entire set.
        eval_steps = None
        # However, if running eval on the TPU, you will need to specify the
        # number of steps.
        if FLAGS.use_tpu:
            assert len(eval_examples) % FLAGS.eval_batch_size == 0
            eval_steps = int(len(eval_examples) // FLAGS.eval_batch_size)

        eval_drop_remainder = True if FLAGS.use_tpu else False
        eval_input_fn = file_based_input_fn_builder(
            input_file=eval_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=eval_drop_remainder)

        #######################################################################################################################
        # evaluate all checkpoints; you can use the checkpoint with the best dev accuarcy
        steps_and_files = []
        filenames = tf.gfile.ListDirectory(FLAGS.output_dir)
        for filename in filenames:
            if filename.endswith(".index"):
                ckpt_name = filename[:-6]
                cur_filename = os.path.join(FLAGS.output_dir, ckpt_name)
                global_step = int(cur_filename.split("-")[-1])
                tf.logging.info("Add {} to eval list.".format(cur_filename))
                steps_and_files.append([global_step, cur_filename])
        steps_and_files = sorted(steps_and_files, key=lambda x: x[0])

        output_eval_file = os.path.join(FLAGS.data_dir,
                                        "eval_results_albert_zh.txt")
        print("output_eval_file:", output_eval_file)
        tf.logging.info("output_eval_file:" + output_eval_file)
        with tf.gfile.GFile(output_eval_file, "w") as writer:
            for global_step, filename in sorted(steps_and_files,
                                                key=lambda x: x[0]):
                result = estimator.evaluate(input_fn=eval_input_fn,
                                            steps=eval_steps,
                                            checkpoint_path=filename)

                tf.logging.info("***** Eval results %s *****" % (filename))
                writer.write("***** Eval results %s *****\n" % (filename))
                for key in sorted(result.keys()):
                    tf.logging.info("  %s = %s", key, str(result[key]))
                    writer.write("%s = %s\n" % (key, str(result[key])))
        #######################################################################################################################
        # result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
        # output_eval_file = os.path.join(FLAGS.output_dir, "eval_results.txt")
        # with tf.gfile.GFile(output_eval_file, "w") as writer:
        #  tf.logging.info("***** Eval results *****")
        #  for key in sorted(result.keys()):
        #    tf.logging.info("  %s = %s", key, str(result[key]))
        #    writer.write("%s = %s\n" % (key, str(result[key])))

    if FLAGS.do_predict:
        predict_examples = processor.get_test_examples(FLAGS.data_dir)
        num_actual_predict_examples = len(predict_examples)
        if FLAGS.use_tpu:
            # TPU requires a fixed batch size for all batches, therefore the number
            # of examples must be a multiple of the batch size, or else examples
            # will get dropped. So we pad with fake examples which are ignored
            # later on.
            while len(predict_examples) % FLAGS.predict_batch_size != 0:
                predict_examples.append(PaddingInputExample())

        predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
        file_based_convert_examples_to_features(predict_examples, label_list,
                                                FLAGS.max_seq_length,
                                                tokenizer, predict_file)

        tf.logging.info("***** Running prediction*****")
        tf.logging.info("  Num examples = %d (%d actual, %d padding)",
                        len(predict_examples), num_actual_predict_examples,
                        len(predict_examples) - num_actual_predict_examples)
        tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

        predict_drop_remainder = True if FLAGS.use_tpu else False
        predict_input_fn = file_based_input_fn_builder(
            input_file=predict_file,
            seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=predict_drop_remainder)

        result = estimator.predict(input_fn=predict_input_fn)

        output_predict_file = os.path.join(FLAGS.output_dir,
                                           "test_results.tsv")
        output_submit_file = os.path.join(FLAGS.output_dir,
                                          "submit_results.tsv")
        with tf.gfile.GFile(output_predict_file, "w") as pred_writer,\
            tf.gfile.GFile(output_submit_file, "w") as sub_writer:
            num_written_lines = 0
            tf.logging.info("***** Predict results *****")
            for (i, (example, prediction)) in\
                enumerate(zip(predict_examples, result)):
                probabilities = prediction["probabilities"]
                if i >= num_actual_predict_examples:
                    break
                output_line = "\t".join(
                    str(class_probability)
                    for class_probability in probabilities) + "\n"
                pred_writer.write(output_line)

                actual_label = label_list[int(prediction["predictions"])]
                sub_writer.write(
                    six.ensure_str(example.guid) + "\t" + actual_label + "\n")
                num_written_lines += 1
        assert num_written_lines == num_actual_predict_examples
Пример #42
0
  def ReportMetrics(self,
                    wait_for_report=False,
                    log_level=None,
                    log_file_path=None):
    """Reports the collected metrics using a separate async process.

    Args:
      wait_for_report: bool, True if the main process should wait for the
        subprocess to exit for testing purposes.
      log_level: int, The subprocess logger's level of debugging for testing
        purposes.
      log_file_path: str, The file that the metrics_reporter module should
        write its logs to. If not supplied, the metrics_reporter module will
        use a predetermined default path. This parameter is intended for use
        by tests that need to evaluate the contents of the file at this path.
    """
    self._CollectCommandAndErrorMetrics()
    self._CollectPerformanceSummaryMetric()
    if not self._metrics:
      return

    if not log_level:
      log_level = self.logger.getEffectiveLevel()
    # If this a testing subprocess, we don't want to write to the log file.
    if os.environ.get('GSUTIL_TEST_ANALYTICS') == '2':
      log_level = logging.WARN

    temp_metrics_file = tempfile.NamedTemporaryFile(delete=False)
    temp_metrics_file_name = six.ensure_str(temp_metrics_file.name)
    with temp_metrics_file:
      pickle.dump(self._metrics, temp_metrics_file)
    logging.debug(self._metrics)
    self._metrics = []

    if log_file_path is not None:
      # If the path is not None, we'll need to surround the path with quotes
      # so that the path is passed as a string to the metrics_reporter module.
      log_file_path = six.ensure_str('r"%s"' % log_file_path)

    reporting_code = six.ensure_str(
        'from gslib.metrics_reporter import ReportMetrics; '
        'ReportMetrics(r"{0}", {1}, log_file_path={2})'.format(
            temp_metrics_file_name, log_level, log_file_path))
    execution_args = [sys.executable, '-c', reporting_code]
    exec_env = os.environ.copy()
    exec_env['PYTHONPATH'] = os.pathsep.join(sys.path)
    # Ensuring submodule (sm) environment keys and values are all str.
    sm_env = dict()
    for k, v in six.iteritems(exec_env):
      sm_env[six.ensure_str(k)] = six.ensure_str(v)
    try:
      # In order for Popen to work correctly with Windows/Py3 shell needs
      # to be True.
      p = subprocess.Popen(execution_args,
                           env=sm_env,
                           shell=(six.PY3 and system_util.IS_WINDOWS))
      self.logger.debug('Metrics reporting process started...')

      if wait_for_report:
        # NOTE: p.wait() can cause a deadlock. p.communicate() is recommended.
        # See python docs for more information.
        p.communicate()
        self.logger.debug('Metrics reporting process finished.')
    except OSError:
      # This can happen specifically if the Python executable moves between the
      # start of this process and now.
      self.logger.debug('Metrics reporting process failed to start.')
      # Delete the tempfile that would normally be cleaned up in the subprocess.
      try:
        os.unlink(temp_metrics_file.name)
      except:  # pylint: disable=bare-except
        pass
Пример #43
0
 def data_to_value(dat):
     if isinstance(dat, bytes):
         dat = ensure_str(dat)  # py2compat
     return dat