Esempio n. 1
0
    def update(self):
        if (self.sample_counter < self.args['update_freq']) or \
           not self.replay_buffer.can_sample(self.batch_size * self.args['episode_len']):
            return None
        self.sample_counter = 0
        self.train()
        tt = time.time()

        obs, full_act, rew, obs_next, done = \
            self.replay_buffer.sample(self.batch_size)
        #act = split_batched_array(full_act, self.act_shape)
        time_counter[-1] += time.time() - tt
        tt = time.time()

        # convert to variables
        obs_n = self._process_frames(obs)
        obs_next_n = self._process_frames(obs_next, volatile=True)
        full_act_n = Variable(torch.from_numpy(full_act)).type(FloatTensor)
        rew_n = Variable(torch.from_numpy(rew),
                         volatile=True).type(FloatTensor)
        done_n = Variable(torch.from_numpy(done),
                          volatile=True).type(FloatTensor)

        time_counter[0] += time.time() - tt
        tt = time.time()

        self.optim.zero_grad()

        # train p network
        q_val = self.net(obs_n, action=None, output_critic=True)
        p_loss = -q_val.mean().squeeze()
        p_ent = self.net.entropy().mean().squeeze()
        if self.args['ent_penalty'] is not None:
            p_loss -= self.args['ent_penalty'] * p_ent  # encourage exploration
        common.debugger.print('>> P_Loss = {}'.format(p_loss.data.mean()),
                              False)
        p_loss.backward()
        self.net.clear_critic_specific_grad(
        )  # we do not need to compute q_grad for actor!!!
        if self.grad_norm_clip is not None:
            utils.clip_grad_norm(self.net.parameters(), self.grad_norm_clip)
        self.optim.step()

        # train q network
        self.optim.zero_grad()
        common.debugger.print('Grad Stats of Q Update ...', False)
        target_q_next = self.target_net(obs_next_n, output_critic=True)
        target_q = rew_n + self.gamma * (1.0 - done_n) * target_q_next
        target_q.volatile = False
        current_q = self.net(obs_n, action=full_act_n, output_critic=True)
        q_norm = (current_q * current_q).mean().squeeze()  # l2 norm
        q_loss = F.smooth_l1_loss(
            current_q,
            target_q) + self.args['critic_penalty'] * q_norm  # huber
        common.debugger.print('>> Q_Loss = {}'.format(q_loss.data.mean()),
                              False)
        #q_loss = q_loss * 50
        q_loss.backward()

        # total_loss = q_loss + p_loss
        # grad clip
        if self.grad_norm_clip is not None:
            utils.clip_grad_norm(self.net.parameters(), self.grad_norm_clip)
        self.optim.step()

        common.debugger.print('Stats of P Network (after clip and opt)....',
                              False)
        utils.log_parameter_stats(common.debugger, self.net)

        time_counter[1] += time.time() - tt
        tt = time.time()

        # update target networks
        make_update_exp(self.net,
                        self.target_net,
                        rate=self.target_update_rate)

        common.debugger.print('Stats of Target Network (After Update)....',
                              False)
        utils.log_parameter_stats(common.debugger, self.target_net)

        time_counter[2] += time.time() - tt

        return dict(policy_loss=p_loss.data.cpu().numpy()[0],
                    policy_entropy=p_ent.data.cpu().numpy()[0],
                    critic_norm=q_norm.data.cpu().numpy()[0],
                    critic_loss=q_loss.data.cpu().numpy()[0])
Esempio n. 2
0
    def update(self):
        if (self.sample_counter < self.args['update_freq']) or \
           not self.replay_buffer.can_sample(self.batch_size * self.args['episode_len']):
            return None
        self.sample_counter = 0
        self.train()
        tt = time.time()

        obs, full_act, rew, obs_next, done = \
            self.replay_buffer.sample(self.batch_size)
        #act = split_batched_array(full_act, self.act_shape)
        time_counter[-1] += time.time() - tt
        tt = time.time()

        # convert to variables
        obs_n = self._process_frames(obs)
        obs_next_n = self._process_frames(obs_next, volatile=True)
        full_act_n = Variable(torch.from_numpy(full_act)).type(FloatTensor)
        rew_n = Variable(torch.from_numpy(rew),
                         volatile=True).type(FloatTensor)
        done_n = Variable(torch.from_numpy(done),
                          volatile=True).type(FloatTensor)

        time_counter[0] += time.time() - tt
        tt = time.time()

        # train q network
        common.debugger.print('Grad Stats of Q Update ...', False)
        target_act_next = self.target_p(obs_next_n)
        target_q_next = self.target_q(obs_next_n, target_act_next)
        target_q = rew_n + self.gamma * (1.0 - done_n) * target_q_next
        target_q.volatile = False
        current_q = self.q(obs_n, full_act_n)
        q_norm = (current_q * current_q).mean().squeeze()  # l2 norm
        q_loss = F.smooth_l1_loss(
            current_q,
            target_q) + self.args['critic_penalty'] * q_norm  # huber

        common.debugger.print('>> Q_Loss = {}'.format(q_loss.data.mean()),
                              False)

        self.q_optim.zero_grad()
        q_loss.backward()

        common.debugger.print('Stats of Q Network (*before* clip and opt)....',
                              False)
        utils.log_parameter_stats(common.debugger, self.q)

        if self.grad_norm_clip is not None:
            #nn.utils.clip_grad_norm(self.q.parameters(), self.grad_norm_clip)
            utils.clip_grad_norm(self.q.parameters(), self.grad_norm_clip)
        self.q_optim.step()

        # train p network
        new_act_n = self.p(obs_n)  # NOTE: maybe use <gumbel_noise=None> ?
        q_val = self.q(obs_n, new_act_n)
        p_loss = -q_val.mean().squeeze()
        p_ent = self.p.entropy().mean().squeeze()
        if self.args['ent_penalty'] is not None:
            p_loss -= self.args['ent_penalty'] * p_ent  # encourage exploration

        common.debugger.print('>> P_Loss = {}'.format(p_loss.data.mean()),
                              False)

        self.p_optim.zero_grad()
        self.q_optim.zero_grad()  # important!! clear the grad in Q
        p_loss.backward()

        if self.grad_norm_clip is not None:
            #nn.utils.clip_grad_norm(self.p.parameters(), self.grad_norm_clip)
            utils.clip_grad_norm(self.p.parameters(), self.grad_norm_clip)
        self.p_optim.step()

        common.debugger.print(
            'Stats of Q Network (in the phase of P-Update)....', False)
        utils.log_parameter_stats(common.debugger, self.q)
        common.debugger.print('Stats of P Network (after clip and opt)....',
                              False)
        utils.log_parameter_stats(common.debugger, self.p)

        time_counter[1] += time.time() - tt
        tt = time.time()

        # update target networks
        make_update_exp(self.p, self.target_p, rate=self.target_update_rate)
        make_update_exp(self.q, self.target_q, rate=self.target_update_rate)

        common.debugger.print('Stats of Q Target Network (After Update)....',
                              False)
        utils.log_parameter_stats(common.debugger, self.target_q)
        common.debugger.print('Stats of P Target Network (After Update)....',
                              False)
        utils.log_parameter_stats(common.debugger, self.target_p)

        time_counter[2] += time.time() - tt

        return dict(policy_loss=p_loss.data.cpu().numpy()[0],
                    policy_entropy=p_ent.data.cpu().numpy()[0],
                    critic_norm=q_norm.data.cpu().numpy()[0],
                    critic_loss=q_loss.data.cpu().numpy()[0])
Esempio n. 3
0
    def update(self):
        if (self.a is not None) or \
           not self.replay_buffer.can_sample(self.batch_size * 4):
            return None
        self.sample_counter = 0
        self.train()
        tt = time.time()

        obs, full_act, rew, msk, done, total_length = \
            self.replay_buffer.sample(self.batch_size, seq_len=self.batch_len)
        total_length = float(total_length)
        #act = split_batched_array(full_act, self.act_shape)
        time_counter[-1] += time.time() - tt
        tt = time.time()

        # convert to variables
        _full_obs_n = self._process_frames(
            obs, merge_dim=False,
            return_variable=False)  # [batch, seq_len+1, ...]
        batch = _full_obs_n.size(0)
        seq_len = _full_obs_n.size(1) - 1
        full_obs_n = Variable(_full_obs_n, volatile=True)
        obs_n = Variable(
            _full_obs_n[:, :-1, ...]).contiguous()  # [batch, seq_len, ...]
        obs_next_n = Variable(_full_obs_n[:, 1:, ...],
                              volatile=True).contiguous()
        img_c, img_h, img_w = obs_n.size(-3), obs_n.size(-2), obs_n.size(-1)
        packed_obs_n = obs_n.view(-1, img_c, img_h, img_w)
        packed_obs_next_n = obs_next_n.view(-1, img_c, img_h, img_w)
        full_act_n = Variable(torch.from_numpy(full_act)).type(
            FloatTensor)  # [batch, seq_len, ...]
        act_padding = Variable(
            torch.zeros(self.batch_size, 1,
                        full_act_n.size(-1))).type(FloatTensor)
        pad_act_n = torch.cat([act_padding, full_act_n],
                              dim=1)  # [batch, seq_len+1, ...]
        rew_n = Variable(torch.from_numpy(rew),
                         volatile=True).type(FloatTensor)
        msk_n = Variable(torch.from_numpy(msk)).type(
            FloatTensor)  # [batch, seq_len]
        done_n = Variable(torch.from_numpy(done)).type(
            FloatTensor)  # [batch, seq_len]

        time_counter[0] += time.time() - tt
        tt = time.time()

        # train q network
        common.debugger.print('Grad Stats of Q Update ...', False)

        full_target_act, _ = self.target_p(
            full_obs_n, act=pad_act_n)  # list([batch, seq_len+1, act_dim])
        target_act_next = torch.cat(full_target_act, dim=-1)[:, 1:, :]
        act_dim = target_act_next.size(-1)
        target_act_next = target_act_next.resize(batch * seq_len, act_dim)

        target_q_next = self.target_q(packed_obs_next_n,
                                      act=target_act_next)  #[batch * seq_len]
        target_q_next.view(batch, seq_len)
        target_q = (rew_n + self.gamma * done_n * target_q_next) * msk_n
        target_q = target_q.view(-1)
        target_q.volatile = False

        current_q = self.q(packed_obs_n, act=full_act_n.view(
            -1, act_dim)) * msk_n.view(-1)
        q_norm = (current_q * current_q).sum() / total_length  # l2 norm
        q_loss = F.smooth_l1_loss(current_q, target_q, size_average=False) / total_length \
                 + self.args['critic_penalty']*q_norm  # huber

        common.debugger.print('>> Q_Loss = {}'.format(q_loss.data.mean()),
                              False)

        self.q_optim.zero_grad()
        q_loss.backward()

        common.debugger.print('Stats of Q Network (*before* clip and opt)....',
                              False)
        utils.log_parameter_stats(common.debugger, self.q)

        if self.grad_norm_clip is not None:
            #nn.utils.clip_grad_norm(self.q.parameters(), self.grad_norm_clip)
            utils.clip_grad_norm(self.q.parameters(), self.grad_norm_clip)
        self.q_optim.step()

        # train p network
        new_act_n, _ = self.p(
            obs_n, act=pad_act_n[:, :-1, :])  # [batch, seq_len, act_dim]
        new_act_n = torch.cat(new_act_n, dim=-1)
        new_act_n = new_act_n.view(-1, act_dim)
        q_val = self.q(packed_obs_n, new_act_n) * msk_n.view(-1)
        p_loss = -q_val.sum() / total_length
        p_ent = self.p.entropy(weight=msk_n).sum() / total_length
        if self.args['ent_penalty'] is not None:
            p_loss -= self.args['ent_penalty'] * p_ent  # encourage exploration

        common.debugger.print('>> P_Loss = {}'.format(p_loss.data.mean()),
                              False)

        self.p_optim.zero_grad()
        self.q_optim.zero_grad()  # important!! clear the grad in Q
        p_loss.backward()

        if self.grad_norm_clip is not None:
            #nn.utils.clip_grad_norm(self.p.parameters(), self.grad_norm_clip)
            utils.clip_grad_norm(self.p.parameters(), self.grad_norm_clip)
        self.p_optim.step()

        common.debugger.print(
            'Stats of Q Network (in the phase of P-Update)....', False)
        utils.log_parameter_stats(common.debugger, self.q)
        common.debugger.print('Stats of P Network (after clip and opt)....',
                              False)
        utils.log_parameter_stats(common.debugger, self.p)

        time_counter[1] += time.time() - tt
        tt = time.time()

        # update target networks
        make_update_exp(self.p, self.target_p, rate=self.target_update_rate)
        make_update_exp(self.q, self.target_q, rate=self.target_update_rate)

        common.debugger.print('Stats of Q Target Network (After Update)....',
                              False)
        utils.log_parameter_stats(common.debugger, self.target_q)
        common.debugger.print('Stats of P Target Network (After Update)....',
                              False)
        utils.log_parameter_stats(common.debugger, self.target_p)

        time_counter[2] += time.time() - tt

        return dict(policy_loss=p_loss.data.cpu().numpy()[0],
                    policy_entropy=p_ent.data.cpu().numpy()[0],
                    critic_norm=q_norm.data.cpu().numpy()[0],
                    critic_loss=q_loss.data.cpu().numpy()[0])
Esempio n. 4
0
    def update(self):
        if (self.sample_counter < self.args['update_freq']) or \
           not self.replay_buffer.can_sample(self.batch_size * min(self.args['update_freq'], 20)):
            return None
        self._update_counter += 1
        self.sample_counter = 0
        self.train()
        tt = time.time()

        obs, act, rew, obs_next, done = \
            self.replay_buffer.sample(self.batch_size)
        if self.multi_target:
            target_idx = self.target_buffer[self.replay_buffer._idxes]
            targets = np.zeros((self.batch_size, common.n_target_instructions), dtype=np.uint8)
            targets[list(range(self.batch_size)), target_idx] = 1
        #act = split_batched_array(full_act, self.act_shape)
        time_counter[-1] += time.time() - tt
        tt = time.time()

        # convert to variables
        obs_n = self._process_frames(obs)
        obs_next_n = self._process_frames(obs_next, volatile=True)
        act_n = Variable(torch.from_numpy(act)).type(LongTensor)
        rew_n = Variable(torch.from_numpy(rew), volatile=True).type(FloatTensor)
        done_n = Variable(torch.from_numpy(done), volatile=True).type(FloatTensor)
        if self.multi_target:
            target_n = Variable(torch.from_numpy(targets).type(FloatTensor))
        else:
            target_n = None

        time_counter[0] += time.time() - tt
        tt = time.time()

        # compute critic loss
        target_q_val_next = self.target_net(obs_next_n, only_q_value=True, target=target_n)
        # double Q learning
        target_act_next = torch.max(self.net(obs_next_n, only_q_value=True, target=target_n), dim=1, keepdim=True)[1]
        target_q_next = torch.gather(target_q_val_next, 1, target_act_next).squeeze()
        target_q = rew_n + self.gamma * (1.0 - done_n) * target_q_next
        target_q.volatile=False
        current_q_val = self.net(obs_n, only_q_value=True, target=target_n)
        current_q = torch.gather(current_q_val, 1, act_n.view(-1, 1)).squeeze()
        q_norm = (current_q * current_q).mean().squeeze()
        q_loss = F.smooth_l1_loss(current_q, target_q)

        common.debugger.print('>> Q_Loss = {}'.format(q_loss.data.mean()), False)
        common.debugger.print('>> Q_Norm = {}'.format(q_norm.data.mean()), False)

        total_loss = q_loss.mean()
        if self.args['critic_penalty'] > 1e-10:
            total_loss += self.args['critic_penalty']*q_norm

        # compute gradient
        self.optim.zero_grad()
        #autograd.backward([total_loss, current_act], [torch.ones(1), None])
        total_loss.backward()
        if self.grad_norm_clip is not None:
            #nn.utils.clip_grad_norm(self.q.parameters(), self.grad_norm_clip)
            utils.clip_grad_norm(self.net.parameters(), self.grad_norm_clip)
        self.optim.step()
        common.debugger.print('Stats of Model (*after* clip and opt)....', False)
        utils.log_parameter_stats(common.debugger, self.net)

        time_counter[1] += time.time() -tt
        tt =time.time()

        # update target networks
        if self.target_net_update_freq is not None:
            if self._update_counter == self.target_net_update_freq:
                self._update_counter = 0
                self.target_net.load_state_dict(self.net.state_dict())
        else:
            make_update_exp(self.net, self.target_net, rate=self.target_update_rate)
        common.debugger.print('Stats of Target Network (After Update)....', False)
        utils.log_parameter_stats(common.debugger, self.target_net)

        time_counter[2] += time.time()-tt

        return dict(critic_norm=q_norm.data.cpu().numpy()[0],
                    critic_loss=q_loss.data.cpu().numpy()[0])
Esempio n. 5
0
    def update(self, cpu_batch, gpu_batch):

        #print('[elf_ddpg] update!!!!')
        self.update_counter += 1
        self.train()
        tt = time.time()

        obs_n, obs_next_n, full_act_n, rew_n, done_n = self._process_elf_frames(
            gpu_batch, keep_time=False)  # collapse all the samples
        obs_n = (obs_n.type(FloatTensor) - 128.0) / 256.0
        obs_n = Variable(obs_n)
        obs_next_n = (obs_next_n.type(FloatTensor) - 128.0) / 256.0
        obs_next_n = Variable(obs_next_n, volatile=True)
        full_act_n = Variable(full_act_n)
        rew_n = Variable(rew_n, volatile=True)
        done_n = Variable(done_n, volatile=True)

        self.sample_counter += obs_n.size(0)

        time_counter[0] += time.time() - tt

        #print('[elf_ddpg] data loaded!!!!!')

        tt = time.time()

        self.optim.zero_grad()

        # train p network
        q_val = self.net(obs_n, action=None, output_critic=True)
        p_loss = -q_val.mean().squeeze()
        p_ent = self.net.entropy().mean().squeeze()
        if self.args['ent_penalty'] is not None:
            p_loss -= self.args['ent_penalty'] * p_ent  # encourage exploration
        common.debugger.print('>> P_Loss = {}'.format(p_loss.data.mean()),
                              False)
        p_loss.backward()
        self.net.clear_critic_specific_grad(
        )  # we do not need to compute q_grad for actor!!!

        # train q network
        common.debugger.print('Grad Stats of Q Update ...', False)
        target_q_next = self.target_net(obs_next_n, output_critic=True)
        target_q = rew_n + self.gamma * (1.0 - done_n) * target_q_next
        target_q.volatile = False
        current_q = self.net(obs_n, action=full_act_n, output_critic=True)
        q_norm = (current_q * current_q).mean().squeeze()  # l2 norm
        q_loss = F.smooth_l1_loss(
            current_q,
            target_q) + self.args['critic_penalty'] * q_norm  # huber
        common.debugger.print('>> Q_Loss = {}'.format(q_loss.data.mean()),
                              False)
        q_loss = q_loss * self.q_loss_coef
        q_loss.backward()

        # total_loss = q_loss + p_loss
        # grad clip
        if self.grad_norm_clip is not None:
            utils.clip_grad_norm(self.net.parameters(), self.grad_norm_clip)
        self.optim.step()

        common.debugger.print('Stats of P Network (after clip and opt)....',
                              False)
        utils.log_parameter_stats(common.debugger, self.net)

        time_counter[1] += time.time() - tt
        tt = time.time()

        # update target networks
        make_update_exp(self.net,
                        self.target_net,
                        rate=self.target_update_rate)

        common.debugger.print('Stats of Target Network (After Update)....',
                              False)
        utils.log_parameter_stats(common.debugger, self.target_net)

        time_counter[2] += time.time() - tt

        stats = dict(policy_loss=p_loss.data.cpu().numpy()[0],
                     policy_entropy=p_ent.data.cpu().numpy()[0],
                     critic_norm=q_norm.data.cpu().numpy()[0],
                     critic_loss=q_loss.data.cpu().numpy()[0] /
                     self.q_loss_coef,
                     eplen=cpu_batch[-1]['stats_eplen'].mean(),
                     avg_rew=cpu_batch[-1]['stats_rew'].mean())
        self.print_log(stats)
Esempio n. 6
0
    def update(self):
        if (self.sample_counter < self.args['update_freq']) or \
           not self.replay_buffer.can_sample(self.batch_size * min(self.args['update_freq'], 10)):
            return None
        self.sample_counter = 0
        self.train()
        tt = time.time()

        obs, act, rew, obs_next, done = \
            self.replay_buffer.sample(self.batch_size)
        #act = split_batched_array(full_act, self.act_shape)
        time_counter[-1] += time.time() - tt
        tt = time.time()

        # convert to variables
        obs_n = self._process_frames(obs)
        obs_next_n = self._process_frames(obs_next, volatile=True)
        act_n = torch.from_numpy(act).type(LongTensor)
        rew_n = Variable(torch.from_numpy(rew), volatile=True).type(FloatTensor)
        done_n = Variable(torch.from_numpy(done), volatile=True).type(FloatTensor)

        time_counter[0] += time.time() - tt
        tt = time.time()

        # compute critic loss
        target_q_next = self.target_net(obs_next_n, only_value=True)
        target_q = rew_n + self.gamma * (1.0 - done_n) * target_q_next
        target_q.volatile=False
        current_act, current_q = self.net(obs_n, return_value=True)
        q_norm = (current_q * current_q).mean().squeeze()
        q_loss = F.smooth_l1_loss(current_q, target_q)

        common.debugger.print('>> Q_Loss = {}'.format(q_loss.data.mean()), False)
        common.debugger.print('>> Q_Norm = {}'.format(q_norm.data.mean()), False)

        total_loss = q_loss.mean()
        if self.args['critic_penalty'] > 1e-10:
            total_loss += self.args['critic_penalty']*q_norm

        # compute policy loss
        # NOTE: currently 1-step lookahead!!! TODO: multiple-step lookahead
        raw_adv_ts = (rew_n - current_q).data
        #raw_adv_ts = (target_q - current_q).data   # use estimated advantage??
        adv_ts = (raw_adv_ts - raw_adv_ts.mean()) / (raw_adv_ts.std() + 1e-15)
        #current_act.reinforce(adv_ts)
        p_ent = self.net.entropy().mean()
        p_loss = self.net.logprob(act_n)
        p_loss = p_loss * Variable(adv_ts)
        p_loss = p_loss.mean()
        total_loss -= p_loss
        if self.args['ent_penalty'] is not None:
            total_loss -= self.args['ent_penalty'] * p_ent  # encourage exploration
        common.debugger.print('>> P_Loss = {}'.format(p_loss.data.mean()), False)
        common.debugger.print('>> P_Entropy = {}'.format(p_ent.data.mean()), False)

        # compute gradient
        self.optim.zero_grad()
        #autograd.backward([total_loss, current_act], [torch.ones(1), None])
        total_loss.backward()
        if self.grad_norm_clip is not None:
            #nn.utils.clip_grad_norm(self.q.parameters(), self.grad_norm_clip)
            utils.clip_grad_norm(self.net.parameters(), self.grad_norm_clip)
        self.optim.step()
        common.debugger.print('Stats of Model (*after* clip and opt)....', False)
        utils.log_parameter_stats(common.debugger, self.net)

        time_counter[1] += time.time() -tt
        tt =time.time()

        # update target networks
        make_update_exp(self.net, self.target_net, rate=self.target_update_rate)
        common.debugger.print('Stats of Target Network (After Update)....', False)
        utils.log_parameter_stats(common.debugger, self.target_net)

        time_counter[2] += time.time()-tt

        return dict(policy_loss=p_loss.data.cpu().numpy()[0],
                    policy_entropy=p_ent.data.cpu().numpy()[0],
                    critic_norm=q_norm.data.cpu().numpy()[0],
                    critic_loss=q_loss.data.cpu().numpy()[0])