pixel_res=256,
    raw=True,
    cond=False,
    half_image_size=False,
    kloss_dataset=True,  # true for mit push
)

model_config = PlaNetBaselineFilterConfig(
    latent_dim=hy_config.latent_dim,
    latent_obs_dim=32,
    hidden_units=64,
    ctrl_dim=6,
    dataset=dataset_config,
    overshoot=(OverShoot.LATENT, 2),
)

exp_config = ExpConfig(
    name=f"oo_planet_push_{model_config.overshoot[1]}" +
    f"{datetime.datetime.now().strftime('%a-%H-%M-%S')}",
    model=model_config,
    ramp_iters=200,
    batch_size=hy_config.batch_size,
    epochs=hy_config.epochs,
    log_iterations_simple=10,
    log_iterations_images=100,
    base_learning_rate=hy_config.learning_rate,
    learning_rate_function=lr5,
    gradient_clip_max_norm=100,
)
train(exp_config)
    dataset=dataset_config,
    dyn_hidden_units=32,
    dyn_layers=3,
    dyn_nonlinearity=nn.Softplus(beta=2, threshold=20),
    obs_hidden_units=32,
    obs_layers=3,
    obs_nonlinearity=nn.Softplus(beta=2, threshold=20),
    is_continuous=False,
    ramp_iters=100,
    burn_in=100,
    dkl_anneal_iter=1000,
    alpha=0.5,
    beta=1.0,
    atol=1e-9,  # default: 1e-9
    rtol=1e-7,  # default: 1e-7
    z_pred=False,
)

# experiment settings
exp_config = ExpConfig(
    model=model_config,
    ramp_iters=model_config.ramp_iters,
    batch_size=hy_config.batch_size,
    epochs=hy_config.epochs,
    log_iterations_simple=10,
    log_iterations_images=model_config.ramp_iters,
    base_learning_rate=hy_config.learning_rate,
    learning_rate_function=lr1,
)
train(exp_config)  # train the model