Example #1
0
    def test_async_send_file_one_chunk(self):
        with virtual_network(self.app, self.app2):
            with self.app2.app_context():
                transfer = Transfer(software=self.filename,
                                    size=self.size,
                                    checksum=self.checksum,
                                    dest_path=self.dest_path,
                                    num_chunks=1,
                                    status=TransferStatus.WAITING_CHUNKS)
                db.session.add(transfer)
                db.session.commit()
                transfer_id = transfer.id

            self.assertFalse(
                os.path.exists(os.path.join(self.dest_path, self.filename)))

            run(
                async_send_file(dest_server=self.s2,
                                transfer_id=transfer_id,
                                file=os.path.join(self.source_path,
                                                  self.filename),
                                chunk_size=80,
                                identity=ROOT))

            self.assertTrue(
                os.path.exists(os.path.join(self.dest_path, self.filename)))
            self.assertEqual(
                self.size,
                os.path.getsize(os.path.join(self.dest_path, self.filename)))
            self.assertEqual(self.checksum,
                             md5(os.path.join(self.dest_path, self.filename)))
Example #2
0
    def test_post(self, m):
        status = 200
        msg = {'new': 'data'}
        data = {'data': 'some data'}

        def callback(request, **kwargs):
            if isinstance(request, PreparedRequest):
                self.assertDictEqual(data,
                                     unpack_msg(json.loads(request.body)))
                return status, {}, json.dumps(pack_msg(msg))
            else:
                self.assertDictEqual(data, unpack_msg(kwargs['json']))
                return CallbackResult(status=status, payload=pack_msg(msg))

        responses.add_callback(responses.POST, self.url, callback=callback)
        m.post(self.url, callback=callback)

        resp = post(self.server, 'home', json=data)

        self.assertEqual(status, resp.code)
        self.assertDictEqual(msg, resp.msg)

        resp = run(async_post(self.server, 'home', json=data))

        self.assertEqual(status, resp.code)
        self.assertDictEqual(msg, resp.msg)
Example #3
0
 def _notify_cluster_out(self):
     with self.dm.flask_app.app_context():
         servers = Server.get_neighbours()
         if servers:
             self.logger.debug(
                 f"Sending shutdown to {', '.join([s.name for s in servers])}"
             )
         else:
             self.logger.debug("No server to send shutdown information")
         if servers:
             responses = asyncio.run(
                 ntwrk.parallel_requests(
                     servers,
                     'post',
                     view_or_url='api_1_0.cluster_out',
                     view_data=dict(server_id=str(Server.get_current().id)),
                     json={
                         'death':
                         get_now().strftime(defaults.DATEMARK_FORMAT)
                     },
                     timeout=2,
                     auth=get_root_auth()))
             if self.logger.level <= logging.DEBUG:
                 for r in responses:
                     if not r.ok:
                         self.logger.warning(
                             f"Unable to send data to {r.server}: {r}")
Example #4
0
    def test_raise_on_error_url(self, m):
        responses.add(responses.GET, self.url, body=ConnectTimeout())
        m.get(self.url, exception=asyncio.TimeoutError())

        self.server.gates = []

        with self.assertRaises(RuntimeError) as e:
            resp = get(self.server, 'home', raise_on_error=True)

        resp = get(self.server, 'home')

        self.assertIsInstance(resp.exception, RuntimeError)

        with self.assertRaises(RuntimeError) as e:
            resp = run(async_get(self.server, 'home', raise_on_error=True))

        resp = run(async_get(self.server, 'home'))

        self.assertIsInstance(resp.exception, RuntimeError)
Example #5
0
    def test_get_headers_url_response(self, m):
        msg = {'data': 'content'}
        status = 200

        def callback(request, **kwargs):
            if isinstance(request, PreparedRequest):
                self.assertEqual(str(self.server.id),
                                 request.headers.get('D-Destination'))
                self.assertEqual(str(self.server.id),
                                 request.headers.get('D-Source'))
                self.assertEqual("True", request.headers.get('D-Securizer'))
                self.assertEqual(f"Bearer {self.token}",
                                 request.headers.get('Authorization'))

                return status, {}, json.dumps(msg)
            else:
                self.assertEqual(str(self.server.id),
                                 kwargs['headers'].get('D-Destination'))
                self.assertEqual(str(self.server.id),
                                 kwargs['headers'].get('D-Source'))
                self.assertEqual("True", kwargs['headers'].get('D-Securizer'))
                self.assertEqual(f"Bearer {self.token}",
                                 kwargs['headers'].get('Authorization'))

                return CallbackResult(status=status, payload=msg)

        responses.add_callback(responses.GET, self.url, callback=callback)
        m.get(self.url, callback=callback)

        resp = get(self.server,
                   'home',
                   auth=self.auth,
                   headers={'D-Securizer': "True"})

        self.assertEqual(
            Response(msg=msg,
                     code=status,
                     exception=None,
                     server=self.server,
                     url=self.url), resp)

        resp = run(
            async_get(self.server,
                      'home',
                      auth=self.auth,
                      headers={'D-Securizer': "True"}))

        self.assertEqual(
            Response(msg=msg,
                     code=status,
                     exception=None,
                     server=self.server,
                     url=self.url), resp)
    def test_log_sender_file_no_dest_folder(self, mock_pb_rl, mock_pt_uof, mock_isfile, m):
        def callback(url, **kwargs):
            self.assertDictEqual(
                {"file": '/var/log/dimensigon.log', 'data': base64.b64encode('content'.encode()).decode('ascii')},
                kwargs['json'])
            return CallbackResult('POST', status=200)

        m.post(self.dest.url('api_1_0.logresource', log_id='aaaaaaaa-1234-5678-1234-56781234aaa1'), callback=callback)

        mock_isfile.return_value = True

        log = Log(id='aaaaaaaa-1234-5678-1234-56781234aaa1', source_server=self.source,
                  target='/var/log/dimensigon.log',
                  destination_server=self.dest, dest_folder=None)
        db.session.add(log)

        run(self.log_sender.send_new_data())

        mock_isfile.assert_called_once()
        mock_pb_rl.assert_called_once()
        mock_pt_uof.assert_called_once()
Example #7
0
 def upgrade_process(self):
     self.logger.debug("Starting check catalog from neighbours")
     # cluster information
     cluster_hearthbeat_id = get_now().strftime(defaults.DATETIME_FORMAT)
     # check version update before catalog update to match database revision
     data = asyncio.run(
         self._async_get_neighbour_healthcheck(cluster_hearthbeat_id))
     if data:
         self.check_new_version(data)
         self.catalog_update(data)
     else:
         raise NoServerFound()
    def test_log_sender_folder(self, mock_pb_rl, mock_pt_uof, mock_isfile, mock_walk, mock_post, m):

        def callback(url, **kwargs):
            if kwargs['json']['file'] == '/dimensigon/logs/log1':
                self.assertDictEqual(
                    {"file": '/dimensigon/logs/log1', 'data': base64.b64encode('content1'.encode()).decode('ascii')},
                    kwargs['json'])
                return CallbackResult('POST', payload={'offset': 8}, status=200)
            elif kwargs['json']['file'] == '/dimensigon/logs/dir1/log2':
                self.assertDictEqual(
                    {"file": '/dimensigon/logs/dir1/log2',
                     'data': base64.b64encode('newcontent2'.encode()).decode('ascii')},
                    kwargs['json'])
                return CallbackResult('POST', payload={'offset': 11}, status=200)
            else:
                raise

        mock_post.side_effect = [({'offset': 8}, 200), ({'offset': 11}, 200), ({'offset': 8}, 200)]

        m.post(self.dest.url('api_1_0.logresource', log_id='aaaaaaaa-1234-5678-1234-56781234aaa1'), callback=callback)
        m.post(self.dest.url('api_1_0.logresource', log_id='aaaaaaaa-1234-5678-1234-56781234aaa1'), callback=callback)
        m.post(self.dest.url('api_1_0.logresource', log_id='aaaaaaaa-1234-5678-1234-56781234aaa1'), callback=callback)

        mock_isfile.return_value = False
        mock_walk.side_effect = [
            [('/var/log', ['dir1'], ['log1', 'file']), ('/var/log/dir1', ['dir2'], ['log2'])],
            [('/var/log', ['dir1'], ['log1', 'file']), ('/var/log/dir1', ['dir2'], [])]
        ]

        log = Log(id='aaaaaaaa-1234-5678-1234-56781234aaa1', source_server=self.source, target='/var/log',
                  destination_server=self.dest, dest_folder='/dimensigon/logs/', include='^(log|dir)', exclude='^dir2',
                  recursive=True)
        db.session.add(log)

        run(self.log_sender.send_new_data())

        mock_isfile.assert_called_once()
        self.assertEqual(2, mock_pb_rl.call_count)
        self.assertEqual(2, mock_pt_uof.call_count)
Example #9
0
    def test_connection_error(self, m):
        responses.add(responses.GET, self.url, body=ConnectionError())
        m.get(self.url, exception=ConnectionError())

        resp = get(self.server, 'home')

        self.assertIsNone(resp.code)
        self.assertIsNone(resp.msg)
        self.assertIsInstance(resp.exception, ConnectionError)

        resp = run(async_get(self.server, 'home'))

        self.assertIsNone(resp.code)
        self.assertIsNone(resp.msg)
        self.assertIsInstance(resp.exception, ConnectionError)
Example #10
0
    def test_get_error_json(self, m):
        msg = {'error': 'this is an error message'}
        status = 500
        responses.add(responses.GET, self.url, json=msg, status=status)
        m.get(self.url, payload=msg, status=status)

        resp = get(self.server, 'home', auth=self.auth)

        self.assertEqual(status, resp.code)
        self.assertDictEqual(msg, resp.msg)

        resp = run(async_get(self.server, 'home', auth=self.auth))

        self.assertEqual(status, resp.code)
        self.assertDictEqual(msg, resp.msg)
Example #11
0
    def test_post_no_content_in_response(self, m):
        msg = ''
        status = 204
        responses.add(responses.POST, self.url, status=204)
        m.post(self.url, status=204)

        data, status = post(self.server, 'home')

        self.assertEqual(status, status)
        self.assertEqual(msg, data)

        data, status = run(async_post(self.server, 'home'))

        self.assertEqual(status, status)
        self.assertEqual(msg, data)
Example #12
0
    def test_get_internal_error_server(self, m):
        msg = '<html>Iternal error server</html>'
        status = 500

        responses.add(responses.GET, self.url, status=status, body=msg)
        m.get(self.url, status=status, body=msg)

        resp = get(Server.get_current(), 'home')

        self.assertEqual(status, resp.code)
        self.assertEqual(msg, resp.msg)

        resp = run(async_get(Server.get_current(), 'home'))

        self.assertEqual(status, resp.code)
        self.assertEqual(msg, resp.msg)
Example #13
0
    def test_raise_on_error(self, m):
        responses.add(responses.GET, self.url, body=ConnectTimeout())
        m.get(self.url, exception=asyncio.TimeoutError())

        with self.assertRaises(TimeoutError) as e:
            resp = get(self.server, 'home', timeout=1, raise_on_error=True)

        self.assertEqual(
            f"Socket timeout reached while trying to connect to {Server.get_current().url('root.home')} "
            f"for 1 seconds", str(e.exception))

        with self.assertRaises(TimeoutError) as e:
            resp = run(
                async_get(self.server, 'home', timeout=1, raise_on_error=True))

        self.assertEqual(
            f"Socket timeout reached while trying to connect to {Server.get_current().url('root.home')} "
            f"for 1 seconds", str(e.exception))
Example #14
0
    def test_timeout(self, m):
        responses.add(responses.GET, self.url, body=ConnectTimeout())
        m.get(self.url, exception=asyncio.TimeoutError())

        resp = get(self.server, 'home', timeout=0.01)

        self.assertIsNone(resp.code)
        self.assertIsNone(resp.msg)
        self.assertIsInstance(resp.exception, TimeoutError)
        self.assertEqual(
            f"Socket timeout reached while trying to connect to {self.url} "
            f"for 0.01 seconds", str(resp.exception))

        resp = run(async_get(self.server, 'home', timeout=0.01))

        self.assertIsNone(resp.code)
        self.assertIsNone(resp.msg)
        self.assertIsInstance(resp.exception, TimeoutError)
        self.assertEqual(
            f"Socket timeout reached while trying to connect to {self.url} "
            f"for 0.01 seconds", str(resp.exception))
Example #15
0
    def _send_new_data(self):
        self.update_mapper()
        tasks = OrderedDict()

        for log_id, pb in self._mapper.items():
            log = self.session.query(Log).get(log_id)
            for pytail in pb:
                data = pytail.fetch()
                data = data.encode() if isinstance(data, str) else data
                if data and log.destination_server.id in self.dm.cluster_manager.get_alive(
                ):
                    if log.mode == Mode.MIRROR:
                        file = pytail.file
                    elif log.mode == Mode.REPO_ROOT:
                        path_to_remove = os.path.dirname(log.target)
                        relative = os.path.relpath(pytail.file, path_to_remove)
                        file = os.path.join('{LOG_REPO}', relative)
                    elif log.mode == Mode.FOLDER:
                        path_to_remove = os.path.dirname(log.target)
                        relative = os.path.relpath(pytail.file, path_to_remove)
                        file = os.path.join(log.dest_folder, relative)
                    else:

                        def get_root(dirname):
                            new_dirname = os.path.dirname(dirname)
                            if new_dirname == dirname:
                                return dirname
                            else:
                                return get_root(new_dirname)

                        relative = os.path.relpath(pytail.file,
                                                   get_root(pytail.file))
                        file = os.path.join('{LOG_REPO}', relative)
                    with self.dm.flask_app.app_context():
                        auth = get_root_auth()

                    task = ntwrk.async_post(
                        log.destination_server,
                        'api_1_0.logresource',
                        view_data={'log_id': str(log_id)},
                        json={
                            "file":
                            file,
                            'data':
                            base64.b64encode(
                                zlib.compress(data)).decode('ascii'),
                            "compress":
                            True
                        },
                        auth=auth)

                    tasks[task] = (pytail, log)
                    _log_logger.debug(
                        f"Task sending data from '{pytail.file}' to '{log.destination_server}' prepared"
                    )

        if tasks:
            with self.dm.flask_app.app_context():
                responses = asyncio.run(asyncio.gather(*list(tasks.keys())))

            for task, resp in zip(tasks.keys(), responses):
                pytail, log = tasks[task]
                if resp.ok:
                    pytail.update_offset_file()
                    _log_logger.debug(f"Updated offset from '{pytail.file}'")
                    if log.id not in self._blacklist:
                        self._blacklist_log.pop(log.id, None)
                else:
                    _log_logger.error(
                        f"Unable to send log information from '{pytail.file}' to '{log.destination_server}'. Error: {resp}"
                    )
                    if log.id not in self._blacklist:
                        bl = BlacklistEntry()
                        self._blacklist_log[log.id] = bl
                    else:
                        bl = self._blacklist_log.get(log.id)
                    bl.retries += 1
                    if bl.retries >= self.max_allowed_errors:
                        _log_logger.debug(
                            f"Adding server {log.destination_server.id} to the blacklist."
                        )
                        bl.blacklisted = time.time()
Example #16
0
def launch_command():
    data = request.get_json()

    server_list = []
    if 'target' in data:
        not_found = []
        servers = Server.query.all()
        if data['target'] == 'all':
            server_list = servers
        elif is_iterable_not_string(data['target']):
            for vv in data['target']:
                sl = search(vv, servers)
                if len(sl) == 0:
                    not_found.append(vv)
                else:
                    server_list.extend(sl)
        else:
            sl = search(data['target'], servers)
            if len(sl) == 0:
                not_found.append(data['target'])
            else:
                server_list.extend(sl)
        if not_found:
            return {
                'error':
                "Following granules or ids did not match to any server: " +
                ', '.join(not_found)
            }, 404
    else:
        server_list.append(g.server)

    if re.search(r'rm\s+((-\w+|--[-=\w]*)\s+)*(-\w*[rR]\w*|--recursive)',
                 data['command']):
        return {'error': 'rm with recursion is not allowed'}, 403
    data.pop('target', None)
    start = None

    username = getattr(User.query.get(get_jwt_identity()), 'name', None)
    if not username:
        raise errors.EntityNotFound('User', get_jwt_identity())
    cmd = wrap_sudo(username, data['command'])
    if g.server in server_list:
        start = time.time()
        server_list.pop(server_list.index(g.server))
        proc = subprocess.Popen(
            cmd,
            stdin=subprocess.PIPE if data.get('input', None) else None,
            stderr=subprocess.PIPE,
            stdout=subprocess.PIPE,
            shell=True,
            close_fds=True,
            encoding='utf-8')

    resp_data = {}
    if check_param_in_uri("human"):
        attr = 'name'
    else:
        attr = 'id'
    if server_list:
        resp: t.List[ntwrk.Response] = asyncio.run(
            ntwrk.parallel_requests(server_list,
                                    method='POST',
                                    view_or_url='api_1_0.launch_command',
                                    json=data))
        for s, r in zip(server_list, resp):
            key = getattr(s, attr, s.id)
            if r.ok:
                resp_data[key] = r.msg[s.id]
            else:
                if not r.exception:
                    resp_data[key] = {
                        'error': {
                            'status_code': r.code,
                            'response': r.msg
                        }
                    }
                else:
                    if isinstance(r.exception, errors.BaseError):
                        resp_data[key] = errors.format_error_content(
                            r.exception, current_app.config['DEBUG'])
                    else:
                        resp_data[key] = {
                            'error':
                            format_exception(r.exception) if
                            current_app.config['DEBUG'] else str(r.exception)
                            or str(r.exception.__class__.__name__)
                        }

    if start:
        key = getattr(g.server, attr, g.server.id)
        timeout = data.get('timeout', defaults.TIMEOUT_COMMAND)
        try:
            outs, errs = proc.communicate(input=(data.get('input', '') or ''),
                                          timeout=timeout)
        except (TimeoutError, subprocess.TimeoutExpired):
            proc.kill()
            try:
                outs, errs = proc.communicate(timeout=1)
            except:
                resp_data[key] = {
                    'error':
                    f"Command '{cmd}' timed out after {timeout} seconds. Unable to communicate with the process launched."
                }
            else:
                resp_data[key] = {
                    'error':
                    f"Command '{cmd}' timed out after {timeout} seconds",
                    'stdout': outs.split('\n'),
                    'stderr': errs.split('\n')
                }
        except Exception as e:
            current_app.logger.exception(
                "Exception raised while trying to run command")
            resp_data[key] = {
                'error':
                traceback.format_exc() if current_app.config['DEBUG'] else
                str(r.exception) or str(r.exception.__class__.__name__)
            }
        else:
            resp_data[key] = {
                'stdout': outs.split('\n'),
                'stderr': errs.split('\n'),
                'returncode': proc.returncode
            }
    resp_data['cmd'] = cmd
    resp_data['input'] = data.get('input', None)
    return resp_data, 200
Example #17
0
def send():
    def search_cost(ssa, route_list):
        cost = [
            route['cost'] for route in route_list
            if str(ssa.server.id) == route['destination_id']
        ]
        if cost:
            if cost[0] is None:
                cost = 999999
            else:
                cost = cost[0]
        else:
            cost = 999999
        return cost

    # Validate Data
    json_data = request.get_json()

    dest_server = Server.query.get_or_raise(json_data['dest_server_id'])

    if 'software_id' in json_data:
        software = Software.query.get_or_raise(json_data['software_id'])

        ssa = SoftwareServerAssociation.query.filter_by(
            server=g.server, software=software).one_or_none()
        # if current server does not have the software, forward request to the closest server who has it
        if not ssa:
            resp = ntwrk.get(dest_server, 'api_1_0.routes', timeout=5)
            if resp.code == 200:
                ssas = copy.copy(software.ssas)
                ssas.sort(key=functools.partial(
                    search_cost, route_list=resp.msg['route_list']))
            # unable to get route cost, we take the first option we have
            else:
                ssas = random.shuffle(list(software.ssas))
            if not ssas or len(ssas) == 0:
                raise errors.NoSoftwareServer(software_id=str(software.id))
            server = ssas[
                0].server  # closest server from dest_server who has the software

            resp = ntwrk.post(server, 'api_1_0.send', json=json_data)
            resp.raise_if_not_ok()
            return resp.msg, resp.code
        else:

            file = os.path.join(ssa.path, software.filename)
            if not os.path.exists(file):
                raise errors.FileNotFound(file)
            size = ssa.software.size
    else:
        file = json_data['file']
        if os.path.exists(file):
            size = os.path.getsize(file)
            checksum = md5(json_data.get('file'))
        else:
            raise errors.FileNotFound(file)

    chunk_size = d.CHUNK_SIZE * 1024 * 1024
    max_senders = min(json_data.get('max_senders', d.MAX_SENDERS),
                      d.MAX_SENDERS)
    chunks = math.ceil(size / chunk_size)

    if 'software_id' in json_data:
        json_msg = dict(software_id=str(software.id), num_chunks=chunks)
        if 'dest_path' in json_data:
            json_msg['dest_path'] = json_data.get('dest_path')
    else:
        json_msg = dict(dest_path=json_data['dest_path'],
                        filename=os.path.basename(json_data.get('file')),
                        size=size,
                        checksum=checksum,
                        num_chunks=chunks)
    # if dest_path not set, file will be sent to

    if 'force' in json_data:
        json_msg['force'] = json_data['force']

    resp = ntwrk.post(dest_server, 'api_1_0.transferlist', json=json_msg)
    resp.raise_if_not_ok()

    transfer_id = resp.msg.get('id')
    current_app.logger.debug(
        f"Transfer {transfer_id} created. Sending {file} to {dest_server}:{json_data.get('dest_path')}."
    )

    if json_data.get('background', True):
        executor.submit(
            asyncio.run,
            async_send_file(dest_server=dest_server,
                            transfer_id=transfer_id,
                            file=file,
                            chunk_size=chunk_size,
                            max_senders=max_senders,
                            identity=get_jwt_identity()))
    else:
        asyncio.run(
            async_send_file(dest_server=dest_server,
                            transfer_id=transfer_id,
                            file=file,
                            chunk_size=chunk_size,
                            max_senders=max_senders,
                            identity=get_jwt_identity()))

    if json_data.get('include_transfer_data', False):
        resp = ntwrk.get(dest_server,
                         "api_1_0.transferresource",
                         view_data=dict(transfer_id=transfer_id))
        if resp.code == 200:
            msg = resp.msg
        else:
            resp.raise_if_not_ok()
    else:
        msg = {'transfer_id': transfer_id}
    return msg, 202 if json_data.get('background', True) else 201
Example #18
0
    def test_async_send_retry(self, m):
        def post_callback_client(url, **kwargs):
            kwargs.pop('allow_redirects')

            r = self.client2.post(url.path,
                                  json=kwargs['json'],
                                  headers=kwargs['headers'])

            return CallbackResult('POST',
                                  status=r.status_code,
                                  body=r.data,
                                  content_type=r.content_type,
                                  headers=r.headers)

        m.post(re.compile(
            Server.query.filter_by(name='node2').one().url() + '.*'),
               callback=post_callback_client)
        m.post(re.compile(
            Server.query.filter_by(name='node2').one().url() + '.*'),
               callback=post_callback_client)
        m.post(re.compile(
            Server.query.filter_by(name='node2').one().url() + '.*'),
               exception=ConnectionError)
        m.post(re.compile(
            Server.query.filter_by(name='node2').one().url() + '.*'),
               exception=ConnectionError)
        m.post(re.compile(
            Server.query.filter_by(name='node2').one().url() + '.*'),
               callback=post_callback_client)
        m.post(re.compile(
            Server.query.filter_by(name='node2').one().url() + '.*'),
               callback=post_callback_client)
        m.post(re.compile(
            Server.query.filter_by(name='node2').one().url() + '.*'),
               exception=ConnectionError)
        m.post(re.compile(
            Server.query.filter_by(name='node2').one().url() + '.*'),
               callback=post_callback_client)

        def patch_callback_client(url, **kwargs):
            kwargs.pop('allow_redirects')

            r = self.client2.patch(url.path,
                                   json=kwargs['json'],
                                   headers=kwargs['headers'])

            return CallbackResult('PATCH',
                                  status=r.status_code,
                                  body=r.data,
                                  content_type=r.content_type,
                                  headers=r.headers)

        m.patch(re.compile(
            Server.query.filter_by(name='node2').one().url() + '.*'),
                callback=patch_callback_client,
                repeat=True)

        def put_callback_client(url, **kwargs):
            kwargs.pop('allow_redirects')

            r = self.client2.put(url.path,
                                 json=kwargs['json'],
                                 headers=kwargs['headers'])

            return CallbackResult('PUT',
                                  status=r.status_code,
                                  body=r.data,
                                  content_type=r.content_type,
                                  headers=r.headers)

        m.put(re.compile(
            Server.query.filter_by(name='node2').one().url() + '.*'),
              callback=put_callback_client,
              repeat=True)

        with self.app2.app_context():
            transfer = Transfer(software=self.filename,
                                size=self.size,
                                checksum=self.checksum,
                                dest_path=self.dest_path,
                                num_chunks=5,
                                status=TransferStatus.WAITING_CHUNKS)
            db.session.add(transfer)
            db.session.commit()
            transfer_id = transfer.id

        self.assertFalse(
            os.path.exists(os.path.join(self.dest_path, self.filename)))

        run(
            async_send_file(dest_server=self.s2,
                            transfer_id=transfer_id,
                            file=os.path.join(self.source_path, self.filename),
                            chunk_size=14,
                            identity=ROOT,
                            retries=3))

        self.assertTrue(
            os.path.exists(os.path.join(self.dest_path, self.filename)))
        self.assertEqual(
            self.size,
            os.path.getsize(os.path.join(self.dest_path, self.filename)))
        self.assertEqual(self.checksum,
                         md5(os.path.join(self.dest_path, self.filename)))
Example #19
0
    def _send_data(self):
        session = self.Session()

        def log_data(data):
            debug_data = []
            for cr in data:
                server = dict(id=cr.id)
                name = getattr(session.query(Server).get(cr.id), 'name', cr.id)
                if name:
                    server.update(name=name)

                debug_data.append({
                    'server':
                    server,
                    'keepalive':
                    cr.keepalive.strftime(defaults.DATEMARK_FORMAT),
                    'death':
                    cr.death
                })
            return debug_data

        # time to send data
        with self.dm.flask_app.app_context():
            neighbours = Server.get_neighbours(session=session)
            if neighbours:
                with self._change_buffer_lock:
                    temp_buffer = dict(self._buffer)
                    self._buffer.clear()

                self.logger.debug(
                    f"Sending cluster information to the following nodes: {', '.join([s.name for s in neighbours])}"
                )
                self.logger.log(
                    1,
                    f"{json.dumps(log_data(temp_buffer.values()), indent=2)}")

                auth = get_root_auth()
                try:
                    responses = asyncio.run(
                        ntwrk.parallel_requests(
                            neighbours,
                            'POST',
                            view_or_url='api_1_0.cluster',
                            json=[{
                                'id':
                                e.id,
                                'keepalive':
                                e.keepalive.strftime(defaults.DATEMARK_FORMAT),
                                'death':
                                e.death
                            } for e in temp_buffer.values()],
                            auth=auth,
                            securizer=False), )
                except Exception as e:
                    self.logger.error(
                        f"Unable to send cluster information to neighbours: {format_exception(e)}"
                    )
                    # restore data with new data arrived
                    with self._change_buffer_lock:
                        temp_buffer.update(**self._buffer)
                        self._buffer.clear()
                        self._buffer.update(temp_buffer)
                else:
                    for r in responses:
                        if not r.ok:
                            self.logger.warning(
                                f"Unable to send data to {r.server}: {r}")

                # check if new data arrived during timer execution
                with self._change_buffer_lock:
                    if self._buffer:
                        self._timer = threading.Timer(interval=1,
                                                      function=self._send_data)
                        self._timer.start()
                    else:
                        self._timer = None
            else:
                self.logger.debug(
                    f"No neighbour servers to send cluster information")
                with self._change_buffer_lock:
                    self._timer = None
        session.close()
Example #20
0
def lock_unlock(action: str,
                scope: Scope,
                servers: t.List[Server],
                applicant=None,
                identity=None):
    """

    Parameters
    ----------
    action
        'U' for unlocking and 'L' for locking
    scope
        scope of the lock
    servers
        servers to ask for a lock

    Raises
    ------
    Raises an error if something went wrong

    Returns
    -------
    None
        returns none if all went as expected.
    """

    assert action in 'UL'

    applicant = applicant or [str(s.id) for s in servers]

    if identity:
        token = create_access_token(identity)
    else:
        token = create_access_token(get_jwt_identity())
    auth = HTTPBearerAuth(token)
    if action == 'U':
        pool_responses = run(
            request_locker(servers=servers,
                           scope=scope,
                           action='unlock',
                           applicant=applicant,
                           auth=auth))

        if len(servers) == len(
            [r for r in pool_responses if r.code in (200, 210)]):
            return
    else:
        action = 'P'
        catalog_ver = Catalog.max_catalog(str)
        pool_responses = run(
            request_locker(servers=servers,
                           scope=scope,
                           action='prevent',
                           applicant=applicant,
                           datemark=catalog_ver,
                           auth=auth))

        if len(servers) == len(
            [r for r in pool_responses if r.code in (200, 210)]):
            action = 'L'
            pool_responses = run(
                request_locker(servers=servers,
                               scope=scope,
                               action='lock',
                               applicant=applicant,
                               auth=auth))
            if len(servers) == len(
                [r for r in pool_responses if r.code in (200, 210)]):
                return

    raise errors.LockError(
        scope, action, [r for r in pool_responses if r.code not in (200, 210)])