Exemple #1
Exemple #2
class AgencyAgent(log.LogProxy, log.Logger, manhole.Manhole,

    implements(IAgencyAgent, IAgencyAgentInternal, ITimeProvider,
               IRecorderNode, IJournalKeeper, ISerializable, IMessagingPeer)

    type_name = "agent-medium" # this is used by ISerializable

    _error_handler = error_handler

    journal_parent = None

    def __init__(self, agency, factory, descriptor):
        log.LogProxy.__init__(self, agency)
        log.Logger.__init__(self, self)

        self.journal_keeper = self
        self.agency = IAgency(agency)
        self._descriptor = descriptor
        # Our instance id. It is used to tell the difference between the
        # journal entries comming from different agencies running the same
        # agent. Our value will be stored in descriptor before calling anything
        # on the agent side, although it needs to be set now to produce valid
        # identifiers.
        self._instance_id = descriptor.instance_id + 1

        self.log_name = descriptor.doc_id
        self.log_category = descriptor.document_type

        self.agent = factory(self)
        self.log('Instantiated the %r instance', self.agent)

        self._protocols = {} # {puid: IAgencyProtocolInternal}
        self._interests = {} # {protocol_type: {protocol_id: IInterest}}
        self._long_running_protocols = [] # Long running protocols

        self._messaging = None
        self._database = None
        self._configuration = None

        self._updating = False
        self._update_queue = []
        self._delayed_calls = container.ExpDict(self)
        # Terminating flag, used to not to run
        # termination procedure more than once
        self._terminating = False

        # traversal_id -> True
        self._traversal_ids = container.ExpDict(self)

        self._entries_since_snapshot = 0

    ### Public Methods ###

    def initiate(self, **kwargs):
        '''Establishes the connections to database and messaging platform,
        taking into account that it might meen performing asynchronous job.'''
        run_startup = kwargs.pop('run_startup', True)

        setter = lambda value, name: setattr(self, name, value)

        d = defer.Deferred()
                      self.agency._messaging.get_connection, self)
        d.addCallback(setter, '_messaging')
        d.addCallback(setter, '_database')
        d.addCallback(defer.drop_param, self._store_instance_id)
        d.addCallback(defer.drop_param, self._load_configuration)
        d.addCallback(setter, '_configuration')
                      self.join_shard, self._descriptor.shard)
                      self._call_initiate, **kwargs)
        d.addCallback(defer.drop_param, self.call_next, self._call_startup,
        d.addCallback(defer.override_result, self)

        # Ensure the execution chain is broken
        self.call_next(d.callback, None)

        return d

    def get_agent_id(self):
        return self._descriptor.doc_id

    def get_full_id(self):
        desc = self._descriptor
        return desc.doc_id + u"/" + unicode(desc.instance_id)

    def snapshot_agent(self):
        '''Gives snapshot of everything related to the agent'''
        protocols = [i.get_agent_side() for i in self._protocols.values()]
        return (self.agent, protocols, )

    def journal_agent_created(self):
        factory = type(self.agent)
            self._descriptor.doc_id, self._instance_id,
            factory, self.snapshot())

    def check_if_should_snapshot(self, force=False):
        if force or self._entries_since_snapshot > MIN_ENTRIES_PER_SNAPSHOT:
            self.log('Skipping snapshot, number of entries %d < %d',
                     self._entries_since_snapshot, MIN_ENTRIES_PER_SNAPSHOT)

    def journal_snapshot(self):
        # Remove all the entries for the agent from  the registry,
        # so that snapshot contains full objects not just the references
        agent_id = self._descriptor.doc_id
        self._entries_since_snapshot = 0
            agent_id, self._instance_id, self.snapshot_agent())

    def journal_protocol_created(self, *args, **kwargs):
                                             *args, **kwargs)

    def start_agent(self, desc, **kwargs):
        return self.agency.start_agent(desc, **kwargs)

    def check_if_hosted(self, agent_id):
        d = self.agency.find_agent(agent_id)
        return d

    def on_killed(self):
        '''called as part of SIGTERM handler.'''

        def generate_body():
            d = defer.succeed(None)
            # run IAgent.killed() and wait for the protocols to finish the job
            d.addBoth(self._run_and_wait, self.agent.on_agent_killed)
            return d

        return self._terminate_procedure(generate_body)

    ### IAgencyAgent Methods ###

    def observe(self, _method, *args, **kwargs):
        res = common.Observer(_method, *args, **kwargs)
        return res

    def get_hostname(self):
        return self.agency.get_hostname()

    def get_ip(self):
        return self.agency.get_ip()

    def get_descriptor(self):
        return copy.deepcopy(self._descriptor)

    def get_configuration(self):
        if self._configuration is None:
            raise RuntimeError(
                'Agent requested to get his configuration, but it was not '
                'found. The metadocument with ID %r is not in database. ' %\
                (self.agent.configuration_doc_id, ))

        return copy.deepcopy(self._configuration)

    def update_descriptor(self, function, *args, **kwargs):
        d = defer.Deferred()
        self._update_queue.append((d, function, args, kwargs))
        return d

    def join_shard(self, shard):
        self.log("Joining shard %r", shard)
        # Rebind agents queue
        binding = self.create_binding(self._descriptor.doc_id, shard)
        # Iterate over interest and create bindings
        bindings = [x.bind(shard) for x in self._iter_interests()]
        # Remove None elements (private interests)
        bindings = [x for x in bindings if x]
        bindings = [binding] + bindings
        return defer.DeferredList([x.created for x in bindings])

    def upgrade_agency(self, upgrade_cmd):
        self.call_next(self.agency.upgrade, upgrade_cmd)

    def leave_shard(self, shard):
        self.log("Leaving shard %r", shard)
        bindings = self._messaging.get_bindings(shard)
        return defer.DeferredList([x.revoke() for x in bindings])

    def register_interest(self, agent_factory, *args, **kwargs):
        agent_factory = IInterest(agent_factory)
        if not IFirstMessage.implementedBy(agent_factory.initiator):
            raise TypeError(
                "%r.initiator expected to implemented IFirstMessage. Got %r" %\
                (agent_factory, agent_factory.initiator, ))
        p_type = agent_factory.protocol_type
        p_id = agent_factory.protocol_id
        if p_type not in self._interests:
            self._interests[p_type] = dict()
        if p_id in self._interests[p_type]:
            self.error('Already interested in %s.%s protocol', p_type, p_id)
            return False
        interest_factory = IAgencyInterestInternalFactory(agent_factory)
        interest = interest_factory(self, *args, **kwargs)
        self._interests[p_type][p_id] = interest
        self.debug('Registered interest in %s.%s protocol', p_type, p_id)
        return interest

    def revoke_interest(self, agent_factory):
        agent_factory = IInterest(agent_factory)
        p_type = agent_factory.protocol_type
        p_id = agent_factory.protocol_id
        if (p_type not in self._interests
            or p_id not in self._interests[p_type]):
            self.error('Requested to revoke interest we are not interested in'
                       ' %s.%s', p_type, p_id)
            return False

        return True

    def initiate_protocol(self, factory, *args, **kwargs):
        return self._initiate_protocol(factory, args, kwargs)

    def retrying_protocol(self, factory, recipients=None,
                          max_retries=None, initial_delay=1,
                          max_delay=None, args=None, kwargs=None):
        #FIXME: this is not needed in agency side API, could be in agent
        Factory = retrying.RetryingProtocolFactory
        factory = Factory(factory, max_retries=max_retries,
                          initial_delay=initial_delay, max_delay=max_delay)
        if recipients is not None:
            args = (recipients, ) + args if args else (recipients, )
        return self._initiate_protocol(factory, args, kwargs)

    def periodic_protocol(self, factory, period, *args, **kwargs):
        #FIXME: this is not needed in agency side API, could be in agent
        factory = periodic.PeriodicProtocolFactory(factory, period)
        return self._initiate_protocol(factory, args, kwargs)

    def initiate_task(self, *args, **kwargs):
        warnings.warn("initiate_task() is deprecated, "
                      "please use initiate_protocol()",
        return self.initiate_protocol(*args, **kwargs)

    def retrying_task(self, *args, **kwargs):
        warnings.warn("retrying_task() is deprecated, "
                      "please use retrying_protocol()",
        return self.retrying_protocol(*args, **kwargs)

    def save_document(self, document):
        return self._database.save_document(document)

    def get_document(self, document_id):
        return self._database.get_document(document_id)

    def reload_document(self, document):
        return self._database.reload_document(document)

    def delete_document(self, document):
        return self._database.delete_document(document)

    def query_view(self, factory, **options):
        return self._database.query_view(factory, **options)

    def terminate(self):

    # get_mode() comes from dependency.AgencyAgentDependencyMixin

    def call_next(self, method, *args, **kwargs):
        return self.call_later_ex(0, method, args, kwargs)

    def call_later(self, time_left, method, *args, **kwargs):
        return self.call_later_ex(time_left, method, args, kwargs)

    def call_later_ex(self, time_left, method,
                      args=None, kwargs=None, busy=True):
        args = args or []
        kwargs = kwargs or {}
        call = time.callLater(time_left, self._call, method,
                              *args, **kwargs)
        call_id = str(uuid.uuid1())
        self._store_delayed_call(call_id, call, busy)
        return call_id

    def cancel_delayed_call(self, call_id):
            _busy, call = self._delayed_calls.remove(call_id)
        except KeyError:
            self.warning('Tried to cancel nonexisting call id: %r', call_id)

        self.log('Canceling delayed call with id %r (active: %s)',
                 call_id, call.active())
        if not call.active():
            self.log('Tried to cancel nonactive call id: %r', call_id)


    def get_machine_state(self):
        return self._get_machine_state()

    ### ITimeProvider Methods ###

    def get_time(self):
        return self.agency.get_time()

    ### IRecorderNode Methods ###

    def generate_identifier(self, recorder):
        assert not getattr(self, 'indentifier_generated', False)
        self._identifier_generated = True
        return (self._descriptor.doc_id, self._instance_id, )

    ### IJournalKeeper Methods ###

    def register(self, recorder):

    def new_entry(self, journal_id, function_id, *args, **kwargs):
        self._entries_since_snapshot += 1
        return self.agency.journal_new_entry(self._descriptor.doc_id,
                                             journal_id, function_id,
                                             *args, **kwargs)

    ### ISerializable Methods ###

    def snapshot(self):
        return (self._descriptor.doc_id, self._instance_id, )

    ### IAgencyAgentInternal Methods ###

    def create_binding(self, key, shard=None):
        '''Used by Interest instances.'''
        return self._messaging.personal_binding(key, shard)

    def register_protocol(self, protocol):
        protocol = IAgencyProtocolInternal(protocol)
        self.log('Registering protocol guid: %r', protocol.guid)
        assert protocol.guid not in self._protocols
        self._protocols[protocol.guid] = protocol
        return protocol

    def unregister_protocol(self, protocol):
        if protocol.guid in self._protocols:
            self.log('Unregistering protocol guid: %r', protocol.guid)
            protocol = self._protocols[protocol.guid]
                self._descriptor.doc_id, self._instance_id,
                protocol.get_agent_side(), protocol.snapshot())
            del self._protocols[protocol.guid]
            self.error('Tried to unregister protocol with guid: %r, '
                        'but not found!', protocol.guid)

    def send_msg(self, recipients, msg, handover=False):
        recipients = recipient.IRecipients(recipients)
        if not handover:
            msg.reply_to = recipient.IRecipient(self)
            msg.message_id = str(uuid.uuid1())
        assert msg.expiration_time is not None
        for recp in recipients:
            self.log('Sending message to %r', recp)
            self._messaging.publish(recp.key, recp.shard, msg)
        return msg

    ### IMessagingPeer Methods ###

    def on_message(self, msg):
        When a message with an already knwon traversal_id is received,
        we try to build a duplication message and send it in to a protocol
        dependent recipient. This is used in contracts traversing
        the graph, when the contract has rereached the same shard.
        This message is necessary, as silently ignoring the incoming bids
        adds a lot of latency to the nested contracts (it is waitng to receive
        message from all the recipients).
        self.log('Received message: %r', msg)

        # Check if it isn't expired message
        time_left = time.left(msg.expiration_time)
        if time_left < 0:
            self.log('Throwing away expired message. Time left: %s, '
                     'msg_class: %r', time_left, msg.get_msg_class())
            return False

        # Check for known traversal ids:
        if IFirstMessage.providedBy(msg):
            t_id = msg.traversal_id
            if t_id is None:
                    "Received corrupted message. The traversal_id is None ! "
                    "Message: %r", msg)
                return False
            if t_id in self._traversal_ids:
                self.log('Throwing away already known traversal id %r, '
                         'msg_class: %r', msg.get_msg_class(), t_id)
                recp = msg.duplication_recipient()
                if recp:
                    resp = msg.duplication_message()
                    self.send_msg(recp, resp)
                return False
                self._traversal_ids.set(t_id, True, msg.expiration_time)

        # Handle registered dialog
        if IDialogMessage.providedBy(msg):
            recv_id = msg.receiver_id
            if recv_id is not None and recv_id in self._protocols:
                protocol = self._protocols[recv_id]
                return True

        # Handle new conversation coming in (interest)
        p_type = msg.protocol_type
        if p_type in self._interests:
            p_id = msg.protocol_id
            interest = self._interests[p_type].get(p_id)
            if interest and interest.schedule_message(msg):
                return True

        self.warning("Couldn't find appropriate protocol for message: "
                     "%s", msg.get_msg_class())
        return False

    def get_queue_name(self):
        return self._descriptor.doc_id

    def get_shard_name(self):
        return self._descriptor.shard

    ### Introspection Methods ###

    def get_agent(self):
        '''get_agent() -> Returns the agent side instance.'''
        return self.agent

    def list_partners(self):
        t = text_helper.Table(fields=["Partner", "Id", "Shard", "Role"],
                  lengths = [20, 35, 35, 10])

        partners = self.agent.query_partners('all')
        return t.render((type(p).__name__, p.recipient.key,
                         p.recipient.shard, p.role)
                        for p in partners)

    def list_resource(self):
        t = text_helper.Table(fields=["Name", "Totals", "Allocated"],
                  lengths = [20, 20, 20])
        totals, allocated = self.agent.list_resource()

        def iter(totals, allocated):
            for x in totals:
                yield x, totals[x], allocated[x]

        return t.render(iter(totals, allocated))

    ### Protected Methods ###

    def wait_for_protocols_finish(self):
        '''Used by tests.'''

        def wait_for_protocol(protocol):
            d = protocol.notify_finish()
            d.addErrback(Failure.trap, ProtocolFailed)
            return d

        a = [interest.wait_finished() for interest in self._iter_interests()]
        b = [wait_for_protocol(l) for l in self._protocols.itervalues()]
        return defer.DeferredList(a + b)

    def is_idle(self):
        return (self.is_ready()
                and self.has_empty_protocols()
                and self.has_all_interests_idle()
                and not self.has_busy_calls()
                and self.has_all_long_running_protocols_idle())

    def is_ready(self):
        return self._cmp_state(AgencyAgentState.ready)

    def has_empty_protocols(self):
        return (len([l for l in self._protocols.itervalues()
                     if not l.is_idle()]) == 0)

    def has_busy_calls(self):
        for busy, call in self._delayed_calls.itervalues():
            if busy and call.active():
                return True
        return False

    def has_all_interests_idle(self):
        return all(i.is_idle() for i in self._iter_interests())

    def has_all_long_running_protocols_idle(self):
        return all(i.is_idle() for i in self._long_running_protocols)

    def show_activity(self):
        if self.is_idle():
            return None
        resp = "\n%r id: %r\n state: %r" % \
               (self.agent.__class__.__name__, self.get_descriptor().doc_id,
        if not self.has_empty_protocols():
            resp += '\nprotocols: \n'
            t = text_helper.Table(fields=["Class"], lengths = [60])
            resp += t.render((i.get_agent_side().__class__.__name__, ) \
                             for i in self._protocols.itervalues())
        if self.has_busy_calls():
            resp += "\nbusy calls: \n"
            t = text_helper.Table(fields=["Call"], lengths = [60])
            resp += t.render((str(call), ) \
                             for busy, call in self._delayed_calls.itervalues()
                             if busy and call.active())

        if not self.has_all_interests_idle():
            resp += "\nInterests not idle: \n"
            t = text_helper.Table(fields=["Factory"], lengths = [60])
            resp += t.render((str(call.agent_factory), ) \
                             for call in self._iter_interests())
        resp += "#" * 60
        return resp

    def on_disconnect(self):
        if self._cmp_state(AgencyAgentState.ready):

    def on_reconnect(self):
        if self._cmp_state(AgencyAgentState.disconnected):

    ### Private Methods ###

    def _initiate_protocol(self, factory, args, kwargs):
        self.log('Initiating protocol for factory: %r, args: %r, kwargs: %r',
                 factory, args, kwargs)
        args = args or ()
        kwargs = kwargs or {}
        factory = IInitiatorFactory(factory)
        medium_factory = IAgencyInitiatorFactory(factory)
        medium = medium_factory(self, *args, **kwargs)
        if ILongRunningProtocol.providedBy(medium):
            cb = lambda _: self._long_running_protocols.remove(medium)
        return medium.initiate()

    def _subscribe_for_descriptor_changes(self):
        return self._database.changes_listener(
            (self._descriptor.doc_id, ), self._descriptor_changed)

    def _descriptor_changed(self, doc_id, rev):
        self.warning('Received the notification about other database session '
                     'changing our descriptor. This means that I got '
                     'restarted on some other machine and need to commit '
                     'suacide :(. Or you have a bug ;).')
        return self.terminate_hard()

    def _reload_descriptor(self):

        def setter(value):
            self._descriptor = value

        d = self.reload_document(self._descriptor)
        return d

    def _store_instance_id(self):
        Run at the initialization before calling any code at agent-side.
        Ensures that descriptor holds our value, this effectively creates a
        lock on the descriptor - if other instance is running somewhere out
        there it would get the notification update and suacide.

        def do_set(desc):
            desc.instance_id = self._instance_id
            desc.under_restart = None

        return self.update_descriptor(do_set)

    def _load_configuration(self):

        def not_found(fail, doc_id):
            self.warning('Agents configuration not found in database. '
                         'Expected doc_id: %r', doc_id)

        d_id = self.agent.configuration_doc_id
        d = self.get_document(d_id)
        d.addErrback(not_found, d_id)
        return d

    def _next_update(self):

        def saved(desc, result, d):
            self.log("Updating descriptor: %r", desc)
            self._descriptor = desc

        def error_handler(failure, d):
            if failure.check(ConflictError):
                self.warning('Descriptor update conflict, killing the agent.')
                self.error("Failed updating descriptor: %s",

        def next_update(any=None):
            self._updating = False
            return any

        if self._updating:
            # Currently updating descriptor

        if not self._update_queue:
            # No more pending updates

        d, fun, args, kwargs = self._update_queue.pop(0)
        self._updating = True
            desc = self.get_descriptor()
            result = fun(desc, *args, **kwargs)
            assert not isinstance(result, (defer.Deferred, fiber.Fiber))
            save_d = self.save_document(desc)
            save_d.addCallbacks(callback=saved, callbackArgs=(result, d),
                                errback=error_handler, errbackArgs=(d, ))
        except Exception as e:

    def _terminate_procedure(self, body):
        assert callable(body)

        if self._cmp_state(AgencyAgentState.terminating):

        # Revoke all interests
         for i in list(self._iter_interests())]

        d = defer.succeed(None)

        # Cancel all long running protocols
        d.addBoth(defer.drop_param, self._cancel_long_running_protocols)
        # Cancel all delayed calls
        d.addBoth(defer.drop_param, self._cancel_all_delayed_calls)
        # Kill all protocols
        # Again, just in case
        d.addBoth(defer.drop_param, self._cancel_all_delayed_calls)
        # Run code specific to the given shutdown
        d.addBoth(defer.drop_param, body)
        # Tell the agency we are no more
        d.addBoth(defer.drop_param, self._unregister_from_agency)
        # Close the messaging connection
        d.addBoth(defer.drop_param, self._messaging.disconnect)
        # Close the database connection
        d.addBoth(defer.drop_param, self._database.disconnect)
                  self._set_state, AgencyAgentState.terminated)
        return d

    def _handle_failure(self, failure):
        error.handle_failure(self, failure, "Failure during termination")

    def _unregister_from_agency(self):

    def _cancel_long_running_protocols(self):
        return defer.DeferredList([defer.maybeDeferred(x.cancel)
                                   for x in self._long_running_protocols])

    def terminate_hard(self):
        '''Kill the agent without notifying anybody.'''

        def generate_body():
            d = defer.succeed(None)
            # run IAgent.killed() and wait for the listeners to finish the job
            d.addBoth(self._run_and_wait, self.agent.on_agent_killed)
            return d

        return self._terminate_procedure(generate_body)

    def _terminate(self):
        '''terminate() -> Shutdown agent gently removing the descriptor and
        notifying partners.'''

        def generate_body():
            d = defer.succeed(None)
            # Run IAgent.shutdown() and wait for
            # the protocols to finish the job
            d.addBoth(self._run_and_wait, self.agent.shutdown_agent)
            # Delete the descriptor
            d.addBoth(lambda _: self.delete_document(self._descriptor))
            # TODO: delete the queue
            return d

        return self._terminate_procedure(generate_body)

    def _run_and_wait(self, _, method, *args, **kwargs):
        Run a agent-side method and wait for all the protocols
        to finish processing.
        d = defer.maybeDeferred(method, *args, **kwargs)
        d.addBoth(defer.drop_param, self.wait_for_protocols_finish)
        return d

    def _iter_interests(self):
        return (interest
                for interests in self._interests.itervalues()
                for interest in interests.itervalues())

    def _kill_all_protocols(self, *_):

        def expire_one(prot):
            d = defer.succeed(None)
            d.addCallback(defer.drop_param, prot.cleanup)
            d.addErrback(Failure.trap, ProtocolFailed)
            return d

        d = defer.DeferredList([expire_one(x)
                                for x in self._protocols.values()])
        return d

    def _call_initiate(self, **kwargs):
        d = defer.maybeDeferred(self.agent.initiate_agent, **kwargs)
        d.addCallback(fiber.drop_param, self._set_state,
        return d

    def _call_startup(self, call_startup=True):
        d = defer.succeed(None)
        if call_startup:
            d.addCallback(defer.drop_param, self.agent.startup_agent)
        d.addCallback(fiber.drop_param, self._become_ready)
        return d

    def _become_ready(self):

    def _startup_error(self, fail):
        self.error("Agent raised an error while starting up. "
                   "He will be punished by terminating. Medium state while "
                   "that happend: %r", self._get_machine_state())

    def _store_delayed_call(self, call_id, call, busy):
        if call.active():
            self.log('Storing delayed call with id %r', call_id)
            self._delayed_calls.set(call_id, (busy, call), call.getTime() + 1)

    def _cancel_all_delayed_calls(self):
        for call_id, (_busy, call) in self._delayed_calls.iteritems():
            self.log('Canceling delayed call with id %r (active: %s)',
                     call_id, call.active())
            if call.active():

    def _call(self, method, *args, **kwargs):

        def raise_on_fiber(res):
            if isinstance(res, fiber.Fiber):
                raise RuntimeError("We are not expecting method %r to "
                                   "return a Fiber, which it did!" % method)
            return res

        self.log('Calling method %r, with args: %r, kwargs: %r', method,
                 args, kwargs)
        d = defer.maybeDeferred(method, *args, **kwargs)
        return d