Ejemplo n.º 1
0
    e2e_trainset = panda_datasets.PandaParticleFilterDataset(
        "data/gentle_push_1000.hdf5",
        subsequence_length=16,
        particle_count=30,
        particle_stddev=(.1, .1),
        **dataset_args)
elif args.dataset == "omnipush":
    omnipush_train_files = (
        "simpler/train0.hdf5",
        "simpler/train1.hdf5",
        "simpler/train2.hdf5",
        "simpler/train3.hdf5",
        "simpler/train4.hdf5",
        "simpler/train5.hdf5",
    )
    dynamics_trainset = omnipush_datasets.OmnipushDynamicsDataset(
        *omnipush_train_files, **dataset_args)
    dynamics_recurrent_trainset = omnipush_datasets.OmnipushSubsequenceDataset(
        *omnipush_train_files, subsequence_length=16, **dataset_args)
    measurement_trainset = omnipush_datasets.OmnipushMeasurementDataset(
        *omnipush_train_files, samples_per_pair=10, **dataset_args)
    e2e_trainset = omnipush_datasets.OmnipushParticleFilterDataset(
        *omnipush_train_files,
        subsequence_length=16,
        particle_count=30,
        particle_stddev=(.1, .1),
        **dataset_args)

# Pre-train measurement
models = [
    (pf_image_model, 'measurement_image'),
    # (pf_force_model, 'measurement_force'),
Ejemplo n.º 2
0
        dynamics_recurrent_trainset = omnipush_datasets.OmnipushSubsequenceDataset(
            "simpler/train0.hdf5",
            "simpler/train1.hdf5",
            "simpler/train2.hdf5",
            "simpler/train3.hdf5",
            "simpler/train4.hdf5",
            "simpler/train5.hdf5",
            subsequence_length=32,
            **dataset_args)

        dataset_dynamics = omnipush_datasets.OmnipushDynamicsDataset(
            "simpler/train0.hdf5",
            "simpler/train1.hdf5",
            "simpler/train2.hdf5",
            "simpler/train3.hdf5",
            "simpler/train4.hdf5",
            "simpler/train5.hdf5",
            subsequence_length=16,
            **dataset_args)
    else:
        e2e_trainset = panda_datasets.PandaParticleFilterDataset(
            "data/gentle_push_{}.hdf5".format(args.data_size),
            subsequence_length=16,
            particle_count=1,
            particle_stddev=(.03, .03),
            **dataset_args)

        dataset_measurement = panda_datasets.PandaMeasurementDataset(
            "data/gentle_push_{}.hdf5".format(args.data_size),
            subsequence_length=16,