# print("Training E2E (joint) epoch", i) # panda_training.train_e2e( # buddy, # pf_fusion_model, # e2e_trainset_loader, # loss_type="mse", # optim_name=optim_name) # buddy.save_checkpoint("phase_4_e2e_joint") # buddy.save_checkpoint() # E2E train (joint) optim_name = "e2e_fusion" pf_fusion_model.freeze_image_model = False pf_fusion_model.freeze_force_model = False pf_fusion_model.image_model.freeze_dynamics_model = True pf_fusion_model.force_model.freeze_dynamics_model = True buddy.set_learning_rate(1e-5, optimizer_name="e2e_fusion") e2e_trainset_loader = torch.utils.data.DataLoader(e2e_trainset, batch_size=32, shuffle=True, num_workers=2) for i in range(E2E_EPOCHS): print("Training E2E (joint) epoch", i) panda_training.train_e2e(buddy, pf_fusion_model, e2e_trainset_loader, loss_type="mse", optim_name=optim_name) buddy.save_checkpoint("phase_4_e2e_joint") buddy.save_checkpoint()
models = [ (pf_image_model, 'e2e_image'), # (pf_force_model, 'e2e_force'), ] for pf_model, optim_name in models: pf_model.freeze_measurement_model = False pf_model.freeze_dynamics_model = True e2e_trainset_loader = torch.utils.data.DataLoader(e2e_trainset, batch_size=32, shuffle=True, num_workers=2) for i in range(E2E_INDIVIDUAL_EPOCHS): print(f"E2E individual training epoch {optim_name}", i) panda_training.train_e2e(buddy, pf_model, e2e_trainset_loader, loss_type="mse", optim_name=optim_name) buddy.save_checkpoint("phase_3_e2e_individual") buddy.save_checkpoint() # E2E train (joint) optim_name = "e2e_fusion" buddy.set_learning_rate(1e-5, optimizer_name=optim_name) pf_fusion_model.freeze_image_model = False pf_fusion_model.freeze_force_model = False pf_fusion_model.image_model.freeze_dynamics_model = True pf_fusion_model.force_model.freeze_dynamics_model = True e2e_trainset_loader = torch.utils.data.DataLoader(e2e_trainset, batch_size=32, shuffle=True,
"data/gentle_push_1000.hdf5", subsequence_length=16, particle_count=30, particle_stddev=(.1, .1), **dataset_args) elif args.dataset == "omnipush": e2e_trainset = omnipush_datasets.OmnipushParticleFilterDataset( *omnipush_train_files, subsequence_length=16, particle_count=30, particle_stddev=(.1, .1), **dataset_args) # E2E training pf_model.freeze_measurement_model = False pf_model.freeze_dynamics_model = False e2e_trainset_loader = torch.utils.data.DataLoader(e2e_trainset, batch_size=32, shuffle=True, num_workers=8) for i in range(E2E_EPOCHS): print("E2E training epoch", i) panda_training.train_e2e(buddy, pf_model, e2e_trainset_loader, loss_type="mse", resample=False) buddy.save_checkpoint("phase_3_end_to_end_trained") buddy.save_checkpoint()