Ejemplo n.º 1
0
    def test_serialize_buffers(self):
        """Test correct handling of files when serializing the buffers."""
        # without providing target path
        with TransportSessionClient(SqliteSession, URI) as session:
            self.setup_buffers1(session)
            _, result = serialize_buffers(
                session,
                buffer_context=BufferContext.USER,
                target_directory=None,
            )
            self.assertEqual(
                sorted(map(os.path.abspath, [FILE_PATHS[0], FILE_PATHS[2]])),
                sorted(result),
            )

        # provide target path --> move files
        with TransportSessionClient(SqliteSession, URI) as session:
            images = self.setup_buffers1(session)
            _, result = serialize_buffers(
                session,
                buffer_context=BufferContext.USER,
                target_directory=CLIENT_DIR,
            )
            target = [
                "%s-%s" % (image.uid.hex, file)
                for image, file in zip(images, FILES)
            ]
            target_full_path = [os.path.join(CLIENT_DIR, t) for t in target]
            self.assertEqual(
                sorted([target_full_path[0], target_full_path[2]]),
                sorted(result),
            )
            self.maxDiff = None
 def _load_from_backend(self, uids, expired=None):
     expired = expired or self._expired
     data, files = serialize_buffers(
         self, buffer_context=None,
         additional_items={"uids": uids,
                           "expired": expired},
         target_directory=self._file_destination
     )
     yield from self._engine.send(LOAD_COMMAND, data, files)
    def _send(self, command, consume_buffers, *args, **kwargs):
        """Send the buffers and a command to the server.

        Args:
            command (str): The command to send
            consume_buffers (bool): Whether to send and consume the buffers
            args (Serializable): The arguments of the command.
            kwargs (Serializable): The keyword arguments of the command.

        Returns:
            Serializable: The command's result.
        """
        arguments = {"args": args, "kwargs": kwargs}
        buffer_context = BufferContext.USER if consume_buffers else None
        data, files = serialize_buffers(
            self, buffer_context=buffer_context,
            additional_items=arguments,
            target_directory=self._file_destination
        )
        return self._engine.send(command, data, files)
    def _load_from_session(self, data, connection_id, temp_directory=None):
        """Load cuds_objects from the session.

        Args:
            data (str): The uids to load as json encoded list.

        Returns:
            str: The resulting cuds_objects, serialized.
        """
        session = self.session_objs[connection_id]
        uids = deserialize_buffers(
            session,
            buffer_context=None,
            data=data,
            temp_directory=temp_directory,
            target_directory=self._file_destination)["uids"]
        cuds_objects = list(session.load(*uids))
        additional = {"result": cuds_objects}
        return serialize_buffers(session,
                                 buffer_context=BufferContext.ENGINE,
                                 additional_items=additional)
    def _init_session(self, data, connection_id):
        """Start a new session.

        Args:
            data (str): The data sent by the user:
                serialized dict containing args, kwargs and root of new
                    session.
            connection_id (Hashable): The connection_id for the connection
                that requests to start a new session

        Returns:
            str: The buffers after the initialization, serialized.
        """
        data = json.loads(data)
        if connection_id in self.session_objs:
            self.session_objs[connection_id].close()
        user_kwargs = dict()
        argspec = inspect.getfullargspec(self.session_cls.__init__)
        args = argspec.kwonlyargs + argspec.args
        if "connection_id" in args:
            user_kwargs["connection_id"] = connection_id
        if "auth" in args:
            user_kwargs["auth"] = data["auth"]
        if self._session_kwargs and (data["args"] or data["kwargs"]):
            raise ValueError("This remote session cannot be parameterized by "
                             "the user. Only provide host and port and no "
                             "further arguments.")
        elif self._session_kwargs:
            session = self.session_cls(**self._session_kwargs, **user_kwargs)
        else:
            session = self.session_cls(*data["args"], **data["kwargs"],
                                       **user_kwargs)
        self.com_facility._file_hashes[connection_id].update(data["hashes"])
        self.session_objs[connection_id] = session
        deserialize(data["root"],
                    session=session,
                    buffer_context=BufferContext.USER)
        return serialize_buffers(session, buffer_context=BufferContext.ENGINE)
    def _run_command(self, data, command, connection_id, temp_directory=None):
        """Run a method of the session.

        Args:
            data (str): The data of the client.
            command (str): The method to execute.

        Returns:
            str: The buffers after the execution of the command, serialized.
        """
        session = self.session_objs[connection_id]
        arguments = deserialize_buffers(
            session,
            buffer_context=BufferContext.USER,
            data=data,
            temp_directory=temp_directory,
            target_directory=self._file_destination)
        result = getattr(session, command)(*arguments["args"],
                                           **arguments["kwargs"])
        additional = {"result": result} if result else dict()
        return serialize_buffers(session,
                                 buffer_context=BufferContext.ENGINE,
                                 additional_items=additional)
Ejemplo n.º 7
0
    def test_serialize_buffers(self):
        """Test if serialization of buffers works."""
        # no expiration
        with TestWrapperSession() as s1:
            ws1 = city.CityWrapper(session=s1, uid=123)
            c = city.City(name="Freiburg", uid=1)
            ws1.add(c)
            s1._reset_buffers(BufferContext.USER)

            cn = city.City(name="Paris", uid=2)
            ws1.add(cn)
            ws1.remove(c.uid)
            s1.prune()
            self.assertEqual(
                ('{"expired": [], "args": [42], "kwargs": {"name": "London"}}',
                 []),
                serialize_buffers(s1,
                                  buffer_context=None,
                                  additional_items={
                                      "args": [42],
                                      "kwargs": {
                                          "name": "London"
                                      }
                                  }))
            added, updated, deleted = s1._buffers[BufferContext.USER]
            self.assertEqual(added.keys(), {uuid.UUID(int=2)})
            self.assertEqual(updated.keys(), {uuid.UUID(int=123)})
            self.assertEqual(deleted.keys(), {uuid.UUID(int=1)})
            self.assertEqual(s1._buffers[BufferContext.ENGINE],
                             [dict(), dict(), dict()])
            self.maxDiff = None
            result = serialize_buffers(s1,
                                       buffer_context=BufferContext.USER,
                                       additional_items={
                                           "args": [42],
                                           "kwargs": {
                                               "name": "London"
                                           }
                                       })
            assertJsonLdEqual(self, json.loads(result[0]), SERIALIZED_BUFFERS)
            self.assertEqual(result[1], [])
            self.assertEqual(s1._buffers,
                             [[dict(), dict(), dict()],
                              [dict(), dict(), dict()]])
            s1._expired = {uuid.UUID(int=123), uuid.UUID(int=2)}

        # with expiration
        with TestWrapperSession() as s1:
            ws1 = city.CityWrapper(session=s1, uid=123)
            c = city.City(name="Freiburg", uid=1)
            ws1.add(c)
            s1._reset_buffers(BufferContext.USER)

            cn = city.City(name="Paris", uid=2)
            ws1.add(cn)
            ws1.remove(c.uid)
            s1.prune()
            s1._expired = {uuid.UUID(int=3)}
            self.assertEqual(
                ('{"expired": [{"UUID": '
                 '"00000000-0000-0000-0000-000000000003"}], '
                 '"args": [42], "kwargs": {"name": "London"}}', []),
                serialize_buffers(s1,
                                  buffer_context=None,
                                  additional_items={
                                      "args": [42],
                                      "kwargs": {
                                          "name": "London"
                                      }
                                  }))
            added, updated, deleted = s1._buffers[BufferContext.USER]
            self.assertEqual(added.keys(), {uuid.UUID(int=2)})
            self.assertEqual(updated.keys(), {uuid.UUID(int=123)})
            self.assertEqual(deleted.keys(), {uuid.UUID(int=1)})
            self.assertEqual(s1._buffers[BufferContext.ENGINE],
                             [dict(), dict(), dict()])

            self.maxDiff = 3000
            result = serialize_buffers(s1,
                                       buffer_context=BufferContext.USER,
                                       additional_items={
                                           "args": [42],
                                           "kwargs": {
                                               "name": "London"
                                           }
                                       })
            assertJsonLdEqual(self, SERIALIZED_BUFFERS_EXPIRED,
                              json.loads(result[0]))
            self.assertEqual([], result[1])
            self.assertEqual(s1._buffers,
                             [[dict(), dict(), dict()],
                              [dict(), dict(), dict()]])
            s1._expired = {uuid.UUID(int=123), uuid.UUID(int=2)}