示例#1
0
    def test_trace_grad(self):
        self.history_options = HistoryOptions(
            trace_record=True,
            trace_record_grad=True,
        )

        self.check_history()
示例#2
0
    def test_trace_schi2(self):
        self.history_options = HistoryOptions(
            trace_record=True,
            trace_record_chi2=False,
            trace_record_schi2=True,
        )

        self.check_history()
示例#3
0
    def test_trace_grad_integrated(self):
        self.obj = rosen_for_sensi(max_sensi_order=2, integrated=True)['obj']

        self.history_options = HistoryOptions(
            trace_record=True,
            trace_record_grad=True,
            trace_record_hess=False,
        )

        self.check_history()
示例#4
0
    def test_trace_all(self):
        self.history_options = HistoryOptions(
            trace_record=True,
            trace_record_grad=True,
            trace_record_hess=True,
            trace_record_res=True,
            trace_record_sres=True,
            trace_record_chi2=True,
            trace_record_schi2=True,
        )

        self.fix_pars = False
        self.check_history()
示例#5
0
    def test_trace_all_aggregated(self):
        self.history_options = HistoryOptions(
            trace_record=True,
            trace_record_grad=True,
            trace_record_hess=True,
            trace_record_res=True,
            trace_record_sres=True,
            trace_record_chi2=True,
            trace_record_schi2=True,
        )

        self.obj = pypesto.objective.AggregatedObjective([self.obj, self.obj])
        self.fix_pars = False
        self.check_history()
示例#6
0
    def test_trace_all(self):
        self.obj = rosen_for_sensi(max_sensi_order=2, integrated=True)['obj']

        self.history_options = HistoryOptions(
            trace_record=True,
            trace_record_grad=True,
            trace_record_hess=True,
            trace_record_res=True,
            trace_record_sres=True,
            trace_record_chi2=True,
            trace_record_schi2=True,
        )
        self.fix_pars = False
        self.check_history()
示例#7
0
    def test_trace_all_aggregated(self):
        self.obj = rosen_for_sensi(max_sensi_order=2, integrated=True)['obj']

        self.history_options = HistoryOptions(
            trace_record=True,
            trace_record_grad=True,
            trace_record_hess=True,
            trace_record_res=True,
            trace_record_sres=True,
            trace_record_chi2=True,
            trace_record_schi2=True,
        )
        self.obj = pypesto.objective.AggregatedObjective([self.obj, self.obj])
        self.fix_pars = False
        self.check_history()
示例#8
0
def train(ae: MechanisticAutoEncoder,
          optimizer: str = 'fides',
          ftol: float = 1e-3,
          maxiter: int = 1e4,
          n_starts: int = 1,
          seed: int = 0) -> Result:
    """
    Trains the provided autoencoder by solving the optimization problem
    generated by :py:func:`create_pypesto_problem`

    :param ae:
        Autoencoder that will be trained
    :param optimizer:
        Optimizer string that specifies the optimizer that will be used
    :param ftol:
        function tolerance that is used to assess optimizer convergence
    :param maxiter:
        maximum number of optimization iterations
    :param n_starts:
        number of local starts that will be performed
    :param seed:
        random seed that will be used to generate the randomly sampled
        initial startpoints

    :returns:
        Pypesto optimization results.
    """
    pypesto_problem = create_pypesto_problem(ae)

    if optimizer == 'ipopt':
        opt = IpoptOptimizer(options={
            'maxiter': maxiter,
            'tol': ftol,
            'disp': 5,
        })
    elif optimizer.startswith('NLOpt_'):
        opt = NLoptOptimizer(method=getattr(nlopt,
                                            optimizer.replace('NLOpt_', '')),
                             options={
                                 'maxtime': 3600,
                                 'ftol_abs': ftol,
                             })
    elif optimizer == 'fides':
        opt = FidesOptimizer(hessian_update=fides.BFGS(),
                             options={
                                 'maxtime': 3600,
                                 fides.Options.FATOL: ftol,
                                 fides.Options.MAXTIME: 3600,
                                 fides.Options.MAXITER: maxiter,
                                 fides.Options.SUBSPACE_DIM:
                                 fides.SubSpaceDim.FULL
                             },
                             verbose=logging.INFO)

    os.makedirs(trace_path, exist_ok=True)

    history_options = HistoryOptions(trace_record=True,
                                     trace_record_hess=False,
                                     trace_record_res=False,
                                     trace_record_sres=False,
                                     trace_record_schi2=False,
                                     storage_file=os.path.join(
                                         trace_path,
                                         TRACE_FILE_TEMPLATE.format(
                                             pathway=ae.pathway_name,
                                             data=ae.data_name,
                                             optimizer=optimizer,
                                             n_hidden=ae.n_hidden,
                                             job=seed)),
                                     trace_save_iter=10)

    np.random.seed(seed)

    optimize_options = OptimizeOptions(
        startpoint_resample=False,
        allow_failed_starts=True,
    )

    decoder_par_pretraining = os.path.join(
        'pretraining', f'{ae.pathway_name}__{ae.data_name}__{ae.n_hidden}'
        f'__decoder_inflate.csv')
    has_decoder_par_pretraing = os.path.exists(decoder_par_pretraining)
    if has_decoder_par_pretraing:
        decoder_pars = pd.read_csv(decoder_par_pretraining)[ae.x_names]

    lb = np.asarray([
        parameter_boundaries_scales[name.split('_')[-1]][0]
        for name in pypesto_problem.x_names
    ])
    ub = np.asarray([
        parameter_boundaries_scales[name.split('_')[-1]][1]
        for name in pypesto_problem.x_names
    ])

    def startpoint(**kwargs):

        if has_decoder_par_pretraing and seed < len(decoder_pars):
            xs = decoder_pars.iloc[seed, :]
        else:
            xs = np.random.random((kwargs['n_starts'],
                                   ae.n_encoder_pars + ae.n_kin_params)) \
                * (ub - lb) + lb
        return xs

    return minimize(pypesto_problem,
                    opt,
                    n_starts=n_starts,
                    options=optimize_options,
                    history_options=history_options,
                    startpoint_method=startpoint)