コード例 #1
0
    def save_snapshot(self):
        """
        Save snapshots of the model at specified epochs.
        If serialize_schedule is a list of ints, it will serialize at those
        epochs.
        If serialize_schedule is a single int, it will serialize when
        epochs_complete is a multiple of serialize_schedule
        """
        if self.serialize_schedule is None:
            return

        if not hasattr(self, 'serialized_path'):
            logger.error('Serialize schedule specified, but no serialize '
                         'path provided, not saving')
            return

        if not isinstance(self.serialize_schedule, (list, int)):
            logger.error('Serialize schedule must be a list of epochs or a '
                         'single int indicating interval between save epochs')
            return

        if isinstance(self.serialize_schedule, list):
            dosave = self.epochs_complete in self.serialize_schedule
        else:
            dosave = self.epochs_complete % self.serialize_schedule == 0

        if dosave:
            serialize(self.get_params(), self.serialized_path)
コード例 #2
0
ファイル: batch_writer.py プロジェクト: zz119/neon
    def write_batches(self,
                      name,
                      start,
                      labels,
                      imfiles,
                      targets=None,
                      is_tar=False):
        pool = Pool(processes=self.num_workers)
        psz = self.batch_size
        osz = self.output_image_size
        npts = (len(imfiles) + psz - 1) / psz

        imfiles = [imfiles[i * psz:(i + 1) * psz] for i in range(npts)]

        if targets is not None:
            targets = [
                targets[i * psz:(i + 1) * psz].T.copy() for i in range(npts)
            ]

        labels = [{k: v[i * psz:(i + 1) * psz]
                   for k, v in labels.iteritems()} for i in range(npts)]

        accum_buf = np.zeros((osz, osz, self.num_channels), dtype=np.int32)
        batch_mean = np.zeros(accum_buf.shape, dtype=np.uint8)
        logger.info("Writing %s batches...", name)
        for i, jpeg_file_batch in enumerate(imfiles):
            t = time()
            if is_tar:
                jpeg_file_batch = [j.read() for j in jpeg_file_batch]
            jpeg_strings = pool.map(
                functools.partial(proc_img, is_string=is_tar), jpeg_file_batch)
            targets_batch = None if targets is None else targets[i]
            labels_batch = labels[i]
            bfile = os.path.join(self.out_dir, 'data_batch_%d' % (start + i))
            serialize(
                {
                    'data': jpeg_strings,
                    'labels': labels_batch,
                    'targets': targets_batch
                }, bfile)
            logger.info("Wrote to %s (%s batch %d of %d) (%.2f sec)",
                        self.out_dir, name, i + 1, len(imfiles),
                        time() - t)

            # get the means and accumulate
            imgworker.calc_batch_mean(jpglist=jpeg_strings,
                                      tgt=batch_mean,
                                      orig_size=osz,
                                      rgb=self.rgb,
                                      nthreads=self.num_workers)

            # scale for the case where we have an undersized batch
            if len(jpeg_strings) < self.batch_size:
                batch_mean *= len(jpeg_strings) / self.batch_size
            accum_buf += batch_mean
        pool.close()
        mean_buf = self.train_mean if name == 'train' else self.val_mean
        mean_buf[:] = accum_buf / len(imfiles)
コード例 #3
0
ファイル: fit_predict_err.py プロジェクト: Eynaliyev/neon
 def save_results(self, dataset, setname, data, dataname):
     out_dir = os.path.join(dataset.repo_path, dataset.__class__.__name__)
     if hasattr(dataset, 'save_dir'):
         out_dir = dataset.save_dir
     out_dir = os.path.expandvars(os.path.expanduser(out_dir))
     if not os.path.exists(out_dir):
         os.makedirs(out_dir)
     filename = os.path.join(out_dir, '{}-{}.pkl'.format(setname, dataname))
     serialize(data.asnumpyarray().T, filename)
コード例 #4
0
ファイル: fit_predict_err.py プロジェクト: zz119/neon
 def save_results(self, dataset, setname, data, dataname):
     out_dir = os.path.join(dataset.repo_path, dataset.__class__.__name__)
     if hasattr(dataset, 'save_dir'):
         out_dir = dataset.save_dir
     out_dir = os.path.expandvars(os.path.expanduser(out_dir))
     if not os.path.exists(out_dir):
         os.makedirs(out_dir)
     filename = os.path.join(out_dir, '{}-{}.pkl'.format(setname, dataname))
     serialize(data.asnumpyarray().T, filename)
コード例 #5
0
ファイル: batch_writer.py プロジェクト: nkhuyu/neon
 def save_meta(self):
     serialize({'ntrain': self.ntrain,
                'nval': self.nval,
                'train_start': self.train_start,
                'val_start': self.val_start,
                'macro_size': self.batch_size,
                'train_mean': self.train_mean,
                'val_mean': self.val_mean,
                'labels_dict': self.labels_dict,
                'val_nrec': self.val_nrec,
                'train_nrec': self.train_nrec,
                'nclass': self.nclass}, self.stats)
コード例 #6
0
 def save_meta(self):
     serialize({'ntrain': self.ntrain,
                'nval': self.nval,
                'train_start': self.train_start,
                'val_start': self.val_start,
                'macro_size': self.batch_size,
                'train_mean': self.train_mean,
                'val_mean': self.val_mean,
                'labels_dict': self.labels_dict,
                'val_nrec': self.val_nrec,
                'train_nrec': self.train_nrec,
                'nclass': self.nclass}, self.stats)
コード例 #7
0
    def run(self):
        """
        Actually carry out each of the experiment steps.
        """

        # load the dataset, save it to disk if specified
        self.dataset.set_batch_size(self.model.batch_size)
        self.dataset.backend = self.backend
        self.dataset.load(backend=self.backend, experiment=self)
        if hasattr(self.dataset,
                   'serialized_path') and (self.dataset.serialized_path
                                           is not None):
            logger.warning('Ability to serialize dataset has been deprecated.')

        # fit the model to the data, save it if specified
        if not hasattr(self.model, 'backend'):
            self.model.backend = self.backend
        if not hasattr(self.model, 'epochs_complete'):
            self.model.epochs_complete = 0
        mfile = ''
        if hasattr(self.model, 'deserialized_path'):
            mfile = os.path.expandvars(
                os.path.expanduser(self.model.deserialized_path))
        elif hasattr(self.model, 'serialized_path'):
            mfile = os.path.expandvars(
                os.path.expanduser(self.model.serialized_path))
        elif self.live:
            raise RuntimeError('Live inference requires a saved model')

        if os.access(mfile, os.R_OK):
            if self.backend.is_distributed():
                raise NotImplementedError('Deserializing models not supported '
                                          'in distributed mode')
            self.model.set_params(deserialize(mfile))
        elif mfile != '':
            logger.info('Unable to find saved model %s, starting over', mfile)
            if self.live:
                raise RuntimeError('Live inference requires a saved model')

        if self.model.epochs_complete >= self.model.num_epochs:
            return
        if self.live:
            return

        self.model.fit(self.dataset)

        if hasattr(self.model, 'serialized_path'):
            if self.backend.rank() == 0:
                serialize(self.model.get_params(), self.model.serialized_path)
コード例 #8
0
ファイル: fit.py プロジェクト: JesseLivezey/neon
    def run(self):
        """
        Actually carry out each of the experiment steps.
        """

        # load the dataset, save it to disk if specified
        self.dataset.set_batch_size(self.model.batch_size)
        self.dataset.backend = self.backend
        self.dataset.load(backend=self.backend, experiment=self)
        if hasattr(self.dataset, 'serialized_path') and (
                self.dataset.serialized_path is not None):
            logger.warning('Ability to serialize dataset has been deprecated.')

        # fit the model to the data, save it if specified
        if not hasattr(self.model, 'backend'):
            self.model.backend = self.backend
        if not hasattr(self.model, 'epochs_complete'):
            self.model.epochs_complete = 0
        mfile = ''
        if hasattr(self.model, 'deserialized_path'):
            mfile = os.path.expandvars(os.path.expanduser(
                self.model.deserialized_path))
        elif hasattr(self.model, 'serialized_path'):
            mfile = os.path.expandvars(os.path.expanduser(
                self.model.serialized_path))
        elif self.live:
            raise RuntimeError('Live inference requires a saved model')

        if os.access(mfile, os.R_OK):
            if self.backend.is_distributed():
                raise NotImplementedError('Deserializing models not supported '
                                          'in distributed mode')
            self.model.set_params(deserialize(mfile))
        elif mfile != '':
            logger.info('Unable to find saved model %s, starting over', mfile)
            if self.live:
                raise RuntimeError('Live inference requires a saved model')

        if self.model.epochs_complete >= self.model.num_epochs:
            return
        if self.live:
            return

        self.model.fit(self.dataset)

        if hasattr(self.model, 'serialized_path'):
            if self.backend.rank() == 0:
                serialize(self.model.get_params(), self.model.serialized_path)
コード例 #9
0
ファイル: batch_writer.py プロジェクト: nkhuyu/neon
    def write_batches(self, name, start, labels, imfiles, targets=None,
                      is_tar=False):
        pool = Pool(processes=self.num_workers)
        psz = self.batch_size
        osz = self.output_image_size
        npts = (len(imfiles) + psz - 1) // psz

        imfiles = [imfiles[i*psz: (i+1)*psz] for i in range(npts)]

        if targets is not None:
            targets = [targets[i*psz: (i+1)*psz].T.copy() for i in range(npts)]

        labels = [{k: v[i*psz: (i+1)*psz] for k, v in labels.iteritems()}
                  for i in range(npts)]

        accum_buf = np.zeros(self.train_mean.shape, dtype=np.int32)
        batch_mean = np.zeros(accum_buf.shape, dtype=np.uint8)
        logger.info("Writing %s batches...", name)
        for i, jpeg_file_batch in enumerate(imfiles):
            t = time()
            if is_tar:
                jpeg_file_batch = [j.read() for j in jpeg_file_batch]
            jpeg_strings = pool.map(
                functools.partial(proc_img, is_string=is_tar), jpeg_file_batch)
            targets_batch = None if targets is None else targets[i]
            labels_batch = labels[i]
            bfile = os.path.join(self.out_dir, 'data_batch_%d' % (start + i))
            serialize({'data': jpeg_strings,
                       'labels': labels_batch,
                       'targets': targets_batch},
                      bfile)
            logger.info("Wrote to %s (%s batch %d of %d) (%.2f sec)",
                        self.out_dir, name, i + 1, len(imfiles), time() - t)

            # get the means and accumulate
            imgworker.calc_batch_mean(jpglist=jpeg_strings, tgt=batch_mean,
                                      orig_size=osz, rgb=self.rgb,
                                      nthreads=self.num_workers)

            # scale for the case where we have an undersized batch
            if len(jpeg_strings) < self.batch_size:
                batch_mean *= len(jpeg_strings) / self.batch_size
            accum_buf += batch_mean
        pool.close()
        mean_buf = self.train_mean if name == 'train' else self.val_mean
        mean_buf[:] = accum_buf / len(imfiles)
コード例 #10
0
ファイル: train.py プロジェクト: eyrun/deepfly
def train():

    save_file = sys.argv[1]
    if len(sys.argv) > 2:
        model = deserialize(sys.argv[2])
    else:
        layers = get_parameters(n_in=FEATURE_LENGTH, n_hidden_units=[100, 50, NUM_CLASSES])
        # define model
        model = MLP(num_epochs=1, batch_size=MINIBATCH_SIZE,
                     layers=layers, epochs_complete=0)
        model.link()
        #be.configure(model, datapar=False, modelpar=False)
        model.initialize(be)
        model.data_layer = model.layers[0]
        model.cost_layer = model.layers[-1]

    dataset = Fly(backend=be,
                    repo_path=os.path.expanduser('~/flyvfly/'))
    
    # par related init
    be.actual_batch_size = model.batch_size
    be.mpi_size = 1
    be.mpi_rank = 0
    be.par = NoPar()
    be.par.backend = be

    max_macro_epochs = 1000
    min_err = sys.maxint
    for i in range(max_macro_epochs):
        model.epochs_complete = 0
        dataset.use_set = "train"
        model.fit(dataset)
        #scores, targets = model.predict_fullset(dataset, "validation")
        val_err = get_validation(model, dataset)
        logger.info('epoch: %d,  valid error: %0.6f', i, val_err)
        if val_err < min_err:
            serialize(model, save_file)
            min_err = val_err
コード例 #11
0
ファイル: model.py プロジェクト: xiaoyunwu/neon
    def save_snapshot(self):
        """
        Save snapshots of the model at specified epochs.
        If serialize_schedule is a list of ints, it will serialize at those
        epochs.
        If serialize_schedule is a single int, it will serialize when
        epochs_complete is a multiple of serialize_schedule
        """
        if self.serialize_schedule is None:
            return

        if not hasattr(self, 'serialized_path'):
            logger.error('Serialize schedule specified, but no serialize '
                         'path provided, not saving')
            return

        if not isinstance(self.serialize_schedule, (list, int)):
            logger.error('Serialize schedule must be a list of epochs or a '
                         'single int indicating interval between save epochs')
            return

        if isinstance(self.serialize_schedule, list):
            dosave = self.epochs_complete in self.serialize_schedule
            if dosave:
                # add 1 to match periodic schedule
                check_point = \
                    self.serialize_schedule.index(self.epochs_complete) + 1
            else:
                check_point = None
        else:
            dosave = self.epochs_complete % self.serialize_schedule == 0
            check_point = self.epochs_complete/self.serialize_schedule

        if dosave:
            serialize(self.get_params(), self.serialized_path)

            if hasattr(self, 'save_checkpoints'):
                if self.save_checkpoints > 0:
                    # save_checkpoints is the number of previous
                    # checkpoints to save
                    file_parts = os.path.splitext(self.serialized_path)
                    cp_fname_str = file_parts[0] + '_cp%d' + file_parts[1]
                    if os.path.exists(cp_fname_str % check_point):
                        logger.warning(
                            'Checkpoint file exists, overwriting it.  ' +
                            'Check for stale check point files in ' +
                            'serialization directory')

                    # will have a duplicate copy of the current file
                    shutil.copy(
                        self.serialized_path,
                        cp_fname_str % check_point)

                    # only keep last "save_checkpoints" files
                    cp_rng = range(check_point-self.save_checkpoints, -1, -1)
                    for cp_ind in cp_rng:
                        # will not run here until at least
                        # min checkspints saved
                        cp_fname = cp_fname_str % cp_ind
                        if os.path.exists(cp_fname):
                            os.remove(cp_fname)
                        else:
                            # all older files should already be deleted
                            # don't need to continue
                            break