def tar_add_link(tar: tarfile.TarFile, filename: str): '''Add filename as symlink to tarfile''' info = tarfile.TarInfo(filename) if info.name[0] == '/': info.name = info.name[1:] info.type = tarfile.SYMTYPE info.linkname = os.readlink(filename) tar.addfile(info)
def add_sqlite_file(ctx: BackupContext, tar: TarFile, src: t.Union[os.PathLike, str], arcname: str) -> None: if not path.isfile(src): return with NamedTemporaryFile(mode='rb', suffix='.db') as db_backup_file: ctx.logger.debug('Dumping database from %s to %s', src, db_backup_file.name) subprocess.run( ['sqlite3', str(src), f".backup '{db_backup_file.name}'"], check=True) # 元ファイルのメタデータを使用する src_stat = os.stat(src) ti = tar.gettarinfo(db_backup_file.name, arcname) ti.mtime = src_stat.st_mtime ti.mode = src_stat.st_mode ti.uid = src_stat.st_uid ti.gid = src_stat.st_gid ti.uname = '' ti.gname = '' with open(src, 'rb') as f: tar.addfile(ti, f)
def tarball_images( images: List[Image.Image], *, name: str = None, animated: bool = False, format: str = "png", extras: List[Tuple[str, BytesIO]], ) -> BytesIO: fp = BytesIO() tar = TarFile(mode="w", fileobj=fp) for idx, image in enumerate(images): f = BytesIO() if animated: image[0].save(f, format, append_images=image[1:], save_all=True, loop=0) else: image.save(f, format) f.seek(0) if name: info = TarInfo(f"{name}_{idx}.{format}") else: info = TarInfo(f"{idx}.{format}") info.size = len(f.getbuffer()) tar.addfile(info, fileobj=f) for extra in extras: info = TarInfo(extra[0] or "_.txt") info.size = len(extra[1].getbuffer()) tar.addfile(info, fileobj=extra[1]) fp.seek(0) return fp
def add_buf_to_tar(tar: TarFile, filename: str, buf: BytesIO): buf.flush() buf.seek(0) info = TarInfo(name=filename) info.size = len(buf.getvalue()) tar.addfile(info, buf)
def copy_to_container(container: "Container", source_path: str, target_path: str) -> None: """ Copy a file into a Docker container :param container: Container object :param source_path: Source file path :param target_path: Target file path (in the container) :return: """ # https://github.com/docker/docker-py/issues/1771 with open(source_path, "rb") as f: data = f.read() tarinfo = TarInfo(name=os.path.basename(target_path)) tarinfo.size = len(data) tarinfo.mtime = int(time.time()) stream = BytesIO() tar = TarFile(fileobj=stream, mode="w") tar.addfile(tarinfo, BytesIO(data)) tar.close() stream.seek(0) container.put_archive(path=os.path.dirname(target_path), data=stream.read())
def move_certs(self, paths): self.log.info("Staging internal ssl certs for %s", self._log_name) yield self.pull_image(self.move_certs_image) # create the volume volume_name = self.format_volume_name(self.certs_volume_name, self) # create volume passes even if it already exists self.log.info("Creating ssl volume %s for %s", volume_name, self._log_name) yield self.docker('create_volume', volume_name) # create a tar archive of the internal cert files # docker.put_archive takes a tarfile and a running container # and unpacks the archive into the container nb_paths = {} tar_buf = BytesIO() archive = TarFile(fileobj=tar_buf, mode='w') for key, hub_path in paths.items(): fname = os.path.basename(hub_path) nb_paths[key] = '/certs/' + fname with open(hub_path, 'rb') as f: content = f.read() tarinfo = TarInfo(name=fname) tarinfo.size = len(content) tarinfo.mtime = os.stat(hub_path).st_mtime tarinfo.mode = 0o644 archive.addfile(tarinfo, BytesIO(content)) archive.close() tar_buf.seek(0) # run a container to stage the certs, # mounting the volume at /certs/ host_config = self.client.create_host_config( binds={ volume_name: {"bind": "/certs", "mode": "rw"}, }, ) container = yield self.docker('create_container', self.move_certs_image, volumes=["/certs"], host_config=host_config, ) container_id = container['Id'] self.log.debug( "Container %s is creating ssl certs for %s", container_id[:12], self._log_name, ) # start the container yield self.docker('start', container_id) # stage the archive to the container try: yield self.docker( 'put_archive', container=container_id, path='/certs', data=tar_buf, ) finally: yield self.docker('remove_container', container_id) return nb_paths
def generate_tar(entries): tar_buf = BytesIO() tar_file = TarFile(mode="w", fileobj=tar_buf) for path, contents in entries.items(): tar_info = TarInfo(name=path) tar_info.size = len(contents) tar_file.addfile(tar_info, fileobj=BytesIO(contents)) return BytesIO(tar_buf.getvalue())
def add_string_to_tarfile(self, tar: tarfile.TarFile, name: str, string: str): encoded = string.encode('utf-8') bytes_io = io.BytesIO(encoded) tar_info = tarfile.TarInfo(name=name) tar_info.mtime = time.time() tar_info.size = len(encoded) tar.addfile(tarinfo=tar_info, fileobj=bytes_io)
def generate_tar(entries): tar_buf = BytesIO() tar_file = TarFile(mode='w', fileobj=tar_buf) for path, contents in entries.items(): tar_info = TarInfo(name=path) tar_info.size = len(contents) tar_file.addfile(tar_info, fileobj=BytesIO(contents)) return BytesIO(tar_buf.getvalue())
def write_tardata(output: tarfile.TarFile, name: str, data: bytes): tarinfo = tarfile.TarInfo() tarinfo.name = name tarinfo.size = len(data) tarinfo.mtime = time.time() tarinfo.mode = int('644', base=8) output.addfile(tarinfo, io.BytesIO(data))
def _unpack_info_file(self, tar: TarFile, member: TarInfo, fileobj: io.BytesIO): directory = Path("var", "lib", "dpkg", "info").as_posix() name = member.name.lstrip("./") member.name = f"./{directory}/{self.package.name}.{name}" tar.addfile(member, fileobj)
def compactar(idfiltro): m_filtro = Filtro.objects.get(pk=idfiltro) if m_filtro.situacao in SITUACOES_EXECUTORES: return m_filtro.situacao = "6" m_filtro.percentual_atual = 0 m_filtro.save() slug_classificador = slugify(m_filtro.nome) documentos = m_filtro.documento_set.all() qtd_documentos = documentos.count() # cria o streamfile em disco nometar = "%s.tar.bz2" % slug_classificador numeros_documentos = defaultdict(int) with BytesIO() as arquivotar: tarfile = TarFile(name=nometar, mode="w", fileobj=arquivotar) for contador, documento in enumerate(documentos): numero = documento.numero numeros_documentos[numero] += 1 ordem = numeros_documentos[numero] with BytesIO() as conteudo_documento: conteudo_documento.write( documento.conteudo.encode("latin1", "ignore")) conteudo_documento.seek(0) if documento.classe_filtro: classe = slugify(documento.classe_filtro.nome) else: classe = slugify("Não Identificado") if documento.tipo_movimento: tipo = slugify(documento.tipo_movimento.nome) else: tipo = "documento" tarinfo = TarInfo(name="%s/%s-%s-%s.txt" % (classe, tipo, numero, ordem)) tarinfo.size = len(conteudo_documento.getvalue()) conteudo_documento.seek(0) tarfile.addfile(fileobj=conteudo_documento, tarinfo=tarinfo) m_filtro.percentual_atual = contador / qtd_documentos * 100 logger.info("Percentual %s" % m_filtro.percentual_atual) m_filtro.save() arquivotar.seek(0) m_filtro.saida.save(nometar, File(arquivotar)) m_filtro.situacao = "7" m_filtro.save()
def create_nested_tar(tar: tarfile.TarFile): for path, file_size in TAR_FILES: tar_info = tarfile.TarInfo() tar_info.mtime = TIME tar_info.size = file_size tar_info.path = path f = ReversibleTestFile(file_size) tar.addfile(tar_info, f) yield f, path, file_size, TIME
def add_bytes_to_tar(tar: TarFile, filename: str, data: bytes): buf = BytesIO() buf.write(data) buf.flush() buf.seek(0) info = TarInfo(name=filename) info.size = len(data) tar.addfile(info, buf)
def add_to_tar(tar: TarFile, data: bytes, filename: str): tarinfo = TarInfo(name=filename) tarinfo.size = len(data) tarinfo.mtime = int(datetime.timestamp(datetime.utcnow())) tarinfo.mode = 436 tarinfo.type = b'0' tarinfo.uid = tarinfo.gid = 0 tarinfo.uname = tarinfo.gname = "0" tar.addfile(tarinfo, BytesIO(data))
def add_data_as_file(tf: tarfile.TarFile, arcname: str, data: Union[str, bytes]) -> tarfile.TarInfo: ans = tarfile.TarInfo(arcname) ans.mtime = 0 ans.type = tarfile.REGTYPE if isinstance(data, str): data = data.encode('utf-8') ans.size = len(data) normalize_tarinfo(ans) tf.addfile(ans, io.BytesIO(data)) return ans
def tar_add_file(tar: tarfile.TarFile, filename: str): '''Add filename to tarfile, file size expected to be invalid''' try: with open(filename, "rb") as fp: data = fp.read() # big gulp except OSError as e: print(f"ERROR: {filename}: {e}", file=sys.stderr) return info = tarfile.TarInfo(filename) if info.name[0] == '/': info.name = info.name[1:] info.size = len(data) tar.addfile(info, io.BytesIO(data))
def _add_pyproject(self, tar: tarfile.TarFile, tar_dir: str) -> None: """Rewrites the pyproject.toml before adding to tarball. This is mainly aiming at fixing the version number in pyproject.toml """ pyproject = toml.loads(self.meta.filepath.read_text("utf-8")) if not isinstance(self.meta._metadata.get("version", ""), str): self.meta._metadata["version"] = self.meta.version pyproject["project"] = self.meta._metadata name = self.meta.filepath.name tarinfo = tar.gettarinfo(name, os.path.join(tar_dir, name)) bio = io.BytesIO(toml.dumps(pyproject).encode("utf-8")) tarinfo.size = len(bio.getvalue()) tar.addfile(tarinfo, bio)
def replace_or_append_file_to_layer(file_to_replace: str, content_or_path: bytes, img: tarfile.TarFile): # Is content or path? if not os.path.exists(content_or_path): # Is a content t = tarfile.TarInfo(file_to_replace) t.size = len(content_or_path) img.addfile(t, io.BytesIO(content_or_path)) else: # Is a path img.add(content_or_path, file_to_replace)
def add_file(tar: tarfile.TarFile, file_name: str) -> Tuple[int, int, datetime, Optional[str]]: # FIXME: error: "TarFile" has no attribute "offset" offset: int = tar.offset # type: ignore tarinfo: tarfile.TarInfo = tar.gettarinfo(file_name) # Change the size of any hardlinks from 0 to the size of the actual file if tarinfo.islnk(): tarinfo.size = os.path.getsize(file_name) # Add the file to the tar tar.addfile(tarinfo) md5: Optional[str] = None # Only add files or hardlinks. # (So don't add directories or softlinks.) if tarinfo.isfile() or tarinfo.islnk(): f: _io.TextIOWrapper = open(file_name, "rb") hash_md5: _hashlib.HASH = hashlib.md5() if tar.fileobj is not None: fileobj: _io.BufferedWriter = tar.fileobj else: raise TypeError("Invalid tar.fileobj={}".format(tar.fileobj)) while True: s: str = f.read(BLOCK_SIZE) if len(s) > 0: # If the block read in is non-empty, write it to fileobj and update the hash fileobj.write(s) hash_md5.update(s) if len(s) < BLOCK_SIZE: # If the block read in is smaller than BLOCK_SIZE, # then we have reached the end of the file. # blocks = how many blocks of tarfile.BLOCKSIZE fit in tarinfo.size # remainder = how much more content is required to reach tarinfo.size blocks: int remainder: int blocks, remainder = divmod(tarinfo.size, tarfile.BLOCKSIZE) if remainder > 0: null_bytes: bytes = tarfile.NUL # Write null_bytes to get the last block to tarfile.BLOCKSIZE fileobj.write(null_bytes * (tarfile.BLOCKSIZE - remainder)) blocks += 1 # Increase the offset by the amount already saved to the tar # FIXME: error: "TarFile" has no attribute "offset" tar.offset += blocks * tarfile.BLOCKSIZE # type: ignore break f.close() md5 = hash_md5.hexdigest() size: int = tarinfo.size mtime: datetime = datetime.utcfromtimestamp(tarinfo.mtime) return offset, size, mtime, md5
def add_to_tarfile(tar: tarfile.TarFile, fetch_result: FetchResult): """ Adds an entity fetched from object storage to a tarfile. The name (path), size, and mtime of the record in the tarfile will be set to match the information retrieved from object storage. :param tar: The tarfile that should be appended to :param fetch_result: The metadata and payload fetched from object storage """ with closing(fetch_result.payload): info = tarfile.TarInfo() info.name = fetch_result.info.path info.size = fetch_result.info.size info.mtime = fetch_result.info.last_modified.timestamp() tar.addfile(info, fetch_result.payload)
def _add_file( archive: tarfile.TarFile, name: str, mode: int, epoch: int, data: bytes ) -> None: """ Add an in-memory file into a tar archive. :param archive: archive to append to :param name: name of the file to add :param mode: permissions of the file :param epoch: fixed modification time to set :param data: file contents """ info = tarfile.TarInfo("./" + name) info.size = len(data) info.mode = mode archive.addfile(_clean_info(None, epoch, info), BytesIO(data))
def _add_virtual_file_to_archive(cls, archive: tarfile.TarFile, filename: str, filedata: dict) -> None: """ Add filedata to a stream of in-memory bytes and add these bytes to the archive. Args: archive (TarFile): The archive object to add the virtual file to. filename (str): The name of the virtual file. filedata (dict): The data to add to the bytes stream. """ filedata_string = json.dumps(filedata) filedata_bytes = BytesIO(filedata_string.encode("utf-8")) tarinfo = tarfile.TarInfo(filename) tarinfo.size = len(filedata_string) archive.addfile(tarinfo, filedata_bytes)
def create_archive(filepaths): tarstream = BytesIO() tarfile = TarFile(fileobj=tarstream, mode='w') for filepath in filepaths: file = open(filepath, 'r') file_data = file.read() tarinfo = TarInfo(name=basename(file.name)) tarinfo.size = len(file_data) tarinfo.mtime = time() tarfile.addfile(tarinfo, BytesIO(file_data)) tarfile.close() tarstream.seek(0) return tarstream
def _write_result_part_to_tarfile(tf: tarfile.TarFile, result: ResultPart) -> None: """Append bytes to `tf`.""" info = tarfile.TarInfo(result.name) info.size = len(result.body) for attr_name, header_name in ( ("mtime", "mtime"), ("api_endpoint", "cjw:apiEndpoint"), ("api_params", "cjw:apiParams"), ("http_status", "cjw:httpStatus"), ("n_tweets", "cjw:nTweets"), ): value = getattr(result, attr_name) if value is not None: info.pax_headers[header_name] = str(value) tf.addfile(info, io.BytesIO(result.body))
def run(self, args, argv): # Create a temporary tarball with our whole build context and # dockerfile for the update tmp = tempfile.NamedTemporaryFile(suffix="dckr.tar.gz") tmp_tar = TarFile(fileobj=tmp, mode='w') # Add the executable to the tarball, using the current # configured binfmt_misc path. If we don't get a path then we # only need the support libraries copied ff, enabled = _check_binfmt_misc(args.executable) if not enabled: print("binfmt_misc not enabled, update disabled") return 1 if ff: tmp_tar.add(args.executable, arcname=ff) # Add any associated libraries libs = _get_so_libs(args.executable) if libs: for l in libs: tmp_tar.add(os.path.realpath(l), arcname=l) # Create a Docker buildfile df = StringIO() df.write(u"FROM %s\n" % args.tag) df.write(u"ADD . /\n") df_bytes = BytesIO(bytes(df.getvalue(), "UTF-8")) df_tar = TarInfo(name="Dockerfile") df_tar.size = df_bytes.getbuffer().nbytes tmp_tar.addfile(df_tar, fileobj=df_bytes) tmp_tar.close() # reset the file pointers tmp.flush() tmp.seek(0) # Run the build with our tarball context dkr = Docker() dkr.update_image(args.tag, tmp, quiet=args.quiet) return 0
def _add_pyproject(self, tar: tarfile.TarFile, tar_dir: str) -> None: """Rewrites the pyproject.toml before adding to tarball. This is mainly aiming at fixing the version number in pyproject.toml """ with self.meta.filepath.open("rb") as f: pyproject = tomli.load(f) if self.meta.dynamic and "version" in self.meta.dynamic: self.meta._metadata["version"] = self.meta.version self.meta._metadata["dynamic"].remove("version") pyproject["project"] = self.meta._metadata name = self.meta.filepath.name tarinfo = tar.gettarinfo(name, os.path.join(tar_dir, name)) bio = io.BytesIO() tomli_w.dump(pyproject, bio) tarinfo.size = len(bio.getvalue()) bio.seek(0) tar.addfile(tarinfo, bio)
def _add_pyproject(self, tar: tarfile.TarFile, tar_dir: str) -> None: """Rewrites the pyproject.toml before adding to tarball. This is mainly aiming at fixing the version number in pyproject.toml """ pyproject_content = self._meta.filepath.read_text() if not isinstance(self._meta._metadata.get("version", ""), str): pyproject_content = re.sub( r"^version *= *.+?$", f'version = "{self._meta.version}"', pyproject_content, flags=re.M, ) name = "pyproject.toml" tarinfo = tar.gettarinfo(name, os.path.join(tar_dir, name)) bio = io.BytesIO(pyproject_content.encode("utf-8")) tarinfo.size = len(bio.getvalue()) tar.addfile(tarinfo, bio)
def run(self, args, argv): # Create a temporary tarball with our whole build context and # dockerfile for the update tmp = tempfile.NamedTemporaryFile(suffix="dckr.tar.gz") tmp_tar = TarFile(fileobj=tmp, mode='w') # Add the executable to the tarball, using the current # configured binfmt_misc path. If we don't get a path then we # only need the support libraries copied ff, enabled = _check_binfmt_misc(args.executable) if not enabled: print("binfmt_misc not enabled, update disabled") return 1 if ff: tmp_tar.add(args.executable, arcname=ff) # Add any associated libraries libs = _get_so_libs(args.executable) if libs: for l in libs: tmp_tar.add(os.path.realpath(l), arcname=l) # Create a Docker buildfile df = StringIO() df.write("FROM %s\n" % args.tag) df.write("ADD . /\n") df.seek(0) df_tar = TarInfo(name="Dockerfile") df_tar.size = len(df.buf) tmp_tar.addfile(df_tar, fileobj=df) tmp_tar.close() # reset the file pointers tmp.flush() tmp.seek(0) # Run the build with our tarball context dkr = Docker() dkr.update_image(args.tag, tmp, quiet=args.quiet) return 0
def save_to_file(self, f): tar = TarFile(f, "w") # save info file f = StringIO(repr((self.agedesc, self.generation))) info = tar.gettarinfo(None, "info.py", f) tar.addfile(info, f) f.close() # save agents for i in range(len(self.agents)): f = StringIO() self.agents[i].save_to_file(f) info = tar.gettarinfo(None, str(i) + ".agt", f) tar.addfile(info, f) f.close() tar.close()
def save_to_file(self, f): tar = TarFile(f, "w") # save info file f = StringIO(repr((self.agedesc, self.generation))) info = tar.gettarinfo(None, "info.py", f) tar.addfile(info, f) f.close() # save agents for i in range(len(self.agents)): f = StringIO() self.agents[i].save_to_file(f) info = tar.gettarinfo(None, str(i)+".agt", f) tar.addfile(info, f) f.close() tar.close()
def _unpack_data(self, tar: TarFile, data_archive: TarFile): with io.BytesIO( str.encode("\n".join([ member.name.lstrip(".") for member in data_archive if member.name.lstrip(".") ]) + "\n")) as fileobj: info = TarInfo("list") info.size = fileobj.getbuffer().nbytes self._unpack_info_file(tar, info, fileobj) names = tar.getnames() for member in (member for member in data_archive if member.name not in names): if member.islnk() or member.issym() or member.isdir(): tar.addfile(member) else: with data_archive.extractfile(member) as fileobj: tar.addfile(member, fileobj)
def write(self, file_name): if not self.data or not os.path.isdir(self.data): raise Exception('Must set data before building') gzfile = GzipFile(file_name, 'w') tar = TarFile(fileobj=gzfile, mode='w') buff = BytesIO(json.dumps(self.control).encode()) info = TarInfo(name='./CONTROL') info.size = buff.getbuffer().nbytes tar.addfile(tarinfo=info, fileobj=buff) if self.init is not None: buff = BytesIO(self.init.encode()) info = TarInfo(name='./INIT') info.size = buff.getbuffer().nbytes tar.addfile(tarinfo=info, fileobj=buff) data = BytesIO() datatar = TarFile(fileobj=data, mode='w') datatar.add(self.data, '/') datatar.close() data.seek(0) info = TarInfo(name='./DATA') info.size = data.getbuffer().nbytes tar.addfile(tarinfo=info, fileobj=data) tar.close() gzfile.close()
def run(self, args, argv): # Create a temporary tarball with our whole build context and # dockerfile for the update tmp = tempfile.NamedTemporaryFile(suffix="dckr.tar.gz") tmp_tar = TarFile(fileobj=tmp, mode='w') # Add the executable to the tarball bn = os.path.basename(args.executable) ff = "/usr/bin/%s" % bn tmp_tar.add(args.executable, arcname=ff) # Add any associated libraries libs = _get_so_libs(args.executable) if libs: for l in libs: tmp_tar.add(os.path.realpath(l), arcname=l) # Create a Docker buildfile df = StringIO() df.write("FROM %s\n" % args.tag) df.write("ADD . /\n") df.seek(0) df_tar = TarInfo(name="Dockerfile") df_tar.size = len(df.buf) tmp_tar.addfile(df_tar, fileobj=df) tmp_tar.close() # reset the file pointers tmp.flush() tmp.seek(0) # Run the build with our tarball context dkr = Docker() dkr.update_image(args.tag, tmp, quiet=args.quiet) return 0
def plot_predictions(self): epoch, batch, data = self.get_next_batch(train=False) # get a test batch num_classes = self.test_data_provider.get_num_classes() NUM_ROWS = 2 NUM_COLS = 4 NUM_IMGS = NUM_ROWS * NUM_COLS if not self.save_preds else data[0].shape[1] NUM_TOP_CLASSES = min(num_classes, 5) # show this many top labels NUM_OUTPUTS = self.model_state["layers"][self.softmax_name]["outputs"] PRED_IDX = 1 label_names = [lab.split(",")[0] for lab in self.test_data_provider.batch_meta["label_names"]] if self.only_errors: preds = n.zeros((data[0].shape[1], NUM_OUTPUTS), dtype=n.single) else: preds = n.zeros((NUM_IMGS, NUM_OUTPUTS), dtype=n.single) # rand_idx = nr.permutation(n.r_[n.arange(1), n.where(data[1] == 552)[1], n.where(data[1] == 795)[1], n.where(data[1] == 449)[1], n.where(data[1] == 274)[1]])[:NUM_IMGS] rand_idx = nr.randint(0, data[0].shape[1], NUM_IMGS) if NUM_IMGS < data[0].shape[1]: data = [n.require(d[:, rand_idx], requirements="C") for d in data] # data += [preds] # Run the model print [d.shape for d in data], preds.shape self.libmodel.startFeatureWriter(data, [preds], [self.softmax_name]) IGPUModel.finish_batch(self) print preds data[0] = self.test_data_provider.get_plottable_data(data[0]) if self.save_preds: if not gfile.Exists(self.save_preds): gfile.MakeDirs(self.save_preds) preds_thresh = preds > 0.5 # Binarize predictions data[0] = data[0] * 255.0 data[0][data[0] < 0] = 0 data[0][data[0] > 255] = 255 data[0] = n.require(data[0], dtype=n.uint8) dir_name = "%s_predictions_batch_%d" % (os.path.basename(self.save_file), batch) tar_name = os.path.join(self.save_preds, "%s.tar" % dir_name) tfo = gfile.GFile(tar_name, "w") tf = TarFile(fileobj=tfo, mode="w") for img_idx in xrange(NUM_IMGS): img = data[0][img_idx, :, :, :] imsave = Image.fromarray(img) prefix = ( "CORRECT" if data[1][0, img_idx] == preds_thresh[img_idx, PRED_IDX] else "FALSE_POS" if preds_thresh[img_idx, PRED_IDX] == 1 else "FALSE_NEG" ) file_name = "%s_%.2f_%d_%05d_%d.png" % ( prefix, preds[img_idx, PRED_IDX], batch, img_idx, data[1][0, img_idx], ) # gf = gfile.GFile(file_name, "w") file_string = StringIO() imsave.save(file_string, "PNG") tarinf = TarInfo(os.path.join(dir_name, file_name)) tarinf.size = file_string.tell() file_string.seek(0) tf.addfile(tarinf, file_string) tf.close() tfo.close() # gf.close() print "Wrote %d prediction PNGs to %s" % (preds.shape[0], tar_name) else: fig = pl.figure(3, figsize=(12, 9)) fig.text(0.4, 0.95, "%s test samples" % ("Mistaken" if self.only_errors else "Random")) if self.only_errors: # what the net got wrong if NUM_OUTPUTS > 1: err_idx = [i for i, p in enumerate(preds.argmax(axis=1)) if p not in n.where(data[2][:, i] > 0)[0]] else: err_idx = n.where(data[1][0, :] != preds[:, 0].T)[0] print err_idx err_idx = r.sample(err_idx, min(len(err_idx), NUM_IMGS)) data[0], data[1], preds = data[0][:, err_idx], data[1][:, err_idx], preds[err_idx, :] import matplotlib.gridspec as gridspec import matplotlib.colors as colors cconv = colors.ColorConverter() gs = gridspec.GridSpec(NUM_ROWS * 2, NUM_COLS, width_ratios=[1] * NUM_COLS, height_ratios=[2, 1] * NUM_ROWS) # print data[1] for row in xrange(NUM_ROWS): for col in xrange(NUM_COLS): img_idx = row * NUM_COLS + col if data[0].shape[0] <= img_idx: break pl.subplot(gs[(row * 2) * NUM_COLS + col]) # pl.subplot(NUM_ROWS*2, NUM_COLS, row * 2 * NUM_COLS + col + 1) pl.xticks([]) pl.yticks([]) img = data[0][img_idx, :, :, :] img = img.squeeze() if len(img.shape) > 2: # more than 2 dimensions if img.shape[2] is 2: # if two channels # copy 2nd to 3rd channel for visualization a1 = img a2 = img[:, :, 1] a2 = a2[:, :, n.newaxis] img = n.concatenate((a1, a2), axis=2) pl.imshow(img, interpolation="lanczos") else: pl.imshow(img, interpolation="lanczos", cmap=pl.gray()) show_title = data[1].shape[0] == 1 true_label = [int(data[1][0, img_idx])] if show_title else n.where(data[1][:, img_idx] == 1)[0] # print true_label # print preds[img_idx,:].shape # print preds[img_idx,:].max() true_label_names = [label_names[i] for i in true_label] img_labels = sorted(zip(preds[img_idx, :], label_names), key=lambda x: x[0])[-NUM_TOP_CLASSES:] # print img_labels axes = pl.subplot(gs[(row * 2 + 1) * NUM_COLS + col]) height = 0.5 ylocs = n.array(range(NUM_TOP_CLASSES)) * height pl.barh( ylocs, [l[0] for l in img_labels], height=height, color=["#ffaaaa" if l[1] in true_label_names else "#aaaaff" for l in img_labels], ) # pl.title(", ".join(true_labels)) if show_title: pl.title(", ".join(true_label_names), fontsize=15, fontweight="bold") else: print true_label_names pl.yticks( ylocs + height / 2, [l[1] for l in img_labels], x=1, backgroundcolor=cconv.to_rgba("0.65", alpha=0.5), weight="bold", ) for line in enumerate(axes.get_yticklines()): line[1].set_visible(False) # pl.xticks([width], ['']) # pl.yticks([]) pl.xticks([]) pl.ylim(0, ylocs[-1] + height) pl.xlim(0, 1)