"simpler/train0.hdf5", "simpler/train1.hdf5", "simpler/train2.hdf5", "simpler/train3.hdf5", "simpler/train4.hdf5", "simpler/train5.hdf5", ) dynamics_trainset = omnipush_datasets.OmnipushDynamicsDataset( *omnipush_train_files, **dataset_args) dynamics_recurrent_trainset = omnipush_datasets.OmnipushSubsequenceDataset( *omnipush_train_files, subsequence_length=16, **dataset_args) measurement_trainset = omnipush_datasets.OmnipushMeasurementDataset( *omnipush_train_files, samples_per_pair=10, **dataset_args) e2e_trainset = omnipush_datasets.OmnipushParticleFilterDataset( *omnipush_train_files, subsequence_length=16, particle_count=30, particle_stddev=(.1, .1), **dataset_args) # Pre-train measurement models = [ (pf_image_model, 'measurement_image'), # (pf_force_model, 'measurement_force'), ] measurement_trainset_loader = torch.utils.data.DataLoader(measurement_trainset, batch_size=32, shuffle=True, num_workers=16) for pf_model, optim_name in models: for i in range(MEASUREMENT_PRETRAIN_EPOCHS): print("Pre-training measurement epoch", i)
buddy.load_checkpoint(path=args.load_checkpoint) if args.module_type == "ekf": buddy.load_checkpoint_module(source="image_model", path=args.load_checkpoint) buddy.load_checkpoint_module(source="force_model", path=args.load_checkpoint) print("Creating dataset...") if args.omnipush: e2e_trainset = omnipush_datasets.OmnipushParticleFilterDataset( "simpler/train0.hdf5", "simpler/train1.hdf5", "simpler/train2.hdf5", "simpler/train3.hdf5", "simpler/train4.hdf5", "simpler/train5.hdf5", subsequence_length=16, particle_count=1, particle_stddev=(.03, .03), **dataset_args) dataset_measurement = omnipush_datasets.OmnipushMeasurementDataset( "simpler/train0.hdf5", "simpler/train1.hdf5", "simpler/train2.hdf5", "simpler/train3.hdf5", "simpler/train4.hdf5", "simpler/train5.hdf5", subsequence_length=16, stddev=(0.5, 0.5),