Ejemplo n.º 1
0
def load_chunk(tarfile, size=None):
  """Load a number of images from a single imagenet .tar file.

  This function also converts the image from grayscale to RGB if necessary.

  Args:
    tarfile (tarfile.TarFile): The archive from which the files get loaded.
    size (Optional[Tuple[int, int]]): Resize the image to this size if provided.

  Returns:
    numpy.ndarray: Contains the image data in format [batch, w, h, c]
  """
  result = []
  filenames = []
  for member in tarfile.getmembers():
    filename = member.path
    content = tarfile.extractfile(member)
    img = Image.open(content)
    rgbimg = Image.new("RGB", img.size)
    rgbimg.paste(img)
    if size != None:
      rgbimg = rgbimg.resize(size, Image.ANTIALIAS)
    result.append(np.array(rgbimg).reshape(1, rgbimg.size[0], rgbimg.size[1], 3))
    filenames.append(filename)
  return np.concatenate(result), filenames
Ejemplo n.º 2
0
 def createArchive(self, cfgOutName , name = '', mode = 'w:gz'):
     """
     create the archive to upload
     """
      
     if not name:
         import uuid
         name = os.path.join(os.getcwd(), str(uuid.uuid4()) +'default.tgz')
      
     import tarfile
     print 'opening tar file'
     tarfile = tarfile.open(name=name , mode=mode, dereference=True)
     print 'adding %s to the tarball' % cfgOutName
     tarfile.add(cfgOutName, arcname='PSet.py')
      
     #checkSum
     print 'calculating the checksum'
     lsl = [(x.name, int(x.size), int(x.mtime), x.uname) for x in tarfile.getmembers()]
     # hasher = hashlib.md5(str(lsl))
     hasher = hashlib.sha256(str(lsl))
     checksum = hasher.hexdigest()
     #end
     tarfile.close()
      
     return name, checksum
Ejemplo n.º 3
0
 def get_categories(self, tarfile):
     catname = re.compile("\/(.+)\.rules$")
     for member in tarfile.getmembers():
         if member.name.endswith('.rules'):
             match = catname.search(member.name)
             name = match.groups()[0]
             category = Category.objects.filter(source = self, name = name)
             if not category:
                 category = Category.objects.create(source = self,
                                         name = name, created_date = timezone.now(),
                                         filename = member.name)
                 category.get_rules(self)
             else:
                 category[0].get_rules(self)
Ejemplo n.º 4
0
def _extract_tar_info(tarfile, class_to_idx=None, sort=True):
    files = []
    labels = []
    for ti in tarfile.getmembers():
        if not ti.isfile():
            continue
        dirname, basename = os.path.split(ti.path)
        label = os.path.basename(dirname)
        ext = os.path.splitext(basename)[1]
        if ext.lower() in IMG_EXTENSIONS:
            files.append(ti)
            labels.append(label)
    if class_to_idx is None:
        unique_labels = set(labels)
        sorted_labels = list(sorted(unique_labels, key=natural_key))
        class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
    tarinfo_and_targets = [(f, class_to_idx[l]) for f, l in zip(files, labels)
                           if l in class_to_idx]
    if sort:
        tarinfo_and_targets = sorted(tarinfo_and_targets,
                                     key=lambda k: natural_key(k[0].path))
    return tarinfo_and_targets, class_to_idx
Ejemplo n.º 5
0
def _analyze_tarfile_for_import(tarfile, project, schema, tmpdir):
    """Validate paths in tarfile.

    Parameters
    ----------
    tarfile : :class:`tarfile.TarFile`
        tarfile to analyze.
    project : :class:`~signac.Project`
        The project to import the data into.
    schema : str or callable
        An optional schema function, which is either a string or a function that accepts a
        path as its first and only argument and returns the corresponding state point as dict.
        (Default value = None).
    tmpdir : :class:`tempfile.TemporaryDirectory`
        Temporary directory, an instance of ``TemporaryDirectory``.

    Yields
    ------
    src : str
        Source path.
    copy_executor : callable
        A callable that uses a provided function to copy to a destination.

    Raises
    ------
    TypeError
        If the schema given is not None, callable, or a string.
    :class:`~signac.errors.DestinationExistsError`
        If a job is already initialized.
    :class:`~signac.errors.StatepointParsingError`
        If the jobs identified with the given schema function are not unique.
    AssertionError
        If ``tmpdir`` given is not a directory.

    """

    def read_sp_manifest_file(path):
        """Read state point from the manifest file.

        Parameters
        ----------
        path : str
            Path to manifest file.

        Returns
        -------
        dict
            state point.

        """
        # Must use forward slashes, not os.path.sep.
        fn_manifest = _tarfile_path_join(path, project.Job.FN_MANIFEST)
        try:
            with closing(tarfile.extractfile(fn_manifest)) as file:
                return json.loads(file.read())
        except KeyError:
            pass

    if schema is None:
        schema_function = read_sp_manifest_file
    elif callable(schema):
        schema_function = _with_consistency_check(schema, read_sp_manifest_file)
    elif isinstance(schema, str):
        schema_function = _with_consistency_check(
            _make_path_based_schema_function(schema), read_sp_manifest_file
        )
    else:
        raise TypeError("The schema variable must be None, callable, or a string.")

    mappings = {}
    skip_subdirs = set()

    dirs = [member.name for member in tarfile.getmembers() if member.isdir()]
    for name in sorted(dirs):
        if (
            os.path.dirname(name) in skip_subdirs
        ):  # skip all sub-dirs of identified dirs
            skip_subdirs.add(name)
            continue

        sp = schema_function(name)
        if sp is not None:
            job = project.open_job(sp)
            if os.path.exists(job.workspace()):
                raise DestinationExistsError(job)
            mappings[name] = job
            skip_subdirs.add(name)

    # Check uniqueness
    if len(set(mappings.values())) != len(mappings):
        raise StatepointParsingError(
            "The jobs identified with the given schema function are not unique!"
        )

    tarfile.extractall(path=tmpdir)
    for path, job in mappings.items():
        assert os.path.isdir(tmpdir)
        src = os.path.join(tmpdir, path)
        assert os.path.isdir(src)
        copy_executor = _CopyFromTarFileExecutor(src, job)
        yield src, copy_executor
# Path to frozen detection graph. This is the actual model that is used for the object detection.
path_to_model = model + '/frozen_inference_graph.pb'

# List of the strings that is used to add correct label for each box.
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')

#number of classes to be identified
NUM_CLASSES = 90

# Model gets downloaded

opener_web = urllib.request.URLopener()
opener_web.retrieve(download_url + model_tar, model_tar)
tarfile = tarfile.open(model_tar)
for file in tarfile.getmembers():
    file_name = os.path.basename(file.name)
    if 'frozen_inference_graph.pb' in file_name:
        tarfile.extract(file, os.getcwd())

# ## Loading this Tensorflow model into the memory

detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(path_to_model, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')

# ## Loading label map
Ejemplo n.º 7
0
    def _package_chart(self, tarfile, version=None, **kwargs):
        '''Internal Helper

        Internal method to make it easier to hanle closing
        the tarfile passed here automatically on exit.
        '''
        def get_data(filename):
            membername = os.path.join(self.name, filename)
            yaml = tarfile.extractfile(membername)
            return membername, ruamel.yaml.load(
                yaml, Loader=ruamel.yaml.RoundTripLoader)

        chart_file, chart_data = get_data('Chart.yaml')
        chart_data['version'] = version

        values_file, values_data = get_data('values.yaml')
        values = self.data.get('values', None)
        if values:
            # TODO(kerrin) expand the amount of data available
            # for users to control
            data = {
                'version': version,
                'name': self.name,
            }
            data.update(kwargs)

            def expand_values(source, expanded):
                for key, value in source.items():
                    if isinstance(value, dict):
                        try:
                            expand_values(value, expanded[key])
                        except KeyError as e:
                            raise windlass.exc.MissingEntryInChartValues(
                                expected_source=source,
                                missing_key=e.args[0],
                                values_filename=values_file,
                                chart_name=self.name)
                    else:
                        newvalue = value.format(**data)
                        expanded[key] = newvalue

            # Update by reference the values_data dictionary based on
            # the format of the supplied values field.
            expand_values(values, values_data)

        with tempfile.NamedTemporaryFile() as tmp_file:
            with tarfile.open(tmp_file.name, 'w:gz') as out:
                for member in tarfile.getmembers():
                    if member.name == chart_file:
                        # Override the size of the file
                        datastr = ruamel.yaml.dump(
                            chart_data, Dumper=ruamel.yaml.RoundTripDumper)
                        databytes = datastr.encode('utf-8')
                        member.size = len(databytes)
                        out.addfile(member, io.BytesIO(databytes))
                    elif member.name == values_file:
                        # Override the size of the file
                        datastr = ruamel.yaml.dump(
                            values_data, Dumper=ruamel.yaml.RoundTripDumper)
                        databytes = datastr.encode('utf-8')
                        member.size = len(databytes)
                        out.addfile(member, io.BytesIO(databytes))
                    else:
                        out.addfile(member, tarfile.extractfile(member.name))

            with open(tmp_file.name, 'rb') as fp:
                return fp.read()