Esempio n. 1
0
                               output_density=output_density).float()

    # initalize policy
    pol_model = models.mlp(D,
                           2 * U,
                           args.pol_shape,
                           dropout_layers=[
                               models.modules.BDropout(args.pol_drop_rate)
                               if args.pol_drop_rate > 0 else None
                               for hid in args.pol_shape
                           ],
                           nonlin=torch.nn.ReLU,
                           output_nonlin=partial(models.DiagGaussianDensity,
                                                 U))

    pol = models.Policy(pol_model, maxU, minU).float()
    print('args\n', args)
    print('Dynamics model\n', dyn)
    print('Policy\n', pol)

    # initalize experience dataset
    exp = utils.ExperienceDataset()

    if loaded_from is not None:
        utils.load_checkpoint(loaded_from, dyn, pol, exp)

    # initialize dynamics optimizer
    opt1 = torch.optim.Adam(dyn.parameters(), args.dyn_lr)

    # initialize policy optimizer
    opt2 = torch.optim.Adam(pol.parameters(), args.pol_lr)
Esempio n. 2
0
                          Q=Q,
                          angle_dims=angle_dims)

# init dynamics model (heteroscedastic noise)
dyn = models.DynamicsModel(models.dropout_mlp(
    D + U,
    dynE, [200] * 2,
    dropout_layers=[models.modules.CDropout(0.1)] * 2,
    nonlin=torch.nn.ReLU),
                           reward_func=reward_func).float()
forward_fn = partial(forward, dynamics=dyn)

# init policy
pol = models.Policy(
    models.dropout_mlp(D,
                       U,
                       output_nonlin=torch.nn.Tanh,
                       dropout_layers=[models.modules.BDropout(0.1)] * 2),
    maxU).float()
randpol = RandPolicy(maxU)

# init experience dataset
exp = ExperienceDataset()

# init policy optimizer
params = filter(lambda p: p.requires_grad, pol.parameters())
opt = torch.optim.Adam(params, 1e-3, amsgrad=True)


def cb(*args, **kwargs):
    env.render()