Esempio n. 1
0
def write_batches(args, macrodir, datadir, val_pct):
    if os.path.exists(macrodir):
        return
    print('Writing batches to %s' % macrodir)
    bw = BatchWriter(out_dir=macrodir,
                     image_dir=os.path.join(args.data_dir, datadir),
                     target_size=32, macro_size=1024,
                     file_pattern='*.jpg', validation_pct=val_pct)
    bw.run()
Esempio n. 2
0
    def load(self, backend=None, experiment=None):
        '''
        Imageset only supports nervanagpu based backends
        '''
        if not hasattr(self.backend, 'ng'):
            raise DeprecationWarning("Only nervanagpu-based backends "
                                     "supported.  For using cudanet backend, "
                                     "revert to neon 0.8.2 ")

        bdir = os.path.expanduser(self.save_dir)
        cachefile = os.path.join(bdir, 'dataset_cache.pkl')
        if not os.path.exists(cachefile):
            logger.error("Batch dir cache not found in %s:", cachefile)
            response = raw_input("Press Y to create, otherwise exit: ")
            if response == 'Y':
                from neon.util.batch_writer import (BatchWriter,
                                                    BatchWriterImagenet)

                if self.imageset.startswith('I1K'):
                    self.bw = BatchWriterImagenet(**self.__dict__)
                else:
                    self.bw = BatchWriter(**self.__dict__)
                self.bw.run()
                logger.error('Done writing batches - please rerun to train.')
            else:
                logger.error('Exiting...')
            sys.exit()
        cstats = deserialize(cachefile, verbose=False)
        if cstats['macro_size'] != self.macro_size:
            raise NotImplementedError("Cached macro size %d different from "
                                      "specified %d, delete save_dir %s "
                                      "and try again.",
                                      cstats['macro_size'],
                                      self.macro_size,
                                      self.save_dir)
        # Set the max indexes of batches for each from the cache file
        self.maxval = cstats['nval'] + cstats['val_start'] - 1
        self.maxtrain = cstats['ntrain'] + cstats['train_start'] - 1

        # Make sure only those properties not by yaml are updated
        cstats.update(self.__dict__)
        self.__dict__.update(cstats)
        # Should also put (in addition to nclass), number of train/val images
        req_param(self, ['ntrain', 'nval', 'train_start', 'val_start',
                         'train_mean', 'val_mean', 'labels_dict'])