Beispiel #1
0
# Load datasets
if args.dataset == "mujoco":
    dynamics_trainset = panda_datasets.PandaDynamicsDataset(
        "data/gentle_push_1000.hdf5", **dataset_args)
    dynamics_recurrent_trainset = panda_datasets.PandaSubsequenceDataset(
        "data/gentle_push_1000.hdf5", subsequence_length=16, **dataset_args)
    measurement_trainset = panda_datasets.PandaMeasurementDataset(
        "data/gentle_push_1000.hdf5",
        samples_per_pair=10,
        # Don't pretrain measurement model on black images if we're
        # rate-limiting images
        ignore_black_images=(args.sequential_image != 1),
        **dataset_args)
    e2e_trainset = panda_datasets.PandaParticleFilterDataset(
        "data/gentle_push_1000.hdf5",
        subsequence_length=16,
        particle_count=30,
        particle_stddev=(.1, .1),
        **dataset_args)
elif args.dataset == "omnipush":
    omnipush_train_files = (
        "simpler/train0.hdf5",
        "simpler/train1.hdf5",
        "simpler/train2.hdf5",
        "simpler/train3.hdf5",
        "simpler/train4.hdf5",
        "simpler/train5.hdf5",
    )
    dynamics_trainset = omnipush_datasets.OmnipushDynamicsDataset(
        *omnipush_train_files, **dataset_args)
    dynamics_recurrent_trainset = omnipush_datasets.OmnipushSubsequenceDataset(
        *omnipush_train_files, subsequence_length=16, **dataset_args)
Beispiel #2
0
            subsequence_length=32,
            **dataset_args)

        dataset_dynamics = omnipush_datasets.OmnipushDynamicsDataset(
            "simpler/train0.hdf5",
            "simpler/train1.hdf5",
            "simpler/train2.hdf5",
            "simpler/train3.hdf5",
            "simpler/train4.hdf5",
            "simpler/train5.hdf5",
            subsequence_length=16,
            **dataset_args)
    else:
        e2e_trainset = panda_datasets.PandaParticleFilterDataset(
            "data/gentle_push_{}.hdf5".format(args.data_size),
            subsequence_length=16,
            particle_count=1,
            particle_stddev=(.03, .03),
            **dataset_args)

        dataset_measurement = panda_datasets.PandaMeasurementDataset(
            "data/gentle_push_{}.hdf5".format(args.data_size),
            subsequence_length=16,
            stddev=(0.5, 0.5),
            samples_per_pair=20,
            **dataset_args)

        dynamics_recurrent_trainset = panda_datasets.PandaSubsequenceDataset(
            "data/gentle_push_{}.hdf5".format(args.data_size),
            subsequence_length=32,
            **dataset_args)