Exemplo n.º 1
0
        "simpler/train0.hdf5",
        "simpler/train1.hdf5",
        "simpler/train2.hdf5",
        "simpler/train3.hdf5",
        "simpler/train4.hdf5",
        "simpler/train5.hdf5",
    )
    dynamics_trainset = omnipush_datasets.OmnipushDynamicsDataset(
        *omnipush_train_files, **dataset_args)
    dynamics_recurrent_trainset = omnipush_datasets.OmnipushSubsequenceDataset(
        *omnipush_train_files, subsequence_length=16, **dataset_args)
    measurement_trainset = omnipush_datasets.OmnipushMeasurementDataset(
        *omnipush_train_files, samples_per_pair=10, **dataset_args)
    e2e_trainset = omnipush_datasets.OmnipushParticleFilterDataset(
        *omnipush_train_files,
        subsequence_length=16,
        particle_count=30,
        particle_stddev=(.1, .1),
        **dataset_args)

# Pre-train measurement
models = [
    (pf_image_model, 'measurement_image'),
    # (pf_force_model, 'measurement_force'),
]
measurement_trainset_loader = torch.utils.data.DataLoader(measurement_trainset,
                                                          batch_size=32,
                                                          shuffle=True,
                                                          num_workers=16)
for pf_model, optim_name in models:
    for i in range(MEASUREMENT_PRETRAIN_EPOCHS):
        print("Pre-training measurement epoch", i)
Exemplo n.º 2
0
            buddy.load_checkpoint(path=args.load_checkpoint)
        if args.module_type == "ekf":
            buddy.load_checkpoint_module(source="image_model",
                                         path=args.load_checkpoint)
            buddy.load_checkpoint_module(source="force_model",
                                         path=args.load_checkpoint)

    print("Creating dataset...")

    if args.omnipush:
        e2e_trainset = omnipush_datasets.OmnipushParticleFilterDataset(
            "simpler/train0.hdf5",
            "simpler/train1.hdf5",
            "simpler/train2.hdf5",
            "simpler/train3.hdf5",
            "simpler/train4.hdf5",
            "simpler/train5.hdf5",
            subsequence_length=16,
            particle_count=1,
            particle_stddev=(.03, .03),
            **dataset_args)

        dataset_measurement = omnipush_datasets.OmnipushMeasurementDataset(
            "simpler/train0.hdf5",
            "simpler/train1.hdf5",
            "simpler/train2.hdf5",
            "simpler/train3.hdf5",
            "simpler/train4.hdf5",
            "simpler/train5.hdf5",
            subsequence_length=16,
            stddev=(0.5, 0.5),