Example #1
0
    def save_samples(self,
                     model_output_list,
                     file=None,
                     overwrite=False,
                     verbose=False):
        ''' Save model output to file. Append data if data is already present.
    '''

        if file is None: file = self.optim_data_file

        if not overwrite:
            old_optim_data = self.load_samples(file=file, verbose=verbose)
        else:
            self.remove_samples(file=file, verbose=verbose)
            old_optim_data = None

        new_optim_data = self.stack_model_output_list(model_output_list)

        if old_optim_data is not None:
            optim_data = self.stack_model_output_list(
                [old_optim_data, new_optim_data])
        else:
            optim_data = new_optim_data

        # Save samples.
        data_utils.save_var(optim_data, file)
    def init_rec_data(self,
                      allow_loading=False,
                      force_loading=False,
                      verbose=True,
                      **kwargs):

        for cell in self.cells:
            assert not cell.make_cones

        # Load data?
        if allow_loading or force_loading:
            loaded_data = self.try_load_init_data(dirname=self.output_folder,
                                                  force_loading=force_loading,
                                                  verbose=verbose)
        else:
            loaded_data = False

        # Initialize?
        if not (loaded_data):
            self.rec_data = self.run_init_rec_data(verbose=verbose, **kwargs)
            self.save_init_data(self.output_folder)

        # Set rec_size.
        self.set_rec_ex_size(verbose=verbose)

        # Save stimulation time.
        data_utils.save_var(self.stim_time,
                            file='optim_data/' + self.output_folder +
                            '/stim_time.pkl')
Example #3
0
    def check_parameter_files(cell, params, folder):
        ''' Test if cell parameters are the same as they were in a previous run.
    This is crucial if you load the data.
    '''

        files_vs_dicts = {}
        files_vs_dicts['cell_params_default.pkl'] = cell.params_default
        files_vs_dicts['cell_params_unit.pkl'] = cell.params_unit
        files_vs_dicts['opt_p_range.pkl'] = params.p_range

        for file, param_dict in files_vs_dicts.items():

            src_file = os.path.join(folder, file)

            if os.path.isfile(src_file):
                loaded_dict = data_utils.load_var(src_file)

                for param_name, param_value in param_dict.items():

                    if param_name not in loaded_dict.keys():
                        print(param_name, 'not in loaded_dict params')

                    elif param_value != loaded_dict[param_name]:
                        print(param_name, ':', param_value, '!= ',
                              loaded_dict[param_name])

                        input("Params in " + file +
                              " are different. Press Enter to overwrite ... ")

            data_utils.save_var(param_dict, src_file)

        # If p_range was fine, p_names is fine too.
        data_utils.save_var(params.p_names,
                            os.path.join(folder, 'opt_p_names.pkl'))
 def save_init_data(self, dirname):
     data_utils.make_dir('optim_data/' + dirname)
     for cell in self.cells:
         data_utils.make_dir('optim_data/' + dirname + '/' + cell.bp_type)
         cell.save_init_data('optim_data/' + dirname + '/' + cell.bp_type)
     data_utils.save_var(self.rec_data,
                         file='optim_data/' + dirname +
                         '/init_rec_data.pkl')
Example #5
0
 def save_init_data(self, dirname):
     ''' Save pilot run, i.e. initial data.
 '''
     print('# Save init data to ' + dirname)
     data_utils.make_dir(os.path.join('optim_data', dirname))
     self.cell.save_init_data(os.path.join('optim_data', dirname))
     data_utils.save_var(
         self.rec_data,
         os.path.join('optim_data', dirname, 'init_rec_data.pkl'))
Example #6
0
    def gen_and_save_data(self,
                          method,
                          adaptive,
                          step_param,
                          pert_method,
                          pert_param='auto',
                          stim=None,
                          overwrite=False,
                          allowgenerror=False,
                          plot=False):
        """Generate data and save to file."""

        folder, filename = self.get_data_folder_and_filename(
            method, adaptive, step_param, pert_method, pert_param)
        print('/'.join(filename.split('/')[2:]).rjust(60), end=' --> ')
        data_utils.make_dir(folder)

        if not os.path.isfile(filename):
            data_loaded = 'file was not found.'
        elif overwrite:
            data_loaded = 'overwrite==True.'
        else:
            data_loaded = self.load_data_and_check(method=method,
                                                   adaptive=adaptive,
                                                   step_param=step_param,
                                                   pert_method=pert_method,
                                                   pert_param=pert_param,
                                                   stim=stim,
                                                   filename=filename)

        if not isinstance(data_loaded, str):
            print('Loaded data.')
            return data_loaded
        else:
            print('Generate data because', data_loaded)
            try:
                data = self.gen_data(method=method,
                                     adaptive=adaptive,
                                     step_param=step_param,
                                     pert_method=pert_method,
                                     pert_param=pert_param,
                                     plot=plot)
            except KeyboardInterrupt:
                raise KeyboardInterrupt()
            except Exception:
                if allowgenerror:
                    print('Data generation failed')
                    data = None
                else:
                    traceback.print_exc()
                    raise

            data_utils.save_var(data, filename)
            return data
Example #7
0
def gen_or_load_samples(optim, opt_params, filename, load):
    if load:
        assert os.path.isfile(filename), 'File does not exist'
        model_output_list = data_utils.load_var(filename)
    else:
        optim.init_rec_data(allow_loading=False,
                            force_loading=True,
                            verbose=True)
        model_output_list = optim.run_parallel(opt_params_list=opt_params,
                                               verbose=True)
        data_utils.save_var(model_output_list, filename)

    if load:
        assert len(model_output_list) == opt_params.shape[
            0], 'Loaded sample size differs from requested'

    return model_output_list
Example #8
0
 def save_acc_sols_to_file(self):
     """Save acc sols in folder"""
     folder = f'{self.base_folder}'
     if self.subfoldername is not None: folder += f'/{self.subfoldername}'
     data_utils.make_dir(folder)
     data_utils.save_var(self.acc_sols, f"{folder}/acc_sols.pkl")
Example #9
0
  def run_SNPE(
      self, n_samples_per_round,
      max_duration_minutes=np.inf, max_rounds=None,
      continue_optimization=False, load_init_tds=False,
    ):
    
    ''' Run SNPE optimization.
    Add some utility functionality to the vanilla SNPE:
    Store all relevant data to folder.
    Optimization can be continued later, loading those files.
    One can also use only the samples from the prior and rerun the rest.
    
    Parameters:
    
    n_samples_per_round : int
      Number of samples generated per NN training round.
      
    max_duration_minutes : float
      If this time is exceeded, will stop after the next round.
      
    max_rounds : int
      Maximum number of rounds.
      
    continue_optimization : bool
      Load previously produced data and continue.
      
    load_init_tds : bool
      Load initial training data, i.e. will use old training data of last round.
      Otherwise algorithm will start from strech.
      Can be used to play around with NN parameters.
 
    '''
    
    assert not(load_init_tds and continue_optimization)

    self.load_random_state()

    #### Initialize.
    if continue_optimization:
      inf_snpes, logs, tds, sample_distributions, n_samples, \
        kernel_bandwidths, pseudo_obs = self.load_SNPE_rounds()
      self.inf_snpe = inf_snpes[-1]
    else:
      self.create_SNPE_rounds_backup()
    
      assert self.inf_snpe is not None, 'Initialize network first'
      inf_snpes            = [deepcopy(self.inf_snpe)]
      logs                 = []
      tds                  = []
      sample_distributions = [self.prior]
      n_samples            = []     
      if self.snpe_type in ['b', 'B']:
        kernel_bandwidths  = []
        pseudo_obs         = []
    
    ### Load data?
    load_tds = False    
    if load_init_tds:
      load_tds = True
      # Restore file from backups.
      copyfile(os.path.join(self.backups_folder, 'delfi_samples_r0.pkl'),
               os.path.join(self.samples_folder + 'delfi_samples_r0.pkl'))
      load_file = os.path.join(self.samples_folder, 'delfi_samples_r0.pkl')

    elif continue_optimization:
      files = sorted(os.listdir(self.samples_folder))
      if len(n_samples) < len(files):
        print('Found incomplete round. Will load samples.')
        load_tds = True
        load_file = self.samples_folder + files[-1]
    
    if load_tds:
      loaded_tds, n_loaded_samples = self.load_tds_from_file(
        file=load_file, params=self.optim.params
      )
    else:
      loaded_tds = None
      n_loaded_samples = 0
      
    ### Optimize.
    if max_rounds is None: max_rounds = np.inf
    t0 = tm.time()
    i_round = 0
    
    print()
    
    while (((tm.time() - t0)/60) < max_duration_minutes) and (self.inf_snpe.round < max_rounds):
      if i_round > 0:
        loaded_tds = None
        n_loaded_samples = 0
              
      # Update output file.
      samples_file = self.samples_folder + 'delfi_samples_r{:d}.pkl'.format(self.inf_snpe.round)
      self.update_samples_file(samples_file)
      
      print('# Sample output file:', self.samples_file)
      print('# Round', self.inf_snpe.round+1, '- Running for {:.2f} [min]'.format(((tm.time()-t0)/60)))
      
      ### Run inference method.     
      new_log, new_td, new_posterior = self.run_SNPE_round(
        n_train=[n_samples_per_round], proposal=sample_distributions[-1],
        initial_tds=loaded_tds, verbose=True
      )
      
      if self.post_as_truncated_normal:
        new_posterior = my_delfi_funcs.normal2truncated_normal(
          new_posterior, lower=self.trunc_lower, upper=self.trunc_upper
        ) 
      
      ### Append data.
      logs.append(new_log)
      sample_distributions.append(new_posterior)
      tds.append(new_td)
      inf_snpes.append(deepcopy(self.inf_snpe))
      n_samples.append(new_td[1].shape[0])
      if self.snpe_type in ['b', 'B']:
        pseudo_obs.append(self.inf_snpe.pseudo_obs[-1].copy())
        kernel_bandwidths.append(self.inf_snpe.kernel_bandwidth[-1].copy())
      
      ### Save data to files.
      self.save_random_state()
      data_utils.save_var(inf_snpes,            file=os.path.join(self.snpe_folder, 'inf_snpes.pkl'))
      data_utils.save_var(sample_distributions, file=os.path.join(self.snpe_folder, 'sample_distributions.pkl'))
      data_utils.save_var(logs,                 file=os.path.join(self.snpe_folder, 'logs.pkl'))
      data_utils.save_var(tds,                  file=os.path.join(self.snpe_folder, 'tds.pkl'))
      data_utils.save_var(n_samples,            file=os.path.join(self.snpe_folder, 'n_samples.pkl'))
      if self.snpe_type in ['b', 'B']:
        data_utils.save_var(pseudo_obs,         file=os.path.join(self.snpe_folder, 'pseudo_obs.pkl'))
        data_utils.save_var(kernel_bandwidths,  file=os.path.join(self.snpe_folder, 'kernel_bandwidths.pkl'))
      
      # Make sure last posterior does exist.
      assert sample_distributions[-1] is not None
      
      print('\n----------------------------------------------------\n')
      
      i_round += 1

    print('---> Done!')
Example #10
0
 def save_random_state(self):
   ''' Save numpy random state to folder.
   '''
   data_utils.save_var(np.random.get_state(), os.path.join(self.general_folder, 'random_state.pkl'))