Exemple #1
0
 def total_gc_size(task):
     return sum([
         task.data_size, task.metadata_size
         if self._include_metadata else Amount(0, Data.BYTES),
         task.log_size if self._include_logs else Amount(0, Data.BYTES)
     ], Amount(0, Data.BYTES))
Exemple #2
0
class SchedulerProxy(object):
    """
    This class is responsible for creating a reliable thrift client to the
    twitter scheduler.  Basically all the dirty work needed by the
    AuroraClientAPI.
  """
    CONNECT_MAXIMUM_WAIT = Amount(1, Time.MINUTES)
    RPC_RETRY_INTERVAL = Amount(5, Time.SECONDS)
    RPC_MAXIMUM_WAIT = Amount(10, Time.MINUTES)

    class Error(Exception):
        pass

    class TimeoutError(Error):
        pass

    class TransientError(Error):
        pass

    class AuthError(Error):
        pass

    class APIVersionError(Error):
        pass

    class ThriftInternalError(Error):
        pass

    def __init__(self, cluster, verbose=False, **kwargs):
        self.cluster = cluster
        # TODO(Sathya): Make this a part of cluster trait when authentication is pushed to the transport
        # layer.
        self._client = self._scheduler_client = None
        self.verbose = verbose
        self._lock = threading.RLock()
        self._terminating = threading.Event()
        self._kwargs = kwargs

    def with_scheduler(method):
        """Decorator magic to make sure a connection is made to the scheduler"""
        def _wrapper(self, *args, **kwargs):
            if not self._client:
                self._construct_scheduler()
            return method(self, *args, **kwargs)

        return _wrapper

    def invalidate(self):
        self._client = self._scheduler_client = None

    def terminate(self):
        """Requests immediate termination of any retry attempts and invalidates client."""
        self._terminating.set()
        self.invalidate()

    @with_scheduler
    def client(self):
        return self._client

    @with_scheduler
    def scheduler_client(self):
        return self._scheduler_client

    def _construct_scheduler(self):
        """
      Populates:
        self._scheduler_client
        self._client
    """
        self._scheduler_client = SchedulerClient.get(self.cluster,
                                                     verbose=self.verbose,
                                                     **self._kwargs)
        assert self._scheduler_client, "Could not find scheduler (cluster = %s)" % self.cluster.name
        start = time.time()
        while (time.time() - start) < self.CONNECT_MAXIMUM_WAIT.as_(
                Time.SECONDS):
            try:
                # this can wind up generating any kind of error, because it turns into
                # a call to a dynamically set authentication module.
                self._client = self._scheduler_client.get_thrift_client()
                break
            except SchedulerClient.CouldNotConnect as e:
                log.warning('Could not connect to scheduler: %s' % e)
            except Exception as e:
                # turn any auth module exception into an auth error.
                log.debug(
                    'Warning: got an unknown exception during authentication:')
                log.debug(traceback.format_exc())
                raise self.AuthError('Error connecting to scheduler: %s' % e)
        if not self._client:
            raise self.TimeoutError(
                'Timed out trying to connect to scheduler at %s' %
                self.cluster.name)

    def __getattr__(self, method_name):
        # If the method does not exist, getattr will return AttributeError for us.
        method = getattr(AuroraAdmin.Client, method_name)
        if not callable(method):
            return method

        @functools.wraps(method)
        def method_wrapper(*args):
            with self._lock:
                start = time.time()
                while not self._terminating.is_set() and (
                        time.time() - start) < self.RPC_MAXIMUM_WAIT.as_(
                            Time.SECONDS):

                    try:
                        method = getattr(self.client(), method_name)
                        if not callable(method):
                            return method

                        resp = method(*args)
                        if resp is not None and resp.responseCode == ResponseCode.ERROR_TRANSIENT:
                            raise self.TransientError(
                                ", ".join([m.message for m in resp.
                                           details] if resp.details else []))
                        return resp
                    except TRequestsTransport.AuthError as e:
                        log.error(
                            self.scheduler_client().get_failed_auth_message())
                        raise self.AuthError(e)
                    except (TTransport.TTransportException, self.TimeoutError,
                            self.TransientError) as e:
                        if not self._terminating.is_set():
                            log.warning(
                                'Connection error with scheduler: %s, reconnecting...'
                                % e)
                            self.invalidate()
                            self._terminating.wait(
                                self.RPC_RETRY_INTERVAL.as_(Time.SECONDS))
                    except Exception as e:
                        # Take any error that occurs during the RPC call, and transform it
                        # into something clients can handle.
                        if not self._terminating.is_set():
                            raise self.ThriftInternalError(
                                "Error during thrift call %s to %s: %s" %
                                (method_name, self.cluster.name, e))
                if not self._terminating.is_set():
                    raise self.TimeoutError(
                        'Timed out attempting to issue %s to %s' %
                        (method_name, self.cluster.name))

        return method_wrapper
Exemple #3
0
class HostMaintenance(object):
  """Submit requests to the scheduler to put hosts into and out of maintenance
  mode so they can be operated upon without causing LOST tasks.

  Aurora provides a two-tiered concept of Maintenance. The first step is to initiate maintenance,
  which will ask the Aurora scheduler to de-prioritize scheduling on a large set of hosts (the ones
  that will be operated upon during this maintenance window).  Once all hosts have been tagged in
  this manner, the operator can begin draining individual machines, which will have all user-tasks
  killed and rescheduled.  When the tasks get placed onto a new machine, the scheduler will first
  look for hosts that do not have the maintenance tag, which will help decrease churn and prevent a
  task from being constantly killed as its hosts go down from underneath it.
  """

  SLA_MIN_JOB_INSTANCE_COUNT = 20
  STATUS_POLL_INTERVAL = Amount(5, Time.SECONDS)
  MAX_STATUS_WAIT = Amount(5, Time.MINUTES)

  @classmethod
  def iter_batches(cls, hostnames, grouping_function=DEFAULT_GROUPING):
    groups = group_hosts(hostnames, grouping_function)
    groups = sorted(groups.items(), key=lambda v: v[0])
    for group in groups:
      yield Hosts(group[1])

  def __init__(self, cluster, verbosity, wait_event=None, bypass_leader_redirect=False):
    self._client = make_admin_client(
        cluster=cluster,
        verbose=verbosity == 'verbose',
        bypass_leader_redirect=bypass_leader_redirect)
    self._wait_event = wait_event or Event()

  def _drain_hosts(self, drainable_hosts):
    """"Drains tasks from the specified hosts.

    This will move active tasks on these hosts to the DRAINING state, causing them to be
    rescheduled elsewhere.

    :param drainable_hosts: Hosts that are in maintenance mode and ready to be drained
    :type drainable_hosts: gen.apache.aurora.ttypes.Hosts
    :rtype: set of host names failed to drain
    """
    check_and_log_response(self._client.drain_hosts(drainable_hosts))
    drainable_hostnames = [hostname for hostname in drainable_hosts.hostNames]

    total_wait = self.STATUS_POLL_INTERVAL
    not_drained_hostnames = set(drainable_hostnames)
    while not self._wait_event.is_set() and not_drained_hostnames:
      log.info('Waiting for hosts to be in DRAINED: %s' % not_drained_hostnames)
      self._wait_event.wait(self.STATUS_POLL_INTERVAL.as_(Time.SECONDS))

      statuses = self.check_status(list(not_drained_hostnames))
      not_drained_hostnames = set(h[0] for h in statuses if h[1] != 'DRAINED')

      total_wait += self.STATUS_POLL_INTERVAL
      if not_drained_hostnames and total_wait > self.MAX_STATUS_WAIT:
        log.warning('Failed to move all hosts into DRAINED within %s:\n%s' %
            (self.MAX_STATUS_WAIT,
            '\n'.join("\tHost:%s\tStatus:%s" % h for h in sorted(statuses) if h[1] != 'DRAINED')))
        break

    return not_drained_hostnames

  def _complete_maintenance(self, drained_hosts):
    """End the maintenance status for a given set of hosts.

    :param drained_hosts: Hosts that are drained and finished being operated upon
    :type drained_hosts: gen.apache.aurora.ttypes.Hosts
    """
    check_and_log_response(self._client.end_maintenance(drained_hosts))
    resp = self._client.maintenance_status(drained_hosts)
    for host_status in resp.result.maintenanceStatusResult.statuses:
      if host_status.mode != MaintenanceMode.NONE:
        log.warning('%s is DRAINING or in DRAINED' % host_status.host)

  def _check_sla(self, hostnames, grouping_function, percentage, duration):
    """Check if the provided list of hosts passes the job uptime SLA check.

    This is an all-or-nothing check, meaning that all provided hosts must pass their job
    SLA check for the maintenance to proceed.

    :param hostnames: list of host names to check SLA for
    :type hostnames: list of strings
    :param grouping_function: grouping function to apply to the given hosts
    :type grouping_function: function
    :param percentage: SLA uptime percentage override
    :type percentage: float
    :param duration: SLA uptime duration override
    :type duration: twitter.common.quantity.Amount
    :rtype: set of unsafe hosts
    """
    vector = self._client.sla_get_safe_domain_vector(self.SLA_MIN_JOB_INSTANCE_COUNT, hostnames)
    host_groups = vector.probe_hosts(
      percentage,
      duration.as_(Time.SECONDS),
      grouping_function)

    unsafe_hostnames = set()
    # Given that maintenance is performed 1 group at a time, any result longer than 1 group
    # should be considered a batch failure.
    if host_groups:
      if len(host_groups) > 1:
        log.error('Illegal multiple groups detected in SLA results. Skipping hosts: %s' % hostnames)
        return set(hostnames)

      results, unsafe_hostnames = format_sla_results(host_groups, unsafe_only=True)
      if results:
        print_results(results)
        return unsafe_hostnames

    return unsafe_hostnames

  def end_maintenance(self, hostnames):
    """Pull a list of hostnames out of maintenance mode.

    :param hostnames: List of hosts to operate upon
    :type hostnames: list of strings
    """
    self._complete_maintenance(Hosts(set(hostnames)))

  def start_maintenance(self, hostnames):
    """Put a list of hostnames into maintenance mode, to de-prioritize scheduling.

    This is part of two-phase draining- tasks will still be running on these hosts until
    drain_hosts is called upon them.

    :param hostnames: List of hosts to set for initial maintenance
    :type hostnames: list of strings
    :rtype: list of hostnames with the maintenance mode set
    """
    resp = self._client.start_maintenance(Hosts(set(hostnames)))
    check_and_log_response(resp)
    result = [host_status.host for host_status in resp.result.startMaintenanceResult.statuses]
    if len(result) != len(hostnames):
      log.warning('Skipping maintenance for unknown hosts: %s' % (set(hostnames) - set(result)))

    return result

  def _operate_on_hosts(self, drained_hosts, callback):
    """Perform a given operation on a list of hosts that are ready for maintenance.

    :param drained_hosts: Hosts that have been drained (via _drain_hosts)
    :type drained_hosts: list of strings
    :param callback: Function to call one hostname at a time
    :type callback: function
    """
    for hostname in drained_hosts:
      callback(hostname)

  def perform_maintenance(self, hostnames, grouping_function=DEFAULT_GROUPING,
                          percentage=None, duration=None, output_file=None, callback=None):
    """Put hosts into maintenance mode and drain them.

    Walk through the process of putting hosts into maintenance and draining them of tasks. The hosts
    will remain in maintenance mode upon completion.


    :param hostnames: A list of hostnames to operate upon
    :type hostnames: list of strings
    :param grouping_function: How to split up the hostname into groups
    :type grouping_function: function
    :param percentage: SLA percentage to use
    :type percentage: float
    :param duration: SLA duration to use
    :type duration: twitter.common.quantity.Time
    :param output_file: file to write hosts that were not drained due to failed SLA check
    :type output_file: string
    :param callback: Function to call once hosts are drained
    :type callback: function
    :rtype: set of host names that were successfully drained
    """
    hostnames = self.start_maintenance(hostnames)
    not_drained_hostnames = set()

    for hosts in self.iter_batches(hostnames, grouping_function):
      log.info('Beginning SLA check for %s' % hosts.hostNames)
      unsafe_hostnames = self._check_sla(
          list(hosts.hostNames),
          grouping_function,
          percentage,
          duration)

      if unsafe_hostnames:
        log.warning('Some hosts did not pass SLA check and will not be drained! '
                    'Skipping hosts: %s' % unsafe_hostnames)
        not_drained_hostnames |= unsafe_hostnames
        drainable_hostnames = hosts.hostNames - unsafe_hostnames
        if not drainable_hostnames:
          continue
        hosts = Hosts(drainable_hostnames)
      else:
        log.info('All hosts passed SLA check.')

      not_drained_hostnames |= self._drain_hosts(hosts)

      if callback:
        self._operate_on_hosts(hosts.hostNames - not_drained_hostnames, callback)

    if not_drained_hostnames:
      output = '\n'.join(list(not_drained_hostnames))
      log.info('The following hosts WERE NOT DRAINED due to failed SLA check or external failures:')
      print(output)
      if output_file:
        try:
          with open(output_file, 'w') as fp:
            fp.write(output)
            fp.write('\n')
          log.info('Written unsafe host names into: %s' % output_file)
        except IOError as e:
          log.error('Failed to write into the output file: %s' % e)

    return set(hostnames) - not_drained_hostnames

  def check_status(self, hostnames):
    """Query the scheduler to determine the maintenance status for a list of hostnames

    :param hostnames: Hosts to query for
    :type hostnames: list of strings
    :rtype: list of 2-tuples, hostname and MaintenanceMode
    """
    resp = self._client.maintenance_status(Hosts(set(hostnames)))
    check_and_log_response(resp)
    statuses = []
    for host_status in resp.result.maintenanceStatusResult.statuses:
      statuses.append((host_status.host, MaintenanceMode._VALUES_TO_NAMES[host_status.mode]))
    return statuses
def test_nested_scopes():
    rm = _cleared_rm()
    mg = rm.scope('a').scope('b').scope('c').register('123')
    mg.write(Amount(1, Time.MILLISECONDS))
    assert rm.sample() == {'a.b.c.123': '1 ms'}
    rm.clear()
Exemple #5
0
class TaskResourceMonitor(ResourceMonitorBase, ExceptionalThread):
    """ Lightweight thread to aggregate resource consumption for a task's constituent processes.
      Actual resource calculation is delegated to collectors; this class periodically polls the
      collectors and aggregates into a representation for the entire task. Also maintains a limited
      history of previous sample results.
  """

    PROCESS_COLLECTION_INTERVAL = Amount(20, Time.SECONDS)
    HISTORY_TIME = Amount(1, Time.HOURS)

    def __init__(self,
                 task_id,
                 task_monitor,
                 disk_collector_provider=DiskCollectorProvider(),
                 process_collection_interval=PROCESS_COLLECTION_INTERVAL,
                 disk_collection_interval=DiskCollectorSettings.
                 DISK_COLLECTION_INTERVAL,
                 history_time=HISTORY_TIME,
                 history_provider=HistoryProvider()):
        """
      task_monitor: TaskMonitor object specifying the task whose resources should be monitored
      sandbox: Directory for which to monitor disk utilisation
    """
        self._task_monitor = task_monitor  # exposes PIDs, sandbox
        self._task_id = task_id
        log.debug('Initialising resource collection for task %s',
                  self._task_id)
        self._process_collectors = dict(
        )  # ProcessStatus => ProcessTreeCollector

        self._disk_collector_provider = disk_collector_provider
        self._disk_collector = None
        self._process_collection_interval = process_collection_interval.as_(
            Time.SECONDS)
        self._disk_collection_interval = disk_collection_interval.as_(
            Time.SECONDS)
        min_collection_interval = min(self._process_collection_interval,
                                      self._disk_collection_interval)
        self._history = history_provider.provides(history_time,
                                                  min_collection_interval)
        self._kill_signal = threading.Event()
        ExceptionalThread.__init__(self,
                                   name='%s[%s]' %
                                   (self.__class__.__name__, task_id))
        self.daemon = True

    def sample(self):
        if not self.is_alive():
            log.warning(
                "TaskResourceMonitor not running - sample may be inaccurate")
        return self.sample_at(time.time())

    def sample_at(self, timestamp):
        _timestamp, full_resources = self._history.get(timestamp)

        aggregated_procs = sum(
            map(attrgetter('num_procs'), full_resources.proc_usage.values()))
        aggregated_sample = sum(
            map(attrgetter('process_sample'),
                full_resources.proc_usage.values()), ProcessSample.empty())

        return _timestamp, self.AggregateResourceResult(
            aggregated_procs, aggregated_sample, full_resources.disk_usage)

    def sample_by_process(self, process_name):
        try:
            process = [
                process for process in self._get_active_processes()
                if process.process == process_name
            ].pop()
        except IndexError:
            raise ValueError(
                'No active process found with name "%s" in this task' %
                process_name)
        else:
            # Since this might be called out of band (before the main loop is aware of the process)
            if process not in self._process_collectors:
                self._process_collectors[process] = ProcessTreeCollector(
                    process.pid)

            # The sample obtained from history is tuple of (timestamp, FullResourceResult), and per
            # process sample can be lookup up from FullResourceResult
            _, full_resources = self._history.get(time.time())
            if process in full_resources.proc_usage:
                return full_resources.proc_usage[process].process_sample

            self._process_collectors[process].sample()
            return self._process_collectors[process].value

    def _get_active_processes(self):
        """Get a list of ProcessStatus objects representing currently-running processes in the task"""
        return [
            process
            for process, _ in self._task_monitor.get_active_processes()
        ]

    def kill(self):
        """Signal that the thread should cease collecting resources and terminate"""
        self._kill_signal.set()

    def run(self):
        """Thread entrypoint. Loop indefinitely, polling collectors at self._collection_interval and
    collating samples."""

        log.debug('Commencing resource monitoring for task "%s"',
                  self._task_id)
        next_process_collection = 0
        next_disk_collection = 0

        while not self._kill_signal.is_set():
            now = time.time()

            if now > next_process_collection:
                next_process_collection = now + self._process_collection_interval
                actives = set(self._get_active_processes())
                current = set(self._process_collectors)
                for process in current - actives:
                    self._process_collectors.pop(process)
                for process in actives - current:
                    self._process_collectors[process] = ProcessTreeCollector(
                        process.pid)
                for process, collector in self._process_collectors.items():
                    collector.sample()

            if now > next_disk_collection:
                next_disk_collection = now + self._disk_collection_interval
                if not self._disk_collector:
                    sandbox = self._task_monitor.get_sandbox()
                    if sandbox:
                        self._disk_collector = self._disk_collector_provider.provides(
                            sandbox)
                if self._disk_collector:
                    self._disk_collector.sample()
                else:
                    log.debug('No sandbox detected yet for %s', self._task_id)

            try:
                disk_usage = self._disk_collector.value if self._disk_collector else 0

                proc_usage_dict = dict()
                for process, collector in self._process_collectors.items():
                    proc_usage_dict.update({
                        process:
                        self.ProcResourceResult(collector.value,
                                                collector.procs)
                    })

                self._history.add(
                    now, self.FullResourceResult(proc_usage_dict, disk_usage))
            except ValueError as err:
                log.warning("Error recording resource sample: %s", err)

            log.debug(
                "TaskResourceMonitor: finished collection of %s in %.2fs",
                self._task_id, (time.time() - now))

            # Sleep until any of the following conditions are met:
            # - it's time for the next disk collection
            # - it's time for the next process collection
            # - the result from the last disk collection is available via the DiskCollector
            # - the TaskResourceMonitor has been killed via self._kill_signal
            now = time.time()
            next_collection = min(next_process_collection - now,
                                  next_disk_collection - now)

            if self._disk_collector:
                waiter = EventMuxer(self._kill_signal,
                                    self._disk_collector.completed_event)
            else:
                waiter = self._kill_signal

            if next_collection > 0:
                waiter.wait(timeout=next_collection)
            else:
                log.warning(
                    'Task resource collection is backlogged. Consider increasing '
                    'process_collection_interval and disk_collection_interval.'
                )

        log.debug('Stopping resource monitoring for task "%s"', self._task_id)
class FastThermosExecutor(AuroraExecutor):
    STOP_WAIT = Amount(0, Time.SECONDS)
Exemple #7
0
 class FakeWeb(Web):
   NS_TIMEOUT = Amount(1, Time.MILLISECONDS)
   def _resolves(self, fullurl):
     event.wait()
   def _reachable(self, fullurl):
     return True
Exemple #8
0
class RESTfulArtifactCache(ArtifactCache):
  """An artifact cache that stores the artifacts on a RESTful service."""

  READ_SIZE = int(Amount(4, Data.MB).as_(Data.BYTES))

  def __init__(self, log, artifact_root, url_base, compress=True):
    """
    url_base: The prefix for urls on some RESTful service. We must be able to PUT and GET to any
              path under this base.
    compress: Whether to compress the artifacts before storing them.
    """
    ArtifactCache.__init__(self, log, artifact_root)
    parsed_url = urlparse.urlparse(url_base)
    if parsed_url.scheme == 'http':
      self._ssl = False
    elif parsed_url.scheme == 'https':
      self._ssl = True
    else:
      raise ValueError('RESTfulArtifactCache only supports HTTP and HTTPS')
    self._timeout_secs = 4.0
    self._netloc = parsed_url.netloc
    self._path_prefix = parsed_url.path.rstrip('/')
    self.compress = compress

  def try_insert(self, cache_key, paths):
    with temporary_file_path() as tarfile:
      artifact = TarballArtifact(self.artifact_root, tarfile, self.compress)
      artifact.collect(paths)

      with open(tarfile, 'rb') as infile:
        remote_path = self._remote_path_for_key(cache_key)
        if not self._request('PUT', remote_path, body=infile):
          raise self.CacheError('Failed to PUT to %s. Error: 404' % self._url_string(remote_path))

  def has(self, cache_key):
    return self._request('HEAD', self._remote_path_for_key(cache_key)) is not None

  def use_cached_files(self, cache_key):
    # This implementation fetches the appropriate tarball and extracts it.
    remote_path = self._remote_path_for_key(cache_key)
    try:
      # Send an HTTP request for the tarball.
      response = self._request('GET', remote_path)
      if response is None:
        return None

      done = False
      with temporary_file() as outfile:
        total_bytes = 0
        # Read the data in a loop.
        while not done:
          data = response.read(self.READ_SIZE)
          outfile.write(data)
          if len(data) < self.READ_SIZE:
            done = True
          total_bytes += len(data)
        outfile.close()
        self.log.debug('Read %d bytes from artifact cache at %s' %
                       (total_bytes,self._url_string(remote_path)))

        # Extract the tarfile.
        artifact = TarballArtifact(self.artifact_root, outfile.name, self.compress)
        artifact.extract()
        return artifact
    except Exception as e:
      self.log.warn('Error while reading from remote artifact cache: %s' % e)
      return None

  def delete(self, cache_key):
    remote_path = self._remote_path_for_key(cache_key)
    self._request('DELETE', remote_path)

  def prune(self, age_hours):
    # Doesn't make sense for a client to prune a remote server.
    # Better to run tmpwatch on the server.
    pass

  def _remote_path_for_key(self, cache_key):
    # Note: it's important to use the id as well as the hash, because two different targets
    # may have the same hash if both have no sources, but we may still want to differentiate them.
    return '%s/%s/%s%s' % (self._path_prefix, cache_key.id, cache_key.hash,
                               '.tar.gz' if self.compress else '.tar')

  def _connect(self):
    if self._ssl:
      return httplib.HTTPSConnection(self._netloc, timeout=self._timeout_secs)
    else:
      return httplib.HTTPConnection(self._netloc, timeout=self._timeout_secs)

  # Returns a response if we get a 200, None if we get a 404 and raises an exception otherwise.
  def _request(self, method, path, body=None):
    self.log.debug('Sending %s request to %s' % (method, self._url_string(path)))
    # TODO(benjy): Keep connection open and reuse?
    conn = self._connect()
    conn.request(method, path, body=body)
    response = conn.getresponse()
    # Allow all 2XX responses. E.g., nginx returns 201 on PUT. HEAD may return 204.
    if int(response.status / 100) == 2:
      return response
    elif response.status == 404:
      return None
    else:
      raise self.CacheError('Failed to %s %s. Error: %d %s' % (method, self._url_string(path),
                                                               response.status, response.reason))

  def _url_string(self, path):
    return '%s://%s%s' % (('https' if self._ssl else 'http'), self._netloc, path)
def test_comparison_mixed_units():
  assert Amount(1, Time.MINUTES) > Amount(59, Time.SECONDS)
  assert Amount(1, Time.MINUTES) == Amount(60, Time.SECONDS)
  assert Amount(1, Time.MINUTES) < Amount(61, Time.SECONDS)

  assert Amount(59, Time.SECONDS) < Amount(1, Time.MINUTES)
  assert Amount(60, Time.SECONDS) == Amount(1, Time.MINUTES)
  assert Amount(61, Time.SECONDS) > Amount(1, Time.MINUTES)
Exemple #10
0
 class FastTaskRunner(TaskRunner):
   COORDINATOR_INTERVAL_SLEEP = Amount(1, Time.MILLISECONDS)
Exemple #11
0
 class FastThermosGCExecutor(ThermosGCExecutor):
   POLL_WAIT = Amount(1, Time.MILLISECONDS)
Exemple #12
0
def make_gc_executor_with_timeouts(maximum_executor_lifetime=Amount(1, Time.DAYS)):
  class TimeoutGCExecutor(ThinTestThermosGCExecutor):
    MAXIMUM_EXECUTOR_LIFETIME = maximum_executor_lifetime
  return TimeoutGCExecutor
Exemple #13
0
class Updater(object):
    """Performs an update command using a collection of parallel threads.
  The number of parallel threads used is determined by the UpdateConfig.batch_size."""
    class Error(Exception):
        """Updater error wrapper."""
        pass

    RPC_COMPLETION_TIMEOUT_SECS = Amount(120, Time.SECONDS)

    OPERATION_CONFIGS = namedtuple('OperationConfigs',
                                   ['from_config', 'to_config'])
    INSTANCE_CONFIGS = namedtuple(
        'InstanceConfigs',
        ['remote_config_map', 'local_config_map', 'instances_to_process'])

    INSTANCE_DATA = namedtuple('InstanceData',
                               ['instance_id', 'operation_configs'])

    def __init__(self,
                 config,
                 health_check_interval_seconds,
                 scheduler=None,
                 instance_watcher=None,
                 quota_check=None,
                 job_monitor=None,
                 scheduler_mux=None,
                 rpc_completion_timeout=RPC_COMPLETION_TIMEOUT_SECS):
        self._config = config
        self._job_key = JobKey(role=config.role(),
                               environment=config.environment(),
                               name=config.name())
        self._health_check_interval_seconds = health_check_interval_seconds
        self._scheduler = scheduler or SchedulerProxy(config.cluster())
        self._quota_check = quota_check or QuotaCheck(self._scheduler)
        self._scheduler_mux = scheduler_mux or SchedulerMux()
        self._job_monitor = job_monitor or JobMonitor(
            self._scheduler,
            self._config.job_key(),
            scheduler_mux=self._scheduler_mux)
        self._rpc_completion_timeout = rpc_completion_timeout
        try:
            self._update_config = UpdaterConfig(**config.update_config().get())
        except ValueError as e:
            raise self.Error(str(e))
        if self._update_config.pulse_interval_secs:
            raise self.Error(
                'Pulse interval seconds is not supported by the client updater.'
            )
        self._lock = None
        self._thread_lock = threading_lock()
        self._batch_wait_event = Event()
        self._batch_completion_queue = Queue()
        self.failure_threshold = FailureThreshold(
            self._update_config.max_per_instance_failures,
            self._update_config.max_total_failures)
        self._watcher = instance_watcher or InstanceWatcher(
            self._scheduler,
            self._job_key,
            self._update_config.restart_threshold,
            self._update_config.watch_secs,
            self._health_check_interval_seconds,
            scheduler_mux=self._scheduler_mux)
        self._terminating = False

    def _start(self):
        """Starts an update by applying an exclusive lock on a job being updated.

    Returns Response instance from the scheduler call.
    """
        resp = self._scheduler.acquireLock(LockKey(job=self._job_key))
        if resp.responseCode == ResponseCode.OK:
            self._lock = resp.result.acquireLockResult.lock
        return resp

    def _finish(self):
        """Finishes an update by removing an exclusive lock on an updated job.

    Returns Response instance from the scheduler call.
    """
        resp = self._scheduler.releaseLock(self._lock, LockValidation.CHECKED)

        if resp.responseCode == ResponseCode.OK:
            self._lock = None
        else:
            log.error('There was an error finalizing the update: %s' %
                      combine_messages(resp))
        return resp

    def int_handler(self, *args):
        """Ensures keyboard interrupt exception is raised on a main thread."""
        raise KeyboardInterrupt()

    def _update(self, instance_configs):
        """Drives execution of the update logic.

    Performs instance updates in parallel using a number of threads bound by
    the batch_size config option.

    Arguments:
    instance_configs -- list of instance update configurations to go through.

    Returns the set of instances that failed to update.
    """
        # Register signal handler to ensure KeyboardInterrupt is received by a main thread.
        signal.signal(signal.SIGINT, self.int_handler)

        instances_to_update = [
            self.INSTANCE_DATA(
                instance_id,
                self.OPERATION_CONFIGS(
                    from_config=instance_configs.remote_config_map,
                    to_config=instance_configs.local_config_map))
            for instance_id in instance_configs.instances_to_process
        ]

        log.info('Instances to update: %s' %
                 instance_configs.instances_to_process)
        update_queue = self._update_instances_in_parallel(
            self._update_instance, instances_to_update)

        if self._is_failed_update(quiet=False):
            if not self._update_config.rollback_on_failure:
                log.info(
                    'Rollback on failure is disabled in config. Aborting rollback'
                )
                return

            rollback_ids = self._get_rollback_ids(
                instance_configs.instances_to_process, update_queue)
            instances_to_revert = [
                self.INSTANCE_DATA(
                    instance_id,
                    self.OPERATION_CONFIGS(
                        from_config=instance_configs.local_config_map,
                        to_config=instance_configs.remote_config_map))
                for instance_id in rollback_ids
            ]

            log.info('Reverting update for: %s' % rollback_ids)
            self._update_instances_in_parallel(self._revert_instance,
                                               instances_to_revert)

        return not self._is_failed_update()

    def _update_instances_in_parallel(self, target, instances_to_update):
        """Processes instance updates in parallel and waits for completion.

    Arguments:
    target -- target method to handle instance update.
    instances_to_update -- list of InstanceData with update details.

    Returns Queue with non-updated instance data.
    """
        log.info('Processing in parallel with %s worker thread(s)' %
                 self._update_config.batch_size)
        instance_queue = Queue()
        for instance_to_update in instances_to_update:
            instance_queue.put(instance_to_update)

        try:
            threads = []
            for _ in range(self._update_config.batch_size):
                threads.append(
                    spawn_worker(target,
                                 kwargs={'instance_queue': instance_queue}))

            for thread in threads:
                thread.join_and_raise()
        except Exception as e:
            log.debug('Caught unhandled exception: %s' % e)
            self._terminate()
            raise

        return instance_queue

    def _try_reset_batch_wait_event(self, instance_id, instance_queue):
        """Resets batch_wait_event in case the current batch is filled up.

    This is a helper method that separates thread locked logic. Called from
    _wait_for_batch_completion_if_needed() when a given instance update completes.
    Resumes worker threads if all batch instances are updated.

    Arguments:
    instance_id -- Instance ID being processed.
    instance_queue -- Instance update work queue.
    """
        with self._thread_lock:
            log.debug("Instance ID %s: Completion queue size %s" %
                      (instance_id, self._batch_completion_queue.qsize()))
            log.debug("Instance ID %s: Instance queue size %s" %
                      (instance_id, instance_queue.qsize()))
            self._batch_completion_queue.put(instance_id)
            filled_up = self._batch_completion_queue.qsize(
            ) % self._update_config.batch_size == 0
            all_done = instance_queue.qsize() == 0
            if filled_up or all_done:
                # Required batch size of completed instances has filled up -> unlock waiting threads.
                log.debug('Instance %s completes the batch wait.' %
                          instance_id)
                self._batch_wait_event.set()
                self._batch_wait_event.clear()
                return True

        return False

    def _wait_for_batch_completion_if_needed(self, instance_id,
                                             instance_queue):
        """Waits for batch completion if wait_for_batch_completion flag is set.

    Arguments:
    instance_id -- Instance ID.
    instance_queue -- Instance update work queue.
    """
        if not self._update_config.wait_for_batch_completion:
            return

        if not self._try_reset_batch_wait_event(instance_id, instance_queue):
            # The current batch has not filled up -> block the work thread.
            log.debug('Instance %s is done. Waiting for batch to complete.' %
                      instance_id)
            self._batch_wait_event.wait()

    def _terminate(self):
        """Attempts to terminate all outstanding activities."""
        if not self._terminating:
            log.info('Cleaning up')
            self._terminating = True
            self._scheduler.terminate()
            self._job_monitor.terminate()
            self._scheduler_mux.terminate()
            self._watcher.terminate()
            self._batch_wait_event.set()

    def _update_instance(self, instance_queue):
        """Works through the instance_queue and performs instance updates (one at a time).

    Arguments:
    instance_queue -- Queue of InstanceData to update.
    """
        while not self._terminating and not self._is_failed_update():
            try:
                instance_data = instance_queue.get_nowait()
            except Empty:
                return

            update = True
            restart = False
            while update or restart and not self._terminating and not self._is_failed_update(
            ):
                instances_to_watch = []
                if update:
                    instances_to_watch += self._kill_and_add_instance(
                        instance_data)
                    update = False
                else:
                    instances_to_watch += self._request_restart_instance(
                        instance_data)

                if instances_to_watch:
                    failed_instances = self._watcher.watch(instances_to_watch)
                    restart = self._is_restart_needed(failed_instances)

            self._wait_for_batch_completion_if_needed(
                instance_data.instance_id, instance_queue)

    def _revert_instance(self, instance_queue):
        """Works through the instance_queue and performs instance rollbacks (one at a time).

    Arguments:
    instance_queue -- Queue of InstanceData to revert.
    """
        while not self._terminating:
            try:
                instance_data = instance_queue.get_nowait()
            except Empty:
                return

            log.info('Reverting instance: %s' % instance_data.instance_id)
            instances_to_watch = self._kill_and_add_instance(instance_data)
            if instances_to_watch and self._watcher.watch(instances_to_watch):
                log.error('Rollback failed for instance: %s' %
                          instance_data.instance_id)

    def _kill_and_add_instance(self, instance_data):
        """Acquires update instructions and performs required kill/add/kill+add sequence.

    Arguments:
    instance_data -- InstanceData to update.

    Returns added instance ID.
    """
        log.info('Examining instance: %s' % instance_data.instance_id)
        to_kill, to_add = self._create_kill_add_lists(
            [instance_data.instance_id], instance_data.operation_configs)
        if not to_kill and not to_add:
            log.info('Skipping unchanged instance: %s' %
                     instance_data.instance_id)
            return to_add

        if to_kill:
            self._request_kill_instance(instance_data)
        if to_add:
            self._request_add_instance(instance_data)

        return to_add

    def _request_kill_instance(self, instance_data):
        """Instructs the scheduler to kill instance and waits for completion.

    Arguments:
    instance_data -- InstanceData to kill.
    """
        log.info('Killing instance: %s' % instance_data.instance_id)
        self._enqueue_and_wait(instance_data, self._kill_instances)
        result = self._job_monitor.wait_until(JobMonitor.terminal,
                                              [instance_data.instance_id],
                                              with_timeout=True)

        if not result:
            raise self.Error('Instance %s was not killed in time' %
                             instance_data.instance_id)
        log.info('Killed: %s' % instance_data.instance_id)

    def _request_add_instance(self, instance_data):
        """Instructs the scheduler to add instance.

    Arguments:
    instance_data -- InstanceData to add.
    """
        log.info('Adding instance: %s' % instance_data.instance_id)
        self._enqueue_and_wait(instance_data, self._add_instances)
        log.info('Added: %s' % instance_data.instance_id)

    def _request_restart_instance(self, instance_data):
        """Instructs the scheduler to restart instance.

    Arguments:
    instance_data -- InstanceData to restart.

    Returns restarted instance ID.
    """
        log.info('Restarting instance: %s' % instance_data.instance_id)
        self._enqueue_and_wait(instance_data, self._restart_instances)
        log.info('Restarted: %s' % instance_data.instance_id)
        return [instance_data.instance_id]

    def _enqueue_and_wait(self, instance_data, command):
        """Queues up the scheduler call and waits for completion.

    Arguments:
    instance_data -- InstanceData to query scheduler for.
    command -- scheduler command to run.
    """
        try:
            self._scheduler_mux.enqueue_and_wait(
                command, instance_data, timeout=self._rpc_completion_timeout)
        except SchedulerMux.Error as e:
            raise self.Error(
                'Failed to complete instance %s operation. Reason: %s' %
                (instance_data.instance_id, e))

    def _is_failed_update(self, quiet=True):
        """Verifies the update status in a thread-safe manner.

    Arguments:
    quiet -- Whether the logging should be suppressed in case of a failed update. Default True.

    Returns True if update failed, False otherwise.
    """
        with self._thread_lock:
            return self.failure_threshold.is_failed_update(
                log_errors=not quiet)

    def _is_restart_needed(self, failed_instances):
        """Checks if there are any failed instances recoverable via restart.

    Arguments:
    failed_instances -- Failed instance IDs.

    Returns True if restart is allowed, False otherwise (i.e. update failed).
    """
        if not failed_instances:
            return False

        log.info('Failed instances: %s' % failed_instances)

        with self._thread_lock:
            unretryable_instances = self.failure_threshold.update_failure_counts(
                failed_instances)
            if unretryable_instances:
                log.warn('Not restarting failed instances %s, which exceeded '
                         'maximum allowed instance failure limit of %s' %
                         (unretryable_instances,
                          self._update_config.max_per_instance_failures))
            return False if unretryable_instances else True

    def _get_rollback_ids(self, update_list, update_queue):
        """Gets a list of instance ids to rollback.

    Arguments:
    update_list -- original list of instances intended for update.
    update_queue -- untouched instances not processed during update.

    Returns sorted list of instance IDs to rollback.
    """
        untouched_ids = []
        while not update_queue.empty():
            untouched_ids.append(update_queue.get_nowait().instance_id)

        return sorted(list(set(update_list) - set(untouched_ids)),
                      reverse=True)

    def _hashable(self, element):
        if isinstance(element, (list, set)):
            return tuple(sorted(self._hashable(item) for item in element))
        elif isinstance(element, dict):
            return tuple(
                sorted((self._hashable(key), self._hashable(value))
                       for (key, value) in element.items()))
        return element

    def _thrift_to_json(self, config):
        return json.loads(
            serialize(
                config,
                protocol_factory=TJSONProtocol.TSimpleJSONProtocolFactory()))

    def _diff_configs(self, from_config, to_config):
        # Thrift objects do not correctly compare against each other due to the unhashable nature
        # of python sets. That results in occasional diff failures with the following symptoms:
        # - Sets are not equal even though their reprs are identical;
        # - Items are reordered within thrift structs;
        # - Items are reordered within sets;
        # To overcome all the above, thrift objects are converted into JSON dicts to flatten out
        # thrift type hierarchy. Next, JSONs are recursively converted into nested tuples to
        # ensure proper ordering on compare.
        return ''.join(
            unified_diff(
                repr(self._hashable(self._thrift_to_json(from_config))),
                repr(self._hashable(self._thrift_to_json(to_config)))))

    def _create_kill_add_lists(self, instance_ids, operation_configs):
        """Determines a particular action (kill or add) to use for every instance in instance_ids.

    Arguments:
    instance_ids -- current batch of IDs to process.
    operation_configs -- OperationConfigs with update details.

    Returns lists of instances to kill and to add.
    """
        to_kill = []
        to_add = []
        for instance_id in instance_ids:
            from_config = operation_configs.from_config.get(instance_id)
            to_config = operation_configs.to_config.get(instance_id)

            if from_config and to_config:
                diff_output = self._diff_configs(from_config, to_config)
                if diff_output:
                    log.debug(
                        'Task configuration changed for instance [%s]:\n%s' %
                        (instance_id, diff_output))
                    to_kill.append(instance_id)
                    to_add.append(instance_id)
            elif from_config and not to_config:
                to_kill.append(instance_id)
            elif not from_config and to_config:
                to_add.append(instance_id)
            else:
                raise self.Error('Instance %s is outside of supported range' %
                                 instance_id)

        return to_kill, to_add

    def _kill_instances(self, instance_data):
        """Instructs the scheduler to batch-kill instances and waits for completion.

    Arguments:
    instance_data -- list of InstanceData to kill.
    """
        instance_ids = [data.instance_id for data in instance_data]
        log.debug('Batch killing instances: %s' % instance_ids)
        query = self._create_task_query(instanceIds=frozenset(
            int(s) for s in instance_ids))
        self._check_and_log_response(
            self._scheduler.killTasks(query, self._lock))
        log.debug('Done batch killing instances: %s' % instance_ids)

    def _add_instances(self, instance_data):
        """Instructs the scheduler to batch-add instances.

    Arguments:
    instance_data -- list of InstanceData to add.
    """
        instance_ids = [data.instance_id for data in instance_data]
        to_config = instance_data[0].operation_configs.to_config

        log.debug('Batch adding instances: %s' % instance_ids)
        add_config = AddInstancesConfig(
            key=self._job_key,
            taskConfig=to_config[instance_ids[
                0]],  # instance_ids will always have at least 1 item.
            instanceIds=frozenset(int(s) for s in instance_ids))
        self._check_and_log_response(
            self._scheduler.addInstances(add_config, self._lock))
        log.debug('Done batch adding instances: %s' % instance_ids)

    def _restart_instances(self, instance_data):
        """Instructs the scheduler to batch-restart instances.

    Arguments:
    instance_data -- list of InstanceData to restart.
    """
        instance_ids = [data.instance_id for data in instance_data]
        log.debug('Batch restarting instances: %s' % instance_ids)
        resp = self._scheduler.restartShards(self._job_key, instance_ids,
                                             self._lock)
        self._check_and_log_response(resp)
        log.debug('Done batch restarting instances: %s' % instance_ids)

    def _validate_quota(self, instance_configs):
        """Validates job update will not exceed quota for production tasks.
    Arguments:
    instance_configs -- InstanceConfig with update details.

    Returns Response.OK if quota check was successful.
    """
        instance_operation = self.OPERATION_CONFIGS(
            from_config=instance_configs.remote_config_map,
            to_config=instance_configs.local_config_map)

        def _aggregate_quota(ops_list, config_map):
            request = CapacityRequest()
            for instance in ops_list:
                task = config_map[instance]
                if task.production:
                    request += CapacityRequest.from_task(task)

            return request

        to_kill, to_add = self._create_kill_add_lists(
            instance_configs.instances_to_process, instance_operation)

        return self._quota_check.validate_quota_from_requested(
            self._job_key,
            self._config.job().taskConfig.production,
            _aggregate_quota(to_kill, instance_operation.from_config),
            _aggregate_quota(to_add, instance_operation.to_config))

    def _get_update_instructions(self, instances=None):
        """Loads, validates and populates update working set.

    Arguments:
    instances -- (optional) set of instances to update.

    Returns:
    InstanceConfigs with the following data:
      remote_config_map -- dictionary of {key:instance_id, value:task_config} from scheduler.
      local_config_map  -- dictionary of {key:instance_id, value:task_config} with local
                           task configs validated and populated with default values.
      instances_to_process -- list of instance IDs to go over in update.
    """
        # Load existing tasks and populate remote config map and instance list.
        assigned_tasks = self._get_existing_tasks()
        remote_config_map = {}
        remote_instances = []
        for assigned_task in assigned_tasks:
            remote_config_map[assigned_task.instanceId] = assigned_task.task
            remote_instances.append(assigned_task.instanceId)

        # Validate local job config and populate local task config.
        local_task_config = self._validate_and_populate_local_config()

        # Union of local and remote instance IDs.
        job_config_instances = list(range(self._config.instances()))
        instance_superset = sorted(
            list(set(remote_instances) | set(job_config_instances)))

        # Calculate the update working set.
        if instances is None:
            # Full job update -> union of remote and local instances
            instances_to_process = instance_superset
        else:
            # Partial job update -> validate all instances are recognized
            instances_to_process = instances
            unrecognized = list(set(instances) - set(instance_superset))
            if unrecognized:
                raise self.Error(
                    'Instances %s are outside of supported range' %
                    unrecognized)

        # Populate local config map
        local_config_map = dict.fromkeys(job_config_instances,
                                         local_task_config)

        return self.INSTANCE_CONFIGS(remote_config_map, local_config_map,
                                     instances_to_process)

    def _get_existing_tasks(self):
        """Loads all existing tasks from the scheduler.

    Returns a list of AssignedTasks.
    """
        resp = self._scheduler.getTasksStatus(self._create_task_query())
        self._check_and_log_response(resp)
        return [t.assignedTask for t in resp.result.scheduleStatusResult.tasks]

    def _validate_and_populate_local_config(self):
        """Validates local job configuration and populates local task config with default values.

    Returns a TaskConfig populated with default values.
    """
        resp = self._scheduler.populateJobConfig(self._config.job())
        self._check_and_log_response(resp)
        return resp.result.populateJobResult.taskConfig

    def _replace_template_if_cron(self):
        """Checks if the provided job config represents a cron job and if so, replaces it.

    Returns True if job is cron and False otherwise.
    """
        if self._config.job().cronSchedule:
            resp = self._scheduler.replaceCronTemplate(self._config.job(),
                                                       self._lock)
            self._check_and_log_response(resp)
            return True
        else:
            return False

    def _create_task_query(self, instanceIds=None):
        return TaskQuery(jobKeys=[self._job_key],
                         statuses=ACTIVE_STATES,
                         instanceIds=instanceIds)

    def _failed_response(self, message):
        # TODO(wfarner): Avoid synthesizing scheduler responses, consider using an exception instead.
        return Response(responseCode=ResponseCode.ERROR,
                        details=[ResponseDetail(message=message)])

    def update(self, instances=None):
        """Performs the job update, blocking until it completes.

    A rollback will be performed if the update was considered a failure based on the
    update configuration.

    Arguments:
    instances -- (optional) instances to update. If not specified, all instances will be updated.

    Returns a response object with update result status.
    """
        try:
            resp = self._start()
            if resp.responseCode != ResponseCode.OK:
                return resp

            try:
                # Handle cron jobs separately from other jobs.
                if self._replace_template_if_cron():
                    log.info(
                        'Cron template updated, next run will reflect changes')
                    return self._finish()
                else:
                    try:
                        instance_configs = self._get_update_instructions(
                            instances)
                        self._check_and_log_response(
                            self._validate_quota(instance_configs))
                    except self.Error as e:
                        # Safe to release the lock acquired above as no job mutation has happened yet.
                        self._finish()
                        return self._failed_response(
                            'Unable to start job update: %s' % e)

                    if not self._update(instance_configs):
                        log.warn('Update failures threshold reached')
                        self._finish()
                        return self._failed_response('Update reverted')
                    else:
                        log.info('Update successful')
                        return self._finish()
            except (self.Error, ExecutionError, Exception) as e:
                return self._failed_response(
                    'Aborting update without rollback! Fatal error: %s' % e)
        finally:
            self._scheduler_mux.terminate()

    @classmethod
    def cancel_update(cls, scheduler, job_key):
        """Cancels an update process by removing an exclusive lock on a provided job.

    Arguments:
    scheduler -- scheduler instance to use.
    job_key -- job key to cancel update for.

    Returns a response object with cancel update result status.
    """
        return scheduler.releaseLock(
            Lock(key=LockKey(job=job_key.to_thrift())),
            LockValidation.UNCHECKED)

    def _check_and_log_response(self, resp):
        """Checks scheduler return status, raises Error in case of unexpected response.

    Arguments:
    resp -- scheduler response object.

    Raises Error in case of unexpected response status.
    """
        message = format_response(resp)
        if resp.responseCode == ResponseCode.OK:
            log.debug(message)
        else:
            raise self.Error(message)
Exemple #14
0
class TaskObserver(ExceptionalThread, Lockable):
  """
    The TaskObserver monitors the thermos checkpoint root for active/finished
    tasks.  It is used to be the oracle of the state of all thermos tasks on
    a machine.

    It currently returns JSON, but really should just return objects.  We should
    then build an object->json translator.
  """
  class UnexpectedError(Exception): pass
  class UnexpectedState(Exception): pass

  POLLING_INTERVAL = Amount(1, Time.SECONDS)

  def __init__(self, root, resource_monitor_class=TaskResourceMonitor):
    self._pathspec = TaskPath(root=root)
    self._detector = TaskDetector(root)
    if not issubclass(resource_monitor_class, ResourceMonitorBase):
      raise ValueError("resource monitor class must implement ResourceMonitorBase!")
    self._resource_monitor = resource_monitor_class
    self._active_tasks = {}    # task_id => ActiveObservedTask
    self._finished_tasks = {}  # task_id => FinishedObservedTask
    self._stop_event = threading.Event()
    ExceptionalThread.__init__(self)
    Lockable.__init__(self)
    self.daemon = True

  @property
  def active_tasks(self):
    """Return a dictionary of active Tasks"""
    return self._active_tasks

  @property
  def finished_tasks(self):
    """Return a dictionary of finished Tasks"""
    return self._finished_tasks

  @property
  def all_tasks(self):
    """Return a dictionary of all Tasks known by the TaskObserver"""
    return dict(self.active_tasks.items() + self.finished_tasks.items())

  def stop(self):
    self._stop_event.set()

  def start(self):
    ExceptionalThread.start(self)

  @Lockable.sync
  def add_active_task(self, task_id):
    if task_id in self.finished_tasks:
      log.error('Found an active task (%s) in finished tasks?' % task_id)
      return
    task_monitor = TaskMonitor(self._pathspec, task_id)
    if not task_monitor.get_state().header:
      log.info('Unable to load task "%s"' % task_id)
      return
    sandbox = task_monitor.get_state().header.sandbox
    resource_monitor = self._resource_monitor(task_monitor, sandbox)
    resource_monitor.start()
    self._active_tasks[task_id] = ActiveObservedTask(
      task_id=task_id, pathspec=self._pathspec,
      task_monitor=task_monitor, resource_monitor=resource_monitor
    )

  @Lockable.sync
  def add_finished_task(self, task_id):
    self._finished_tasks[task_id] = FinishedObservedTask(
      task_id=task_id, pathspec=self._pathspec
    )

  @Lockable.sync
  def active_to_finished(self, task_id):
    self.remove_active_task(task_id)
    self.add_finished_task(task_id)

  @Lockable.sync
  def remove_active_task(self, task_id):
    task = self.active_tasks.pop(task_id)
    task.resource_monitor.kill()

  @Lockable.sync
  def remove_finished_task(self, task_id):
    self.finished_tasks.pop(task_id)

  def run(self):
    """
      The internal thread for the observer.  This periodically polls the
      checkpoint root for new tasks, or transitions of tasks from active to
      finished state.
    """
    while not self._stop_event.is_set():
      time.sleep(self.POLLING_INTERVAL.as_(Time.SECONDS))

      active_tasks = [task_id for _, task_id in self._detector.get_task_ids(state='active')]
      finished_tasks = [task_id for _, task_id in self._detector.get_task_ids(state='finished')]

      with self.lock:

        # Ensure all tasks currently detected on the system are observed appropriately
        for active in active_tasks:
          if active not in self.active_tasks:
            log.debug('task_id %s (unknown) -> active' % active)
            self.add_active_task(active)
        for finished in finished_tasks:
          if finished in self.active_tasks:
            log.debug('task_id %s active -> finished' % finished)
            self.active_to_finished(finished)
          elif finished not in self.finished_tasks:
            log.debug('task_id %s (unknown) -> finished' % finished)
            self.add_finished_task(finished)

        # Remove ObservedTasks for tasks no longer detected on the system
        for unknown in set(self.active_tasks) - set(active_tasks + finished_tasks):
          log.debug('task_id %s active -> (unknown)' % unknown)
          self.remove_active_task(unknown)
        for unknown in set(self.finished_tasks) - set(active_tasks + finished_tasks):
          log.debug('task_id %s finished -> (unknown)' % unknown)
          self.remove_finished_task(unknown)

  @Lockable.sync
  def process_from_name(self, task_id, process_id):
    if task_id in self.all_tasks:
      task = self.all_tasks[task_id].task
      if task:
        for process in task.processes():
          if process.name().get() == process_id:
            return process

  @Lockable.sync
  def task_count(self):
    """
      Return the count of tasks that could be ready properly from disk.
      This may be <= self.task_id_count()
    """
    return dict(
      active=len(self.active_tasks),
      finished=len(self.finished_tasks),
      all=len(self.all_tasks),
    )

  @Lockable.sync
  def task_id_count(self):
    """
      Return the raw count of active and finished task_ids from the TaskDetector.
    """
    num_active = len(list(self._detector.get_task_ids(state='active')))
    num_finished = len(list(self._detector.get_task_ids(state='finished')))
    return dict(active=num_active, finished=num_finished, all=num_active + num_finished)

  def _get_tasks_of_type(self, type):
    """Convenience function to return all tasks of a given type"""
    tasks = {
      'active': self.active_tasks,
      'finished': self.finished_tasks,
      'all': self.all_tasks,
    }.get(type, None)

    if tasks is None:
      log.error('Unknown task type %s' % type)
      return {}

    return tasks

  @Lockable.sync
  def state(self, task_id):
    """Return a dict containing mapped information about a task's state"""
    real_state = self.raw_state(task_id)
    if real_state is None or real_state.header is None:
      return {}
    else:
      return dict(
        task_id=real_state.header.task_id,
        launch_time=real_state.header.launch_time_ms / 1000.0,
        sandbox=real_state.header.sandbox,
        hostname=real_state.header.hostname,
        user=real_state.header.user
      )

  @Lockable.sync
  def raw_state(self, task_id):
    """
      Return the current runner state (thrift blob: gen.twitter.thermos.ttypes.RunnerState)
      of a given task id
    """
    if task_id not in self.all_tasks:
      return None
    return self.all_tasks[task_id].state

  @Lockable.sync
  def _task_processes(self, task_id):
    """
      Return the processes of a task given its task_id.

      Returns a map from state to processes in that state, where possible
      states are: waiting, running, success, failed.
    """
    if task_id not in self.all_tasks:
      return {}
    state = self.raw_state(task_id)
    if state is None or state.header is None:
      return {}

    waiting, running, success, failed, killed = [], [], [], [], []
    for process, runs in state.processes.items():
      # No runs ==> nothing started.
      if len(runs) == 0:
        waiting.append(process)
      else:
        if runs[-1].state in (None, ProcessState.WAITING, ProcessState.LOST):
          waiting.append(process)
        elif runs[-1].state in (ProcessState.FORKED, ProcessState.RUNNING):
          running.append(process)
        elif runs[-1].state == ProcessState.SUCCESS:
          success.append(process)
        elif runs[-1].state == ProcessState.FAILED:
          failed.append(process)
        elif runs[-1].state == ProcessState.KILLED:
          killed.append(process)
        else:
          # TODO(wickman)  Consider log.error instead of raising.
          raise self.UnexpectedState(
            "Unexpected ProcessHistoryState: %s" % state.processes[process].state)

    return dict(waiting=waiting, running=running, success=success, failed=failed, killed=killed)

  @Lockable.sync
  def main(self, type=None, offset=None, num=None):
    """Return a set of information about tasks, optionally filtered

      Args:
        type = (all|active|finished|None) [default: all]
        offset = offset into the list of task_ids [default: 0]
        num = number of results to return [default: 20]

      Tasks are sorted by interest:
        - active tasks are sorted by start time
        - finished tasks are sorted by completion time

      Returns:
        {
          tasks: [task_id_1, ..., task_id_N],
          type: query type,
          offset: next offset,
          num: next num
        }

    """
    type = type or 'all'
    offset = offset or 0
    num = num or 20

    # Get a list of all ObservedTasks of requested type
    tasks = sorted((task for task in self._get_tasks_of_type(type).values()),
                   key=attrgetter('mtime'), reverse=True)

    # Filter by requested offset + number of results
    end = num
    if offset < 0:
      offset = offset % len(tasks) if len(tasks) > abs(offset) else 0
    end += offset
    tasks = tasks[offset:end]

    def task_row(observed_task):
      """Generate an output row for a Task"""
      task = self._task(observed_task.task_id)
      # tasks include those which could not be found properly and are hence empty {}
      if task:
        return dict(
            task_id=observed_task.task_id,
            name=task['name'],
            role=task['user'],
            launch_timestamp=task['launch_timestamp'],
            state=task['state'],
            state_timestamp=task['state_timestamp'],
            ports=task['ports'],
            **task['resource_consumption'])

    return dict(
      tasks=filter(None, map(task_row, tasks)),
      type=type,
      offset=offset,
      num=num,
      task_count=self.task_count()[type],
    )

  def _sample(self, task_id):
    if task_id not in self.active_tasks:
      log.debug("Task %s not found in active tasks" % task_id)
      sample = ProcessSample.empty().to_dict()
      sample['disk'] = 0
    else:
      resource_sample = self.active_tasks[task_id].resource_monitor.sample()[1]
      sample = resource_sample.process_sample.to_dict()
      sample['disk'] = resource_sample.disk_usage
      log.debug("Got sample for task %s: %s" % (task_id, sample))
    return sample

  @Lockable.sync
  def task_statuses(self, task_id):
    """
      Return the sequence of task states.

      [(task_state [string], timestamp), ...]
    """

    # Unknown task_id.
    if task_id not in self.all_tasks:
      return []

    task = self.all_tasks[task_id]
    if task is None:
      return []

    state = self.raw_state(task_id)
    if state is None or state.header is None:
      return []

    # Get the timestamp of the transition into the current state.
    return [
      (TaskState._VALUES_TO_NAMES.get(st.state, 'UNKNOWN'), st.timestamp_ms / 1000)
      for st in state.statuses]

  @Lockable.sync
  def _task(self, task_id):
    """
      Return composite information about a particular task task_id, given the below
      schema.

      {
         task_id: string,
         name: string,
         user: string,
         launch_timestamp: seconds,
         state: string [ACTIVE, SUCCESS, FAILED]
         ports: { name1: 'url', name2: 'url2' }
         resource_consumption: { cpu:, ram:, disk: }
         processes: { -> names only
            waiting: [],
            running: [],
            success: [],
            failed:  []
         }
      }
    """
    # Unknown task_id.
    if task_id not in self.all_tasks:
      return {}

    task = self.all_tasks[task_id].task
    if task is None:
      # TODO(wickman)  Can this happen?
      log.error('Could not find task: %s' % task_id)
      return {}

    state = self.raw_state(task_id)
    if state is None or state.header is None:
      # TODO(wickman)  Can this happen?
      return {}

    # Get the timestamp of the transition into the current state.
    current_state = state.statuses[-1].state
    last_state = state.statuses[0]
    state_timestamp = 0
    for status in state.statuses:
      if status.state == current_state and last_state != current_state:
        state_timestamp = status.timestamp_ms / 1000
      last_state = status.state

    return dict(
       task_id=task_id,
       name=task.name().get(),
       launch_timestamp=state.statuses[0].timestamp_ms / 1000,
       state=TaskState._VALUES_TO_NAMES[state.statuses[-1].state],
       state_timestamp=state_timestamp,
       user=state.header.user,
       resource_consumption=self._sample(task_id),
       ports=state.header.ports,
       processes=self._task_processes(task_id),
       task_struct=task,
    )

  @Lockable.sync
  def _get_process_resource_consumption(self, task_id, process_name):
    if task_id not in self.active_tasks:
      log.debug("Task %s not found in active tasks" % task_id)
      return ProcessSample.empty().to_dict()
    sample = self.active_tasks[task_id].resource_monitor.sample_by_process(process_name).to_dict()
    log.debug('Resource consumption (%s, %s) => %s' % (task_id, process_name, sample))
    return sample

  @Lockable.sync
  def _get_process_tuple(self, history, run):
    """
      Return the basic description of a process run if it exists, otherwise
      an empty dictionary.

      {
        process_name: string
        process_run: int
        state: string [WAITING, FORKED, RUNNING, SUCCESS, KILLED, FAILED, LOST]
        (optional) start_time: seconds from epoch
        (optional) stop_time: seconds from epoch
      }
    """
    if len(history) == 0:
      return {}
    if run >= len(history):
      return {}
    else:
      process_run = history[run]
      run = run % len(history)
      d = dict(
        process_name=process_run.process,
        process_run=run,
        state=ProcessState._VALUES_TO_NAMES[process_run.state],
      )
      if process_run.start_time:
        d.update(start_time=process_run.start_time)
      if process_run.stop_time:
        d.update(stop_time=process_run.stop_time)
      return d

  @Lockable.sync
  def process(self, task_id, process, run=None):
    """
      Returns a process run, where the schema is given below:

      {
        process_name: string
        process_run: int
        used: { cpu: float, ram: int bytes, disk: int bytes }
        start_time: (time since epoch in millis (utc))
        stop_time: (time since epoch in millis (utc))
        state: string [WAITING, FORKED, RUNNING, SUCCESS, KILLED, FAILED, LOST]
      }

      If run is None, return the latest run.
    """
    state = self.raw_state(task_id)
    if state is None or state.header is None:
      return {}
    if process not in state.processes:
      return {}
    history = state.processes[process]
    run = int(run) if run is not None else -1
    tup = self._get_process_tuple(history, run)
    if not tup:
      return {}
    if tup.get('state') == 'RUNNING':
      tup.update(used=self._get_process_resource_consumption(task_id, process))
    return tup

  @Lockable.sync
  def _processes(self, task_id):
    """
      Return
        {
          process1: { ... }
          process2: { ... }
          ...
          processN: { ... }
        }

      where processK is the latest run of processK and in the schema as
      defined by process().
    """

    if task_id not in self.all_tasks:
      return {}
    state = self.raw_state(task_id)
    if state is None or state.header is None:
      return {}

    processes = self._task_processes(task_id)
    d = dict()
    for process_type in processes:
      for process_name in processes[process_type]:
        d[process_name] = self.process(task_id, process_name)
    return d

  @Lockable.sync
  def processes(self, task_ids):
    """
      Given a list of task_ids, returns a map of task_id => processes, where processes
      is defined by the schema in _processes.
    """
    if not isinstance(task_ids, (list, tuple)):
      return {}
    return dict((task_id, self._processes(task_id)) for task_id in task_ids)

  @Lockable.sync
  def get_run_number(self, runner_state, process, run=None):
    if runner_state is not None and runner_state.processes is not None:
      run = run if run is not None else -1
      if run < len(runner_state.processes[process]):
        if len(runner_state.processes[process]) > 0:
          return run % len(runner_state.processes[process])

  @Lockable.sync
  def logs(self, task_id, process, run=None):
    """
      Given a task_id and a process and (optional) run number, return a dict:
      {
        stderr: [dir, filename]
        stdout: [dir, filename]
      }

      If the run number is unspecified, uses the latest run.

      TODO(wickman)  Just return the filenames directly?
    """
    runner_state = self.raw_state(task_id)
    if runner_state is None or runner_state.header is None:
      return {}
    run = self.get_run_number(runner_state, process, run)
    if run is None:
      return {}
    log_path = self._pathspec.given(task_id=task_id, process=process, run=run,
                                    log_dir=runner_state.header.log_dir).getpath('process_logdir')
    return dict(
      stdout=[log_path, 'stdout'],
      stderr=[log_path, 'stderr']
    )

  @staticmethod
  def _sanitize_path(base_path, relpath):
    """
      Attempts to sanitize a path through path normalization, also making sure
      that the relative path is contained inside of base_path.
    """
    if relpath is None:
      relpath = "."
    normalized_base = os.path.realpath(base_path)
    normalized = os.path.realpath(os.path.join(base_path, relpath))
    if normalized.startswith(normalized_base):
      return (normalized_base, os.path.relpath(normalized, normalized_base))
    return (None, None)

  @Lockable.sync
  def valid_file(self, task_id, path):
    """
      Like valid_path, but also verify the given path is a file
    """
    chroot, path = self.valid_path(task_id, path)
    if chroot and path and os.path.isfile(os.path.join(chroot, path)):
      return chroot, path
    return None, None

  @Lockable.sync
  def valid_path(self, task_id, path):
    """
      Given a task_id and a path within that task_id's sandbox, verify:
        (1) it's actually in the sandbox and not outside
        (2) it's a valid, existing path
      Returns chroot and the pathname relative to that chroot.
    """
    runner_state = self.raw_state(task_id)
    if runner_state is None or runner_state.header is None:
      return None, None
    try:
      chroot = runner_state.header.sandbox
    except AttributeError:
      return None, None
    chroot, path = self._sanitize_path(chroot, path)
    if chroot and path:
      return chroot, path
    return None, None

  @Lockable.sync
  def files(self, task_id, path=None):
    """
      Returns dictionary
      {
        task_id: task_id
        chroot: absolute directory on machine
        path: sanitized relative path w.r.t. chroot
        dirs: list of directories
        files: list of files
      }
    """
    # TODO(jon): DEPRECATED: most of the necessary logic is handled directly in the templates.
    # Also, global s/chroot/sandbox/?
    empty = dict(task_id=task_id, chroot=None, path=None, dirs=None, files=None)
    path = path if path is not None else '.'
    runner_state = self.raw_state(task_id)
    if runner_state is None:
      return empty
    try:
      chroot = runner_state.header.sandbox
    except AttributeError:
      return empty
    if chroot is None:  # chroot-less job
      return empty
    chroot, path = self._sanitize_path(chroot, path)
    if (chroot is None or path is None
        or not os.path.isdir(os.path.join(chroot, path))):
      return empty
    names = os.listdir(os.path.join(chroot, path))
    dirs, files = [], []
    for name in names:
      if os.path.isdir(os.path.join(chroot, path, name)):
        dirs.append(name)
      else:
        files.append(name)
    return dict(
      task_id=task_id,
      chroot=chroot,
      path=path,
      dirs=dirs,
      files=files
    )
Exemple #15
0
class ProcessBase(object):
    """
    Encapsulate a running process for a task.
  """
    class Error(Exception):
        pass

    class UnknownUserError(Error):
        pass

    class CheckpointError(Error):
        pass

    class UnspecifiedSandbox(Error):
        pass

    class PermissionError(Error):
        pass

    CONTROL_WAIT_CHECK_INTERVAL = Amount(100, Time.MILLISECONDS)
    MAXIMUM_CONTROL_WAIT = Amount(1, Time.MINUTES)

    def __init__(self,
                 name,
                 cmdline,
                 sequence,
                 pathspec,
                 sandbox_dir,
                 user=None,
                 platform=None,
                 logger_mode=LoggerMode.STANDARD,
                 rotate_log_size=None,
                 rotate_log_backups=None):
        """
      required:
        name        = name of the process
        cmdline     = cmdline of the process
        sequence    = the next available sequence number for state updates
        pathspec    = TaskPath object for synthesizing path names
        sandbox_dir = the sandbox in which to run the process
        platform    = Platform providing fork, clock, getpid

      optional:
        user               = the user to run as (if unspecified, will default to current user.)
                             if specified to a user that is not the current user, you must have root
                             access
        logger_mode        = The type of logger to use for the process.
        rotate_log_size    = The maximum size of the rotated stdout/stderr logs.
        rotate_log_backups = The maximum number of rotated stdout/stderr log backups.
    """
        self._name = name
        self._cmdline = cmdline
        self._pathspec = pathspec
        self._seq = sequence
        self._sandbox = sandbox_dir
        if self._sandbox:
            safe_mkdir(self._sandbox)
        self._pid = None
        self._fork_time = None
        self._user = user
        self._ckpt = None
        self._ckpt_head = -1
        if platform is None:
            raise ValueError("Platform must be specified")
        self._platform = platform
        self._logger_mode = logger_mode
        self._rotate_log_size = rotate_log_size
        self._rotate_log_backups = rotate_log_backups

        if not LoggerMode.is_valid(self._logger_mode):
            raise ValueError("Logger mode %s is invalid." % self._logger_mode)

        if self._logger_mode == LoggerMode.ROTATE:
            if self._rotate_log_size.as_(Data.BYTES) <= 0:
                raise ValueError('Log size cannot be less than one byte.')
            if self._rotate_log_backups <= 0:
                raise ValueError('Log backups cannot be less than one.')

    def _log(self, msg):
        log.debug('[process:%5s=%s]: %s' % (self._pid, self.name(), msg))

    def _getpwuid(self):
        """Returns a tuple of the user (i.e. --user) and current user."""
        uid = os.getuid()
        try:
            current_user = pwd.getpwuid(uid)
        except KeyError:
            raise self.UnknownUserError('Unknown uid %s!' % uid)
        try:
            user = pwd.getpwnam(
                self._user) if self._user is not None else current_user
        except KeyError:
            raise self.UnknownUserError('Unable to get pwent information!')
        return user, current_user

    def _ckpt_write(self, msg):
        self._init_ckpt_if_necessary()
        self._log("child state transition [%s] <= %s" %
                  (self.ckpt_file(), msg))
        self._ckpt.write(msg)

    def _write_process_update(self, **kw):
        """Write a process update to the coordinator's checkpoint stream."""
        process_status = ProcessStatus(**kw)
        process_status.seq = self._seq
        process_status.process = self.name()
        self._ckpt_write(RunnerCkpt(process_status=process_status))
        self._seq += 1

    def _write_initial_update(self):
        self._write_process_update(state=ProcessState.FORKED,
                                   fork_time=self._fork_time,
                                   coordinator_pid=self._pid)

    def cmdline(self):
        return self._cmdline

    def name(self):
        return self._name

    def pid(self):
        """pid of the coordinator"""
        return self._pid

    def rebind(self, pid, fork_time):
        """rebind Process to an existing coordinator pid without forking"""
        self._pid = pid
        self._fork_time = fork_time

    def ckpt_file(self):
        return self._pathspec.getpath('process_checkpoint')

    def process_logdir(self):
        return self._pathspec.getpath('process_logdir')

    def _setup_ckpt(self):
        """Set up the checkpoint: must be run on the parent."""
        self._log('initializing checkpoint file: %s' % self.ckpt_file())
        ckpt_fp = lock_file(self.ckpt_file(), "a+")
        if ckpt_fp in (None, False):
            raise self.CheckpointError(
                'Could not acquire checkpoint permission or lock for %s!' %
                self.ckpt_file())
        self._ckpt_head = os.path.getsize(self.ckpt_file())
        ckpt_fp.seek(self._ckpt_head)
        self._ckpt = ThriftRecordWriter(ckpt_fp)
        self._ckpt.set_sync(True)

    def _init_ckpt_if_necessary(self):
        if self._ckpt is None:
            self._setup_ckpt()

    def _wait_for_control(self):
        """Wait for control of the checkpoint stream: must be run in the child."""
        total_wait_time = Amount(0, Time.SECONDS)

        with open(self.ckpt_file(), 'r') as fp:
            fp.seek(self._ckpt_head)
            rr = ThriftRecordReader(fp, RunnerCkpt)
            while total_wait_time < self.MAXIMUM_CONTROL_WAIT:
                ckpt_tail = os.path.getsize(self.ckpt_file())
                if ckpt_tail == self._ckpt_head:
                    self._platform.clock().sleep(
                        self.CONTROL_WAIT_CHECK_INTERVAL.as_(Time.SECONDS))
                    total_wait_time += self.CONTROL_WAIT_CHECK_INTERVAL
                    continue
                checkpoint = rr.try_read()
                if checkpoint:
                    if not checkpoint.process_status:
                        raise self.CheckpointError(
                            'No process status in checkpoint!')
                    if (checkpoint.process_status.process != self.name()
                            or checkpoint.process_status.state !=
                            ProcessState.FORKED
                            or checkpoint.process_status.fork_time !=
                            self._fork_time
                            or checkpoint.process_status.coordinator_pid !=
                            self._pid):
                        self._log('Losing control of the checkpoint stream:')
                        self._log('   fork_time [%s] vs self._fork_time [%s]' %
                                  (checkpoint.process_status.fork_time,
                                   self._fork_time))
                        self._log('   coordinator_pid [%s] vs self._pid [%s]' %
                                  (checkpoint.process_status.coordinator_pid,
                                   self._pid))
                        raise self.CheckpointError(
                            'Lost control of the checkpoint stream!')
                    self._log(
                        'Taking control of the checkpoint stream at record: %s'
                        % checkpoint.process_status)
                    self._seq = checkpoint.process_status.seq + 1
                    return True
        raise self.CheckpointError('Timed out waiting for checkpoint stream!')

    def _prepare_fork(self):
        user, current_user = self._getpwuid()
        if self._user:
            if user != current_user and os.geteuid() != 0:
                raise self.PermissionError(
                    'Must be root to run processes as other users!')
        self._fork_time = self._platform.clock().time()
        self._setup_ckpt()
        # Since the forked process is responsible for creating log files, it needs to own the log dir.
        safe_mkdir(self.process_logdir())
        os.chown(self.process_logdir(), user.pw_uid, user.pw_gid)

    def _finalize_fork(self):
        self._write_initial_update()
        self._ckpt.close()
        self._ckpt = None

    def start(self):
        """
      This is the main call point from the runner, and forks a co-ordinator process to run the
      target process (i.e. self.cmdline())

      The parent returns immediately and populates information about the pid of the co-ordinator.
      The child (co-ordinator) will launch the target process in a subprocess.
    """
        self._prepare_fork(
        )  # calls _setup_ckpt which can raise CheckpointError
        # calls _getpwuid which can raise:
        #    UnknownUserError
        #    PermissionError
        self._pid = self._platform.fork()
        if self._pid == 0:
            self._pid = self._platform.getpid()
            self._wait_for_control()  # can raise CheckpointError
            try:
                self.execute()
            finally:
                self._ckpt.close()
                self.finish()
        else:
            self._finalize_fork()  # can raise CheckpointError

    def execute(self):
        raise NotImplementedError

    def finish(self):
        pass
 def map_to_amount(amtlist):
   return [Amount(x, Time.MILLISECONDS) for x in amtlist]
class MesosMaintenance(object):
    """This class provides more methods to interact with the mesos cluster and perform
  maintenance.
  """

    DEFAULT_GROUPING = 'by_host'
    GROUPING_FUNCTIONS = {
        'by_host': group_by_host,
    }
    START_MAINTENANCE_DELAY = Amount(30, Time.SECONDS)

    @classmethod
    def group_hosts(cls, hostnames, grouping_function=DEFAULT_GROUPING):
        try:
            grouping_function = cls.GROUPING_FUNCTIONS[grouping_function]
        except KeyError:
            raise ValueError('Unknown grouping function %s!' %
                             grouping_function)
        groups = defaultdict(set)
        for hostname in hostnames:
            groups[grouping_function(hostname)].add(hostname)
        return groups

    @classmethod
    def iter_batches(cls,
                     hostnames,
                     batch_size,
                     grouping_function=DEFAULT_GROUPING):
        if batch_size <= 0:
            raise ValueError('Batch size must be > 0!')
        groups = cls.group_hosts(hostnames, grouping_function)
        groups = sorted(groups.items(), key=lambda v: v[0])
        for k in range(0, len(groups), batch_size):
            yield Hosts(
                set.union(*(hostset
                            for (key, hostset) in groups[k:k + batch_size])))

    def __init__(self, cluster, verbosity):
        self._client = AuroraClientAPI(cluster, verbosity == 'verbose')

    def _drain_hosts(self, drainable_hosts, clock=time):
        """This will actively turn down tasks running on hosts."""
        check_and_log_response(self._client.drain_hosts(drainable_hosts))
        not_ready_hosts = [hostname for hostname in drainable_hosts.hostNames]
        while not_ready_hosts:
            log.info("Sleeping for %s." % self.START_MAINTENANCE_DELAY)
            clock.sleep(self.START_MAINTENANCE_DELAY.as_(Time.SECONDS))
            resp = self._client.maintenance_status(Hosts(not_ready_hosts))
            #TODO(jsmith): Workaround until scheduler responds with unknown slaves in MESOS-3454
            if not resp.result.maintenanceStatusResult.statuses:
                not_ready_hosts = None
            for host_status in resp.result.maintenanceStatusResult.statuses:
                if host_status.mode != MaintenanceMode.DRAINED:
                    log.warning(
                        '%s is currently in status %s' %
                        (host_status.host,
                         MaintenanceMode._VALUES_TO_NAMES[host_status.mode]))
                else:
                    not_ready_hosts.remove(host_status.host)

    def _complete_maintenance(self, drained_hosts):
        """End the maintenance status for a give set of hosts."""
        check_and_log_response(self._client.end_maintenance(drained_hosts))
        resp = self._client.maintenance_status(drained_hosts)
        for host_status in resp.result.maintenanceStatusResult.statuses:
            if host_status.mode != MaintenanceMode.NONE:
                log.warning('%s is DRAINING or in DRAINED' % host_status.host)

    def _operate_on_hosts(self, drained_hosts, callback):
        """Perform a given operation on a list of hosts that are ready for maintenance."""
        for host in drained_hosts.hostNames:
            callback(host)

    def end_maintenance(self, hosts):
        """Pull a list of hosts out of maintenance mode."""
        self._complete_maintenance(Hosts(set(hosts)))

    def start_maintenance(self, hosts):
        """Put a list of hosts into maintenance mode, to de-prioritize scheduling."""
        check_and_log_response(
            self._client.start_maintenance(Hosts(set(hosts))))

    def perform_maintenance(self,
                            hosts,
                            batch_size=1,
                            grouping_function=DEFAULT_GROUPING,
                            callback=None):
        """The wrap a callback in between sending hosts into maintenance mode and back.

    Walk through the process of putting hosts into maintenance, draining them of tasks,
    performing an action on them once drained, then removing them from maintenance mode
    so tasks can schedule.
    """
        self._complete_maintenance(Hosts(set(hosts)))
        self.start_maintenance(hosts)

        for hosts in self.iter_batches(hosts, batch_size, grouping_function):
            self._drain_hosts(hosts)
            if callback:
                self._operate_on_hosts(hosts, callback)
            self._complete_maintenance(hosts)

    def check_status(self, hosts):
        resp = self._client.maintenance_status(Hosts(set(hosts)))
        check_and_log_response(resp)
        statuses = []
        for host_status in resp.result.maintenanceStatusResult.statuses:
            statuses.append(
                (host_status.host,
                 MaintenanceMode._VALUES_TO_NAMES[host_status.mode]))
        return statuses
def test_reduction():
  minute = Amount(60, Time.SECONDS)
  assert minute._amount == 1 and minute._unit == Time.MINUTES
Exemple #19
0
class FastStatusManager(StatusManager):
    POLL_WAIT = Amount(10, Time.MILLISECONDS)
def test_mul():
  assert 5 * Amount(12, Time.SECONDS) == Amount(12, Time.SECONDS) * 5
  amount = 5 * Amount(12, Time.SECONDS)
  assert amount._amount == 1 and amount._unit == Time.MINUTES
Exemple #21
0
class AuroraExecutor(ExecutorBase, Observable):
    PERSISTENCE_WAIT = Amount(5, Time.SECONDS)
    SANDBOX_INITIALIZATION_TIMEOUT = Amount(10, Time.MINUTES)
    START_TIMEOUT = Amount(2, Time.MINUTES)
    STOP_TIMEOUT = Amount(2, Time.MINUTES)
    STOP_WAIT = Amount(5, Time.SECONDS)

    def __init__(self,
                 runner_provider,
                 status_manager_class=StatusManager,
                 sandbox_provider=DefaultSandboxProvider(),
                 status_providers=(),
                 clock=time,
                 no_sandbox_create_user=False,
                 sandbox_mount_point=None):

        ExecutorBase.__init__(self)
        if not isinstance(runner_provider, TaskRunnerProvider):
            raise TypeError(
                'runner_provider must be a TaskRunnerProvider, got %s' %
                type(runner_provider))
        self._runner = None
        self._runner_provider = runner_provider
        self._clock = clock
        self._task_id = None
        self._status_providers = status_providers
        self._status_manager = None
        self._status_manager_class = status_manager_class
        self._sandbox = None
        self._sandbox_provider = sandbox_provider
        self._no_sandbox_create_user = no_sandbox_create_user
        self._sandbox_mount_point = sandbox_mount_point
        self._kill_manager = KillManager()
        # Events that are exposed for interested entities
        self.runner_aborted = threading.Event()
        self.runner_started = threading.Event()
        self.sandbox_initialized = threading.Event()
        self.sandbox_created = threading.Event()
        self.status_manager_started = threading.Event()
        self.terminated = threading.Event()
        self.launched = threading.Event()

    @property
    def runner(self):
        return self._runner

    def _die(self, driver, status, msg):
        log.fatal(msg)
        self.send_update(driver, self._task_id, status, msg)
        defer(driver.stop, delay=self.STOP_WAIT)

    def _run(self, driver, assigned_task, mounted_volume_paths):
        """
      Commence running a Task.
        - Initialize the sandbox
        - Start the ThermosTaskRunner (fork the Thermos TaskRunner)
        - Set up necessary HealthCheckers
        - Set up StatusManager, and attach HealthCheckers
    """
        self.send_update(driver, self._task_id, mesos_pb2.TASK_STARTING,
                         'Initializing sandbox.')

        if not self._initialize_sandbox(driver, assigned_task,
                                        mounted_volume_paths):
            return

        # start the process on a separate thread and give the message processing thread back
        # to the driver
        try:
            self._runner = self._runner_provider.from_assigned_task(
                assigned_task, self._sandbox)
        except TaskError as e:
            self.runner_aborted.set()
            self._die(driver, mesos_pb2.TASK_FAILED, str(e))
            return

        if not isinstance(self._runner, TaskRunner):
            self._die(driver, mesos_pb2.TASK_FAILED, 'Unrecognized task!')
            return

        if not self._start_runner(driver, assigned_task):
            return

        try:
            self._start_status_manager(driver, assigned_task)
        except Exception:
            log.error(traceback.format_exc())
            self._die(driver, mesos_pb2.TASK_FAILED, "Internal error")

    def _initialize_sandbox(self, driver, assigned_task, mounted_volume_paths):
        self._sandbox = self._sandbox_provider.from_assigned_task(
            assigned_task,
            no_create_user=self._no_sandbox_create_user,
            mounted_volume_paths=mounted_volume_paths,
            sandbox_mount_point=self._sandbox_mount_point)
        self.sandbox_initialized.set()
        try:
            propagate_deadline(self._sandbox.create,
                               timeout=self.SANDBOX_INITIALIZATION_TIMEOUT)
        except Timeout:
            self._die(driver, mesos_pb2.TASK_FAILED,
                      'Timed out waiting for sandbox to initialize!')
            return
        except self._sandbox.Error as e:
            self._die(driver, mesos_pb2.TASK_FAILED,
                      'Failed to initialize sandbox: %s' % e)
            return
        except Exception as e:
            self._die(driver, mesos_pb2.TASK_FAILED,
                      'Unknown exception initializing sandbox: %s' % e)
            return
        self.sandbox_created.set()
        return True

    def _start_runner(self, driver, assigned_task):
        if self.runner_aborted.is_set():
            self._die(driver, mesos_pb2.TASK_KILLED,
                      'Task killed during initialization.')

        try:
            propagate_deadline(self._runner.start, timeout=self.START_TIMEOUT)
        except TaskError as e:
            self._die(driver, mesos_pb2.TASK_FAILED,
                      'Task initialization failed: %s' % e)
            return False
        except Timeout:
            self._die(driver, mesos_pb2.TASK_LOST,
                      'Timed out waiting for task to start!')
            return False

        self.runner_started.set()
        log.debug('Task started.')

        return True

    def _start_status_manager(self, driver, assigned_task):
        status_checkers = [self._kill_manager]
        self.metrics.register_observable(self._kill_manager.name(),
                                         self._kill_manager)

        for status_provider in self._status_providers:
            status_checker = status_provider.from_assigned_task(
                assigned_task, self._sandbox)
            if status_checker is None:
                continue
            status_checkers.append(status_checker)
            self.metrics.register_observable(status_checker.name(),
                                             status_checker)

        self._chained_checker = ChainedStatusChecker(status_checkers)
        self._chained_checker.start()

        # chain the runner to the other checkers, but do not chain .start()/.stop()
        complete_checker = ChainedStatusChecker(
            [self._runner, self._chained_checker])
        self._status_manager = self._status_manager_class(complete_checker,
                                                          self._signal_running,
                                                          self._shutdown,
                                                          clock=self._clock)
        self._status_manager.start()
        self.status_manager_started.set()

    def _signal_running(self, status_result):
        log.info('Send TASK_RUNNING status update. status: %s' % status_result)
        self.send_update(self._driver, self._task_id, mesos_pb2.TASK_RUNNING,
                         status_result.reason)

    def _signal_kill_manager(self, driver, task_id, reason):
        if self._task_id is None:
            log.error('Was asked to kill task but no task running!')
            return
        if task_id != self._task_id:
            log.error('Asked to kill a task other than what we are running!')
            return
        if not self.sandbox_created.is_set():
            log.error(
                'Asked to kill task with incomplete sandbox - aborting runner start'
            )
            self.runner_aborted.set()
            return
        self.log('Activating kill manager.')
        self._kill_manager.kill(reason)

    def _shutdown(self, status_result):
        runner_status = self._runner.status

        try:
            propagate_deadline(self._chained_checker.stop,
                               timeout=self.STOP_TIMEOUT)
        except Timeout:
            log.error('Failed to stop all checkers within deadline.')
        except Exception:
            log.error('Failed to stop health checkers:')
            log.error(traceback.format_exc())

        try:
            propagate_deadline(self._runner.stop, timeout=self.STOP_TIMEOUT)
        except Timeout:
            log.error('Failed to stop runner within deadline.')
        except Exception:
            log.error('Failed to stop runner:')
            log.error(traceback.format_exc())

        # If the runner was alive when _shutdown was called, defer to the status_result,
        # otherwise the runner's terminal state is the preferred state.
        exit_status = runner_status or status_result

        self.send_update(self._driver, self._task_id, exit_status.status,
                         status_result.reason)

        self.terminated.set()
        defer(self._driver.stop, delay=self.PERSISTENCE_WAIT)

    @classmethod
    def validate_task(cls, task):
        try:
            assigned_task = assigned_task_from_mesos_task(task)
            return assigned_task
        except Exception:
            log.fatal('Could not deserialize AssignedTask')
            log.fatal(traceback.format_exc())
            return None

    @classmethod
    def extract_mount_paths_from_task(cls, task):
        if task.executor and task.executor.container:
            return [v.container_path for v in task.executor.container.volumes]

        return None

    """ Mesos Executor API methods follow """

    def launchTask(self, driver, task):
        """
      Invoked when a task has been launched on this executor (initiated via Scheduler::launchTasks).
      Note that this task can be realized with a thread, a process, or some simple computation,
      however, no other callbacks will be invoked on this executor until this callback has returned.
    """
        self.launched.set()
        self.log('TaskInfo: %s' % task)
        self.log('launchTask got task: %s:%s' %
                 (task.name, task.task_id.value))

        # TODO(wickman)  Update the tests to call registered(), then remove this line and issue
        # an assert if self._driver is not populated.
        self._driver = driver

        if self._runner:
            log.error('Already running a task! %s' % self._task_id)
            self.send_update(
                driver, task.task_id.value, mesos_pb2.TASK_LOST,
                "Task already running on this executor: %s" % self._task_id)
            return

        self._slave_id = task.slave_id.value
        self._task_id = task.task_id.value

        assigned_task = self.validate_task(task)
        self.log("Assigned task: %s" % assigned_task)
        if not assigned_task:
            self.send_update(driver, self._task_id, mesos_pb2.TASK_FAILED,
                             'Could not deserialize task.')
            defer(driver.stop, delay=self.STOP_WAIT)
            return

        defer(lambda: self._run(driver, assigned_task,
                                self.extract_mount_paths_from_task(task)))

    def killTask(self, driver, task_id):
        """
     Invoked when a task running within this executor has been killed (via
     SchedulerDriver::killTask). Note that no status update will be sent on behalf of the executor,
     the executor is responsible for creating a new TaskStatus (i.e., with TASK_KILLED) and invoking
     ExecutorDriver::sendStatusUpdate.
    """
        self.log('killTask got task_id: %s' % task_id)
        self._signal_kill_manager(driver, task_id.value,
                                  "Instructed to kill task.")
        self.log('killTask returned.')

    def shutdown(self, driver):
        """
     Invoked when the executor should terminate all of its currently running tasks. Note that after
     Mesos has determined that an executor has terminated any tasks that the executor did not send
     terminal status updates for (e.g., TASK_KILLED, TASK_FINISHED, TASK_FAILED, etc) a TASK_LOST
     status update will be created.

    """
        self.log('shutdown called')
        if self._task_id:
            self.log('shutting down %s' % self._task_id)
            self._signal_kill_manager(driver, self._task_id,
                                      "Told to shut down executor.")
        self.log('shutdown returned')
class RESTfulArtifactCache(ArtifactCache):
    """An artifact cache that stores the artifacts on a RESTful service."""

    READ_SIZE = int(Amount(4, Data.MB).as_(Data.BYTES))

    def __init__(self, log, artifact_root, url_base, compress=True):
        """
    url_base: The prefix for urls on some RESTful service. We must be able to PUT and GET to any
              path under this base.
    compress: Whether to compress the artifacts before storing them.
    """
        ArtifactCache.__init__(self, log, artifact_root)
        parsed_url = urlparse.urlparse(url_base)
        if parsed_url.scheme == 'http':
            self._ssl = False
        elif parsed_url.scheme == 'https':
            self._ssl = True
        else:
            raise ValueError(
                'RESTfulArtifactCache only supports HTTP and HTTPS')
        self._timeout_secs = 4.0
        self._netloc = parsed_url.netloc
        self._path_prefix = parsed_url.path.rstrip('/')
        self.compress = compress

    def try_insert(self, cache_key, build_artifacts):
        with temporary_file_path() as tarfile:
            # In our tests, gzip is slightly less compressive than bzip2 on .class files,
            # but decompression times are much faster.
            mode = 'w:gz' if self.compress else 'w'
            with open_tar(tarfile, mode, dereference=True) as tarout:
                for artifact in build_artifacts:
                    # Adds dirs recursively.
                    tarout.add(artifact,
                               os.path.relpath(artifact, self.artifact_root))

            with open(tarfile, 'rb') as infile:
                path = self._path_for_key(cache_key)
                if not self._request('PUT', path, body=infile):
                    raise self.CacheError('Failed to PUT to %s. Error: 404' %
                                          self._url_string(path))

    def has(self, cache_key):
        return self._request('HEAD', self._path_for_key(cache_key)) is not None

    def use_cached_files(self, cache_key):
        # This implementation fetches the appropriate tarball and extracts it.
        path = self._path_for_key(cache_key)
        try:
            # Send an HTTP request for the tarball.
            response = self._request('GET', path)
            if response is None:
                return False
            expected_size = int(response.getheader('content-length', -1))
            if expected_size == -1:
                raise self.CacheError(
                    'No content-length header in HTTP response')

            done = False
            self.log.info('Reading %d bytes from artifact cache at %s' %
                          (expected_size, self._url_string(path)))

            with temporary_file() as outfile:
                total_bytes = 0
                # Read the data in a loop.
                while not done:
                    data = response.read(self.READ_SIZE)
                    outfile.write(data)
                    if len(data) < self.READ_SIZE:
                        done = True
                    total_bytes += len(data)
                outfile.close()
                self.log.debug('Read %d bytes' % total_bytes)

                # Check the size.
                if total_bytes != expected_size:
                    raise self.CacheError(
                        'Read only %d bytes from %d expected' %
                        (total_bytes, expected_size))
                # Extract the tarfile.
                with open_tar(outfile.name, 'r') as tarfile:
                    tarfile.extractall(self.artifact_root)
            return True
        except Exception as e:
            self.log.warn('Error while reading from artifact cache: %s' % e)
            return False

    def delete(self, cache_key):
        path = self._path_for_key(cache_key)
        self._request('DELETE', path)

    def _path_for_key(self, cache_key):
        # Note: it's important to use the id as well as the hash, because two different targets
        # may have the same hash if both have no sources, but we may still want to differentiate them.
        return '%s/%s/%s.tar%s' % (self._path_prefix, cache_key.id,
                                   cache_key.hash,
                                   '.gz' if self.compress else '')

    def _connect(self):
        if self._ssl:
            return httplib.HTTPSConnection(self._netloc,
                                           timeout=self._timeout_secs)
        else:
            return httplib.HTTPConnection(self._netloc,
                                          timeout=self._timeout_secs)

    # Returns a response if we get a 200, None if we get a 404 and raises an exception otherwise.
    def _request(self, method, path, body=None):
        self.log.debug('Sending %s request to %s' %
                       (method, self._url_string(path)))
        # TODO(benjy): Keep connection open and reuse?
        conn = self._connect()
        conn.request(method, path, body=body)
        response = conn.getresponse()
        # Allow all 2XX responses. E.g., nginx returns 201 on PUT. HEAD may return 204.
        if int(response.status / 100) == 2:
            return response
        elif response.status == 404:
            return None
        else:
            raise self.CacheError('Failed to %s %s. Error: %d %s' %
                                  (method, self._url_string(path),
                                   response.status, response.reason))

    def _url_string(self, path):
        return '%s://%s%s' % (
            ('https' if self._ssl else 'http'), self._netloc, path)
Exemple #23
0
class TaskRunner(object):
    """
    Run a ThermosTask.

    This class encapsulates the core logic to run and control the state of a Thermos task.
    Typically, it will be instantiated directly to control a new task, but a TaskRunner can also be
    synthesised from an existing task's checkpoint root
  """
    class Error(Exception):
        pass

    class InternalError(Error):
        pass

    class InvalidTask(Error):
        pass

    class PermissionError(Error):
        pass

    class StateError(Error):
        pass

    # Maximum amount of time we spend waiting for new updates from the checkpoint streams
    # before doing housecleaning (checking for LOST tasks, dead PIDs.)
    MAX_ITERATION_TIME = Amount(10, Time.SECONDS)

    # Minimum amount of time we wait between polls for updates on coordinator checkpoints.
    COORDINATOR_INTERVAL_SLEEP = Amount(1, Time.SECONDS)

    # Amount of time we're willing to wait after forking before we expect the runner to have
    # exec'ed the child process.
    LOST_TIMEOUT = Amount(60, Time.SECONDS)

    # Active task stages
    STAGES = {
        TaskState.ACTIVE: TaskRunnerStage_ACTIVE,
        TaskState.CLEANING: TaskRunnerStage_CLEANING,
        TaskState.FINALIZING: TaskRunnerStage_FINALIZING
    }

    @classmethod
    def get(cls, task_id, checkpoint_root):
        """
      Get a TaskRunner bound to the task_id in checkpoint_root.
    """
        path = TaskPath(root=checkpoint_root, task_id=task_id, state='active')
        task_json = path.getpath('task_path')
        task_checkpoint = path.getpath('runner_checkpoint')
        if not os.path.exists(task_json):
            return None
        task = ThermosConfigLoader.load_json(task_json)
        if task is None:
            return None
        if len(task.tasks()) == 0:
            return None
        try:
            checkpoint = CheckpointDispatcher.from_file(task_checkpoint)
            if checkpoint is None or checkpoint.header is None:
                return None
            return cls(task.tasks()[0].task(),
                       checkpoint_root,
                       checkpoint.header.sandbox,
                       log_dir=checkpoint.header.log_dir,
                       task_id=task_id,
                       portmap=checkpoint.header.ports,
                       hostname=checkpoint.header.hostname)
        except Exception as e:
            log.error(
                'Failed to reconstitute checkpoint in TaskRunner.get: %s',
                e,
                exc_info=True)
            return None

    def __init__(self,
                 task,
                 checkpoint_root,
                 sandbox,
                 log_dir=None,
                 task_id=None,
                 portmap=None,
                 user=None,
                 chroot=False,
                 clock=time,
                 universal_handler=None,
                 planner_class=TaskPlanner,
                 hostname=None,
                 process_logger_destination=None,
                 process_logger_mode=None,
                 rotate_log_size_mb=None,
                 rotate_log_backups=None,
                 preserve_env=False,
                 mesos_containerizer_path=None,
                 container_sandbox=None):
        """
      required:
        task (config.Task) = the task to run
        checkpoint_root (path) = the checkpoint root
        sandbox (path) = the sandbox in which the path will be run
                         [if None, cwd will be assumed, but garbage collection will be
                          disabled for this task.]

      optional:
        log_dir (string)  = directory to house stdout/stderr logs. If not specified, logs will be
                            written into the sandbox directory under .logs/
        task_id (string)  = bind to this task id.  if not specified, will synthesize an id based
                            upon task.name()
        portmap (dict)    = a map (string => integer) from name to port, e.g. { 'http': 80 }
        user (string)     = the user to run the task as.  if not current user, requires setuid
                            privileges.
        chroot (boolean)  = whether or not to chroot into the sandbox prior to exec.
        clock (time interface) = the clock to use throughout
        universal_handler = checkpoint record handler (only used for testing)
        planner_class (TaskPlanner class) = TaskPlanner class to use for constructing the task
                            planning policy.
        process_logger_destination (string) = The destination of logger to use for all processes.
        process_logger_mode (string) = The mode of logger to use for all processes.
        rotate_log_size_mb (integer) = The maximum size of the rotated stdout/stderr logs in MiB.
        rotate_log_backups (integer) = The maximum number of rotated stdout/stderr log backups.
        preserve_env (boolean) = whether or not env variables for the runner should be in the
                                 env for the task being run
        mesos_containerizer_path = the path to the mesos-containerizer executable that will be used
                                   to isolate the task's filesystem (if using a filesystem image).
        container_sandbox = the path within the isolated filesystem where the task's sandbox is
                            mounted.
    """
        if not issubclass(planner_class, TaskPlanner):
            raise TypeError('planner_class must be a TaskPlanner.')
        self._clock = clock
        launch_time = self._clock.time()
        launch_time_ms = '%06d' % int(
            (launch_time - int(launch_time)) * (10**6))
        if not task_id:
            self._task_id = '%s-%s.%s' % (
                task.name(),
                time.strftime('%Y%m%d-%H%M%S',
                              time.localtime(launch_time)), launch_time_ms)
        else:
            self._task_id = task_id
        current_user = TaskRunnerHelper.get_actual_user()
        self._user = user or current_user
        # TODO(wickman) This should be delegated to the ProcessPlatform / Helper
        if self._user != current_user:
            if os.geteuid() != 0:
                raise ValueError(
                    'task specifies user as %s, but %s does not have setuid permission!'
                    % (self._user, current_user))
        self._portmap = portmap or {}
        self._launch_time = launch_time
        self._log_dir = log_dir or os.path.join(sandbox, '.logs')
        self._process_logger_destination = process_logger_destination
        self._process_logger_mode = process_logger_mode
        self._rotate_log_size_mb = rotate_log_size_mb
        self._rotate_log_backups = rotate_log_backups
        self._pathspec = TaskPath(root=checkpoint_root,
                                  task_id=self._task_id,
                                  log_dir=self._log_dir)
        self._hostname = hostname or socket.gethostname()
        try:
            ThermosTaskValidator.assert_valid_task(task)
            ThermosTaskValidator.assert_valid_ports(task, self._portmap)
        except ThermosTaskValidator.InvalidTaskError as e:
            raise self.InvalidTask('Invalid task: %s' % e)
        context = ThermosContext(task_id=self._task_id,
                                 ports=self._portmap,
                                 user=self._user)
        self._task, uninterp = (task %
                                Environment(thermos=context)).interpolate()
        if len(uninterp) > 0:
            raise self.InvalidTask('Failed to interpolate task, missing: %s' %
                                   ', '.join(str(ref) for ref in uninterp))
        try:
            ThermosTaskValidator.assert_same_task(self._pathspec, self._task)
        except ThermosTaskValidator.InvalidTaskError as e:
            raise self.InvalidTask('Invalid task: %s' % e)
        self._plan = None  # plan currently being executed (updated by Handlers)
        self._regular_plan = planner_class(
            self._task,
            clock=clock,
            process_filter=lambda proc: proc.final().get() is False)
        self._finalizing_plan = planner_class(
            self._task,
            clock=clock,
            process_filter=lambda proc: proc.final().get() is True)
        self._chroot = chroot
        self._sandbox = sandbox
        self._container_sandbox = container_sandbox
        self._terminal_state = None
        self._ckpt = None
        self._process_map = dict(
            (p.name().get(), p) for p in self._task.processes())
        self._task_processes = {}
        self._stages = dict(
            (state, stage(self)) for state, stage in self.STAGES.items())
        self._finalization_start = None
        self._preemption_deadline = None
        self._watcher = ProcessMuxer(self._pathspec)
        self._state = RunnerState(processes={})
        self._preserve_env = preserve_env
        self._mesos_containerizer_path = mesos_containerizer_path

        # create runner state
        universal_handler = universal_handler or TaskRunnerUniversalHandler
        self._dispatcher = CheckpointDispatcher()
        self._dispatcher.register_handler(universal_handler(self))
        self._dispatcher.register_handler(TaskRunnerProcessHandler(self))
        self._dispatcher.register_handler(TaskRunnerTaskHandler(self))

        # recover checkpointed runner state and update plan
        self._recovery = True
        self._replay_runner_ckpt()

    @property
    def task(self):
        return self._task

    @property
    def task_id(self):
        return self._task_id

    @property
    def state(self):
        return self._state

    @property
    def processes(self):
        return self._task_processes

    def task_state(self):
        return self._state.statuses[
            -1].state if self._state.statuses else TaskState.ACTIVE

    def close_ckpt(self):
        """Force close the checkpoint stream.  This is necessary for runners terminated through
       exception propagation."""
        log.debug('Closing the checkpoint stream.')
        self._ckpt.close()

    @contextmanager
    def control(self, force=False):
        """
      Bind to the checkpoint associated with this task, position to the end of the log if
      it exists, or create it if it doesn't.  Fails if we cannot get "leadership" i.e. a
      file lock on the checkpoint stream.
    """
        if self.is_terminal():
            raise self.StateError(
                'Cannot take control of a task in terminal state.')
        if self._sandbox:
            safe_mkdir(self._sandbox)
        ckpt_file = self._pathspec.getpath('runner_checkpoint')
        try:
            self._ckpt = TaskRunnerHelper.open_checkpoint(ckpt_file,
                                                          force=force,
                                                          state=self._state)
        except TaskRunnerHelper.PermissionError:
            raise self.PermissionError('Unable to open checkpoint %s' %
                                       ckpt_file)
        log.debug('Flipping recovery mode off.')
        self._recovery = False
        self._set_task_status(self.task_state())
        self._resume_task()
        try:
            yield
        except Exception as e:
            log.error('Caught exception in self.control(): %s', e)
            log.error('  %s', traceback.format_exc())
        self._ckpt.close()

    def _resume_task(self):
        assert self._ckpt is not None
        unapplied_updates = self._replay_process_ckpts()
        if self.is_terminal():
            raise self.StateError('Cannot resume terminal task.')
        self._initialize_ckpt_header()
        self._replay(unapplied_updates)

    def _ckpt_write(self, record):
        """
      Write to the checkpoint stream if we're not in recovery mode.
    """
        if not self._recovery:
            self._ckpt.write(record)

    def _replay(self, checkpoints):
        """
      Replay a sequence of RunnerCkpts.
    """
        for checkpoint in checkpoints:
            self._dispatcher.dispatch(self._state, checkpoint)

    def _replay_runner_ckpt(self):
        """
      Replay the checkpoint stream associated with this task.
    """
        ckpt_file = self._pathspec.getpath('runner_checkpoint')
        if os.path.exists(ckpt_file):
            with open(ckpt_file, 'r') as fp:
                ckpt_recover = ThriftRecordReader(fp, RunnerCkpt)
                for record in ckpt_recover:
                    log.debug('Replaying runner checkpoint record: %s', record)
                    self._dispatcher.dispatch(self._state,
                                              record,
                                              recovery=True)

    def _replay_process_ckpts(self):
        """
      Replay the unmutating process checkpoints.  Return the unapplied process updates that
      would mutate the runner checkpoint stream.
    """
        process_updates = self._watcher.select()
        unapplied_process_updates = []
        for process_update in process_updates:
            if self._dispatcher.would_update(self._state, process_update):
                unapplied_process_updates.append(process_update)
            else:
                self._dispatcher.dispatch(self._state,
                                          process_update,
                                          recovery=True)
        return unapplied_process_updates

    def _initialize_ckpt_header(self):
        """
      Initializes the RunnerHeader for this checkpoint stream if it has not already
      been constructed.
    """
        if self._state.header is None:
            try:
                uid = pwd.getpwnam(self._user).pw_uid
            except KeyError:
                # This will cause failures downstream, but they will at least be correctly
                # reflected in the process state.
                log.error('Unknown user %s.', self._user)
                uid = None

            header = RunnerHeader(task_id=self._task_id,
                                  launch_time_ms=int(self._launch_time * 1000),
                                  sandbox=self._sandbox,
                                  log_dir=self._log_dir,
                                  hostname=self._hostname,
                                  user=self._user,
                                  uid=uid,
                                  ports=self._portmap)
            runner_ckpt = RunnerCkpt(runner_header=header)
            self._dispatcher.dispatch(self._state, runner_ckpt)

    def _set_task_status(self, state):
        update = TaskStatus(state=state,
                            timestamp_ms=int(self._clock.time() * 1000),
                            runner_pid=os.getpid(),
                            runner_uid=os.getuid())
        runner_ckpt = RunnerCkpt(task_status=update)
        self._dispatcher.dispatch(self._state, runner_ckpt, self._recovery)

    def _finalization_remaining(self):
        # If a preemption deadline has been set, use that.
        if self._preemption_deadline:
            return max(0, self._preemption_deadline - self._clock.time())

        # Otherwise, use the finalization wait provided in the configuration.
        finalization_allocation = self.task.finalization_wait().get()
        if self._finalization_start is None:
            return sys.float_info.max
        else:
            waited = max(0, self._clock.time() - self._finalization_start)
            return max(0, finalization_allocation - waited)

    def _set_process_status(self, process_name, process_state, **kw):
        if 'sequence_number' in kw:
            sequence_number = kw.pop('sequence_number')
            log.debug('_set_process_status(%s <= %s, seq=%s[force])',
                      process_name,
                      ProcessState._VALUES_TO_NAMES.get(process_state),
                      sequence_number)
        else:
            current_run = self._current_process_run(process_name)
            if not current_run:
                assert process_state == ProcessState.WAITING
                sequence_number = 0
            else:
                sequence_number = current_run.seq + 1
            log.debug('_set_process_status(%s <= %s, seq=%s[auto])',
                      process_name,
                      ProcessState._VALUES_TO_NAMES.get(process_state),
                      sequence_number)
        runner_ckpt = RunnerCkpt(
            process_status=ProcessStatus(process=process_name,
                                         state=process_state,
                                         seq=sequence_number,
                                         **kw))
        self._dispatcher.dispatch(self._state, runner_ckpt, self._recovery)

    def _task_process_from_process_name(self, process_name, sequence_number):
        """
      Construct a Process() object from a process_name, populated with its
      correct run number and fully interpolated commandline.
    """
        run_number = len(self.state.processes[process_name]) - 1
        pathspec = self._pathspec.given(process=process_name, run=run_number)
        process = self._process_map.get(process_name)
        if process is None:
            raise self.InternalError('FATAL: Could not find process: %s' %
                                     process_name)

        def close_ckpt_and_fork():
            pid = os.fork()
            if pid == 0 and self._ckpt is not None:
                self._ckpt.close()
            return pid

        (logger_destination, logger_mode, rotate_log_size,
         rotate_log_backups) = self._build_process_logger_args(process)

        return Process(process.name().get(),
                       process.cmdline().get(),
                       sequence_number,
                       pathspec,
                       self._sandbox,
                       self._user,
                       chroot=self._chroot,
                       fork=close_ckpt_and_fork,
                       logger_destination=logger_destination,
                       logger_mode=logger_mode,
                       rotate_log_size=rotate_log_size,
                       rotate_log_backups=rotate_log_backups,
                       preserve_env=self._preserve_env,
                       mesos_containerizer_path=self._mesos_containerizer_path,
                       container_sandbox=self._container_sandbox)

    _DEFAULT_LOGGER = Logger()
    _DEFAULT_ROTATION = RotatePolicy()

    def _build_process_logger_args(self, process):
        """
      Build the appropriate logging configuration based on flags + process
      configuration settings.

      If no configuration (neither flags nor process config), default to
      "standard" mode.
    """

        destination, mode, size, backups = (
            self._DEFAULT_LOGGER.destination().get(),
            self._DEFAULT_LOGGER.mode().get(), None, None)

        logger = process.logger()
        if logger is Empty:
            if self._process_logger_destination:
                destination = self._process_logger_destination
            if self._process_logger_mode:
                mode = self._process_logger_mode
        else:
            destination = logger.destination().get()
            mode = logger.mode().get()

        if mode == LoggerMode.ROTATE:
            size = Amount(self._DEFAULT_ROTATION.log_size().get(), Data.BYTES)
            backups = self._DEFAULT_ROTATION.backups().get()
            if logger is Empty:
                if self._rotate_log_size_mb:
                    size = Amount(self._rotate_log_size_mb, Data.MB)
                if self._rotate_log_backups:
                    backups = self._rotate_log_backups
            else:
                rotate = logger.rotate()
                if rotate is not Empty:
                    size = Amount(rotate.log_size().get(), Data.BYTES)
                    backups = rotate.backups().get()

        return destination, mode, size, backups

    def deadlocked(self, plan=None):
        """Check whether a plan is deadlocked, i.e. there are no running/runnable processes, and the
    plan is not complete."""
        plan = plan or self._regular_plan
        now = self._clock.time()
        running = list(plan.running)
        runnable = list(plan.runnable_at(now))
        waiting = list(plan.waiting_at(now))
        log.debug('running:%d runnable:%d waiting:%d complete:%s',
                  len(running), len(runnable), len(waiting),
                  plan.is_complete())
        return len(running + runnable +
                   waiting) == 0 and not plan.is_complete()

    def is_healthy(self):
        """Check whether the TaskRunner is healthy. A healthy TaskRunner is not deadlocked and has not
    reached its max_failures count."""
        max_failures = self._task.max_failures().get()
        deadlocked = self.deadlocked()
        under_failure_limit = max_failures == 0 or len(
            self._regular_plan.failed) < max_failures
        log.debug(
            'max_failures:%d failed:%d under_failure_limit:%s deadlocked:%s ==> health:%s',
            max_failures, len(self._regular_plan.failed), under_failure_limit,
            deadlocked, not deadlocked and under_failure_limit)
        return not deadlocked and under_failure_limit

    def _current_process_run(self, process_name):
        if process_name not in self._state.processes or len(
                self._state.processes[process_name]) == 0:
            return None
        return self._state.processes[process_name][-1]

    def is_process_lost(self, process_name):
        """Determine whether or not we should mark a task as LOST and do so if necessary."""
        current_run = self._current_process_run(process_name)
        if not current_run:
            raise self.InternalError('No current_run for process %s!' %
                                     process_name)

        def forked_but_never_came_up():
            return current_run.state == ProcessState.FORKED and (
                self._clock.time() - current_run.fork_time >
                self.LOST_TIMEOUT.as_(Time.SECONDS))

        def running_but_coordinator_died():
            if current_run.state != ProcessState.RUNNING:
                return False
            coordinator_pid, _, _ = TaskRunnerHelper.scan_process(
                self.state, process_name)
            if coordinator_pid is not None:
                return False
            elif self._watcher.has_data(process_name):
                return False
            return True

        if forked_but_never_came_up() or running_but_coordinator_died():
            log.info('Detected a LOST task: %s', current_run)
            log.debug('  forked_but_never_came_up: %s',
                      forked_but_never_came_up())
            log.debug('  running_but_coordinator_died: %s',
                      running_but_coordinator_died())
            return True

        return False

    def _run_plan(self, plan):
        log.debug('Schedule pass:'******'running: %s', ' '.join(plan.running))
        log.debug('finished: %s', ' '.join(plan.finished))

        launched = []
        for process_name in plan.running:
            if self.is_process_lost(process_name):
                self._set_process_status(process_name, ProcessState.LOST)

        now = self._clock.time()
        runnable = list(plan.runnable_at(now))
        waiting = list(plan.waiting_at(now))
        log.debug('runnable: %s', ' '.join(runnable))
        log.debug(
            'waiting: %s',
            ' '.join('%s[T-%.1fs]' % (process, plan.get_wait(process))
                     for process in waiting))

        def pick_processes(process_list):
            if self._task.max_concurrency().get() == 0:
                return process_list
            num_to_pick = max(
                self._task.max_concurrency().get() - len(running), 0)
            return process_list[:num_to_pick]

        for process_name in pick_processes(runnable):
            tp = self._task_processes.get(process_name)
            if tp:
                current_run = self._current_process_run(process_name)
                assert current_run.state == ProcessState.WAITING
            else:
                self._set_process_status(process_name, ProcessState.WAITING)
                tp = self._task_processes[process_name]
            log.info('Forking Process(%s)', process_name)
            try:
                tp.start()
                launched.append(tp)
            except Process.Error as e:
                log.error('Failed to launch process: %s', e)
                self._set_process_status(process_name, ProcessState.FAILED)

        return len(launched) > 0

    def _terminate_plan(self, plan):
        TaskRunnerHelper.terminate_orphans(self.state)

        for process in plan.running:
            last_run = self._current_process_run(process)
            if last_run and last_run.state in (ProcessState.FORKED,
                                               ProcessState.RUNNING):
                TaskRunnerHelper.terminate_process(self.state, process)

    def has_running_processes(self):
        """
      Returns True if any processes associated with this task have active pids.
    """
        process_tree = TaskRunnerHelper.scan_tree(self.state)
        return any(any(process_set) for process_set in process_tree.values())

    def has_active_processes(self):
        """
      Returns True if any processes are in non-terminal states.
    """
        return any(
            not TaskRunnerHelper.is_process_terminal(run.state)
            for run in filter(None, (self._current_process_run(process)
                                     for process in self.state.processes)))

    def collect_updates(self, timeout=None):
        """
      Collects and applies updates from process checkpoint streams.  Returns the number
      of applied process checkpoints.
    """
        if not self.has_active_processes():
            return 0

        sleep_interval = self.COORDINATOR_INTERVAL_SLEEP.as_(Time.SECONDS)
        total_time = 0.0

        while True:
            process_updates = self._watcher.select()
            for process_update in process_updates:
                self._dispatcher.dispatch(self._state, process_update,
                                          self._recovery)
            if process_updates:
                return len(process_updates)
            if timeout is not None and total_time >= timeout:
                return 0
            total_time += sleep_interval
            self._clock.sleep(sleep_interval)

    def is_terminal(self):
        return TaskRunnerHelper.is_task_terminal(self.task_state())

    def terminal_state(self):
        if self._terminal_state:
            log.debug(
                'Forced terminal state: %s' % TaskState._VALUES_TO_NAMES.get(
                    self._terminal_state, 'UNKNOWN'))
            return self._terminal_state
        else:
            return TaskState.SUCCESS if self.is_healthy() else TaskState.FAILED

    def run(self, force=False):
        """
      Entrypoint to runner. Assume control of checkpoint stream, and execute TaskRunnerStages
      until runner is terminal.
    """
        if self.is_terminal():
            return
        with self.control(force):
            self._run()

    def _run(self):
        while not self.is_terminal():
            start = self._clock.time()
            # step 1: execute stage corresponding to the state we're currently in
            runner = self._stages[self.task_state()]
            iteration_wait = runner.run()
            if iteration_wait is None:
                log.debug('Run loop: No more work to be done in state %s' %
                          TaskState._VALUES_TO_NAMES.get(
                              self.task_state(), 'UNKNOWN'))
                self._set_task_status(runner.transition_to())
                continue
            log.debug('Run loop: Work to be done within %.1fs', iteration_wait)
            # step 2: check child process checkpoint streams for updates
            if not self.collect_updates(iteration_wait):
                # If we don't collect any updates, at least 'touch' the checkpoint stream
                # so as to prevent garbage collection.
                elapsed = self._clock.time() - start
                if elapsed < iteration_wait:
                    log.debug(
                        'Update collection only took %.1fs, idling %.1fs',
                        elapsed, iteration_wait - elapsed)
                    self._clock.sleep(iteration_wait - elapsed)
                log.debug(
                    'Run loop: No updates collected, touching checkpoint.')
                os.utime(self._pathspec.getpath('runner_checkpoint'), None)
            # step 3: reap any zombie child processes
            TaskRunnerHelper.reap_children()

    def kill(self,
             force=False,
             terminal_status=TaskState.KILLED,
             preemption_wait=Amount(1, Time.MINUTES)):
        """
      Kill all processes associated with this task and set task/process states as terminal_status
      (defaults to KILLED)
    """
        log.debug('Runner issued kill: force:%s, preemption_wait:%s', force,
                  preemption_wait)
        assert terminal_status in (TaskState.KILLED, TaskState.LOST)
        self._preemption_deadline = self._clock.time() + preemption_wait.as_(
            Time.SECONDS)
        with self.control(force):
            if self.is_terminal():
                log.warning('Task is not in ACTIVE state, cannot issue kill.')
                return
            self._terminal_state = terminal_status
            if self.task_state() == TaskState.ACTIVE:
                self._set_task_status(TaskState.CLEANING)
            self._run()

    def lose(self, force=False):
        """
      Mark a task as LOST and kill any straggling processes.
    """
        self.kill(force,
                  preemption_wait=Amount(0, Time.SECONDS),
                  terminal_status=TaskState.LOST)

    def _kill(self):
        processes = TaskRunnerHelper.scan_tree(self._state)
        for process, pid_tuple in processes.items():
            current_run = self._current_process_run(process)
            coordinator_pid, pid, tree = pid_tuple
            if TaskRunnerHelper.is_process_terminal(current_run.state):
                if coordinator_pid or pid or tree:
                    log.warning(
                        'Terminal process (%s) still has running pids:',
                        process)
                    log.warning('  coordinator_pid: %s', coordinator_pid)
                    log.warning('              pid: %s', pid)
                    log.warning('             tree: %s', tree)
                TaskRunnerHelper.kill_process(self.state, process)
            else:
                if coordinator_pid or pid or tree:
                    log.info('Transitioning %s to KILLED', process)
                    self._set_process_status(process,
                                             ProcessState.KILLED,
                                             stop_time=self._clock.time(),
                                             return_code=-1)
                else:
                    log.info('Transitioning %s to LOST', process)
                    if current_run.state != ProcessState.WAITING:
                        self._set_process_status(process, ProcessState.LOST)
def initialize(options):
    path_detector = MesosPathDetector(options.mesos_root)
    polling_interval = Amount(options.polling_interval_secs, Time.SECONDS)
    return TaskObserver(path_detector, interval=polling_interval)
Exemple #25
0
class ZookeeperSchedulerClient(SchedulerClient):
    SERVERSET_TIMEOUT = Amount(10, Time.SECONDS)

    @classmethod
    def get_scheduler_serverset(cls, cluster, port=2181, verbose=False, **kw):
        if cluster.zk is None:
            raise ValueError('Cluster has no associated zookeeper ensemble!')
        if cluster.scheduler_zk_path is None:
            raise ValueError(
                'Cluster has no defined scheduler path, must specify scheduler_zk_path '
                'in your cluster config!')
        hosts = [h + ':{p}' for h in cluster.zk.split(',')]
        zk = TwitterKazooClient.make(str(','.join(hosts).format(p=port)),
                                     verbose=verbose)
        return zk, ServerSet(zk, cluster.scheduler_zk_path, **kw)

    def __init__(self,
                 cluster,
                 port=2181,
                 verbose=False,
                 _deadline=deadline,
                 **kwargs):
        SchedulerClient.__init__(self, verbose=verbose, **kwargs)
        self._cluster = cluster
        self._zkport = port
        self._endpoint = None
        self._uri = None
        self._deadline = _deadline

    def _resolve(self):
        """Resolve the uri associated with this scheduler from zookeeper."""
        joined = threading.Event()

        def on_join(elements):
            joined.set()

        zk, serverset = self.get_scheduler_serverset(self._cluster,
                                                     verbose=self._verbose,
                                                     port=self._zkport,
                                                     on_join=on_join)

        joined.wait(timeout=self.SERVERSET_TIMEOUT.as_(Time.SECONDS))

        try:
            # Need to perform this operation in a separate thread, because kazoo will wait for the
            # result of this serverset evaluation indefinitely, which will prevent people killing
            # the client with keyboard interrupts.
            serverset_endpoints = self._deadline(
                lambda: list(serverset),
                timeout=self.SERVERSET_TIMEOUT.as_(Time.SECONDS),
                daemon=True,
                propagate=True)
        except Timeout:
            raise self.CouldNotConnect(
                "Failed to connect to Zookeeper within %d seconds." %
                self.SERVERSET_TIMEOUT.as_(Time.SECONDS))

        if len(serverset_endpoints) == 0:
            raise self.CouldNotConnect('No schedulers detected in %s!' %
                                       self._cluster.name)
        instance = serverset_endpoints[0]
        if 'https' in instance.additional_endpoints:
            endpoint = instance.additional_endpoints['https']
            self._uri = 'https://%s:%s' % (endpoint.host, endpoint.port)
        elif 'http' in instance.additional_endpoints:
            endpoint = instance.additional_endpoints['http']
            self._uri = 'http://%s:%s' % (endpoint.host, endpoint.port)
        zk.stop()

    def _connect(self):
        if self._uri is None:
            self._resolve()
        if self._uri is not None:
            return self._connect_scheduler(urljoin(self._uri, 'api'))

    @property
    def url(self):
        proxy_url = self._cluster.proxy_url
        if proxy_url:
            return proxy_url
        return self.raw_url

    @property
    def raw_url(self):
        if self._uri is None:
            self._resolve()
        if self._uri:
            return self._uri
class SchedulerProxy(object):
    """
    This class is responsible for creating a reliable thrift client to the
    twitter scheduler.  Basically all the dirty work needed by the
    AuroraClientAPI.
  """
    CONNECT_MAXIMUM_WAIT = Amount(1, Time.MINUTES)
    RPC_RETRY_INTERVAL = Amount(5, Time.SECONDS)
    RPC_MAXIMUM_WAIT = Amount(10, Time.MINUTES)
    UNAUTHENTICATED_RPCS = frozenset([
        'populateJobConfig',
        'getTasksStatus',
        'getJobs',
        'getQuota',
        'getVersion',
    ])

    class Error(Exception):
        pass

    class TimeoutError(Error):
        pass

    class AuthenticationError(Error):
        pass

    class APIVersionError(Error):
        pass

    def __init__(self,
                 cluster,
                 verbose=False,
                 session_key_factory=make_session_key):
        """A callable session_key_factory should be provided for authentication"""
        self.cluster = cluster
        # TODO(Sathya): Make this a part of cluster trait when authentication is pushed to the transport
        # layer.
        self._session_key_factory = session_key_factory
        self._client = self._scheduler = None
        self.verbose = verbose

    def with_scheduler(method):
        """Decorator magic to make sure a connection is made to the scheduler"""
        def _wrapper(self, *args, **kwargs):
            if not self._scheduler:
                self._construct_scheduler()
            return method(self, *args, **kwargs)

        return _wrapper

    def invalidate(self):
        self._client = self._scheduler = None

    @with_scheduler
    def client(self):
        return self._client

    @with_scheduler
    def scheduler(self):
        return self._scheduler

    def session_key(self):
        try:
            return self._session_key_factory(self.cluster.auth_mechanism)
        except SessionKeyError as e:
            raise self.AuthenticationError('Unable to create session key %s' %
                                           e)

    def _construct_scheduler(self):
        """
      Populates:
        self._scheduler
        self._client
    """
        self._scheduler = SchedulerClient.get(self.cluster,
                                              verbose=self.verbose)
        assert self._scheduler, "Could not find scheduler (cluster = %s)" % self.cluster.name
        start = time.time()
        while (time.time() - start) < self.CONNECT_MAXIMUM_WAIT.as_(
                Time.SECONDS):
            try:
                self._client = self._scheduler.get_thrift_client()
                break
            except SchedulerClient.CouldNotConnect as e:
                log.warning('Could not connect to scheduler: %s' % e)
        if not self._client:
            raise self.TimeoutError(
                'Timed out trying to connect to scheduler at %s' %
                self.cluster.name)

        server_version = self._client.getVersion().result.getVersionResult
        if server_version != CURRENT_API_VERSION:
            raise self.APIVersionError(
                "Client Version: %s, Server Version: %s" %
                (CURRENT_API_VERSION, server_version))

    def __getattr__(self, method_name):
        # If the method does not exist, getattr will return AttributeError for us.
        method = getattr(AuroraAdmin.Client, method_name)
        if not callable(method):
            return method

        @functools.wraps(method)
        def method_wrapper(*args):
            start = time.time()
            while (time.time() - start) < self.RPC_MAXIMUM_WAIT.as_(
                    Time.SECONDS):
                auth_args = (
                ) if method_name in self.UNAUTHENTICATED_RPCS else (
                    self.session_key(), )
                try:
                    method = getattr(self.client(), method_name)
                    if not callable(method):
                        return method
                    return method(*(args + auth_args))
                except (TTransport.TTransportException,
                        self.TimeoutError) as e:
                    log.warning(
                        'Connection error with scheduler: %s, reconnecting...'
                        % e)
                    self.invalidate()
                    time.sleep(self.RPC_RETRY_INTERVAL.as_(Time.SECONDS))
            raise self.TimeoutError('Timed out attempting to issue %s to %s' %
                                    (method_name, self.cluster.name))

        return method_wrapper
Exemple #27
0
class SchedulerClient(object):
    THRIFT_RETRIES = 5
    RETRY_TIMEOUT = Amount(1, Time.SECONDS)

    class Error(Exception):
        pass

    class CouldNotConnect(Error):
        pass

    # TODO(wickman) Refactor per MESOS-3005 into two separate classes with separate traits:
    #   ZookeeperClientTrait
    #   DirectClientTrait
    @classmethod
    def get(cls, cluster, auth_factory=get_auth_handler, **kwargs):
        if not isinstance(cluster, Cluster):
            raise TypeError(
                '"cluster" must be an instance of Cluster, got %s' %
                type(cluster))
        cluster = cluster.with_trait(SchedulerClientTrait)
        auth_handler = auth_factory(cluster.auth_mechanism)
        if cluster.zk:
            return ZookeeperSchedulerClient(cluster,
                                            port=cluster.zk_port,
                                            auth=auth_handler,
                                            **kwargs)
        elif cluster.scheduler_uri:
            return DirectSchedulerClient(cluster.scheduler_uri,
                                         auth=auth_handler,
                                         **kwargs)
        else:
            raise ValueError('"cluster" does not specify zk or scheduler_uri')

    def __init__(self, auth, user_agent, verbose=False):
        self._client = None
        self._auth_handler = auth
        self._user_agent = user_agent
        self._verbose = verbose

    def get_thrift_client(self):
        if self._client is None:
            self._client = self._connect()
        return self._client

    def get_failed_auth_message(self):
        return self._auth_handler.failed_auth_message

    # per-class implementation -- mostly meant to set up a valid host/port
    # pair and then delegate the opening to SchedulerClient._connect_scheduler
    def _connect(self):
        return None

    def _connect_scheduler(self, uri, clock=time):
        transport = TRequestsTransport(uri,
                                       auth=self._auth_handler.auth(),
                                       user_agent=self._user_agent)
        protocol = TJSONProtocol.TJSONProtocol(transport)
        schedulerClient = AuroraAdmin.Client(protocol)
        for _ in range(self.THRIFT_RETRIES):
            try:
                transport.open()
                return schedulerClient
            except TTransport.TTransportException:
                clock.sleep(self.RETRY_TIMEOUT.as_(Time.SECONDS))
                continue
            except Exception as e:
                # Monkey-patched proxies, like socks, can generate a proxy error here.
                # without adding a dependency, we can't catch those in a more specific way.
                raise self.CouldNotConnect(
                    'Connection to scheduler failed: %s' % e)
        raise self.CouldNotConnect('Could not connect to %s' % uri)
class SchedulerClient(object):
    THRIFT_RETRIES = 5
    RETRY_TIMEOUT = Amount(1, Time.SECONDS)

    class CouldNotConnect(Exception):
        pass

    # TODO(wickman) Refactor per MESOS-3005 into two separate classes with separate traits:
    #   ZookeeperClientTrait
    #   DirectClientTrait
    @classmethod
    def get(cls, cluster, **kwargs):
        if not isinstance(cluster, Cluster):
            raise TypeError(
                '"cluster" must be an instance of Cluster, got %s' %
                type(cluster))
        cluster = cluster.with_trait(SchedulerClientTrait)
        if cluster.zk:
            return ZookeeperSchedulerClient(cluster,
                                            port=cluster.zk_port,
                                            ssl=cluster.use_thrift_ssl,
                                            **kwargs)
        elif cluster.scheduler_uri:
            try:
                host, port = cluster.scheduler_uri.split(':', 2)
                port = int(port)
            except ValueError:
                raise ValueError('Malformed Cluster scheduler_uri: %s' %
                                 cluster.scheduler_uri)
            return DirectSchedulerClient(host,
                                         port,
                                         ssl=cluster.use_thrift_ssl)
        else:
            raise ValueError('"cluster" does not specify zk or scheduler_uri')

    def __init__(self, verbose=False, ssl=False):
        self._client = None
        self._verbose = verbose
        self._ssl = ssl

    def get_thrift_client(self):
        if self._client is None:
            self._client = self._connect()
        return self._client

    # per-class implementation -- mostly meant to set up a valid host/port
    # pair and then delegate the opening to SchedulerClient._connect_scheduler
    def _connect(self):
        return None

    @staticmethod
    def _connect_scheduler(host, port, with_ssl=False):
        if with_ssl:
            socket = DelayedHandshakeTSSLSocket(host,
                                                port,
                                                delay_handshake=True,
                                                validate=False)
        else:
            socket = TSocket.TSocket(host, port)
        transport = TTransport.TBufferedTransport(socket)
        protocol = TBinaryProtocol.TBinaryProtocol(transport)
        schedulerClient = AuroraAdmin.Client(protocol)
        for _ in range(SchedulerClient.THRIFT_RETRIES):
            try:
                transport.open()
                return schedulerClient
            except TTransport.TTransportException:
                time.sleep(SchedulerClient.RETRY_TIMEOUT.as_(Time.SECONDS))
                continue
            except Exception as e:
                # Monkey-patched proxies, like socks, can generate a proxy error here.
                # without adding a dependency, we can't catch those in a more specific way.
                raise SchedulerClient.CouldNotConnect(
                    'Connection to scheduler failed: %s' % e)
        raise SchedulerClient.CouldNotConnect('Could not connect to %s:%s' %
                                              (host, port))
Exemple #29
0
class ThermosTaskRunner(TaskRunner):
    ESCALATION_WAIT = Amount(5, Time.SECONDS)
    EXIT_STATE_MAP = {
        TaskState.ACTIVE:
        StatusResult('Runner died while task was active.', ExitState.LOST),
        TaskState.FAILED:
        StatusResult('Task failed.', ExitState.FAILED),
        TaskState.KILLED:
        StatusResult('Task killed.', ExitState.KILLED),
        TaskState.LOST:
        StatusResult('Task lost.', ExitState.LOST),
        TaskState.SUCCESS:
        StatusResult('Task finished.', ExitState.FINISHED),
    }
    MAX_WAIT = Amount(1, Time.MINUTES)
    PEX_NAME = 'thermos_runner.pex'
    POLL_INTERVAL = Amount(500, Time.MILLISECONDS)
    THERMOS_PREEMPTION_WAIT = Amount(1, Time.MINUTES)

    def __init__(self,
                 runner_pex,
                 task_id,
                 task,
                 role,
                 portmap,
                 sandbox,
                 checkpoint_root=None,
                 artifact_dir=None,
                 clock=time):
        """
      runner_pex       location of the thermos_runner pex that this task runner should use
      task_id          task_id assigned by scheduler
      task             thermos pystachio Task object
      role             role to run the task under
      portmap          { name => port } dictionary
      sandbox          the sandbox object
      checkpoint_root  the checkpoint root for the thermos runner
      artifact_dir     scratch space for the thermos runner (basically cwd of thermos.pex)
      clock            clock
    """
        self._runner_pex = runner_pex
        self._task_id = task_id
        self._task = task
        self._popen = None
        self._monitor = None
        self._status = None
        self._ports = portmap
        self._root = sandbox.root
        self._checkpoint_root = checkpoint_root or TaskPath.DEFAULT_CHECKPOINT_ROOT
        self._enable_chroot = sandbox.chrooted
        self._role = role
        self._clock = clock
        self._artifact_dir = artifact_dir or safe_mkdtemp()

        # wait events
        self._dead = threading.Event()
        self._kill_signal = threading.Event()
        self.forking = threading.Event()
        self.forked = threading.Event()

        try:
            with open(os.path.join(self._artifact_dir, 'task.json'),
                      'w') as fp:
                self._task_filename = fp.name
                ThermosTaskWrapper(self._task).to_file(self._task_filename)
        except ThermosTaskWrapper.InvalidTask as e:
            raise TaskError('Failed to load task: %s' % e)

    def _terminate_http(self):
        if 'health' not in self._ports:
            return

        http_signaler = HttpSignaler(self._ports['health'])

        # pass 1
        http_signaler.quitquitquit()
        self._clock.sleep(self.ESCALATION_WAIT.as_(Time.SECONDS))
        if self.status is not None:
            return True

        # pass 2
        http_signaler.abortabortabort()
        self._clock.sleep(self.ESCALATION_WAIT.as_(Time.SECONDS))
        if self.status is not None:
            return True

    @property
    def artifact_dir(self):
        return self._artifact_dir

    def task_state(self):
        return self._monitor.task_state() if self._monitor else None

    @property
    def is_alive(self):
        """
      Is the process underlying the Thermos task runner alive?
    """
        if not self._popen:
            return False
        if self._dead.is_set():
            return False

        # N.B. You cannot mix this code and any code that relies upon os.wait
        # mechanisms with blanket child process collection.  One example is the
        # Thermos task runner which calls os.wait4 -- without refactoring, you
        # should not mix a Thermos task runner in the same process as this
        # thread.
        try:
            pid, _ = os.waitpid(self._popen.pid, os.WNOHANG)
            if pid == 0:
                return True
            else:
                log.info('Detected runner termination: pid=%s' % pid)
        except OSError as e:
            log.error('is_alive got OSError: %s' % e)
            if e.errno != errno.ECHILD:
                raise

        self._dead.set()
        return False

    def compute_status(self):
        if self.is_alive:
            return None
        exit_state = self.EXIT_STATE_MAP.get(self.task_state())
        if exit_state is None:
            log.error('Received unexpected exit state from TaskMonitor.')
        return exit_state

    def terminate_runner(self, as_loss=False):
        """
      Terminate the underlying runner process, if it exists.
    """
        if self._kill_signal.is_set():
            log.warning('Duplicate kill/lose signal received, ignoring.')
            return
        self._kill_signal.set()
        if self.is_alive:
            sig = 'SIGUSR2' if as_loss else 'SIGUSR1'
            log.info('Runner is alive, sending %s' % sig)
            try:
                self._popen.send_signal(getattr(signal, sig))
            except OSError as e:
                log.error('Got OSError sending %s: %s' % (sig, e))
        else:
            log.info('Runner is dead, skipping kill.')

    def kill(self):
        self.terminate_runner()

    def lose(self):
        self.terminate_runner(as_loss=True)

    def quitquitquit(self):
        """Bind to the process tree of a Thermos task and kill it with impunity."""
        try:
            runner = core.TaskRunner.get(self._task_id, self._checkpoint_root)
            if runner:
                log.info('quitquitquit calling runner.kill')
                # Right now preemption wait is hardcoded, though it may become configurable in the future.
                runner.kill(force=True,
                            preemption_wait=self.THERMOS_PREEMPTION_WAIT)
            else:
                log.error('Could not instantiate runner!')
        except core.TaskRunner.Error as e:
            log.error('Could not quitquitquit runner: %s' % e)

    def _cmdline(self):
        params = dict(log_dir=LogOptions.log_dir(),
                      log_to_disk='DEBUG',
                      checkpoint_root=self._checkpoint_root,
                      sandbox=self._root,
                      task_id=self._task_id,
                      thermos_json=self._task_filename)

        if getpass.getuser() == 'root':
            params.update(setuid=self._role)

        cmdline_args = [self._runner_pex]
        cmdline_args.extend('--%s=%s' % (flag, value)
                            for flag, value in params.items())
        if self._enable_chroot:
            cmdline_args.extend(['--enable_chroot'])
        for name, port in self._ports.items():
            cmdline_args.extend(['--port=%s:%s' % (name, port)])
        return cmdline_args

    # --- public interface
    def start(self, timeout=MAX_WAIT):
        """Fork the task runner and return once the underlying task is running, up to timeout."""
        self.forking.set()

        try:
            chmod_plus_x(self._runner_pex)
        except OSError as e:
            if e.errno != errno.EPERM:
                raise TaskError('Failed to chmod +x runner: %s' % e)

        self._monitor = TaskMonitor(TaskPath(root=self._checkpoint_root),
                                    self._task_id)

        cmdline_args = self._cmdline()
        log.info('Forking off runner with cmdline: %s' %
                 ' '.join(cmdline_args))

        try:
            self._popen = subprocess.Popen(cmdline_args)
        except OSError as e:
            raise TaskError(e)

        self.forked.set()

        log.debug('Waiting for task to start.')

        def is_started():
            return self._monitor and (self._monitor.active
                                      or self._monitor.finished)

        waited = Amount(0, Time.SECONDS)
        while not is_started() and waited < timeout:
            log.debug('  - sleeping...')
            self._clock.sleep(self.POLL_INTERVAL.as_(Time.SECONDS))
            waited += self.POLL_INTERVAL

        if not is_started():
            log.error('Task did not start with in deadline, forcing loss.')
            self.lose()
            raise TaskError('Task did not start within deadline.')

    def stop(self, timeout=MAX_WAIT):
        """Stop the runner.  If it's already completed, no-op.  If it's still running, issue a kill."""
        log.info('ThermosTaskRunner is shutting down.')

        if not self.forking.is_set():
            raise TaskError('Failed to call TaskRunner.start.')

        log.info('Invoking runner HTTP teardown.')
        self._terminate_http()

        log.info('Invoking runner.kill')
        self.kill()

        waited = Amount(0, Time.SECONDS)
        while self.is_alive and waited < timeout:
            self._clock.sleep(self.POLL_INTERVAL.as_(Time.SECONDS))
            waited += self.POLL_INTERVAL

        if not self.is_alive and self.task_state() != TaskState.ACTIVE:
            return

        log.info('Thermos task did not shut down cleanly, rebinding to kill.')
        self.quitquitquit()

        while not self._monitor.finished and waited < timeout:
            self._clock.sleep(self.POLL_INTERVAL.as_(Time.SECONDS))
            waited += self.POLL_INTERVAL

        if not self._monitor.finished:
            raise TaskError('Task did not stop within deadline.')

    @property
    def status(self):
        """Return the StatusResult of this task runner.  This returns None as
       long as no terminal state is reached."""
        if self._status is None:
            self._status = self.compute_status()
        return self._status
Exemple #30
0
    def run(self):
        tasks = []
        now = time.time()

        # age: The time (in seconds) since the last task transition to/from ACTIVE/FINISHED
        # metadata_size: The size of the thermos checkpoint records for this task
        # log_size: The size of the stdout/stderr logs for this task's processes
        # data_size: The size of the sandbox of this task.
        TaskTuple = namedtuple(
            'TaskTuple',
            'checkpoint_root task_id age metadata_size log_size data_size')

        for checkpoint_root, task_id in self.get_finished_tasks():
            collector = TaskGarbageCollector(checkpoint_root, task_id)

            age = Amount(int(now - collector.get_age()), Time.SECONDS)
            self.log('Analyzing task %s (age: %s)... ' % (task_id, age))
            metadata_size = Amount(
                sum(sz for _, sz in collector.get_metadata()), Data.BYTES)
            self.log('  metadata %.1fKB ' % metadata_size.as_(Data.KB))
            log_size = Amount(sum(sz for _, sz in collector.get_logs()),
                              Data.BYTES)
            self.log('  logs %.1fKB ' % log_size.as_(Data.KB))
            data_size = Amount(sum(sz for _, sz in collector.get_data()),
                               Data.BYTES)
            self.log('  data %.1fMB ' % data_size.as_(Data.MB))
            tasks.append(
                TaskTuple(checkpoint_root, task_id, age, metadata_size,
                          log_size, data_size))

        gc_tasks = set()
        gc_tasks.update(task for task in tasks if task.age > self._max_age)

        self.log('After age filter: %s tasks' % len(gc_tasks))

        def total_gc_size(task):
            return sum([
                task.data_size, task.metadata_size
                if self._include_metadata else Amount(0, Data.BYTES),
                task.log_size if self._include_logs else Amount(0, Data.BYTES)
            ], Amount(0, Data.BYTES))

        total_used = Amount(0, Data.BYTES)
        for task in sorted(tasks, key=lambda tsk: tsk.age, reverse=True):
            if task not in gc_tasks:
                total_used += total_gc_size(task)
                if total_used > self._max_space:
                    gc_tasks.add(task)
        self.log('After size filter: %s tasks' % len(gc_tasks))

        for task in sorted(tasks, key=lambda tsk: tsk.age, reverse=True):
            if task not in gc_tasks and len(tasks) - len(
                    gc_tasks) > self._max_tasks:
                gc_tasks.add(task)
        self.log('After total task filter: %s tasks' % len(gc_tasks))

        self.log('Deciding to garbage collect the following tasks:')
        if gc_tasks:
            for task in gc_tasks:
                self.log('   %s' % repr(task))
        else:
            self.log('   None.')

        return gc_tasks