def end_episode(self, clear_buffers: bool = True, timeout_ms: Optional[int] = None): """Flush all pending items and generate a new episode ID. Configurations that are conditioned to only be appied on episode end are applied (assuming all other conditions are fulfilled) and the items inserted before flush is called. Args: clear_buffers: Whether the history should be cleared or not. Buffers should only not be cleared when trajectories spanning multiple episodes are used. timeout_ms: (optional, default is no timeout) Maximum time to block for before unblocking and raising a `DeadlineExceededError` instead. Note that although the block is interrupted, the buffers and episode ID are reset all the same and the insertion of the items will proceed in the background thread. Raises: DeadlineExceededError: If operation did not complete before the timeout. """ try: self._writer.EndEpisode(clear_buffers, timeout_ms) except RuntimeError as e: if 'Timeout exceeded' in str(e) and timeout_ms is not None: raise errors.DeadlineExceededError( f'End episode call did not complete within provided timeout of ' f'{datetime.timedelta(milliseconds=timeout_ms)}') raise
def server_info( self, timeout: Optional[int] = None ) -> Dict[str, reverb_types.TableInfo]: """Get table metadata information. Args: timeout: Timeout in seconds to wait for server response. By default no deadline is set and call will block indefinetely until server responds. Returns: A dictionary mapping table names to their associated `TableInfo` instances, which contain metadata about the table. Raises: errors.DeadlineExceededError: If timeout provided and exceeded. """ try: info_proto_strings = self._client.ServerInfo(timeout or 0) except RuntimeError as e: if 'Deadline Exceeded' in str(e) and timeout is not None: raise errors.DeadlineExceededError( f'ServerInfo call did not complete within provided timeout of ' f'{timeout}s') raise table_infos = {} for proto_string in info_proto_strings: table_info = reverb_types.TableInfo.from_serialized_proto( proto_string) table_infos[table_info.name] = table_info return table_infos
def flush(self, block_until_num_items: int = 0, timeout_ms: Optional[int] = None): """Block until all but `block_until_num_items` confirmed by the server. There are two ways that an item could be "pending": 1. Some of the data elements referenced by the item have not yet been finalized (and compressed) as a `ChunkData`. 2. The item has been written to the gRPC stream but the response confirming the insertion has not yet been received. Type 1 pending items are transformed into type 2 when flush is called by forcing (premature) chunk finalization of the data elements referenced by the items. This will allow the background worker to write the data and items to the gRPC stream and turn them into type 2 pending items. The time it takes for type 2 pending items to be confirmed is primarily due to the state of the table rate limiter. After the items have been written to the gRPC stream then all we can do is wait (GIL is not held). Args: block_until_num_items: If > 0 then this many pending items will be allowed to remain as type 1. If the number of type 1 pending items is less than `block_until_num_items` then we simply wait until the total number of pending items is <= `block_until_num_items`. timeout_ms: (optional, default is no timeout) Maximum time to block for before unblocking and raising a `DeadlineExceededError` instead. Note that although the block is interrupted, the insertion of the items will proceed in the background. Raises: ValueError: If block_until_num_items < 0. DeadlineExceededError: If operation did not complete before the timeout. """ if block_until_num_items < 0: raise ValueError( f'block_until_num_items must be >= 0, got {block_until_num_items}' ) if timeout_ms is None: timeout_ms = -1 try: self._writer.Flush(block_until_num_items, timeout_ms) except RuntimeError as e: if 'Timeout exceeded' in str(e) and timeout_ms is not None: raise errors.DeadlineExceededError( f'Flush call did not complete within provided timeout of ' f'{datetime.timedelta(milliseconds=timeout_ms)}') raise
def server_info( self, timeout: Optional[int] = None ) -> Dict[str, reverb_types.TableInfo]: """Get table metadata information. Args: timeout: Timeout in seconds to wait for server response. By default no deadline is set and call will block indefinetely until server responds. Returns: A dictionary mapping table names to their associated `TableInfo` instances, which contain metadata about the table. Raises: errors.DeadlineExceededError: If timeout provided and exceeded. """ try: info_proto_strings = self._client.ServerInfo(timeout or 0) except RuntimeError as e: if 'Deadline Exceeded' in str(e) and timeout is not None: raise errors.DeadlineExceededError( f'ServerInfo call did not complete within provided timeout of ' f'{timeout}s') raise table_info = {} for proto_string in info_proto_strings: proto = schema_pb2.TableInfo.FromString(proto_string) if proto.HasField('signature'): signature = nested_structure_coder.StructureCoder( ).decode_proto(proto.signature) else: signature = None info_dict = dict((descr.name, getattr(proto, descr.name)) for descr in proto.DESCRIPTOR.fields) info_dict['signature'] = signature name = str(info_dict['name']) table_info[name] = reverb_types.TableInfo(**info_dict) return table_info
def test_exit_does_not_flush_on_reverb_error(self): # If there are no errors then flush should be called. with mock.patch.object(self.writer, 'flush') as flush_mock: with self.writer: pass flush_mock.assert_called_once() # It flush if unrelated errors are encountered with mock.patch.object(self.writer, 'flush') as flush_mock: with self.assertRaises(ValueError): with self.writer: raise ValueError('Test') flush_mock.assert_called_once() # But it should not flush if Reverb raises the error. with mock.patch.object(self.writer, 'flush') as flush_mock: with self.assertRaises(errors.ReverbError): with self.writer: raise errors.DeadlineExceededError('Test') flush_mock.assert_not_called()