示例#1
0
    def random_batch_rnn(self, batch_size, seq_length=None, epoch=1):
        """
        Providing sequences of batch which is randomly sampled from trajectory.
        batch shape is (seq_length, batch_size, * )

        Parameters
        ----------
        batch_size : int
        seq_length : int
            Length of sequence of batch. 
            If seq_length None, max episode length is selected.
        epoch : int

        Returns
        -------
        batch : dict of torch.Tensor
        """

        if seq_length is None:
            seq_length = max([self._epis_index[i+1] - self._epis_index[i]
                              for i in range(len(self._epis_index)-1)])

        for _ in range(epoch):
            seqs = []
            lengths = []
            indices = np.random.randint(
                0, len(self._epis_index)-1, (batch_size,))

            if self.ddp:
                indices = indices[self.rank:len(indices):self.world_size]

            for idx in indices:
                length = min(
                    self._epis_index[idx+1] - self._epis_index[idx], seq_length)
                start = np.random.randint(
                    self._epis_index[idx], self._epis_index[idx+1] - length + 1)
                data_map = dict()
                for key in self.data_map:
                    if self._epis_index[-1] - self._epis_index[idx] < seq_length:
                        pad = torch.zeros_like(self.data_map[key][:seq_length-length],
                                               dtype=torch.float, device=get_device())
                        data_map[key] = torch.cat(
                            [self.data_map[key][start: start+length], pad])
                    else:
                        data_map[key] = self.data_map[key][start: start+seq_length]
                lengths.append(length)
                seqs.append(data_map)

            batch = dict()
            keys = seqs[0].keys()
            for key in keys:
                batch[key] = torch.stack([seq[key] for seq in seqs])
                # (batch_size, seq_length, *) -> (seq_length, batch_size, *)
                batch[key] = batch[key].transpose(0, 1).to(get_device())
            out_masks = torch.ones(
                (seq_length, batch_size), dtype=torch.float, device=get_device())
            for i in range(batch_size):
                out_masks[lengths[i]:, i] = 0
            batch['out_masks'] = out_masks
            yield batch
示例#2
0
    def iterate_rnn(self, batch_size, num_epi_per_seq=1, epoch=1):
        """
        Iterating batches for rnn.
        batch shape is (max_seq, batch_size, *)

        Parameters
        ----------
        batch_size : int
        num_epi_per_seq : int
            Number of episodes in one sequence for rnn.
        epoch : int

        Returns
        -------
        batch : dict of torch.Tensor
        """
        assert batch_size * num_epi_per_seq <= self.num_epi
        for _ in range(epoch):
            epi_count = 0
            all_batch = []
            seq = []
            for epi in self.iterate_epi(shuffle=True):
                seq.append(epi)
                epi_count += 1
                if epi_count >= num_epi_per_seq:
                    _seq = dict()
                    for key in seq[0].keys():
                        _seq[key] = torch.cat([s[key] for s in seq])
                    all_batch.append(_seq)
                    seq = []
                    epi_count = 0
            num_batch = len(all_batch)
            idx = 0
            while idx <= num_batch - batch_size:
                cur_batch_size = min(batch_size, num_batch - idx)
                batch = all_batch[idx:idx + cur_batch_size]
                idx += cur_batch_size

                lengths = [list(b.values())[0].size(0) for b in batch]
                max_length = max(lengths)
                out_masks = torch.ones((max_length, cur_batch_size),
                                       dtype=torch.float,
                                       device=get_device())
                time_slice = list(
                    functools.reduce(
                        lambda x, y: x + y,
                        [list(range(l, max_length)) for l in lengths]))
                batch_idx = list(
                    functools.reduce(lambda x, y: x + y,
                                     [(max_length - l) * [i]
                                      for i, l in enumerate(lengths)]))
                out_masks[time_slice, batch_idx] = 0

                _batch = dict()
                keys = batch[0].keys()
                for key in keys:
                    _batch[key] = pad_sequence([b[key] for b in batch
                                                ]).to(get_device())
                _batch['out_masks'] = out_masks.to(get_device())
                yield _batch
示例#3
0
    def __init__(self,
                 observation_space,
                 action_space,
                 net,
                 rnn=False,
                 data_parallel=False,
                 parallel_dim=0):
        nn.Module.__init__(self)
        self.observation_space = observation_space
        self.action_space = action_space
        self.net = net

        self.rnn = rnn
        self.hs = None

        self.data_parallel = data_parallel
        if data_parallel:
            if data_parallel is True:
                self.dp_net = nn.DataParallel(self.net, dim=parallel_dim)
            elif data_parallel == 'ddp':
                self.net.to(get_device())
                self.dp_net = nn.parallel.DistributedDataParallel(
                    self.net, device_ids=[get_device()], dim=parallel_dim)
            else:
                raise ValueError(
                    'Bool and str(ddp) are allowed to be data_parallel.')
        self.dp_run = False
示例#4
0
def compute_vs(data, vf):
    """
    Computing Value Function.

    Parameters
    ----------
    data : Traj or epis(dict of ndarray)
    vf : SVFunction

    Returns
    -------
    data : Traj or epi(dict of ndarray)
        Corresponding to input
    """
    if isinstance(data, Traj):
        epis = data.current_epis
    else:
        epis = data

    vf.reset()
    with torch.no_grad():
        for epi in epis:
            if vf.rnn:
                obs = torch.tensor(epi['obs'],
                                   dtype=torch.float,
                                   device=get_device()).unsqueeze(1)
            else:
                obs = torch.tensor(epi['obs'],
                                   dtype=torch.float,
                                   device=get_device())
            epi['vs'] = vf(obs)[0].detach().cpu().numpy()

    return data
示例#5
0
def compute_vs(data, vf):
    """
    Computing Value Function.

    Parameters
    ----------
    data : Traj
    vf : SVFunction

    Returns
    -------
    data : Traj
    """
    epis = data.current_epis
    vf.reset()
    with torch.no_grad():
        for epi in epis:
            if vf.rnn:
                obs = torch.tensor(epi['obs'],
                                   dtype=torch.float,
                                   device=get_device()).unsqueeze(1)
            else:
                obs = torch.tensor(epi['obs'],
                                   dtype=torch.float,
                                   device=get_device())
            epi['vs'] = vf(obs)[0].detach().cpu().numpy()

    return data
    def max(self, obs):
        """
        Max and Argmax of Qfunc
        Parameters
        ----------
        obs : torch.Tensor

        Returns
        -------
        max_qs, max_acs
        """

        obs = self._check_obs_shape(obs)

        self.batch_size = obs.shape[0]
        self.dim_ob = obs.shape[1]
        high = torch.tensor(self.ac_space.high,
                            dtype=torch.float, device=get_device())
        low = torch.tensor(
            self.ac_space.low, dtype=torch.float, device=get_device())
        init_samples = torch.linspace(0, 1, self.num_sampling, device=get_device()).reshape(
            self.num_sampling, -1) * (high - low) + low  # (self.num_sampling, dim_ac)
        init_samples = self._clamp(init_samples)
        max_qs, max_acs = self._cem(obs, init_samples)
        return max_qs, max_acs
 def _clamp(self, samples):
     low = torch.tensor(self.ac_space.low,
                        dtype=torch.float, device=get_device())
     high = torch.tensor(self.ac_space.high,
                         dtype=torch.float, device=get_device())
     samples = (samples - low) / (high - low)
     samples = torch.clamp(samples, 0, 1) * (high - low) + low
     return samples
示例#8
0
    def max(self, obs):
        """
        Perform max and argmax of Qfunc

        Parameters
        ----------
        obs : torch.Tensor

        Returns
        -------
        max_qs : torch.Tensor
        max_acs : torch.Tensor
        """

        obs = self._check_obs_shape(obs)

        self.dim_ob = obs.shape[1]
        high = torch.tensor(self.action_space.high,
                            dtype=torch.float,
                            device=get_device())
        low = torch.tensor(self.action_space.low,
                           dtype=torch.float,
                           device=get_device())
        init_samples = torch.linspace(0,
                                      1,
                                      self.num_sampling,
                                      device=get_device())
        init_samples = init_samples.reshape(self.num_sampling, -1) * (
            high - low) + low  # (self.num_sampling, dim_ac)
        init_samples = self._clamp(init_samples)
        if not self.save_memory:  # batch
            self.cem_batch_size = obs.shape[0]
            obs = obs.repeat((1, self.num_sampling)).reshape(
                (self.cem_batch_size * self.num_sampling, self.dim_ob))
            # concatenate[(self.num_sampling, dim_ac), ..., (self.num_sampling, self.dim_ob)], dim=0)
            init_samples = init_samples.repeat((self.cem_batch_size, 1))
            # concatenate[(self.num_sampling, dim_ac), ..., (self.num_sampling, dim_ac)], dim=0)
            max_qs, max_acs = self._cem(obs, init_samples)
        else:  # for-sentence
            self.cem_batch_size = 1
            max_acs = []
            max_qs = []
            for ob in obs:
                ob = ob.repeat((1, self.num_sampling)).reshape(
                    (self.cem_batch_size * self.num_sampling, self.dim_ob))
                ob = self._check_obs_shape(ob)
                max_q, max_ac = self._cem(ob, init_samples)
                max_qs.append(max_q)
                max_acs.append(max_ac)
            max_qs = torch.tensor(max_qs, dtype=torch.float, device=obs.device)
            max_acs = torch.cat(max_acs, dim=0)
        max_acs = self._check_acs_shape(max_acs)
        return max_qs, max_acs
示例#9
0
def compute_hs(data, func, hs_name='hs', input_acs=False):
    """
    Computing Hidden State of RNN Cell.

    Parameters
    ----------
    data : Traj or epis(dict of ndarray)
    func : 
        Any function. for example pols, vf and qf.

    Returns
    -------
    data : Traj or epi(dict of ndarray)
        Corresponding to input
    """
    if isinstance(data, Traj):
        epis = data.current_epis
    else:
        epis = data

    func.reset()
    with torch.no_grad():
        for epi in epis:
            obs = torch.tensor(epi['obs'],
                               dtype=torch.float,
                               device=get_device()).unsqueeze(1)
            time_seq = obs.size()[0]
            if input_acs:
                acs = torch.tensor(epi['acs'],
                                   dtype=torch.float,
                                   device=get_device()).unsqueeze(1)
                hs_seq = [
                    func(obs[i:i + 1], acs[i:i + 1])[-1]['hs']
                    for i in range(time_seq)
                ]
            else:
                hs_seq = [
                    func(obs[i:i + 1])[-1]['hs'] for i in range(time_seq)
                ]
            if isinstance(hs_seq[0], tuple):
                hs = np.array(
                    [[h.squeeze().detach().cpu().numpy() for h in hs]
                     for hs in hs_seq],
                    dtype='float32')
            else:
                hs = np.array(hs.detach().cpu().numpy(), dtype='float32')
            epi[hs_name] = hs

    return data
示例#10
0
    def register_epis(self):
        epis = self.current_epis
        keys = epis[0].keys()
        data_map = dict()
        for key in keys:
            if isinstance(epis[0][key], list) or isinstance(epis[0][key], np.ndarray):
                data_map[key] = torch.tensor(np.concatenate(
                    [epi[key] for epi in epis], axis=0), dtype=torch.float, device=get_device())
            elif isinstance(epis[0][key], dict):
                new_keys = epis[0][key].keys()
                for new_key in new_keys:
                    data_map[new_key] = torch.tensor(np.concatenate(
                        [epi[key][new_key] for epi in epis], axis=0), dtype=torch.float, device=get_device())

        self._concat_data_map(data_map)

        epis_index = []
        index = 0
        for epi in epis:
            l_epi = len(epi['rews'])
            index += l_epi
            epis_index.append(index)
        epis_index = np.array(epis_index) + self._epis_index[-1]
        self._epis_index = np.concatenate([self._epis_index, epis_index])

        self.current_epis = None
示例#11
0
 def _get_indices(self, indices=None, shuffle=True):
     if indices is None:
         indices = torch.arange(
             self.num_step, device=get_device(), dtype=torch.long)
     if shuffle:
         indices = self._shuffled_indices(indices)
     return indices
示例#12
0
    def __init__(self,
                 observation_space,
                 action_space,
                 net,
                 rew_func,
                 n_samples=1000,
                 horizon=20,
                 mean_obs=0.,
                 std_obs=1.,
                 mean_acs=0.,
                 std_acs=1.,
                 rnn=False,
                 normalize_ac=True):
        BasePol.__init__(self,
                         observation_space,
                         action_space,
                         net,
                         rnn=rnn,
                         normalize_ac=normalize_ac)
        self.rew_func = rew_func
        self.n_samples = n_samples
        self.horizon = horizon
        self.to(get_device())

        self.mean_obs = torch.tensor(mean_obs,
                                     dtype=torch.float).repeat(n_samples, 1)
        self.std_obs = torch.tensor(std_obs,
                                    dtype=torch.float).repeat(n_samples, 1)
        self.mean_acs = torch.tensor(mean_acs,
                                     dtype=torch.float).repeat(n_samples, 1)
        self.std_acs = torch.tensor(std_acs,
                                    dtype=torch.float).repeat(n_samples, 1)
示例#13
0
def density_ratio_cross_ent(pol,
                            batch,
                            expert_or_agent,
                            gamma,
                            rewf=None,
                            shaping_vf=None,
                            advf=None):
    obs = batch['obs']
    acs = batch['acs']
    if rewf is not None and shaping_vf is not None:
        next_obs = batch['next_obs']
        dones = batch['dones']
        vs, _ = shaping_vf(obs)
        rews, _ = rewf(obs)
        next_vs, _ = shaping_vf(next_obs)
        energies = rews + (1 - dones) * gamma * next_vs - vs
    elif advf is not None:
        energies, _ = advf(obs, acs)
    with torch.no_grad():
        _, _, params = pol(obs)
        llhs = pol.pd.llh(acs, params)
    logits = energies - llhs
    len = obs.shape[0]
    discrim_loss = F.binary_cross_entropy_with_logits(
        logits,
        torch.ones(len, device=get_device()) * expert_or_agent)
    return discrim_loss
示例#14
0
 def __init__(self, ob_space, ac_space, qfunc, rnn=False, normalize_ac=True, data_parallel=False, parallel_dim=0, eps=0.2):
     BasePol.__init__(self, ob_space, ac_space, None, rnn,
                      normalize_ac, data_parallel, parallel_dim)
     self.qfunc = qfunc
     self.eps = eps
     self.a_i_shape = (1, )
     self.to(get_device())
示例#15
0
def compute_pseudo_rews(data, rew_giver, state_only=False):
    epis = data.current_epis
    for epi in epis:
        obs = torch.tensor(epi['obs'], dtype=torch.float, device=get_device())
        if state_only:
            logits, _ = rew_giver(obs)
        else:
            acs = torch.tensor(epi['acs'],
                               dtype=torch.float,
                               device=get_device())
            logits, _ = rew_giver(obs, acs)
        with torch.no_grad():
            rews = -F.logsigmoid(-logits).cpu().numpy()
        epi['real_rews'] = copy.deepcopy(epi['rews'])
        epi['rews'] = rews
    return data
示例#16
0
    def prioritized_random_batch_rnn_once(self, batch_size, seq_length, return_indices=False, init_beta=0.4, beta_step=0.00025/4):
        if hasattr(self, 'pri_beta') == False:
            self.pri_beta = init_beta
        elif self.pri_beta >= 1.0:
            self.pri_beta = 1.0
        else:
            self.pri_beta += beta_step

        seq_pris = self.data_map['seq_pris'].clone().detach()

        start_indices = torch.utils.data.sampler.WeightedRandomSampler(
            seq_pris, batch_size)  # , replacement=True)
        start_indices = [idx for idx in start_indices]

        seqs = []
        length = []
        for start in start_indices:
            data_map = dict()
            for key in self.data_map:
                data_map[key] = self.data_map[key][start: start+seq_length]
            seqs.append(data_map)

        batch = dict()
        keys = seqs[0].keys()
        for key in keys:
            batch[key] = torch.stack([seq[key] for seq in seqs], dim=0)
            # (batch_size, seq_length, *) -> (seq_length, batch_size, *)
            batch[key] = batch[key].transpose(0, 1).to(get_device())

        if return_indices:
            return batch, start_indices
        else:
            return batch
示例#17
0
    def prioritized_random_batch_once(self, batch_size, return_indices=False, mode='proportional', alpha=0.6, init_beta=0.4, beta_step=0.00025/4):
        if hasattr(self, 'pri_beta') == False:
            self.pri_beta = init_beta
        elif self.pri_beta >= 1.0:
            self.pri_beta = 1.0
        else:
            self.pri_beta += beta_step

        pris = self.data_map['pris'].cpu().numpy()

        if mode == 'rank_based':
            index = np.argsort(-pris)
            pris = (index.astype(np.float32)+1) ** -1
            pris = pris ** alpha

        is_weights = (len(pris) * (pris/pris.sum())) ** -self.pri_beta
        is_weights /= np.max(is_weights)
        pris *= is_weights
        pris = torch.tensor(pris)
        indices = torch.utils.data.sampler.WeightedRandomSampler(
            pris, batch_size, replacement=True)
        indices = [index for index in indices]

        if self.ddp:
            indices = indices[self.rank:len(indices):self.world_size]

        data_map = dict()
        for key in self.data_map:
            data_map[key] = self.data_map[key][indices].to(get_device())
        if return_indices:
            return data_map, indices
        else:
            return data_map
示例#18
0
    def __init__(self, observation_space, net):
        super().__init__(self, observation_space, net)
        self.x_mean = torch.zeros(1)
        self.x_std = torch.ones(1)
        self.to(get_device())

        self.normalized = True
示例#19
0
    def random_batch_once(self,
                          batch_size,
                          indices=None,
                          return_indices=False):
        """
        Providing a batch which is randomly sampled from trajectory.

        Parameters
        ----------
        batch_size : int
        indices : ndarray or torch.Tensor or None
            Selected indices for iteration.
            If None, whole trajectory is selected.
        return_indices : bool
            If True, indices are also returned.

        Returns
        -------
        data_map : dict of torch.Tensor
        """
        #indices = self._get_indices(indices, shuffle=True)
        indices = torch.randint(0, self.num_step - 1, size=(batch_size, ))

        data_map = dict()
        for key in self.data_map:
            data_map[key] = self.data_map[key][indices].to(get_device())
        if return_indices:
            return data_map, indices
        else:
            return data_map
示例#20
0
 def __init__(self,
              observation_space,
              action_space,
              net,
              normalize_ac=True):
     BasePol.__init__(self, observation_space, action_space, normalize_ac)
     self.net = net
     self.pd = MixtureGaussianPd()
     self.to(get_device())
示例#21
0
    def __init__(self,
                 observation_space,
                 action_space,
                 net,
                 rnn=False,
                 normalize_ac=True,
                 data_parallel=False,
                 parallel_dim=0):
        nn.Module.__init__(self)
        self.observation_space = observation_space
        self.action_space = action_space
        self.net = net

        self.rnn = rnn
        self.hs = None

        self.normalize_ac = normalize_ac
        self.data_parallel = data_parallel
        if data_parallel:
            if data_parallel is True:
                self.dp_net = nn.DataParallel(self.net, dim=parallel_dim)
            elif data_parallel == 'ddp':
                self.net.to(get_device())
                self.dp_net = nn.parallel.DistributedDataParallel(
                    self.net, device_ids=[get_device()], dim=parallel_dim)
            else:
                raise ValueError(
                    'Bool and str(ddp) are allowed to be data_parallel.')
        self.dp_run = False

        self.discrete = isinstance(action_space,
                                   gym.spaces.MultiDiscrete) or isinstance(
                                       action_space, gym.spaces.Discrete)
        self.multi = isinstance(action_space, gym.spaces.MultiDiscrete)

        if not self.discrete:
            self.a_i_shape = action_space.shape
        else:
            if isinstance(action_space, gym.spaces.MultiDiscrete):
                nvec = action_space.nvec
                assert any([nvec[0] == nv for nv in nvec])
                self.a_i_shape = (len(nvec), nvec[0])
            elif isinstance(action_space, gym.spaces.Discrete):
                self.a_i_shape = (action_space.n, )
示例#22
0
 def __init__(self,
              observation_space,
              action_space,
              net,
              rnn=False,
              normalize_ac=True):
     BasePol.__init__(self, observation_space, action_space, net, rnn,
                      normalize_ac)
     self.pd = MultiCategoricalPd()
     self.to(get_device())
示例#23
0
def cross_ent(discrim, batch, expert_or_agent, ent_beta):
    obs = batch['obs']
    acs = batch['acs']
    len = obs.shape[0]
    logits, _ = discrim(obs, acs)
    discrim_loss = F.binary_cross_entropy_with_logits(
        logits, torch.ones(len, device=get_device())*expert_or_agent)
    ent = (1 - torch.sigmoid(logits))*logits - F.logsigmoid(logits)
    discrim_loss -= ent_beta * torch.mean(ent)
    return discrim_loss
示例#24
0
    def _next_batch(self, batch_size, indices):
        cur_id = self._next_id
        cur_batch_size = min(batch_size, len(indices) - self._next_id)
        self._next_id += cur_batch_size

        data_map = dict()
        for key in self.data_map:
            data_map[key] = self.data_map[key][cur_id:cur_id +
                                               cur_batch_size].to(get_device())
        return data_map
示例#25
0
 def __init__(self, ob_space, ac_space, net, rnn=False, data_parallel=False, parallel_dim=0, num_sampling=64, num_best_sampling=6, num_iter=2, multivari=True, delta=1e-4):
     super().__init__(ob_space, ac_space, net, rnn, data_parallel, parallel_dim)
     self.num_sampling = num_sampling
     self.delta = delta
     self.num_best_sampling = num_best_sampling
     self.num_iter = num_iter
     self.net = net
     self.dim_ac = self.ac_space.shape[0]
     self.multivari = multivari
     self.to(get_device())
示例#26
0
 def __init__(self,
              observation_space,
              action_space,
              net,
              rnn=False,
              data_parallel=False,
              parallel_dim=0):
     super().__init__(observation_space, action_space, net, rnn,
                      data_parallel, parallel_dim)
     self.to(get_device())
示例#27
0
    def _cem(self, obs, samples):
        """
        Perform cross entropy method

        Parameters
        ----------
        obs : torch.Tensor
        samples : torch.Tensor
            shape (self.num_sampling, dim_ac)

        Returns
        -------
        max_q : torch.Tensor
        max_ac : torch.Tensor
        """
        for i in range(self.num_iter + 1):
            with torch.no_grad():
                qvals, _ = self.forward(obs, samples)
            if i != self.num_iter:
                qvals = qvals.reshape((self.cem_batch_size, self.num_sampling))
                _, indices = torch.sort(qvals, dim=1, descending=True)
                best_indices = indices[:, :self.num_best_sampling]
                best_indices = best_indices + \
                    torch.arange(0, self.num_sampling * self.cem_batch_size,
                                 self.num_sampling, device=get_device()).reshape((self.cem_batch_size, 1))
                best_indices = best_indices.reshape(
                    (self.num_best_sampling * self.cem_batch_size, ))
                # (self.num_best_sampling * self.cem_batch_size,  self.dim_ac)
                best_samples = samples[best_indices, :]
                # (self.cem_batch_size, self.num_best_sampling, self.dim_ac)
                best_samples = best_samples.reshape(
                    (self.cem_batch_size, self.num_best_sampling, self.dim_ac))
                samples = self._fitting_diag(
                    best_samples
                ) if not self.multivari else self._fitting_multivari(
                    best_samples)
        qvals = qvals.reshape((self.cem_batch_size, self.num_sampling))
        samples = samples.reshape(
            (self.cem_batch_size, self.num_sampling, self.dim_ac))
        max_q, ind = torch.max(qvals, dim=1)
        max_ac = samples[
            torch.arange(self.cem_batch_size, device=get_device()), ind]
        return max_q, max_ac
示例#28
0
def task_oriented_reward(data, rew_giver, state_only=False):
    if isinstance(data, Traj):
        epis = data.current_epis
    else:
        epis = data
    for epi in epis:
        obs = torch.tensor(epi['obs'], dtype=torch.float, device=get_device())
        if state_only:
            logits, _ = rew_giver(obs)
        else:
            acs = torch.tensor(epi['acs'],
                               dtype=torch.float,
                               device=get_device())
            logits, _ = rew_giver(obs, acs)
        with torch.no_grad():
            rews = F.logsigmoid(logits).cpu().numpy()
        epi['real_rews'] = copy.deepcopy(epi['rews'])
        epi['rews'] = rews + copy.deepcopy(epi['rews'])
    return data
示例#29
0
 def __init__(self,
              observation_space,
              action_space,
              net,
              rnn=False,
              normalize_ac=True):
     BasePol.__init__(self, observation_space, action_space, net, rnn,
                      normalize_ac)
     self.pd = GaussianPd()
     self.to(get_device())
示例#30
0
def compute_diayn_rews(data, rew_giver):
    epis = data.current_epis
    for epi in epis:
        obs = torch.as_tensor(epi['obs'],
                              dtype=torch.float,
                              device=get_device())
        with torch.no_grad():
            rews, info = rew_giver(obs)
        epi['rews'] = rews.cpu().numpy()
    return data