Exemple #1
0
    def update_backward(self, batch, normalizer=None):

        observation_space = self.observation_space - K.tensor(
            batch['g'], dtype=self.dtype, device=self.device).shape[1]
        action_space = self.action_space[0].shape[0]

        if self.n_objects <= 1:
            s2 = K.cat([
                K.tensor(batch['o'], dtype=self.dtype,
                         device=self.device)[:, observation_space:],
                K.tensor(batch['g'], dtype=self.dtype, device=self.device)
            ],
                       dim=-1)
        else:
            s2 = get_obj_obs(K.tensor(batch['o'],
                                      dtype=self.dtype,
                                      device=self.device)[:,
                                                          observation_space:],
                             K.tensor(batch['g'],
                                      dtype=self.dtype,
                                      device=self.device),
                             n_object=self.n_objects)

        a2 = K.tensor(batch['u'], dtype=self.dtype,
                      device=self.device)[:, action_space:]

        if self.n_objects <= 1:
            s2_ = K.cat([
                K.tensor(batch['o_2'], dtype=self.dtype,
                         device=self.device)[:, observation_space:],
                K.tensor(batch['g'], dtype=self.dtype, device=self.device)
            ],
                        dim=-1)
        else:
            s2_ = get_obj_obs(K.tensor(batch['o_2'],
                                       dtype=self.dtype,
                                       device=self.device)[:,
                                                           observation_space:],
                              K.tensor(batch['g'],
                                       dtype=self.dtype,
                                       device=self.device),
                              n_object=self.n_objects)

        if normalizer[1] is not None:
            if self.n_objects <= 1:
                s2 = normalizer[1].preprocess(s2)
                s2_ = normalizer[1].preprocess(s2_)
            else:
                for i_object in range(self.n_objects):
                    s2[:, :, i_object] = normalizer[1].preprocess(s2[:, :,
                                                                     i_object])
                    s2_[:, :,
                        i_object] = normalizer[1].preprocess(s2_[:, :,
                                                                 i_object])

        if self.n_objects <= 1:
            a2_pred = self.backward(s2, s2_)
            loss_backward = self.loss_func(a2_pred, a2)
        else:
            loss_backward = 0.
            n_obj_actions = a2.shape[1] // self.n_objects
            for i_object in range(self.n_objects):
                act_slice = slice(i_object * n_obj_actions,
                                  (i_object + 1) * n_obj_actions)
                a2_pred = self.backward(s2[:, :, i_object], s2_[:, :,
                                                                i_object])
                loss_backward += self.loss_func(a2_pred, a2[:, act_slice])

        self.backward_optim.zero_grad()
        loss_backward.backward()
        self.backward_optim.step()

        return loss_backward.item()
Exemple #2
0
    def update_parameters(self, batch, normalizer=None):

        observation_space = self.observation_space - K.tensor(
            batch['g'], dtype=self.dtype, device=self.device).shape[1]
        action_space = self.action_space[0].shape[0]

        V = K.zeros((len(batch['o']), 1), dtype=self.dtype, device=self.device)

        s1 = K.cat([
            K.tensor(batch['o'], dtype=self.dtype,
                     device=self.device)[:, 0:observation_space],
            K.tensor(batch['g'], dtype=self.dtype, device=self.device)
        ],
                   dim=-1)

        if self.n_objects <= 1:
            s2 = K.cat([
                K.tensor(batch['o'], dtype=self.dtype,
                         device=self.device)[:, observation_space:],
                K.tensor(batch['g'], dtype=self.dtype, device=self.device)
            ],
                       dim=-1)
        else:
            s2 = get_obj_obs(K.tensor(batch['o'],
                                      dtype=self.dtype,
                                      device=self.device)[:,
                                                          observation_space:],
                             K.tensor(batch['g'],
                                      dtype=self.dtype,
                                      device=self.device),
                             n_object=self.n_objects)

        a1 = K.tensor(batch['u'], dtype=self.dtype,
                      device=self.device)[:, 0:action_space]
        a2 = K.tensor(batch['u'], dtype=self.dtype,
                      device=self.device)[:, action_space:]

        s1_ = K.cat([
            K.tensor(batch['o_2'], dtype=self.dtype,
                     device=self.device)[:, 0:observation_space],
            K.tensor(batch['g'], dtype=self.dtype, device=self.device)
        ],
                    dim=-1)

        if self.n_objects <= 1:
            s2_ = K.cat([
                K.tensor(batch['o_2'], dtype=self.dtype,
                         device=self.device)[:, observation_space:],
                K.tensor(batch['g'], dtype=self.dtype, device=self.device)
            ],
                        dim=-1)
        else:
            s2_ = get_obj_obs(K.tensor(batch['o_2'],
                                       dtype=self.dtype,
                                       device=self.device)[:,
                                                           observation_space:],
                              K.tensor(batch['g'],
                                       dtype=self.dtype,
                                       device=self.device),
                              n_object=self.n_objects)

        if normalizer[0] is not None:
            s1 = normalizer[0].preprocess(s1)
            s1_ = normalizer[0].preprocess(s1_)

        if normalizer[1] is not None:
            if self.n_objects <= 1:
                s2 = normalizer[1].preprocess(s2)
                s2_ = normalizer[1].preprocess(s2_)
            else:
                for i_object in range(self.n_objects):
                    s2[:, :, i_object] = normalizer[1].preprocess(s2[:, :,
                                                                     i_object])
                    s2_[:, :,
                        i_object] = normalizer[1].preprocess(s2_[:, :,
                                                                 i_object])

        s, s_, a = (s1, s1_, a1) if self.agent_id == 0 else (s2, s2_, a2)
        a_ = self.actors_target[0](s_)

        if self.object_Qfunc is None:
            r = K.tensor(batch['r'], dtype=self.dtype,
                         device=self.device).unsqueeze(1)
        else:
            r = K.tensor(batch['r'], dtype=self.dtype,
                         device=self.device).unsqueeze(1)
            if self.n_objects <= 1:
                if self.masked_with_r:
                    r = self.get_obj_reward(s2, s2_) * K.abs(r) + r
                else:
                    r = self.get_obj_reward(s2, s2_) + r
            else:
                r_intr = K.zeros_like(r)
                for i_object in range(self.n_objects):
                    r_intr += self.get_obj_reward(s2[:, :, i_object],
                                                  s2_[:, :, i_object])
                if self.masked_with_r:
                    r = r_intr * K.abs(r) + r
                else:
                    r = r_intr + r

        Q = self.critics[0](s, a)
        V = self.critics_target[0](s_, a_).detach()

        target_Q = (V * self.gamma) + r
        if self.object_Qfunc is None:
            target_Q = target_Q.clamp(-1. / (1. - self.gamma), 0.)
        else:
            target_Q = target_Q.clamp(
                -(1 + self.n_objects) / (1. - self.gamma), 0.)

        loss_critic = self.loss_func(Q, target_Q)

        self.critics_optim[0].zero_grad()
        loss_critic.backward()
        self.critics_optim[0].step()

        a = self.actors[0](s)

        loss_actor = -self.critics[0](s, a).mean()

        if self.regularization:
            loss_actor += (self.actors[0](s)**2).mean() * 1

        self.actors_optim[0].zero_grad()
        loss_actor.backward()
        self.actors_optim[0].step()

        return loss_critic.item(), loss_actor.item()
Exemple #3
0
    def update_parameters(self, batch, normalizer=None):

        observation_space = self.observation_space - K.tensor(
            batch['g'], dtype=self.dtype, device=self.device).shape[1]
        action_space = self.action_space[0].shape[0]

        V = K.zeros((len(batch['o']), 1), dtype=self.dtype, device=self.device)

        s1 = K.cat([
            K.tensor(batch['o'], dtype=self.dtype,
                     device=self.device)[:, 0:observation_space],
            K.tensor(batch['g'], dtype=self.dtype, device=self.device)
        ],
                   dim=-1)

        if self.n_objects <= 1:
            s2 = K.cat([
                K.tensor(batch['o'], dtype=self.dtype,
                         device=self.device)[:, observation_space:],
                K.tensor(batch['g'], dtype=self.dtype, device=self.device)
            ],
                       dim=-1)
        else:
            s2 = get_obj_obs(K.tensor(batch['o'],
                                      dtype=self.dtype,
                                      device=self.device)[:,
                                                          observation_space:],
                             K.tensor(batch['g'],
                                      dtype=self.dtype,
                                      device=self.device),
                             n_object=self.n_objects)

        a1 = K.tensor(batch['u'], dtype=self.dtype,
                      device=self.device)[:, 0:action_space]
        a2 = K.tensor(batch['u'], dtype=self.dtype,
                      device=self.device)[:, action_space:]

        s1_ = K.cat([
            K.tensor(batch['o_2'], dtype=self.dtype,
                     device=self.device)[:, 0:observation_space],
            K.tensor(batch['g'], dtype=self.dtype, device=self.device)
        ],
                    dim=-1)

        if self.n_objects <= 1:
            s2_ = K.cat([
                K.tensor(batch['o_2'], dtype=self.dtype,
                         device=self.device)[:, observation_space:],
                K.tensor(batch['g'], dtype=self.dtype, device=self.device)
            ],
                        dim=-1)
        else:
            s2_ = get_obj_obs(K.tensor(batch['o_2'],
                                       dtype=self.dtype,
                                       device=self.device)[:,
                                                           observation_space:],
                              K.tensor(batch['g'],
                                       dtype=self.dtype,
                                       device=self.device),
                              n_object=self.n_objects)

        if normalizer[0] is not None:
            s1 = normalizer[0].preprocess(s1)
            s1_ = normalizer[0].preprocess(s1_)

        if normalizer[1] is not None:
            if self.n_objects <= 1:
                s2 = normalizer[1].preprocess(s2)
                s2_ = normalizer[1].preprocess(s2_)
            else:
                for i_object in range(self.n_objects):
                    s2[:, :, i_object] = normalizer[1].preprocess(s2[:, :,
                                                                     i_object])
                    s2_[:, :,
                        i_object] = normalizer[1].preprocess(s2_[:, :,
                                                                 i_object])

        s, s_, a = (s1, s1_, a1) if self.agent_id == 0 else (s2, s2_, a2)
        a_ = self.actors_target[0](s_)

        r_all = []
        if self.object_Qfunc is None:
            r = K.tensor(batch['r'], dtype=self.dtype,
                         device=self.device).unsqueeze(1)
            r_all.append(r)
        else:
            r = K.tensor(batch['r'], dtype=self.dtype,
                         device=self.device).unsqueeze(1)
            r_all.append(r)
            for i_object in range(self.n_objects):
                r_all.append(
                    self.get_obj_reward(s2[:, :, i_object], s2_[:, :,
                                                                i_object]))

        # first critic for main rewards
        Q = self.critics[0](s, a)
        V = self.critics_target[0](s_, a_).detach()

        target_Q = (V * self.gamma) + r_all[0]
        target_Q = target_Q.clamp(self.clip_Q_neg, 0.)

        loss_critic = self.loss_func(Q, target_Q)

        self.critics_optim[0].zero_grad()
        loss_critic.backward()
        self.critics_optim[0].step()

        # other critics for intrinsic
        for i_object in range(self.n_objects):
            Q = self.critics[i_object + 1](s, a)
            V = self.critics_target[i_object + 1](s_, a_).detach()

            target_Q = (V * self.gamma) + r_all[i_object + 1]
            target_Q = target_Q.clamp(self.clip_Q_neg, 0.)

            loss_critic = self.loss_func(Q, target_Q)

            self.critics_optim[i_object + 1].zero_grad()
            loss_critic.backward()
            self.critics_optim[i_object + 1].step()

        # actor update
        a = self.actors[0](s)

        loss_actor = -self.critics[0](s, a).mean()
        for i_object in range(self.n_objects):
            loss_actor += -self.critics[i_object + 1](s, a).mean()

        if self.regularization:
            loss_actor += (self.actors[0](s)**2).mean() * 1

        self.actors_optim[0].zero_grad()
        loss_actor.backward()
        self.actors_optim[0].step()

        return loss_critic.item(), loss_actor.item()
Exemple #4
0
def rollout(env,
            model,
            noise,
            i_env,
            normalizer=None,
            render=False,
            agent_id=0,
            ai_object=False,
            rob_policy=[0., 0.]):
    trajectories = []
    for i_agent in range(2):
        trajectories.append([])

    # monitoring variables
    episode_reward = 0
    frames = []

    env[i_env].env.ai_object = True if agent_id == 1 else ai_object
    env[i_env].env.deactivate_ai_object()
    state_all = env[i_env].reset()

    if agent_id == 1:
        env[i_env].env.activate_ai_object()
    elif agent_id == 0 and ai_object:
        for i_step in range(env[0]._max_episode_steps):
            model.to_cpu()

            obs = [
                K.tensor(obs, dtype=K.float32).unsqueeze(0)
                for obs in state_all['observation']
            ]
            goal = K.tensor(state_all['desired_goal'],
                            dtype=K.float32).unsqueeze(0)

            # Observation normalization
            obs_goal = []
            obs_goal.append(K.cat([obs[0], goal], dim=-1))
            if normalizer[0] is not None:
                obs_goal[0] = normalizer[0].preprocess(obs_goal[0])

            action = model.select_action(obs_goal[0],
                                         noise[0]).cpu().numpy().squeeze(0)

            action_to_env = np.zeros_like(env[0].action_space.sample())
            action_to_env[0:action.shape[0]] = action

            next_state_all, reward, done, info = env[i_env].step(action_to_env)

            # Move to the next state
            state_all = next_state_all

            # Record frames
            if render:
                frames.append(env[i_env].render(mode='rgb_array')[0])

        env[i_env].env.activate_ai_object()

    for i_step in range(env[0]._max_episode_steps):

        model.to_cpu()

        obs = [
            K.tensor(obs, dtype=K.float32).unsqueeze(0)
            for obs in state_all['observation']
        ]
        goal = K.tensor(state_all['desired_goal'],
                        dtype=K.float32).unsqueeze(0)

        # Observation normalization
        obs_goal = []
        obs_goal.append(K.cat([obs[0], goal], dim=-1))
        if normalizer[0] is not None:
            obs_goal[0] = normalizer[0].preprocess_with_update(obs_goal[0])

        if model.n_objects <= 1:
            obs_goal.append(K.cat([obs[1], goal], dim=-1))
            if normalizer[1] is not None:
                obs_goal[1] = normalizer[1].preprocess_with_update(obs_goal[1])
        else:
            obs_goal.append(get_obj_obs(obs[1], goal, model.n_objects))
            if normalizer[1] is not None:
                for i_object in range(model.n_objects):
                    obs_goal[1][:, :, i_object] = normalizer[
                        1].preprocess_with_update(obs_goal[1][:, :, i_object])

        action = model.select_action(obs_goal[agent_id],
                                     noise[agent_id]).cpu().numpy().squeeze(0)

        if agent_id == 0:
            if ai_object:
                action_to_env = env[0].action_space.sample(
                ) * rob_policy[0] + np.ones_like(
                    env[0].action_space.sample()) * rob_policy[1]
                if model.n_objects <= 1:
                    action_to_env[action.shape[0]::] = model.get_obj_action(
                        obs_goal[1], noise[1]).cpu().numpy().squeeze(0)
                else:
                    n_obj_actions = len(
                        action_to_env[action.shape[0]::]) // model.n_objects
                    for i_object in range(model.n_objects):
                        act_slice = slice(
                            action.shape[0] + i_object * n_obj_actions,
                            action.shape[0] + (i_object + 1) * n_obj_actions)
                        action_to_env[act_slice] = model.get_obj_action(
                            obs_goal[1][:, :, i_object],
                            noise[1]).cpu().numpy().squeeze(0)
            else:
                action_to_env = np.zeros_like(env[0].action_space.sample())
                action_to_env[0:action.shape[0]] = action
        else:
            action_to_env = env[0].action_space.sample() * rob_policy[
                0] + np.ones_like(env[0].action_space.sample()) * rob_policy[1]
            action_to_env[-action.shape[0]::] = action

        next_state_all, reward, done, info = env[i_env].step(action_to_env)
        reward = K.tensor(reward, dtype=dtype).view(1, 1)

        next_obs = [
            K.tensor(next_obs, dtype=K.float32).unsqueeze(0)
            for next_obs in next_state_all['observation']
        ]

        # Observation normalization
        next_obs_goal = []
        next_obs_goal.append(K.cat([next_obs[0], goal], dim=-1))
        if normalizer[0] is not None:
            next_obs_goal[0] = normalizer[0].preprocess(next_obs_goal[0])

        if model.n_objects <= 1:
            next_obs_goal.append(K.cat([next_obs[1], goal], dim=-1))
            if normalizer[1] is not None:
                next_obs_goal[1] = normalizer[1].preprocess(next_obs_goal[1])
        else:
            next_obs_goal.append(
                get_obj_obs(next_obs[1], goal, model.n_objects))
            if normalizer[1] is not None:
                for i_object in range(model.n_objects):
                    next_obs_goal[1][:, :,
                                     i_object] = normalizer[1].preprocess(
                                         next_obs_goal[1][:, :, i_object])

        # for monitoring
        if model.object_Qfunc is None:
            episode_reward += reward
        else:
            if model.n_objects <= 1:
                if model.masked_with_r:
                    episode_reward += (
                        model.get_obj_reward(obs_goal[1], next_obs_goal[1]) *
                        K.abs(reward) + reward)
                else:
                    episode_reward += (
                        model.get_obj_reward(obs_goal[1], next_obs_goal[1]) +
                        reward)
            else:
                intrinsic_reward = K.zeros_like(reward)
                for i_object in range(model.n_objects):
                    intrinsic_reward += model.get_obj_reward(
                        obs_goal[1][:, :, i_object],
                        next_obs_goal[1][:, :, i_object])

                if model.masked_with_r:
                    episode_reward += (intrinsic_reward * K.abs(reward) +
                                       reward)
                else:
                    episode_reward += (intrinsic_reward + reward)

        for i_agent in range(2):
            state = {
                'observation': state_all['observation'][i_agent],
                'achieved_goal': state_all['achieved_goal'],
                'desired_goal': state_all['desired_goal']
            }
            next_state = {
                'observation': next_state_all['observation'][i_agent],
                'achieved_goal': next_state_all['achieved_goal'],
                'desired_goal': next_state_all['desired_goal']
            }

            trajectories[i_agent].append(
                (state.copy(), action_to_env, reward, next_state.copy(), done))

        # Move to the next state
        state_all = next_state_all

        # Record frames
        if render:
            frames.append(env[i_env].render(mode='rgb_array')[0])

    obs, ags, goals, acts = [], [], [], []

    for trajectory in trajectories:
        obs.append([])
        ags.append([])
        goals.append([])
        acts.append([])
        for i_step in range(env[0]._max_episode_steps):
            obs[-1].append(trajectory[i_step][0]['observation'])
            ags[-1].append(trajectory[i_step][0]['achieved_goal'])
            goals[-1].append(trajectory[i_step][0]['desired_goal'])
            if (i_step < env[0]._max_episode_steps - 1):
                acts[-1].append(trajectory[i_step][1])

    trajectories = {
        'o': np.concatenate(obs, axis=1)[np.newaxis, :],
        'ag': np.asarray(ags)[0:1, ],
        'g': np.asarray(goals)[0:1, ],
        'u': np.asarray(acts)[0:1, ],
    }

    return trajectories, episode_reward, info['is_success'], frames
Exemple #5
0
def rollout(
    env,
    model,
    noise,
    config,
    normalizer=None,
    render=False,
):

    trajectories = []
    for i_agent in range(2):
        trajectories.append([])

    # monitoring variables
    episode_reward = np.zeros(env.num_envs)
    frames = []

    state_all = env.reset()
    state_all = back_to_dict(state_all, config)

    for i_step in range(config['episode_length']):

        model.to_cpu()

        obs = [
            K.tensor(obs, dtype=K.float32) for obs in state_all['observation']
        ]
        goal = K.tensor(state_all['desired_goal'], dtype=K.float32)

        # Observation normalization
        obs_goal = []
        obs_goal.append(K.cat([obs[0], goal], dim=-1))
        if normalizer[0] is not None:
            obs_goal[0] = normalizer[0].preprocess_with_update(obs_goal[0])

        if model.n_objects <= 1:
            obs_goal.append(K.cat([obs[1], goal], dim=-1))
            if normalizer[1] is not None:
                obs_goal[1] = normalizer[1].preprocess(obs_goal[1])
        else:
            obs_goal.append(get_obj_obs(obs[1], goal, model.n_objects))
            if normalizer[1] is not None:
                for i_object in range(model.n_objects):
                    obs_goal[1][:, :, i_object] = normalizer[1].preprocess(
                        obs_goal[1][:, :, i_object])

        action = model.select_action(obs_goal[0], noise).cpu().numpy()

        action_to_env = np.zeros((len(action), len(env.action_space.sample())))
        action_to_env[:, 0:action.shape[1]] = action

        next_state_all, reward, done, info = env.step(action_to_env)
        next_state_all = back_to_dict(next_state_all, config)
        reward = K.tensor(reward, dtype=dtype).view(-1, 1)

        next_obs = [
            K.tensor(next_obs, dtype=K.float32)
            for next_obs in next_state_all['observation']
        ]

        # Observation normalization
        next_obs_goal = []
        next_obs_goal.append(K.cat([next_obs[0], goal], dim=-1))
        if normalizer[0] is not None:
            next_obs_goal[0] = normalizer[0].preprocess(next_obs_goal[0])

        if model.n_objects <= 1:
            next_obs_goal.append(K.cat([next_obs[1], goal], dim=-1))
            if normalizer[1] is not None:
                next_obs_goal[1] = normalizer[1].preprocess(next_obs_goal[1])
        else:
            next_obs_goal.append(
                get_obj_obs(next_obs[1], goal, model.n_objects))
            if normalizer[1] is not None:
                for i_object in range(model.n_objects):
                    next_obs_goal[1][:, :,
                                     i_object] = normalizer[1].preprocess(
                                         next_obs_goal[1][:, :, i_object])

        # for monitoring
        if model.object_Qfunc is None:
            episode_reward += reward.squeeze(1).cpu().numpy()
        else:
            if model.n_objects <= 1:
                r_intr = model.get_obj_reward(obs_goal[1], next_obs_goal[1])
            else:
                r_intr = K.zeros_like(reward)
                for i_object in range(model.n_objects):
                    r_intr += model.get_obj_reward(
                        obs_goal[1][:, :, i_object],
                        next_obs_goal[1][:, :, i_object])

            episode_reward += (r_intr + reward).squeeze(1).cpu().numpy()

        for i_agent in range(2):
            state = {
                'observation': state_all['observation'][i_agent],
                'achieved_goal': state_all['achieved_goal'],
                'desired_goal': state_all['desired_goal']
            }
            next_state = {
                'observation': next_state_all['observation'][i_agent],
                'achieved_goal': next_state_all['achieved_goal'],
                'desired_goal': next_state_all['desired_goal']
            }

            trajectories[i_agent].append(
                (state.copy(), action_to_env, reward, next_state.copy(), done))

        # Move to the next state
        state_all = next_state_all

        # Record frames
        if render:
            frames.append(env.render(mode='rgb_array')[0])

    obs, ags, goals, acts = [], [], [], []

    for trajectory in trajectories:
        obs.append([])
        ags.append([])
        goals.append([])
        acts.append([])
        for i_step in range(config['episode_length']):
            obs[-1].append(trajectory[i_step][0]['observation'])
            ags[-1].append(trajectory[i_step][0]['achieved_goal'])
            goals[-1].append(trajectory[i_step][0]['desired_goal'])
            if (i_step < config['episode_length'] - 1):
                acts[-1].append(trajectory[i_step][1])

    trajectories = {
        'o': np.concatenate(obs, axis=-1).swapaxes(0, 1),
        'ag': np.asarray(ags)[0, ].swapaxes(0, 1),
        'g': np.asarray(goals)[0, ].swapaxes(0, 1),
        'u': np.asarray(acts)[0, ].swapaxes(0, 1),
    }

    info = np.asarray([i_info['is_success'] for i_info in info])

    return trajectories, episode_reward, info, frames
Exemple #6
0
    def update_parameters(self, batch, normalizer=None):

        observation_space = self.observation_space - K.tensor(
            batch['g'], dtype=self.dtype, device=self.device).shape[1]
        action_space = self.action_space[0].shape[0]

        V = K.zeros((len(batch['o']), 1), dtype=self.dtype, device=self.device)

        s1 = K.cat([
            K.tensor(batch['o'], dtype=self.dtype,
                     device=self.device)[:, 0:observation_space],
            K.tensor(batch['g'], dtype=self.dtype, device=self.device)
        ],
                   dim=-1)

        if self.n_objects[0] <= 1:
            s2 = K.cat([
                K.tensor(batch['o'], dtype=self.dtype,
                         device=self.device)[:, observation_space:],
                K.tensor(batch['g'], dtype=self.dtype, device=self.device)
            ],
                       dim=-1)
        else:
            s2 = get_obj_obs(K.tensor(batch['o'],
                                      dtype=self.dtype,
                                      device=self.device)[:,
                                                          observation_space:],
                             K.tensor(batch['g'],
                                      dtype=self.dtype,
                                      device=self.device),
                             n_object=self.n_objects[0])

        a1 = K.tensor(batch['u'], dtype=self.dtype,
                      device=self.device)[:, 0:action_space]
        a2 = K.tensor(batch['u'], dtype=self.dtype,
                      device=self.device)[:, action_space:]

        s1_ = K.cat([
            K.tensor(batch['o_2'], dtype=self.dtype,
                     device=self.device)[:, 0:observation_space],
            K.tensor(batch['g'], dtype=self.dtype, device=self.device)
        ],
                    dim=-1)

        if self.n_objects[0] <= 1:
            s2_ = K.cat([
                K.tensor(batch['o_2'], dtype=self.dtype,
                         device=self.device)[:, observation_space:],
                K.tensor(batch['g'], dtype=self.dtype, device=self.device)
            ],
                        dim=-1)
        else:
            s2_ = get_obj_obs(K.tensor(batch['o_2'],
                                       dtype=self.dtype,
                                       device=self.device)[:,
                                                           observation_space:],
                              K.tensor(batch['g'],
                                       dtype=self.dtype,
                                       device=self.device),
                              n_object=self.n_objects[0])

        if normalizer[0] is not None:
            s1 = normalizer[0].preprocess(s1)
            s1_ = normalizer[0].preprocess(s1_)

        if normalizer[1] is not None:
            if self.n_objects[0] <= 1:
                s2 = normalizer[1].preprocess(s2)
                s2_ = normalizer[1].preprocess(s2_)
            else:
                for i_object in range(self.n_objects[0]):
                    s2[:, :, i_object] = normalizer[1].preprocess(s2[:, :,
                                                                     i_object])
                    s2_[:, :,
                        i_object] = normalizer[1].preprocess(s2_[:, :,
                                                                 i_object])

        s3 = get_obj_obs(K.tensor(batch['o'],
                                  dtype=self.dtype,
                                  device=self.device)[:, 0:observation_space],
                         K.tensor(batch['g'],
                                  dtype=self.dtype,
                                  device=self.device),
                         n_object=self.n_objects[0])
        s3 = s3[:, :, 0:self.n_objects[1]]
        s3 = get_rob_obs(s3, self.n_objects[1])

        s3_ = get_obj_obs(K.tensor(batch['o_2'],
                                   dtype=self.dtype,
                                   device=self.device)[:, 0:observation_space],
                          K.tensor(batch['g'],
                                   dtype=self.dtype,
                                   device=self.device),
                          n_object=self.n_objects[0])
        s3_ = s3_[:, :, 0:self.n_objects[1]]
        s3_ = get_rob_obs(s3_, self.n_objects[1])

        if normalizer[2] is not None:
            s3 = normalizer[2].preprocess(s3)
            s3_ = normalizer[2].preprocess(s3_)

        s, s_, a = (s1, s1_, a1) if self.agent_id == 0 else (s2, s2_, a2)
        a_ = self.actors_target[0](s_)

        r_all = []
        if self.object_Qfunc is None:
            r = K.tensor(batch['r'], dtype=self.dtype,
                         device=self.device).unsqueeze(1)
            r_all.append(r)
        else:
            r = K.tensor(batch['r'], dtype=self.dtype,
                         device=self.device).unsqueeze(1)
            r_all.append(r)
            if len(self.object_Qfunc) > 1:
                # estimated actions
                #r_intr.append(self.get_obj_reward(s3, s3_, index=1))
                # actual actions
                r_all.append(self.get_obj_reward(s3, s3_, index=1, action=a1))
            for i_object in range(self.n_objects[2], self.n_objects[0]):
                r_all.append(
                    self.get_obj_reward(s2[:, :, i_object],
                                        s2_[:, :, i_object],
                                        index=0))

        # first critic for main rewards
        Q = self.critics[0](s, a)
        V = self.critics_target[0](s_, a_).detach()

        target_Q = (V * self.gamma) + r_all[0]
        target_Q = target_Q.clamp(self.clip_Q_neg, 0.)

        loss_critic = self.loss_func(Q, target_Q)

        self.critics_optim[0].zero_grad()
        loss_critic.backward()
        self.critics_optim[0].step()

        #r_sum = K.zeros_like(r_all[0])
        #for i_object in range(self.n_objects[2], self.n_objects[1]):
        #    r_sum += r_all[i_object+2]
        #r_mask = r_sum < -0.0001
        #r_all[1] *= K.tensor(r_mask, dtype=r_all[1].dtype, device=r_all[1].device)

        # other critics for intrinsic
        for i_critic in range(1, self.n_aux_critics + 1):
            Q = self.critics[i_critic](s, a)
            V = self.critics_target[i_critic](s_, a_).detach()

            target_Q = (V * self.gamma) + r_all[i_critic]
            target_Q = target_Q.clamp(self.clip_Q_neg, 0.)

            loss_critic = self.loss_func(Q, target_Q)

            self.critics_optim[i_critic].zero_grad()
            loss_critic.backward()
            self.critics_optim[i_critic].step()

        # actor update
        a = self.actors[0](s)

        loss_actor = -self.critics[0](s, a).mean()
        for i_critic in range(1, self.n_aux_critics + 1):
            loss_actor += -self.critics[i_critic](s, a).mean()

        if self.regularization:
            loss_actor += (self.actors[0](s)**2).mean() * 1

        self.actors_optim[0].zero_grad()
        loss_actor.backward()
        self.actors_optim[0].step()

        return loss_critic.item(), loss_actor.item()