def save_snapshot(self): """ Save snapshots of the model at specified epochs. If serialize_schedule is a list of ints, it will serialize at those epochs. If serialize_schedule is a single int, it will serialize when epochs_complete is a multiple of serialize_schedule """ if self.serialize_schedule is None: return if not hasattr(self, 'serialized_path'): logger.error('Serialize schedule specified, but no serialize ' 'path provided, not saving') return if not isinstance(self.serialize_schedule, (list, int)): logger.error('Serialize schedule must be a list of epochs or a ' 'single int indicating interval between save epochs') return if isinstance(self.serialize_schedule, list): dosave = self.epochs_complete in self.serialize_schedule else: dosave = self.epochs_complete % self.serialize_schedule == 0 if dosave: serialize(self.get_params(), self.serialized_path)
def write_batches(self, name, start, labels, imfiles, targets=None, is_tar=False): pool = Pool(processes=self.num_workers) psz = self.batch_size osz = self.output_image_size npts = (len(imfiles) + psz - 1) / psz imfiles = [imfiles[i * psz:(i + 1) * psz] for i in range(npts)] if targets is not None: targets = [ targets[i * psz:(i + 1) * psz].T.copy() for i in range(npts) ] labels = [{k: v[i * psz:(i + 1) * psz] for k, v in labels.iteritems()} for i in range(npts)] accum_buf = np.zeros((osz, osz, self.num_channels), dtype=np.int32) batch_mean = np.zeros(accum_buf.shape, dtype=np.uint8) logger.info("Writing %s batches...", name) for i, jpeg_file_batch in enumerate(imfiles): t = time() if is_tar: jpeg_file_batch = [j.read() for j in jpeg_file_batch] jpeg_strings = pool.map( functools.partial(proc_img, is_string=is_tar), jpeg_file_batch) targets_batch = None if targets is None else targets[i] labels_batch = labels[i] bfile = os.path.join(self.out_dir, 'data_batch_%d' % (start + i)) serialize( { 'data': jpeg_strings, 'labels': labels_batch, 'targets': targets_batch }, bfile) logger.info("Wrote to %s (%s batch %d of %d) (%.2f sec)", self.out_dir, name, i + 1, len(imfiles), time() - t) # get the means and accumulate imgworker.calc_batch_mean(jpglist=jpeg_strings, tgt=batch_mean, orig_size=osz, rgb=self.rgb, nthreads=self.num_workers) # scale for the case where we have an undersized batch if len(jpeg_strings) < self.batch_size: batch_mean *= len(jpeg_strings) / self.batch_size accum_buf += batch_mean pool.close() mean_buf = self.train_mean if name == 'train' else self.val_mean mean_buf[:] = accum_buf / len(imfiles)
def save_results(self, dataset, setname, data, dataname): out_dir = os.path.join(dataset.repo_path, dataset.__class__.__name__) if hasattr(dataset, 'save_dir'): out_dir = dataset.save_dir out_dir = os.path.expandvars(os.path.expanduser(out_dir)) if not os.path.exists(out_dir): os.makedirs(out_dir) filename = os.path.join(out_dir, '{}-{}.pkl'.format(setname, dataname)) serialize(data.asnumpyarray().T, filename)
def save_meta(self): serialize({'ntrain': self.ntrain, 'nval': self.nval, 'train_start': self.train_start, 'val_start': self.val_start, 'macro_size': self.batch_size, 'train_mean': self.train_mean, 'val_mean': self.val_mean, 'labels_dict': self.labels_dict, 'val_nrec': self.val_nrec, 'train_nrec': self.train_nrec, 'nclass': self.nclass}, self.stats)
def run(self): """ Actually carry out each of the experiment steps. """ # load the dataset, save it to disk if specified self.dataset.set_batch_size(self.model.batch_size) self.dataset.backend = self.backend self.dataset.load(backend=self.backend, experiment=self) if hasattr(self.dataset, 'serialized_path') and (self.dataset.serialized_path is not None): logger.warning('Ability to serialize dataset has been deprecated.') # fit the model to the data, save it if specified if not hasattr(self.model, 'backend'): self.model.backend = self.backend if not hasattr(self.model, 'epochs_complete'): self.model.epochs_complete = 0 mfile = '' if hasattr(self.model, 'deserialized_path'): mfile = os.path.expandvars( os.path.expanduser(self.model.deserialized_path)) elif hasattr(self.model, 'serialized_path'): mfile = os.path.expandvars( os.path.expanduser(self.model.serialized_path)) elif self.live: raise RuntimeError('Live inference requires a saved model') if os.access(mfile, os.R_OK): if self.backend.is_distributed(): raise NotImplementedError('Deserializing models not supported ' 'in distributed mode') self.model.set_params(deserialize(mfile)) elif mfile != '': logger.info('Unable to find saved model %s, starting over', mfile) if self.live: raise RuntimeError('Live inference requires a saved model') if self.model.epochs_complete >= self.model.num_epochs: return if self.live: return self.model.fit(self.dataset) if hasattr(self.model, 'serialized_path'): if self.backend.rank() == 0: serialize(self.model.get_params(), self.model.serialized_path)
def run(self): """ Actually carry out each of the experiment steps. """ # load the dataset, save it to disk if specified self.dataset.set_batch_size(self.model.batch_size) self.dataset.backend = self.backend self.dataset.load(backend=self.backend, experiment=self) if hasattr(self.dataset, 'serialized_path') and ( self.dataset.serialized_path is not None): logger.warning('Ability to serialize dataset has been deprecated.') # fit the model to the data, save it if specified if not hasattr(self.model, 'backend'): self.model.backend = self.backend if not hasattr(self.model, 'epochs_complete'): self.model.epochs_complete = 0 mfile = '' if hasattr(self.model, 'deserialized_path'): mfile = os.path.expandvars(os.path.expanduser( self.model.deserialized_path)) elif hasattr(self.model, 'serialized_path'): mfile = os.path.expandvars(os.path.expanduser( self.model.serialized_path)) elif self.live: raise RuntimeError('Live inference requires a saved model') if os.access(mfile, os.R_OK): if self.backend.is_distributed(): raise NotImplementedError('Deserializing models not supported ' 'in distributed mode') self.model.set_params(deserialize(mfile)) elif mfile != '': logger.info('Unable to find saved model %s, starting over', mfile) if self.live: raise RuntimeError('Live inference requires a saved model') if self.model.epochs_complete >= self.model.num_epochs: return if self.live: return self.model.fit(self.dataset) if hasattr(self.model, 'serialized_path'): if self.backend.rank() == 0: serialize(self.model.get_params(), self.model.serialized_path)
def write_batches(self, name, start, labels, imfiles, targets=None, is_tar=False): pool = Pool(processes=self.num_workers) psz = self.batch_size osz = self.output_image_size npts = (len(imfiles) + psz - 1) // psz imfiles = [imfiles[i*psz: (i+1)*psz] for i in range(npts)] if targets is not None: targets = [targets[i*psz: (i+1)*psz].T.copy() for i in range(npts)] labels = [{k: v[i*psz: (i+1)*psz] for k, v in labels.iteritems()} for i in range(npts)] accum_buf = np.zeros(self.train_mean.shape, dtype=np.int32) batch_mean = np.zeros(accum_buf.shape, dtype=np.uint8) logger.info("Writing %s batches...", name) for i, jpeg_file_batch in enumerate(imfiles): t = time() if is_tar: jpeg_file_batch = [j.read() for j in jpeg_file_batch] jpeg_strings = pool.map( functools.partial(proc_img, is_string=is_tar), jpeg_file_batch) targets_batch = None if targets is None else targets[i] labels_batch = labels[i] bfile = os.path.join(self.out_dir, 'data_batch_%d' % (start + i)) serialize({'data': jpeg_strings, 'labels': labels_batch, 'targets': targets_batch}, bfile) logger.info("Wrote to %s (%s batch %d of %d) (%.2f sec)", self.out_dir, name, i + 1, len(imfiles), time() - t) # get the means and accumulate imgworker.calc_batch_mean(jpglist=jpeg_strings, tgt=batch_mean, orig_size=osz, rgb=self.rgb, nthreads=self.num_workers) # scale for the case where we have an undersized batch if len(jpeg_strings) < self.batch_size: batch_mean *= len(jpeg_strings) / self.batch_size accum_buf += batch_mean pool.close() mean_buf = self.train_mean if name == 'train' else self.val_mean mean_buf[:] = accum_buf / len(imfiles)
def train(): save_file = sys.argv[1] if len(sys.argv) > 2: model = deserialize(sys.argv[2]) else: layers = get_parameters(n_in=FEATURE_LENGTH, n_hidden_units=[100, 50, NUM_CLASSES]) # define model model = MLP(num_epochs=1, batch_size=MINIBATCH_SIZE, layers=layers, epochs_complete=0) model.link() #be.configure(model, datapar=False, modelpar=False) model.initialize(be) model.data_layer = model.layers[0] model.cost_layer = model.layers[-1] dataset = Fly(backend=be, repo_path=os.path.expanduser('~/flyvfly/')) # par related init be.actual_batch_size = model.batch_size be.mpi_size = 1 be.mpi_rank = 0 be.par = NoPar() be.par.backend = be max_macro_epochs = 1000 min_err = sys.maxint for i in range(max_macro_epochs): model.epochs_complete = 0 dataset.use_set = "train" model.fit(dataset) #scores, targets = model.predict_fullset(dataset, "validation") val_err = get_validation(model, dataset) logger.info('epoch: %d, valid error: %0.6f', i, val_err) if val_err < min_err: serialize(model, save_file) min_err = val_err
def save_snapshot(self): """ Save snapshots of the model at specified epochs. If serialize_schedule is a list of ints, it will serialize at those epochs. If serialize_schedule is a single int, it will serialize when epochs_complete is a multiple of serialize_schedule """ if self.serialize_schedule is None: return if not hasattr(self, 'serialized_path'): logger.error('Serialize schedule specified, but no serialize ' 'path provided, not saving') return if not isinstance(self.serialize_schedule, (list, int)): logger.error('Serialize schedule must be a list of epochs or a ' 'single int indicating interval between save epochs') return if isinstance(self.serialize_schedule, list): dosave = self.epochs_complete in self.serialize_schedule if dosave: # add 1 to match periodic schedule check_point = \ self.serialize_schedule.index(self.epochs_complete) + 1 else: check_point = None else: dosave = self.epochs_complete % self.serialize_schedule == 0 check_point = self.epochs_complete/self.serialize_schedule if dosave: serialize(self.get_params(), self.serialized_path) if hasattr(self, 'save_checkpoints'): if self.save_checkpoints > 0: # save_checkpoints is the number of previous # checkpoints to save file_parts = os.path.splitext(self.serialized_path) cp_fname_str = file_parts[0] + '_cp%d' + file_parts[1] if os.path.exists(cp_fname_str % check_point): logger.warning( 'Checkpoint file exists, overwriting it. ' + 'Check for stale check point files in ' + 'serialization directory') # will have a duplicate copy of the current file shutil.copy( self.serialized_path, cp_fname_str % check_point) # only keep last "save_checkpoints" files cp_rng = range(check_point-self.save_checkpoints, -1, -1) for cp_ind in cp_rng: # will not run here until at least # min checkspints saved cp_fname = cp_fname_str % cp_ind if os.path.exists(cp_fname): os.remove(cp_fname) else: # all older files should already be deleted # don't need to continue break