예제 #1
0
def setup_shooting_block(integrator=None,
                         in_features=20,
                         shooting_model='updown',
                         parameter_weight=1.0,
                         nr_of_particles=10,
                         inflation_factor=2,
                         nonlinearity='relu',
                         use_particle_rnn_mode=False,
                         use_particle_free_rnn_mode=False,
                         optimize_over_data_initial_conditions=False,
                         optimize_over_data_initial_conditions_type='linear'):

    if shooting_model == 'updown':
        smodel = shooting_models.AutoShootingIntegrandModelUpDown(
            in_features=in_features,
            nonlinearity=nonlinearity,
            parameter_weight=parameter_weight,
            inflation_factor=inflation_factor,
            nr_of_particles=nr_of_particles,
            particle_dimension=1,
            particle_size=in_features,
            use_analytic_solution=True,
            use_rnn_mode=use_particle_rnn_mode,
            optimize_over_data_initial_conditions=
            optimize_over_data_initial_conditions,
            optimize_over_data_initial_conditions_type=
            optimize_over_data_initial_conditions_type)
    elif shooting_model == 'updown_universal':
        smodel = shooting_models.AutoShootingIntegrandModelUpDownUniversal(
            in_features=in_features,
            nonlinearity=nonlinearity,
            parameter_weight=parameter_weight,
            inflation_factor=inflation_factor,
            nr_of_particles=nr_of_particles,
            particle_dimension=1,
            particle_size=in_features,
            use_analytic_solution=True,
            optional_weight=0.1,
            use_rnn_mode=use_particle_rnn_mode,
            optimize_over_data_initial_conditions=
            optimize_over_data_initial_conditions,
            optimize_over_data_initial_conditions_type=
            optimize_over_data_initial_conditions_type)
    elif shooting_model == 'periodic':
        smodel = shooting_models.AutoShootingIntegrandModelUpdownPeriodic(
            in_features=in_features,
            nonlinearity=nonlinearity,
            parameter_weight=parameter_weight,
            inflation_factor=inflation_factor,
            nr_of_particles=nr_of_particles,
            particle_dimension=1,
            particle_size=in_features,
            use_analytic_solution=True,
            use_rnn_mode=use_particle_rnn_mode,
            optimize_over_data_initial_conditions=
            optimize_over_data_initial_conditions,
            optimize_over_data_initial_conditions_type=
            optimize_over_data_initial_conditions_type)
    elif shooting_model == 'simple':
        smodel = shooting_models.AutoShootingIntegrandModelSimple(
            in_features=in_features,
            nonlinearity=nonlinearity,
            parameter_weight=parameter_weight,
            nr_of_particles=nr_of_particles,
            particle_dimension=1,
            particle_size=in_features,
            use_analytic_solution=True,
            use_rnn_mode=use_particle_rnn_mode)

    print('Using shooting model {}'.format(shooting_model))

    par_initializer = pi.VectorEvolutionParameterInitializer(
        only_random_initialization=True, random_initialization_magnitude=0.5)
    smodel.set_state_initializer(state_initializer=par_initializer)

    shooting_block = shooting_blocks.ShootingBlockBase(
        name='simple',
        shooting_integrand=smodel,
        use_particle_free_rnn_mode=use_particle_free_rnn_mode,
        integrator=integrator)

    return shooting_block
    shootingintegrand_kwargs = {
        'in_features':
        1,
        'nonlinearity':
        args.nonlinearity,
        'nr_of_particles':
        args.nr_of_particles,
        'parameter_weight':
        args.pw,
        'particle_dimension':
        1,
        'particle_size':
        1,
        'costate_initializer':
        pi.VectorEvolutionParameterInitializer(
            random_initialization_magnitude=0.1),
        'optimize_over_data_initial_conditions':
        args.optimize_over_data_initial_conditions,
        'optimize_over_data_initial_conditions_type':
        args.optimize_over_data_initial_conditions_type,
        'optional_weight':
        args.optional_weight
    }

    inflation_factor = args.inflation_factor  # for the up-down models (i.e., how much larger is the internal state; default is 5)
    use_particle_rnn_mode = args.use_particle_rnn_mode
    use_particle_free_rnn_mode = args.use_particle_free_rnn_mode

    use_analytic_solution = True  # True is the proper setting here for models that have analytic solutions implemented
    write_out_first_five_gradients = False  # for debugging purposes; use jointly with check_gradient_over_iterations.py
    use_fixed_sample_batch = write_out_first_five_gradients  # has to be set to True if we want to compare autodiff and analytic gradients (as otherwise there will be different random initializations