Beispiel #1
0
class Context(object):
    def __init__(self):
        self.hostname = None
        self.connection = Client()
        self.ml = None
        self.logger = logging.getLogger('cli')
        self.plugin_dirs = []
        self.task_callbacks = {}
        self.plugins = {}
        self.variables = VariableStore()
        self.root_ns = RootNamespace('')
        self.event_masks = ['*']
        self.event_divert = False
        self.event_queue = six.moves.queue.Queue()
        self.keepalive_timer = None
        self.argparse_parser = None
        config.instance = self

    @property
    def is_interactive(self):
        return os.isatty(sys.stdout.fileno())

    def start(self):
        self.discover_plugins()
        self.connect()

    def connect(self):
        try:
            self.connection.connect(self.hostname)
        except socket_error as err:
            output_msg(_(
                "Could not connect to host: {0} due to error: {1}".format(self.hostname, err)
            ))
            self.argparse_parser.print_help()
            sys.exit(1)

    def login(self, user, password):
        try:
            self.connection.login_user(user, password)
            self.connection.subscribe_events(*EVENT_MASKS)
            self.connection.on_event(self.handle_event)
            self.connection.on_error(self.connection_error)

        except RpcException as e:
            if e.code == errno.EACCES:
                self.connection.disconnect()
                output_msg(_("Wrong username or password"))
                sys.exit(1)

        self.login_plugins()

    def keepalive(self):
        if self.connection.opened:
            self.connection.call_sync('management.ping')

    def read_middleware_config_file(self, file):
        """
        If there is a cli['plugin-dirs'] in middleware.conf use that,
        otherwise use the default plugins dir within cli namespace
        """
        plug_dirs = None
        if file:
            with open(file, 'r') as f:
                data = json.load(f)

            if 'cli' in data and 'plugin-dirs' in data['cli']:

                if type(data['cli']['plugin-dirs']) != list:
                    return

                self.plugin_dirs += data['cli']['plugin-dirs']

        if plug_dirs is None:
            plug_dirs = os.path.dirname(os.path.realpath(__file__))
            plug_dirs = os.path.join(plug_dirs, 'plugins')
            self.plugin_dirs += [plug_dirs]


    def discover_plugins(self):
        for dir in self.plugin_dirs:
            self.logger.debug(_("Searching for plugins in %s"), dir)
            self.__discover_plugin_dir(dir)

    def login_plugins(self):
        for i in list(self.plugins.values()):
            if hasattr(i, '_login'):
                i._login(self)

    def __discover_plugin_dir(self, dir):
        for i in glob.glob1(dir, "*.py"):
            self.__try_load_plugin(os.path.join(dir, i))

    def __try_load_plugin(self, path):
        if path in self.plugins:
            return

        self.logger.debug(_("Loading plugin from %s"), path)
        name, ext = os.path.splitext(os.path.basename(path))
        plugin = imp.load_source(name, path)

        if hasattr(plugin, '_init'):
            plugin._init(self)
            self.plugins[path] = plugin

    def __try_reconnect(self):
        output_lock.acquire()
        self.ml.blank_readline()

        output_msg(_('Connection lost! Trying to reconnect...'))
        retries = 0
        while True:
            retries += 1
            try:
                time.sleep(2)
                self.connect()
                try:
                    if self.hostname == '127.0.0.1':
                        self.connection.login_user(getpass.getuser(), '')
                    else:
                        self.connection.login_token(self.connection.token)

                    self.connection.subscribe_events(*EVENT_MASKS)
                except RpcException:
                    output_msg(_("Reauthentication failed (most likely token expired or server was restarted)"))
                    sys.exit(1)
                break
            except Exception as e:
                output_msg(_('Cannot reconnect: {0}'.format(str(e))))

        self.ml.restore_readline()
        output_lock.release()

    def attach_namespace(self, path, ns):
        splitpath = path.split('/')
        ptr = self.root_ns
        ptr_namespaces = ptr.namespaces()

        for n in splitpath[1:-1]:

            if n not in list(ptr_namespaces().keys()):
                self.logger.warn(_("Cannot attach to namespace %s"), path)
                return

            ptr = ptr_namespaces()[n]

        ptr.register_namespace(ns)

    def connection_error(self, event, **kwargs):
        if event == ClientError.LOGOUT:
            output_msg('Logged out from server.')
            self.connection.disconnect()
            sys.exit(0)

        if event == ClientError.CONNECTION_CLOSED:
            time.sleep(1)
            self.__try_reconnect()
            return

    def handle_event(self, event, data):
        if event == 'task.updated':
            if data['id'] in self.task_callbacks:
                self.handle_task_callback(data)

        self.print_event(event, data)

    def handle_task_callback(self, data):
        if data['state'] in ('FINISHED', 'CANCELLED', 'ABORTED', 'FAILED'):
            self.task_callbacks[data['id']](data['state'])

    def print_event(self, event, data):
        if self.event_divert:
            self.event_queue.put((event, data))
            return

        if event == 'task.progress':
            return

        output_lock.acquire()
        self.ml.blank_readline()

        translation = events.translate(self, event, data)
        if translation:
            output_msg(translation)
            if 'state' in data:
                if data['state'] == 'FAILED':
                    status = self.connection.call_sync('task.status', data['id'])
                    output_msg(_(
                        "Task #{0} error: {1}".format(
                            data['id'],
                            status['error'].get('message', '') if status.get('error') else ''
                        )
                    ))

        sys.stdout.flush()
        self.ml.restore_readline()
        output_lock.release()

    def call_sync(self, name, *args, **kwargs):
        return wrap(self.connection.call_sync(name, *args, **kwargs))

    def call_task_sync(self, name, *args, **kwargs):
        self.ml.skip_prompt_print = True
        wrapped_result = wrap(self.connection.call_task_sync(name, *args))
        self.ml.skip_prompt_print = False
        return wrapped_result

    def submit_task(self, name, *args, **kwargs):
        callback = kwargs.pop('callback', None)
        message_formatter = kwargs.pop('message_formatter', None)

        if not self.variables.get('tasks_blocking'):
            tid = self.connection.call_sync('task.submit', name, args)
            if callback:
                self.task_callbacks[tid] = callback

            return tid
        else:
            output_msg(_("Hit Ctrl+C to terminate task if needed"))
            self.event_divert = True
            tid = self.connection.call_sync('task.submit', name, args)
            progress = ProgressBar()
            try:
                while True:
                    event, data = self.event_queue.get()

                    if event == 'task.progress' and data['id'] == tid:
                        message = data['message']
                        if isinstance(message_formatter, collections.Callable):
                            message = message_formatter(message)
                        progress.update(percentage=data['percentage'], message=message)

                    if event == 'task.updated' and data['id'] == tid:
                        progress.update(message=data['state'])
                        if data['state'] == 'FINISHED':
                            progress.finish()
                            break

                        if data['state'] == 'FAILED':
                            print()
                            break
            except KeyboardInterrupt:
                print()
                output_msg(_("User requested task termination. Task abort signal sent"))
                self.call_sync('task.abort', tid)

        self.event_divert = False
        return tid
Beispiel #2
0
class BaseTestCase(unittest.TestCase):
    class TaskState(object):
        def __init__(self):
            self.tid = None
            self.state = None
            self.message = None
            self.result = None
            self.name = None
            self.ended = Event()

    def __init__(self, methodName):
        super(BaseTestCase, self).__init__(methodName)
        self.tasks = {}
        self.tasks_lock = Lock()
        self.conn = None
        self.task_timeout = 30

    def setUp(self):
        try:
            self.conn = Client()
            self.conn.event_callback = self.on_event
            self.conn.connect(os.getenv('TESTHOST', '127.0.0.1'))
            self.conn.login_user(os.getenv('TESTUSER', 'root'),
                                 os.getenv('TESTPWD', ''),
                                 timeout=self.task_timeout)
            self.conn.subscribe_events('*')
        except:
            raise

    def tearDown(self):
        self.conn.disconnect()

    def submitTask(self, name, *args):
        with self.tasks_lock:
            try:
                tid = self.conn.call_sync('task.submit', name, args)
            except RpcException:
                raise
            except Exception:
                raise

            self.tasks[tid] = self.TaskState()
            self.tasks[tid].tid = tid
            self.tasks[tid].name = name
        return tid

    def assertTaskCompletion(self, tid):
        t = self.tasks[tid]
        if not t.ended.wait(self.task_timeout):
            self.fail('Task {0} timed out'.format(tid))
        #print dir(t)
        #print 'Message is ' + str(t.message)
        #print 'State is ' + str(t.state)
        #print 'Result is ' + str(t.result)

        if t.state.count('Executing...'):
            message = t.error
        elif t.__getattribute__('message') and t.message.count('Executing...'):
            message = t.state
        else:
            message = t.message
        if not message:
            self.query_task(tid)

        self.assertEqual(t.state, 'FINISHED', msg=message)

    def assertTaskFailure(self, tid):
        t = self.tasks[tid]
        if not t.ended.wait(self.task_timeout):
            self.fail('Task {0} timed out'.format(tid))

        self.assertNotEqual(t.state, 'FINISHED', msg=t.message)

    def assertSeenEvent(self, name, func=None):
        pass

    def skip(self, reason):
        raise unittest.SkipTest(str(reason))

    def getTaskResult(self, tid):
        t = self.tasks[tid]
        return t.result

    def on_event(self, name, args):

        with self.tasks_lock:
            if name == 'task.updated':
                #DEBUG
                #print 'ARGS IS ' + str(args)
                #print 'TASK LIST IS ' + str(self.tasks)
                #for pc in self.conn.pending_calls.keys():
                #    print 'PENDING CALL METHOD ' + str(self.conn.pending_calls[pc].method) + \
                #    ' and ID ' + str(self.conn.pending_calls[pc].id)

                if args['id'] not in self.tasks.keys():
                    if args['state'] == 'EXECUTING':
                        return
                else:
                    t = self.tasks[args['id']]
                    t.state = args['state']
                    if t.state in ('FINISHED', 'FAILED'):
                        t.result = args['result'] if 'result' in args else None
                        t.ended.set()

            elif name == 'task.progress':
                if args['id'] not in self.tasks.keys():
                    if args['state'] == 'EXECUTING':
                        return
                else:
                    t = self.tasks[args['id']]
                    t.message = args['message']

    def on_eventOrig(self, name, args):

        self.tasks_lock.acquire()
        if name == 'task.updated':
            #DEBUG
            #print 'ARGS IS ' + str(args)
            #print 'TASK LIST IS ' + str(self.tasks)
            #for pc in self.conn.pending_calls.keys():
            #    print 'PENDING CALL METHOD ' + str(self.conn.pending_calls[pc].method) + \
            #    ' and ID ' + str(self.conn.pending_calls[pc].id)

            if args['id'] not in self.tasks.keys():
                if args['state'] == 'EXECUTING':
                    self.tasks_lock.release()
                    return
            else:
                t = self.tasks[args['id']]
                t.state = args['state']
                if t.state in ('FINISHED', 'FAILED'):
                    t.result = args['result'] if 'result' in args else None
                    t.ended.set()

        elif name == 'task.progress':
            if args['id'] not in self.tasks.keys():
                if args['state'] == 'EXECUTING':
                    self.tasks_lock.release()
                    return
            else:
                t = self.tasks[args['id']]
                t.message = args['message']

        self.tasks_lock.release()

    def pretty_print(self, res):
        if '-v' in sys.argv:
            print json.dumps(res, indent=4, sort_keys=True)

    def query_task(self, tid):
        # Makes tests very slow, keep as debug
        query = self.conn.call_sync('task.query', [('id', '=', tid)])
        message = query[0]['error']
        self.pretty_print(message)
Beispiel #3
0
class BaseTestCase(unittest.TestCase):
    class TaskState(object):
        def __init__(self):
            self.tid = None
            self.state = None
            self.message = None
            self.result = None
            self.name = None
            self.ended = Event()
            

    def __init__(self, methodName):
        super(BaseTestCase, self).__init__(methodName)
        self.tasks = {}
        self.tasks_lock = Lock()
        self.conn = None
        self.task_timeout = 30

    def setUp(self):
        try:
            self.conn = Client()
            self.conn.event_callback = self.on_event
            self.conn.connect(os.getenv('TESTHOST', '127.0.0.1'))
            self.conn.login_user(os.getenv('TESTUSER', 'root'), os.getenv('TESTPWD', ''), timeout = self.task_timeout)
            self.conn.subscribe_events('*')
        except:
            raise

    def tearDown(self):
        self.conn.disconnect()

    def submitTask(self, name, *args):
        with self.tasks_lock:
            try:
                tid = self.conn.call_sync('task.submit', name, args)
            except RpcException:
                raise
            except Exception:
                raise    

            self.tasks[tid] = self.TaskState()
            self.tasks[tid].tid = tid
            self.tasks[tid].name = name
        return tid

    def assertTaskCompletion(self, tid):
        t = self.tasks[tid]
        if not t.ended.wait(self.task_timeout):
            self.fail('Task {0} timed out'.format(tid))
        #print dir(t)    
        #print 'Message is ' + str(t.message)  
        #print 'State is ' + str(t.state)
        #print 'Result is ' + str(t.result)

        if t.state.count('Executing...'):
            message = t.error
        elif t.__getattribute__('message') and t.message.count('Executing...'):
            message = t.state    
        else:
            message = t.message
        if not message:
            self.query_task(tid)
 
        self.assertEqual(t.state, 'FINISHED', msg=message)

    def assertTaskFailure(self, tid):
        t = self.tasks[tid]
        if not t.ended.wait(self.task_timeout):
            self.fail('Task {0} timed out'.format(tid))

        self.assertNotEqual(t.state, 'FINISHED', msg=t.message)

    def assertSeenEvent(self, name, func=None):
        pass

    def skip(self, reason):
        raise unittest.SkipTest(str(reason))      

    def getTaskResult(self, tid):
        t = self.tasks[tid]
        return t.result

    def on_event(self, name, args):

        with self.tasks_lock:
            if name == 'task.updated':
                #DEBUG
                #print 'ARGS IS ' + str(args)
                #print 'TASK LIST IS ' + str(self.tasks)
                #for pc in self.conn.pending_calls.keys():
                #    print 'PENDING CALL METHOD ' + str(self.conn.pending_calls[pc].method) + \
                #    ' and ID ' + str(self.conn.pending_calls[pc].id)

                if args['id'] not in self.tasks.keys():
                    if args['state'] == 'EXECUTING':
                        return
                else:           
                    t = self.tasks[args['id']]
                    t.state = args['state']
                    if t.state in ('FINISHED', 'FAILED'):
                        t.result = args['result'] if 'result' in args else None
                        t.ended.set()

            elif name == 'task.progress':
                if args['id'] not in self.tasks.keys():
                    if args['state'] == 'EXECUTING':
                        return
                else:
                    t = self.tasks[args['id']]
                    t.message = args['message']


    def on_eventOrig(self, name, args):

        self.tasks_lock.acquire()
        if name == 'task.updated':
            #DEBUG
            #print 'ARGS IS ' + str(args)
            #print 'TASK LIST IS ' + str(self.tasks)
            #for pc in self.conn.pending_calls.keys():
            #    print 'PENDING CALL METHOD ' + str(self.conn.pending_calls[pc].method) + \
            #    ' and ID ' + str(self.conn.pending_calls[pc].id)

            if args['id'] not in self.tasks.keys():
                if args['state'] == 'EXECUTING':
                    self.tasks_lock.release()
                    return
            else:           
                t = self.tasks[args['id']]
                t.state = args['state']
                if t.state in ('FINISHED', 'FAILED'):
                    t.result = args['result'] if 'result' in args else None
                    t.ended.set()

        elif name == 'task.progress':
            if args['id'] not in self.tasks.keys():
                if args['state'] == 'EXECUTING':
                    self.tasks_lock.release()
                    return
            else:
                t = self.tasks[args['id']]
                t.message = args['message']
        
        self.tasks_lock.release()    

    def pretty_print(self, res):
        if '-v' in sys.argv:
            print json.dumps(res, indent=4, sort_keys=True)

    def query_task(self, tid):
        # Makes tests very slow, keep as debug
        query =  self.conn.call_sync('task.query', [('id','=',tid)])    
        message = query[0]['error']
        self.pretty_print(message)