Esempio n. 1
0
    def get_data_config(self, conf_module):
        # get default data config

        path = os.path.join(
            get_dataset_path(conf_module.configuration['dataset_name']),
            'dataset_spec.py')
        data_conf_file = imp.load_source('dataset_spec', path)
        data_conf = AttrDict()
        data_conf.dataset_spec = AttrDict(data_conf_file.dataset_spec)

        # update with custom params if available
        update_data_conf = {}
        if hasattr(conf_module, 'data_config'):
            update_data_conf = conf_module.data_config
        elif conf_module.configuration.dataset_name is not None:
            update_data_conf = importlib.import_module(
                'gcp.datasets.configs.' +
                conf_module.configuration.dataset_name).config

        for key in update_data_conf:
            if key == "dataset_spec":
                data_conf.dataset_spec.update(update_data_conf.dataset_spec)
            else:
                data_conf[key] = update_data_conf[key]

        if not 'fps' in data_conf:
            data_conf.fps = 4
        return data_conf
Esempio n. 2
0
    def get_data_config(self, conf_module):
        # get default data config
        path = os.path.join(
            get_dataset_path(conf_module.configuration['dataset_name']),
            'dataset_spec.py')
        data_conf_file = imp.load_source('dataset_spec', path)
        data_conf = AttrDict()
        data_conf.dataset_spec = AttrDict(data_conf_file.dataset_spec)

        # update with custom params if available
        try:
            update_data_conf = conf_module.data_config
        except AttributeError:
            pass
        for key in update_data_conf:
            if key == "dataset_spec":
                data_conf.dataset_spec.update(update_data_conf.dataset_spec)
            else:
                data_conf[key] = update_data_conf[key]

        if not 'fps' in data_conf:
            data_conf.fps = 4
        return data_conf
Esempio n. 3
0
    'nz_mid_lstm': 512,
    'n_lstm_layers': 3,
    'nz_mid': 128,
    'nz_enc': 128,
    'nz_vae': 256,
    'regress_length': True,
    'attach_state_regressor': True,
    'attach_cost_mdl': True,
    'cost_mdl_params': AttrDict(
        cost_fcn=EuclideanPathLength,
    ),
    'attach_inv_mdl': True,
    'inv_mdl_params': AttrDict(
        n_actions=2,
        use_convs=False,
        build_encoder=False,
    ),
    'decoder_distribution': 'discrete_logistic_mixture',
})
model_config.pop("add_weighted_pixel_copy")


## Dataset
data_config = AttrDict()
data_config.dataset_spec = AttrDict()
data_config.dataset_spec.max_seq_len = 100

data_config.dataset_spec.dataset_class = MazeTopRenderedGlobalSplitVarLenVideoDataset
data_config.n_rooms = configuration['n_rooms']
data_config.crop_window = 40
data_config.dataset_spec.split = AttrDict(train=0.994, val=0.006, test=0.00)