Ejemplo n.º 1
0
# Configure experiment
experiment_name = args.experiment_name
dataset_args = {
    'use_proprioception': True,
    'use_haptics': True,
    'use_vision': True,
    'vision_interval': 2,
    'image_blackout_ratio': args.blackout,
    'sequential_image_rate': args.sequential_image,
    'start_timestep': args.start_timestep,
}

# Create models & training buddy
pf_image_model = panda_models.PandaParticleFilterNetwork(
    panda_models.PandaDynamicsModel(),
    panda_models.PandaMeasurementModel(units=args.hidden_units,
                                       missing_modalities=['gripper_sensors']))
pf_force_model = panda_models.PandaParticleFilterNetwork(
    panda_models.PandaDynamicsModel(),
    panda_models.PandaMeasurementModel(units=args.hidden_units,
                                       missing_modalities=['image']),
)
weight_model = fusion.CrossModalWeights(state_dim=1,
                                        use_softmax=True,
                                        use_log_softmax=True)
pf_fusion_model = fusion_pf.ParticleFusionModel(pf_image_model, pf_force_model,
                                                weight_model)

buddy = fannypack.utils.Buddy(experiment_name,
                              pf_fusion_model,
                              optimizer_names=[
                    type=str,
                    choices=["mujoco", "omnipush"],
                    default="mujoco")
parser.add_argument("--hidden_units", type=int, default=64)
args = parser.parse_args()

# Some constants
E2E_EPOCHS = 10

# Configure experiment
experiment_name = args.experiment_name

# Create models & training buddy
dynamics_model = panda_models.PandaDynamicsModel(units=32)
measurement_model = panda_models.PandaMeasurementModel(units=args.hidden_units)
pf_model = panda_models.PandaParticleFilterNetwork(dynamics_model,
                                                   measurement_model)

buddy = fannypack.utils.Buddy(experiment_name + "_unfrozen",
                              pf_model,
                              optimizer_names=[
                                  "e2e_fusion",
                                  "e2e_image",
                                  "e2e_force",
                                  "dynamics_image",
                                  "dynamics_force",
                                  "dynamics_recurrent_image",
                                  "dynamics_recurrent_force",
                                  "measurement_image",
                                  "measurement_force",
                              ])
buddy.load_metadata(experiment_name=experiment_name)