Ejemplo n.º 1
0
def setup_shooting_block(integrator=None,
                         in_features=20,
                         shooting_model='updown',
                         parameter_weight=1.0,
                         nr_of_particles=10,
                         inflation_factor=2,
                         nonlinearity='relu',
                         use_particle_rnn_mode=False,
                         use_particle_free_rnn_mode=False,
                         optimize_over_data_initial_conditions=False,
                         optimize_over_data_initial_conditions_type='linear'):

    if shooting_model == 'updown':
        smodel = shooting_models.AutoShootingIntegrandModelUpDown(
            in_features=in_features,
            nonlinearity=nonlinearity,
            parameter_weight=parameter_weight,
            inflation_factor=inflation_factor,
            nr_of_particles=nr_of_particles,
            particle_dimension=1,
            particle_size=in_features,
            use_analytic_solution=True,
            use_rnn_mode=use_particle_rnn_mode,
            optimize_over_data_initial_conditions=
            optimize_over_data_initial_conditions,
            optimize_over_data_initial_conditions_type=
            optimize_over_data_initial_conditions_type)
    elif shooting_model == 'updown_universal':
        smodel = shooting_models.AutoShootingIntegrandModelUpDownUniversal(
            in_features=in_features,
            nonlinearity=nonlinearity,
            parameter_weight=parameter_weight,
            inflation_factor=inflation_factor,
            nr_of_particles=nr_of_particles,
            particle_dimension=1,
            particle_size=in_features,
            use_analytic_solution=True,
            optional_weight=0.1,
            use_rnn_mode=use_particle_rnn_mode,
            optimize_over_data_initial_conditions=
            optimize_over_data_initial_conditions,
            optimize_over_data_initial_conditions_type=
            optimize_over_data_initial_conditions_type)
    elif shooting_model == 'periodic':
        smodel = shooting_models.AutoShootingIntegrandModelUpdownPeriodic(
            in_features=in_features,
            nonlinearity=nonlinearity,
            parameter_weight=parameter_weight,
            inflation_factor=inflation_factor,
            nr_of_particles=nr_of_particles,
            particle_dimension=1,
            particle_size=in_features,
            use_analytic_solution=True,
            use_rnn_mode=use_particle_rnn_mode,
            optimize_over_data_initial_conditions=
            optimize_over_data_initial_conditions,
            optimize_over_data_initial_conditions_type=
            optimize_over_data_initial_conditions_type)
    elif shooting_model == 'simple':
        smodel = shooting_models.AutoShootingIntegrandModelSimple(
            in_features=in_features,
            nonlinearity=nonlinearity,
            parameter_weight=parameter_weight,
            nr_of_particles=nr_of_particles,
            particle_dimension=1,
            particle_size=in_features,
            use_analytic_solution=True,
            use_rnn_mode=use_particle_rnn_mode)

    print('Using shooting model {}'.format(shooting_model))

    par_initializer = pi.VectorEvolutionParameterInitializer(
        only_random_initialization=True, random_initialization_magnitude=0.5)
    smodel.set_state_initializer(state_initializer=par_initializer)

    shooting_block = shooting_blocks.ShootingBlockBase(
        name='simple',
        shooting_integrand=smodel,
        use_particle_free_rnn_mode=use_particle_free_rnn_mode,
        integrator=integrator)

    return shooting_block
    inflation_factor = args.inflation_factor  # for the up-down models (i.e., how much larger is the internal state; default is 5)
    use_particle_rnn_mode = args.use_particle_rnn_mode
    use_particle_free_rnn_mode = args.use_particle_free_rnn_mode

    use_analytic_solution = True  # True is the proper setting here for models that have analytic solutions implemented
    write_out_first_five_gradients = False  # for debugging purposes; use jointly with check_gradient_over_iterations.py
    use_fixed_sample_batch = write_out_first_five_gradients  # has to be set to True if we want to compare autodiff and analytic gradients (as otherwise there will be different random initializations

    if write_out_first_five_gradients and not use_fixed_sample_batch:
        print(
            'WARNING: if you want to compare autodiff/analytic gradient then use_fixed_sample_batch should be set to True'
        )

    if args.shooting_model == 'simple':
        smodel = smodels.AutoShootingIntegrandModelSimple(
            **shootingintegrand_kwargs,
            use_analytic_solution=use_analytic_solution,
            use_rnn_mode=use_particle_rnn_mode)
    elif args.shooting_model == '2nd_order':
        smodel = smodels.AutoShootingIntegrandModelSecondOrder(
            **shootingintegrand_kwargs, use_rnn_mode=use_particle_rnn_mode)
    elif args.shooting_model == 'updown':
        smodel = smodels.AutoShootingIntegrandModelUpDown(
            **shootingintegrand_kwargs,
            use_analytic_solution=use_analytic_solution,
            inflation_factor=inflation_factor,
            use_rnn_mode=use_particle_rnn_mode)
    elif args.shooting_model == 'updown_universal':
        smodel = smodels.AutoShootingIntegrandModelUpDownUniversal(
            **shootingintegrand_kwargs,
            use_analytic_solution=use_analytic_solution,
            inflation_factor=inflation_factor,
number_of_tests_attempted = 0
tolerance = 5e-3

integrator = generic_integrator.GenericIntegrator(
    integrator_library='odeint',
    integrator_name='rk4',
    use_adjoint_integration=False,
    integrator_options=integrator_options)

for current_model in check_models:

    if current_model == 'simple':
        shooting_model = shooting_models.AutoShootingIntegrandModelSimple(
            in_features=in_features_size,
            nonlinearity=nonlinearity,
            nr_of_particles=nr_of_particles,
            particle_dimension=1,
            particle_size=in_features_size,
            parameter_weight=parameter_weight)
    elif current_model == 'universal':
        shooting_model = shooting_models.AutoShootingIntegrandModelUniversal(
            in_features=in_features_size,
            nonlinearity=nonlinearity,
            nr_of_particles=nr_of_particles,
            particle_dimension=1,
            particle_size=in_features_size,
            parameter_weight=parameter_weight,
            inflation_factor=5)

    elif current_model == 'updown':
        shooting_model = shooting_models.AutoShootingIntegrandModelUpDown(
    return initial_conditions, assembly_plans

# create a shooting integrand

parameter_weight = 1.0
in_features = 3
nr_of_particles = 4

# this is neeeded to determine how many Lagrangian multipliers there are
# as we add them via their mean
nr_of_particle_parameters = in_features*nr_of_particles

shooting_integrand = shooting_models.AutoShootingIntegrandModelSimple(
            in_features=in_features,
            particle_dimension=1,
            particle_size=in_features,
            nonlinearity='tanh',
            nr_of_particles=nr_of_particles,
            parameter_weight=parameter_weight,
            use_analytic_solution=False)

keep_state_parameters_at_zero = False

# get the state and costate dictionaries
state_dict = shooting_integrand.create_initial_state_parameters_if_needed(set_to_zero=keep_state_parameters_at_zero)
costate_dict = shooting_integrand.create_initial_costate_parameters(state_dict=state_dict)

# create some initial data
x = torch.randn([10,1,in_features])
block_name = 'test_block'
effective_data_dict = shooting_integrand.get_initial_data_dict_from_data_tensor(x)
effective_data_dict_of_dicts = SortedDict({block_name: effective_data_dict})