class SnapshotterToDB(SnapshotterBase): """Takes workflow snapshots to the database via ODBC. """ MAPPING = "odbc" WRITE_CODECS = { None: lambda n, l: n, "": lambda n, l: n, "snappy": lambda n, _: SnappyFile(n, "wb"), "gz": lambda n, l: gzip.GzipFile(fileobj=n, mode="wb", compresslevel=l), "bz2": lambda n, l: bz2.BZ2File(n, "wb", compresslevel=l), "xz": lambda n, l: lzma.LZMAFile(n, "wb", preset=l) } READ_CODECS = { "pickle": lambda n: n, "snappy": lambda n: SnappyFile(n, "rb"), "gz": lambda name: gzip.GzipFile(fileobj=name, mode="rb"), "bz2": lambda name: bz2.BZ2File(name, "rb"), "xz": lambda name: lzma.LZMAFile(name, "rb") } def __init__(self, workflow, **kwargs): super(SnapshotterToDB, self).__init__(workflow, **kwargs) self._odbc = kwargs["odbc"] self._table = kwargs.get("table", "veles") @property def odbc(self): return self._odbc @property def table(self): return self._table def initialize(self, **kwargs): super(SnapshotterToDB, self).initialize(**kwargs) self._db_ = pyodbc.connect(self.odbc) self._cursor_ = self._db_.cursor() def stop(self): if self.odbc is not None: self._db_.close() def export(self): self._destination = ".".join( (self.prefix, self.suffix, str(best_protocol))) fio = BytesIO() self.info("Preparing the snapshot...") with self._open_fobj(fio) as fout: pickle.dump(self.workflow, fout, protocol=best_protocol) self.check_snapshot_size(len(fio.getvalue())) binary = pyodbc.Binary(fio.getvalue()) self.info("Executing SQL insert into \"%s\"...", self.table) now = datetime.now() self._cursor_.execute( "insert into %s(timestamp, id, log_id, workflow, name, codec, data" ") values (?, ?, ?, ?, ?, ?, ?);" % self.table, now, self.launcher.id, self.launcher.log_id, self.launcher.workflow.name, self.destination, self.compression, binary) self._db_.commit() self.info("Successfully wrote %d bytes as %s @ %s", len(binary), self.destination, now) @staticmethod def import_(odbc, table, id_, log_id, name=None): conn = pyodbc.connect(odbc) cursor = conn.cursor() query = "select codec, data from %s where id='%s' and log_id='%s'" % ( table, id_, log_id) if name is not None: query += " and name = '%s'" % name else: query += " order by timestamp desc limit 1" cursor.execute(query) row = cursor.fetchone() codec = SnapshotterToDB.READ_CODECS[row.codec] with codec(BytesIO(row.data)) as fin: return SnapshotterToDB._import_fobj(fin) def get_metric_values(self): return { "Snapshot": { "odbc": self.odbc, "table": self.table, "name": self.destination } } def _open_fobj(self, fobj): return SnapshotterToDB.WRITE_CODECS[self.compression]( fobj, self.compression_level)
class MinibatchesSaver(Unit): """Saves data from Loader to pickle file. """ CODECS = { "raw": lambda f, _: f, "snappy": lambda f, _: SnappyFile(f, "wb"), "gz": lambda f, l: gzip.GzipFile(None, fileobj=f, compresslevel=l), "bz2": lambda f, l: bz2.BZ2File(f, compresslevel=l), "xz": lambda f, l: lzma.LZMAFile(f, preset=l) } def __init__(self, workflow, **kwargs): super(MinibatchesSaver, self).__init__(workflow, **kwargs) kwargs["view_group"] = kwargs.get("view_group", "SERVICE") self.file_name = os.path.abspath(kwargs.get( "file_name", os.path.join(root.common.dirs.cache, "minibatches.dat"))) self.compression = kwargs.get("compression", "snappy") self.compression_level = kwargs.get("compression_level", 9) self.class_chunk_sizes = kwargs.get("class_chunk_sizes", (0, 0, 1)) self.offset_table = [] self.demand( "minibatch_data", "minibatch_labels", "minibatch_class", "class_lengths", "max_minibatch_size", "minibatch_size", "shuffle_limit", "has_labels", "labels_mapping") def init_unpickled(self): super(MinibatchesSaver, self).init_unpickled() self._file_ = None @property def file(self): return self._file_ @property def effective_class_chunk_sizes(self): chunk_sizes = [] for ci, cs in enumerate(self.class_chunk_sizes): if cs == 0: cs = self.max_minibatch_size elif cs > self.max_minibatch_size: raise ValueError( "%s's chunk size may not exceed max minibatch size = %d (" "got %d)" % (CLASS_NAME[ci], self.max_minibatch_size, cs)) chunk_sizes.append(cs) return tuple(chunk_sizes) def initialize(self, **kwargs): if self.shuffle_limit != 0: raise error.VelesException( "You must disable shuffling in your loader (set shuffle_limit " "to 0)") self._file_ = open(self.file_name, "wb") pickle.dump(self.get_header_data(), self.file, protocol=best_protocol) def get_header_data(self): return self.compression, self.class_lengths, self.max_minibatch_size, \ self.effective_class_chunk_sizes, \ self.minibatch_data.shape, self.minibatch_data.dtype, \ self.minibatch_labels.shape if self.has_labels else None, \ self.minibatch_labels.dtype if self.has_labels else None, \ self.labels_mapping def prepare_chunk_data(self): self.minibatch_data.map_read() self.minibatch_labels.map_read() arr_data = numpy.zeros( (self.effective_class_chunk_sizes[self.minibatch_class],) + self.minibatch_data.shape[1:], dtype=self.minibatch_data.dtype) if self.has_labels: arr_labels = numpy.zeros( (self.effective_class_chunk_sizes[self.minibatch_class],) + self.minibatch_labels.shape[1:], self.minibatch_labels.dtype) else: arr_labels = None return arr_data, arr_labels def fill_chunk_data(self, prepared, interval): prepared[0][:] = self.minibatch_data[interval[0]:interval[1]] if self.has_labels: prepared[1][:] = self.minibatch_labels[interval[0]:interval[1]] def run(self): prepared = self.prepare_chunk_data() chunk_size = self.effective_class_chunk_sizes[self.minibatch_class] chunks_number = int(numpy.ceil(self.max_minibatch_size / chunk_size)) for i in range(chunks_number): self.offset_table.append(numpy.uint64(self.file.tell())) file = MinibatchesSaver.CODECS[self.compression]( self.file, self.compression_level) self.fill_chunk_data( prepared, (i * chunk_size, (i + 1) * chunk_size)) pickle.dump(prepared, file, protocol=best_protocol) file.flush() def stop(self): if self.file.closed: return pos = self.file.tell() pickle.dump(self.offset_table, self.file, protocol=best_protocol) self.debug("Offset table took %d bytes", self.file.tell() - pos) self.file.close() self.info("Wrote %s", self.file_name)
class SnapshotterToFile(SnapshotterBase): """Takes workflow snapshots to the file system. """ MAPPING = "file" WRITE_CODECS = { None: lambda n, l: open(n, "wb"), "": lambda n, l: open(n, "wb"), "snappy": lambda n, _: SnappyFile(n, "wb"), "gz": lambda n, l: gzip.GzipFile(n, "wb", compresslevel=l), "bz2": lambda n, l: bz2.BZ2File(n, "wb", compresslevel=l), "xz": lambda n, l: lzma.LZMAFile(n, "wb", preset=l) } READ_CODECS = { "pickle": lambda name: open(name, "rb"), "snappy": lambda n: SnappyFile(n, "rb"), "gz": lambda name: gzip.GzipFile(name, "rb"), "bz2": lambda name: bz2.BZ2File(name, "rb"), "xz": lambda name: lzma.LZMAFile(name, "rb") } def __init__(self, workflow, **kwargs): kwargs["view_group"] = kwargs.get("view_group", "SERVICE") super(SnapshotterToFile, self).__init__(workflow, **kwargs) self.directory = kwargs.get("directory", root.common.dirs.snapshots) def export(self): ext = ("." + self.compression) if self.compression else "" rel_file_name = "%s_%s.%d.pickle%s" % (self.prefix, self.suffix, best_protocol, ext) self._destination = os.path.abspath( os.path.join(self.directory, rel_file_name)) self.info("Snapshotting to %s..." % self.destination) with self._open_file() as fout: pickle.dump(self.workflow, fout, protocol=best_protocol) self.check_snapshot_size(os.path.getsize(self.destination)) file_name_link = os.path.join( self.directory, "%s_current.%d.pickle%s" % (self.prefix, best_protocol, ext)) # Link creation may fail when several processes do this all at once, # so try-except here: try: os.remove(file_name_link) except OSError: pass try: os.symlink(rel_file_name, file_name_link) except OSError: pass @staticmethod def import_(file_name): file_name = file_name.strip() if not os.path.exists(file_name): raise FileNotFoundError(file_name) _, ext = os.path.splitext(file_name) codec = SnapshotterToFile.READ_CODECS[ext[1:]] with codec(file_name) as fin: return SnapshotterToFile._import_fobj(fin) def _open_file(self): return SnapshotterToFile.WRITE_CODECS[self.compression]( self.destination, self.compression_level)