Example #1
0
def main(_):
  while True:
    try:
      # Client to communicate with the learner.
      client = grpc.Client(FLAGS.server_address)

      env = config.create_environment(FLAGS.task)

      # Unique ID to identify a specific run of an actor.
      run_id = np.random.randint(np.iinfo(np.int64).max)
      observation = env.reset()
      reward = 0.0
      raw_reward = 0.0
      done = False

      while True:
        env_output = utils.EnvOutput(reward, done, np.array(observation))
        action = client.inference((FLAGS.task, run_id, env_output, raw_reward))
        observation, reward, done, info = env.step(action.numpy())
        raw_reward = float(info.get('score_reward', reward))

        if done:
          observation = env.reset()
    except (tf.errors.UnavailableError, tf.errors.CancelledError) as e:
      logging.exception(e)
      env.close()
Example #2
0
 def __rshift__(self, other: int) -> 'Bits':
   shift_amount = min(other, self.bit_count)
   try:
     return Bits(bit_count=self.bit_count, value=self.value >> shift_amount)
   except OverflowError:
     logging.exception('bit_count %r other %r', self.bit_count, other)
     raise
def convex_hull_roc(roc):
    """Returns an roc curve without the points inside the convex hull.

  Points below the fpr=tpr line corresponding to random performance are also
  removed.

  Args:
    roc: A tuple of lists that are all the same length, containing
      (false_positive_rates, true_positive_rates, thresholds). This is the same
      format returned by sklearn.metrics.roc_curve.
  """
    fprs, tprs, thresholds = roc
    if np.isnan(fprs).any() or np.isnan(tprs).any():
        logging.warning("Convex hull solver does not handle NaNs.")
        return roc
    if len(fprs) < 3:
        logging.warning("Convex hull solver does not curves with < 3 points.")
        return roc
    try:
        # Add (fpr=1, tpr=0) to the convex hull to remove any points below the
        # random-performance line.
        hull = scipy.spatial.ConvexHull(np.vstack([fprs + [1], tprs + [0]]).T)
    except scipy.spatial.qhull.QhullError:
        logging.exception("Convex hull solver failed.")
        return roc
    verticies = set(hull.vertices)

    return (
        [fpr for idx, fpr in enumerate(fprs) if idx in verticies],
        [tpr for idx, tpr in enumerate(tprs) if idx in verticies],
        [thresh for idx, thresh in enumerate(thresholds) if idx in verticies],
    )
    def _QueryLdap(self, base_dn, ldap_filter, scope=ldap.SCOPE_SUBTREE):
        """Yields LDAP results for a given filter in a given base DN.

    This method exists primarily to assist with paging through large result
    sets, and to centralize exception handling and retrying.

    Args:
      base_dn: str base distinguishedName to query within.
      ldap_filter: str LDAP filter to query with.
      scope: ldap.SCOPE_SUBTREE (default), SCOPE_ONELEVEL, or SCOPE_BASE.
    Yields:
      LDAP result dictionary.
    Raises:
      ldap.LDAPError: there was an error querying LDAP.
    """
        if not self.conn:
            self._ConnectToAd()

        page_control = controls.SimplePagedResultsControl(True,
                                                          size=self.page_size,
                                                          cookie='')

        failures = 0
        # Iterate over all hosts matching the given ldap_filter, in batches of
        # self.page_size, escrowing each resulting recovery object.
        while True:
            server_controls = [page_control]
            query_id = self.conn.search_ext(base_dn,
                                            scope,
                                            ldap_filter,
                                            serverctrls=server_controls)

            try:
                _, results, _, server_controls = self.conn.result3(query_id)
            except ldap.LDAPError:
                failures += 1
                if failures == MAX_QUERY_FAILURES:
                    raise
                logging.exception('LDAPError on result3() call.')
                time.sleep(5 * failures)
                self._ConnectToAd()
                continue
            else:
                failures = 0

            for result in results:
                yield result[1]

            # Update page_control with server provided control data, otherwise there
            # are no more results, so break.
            cookie = None
            for server_control in server_controls:
                if (server_control.controlType ==
                        controls.SimplePagedResultsControl.controlType):
                    cookie = server_control.cookie
                    if cookie:
                        page_control.cookie = cookie
                    break
            if not cookie:
                break
Example #5
0
  def upload_file(self, source_file_path: str, bucket_name: str,
                  destination_file_path: str) -> None:
    """Uploads file from source file system to Cloud Storage.

    If the bucket doesn't exist in the Cloud Storage, it will be created.

    Args:
      source_file_path: Path to the file to be uploaded. e.g - /tmp/file.txt
      bucket_name: Cloud Storage bucket to which the file should be uploaded. If
        the Cloud Storage URL is 'gs://bucket1/file1.txt', then the bucket_name
        would be 'bucket1'.
      destination_file_path: Path of the destination blob/object within the
        Cloud Storage bucket. If the Cloud Storage URL is
        'gs://bucket1/dir1/file1.txt', then the destination_file_path would be
        'dir1/file1.txt'.
    Raises:
      FileNotFoundError: If the provided file is not found.
      Error: If the upload was not successful.
    """
    if not os.path.isfile(source_file_path):
      logging.error('The file "%s" could not be found.', source_file_path)
      raise FileNotFoundError(
          f'The file "{source_file_path}" could not be found.')
    try:
      logging.info('Uploading "%s" file to "gs://%s/%s"', source_file_path,
                   bucket_name, destination_file_path)
      bucket = self._get_or_create_bucket(bucket_name)
      self._upload_file(source_file_path, bucket, destination_file_path)
      logging.info('Uploaded "%s" file to "gs://%s/%s"', source_file_path,
                   bucket_name, destination_file_path)
    except exceptions.RetryError:
      error_message = (f'Error when uploading file "{source_file_path}" to '
                       f'"gs://{bucket_name}/{destination_file_path}"')
      logging.exception(error_message)
      raise Error(error_message)
Example #6
0
def create_optimizer(learning_rate, params):
    """Creates optimized based on the specified flags."""
    if params['optimizer'] == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate,
                                               momentum=params['momentum'])
    elif params['optimizer'] == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate)
    elif params['optimizer'] == 'adadelta':
        optimizer = tf.train.AdadeltaOptimizer(learning_rate)
    elif params['optimizer'] == 'adagrad':
        optimizer = tf.train.AdagradOptimizer(learning_rate)
    elif params['optimizer'] == 'rmsprop':
        optimizer = tf.train.RMSPropOptimizer(learning_rate,
                                              momentum=params['momentum'])
    elif params['optimizer'] == 'lars':
        try:
            from tensorflow.contrib.opt import LARSOptimizer  # pylint: disable=g-import-not-at-top

            optimizer = LARSOptimizer(
                learning_rate,
                momentum=params['momentum'],
                weight_decay=params['lars_weight_decay'],
                skip_list=['batch_normalization', 'bias'])
        except ImportError as e:
            logging.exception('LARSOptimizer is currently not supported '
                              'in TensorFlow 2.x.')
            raise e

    else:
        raise ValueError('Unsupported optimizer type %s.' %
                         params['optimizer'])
    return optimizer
    def _ConnectToAd(self):
        """Establish a connection to the Active Directory server."""

        # To enable verbose debug logging, uncomment the following line.
        # ldap.set_option(ldap.OPT_DEBUG_LEVEL, 255)

        ldap.set_option(ldap.OPT_REFERRALS, 0)
        ldap.set_option(ldap.OPT_X_TLS_ALLOW, 1)

        ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_DEMAND)
        logging.debug('Connecting to Active Directory: %s', self.ldap_url)

        failures = 0
        while True:
            try:
                self.conn = ldap.initialize(self.ldap_url)
                self.conn.protocol_version = ldap.VERSION3
                self.conn.simple_bind_s(self.auth_user, self.auth_password)
                break
            except ldap.LDAPError, e:
                failures += 1
                if failures == MAX_CONNECTION_FAILURES:
                    raise
                logging.exception('LDAPError in ConnectToAd().')
                if e.args and e.args[0].get(
                        'desc') == 'Can\'t contact LDAP server':
                    time.sleep(5 * failures)
                    continue
                raise
Example #8
0
 def test_version_numbers(self):
     run_config = run_configs.get()
     failures = []
     for game_version, version in sorted(run_config.get_versions().items()):
         try:
             self.assertEqual(game_version, version.game_version)
             log_center("starting version check: %s", game_version)
             with run_config.start(version=game_version,
                                   want_rgb=False) as controller:
                 ping = controller.ping()
                 logging.info("expected: %s", version)
                 logging.info("actual: %s",
                              ", ".join(str(ping).strip().split("\n")))
                 self.assertEqual(version.build_version, ping.base_build)
                 if version.game_version != "latest":
                     self.assertEqual(major_version(ping.game_version),
                                      major_version(version.game_version))
                     self.assertEqual(version.data_version.lower(),
                                      ping.data_version.lower())
             log_center("success: %s", game_version)
         except:  # pylint: disable=bare-except
             log_center("failure: %s", game_version)
             logging.exception("Failed")
             failures.append(game_version)
     self.assertEmpty(failures)
Example #9
0
def main(argv):
    del argv  # Unused.

    # Initialise Tink
    try:
        jwt.register_jwt_signature()
    except tink.TinkError as e:
        logging.exception('Error initialising Tink: %s', e)
        return 1

    # Read the keyset into a KeysetHandle
    with open(FLAGS.keyset_path, 'rt') as keyset_file:
        try:
            text = keyset_file.read()
            keyset_handle = cleartext_keyset_handle.read(
                tink.JsonKeysetReader(text))
        except tink.TinkError as e:
            logging.exception('Error reading keyset: %s', e)
            return 1

    # Export Public Keyset as JWK set
    public_jwk_set = jwt.jwk_set_from_public_keyset_handle(
        keyset_handle.public_keyset_handle())
    with open(FLAGS.public_jwk_set_path, 'wt') as public_jwk_set_file:
        public_jwk_set_file.write(public_jwk_set)
    logging.info('The public JWK set has been written to %s',
                 FLAGS.public_jwk_set_path)
Example #10
0
def train(config_info, train_task, s3_path, verbosity="info"):
    if verbosity in VERBOSITY_MAP.keys():
        logging.set_verbosity(VERBOSITY_MAP[verbosity])
        pass
    else:
        logging.warning("un-known logging level-{}".format(verbosity))

    controller = start_train(config_info,
                             train_task,
                             data_url=s3_path,
                             verbosity=verbosity)
    loop_is_end = False
    try:
        controller.tasks_loop()
        loop_is_end = True
    except (KeyboardInterrupt, EOFError) as ex:
        logging.warning("Get a KeyboardInterrupt, Stop early.")
    except BaseException as ex:
        logging.exception(ex)
        logging.warning("Get a Exception, Stop early.")

    # handle close signal, with cleaning works.
    for _task in controller.tasks:
        _task.train_worker.logger.save_to_json()
    controller.stop()

    # fixme: make close harmonious between controller & broker
    time.sleep(2)
    if loop_is_end:
        logging.info("Finished train job normally.")

    os._exit(0)
  def _ConnectToAd(self):
    """Establish a connection to the Active Directory server."""

    # To enable verbose debug logging, uncomment the following line.
    # ldap.set_option(ldap.OPT_DEBUG_LEVEL, 255)

    ldap.set_option(ldap.OPT_REFERRALS, 0)
    ldap.set_option(ldap.OPT_X_TLS_ALLOW, 1)

    ldap.set_option(ldap.OPT_X_TLS_REQUIRE_CERT, ldap.OPT_X_TLS_DEMAND)
    logging.debug('Connecting to Active Directory: %s', self.ldap_url)

    failures = 0
    while True:
      try:
        self.conn = ldap.initialize(self.ldap_url)
        self.conn.protocol_version = ldap.VERSION3
        self.conn.simple_bind_s(self.auth_user, self.auth_password)
        break
      except ldap.LDAPError, e:
        failures += 1
        if failures == MAX_CONNECTION_FAILURES:
          raise
        logging.exception('LDAPError in ConnectToAd().')
        if e.args and e.args[0].get('desc') == 'Can\'t contact LDAP server':
          time.sleep(5 * failures)
          continue
        raise
Example #12
0
    def post(self):
        """Process an async Action task with the correct Action class."""
        payload = pickle.loads(self.request.body)
        async_actions = payload.pop('async_actions')
        action_name = async_actions.pop(0)
        action_instance = self.actions['async'].get(action_name)
        if action_instance:
            try:
                action_instance.run(**payload)
            # pylint: disable=broad-except, because this logic, in which tasks are
            # responsible for spawning subsequent tasks, creates a chain that could be
            # interrupted by any conceivable exception in an action's run method. This
            # handling ensures any further tasks will run.
            except Exception as error:
                logging.exception(
                    'Failed to run async Action %r due to error: %r',
                    action_name, str(error))
            # pylint: enable=broad-except
        else:
            logging.error('No async Action named %s found.', action_name)

        if async_actions:
            payload['async_actions'] = async_actions
            taskqueue.add(queue_name='process-action',
                          payload=pickle.dumps(payload),
                          target='default')
Example #13
0
def main(config_file, train_task, s3_path=None, verbosity="info"):
    """do train task with single case """
    broker_master = start_train(config_file,
                                train_task,
                                data_url=s3_path,
                                verbosity=verbosity)
    loop_is_end = False
    try:
        broker_master.main_loop()
        loop_is_end = True
    except (KeyboardInterrupt, EOFError) as ex:
        logging.warning("Get a KeyboardInterrupt, Stop early.")
    except BaseException as ex:
        logging.exception(ex)
        logging.warning("Get a Exception, Stop early.")

    # handle close signal, with cleaning works.
    broker_master.main_task.train_worker.logger.save_to_json()
    broker_master.stop()

    # fixme: make close harmonious between broker master & slave
    time.sleep(2)
    if loop_is_end:
        logging.info("Finished train job normally.")

    os._exit(0)
Example #14
0
    def _ValidateWithRetry(self, model_path: Text,
                           serving_binary: serving_bins.ServingBinary,
                           serving_spec: infra_validator_pb2.ServingSpec,
                           validation_spec: infra_validator_pb2.ValidationSpec,
                           requests: List[iv_types.Request]):

        for i in range(validation_spec.num_tries):
            logging.info('Infra validation trial %d/%d start.', i + 1,
                         validation_spec.num_tries)
            try:
                self._ValidateOnce(model_path=model_path,
                                   serving_binary=serving_binary,
                                   serving_spec=serving_spec,
                                   validation_spec=validation_spec,
                                   requests=requests)
                # If validation has passed without any exception, succeeded.
                return True
            except Exception as e:  # pylint: disable=broad-except
                # Exception indicates validation failure. Log the error and retry.
                logging.exception(e)
                if isinstance(e, error_types.DeadlineExceeded):
                    logging.info('Consider increasing the value of '
                                 'ValidationSpec.max_loading_time_seconds.')
                continue

        # Every trial has failed. Marking model as not blessed.
        return False
Example #15
0
def get_service_account(project_id: str,
                        service_account_name: str) -> Dict[str, Any]:
    """Find the service account with given name.

  Args:
    project_id: GCP project id.
    service_account_name: The service account name.

  Returns:
    service_account: If the service account is found in the cloud project.
    None: If no service account is found.
  """
    try:
        logging.info('Retrieving "%s" service account in "%s" project',
                     service_account_name, project_id)
        name = 'projects/{p}/serviceAccounts/{s}@{p}.iam.gserviceaccount.com'.format(
            p=project_id, s=service_account_name)
        service_account_details = _get_service_account_client().get(
            name=name).execute()
        return service_account_details
    except errors.HttpError as error:
        if error.resp.status == _NOT_FOUND_ERROR_CODE:
            return None  # pytype: disable=bad-return-type
        logging.exception(
            'Error occurred while retrieving service account: "%s".', error)
        raise Error('Error occurred while retrieving service account.')
Example #16
0
async def _initialize():
    try:
        await asyncio.gather(CardDb.get().initialize(),
                             ManamojiDb.get().initialize(bot))
    except:
        logging.exception('Bot initialization has failed.')
        logging.fatal('Dying due to failed initialization.')
Example #17
0
  def setUpClass(cls):
    super().setUpClass()
    cls.compiled_modules = {}
    if cls._modules_to_compile:
      for name, (ctor, exported_names,
                 backends) in cls._modules_to_compile.items():
        # Setup crash reproducer
        crash_reproducer_path = os.path.join(FLAGS.test_tmpdir, cls.__name__,
                                             name + ".mlir")
        try:
          os.makedirs(os.path.dirname(crash_reproducer_path))
        except IOError:
          logging.exception("Error creating crash reproducer dir for: %s",
                            crash_reproducer_path)
        compiler.Context.default_crash_reproducer_path = crash_reproducer_path

        try:
          # Compile.
          if backends is None:
            backends = get_default_test_backends()
          cls.compiled_modules[name] = dict([
              (backend.name, CompiledModule.create(ctor, exported_names,
                                                   backend))
              for backend in backends
          ])
        finally:
          # Disable crash reproducer (to avoid inadvertently overwriting this
          # path on a subsequent interaction).
          compiler.Context.default_crash_reproducer_path = None
Example #18
0
def _log_wrapper(method, args, kwargs):
    try:
        return_value = method(*args, **kwargs)
    except:
        logging.exception('Exception in %s', method)
        raise
    return return_value
Example #19
0
 def response_thread_fn():
     """Consumes response iter and exposes the value on corresponding Event."""
     try:
         logging.debug('Response thread: blocking for next response')
         for response in response_iter:
             logging.debug(
                 'Response thread: processing response of type %s, seq_no %s',
                 response.WhichOneof('response'),
                 response.sequence_number)
             # Get the corresponding response Event
             response_event = self._response_event_dict[
                 response.sequence_number]
             # Attach the response as an attribute on the Event
             response_event.response = response
             response_event.set()
         # Set the event indicating the stream has been closed
         self._stream_closed_event.set()
     except grpc.RpcError as error:
         logging.exception('Error calling remote executor: %s', error)
         if _is_retryable_grpc_error(error):  # pytype: disable=attribute-error
             logging.info('gRPC error is retryable')
             error = execution_context.RetryableError(error)
         # Set all response events to errors. This is heavy-handed and
         # potentially can be scaled back.
         for _, response_event in self._response_event_dict.items():
             if not response_event.isSet():
                 response_event.response = error
                 response_event.set()
         self._stream_closed_event.set()
Example #20
0
  def _ValidateWithRetry(
      self, model_path: Text,
      serving_binary: serving_bins.ServingBinary,
      serving_spec: infra_validator_pb2.ServingSpec,
      validation_spec: infra_validator_pb2.ValidationSpec,
      requests: List[iv_types.Request]):

    for i in range(validation_spec.num_tries):
      logging.info('Starting infra validation (attempt %d/%d).', i + 1,
                   validation_spec.num_tries)
      try:
        self._ValidateOnce(
            model_path=model_path,
            serving_binary=serving_binary,
            serving_spec=serving_spec,
            validation_spec=validation_spec,
            requests=requests)
      except error_types.GracefulShutdown:
        # GracefulShutdown means infra validation aborted. No more retry and
        # escalate the error.
        raise
      except Exception as e:  # pylint: disable=broad-except
        # Other exceptions indicates validation failure. Log the error and
        # retry.
        logging.exception('Infra validation (attempt %d/%d) failed.', i + 1,
                          validation_spec.num_tries)
        if isinstance(e, error_types.DeadlineExceeded):
          logging.info('Consider increasing the value of '
                       'ValidationSpec.max_loading_time_seconds.')
      else:
        # If validation has passed without any exception, succeeded.
        return True

    # Every trial has failed. Marking model as not blessed.
    return False
Example #21
0
 def _process_exec_node_task(self, scheduler: ts.TaskScheduler,
                             task: task_lib.ExecNodeTask) -> None:
     """Processes an `ExecNodeTask` using the given task scheduler."""
     # This is a blocking call to the scheduler which can take a long time to
     # complete for some types of task schedulers. The scheduler is expected to
     # handle any internal errors gracefully and return the result with an error
     # status. But in case the scheduler raises an exception, it is considered
     # a failed execution and MLMD is updated accordingly.
     try:
         result = scheduler.schedule()
     except Exception as e:  # pylint: disable=broad-except
         logging.exception(
             'Exception raised by task scheduler; node uid: %s',
             task.node_uid)
         result = ts.TaskSchedulerResult(status=status_lib.Status(
             code=status_lib.Code.ABORTED, message=str(e)))
     logging.info(
         'For ExecNodeTask id: %s, task-scheduler result status: %s',
         task.task_id, result.status)
     _publish_execution_results(mlmd_handle=self._mlmd_handle,
                                task=task,
                                result=result)
     with self._publish_time_lock:
         self._last_mlmd_publish_time = time.time()
     with self._tm_lock:
         del self._scheduler_by_node_uid[task.node_uid]
         self._task_queue.task_done(task)
Example #22
0
    def __next__(self) -> types.NestedArray:
        try:
            if not self.pmapped_user:
                item = next(self.iterator)
                if self.split_fn is None:
                    return jax.device_put(item, self.devices[0])
                item_split = self.split_fn(item)
                return PrefetchingSplit(host=item_split.host,
                                        device=jax.device_put(
                                            item_split.device,
                                            self.devices[0]))

            items = itertools.islice(self.iterator, self.num_devices)
            items = tuple(items)
            if len(items) < self.num_devices:
                raise StopIteration
            if self.split_fn is None:
                return jax.device_put_sharded(tuple(items), self.devices)
            else:
                # ((host: x1, device: y1), ..., (host: xN, device: yN)).
                items_split = (self.split_fn(item) for item in items)
                # (host: (x1, ..., xN), device: (y1, ..., yN)).
                split = tree.map_structure_up_to(PrefetchingSplit(None, None),
                                                 lambda *x: x, *items_split)

                return PrefetchingSplit(host=np.stack(split.host),
                                        device=jax.device_put_sharded(
                                            split.device, self.devices))

        except StopIteration:
            raise

        except Exception:  # pylint: disable=broad-except
            logging.exception('Error for %s', self.iterable)
            raise
Example #23
0
    def delete_environment(self, environment_name: str) -> None:
        """Deletes an existing Cloud Composer environment.

    Args:
      environment_name: Name of Composer environment.

    Raises:
      Error: If the request was not processed successfully.
    """
        fully_qualified_name = self._get_fully_qualified_env_name(
            environment_name)
        logging.info('Deleting "%s" Composer environment from "%s" project.',
                     fully_qualified_name, self.project_id)
        try:
            request = self.client.projects().locations().environments().delete(
                name=fully_qualified_name)
            operation = utils.execute_request(request)
            operation_client = self.client.projects().locations().operations()
            utils.wait_for_operation(operation_client, operation)
        except errors.HttpError as error:
            if error.__dict__['resp'].status == _HTTP_NOT_FOUND_CODE:
                logging.info('The Composer environment %s does not exists.',
                             fully_qualified_name)
                return
            logging.exception(
                'Error occurred while deleting Composer environment.')
            raise Error('Error occurred while deleting Composer environment.')
Example #24
0
    def producer():
        """Enqueues batched items from `iterable` on a given thread."""
        try:
            # Build a new iterable for each thread. This is crucial if working with
            # tensorflow datasets because tf.Graph objects are thread local.
            it = iter(iterable)
            while True:
                items = itertools.islice(it, len(devices))
                if not items:
                    break
                if split_fn is None:
                    buffer.put(
                        jax.api.device_put_sharded(tuple(items), devices))
                else:
                    # ((host: x1, device: y1), ..., (host: xN, device: yN)).
                    items_split = (split_fn(item) for item in items)
                    # (host: (x1, ..., xN), device: (y1, ..., yN)).
                    split = tree.map_structure_up_to(
                        PrefetchingSplit(None, None), lambda *x: x,
                        *items_split)

                    buffer.put(
                        PrefetchingSplit(host=np.stack(split.host),
                                         device=jax.api.device_put_sharded(
                                             split.device, devices)))
        except Exception as e:  # pylint: disable=broad-except
            logging.exception('Error in producer thread for %s',
                              iterable.__name__)
            producer_error.append(e)
        finally:
            buffer.put(end)
Example #25
0
def _encode_multivalent_numeric_feature(
        feature_array: pa.Array, encoding_length: int) -> Optional[List[int]]:
    """Encodes numeric multivalent features into a fixed length representation.

  Numeric multivalent features are encoded using bucketization.
  max_encoding_length bins of equal sized intervals are constructed from the
  feature values. For each example, a histogram is constructed. These bin
  counts represent an encoding for the example.

  Args:
    feature_array: Arrow Array.
    encoding_length: The length of the list containing the encoded feature
      values.

  Returns:
    A list containing the encoded feature values for each example. Returns None
    if unable to encode the feature_array.
  """
    flattened_feature_values = _get_flattened_feature_values_without_nulls(
        feature_array)
    try:
        _, histogram_bin_boundaries = np.histogram(flattened_feature_values,
                                                   bins=encoding_length - 1)
    except IndexError as e:
        # np.histogram cannot handle values > 2**53 if the min and max of the
        # examples are the same. https://github.com/numpy/numpy/issues/8627
        logging.exception("Unable to encode examples: %s with error: %s",
                          flattened_feature_values, e)
        return None
    return _apply_numerical_encoding_to_feature_array(
        feature_array, histogram_bin_boundaries, encoding_length)
Example #26
0
 def test_Parallel(self):
     methods_and_args = []
     methods_and_args.append((_CalculateVisqol, [REF_FILE, DEG_FILE]))
     methods_and_args.append((_CalculateVisqol, [REF_FILE, DEG_FILE]))
     results = []
     with futures.ThreadPoolExecutor(
             max_workers=len(methods_and_args)) as executor:
         started_futures = []
         # This is pretty much a map() but we have better control over the chunk
         # size.
         for method, args in methods_and_args:
             future = executor.submit(_log_wrapper, method, args, {})
             started_futures.append((future, method, args))
         last_exception = None
         for f, method, args in started_futures:
             try:
                 results.append(f.result())
             except Exception as ex:  # pylint: disable=broad-except
                 e = ex  # This assigment works around pytype bug http://b/136279340.
                 msg = (
                     'Execution of method %s with args %s in a separate thread failed.'
                     % (method, args))
                 msg += " The last failed method's exception will be re-thrown."
                 last_exception = e
                 logging.exception(msg)
         if last_exception:
             raise last_exception  # Can only be Exception pylint: disable=raising-bad-type
         return results
Example #27
0
 def _start_session(self):
     try:
         profiler.start(logdir=self._logdir)
         self._session_running = True
         self._session_started = time.time()
     except Exception as e:  # pylint: disable=broad-except
         logging.exception("Could not start profiling: %s", e)
Example #28
0
 def response_thread_fn():
     """Consumes response iter and exposes the value on corresponding Event."""
     try:
         logging.debug('Response thread: blocking for next response')
         for response in response_iter:
             logging.debug(
                 'Response thread: processing response of type %s, seq_no %s',
                 response.WhichOneof('response'),
                 response.sequence_number)
             # Get the corresponding response Event
             response_event = self._response_event_dict[
                 response.sequence_number]
             # Attach the response as an attribute on the Event
             response_event.response = response
             response_event.set()
         # Set the event indicating the stream has been closed
         self._stream_closed_event.set()
     except Exception as error:  # pylint: disable=broad-except
         logging.exception('Error calling remote executor: %s', error)
         if _is_retryable_grpc_error(error):
             logging.exception('gRPC error is retryable')
             error = execution_context.RetryableError(error)
         with self._response_event_lock:
             self._stream_error = error
             for _, response_event in self._response_event_dict.items():
                 if not response_event.isSet():
                     response_event.response = error
                     response_event.set()
         self._stream_closed_event.set()
Example #29
0
 def response_thread_fn():
     """Consumes response iter and exposes the value on corresponding Event."""
     try:
         logging.debug('Response thread: blocking for next response')
         for response in response_iter:
             if response.WhichOneof('response') is None:
                 # TODO(b/175927125): We currently pass an empty response in some
                 # error cases and pass a GRPC error back via the ServicerContext in
                 # some others. Unify this error-passing.
                 raise execution_context.RetryableError(
                     'Unknown error on the service side.')
             logging.debug(
                 'Response thread: processing response of type %s, seq_no %s',
                 response.WhichOneof('response'),
                 response.sequence_number)
             # Get the corresponding response Event
             response_event = self._response_event_dict[
                 response.sequence_number]
             # Attach the response as an attribute on the Event
             response_event.response = response
             response_event.set()
         # Set the event indicating the stream has been closed
         self._stream_closed_event.set()
     except Exception as error:  # pylint: disable=broad-except
         logging.exception('Error calling remote executor: %s', error)
         if _is_retryable_grpc_error(error):
             logging.exception('gRPC error is retryable')
             error = execution_context.RetryableError(error)
         with self._response_event_lock:
             self._stream_error = error
             for _, response_event in self._response_event_dict.items():
                 if not response_event.isSet():
                     response_event.response = error
                     response_event.set()
         self._stream_closed_event.set()
Example #30
0
    def get_environment(self, environment_name: str) -> Dict[str, Any]:
        """Retrieves details of a Composer environment.

    Args:
      environment_name: Name of the existing Composer environment. The fully
        qualified environment name will be constructed as follows -
        'projects/{project_id}/locations/{location}/environments/
        {environment_name}'.

    Returns:
      environment: Details of Composer environment.

    Raises:
      Error: If the request was not processed successfully.
    """
        fully_qualified_name = self._get_fully_qualified_env_name(
            environment_name)
        logging.info('Retrieving Composer environment details for "%s"',
                     fully_qualified_name)
        try:
            request = self.client.projects().locations().environments().get(
                name=fully_qualified_name)
            composer_environment_details = utils.execute_request(request)
            return composer_environment_details
        except errors.HttpError:
            logging.exception(
                'Error while retrieving Composer environment details.')
            raise Error('Error while retrieving Composer environment details.')
Example #31
0
def ProcessImageList(image_list):
    """Extract labels from image, image resides in CNS or local file.

    Args:
      movies: (list). Contains movie information.

    Returns:
      list: List of tuples (master_id, image labels).

    Raises:
        ValueError: Invalid movie image list.
    """

    if not image_list:
        raise ValueError('Invalid image list')

    image_labels = []
    with futures.ThreadPoolExecutor(max_workers=_MAX_WORKERS) as executor:
        # Start the load operations and mark each future with its Image id.
        future_to_url = {
            executor.submit(_ExtractLabels, image): image
            for image in image_list
        }
        for future in futures.as_completed(future_to_url):
            image = future_to_url[future]
            try:
                image_labels.append((image.master_id, future.result()))
            except ValueError as e:
                logging.exception('Exception at: %s, %s', image.master_id, e)
    return image_labels
  def _QueryLdap(self, base_dn, ldap_filter, scope=ldap.SCOPE_SUBTREE):
    """Yields LDAP results for a given filter in a given base DN.

    This method exists primarily to assist with paging through large result
    sets, and to centralize exception handling and retrying.

    Args:
      base_dn: str base distinguishedName to query within.
      ldap_filter: str LDAP filter to query with.
      scope: ldap.SCOPE_SUBTREE (default), SCOPE_ONELEVEL, or SCOPE_BASE.
    Yields:
      LDAP result dictionary.
    Raises:
      ldap.LDAPError: there was an error querying LDAP.
    """
    if not self.conn:
      self._ConnectToAd()

    page_control = controls.SimplePagedResultsControl(
        True, size=self.page_size, cookie='')

    failures = 0
    # Iterate over all hosts matching the given ldap_filter, in batches of
    # self.page_size, escrowing each resulting recovery object.
    while True:
      server_controls = [page_control]
      query_id = self.conn.search_ext(
          base_dn, scope, ldap_filter, serverctrls=server_controls)

      try:
        _, results, _, server_controls = self.conn.result3(query_id)
      except ldap.LDAPError:
        failures += 1
        if failures == MAX_QUERY_FAILURES:
          raise
        logging.exception('LDAPError on result3() call.')
        time.sleep(5 * failures)
        self._ConnectToAd()
        continue
      else:
        failures = 0

      for result in results:
        yield result[1]

      # Update page_control with server provided control data, otherwise there
      # are no more results, so break.
      cookie = None
      for server_control in server_controls:
        if (server_control.controlType ==
            controls.SimplePagedResultsControl.controlType):
          cookie = server_control.cookie
          if cookie:
            page_control.cookie = cookie
          break
      if not cookie:
        break