Esempio n. 1
0
    def get_sample(self):
        data_dict = AttrDict()
        data_dict.images = np.random.rand(self.spec['max_seq_len'], 3, self.img_sz, self.img_sz).astype(np.float32)
        data_dict.states = np.random.rand(self.spec['max_seq_len'], self.spec['state_dim']).astype(np.float32)
        data_dict.actions = np.random.rand(self.spec['max_seq_len'] - 1, self.spec['n_actions']).astype(np.float32)

        return data_dict
Esempio n. 2
0
    def _get_raw_data(self, index):
        data = AttrDict()
        file_index = index // self.samples_per_file
        path = self.filenames[file_index]

        try:
            with h5py.File(path, 'r') as F:
                ex_index = index % self.samples_per_file  # get the index
                key = 'traj{}'.format(ex_index)

                # Fetch data into a dict
                for name in F[key].keys():
                    if name in ['states', 'actions', 'pad_mask']:
                        data[name] = F[key + '/' + name][()].astype(np.float32)

                if key + '/images' in F:
                    data.images = F[key + '/images'][()]
                else:
                    data.images = np.zeros((data.states.shape[0], 2, 2, 3), dtype=np.uint8)
        except:
            raise ValueError("Could not load from file {}".format(path))
        return data
Esempio n. 3
0
 def forward(self, *args, **kwargs):
     output = AttrDict()
     output.feat = self.net(*args, **kwargs)
     output.images = self.gen_head(output.feat)
     return output