예제 #1
0
    def test_raise_on_error_url(self, m):
        responses.add(responses.GET, self.url, body=ConnectTimeout())
        m.get(self.url, exception=asyncio.TimeoutError())

        self.server.gates = []

        with self.assertRaises(RuntimeError) as e:
            resp = get(self.server, 'home', raise_on_error=True)

        resp = get(self.server, 'home')

        self.assertIsInstance(resp.exception, RuntimeError)

        with self.assertRaises(RuntimeError) as e:
            resp = run(async_get(self.server, 'home', raise_on_error=True))

        resp = run(async_get(self.server, 'home'))

        self.assertIsInstance(resp.exception, RuntimeError)
예제 #2
0
def get_software(server: Server, auth) -> t.Tuple[str, str]:
    resp, code = get(server, 'api_1_0.software_dimensigon', auth=auth)
    if code == 200:
        content = base64.b64decode(resp.get('content').encode('ascii'))

        file = os.path.join(current_app.config['SOFTWARE_REPO'], 'dimensigon', resp.get('filename'))
        with open(file, 'wb') as fh:
            fh.write(content)
        return file, resp.get('version')
    else:
        return None, None
예제 #3
0
    def test_get_headers_url_response(self, m):
        msg = {'data': 'content'}
        status = 200

        def callback(request, **kwargs):
            if isinstance(request, PreparedRequest):
                self.assertEqual(str(self.server.id),
                                 request.headers.get('D-Destination'))
                self.assertEqual(str(self.server.id),
                                 request.headers.get('D-Source'))
                self.assertEqual("True", request.headers.get('D-Securizer'))
                self.assertEqual(f"Bearer {self.token}",
                                 request.headers.get('Authorization'))

                return status, {}, json.dumps(msg)
            else:
                self.assertEqual(str(self.server.id),
                                 kwargs['headers'].get('D-Destination'))
                self.assertEqual(str(self.server.id),
                                 kwargs['headers'].get('D-Source'))
                self.assertEqual("True", kwargs['headers'].get('D-Securizer'))
                self.assertEqual(f"Bearer {self.token}",
                                 kwargs['headers'].get('Authorization'))

                return CallbackResult(status=status, payload=msg)

        responses.add_callback(responses.GET, self.url, callback=callback)
        m.get(self.url, callback=callback)

        resp = get(self.server,
                   'home',
                   auth=self.auth,
                   headers={'D-Securizer': "True"})

        self.assertEqual(
            Response(msg=msg,
                     code=status,
                     exception=None,
                     server=self.server,
                     url=self.url), resp)

        resp = run(
            async_get(self.server,
                      'home',
                      auth=self.auth,
                      headers={'D-Securizer': "True"}))

        self.assertEqual(
            Response(msg=msg,
                     code=status,
                     exception=None,
                     server=self.server,
                     url=self.url), resp)
예제 #4
0
    def make_first_request(self):
        from dimensigon.domain.entities import Server
        import dimensigon.web.network as ntwrk

        with self.flask_app.app_context():
            start = time.time()
            while True:
                resp = ntwrk.get(Server.get_current(), 'root.home', timeout=1)
                if not resp.ok and time.time() - start < 30:
                    time.sleep(0.5)
                else:
                    break
            self._main_ctx.publish_q.safe_put(EventMessage("Listening", source="Dimensigon"))
예제 #5
0
    def _update_catalog_from_server(self, server):
        with lock_scope(Scope.UPGRADE, [self.server]):
            resp = ntwrk.get(
                server,
                'api_1_0.catalog',
                view_data=dict(data_mark=self.catalog_ver.strftime(
                    defaults.DATEMARK_FORMAT)),
                auth=get_root_auth())

            if resp.code and 199 < resp.code < 300:
                delta_catalog = resp.msg
                self.db_update_catalog(delta_catalog)
            else:
                raise CatalogFetchError(resp)
예제 #6
0
    def test_connection_error(self, m):
        responses.add(responses.GET, self.url, body=ConnectionError())
        m.get(self.url, exception=ConnectionError())

        resp = get(self.server, 'home')

        self.assertIsNone(resp.code)
        self.assertIsNone(resp.msg)
        self.assertIsInstance(resp.exception, ConnectionError)

        resp = run(async_get(self.server, 'home'))

        self.assertIsNone(resp.code)
        self.assertIsNone(resp.msg)
        self.assertIsInstance(resp.exception, ConnectionError)
예제 #7
0
    def test_get_error_json(self, m):
        msg = {'error': 'this is an error message'}
        status = 500
        responses.add(responses.GET, self.url, json=msg, status=status)
        m.get(self.url, payload=msg, status=status)

        resp = get(self.server, 'home', auth=self.auth)

        self.assertEqual(status, resp.code)
        self.assertDictEqual(msg, resp.msg)

        resp = run(async_get(self.server, 'home', auth=self.auth))

        self.assertEqual(status, resp.code)
        self.assertDictEqual(msg, resp.msg)
예제 #8
0
    def test_get_internal_error_server(self, m):
        msg = '<html>Iternal error server</html>'
        status = 500

        responses.add(responses.GET, self.url, status=status, body=msg)
        m.get(self.url, status=status, body=msg)

        resp = get(Server.get_current(), 'home')

        self.assertEqual(status, resp.code)
        self.assertEqual(msg, resp.msg)

        resp = run(async_get(Server.get_current(), 'home'))

        self.assertEqual(status, resp.code)
        self.assertEqual(msg, resp.msg)
예제 #9
0
    def test_raise_on_error(self, m):
        responses.add(responses.GET, self.url, body=ConnectTimeout())
        m.get(self.url, exception=asyncio.TimeoutError())

        with self.assertRaises(TimeoutError) as e:
            resp = get(self.server, 'home', timeout=1, raise_on_error=True)

        self.assertEqual(
            f"Socket timeout reached while trying to connect to {Server.get_current().url('root.home')} "
            f"for 1 seconds", str(e.exception))

        with self.assertRaises(TimeoutError) as e:
            resp = run(
                async_get(self.server, 'home', timeout=1, raise_on_error=True))

        self.assertEqual(
            f"Socket timeout reached while trying to connect to {Server.get_current().url('root.home')} "
            f"for 1 seconds", str(e.exception))
예제 #10
0
    def test_timeout(self, m):
        responses.add(responses.GET, self.url, body=ConnectTimeout())
        m.get(self.url, exception=asyncio.TimeoutError())

        resp = get(self.server, 'home', timeout=0.01)

        self.assertIsNone(resp.code)
        self.assertIsNone(resp.msg)
        self.assertIsInstance(resp.exception, TimeoutError)
        self.assertEqual(
            f"Socket timeout reached while trying to connect to {self.url} "
            f"for 0.01 seconds", str(resp.exception))

        resp = run(async_get(self.server, 'home', timeout=0.01))

        self.assertIsNone(resp.code)
        self.assertIsNone(resp.msg)
        self.assertIsInstance(resp.exception, TimeoutError)
        self.assertEqual(
            f"Socket timeout reached while trying to connect to {self.url} "
            f"for 0.01 seconds", str(resp.exception))
예제 #11
0
def send():
    def search_cost(ssa, route_list):
        cost = [
            route['cost'] for route in route_list
            if str(ssa.server.id) == route['destination_id']
        ]
        if cost:
            if cost[0] is None:
                cost = 999999
            else:
                cost = cost[0]
        else:
            cost = 999999
        return cost

    # Validate Data
    json_data = request.get_json()

    dest_server = Server.query.get_or_raise(json_data['dest_server_id'])

    if 'software_id' in json_data:
        software = Software.query.get_or_raise(json_data['software_id'])

        ssa = SoftwareServerAssociation.query.filter_by(
            server=g.server, software=software).one_or_none()
        # if current server does not have the software, forward request to the closest server who has it
        if not ssa:
            resp = ntwrk.get(dest_server, 'api_1_0.routes', timeout=5)
            if resp.code == 200:
                ssas = copy.copy(software.ssas)
                ssas.sort(key=functools.partial(
                    search_cost, route_list=resp.msg['route_list']))
            # unable to get route cost, we take the first option we have
            else:
                ssas = random.shuffle(list(software.ssas))
            if not ssas or len(ssas) == 0:
                raise errors.NoSoftwareServer(software_id=str(software.id))
            server = ssas[
                0].server  # closest server from dest_server who has the software

            resp = ntwrk.post(server, 'api_1_0.send', json=json_data)
            resp.raise_if_not_ok()
            return resp.msg, resp.code
        else:

            file = os.path.join(ssa.path, software.filename)
            if not os.path.exists(file):
                raise errors.FileNotFound(file)
            size = ssa.software.size
    else:
        file = json_data['file']
        if os.path.exists(file):
            size = os.path.getsize(file)
            checksum = md5(json_data.get('file'))
        else:
            raise errors.FileNotFound(file)

    chunk_size = d.CHUNK_SIZE * 1024 * 1024
    max_senders = min(json_data.get('max_senders', d.MAX_SENDERS),
                      d.MAX_SENDERS)
    chunks = math.ceil(size / chunk_size)

    if 'software_id' in json_data:
        json_msg = dict(software_id=str(software.id), num_chunks=chunks)
        if 'dest_path' in json_data:
            json_msg['dest_path'] = json_data.get('dest_path')
    else:
        json_msg = dict(dest_path=json_data['dest_path'],
                        filename=os.path.basename(json_data.get('file')),
                        size=size,
                        checksum=checksum,
                        num_chunks=chunks)
    # if dest_path not set, file will be sent to

    if 'force' in json_data:
        json_msg['force'] = json_data['force']

    resp = ntwrk.post(dest_server, 'api_1_0.transferlist', json=json_msg)
    resp.raise_if_not_ok()

    transfer_id = resp.msg.get('id')
    current_app.logger.debug(
        f"Transfer {transfer_id} created. Sending {file} to {dest_server}:{json_data.get('dest_path')}."
    )

    if json_data.get('background', True):
        executor.submit(
            asyncio.run,
            async_send_file(dest_server=dest_server,
                            transfer_id=transfer_id,
                            file=file,
                            chunk_size=chunk_size,
                            max_senders=max_senders,
                            identity=get_jwt_identity()))
    else:
        asyncio.run(
            async_send_file(dest_server=dest_server,
                            transfer_id=transfer_id,
                            file=file,
                            chunk_size=chunk_size,
                            max_senders=max_senders,
                            identity=get_jwt_identity()))

    if json_data.get('include_transfer_data', False):
        resp = ntwrk.get(dest_server,
                         "api_1_0.transferresource",
                         view_data=dict(transfer_id=transfer_id))
        if resp.code == 200:
            msg = resp.msg
        else:
            resp.raise_if_not_ok()
    else:
        msg = {'transfer_id': transfer_id}
    return msg, 202 if json_data.get('background', True) else 201
예제 #12
0
    def _execute(self,
                 params: Kwargs,
                 timeout=None,
                 context: Context = None) -> CompletedProcess:
        input_params = params['input']
        cp = CompletedProcess()
        cp.set_start_time()

        # common parameters
        kwargs = self.system_kwargs
        kwargs['timeout'] = timeout or kwargs.get('timeout')
        kwargs['identity'] = context.env.get('executor_id')

        resp, exception = None, None

        def search_cost(ssa, route_list):
            cost = [
                route['cost'] for route in route_list
                if str(ssa.server.id) == route['destination_id']
            ]
            if cost:
                if cost[0] is None:
                    cost = 999999
                else:
                    cost = cost[0]
            else:
                cost = 999999
            return cost

        software = input_params.get('software', None)
        if is_valid_uuid(software):
            soft = Software.query.get(software)
            if not soft:
                cp.stderr = f"software id '{software}' not found"
                cp.success = False
                cp.set_end_time()
                return cp
        else:
            version = input_params.get('version', None)
            if version:
                parsed_ver = parse(str(version))
                soft_list = [
                    s for s in Software.query.filter_by(name=software).all()
                    if s.parsed_version == parsed_ver
                ]
            else:
                soft_list = sorted(
                    Software.query.filter_by(name=software).all(),
                    key=lambda x: x.parsed_version)
            if soft_list:
                soft = soft_list[-1]
            else:
                cp.stderr = f"No software found for '{software}'" + (
                    f" and version '{version}'" if version else "")
                cp.success = False
                cp.set_end_time()
                return cp

        if not soft.ssas:
            cp.stderr = f"{soft.id} has no server association"
            cp.success = False
            cp.set_end_time()
            return cp

        # Server validation
        server = input_params.get('server', None)
        if is_valid_uuid(server):
            dest_server = Server.query.get(server)
        else:
            dest_server = Server.query.filter_by(name=server).one_or_none()
        if not dest_server:
            cp.stderr = f"destination server {'id ' if is_valid_uuid(server) else ''}'{server}' not found"
            cp.success = False
            cp.set_end_time()
            return cp

        # decide best server source
        resp = ntwrk.get(dest_server, 'api_1_0.routes', timeout=10)
        if resp.code == 200:
            ssas = copy.copy(soft.ssas)
            ssas.sort(key=functools.partial(search_cost,
                                            route_list=resp.msg['route_list']))
        else:
            ssas = soft.ssas
        server = ssas[0].server

        # Process kwargs
        data = {
            'software_id': soft.id,
            'dest_server_id': dest_server.id,
            "background": False,
            "include_transfer_data": True,
            "force": True
        }
        if input_params.get('dest_path', None):
            data.update(dest_path=input_params.get('dest_path', None))
        if input_params.get('chunk_size', None):
            data.update(chunk_size=input_params.get('chunk_size', None))
        if input_params.get('max_senders', None):
            data.update(max_senders=input_params.get('max_senders', None))
        # run request
        resp = ntwrk.post(server, 'api_1_0.send', json=data, **kwargs)
        cp.stdout = flask.json.dumps(resp.msg) if isinstance(
            resp.msg, dict) else resp.msg
        cp.stderr = str(resp.exception) if str(
            resp.exception) else resp.exception.__class__.__name__
        cp.rc = resp.code
        if resp.exception is None:
            self.evaluate_result(cp)
        cp.set_end_time()
        return cp
예제 #13
0
    def to_json(self, add_step_exec=False, human=False, split_lines=False):
        data = {}
        if self.id:
            data.update(id=str(self.id))
        if self.start_time:
            data.update(
                start_time=self.start_time.strftime(defaults.DATETIME_FORMAT))
        if self.end_time:
            data.update(
                end_time=self.end_time.strftime(defaults.DATETIME_FORMAT))
        if human:
            # convert target ids to server names
            d = {}
            if isinstance(self.target, dict):
                for k, v in self.target.items():
                    if is_iterable_not_string(v):
                        d[k] = [str(Server.query.get(s) or s) for s in v]
                    else:
                        d[k] = str(Server.query.get(v) or v)
            elif isinstance(self.target, list):
                d = [str(Server.query.get(s) or s) for s in self.target]
            else:
                d = str(Server.query.get(self.target) or self.target)
            data.update(target=d)
            if self.executor:
                data.update(executor=str(self.executor))
            if self.service:
                data.update(service=str(self.service))
            if self.orchestration:
                data.update(
                    orchestration=dict(id=str(self.orchestration.id),
                                       name=self.orchestration.name,
                                       version=self.orchestration.version))
            else:
                data.update(orchestration=None)
            if self.server:
                data.update(
                    server=dict(id=str(self.server.id), name=self.server.name))
        else:
            data.update(target=self.target)
            if self.orchestration_id or getattr(self.orchestration, 'id',
                                                None):
                data.update(
                    orchestration_id=str(self.orchestration_id or getattr(
                        self.orchestration, 'id', None)))
            if self.executor_id or getattr(self.executor, 'id', None):
                data.update(executor_id=str(
                    self.executor_id or getattr(self.executor, 'id', None)))
            if self.service_id or getattr(self.service, 'id', None):
                data.update(service_id=str(
                    self.server_id or getattr(self.service, 'id', None)))
            if self.server_id or getattr(self.server, 'id', None):
                data.update(server_id=str(self.server_id
                                          or getattr(self.server, 'id', None)))
        data.update(params=self.params)
        data.update(success=self.success)
        data.update(undo_success=self.undo_success)
        data.update(message=self.message)

        if self.parent_step_execution_id and not add_step_exec:
            data.update(
                parent_step_execution_id=str(self.parent_step_execution_id))
        if add_step_exec:
            steps = []
            for se in self.step_executions:
                se: StepExecution

                se_json = se.to_json(human, split_lines=split_lines)
                if se.child_orch_execution:
                    se_json[
                        'orch_execution'] = se.child_orch_execution.to_json(
                            add_step_exec=add_step_exec,
                            split_lines=split_lines,
                            human=human)
                elif se.child_orch_execution_id:
                    from dimensigon.web.network import get, Response
                    from dimensigon.network.auth import HTTPBearerAuth
                    from flask_jwt_extended import create_access_token
                    params = ['steps']
                    if human:
                        params.append('human')

                    try:
                        resp = get(se.server,
                                   'api_1_0.orchexecutionresource',
                                   view_data=dict(
                                       execution_id=se.child_orch_execution_id,
                                       params=params))
                    except Exception as e:
                        current_app.logger.exception(
                            f"Exception while trying to acquire orch execution "
                            f"{se.child_orch_execution_id} from {se.server}")
                        resp = Response(exception=e)

                    if resp.ok:
                        se_json['orch_execution'] = resp.msg
                        se_json.pop('child_orch_execution_id', None)

                steps.append(se_json)
            # steps.sort(key=lambda x: x.start_time)
            data.update(steps=steps)
        return data