Example #1
0
    def Imitation_Learning(self, step_time, data=None, policy=None,learning_start=1000,
                           buffer_size = 5000, value_training_round = 10, value_training_fre = 2500,
                           verbose=2,render = False):
        '''
        :param data:  the data is a list, and each element is a dict with 5 keys s,a,r,s_,tr
        sample = {"s": s, "a": a, "s_": s_, "r": r, "tr": done}
        :param policy:
        :return:
        '''
        if data is not None and policy is not None:
            raise Exception("The IL only need one way to guide, Please make sure the input ")

        if data is not None:
            for time in step_time:
                self.step += 1
                loss = self.backward(data[time])
                if verbose == 1:
                    logger.record_tabular("steps", self.step)
                    logger.record_tabular("loss", loss)
                    logger.dumpkvs()

        if policy is not None:
            buffer = ReplayMemory(buffer_size)
            s = self.env.reset()
            loss_BC = 0
            ep_step,ep_reward = 0, 0
            for _ in range(step_time):
                self.step += 1
                ep_step += 1
                a = policy(self.env)
                s_, r, done, info = self.env.step(a)
                #print(r,info)
                ep_reward += r
                if render:
                    self.env.render()
                sample = {"s": s, "a": a, "s_": s_, "r": r, "tr": done}
                buffer.push(sample)
                s = s_[:]
                if self.step > learning_start:
                    sample_ = buffer.sample(self.batch_size)
                    loss = self.policy_behavior_clone(sample_)
                    if self.step % value_training_fre==0:
                        record_sample = {}
                        for key in buffer.memory.keys():
                            record_sample[key] = np.array(buffer.memory[key]).astype(np.float32)[-value_training_fre:]
                        record_sample["value"] = self.value.forward(torch.from_numpy(record_sample["s"]))
                        returns, advants = get_gae(record_sample["r"], record_sample["tr"], record_sample["value"],
                                                   self.gamma, self.lam)
                        record_sample["advs"] = advants
                        record_sample["return"] = returns
                        for round_ in range(value_training_round):
                            loss_value = self.value_pretrain(record_sample, value_training_fre)
                            print(round_, loss_value)

                    if verbose == 1:
                        logger.record_tabular("learning_steps", self.step)
                        logger.record_tabular("loss", loss)
                        logger.record_tabular("rewrad",r)
                        logger.dumpkvs()
                    loss_BC += loss
                if done:
                    if verbose == 2:
                        logger.record_tabular("learning_steps", self.step)
                        logger.record_tabular("step_used", ep_step)
                        logger.record_tabular("loss", loss_BC/ep_step)
                        logger.record_tabular("ep_reward",ep_reward )
                        logger.dumpkvs()

                    s = self.env.reset()
                    loss_BC = 0
                    ep_step,ep_reward = 0, 0
Example #2
0
class PPO_Agent(Agent_policy_based):
    def __init__(
            self,
            env,
            policy_model,
            value_model,
            lr=5e-4,
            ent_coef=0.01,
            vf_coef=0.5,
            ## hyper-parawmeter
            gamma=0.99,
            lam=0.95,
            cliprange=0.2,
            batch_size=64,
            value_train_round=200,
            running_step=2048,
            running_ep=20,
            value_regular=0.01,
            buffer_size=50000,
            ## decay
            decay=False,
            decay_rate=0.9,
            lstm_enable=False,
            ##
            path=None):
        self.gpu = False
        self.env = env
        self.gamma = gamma
        self.lam = lam
        self.ent_coef = ent_coef
        self.vf_coef = vf_coef
        self.cliprange = cliprange

        self.value_train_step = value_train_round

        self.sample_rollout = running_step
        self.sample_ep = running_ep
        self.batch_size = batch_size
        self.lstm_enable = lstm_enable
        self.replay_buffer = ReplayMemory(buffer_size,
                                          other_record=["value", "return"])

        self.loss_cal = torch.nn.SmoothL1Loss()

        self.policy = policy_model
        if value_model == "shared":
            self.value = policy_model
        elif value_model == "copy":
            self.value = deepcopy(policy_model)
        else:
            self.value = value_model

        self.dist = make_pdtype(env.action_space, policy_model)

        self.policy_model_optim = Adam(self.policy.parameters(), lr=lr)
        self.value_model_optim = Adam(self.value.parameters(),
                                      lr=lr,
                                      weight_decay=value_regular)
        if decay:
            self.policy_model_decay_optim = torch.optim.lr_scheduler.ExponentialLR(
                self.policy_model_optim, decay_rate, last_epoch=-1)
            self.value_model_decay_optim = torch.optim.lr_scheduler.ExponentialLR(
                self.value_model_optim, decay_rate, last_epoch=-1)

        #torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 1, norm_type=2)
        #torch.nn.utils.clip_grad_norm_(self.value.parameters(), 1, norm_type=2)

        super(PPO_Agent, self).__init__(path)
        #example_input = Variable(torch.rand((100,)+self.env.observation_space.shape))
        #self.writer.add_graph(self.policy, input_to_model=example_input)

        self.backward_step_show_list = ["pg_loss", "entropy", "vf_loss"]
        self.backward_ep_show_list = ["pg_loss", "entropy", "vf_loss"]

        self.training_round = 0
        self.running_step = 0
        self.record_sample = None
        self.training_step = 0

    def update(self, sample):
        step_len = len(sample["s"])
        for ki in range(step_len):
            sample_ = {
                "s": sample["s"][ki].cpu().numpy(),
                "a": sample["a"][ki].cpu().numpy(),
                "r": sample["r"][ki].cpu().numpy(),
                "tr": sample["tr"][ki].cpu().numpy(),
                "s_": sample["s_"][ki].cpu().numpy(),
                "value": sample["value"][ki].cpu().numpy(),
                "return": sample["return"][ki].cpu().numpy()
            }
            self.replay_buffer.push(sample_)
        '''
        train the value part
        '''
        vfloss_re = []
        for _ in range(self.value_train_step):
            tarin_value_sample = self.replay_buffer.sample(self.batch_size)
            for key in tarin_value_sample.keys():
                if self.gpu:
                    tarin_value_sample[key] = tarin_value_sample[key].cuda()
                else:
                    tarin_value_sample[key] = tarin_value_sample[key]
            old_value = tarin_value_sample["value"]
            training_s = tarin_value_sample["s"]
            R = tarin_value_sample["return"]
            value_now = self.value.forward(training_s).squeeze()
            # value loss
            value_clip = old_value + torch.clamp(
                old_value - value_now, min=-self.cliprange,
                max=self.cliprange)  # Clipped value
            vf_loss1 = self.loss_cal(value_now, R)  # Unclipped loss
            vf_loss2 = self.loss_cal(value_clip, R)  # clipped loss
            vf_loss = .5 * torch.max(vf_loss1, vf_loss2)
            self.value_model_optim.zero_grad()
            vf_loss1.backward()
            self.value_model_optim.step()
            vfloss_re.append(vf_loss1.cpu().detach().numpy())
        '''
        train the policy part
        '''

        for key in sample.keys():
            temp = torch.stack(list(sample[key]), 0).squeeze()
            if self.gpu:
                sample[key] = temp.cuda()
            else:
                sample[key] = temp

        array_index = []
        time_round = np.ceil(step_len / self.batch_size)
        time_left = time_round * self.batch_size - step_len
        array = list(range(step_len)) + list(range(int(time_left)))
        array_index = []
        for train_time in range(int(time_round)):
            array_index.append(
                array[train_time * self.batch_size:(train_time + 1) *
                      self.batch_size])

        loss_re, pgloss_re, enloss_re = [], [], []
        for train_time in range(int(time_round)):
            index = array_index[train_time]
            # for index in range(step_len):
            training_s = sample["s"][index].detach()
            training_a = sample["a"][index].detach()
            old_neglogp = sample["logp"][index].detach()
            advs = sample["advs"][index].detach()

            " CALCULATE THE LOSS"
            " Total loss = Policy gradient loss - entropy * entropy coefficient + Value coefficient * value loss"

            #generate Policy gradient loss
            outcome = self.policy.forward(training_s).squeeze()
            # new_neg_lop = torch.empty(size=(self.batch_size,))
            # for time in range(self.batch_size):
            #     new_policy = self.dist(outcome[time])
            #     new_neg_lop[time] = new_policy.log_prob(training_a[time])
            new_policy = self.dist(outcome)
            new_neg_lop = new_policy.log_prob(training_a)
            ratio = torch.exp(new_neg_lop - old_neglogp)
            pg_loss1 = -advs * ratio
            pg_loss2 = -advs * torch.clamp(ratio, 1.0 - self.cliprange,
                                           1.0 + self.cliprange)
            pg_loss = .5 * torch.max(pg_loss1, pg_loss2).mean()

            # entropy
            entropy = new_policy.entropy().mean()
            # loss = pg_loss - entropy * self.ent_coef + vf_loss * self.vf_coef
            loss = pg_loss - entropy * self.ent_coef
            self.policy_model_optim.zero_grad()
            loss.backward()
            self.policy_model_optim.step()
            # approxkl = self.loss_cal(neg_log_pac, self.record_sample["neglogp"])
            # self.cliprange = torch.gt(torch.abs(ratio - 1.0).mean(), self.cliprange)
            loss_re = loss.cpu().detach().numpy()
            pgloss_re.append(pg_loss.cpu().detach().numpy())
            enloss_re.append(entropy.cpu().detach().numpy())

        return np.sum(loss_re), {
            "pg_loss": np.sum(pgloss_re),
            "entropy": np.sum(enloss_re),
            "vf_loss": np.sum(vfloss_re)
        }

    def load_weights(self, filepath):
        model = torch.load(filepath + "/PPO.pkl")
        self.policy.load_state_dict(model["policy"].state_dict())
        self.value.load_state_dict(model["value"].state_dict())

    def save_weights(self, filepath, overwrite=False):
        torch.save({
            "policy": self.policy,
            "value": self.value
        }, filepath + "/PPO.pkl")

    def policy_behavior_clone(self, sample_):
        action_label = sample_["a"].squeeze()
        if self.gpu:
            action_predict = self.policy(sample_["s"].cuda())
            action_label = action_label.cuda()
        else:
            action_predict = self.policy(sample_["s"])
        loss_bc = self.loss_cal(action_label, action_predict)
        del action_label
        del action_predict
        loss = loss_bc
        self.policy_model_optim.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.policy.parameters(), 1, norm_type=2)
        self.policy_model_optim.step()
        return loss.cpu().detach().numpy()

    def value_pretrain(self, record_sample, new_sample_len):
        train_times = int(np.floor(new_sample_len / 128))
        round_loss = 0
        for io in range(train_times - 1):
            index = list(range(128 * io, 128 * (io + 1)))
            if self.gpu:
                predict = torch.from_numpy(
                    np.array(record_sample["s"])[index]).cuda()
                lable = torch.from_numpy(np.array(
                    record_sample["return"]))[index].cuda()
            else:
                predict = torch.from_numpy(np.array(record_sample["s"])[index])
                lable = torch.from_numpy(np.array(
                    record_sample["return"]))[index]
            value_now = self.value.forward(predict)
            # value loss
            vf_loss = self.loss_cal(value_now, lable)  # Unclipped loss
            del predict
            del lable
            self.value_model_optim.zero_grad()
            vf_loss.backward()
            self.value_model_optim.step()
            round_loss += vf_loss.cpu().detach().numpy()
        return round_loss

    def cuda(self, device=None):
        self.policy.to_gpu(device)
        self.value.to_gpu(device)
        self.loss_cal = self.loss_cal.cuda(device)
        self.gpu = True
Example #3
0
class DDPG_Agent(Agent_value_based):
    def __init__(
            self,
            env,
            actor_model,
            critic_model,
            actor_lr=1e-4,
            critic_lr=1e-3,
            actor_target_network_update_freq=1000,
            critic_target_network_update_freq=1000,
            actor_training_freq=1,
            critic_training_freq=1,
            sperate_critic=False,
            ## hyper-parameter
            gamma=0.99,
            batch_size=32,
            buffer_size=50000,
            learning_starts=1000,
            ## lr_decay
            decay=False,
            decay_rate=0.9,
            critic_l2_reg=1e-2,
            clip_norm=None,
            ##
            path=None):

        self.gpu = False
        self.env = env
        self.sperate_critic = sperate_critic
        self.gamma = gamma
        self.batch_size = batch_size
        self.learning_starts = learning_starts

        self.replay_buffer = ReplayMemory(buffer_size)

        self.actor_training_freq, self.critic_training_freq = actor_training_freq, critic_training_freq
        self.actor_target_network_update_freq = actor_target_network_update_freq
        self.critic_target_network_update_freq = critic_target_network_update_freq
        self.actor = actor_model
        self.critic = critic_model
        self.target_actor = deepcopy(actor_model)
        self.target_critic = deepcopy(critic_model)

        self.actor_critic = actor_critic(self.actor, self.critic, self.GCN)

        actor_optim = Adam(self.actor.parameters(), lr=actor_lr)
        critic_optim = Adam(self.critic.parameters(),
                            lr=critic_lr,
                            weight_decay=critic_l2_reg)
        if decay:
            self.actor_optim = torch.optim.lr_scheduler.ExponentialLR(
                actor_optim, decay_rate, last_epoch=-1)
            self.critic_optim = torch.optim.lr_scheduler.ExponentialLR(
                critic_optim, decay_rate, last_epoch=-1)
        else:
            self.actor_optim = actor_optim
            self.critic_optim = critic_optim

        super(DDPG_Agent, self).__init__(path)
        #example_input = Variable(torch.rand(100, self.env.observation_space.shape[0]))
        #self.writer.add_graph(self.actor_critic, input_to_model=example_input)
        self.forward_step_show_list = []
        self.backward_step_show_list = []
        self.forward_ep_show_list = []
        self.backward_ep_show_list = []

    def forward(self, observation):
        observation = observation[np.newaxis, :].astype(np.float32)
        observation = torch.from_numpy(observation)
        action = self.actor.forward(observation)
        action = torch.normal(action, torch.ones_like(action))
        if self.sperate_critic:
            Q = self.critic.forward(observation,
                                    action).squeeze().detach().numpy()
        else:
            Q = self.critic(torch.cat((observation, action),
                                      dim=1)).squeeze().detach().numpy()
        return action.cpu().squeeze(0).detach().numpy(), Q, {}

    def backward(self, sample_):
        self.replay_buffer.push(sample_)
        if self.step > self.learning_starts and self.learning:
            sample = self.replay_buffer.sample(self.batch_size)
            if self.gpu:
                for key in sample.keys():
                    sample[key] = sample[key].cuda()
            assert len(sample["s"]) == self.batch_size
            "update the critic "
            if self.step % self.critic_training_freq == 0:
                if self.sperate_critic:
                    Q = self.critic.forward(sample["s"], sample["a"])
                else:
                    input = torch.cat((sample["s"], sample["a"]), -1)
                    Q = self.critic.forward(input)
                target_a = self.target_actor(sample["s_"])
                if self.sperate_critic:
                    targetQ = self.target_critic(sample["s_"], target_a)
                else:
                    target_input = torch.cat((sample["s_"], target_a), -1)
                    targetQ = self.target_critic(target_input)
                targetQ = targetQ.squeeze(1)
                Q = Q.squeeze(1)
                expected_q_values = sample["r"] + self.gamma * targetQ * (
                    1.0 - sample["tr"])
                loss = torch.mean(huber_loss(expected_q_values - Q))
                self.critic_optim.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.critic.parameters(),
                                               1,
                                               norm_type=2)
                self.critic_optim.step()
            "training the actor"
            if self.step % self.actor_training_freq == 0:
                Q = self.actor_critic.forward(sample["s"])
                Q = -torch.mean(Q)
                self.actor_optim.zero_grad()
                Q.backward()
                torch.nn.utils.clip_grad_norm_(self.actor_critic.parameters(),
                                               1,
                                               norm_type=2)
                self.actor_optim.step()
            if self.step % self.actor_target_network_update_freq == 0:
                self.target_actor_net_update()
            if self.step % self.critic_target_network_update_freq == 0:
                self.target_critic_net_update()
            loss = loss.data.numpy()
            return loss, {}
        return 0, {}

    def target_actor_net_update(self):
        self.target_actor.load_state_dict(self.actor.state_dict())

    def target_critic_net_update(self):
        self.target_critic.load_state_dict(self.critic.state_dict())

    def load_weights(self, filepath):
        model = torch.load(filepath)
        self.actor.load_state_dict(model["actor"])
        self.critic.load_state_dict(model["critic"])
        self.target_actor.load_state_dict(model["target_actor"])
        self.target_critic.load_state_dict(model["target_critic"])
        self.actor_optim.load_state_dict(model["actor_optim"])
        self.critic_optim.load_state_dict(model["critic_optim"])

    def save_weights(self, filepath, overwrite=False):
        torch.save(
            {
                "actor": self.actor,
                "critic": self.critic,
                "target_actor": self.target_actor,
                "target_critic": self.target_critic,
                "actor_optim": self.actor_optim,
                "critic_optim": self.critic_optim
            }, filepath + "DDPG.pkl")

    def cuda(self):
        self.actor.to_gpu()
        self.critic.to_gpu()
        self.target_actor = deepcopy(self.actor)
        self.target_critic = deepcopy(self.critic)
        self.gpu = True
Example #4
0
class DQN_Agent(Agent_value_based):
    def __init__(
            self,
            env,
            model,
            policy,
            ## hyper-parameter
            gamma=0.90,
            lr=1e-3,
            batch_size=32,
            buffer_size=50000,
            learning_starts=1000,
            target_network_update_freq=500,
            ## decay
            decay=False,
            decay_rate=0.9,
            ## DDqn && DuelingDQN
            double_dqn=True,
            dueling_dqn=False,
            dueling_way="native",
            ## prioritized_replay
            prioritized_replay=False,
            prioritized_replay_alpha=0.6,
            prioritized_replay_beta0=0.4,
            prioritized_replay_beta_iters=None,
            prioritized_replay_eps=1e-6,
            param_noise=False,
            ##
            path=None):
        """

        :param env:      the GYM environment
        :param model:    the Torch NN model
        :param policy:   the policy when choosing action
        :param ep:       the MAX episode time
        :param step:     the MAx step time
         .........................hyper-parameter..................................
        :param gamma:
        :param lr:
        :param batchsize:
        :param buffer_size:
        :param target_network_update_freq:
        .........................further improve way..................................
        :param double_dqn:  whether enable DDQN
        :param dueling_dqn: whether dueling DDQN
        :param dueling_way: the Dueling DQN method
            it can choose the following three ways
            `avg`: Q(s,a;theta) = V(s;theta) + (A(s,a;theta)-Avg_a(A(s,a;theta)))
            `max`: Q(s,a;theta) = V(s;theta) + (A(s,a;theta)-max_a(A(s,a;theta)))
            `naive`: Q(s,a;theta) = V(s;theta) + A(s,a;theta)
        .........................prioritized-part..................................
        :param prioritized_replay: (bool) if True prioritized replay buffer will be used.
        :param prioritized_replay_alpha: (float)alpha parameter for prioritized replay buffer.
        It determines how much prioritization is used, with alpha=0 corresponding to the uniform case.
        :param prioritized_replay_beta0: (float) initial value of beta for prioritized replay buffer
        :param prioritized_replay_beta_iters: (int) number of iterations over which beta will be annealed from initial
            value to 1.0. If set to None equals to max_timesteps.
        :param prioritized_replay_eps: (float) epsilon to add to the TD errors when updating priorities.
        .........................imitation_learning_part..................................
        :param imitation_learning_policy:     To initial the network with the given policy
        which is supervised way to training the network
        :param IL_time:    supervised training times
        :param network_kwargs:
        """
        self.gpu = False
        self.env = env
        self.policy = policy

        self.gamma = gamma
        self.batch_size = batch_size
        self.learning_starts = learning_starts
        self.target_network_update_freq = target_network_update_freq
        self.double_dqn = double_dqn

        if dueling_dqn:
            self.Q_net = Dueling_dqn(model, dueling_way)
        else:
            self.Q_net = model

        self.target_Q_net = deepcopy(self.Q_net)

        q_net_optim = Adam(self.Q_net.parameters(), lr=lr)
        if decay:
            self.optim = torch.optim.lr_scheduler.ExponentialLR(q_net_optim,
                                                                decay_rate,
                                                                last_epoch=-1)
        else:
            self.optim = q_net_optim

        self.replay_buffer = ReplayMemory(buffer_size)
        self.learning = False
        super(DQN_Agent, self).__init__(path)
        example_input = Variable(
            torch.rand((100, ) + self.env.observation_space.shape))
        self.writer.add_graph(self.Q_net, input_to_model=example_input)
        self.forward_step_show_list = []
        self.backward_step_show_list = []
        self.forward_ep_show_list = []
        self.backward_ep_show_list = []

    def forward(self, observation):
        observation = observation[np.newaxis, :].astype(np.float32)
        observation = torch.from_numpy(observation)
        Q_value = self.Q_net.forward(observation)
        Q_value = Q_value.cpu().squeeze().detach().numpy()
        if self.policy is not None:
            action = self.policy.select_action(Q_value)
        else:
            action = np.argmax(Q_value)
        return action, np.max(Q_value), {}

    def backward(self, sample_):
        self.replay_buffer.push(sample_)
        if self.step > self.learning_starts and self.learning:
            sample = self.replay_buffer.sample(self.batch_size)
            if self.gpu:
                for key in sample.keys():
                    sample[key] = sample[key].cuda()
            assert len(sample["s"]) == self.batch_size
            a = sample["a"].long().unsqueeze(1)
            Q = self.Q_net(sample["s"]).gather(1, a)
            if self.double_dqn:
                _, next_actions = self.Q_net(sample["s_"]).max(1, keepdim=True)
                targetQ = self.target_Q_net(sample["s_"]).gather(
                    1, next_actions)
            else:
                _, next_actions = self.target_Q_net(sample["s_"]).max(
                    1, keepdim=True)
                targetQ = self.target_Q_net(sample["s_"]).gather(
                    1, next_actions)
            targetQ = targetQ.squeeze(1)
            Q = Q.squeeze(1)
            expected_q_values = sample["r"] + self.gamma * targetQ * (
                1.0 - sample["tr"])
            loss = torch.mean(huber_loss(expected_q_values - Q))
            self.optim.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.Q_net.parameters(),
                                           1,
                                           norm_type=2)
            self.optim.step()
            if self.step % self.target_network_update_freq == 0:
                self.target_net_update()
            loss = loss.data.numpy()
            return loss, {}
        return 0, {}

    def target_net_update(self):
        self.target_Q_net.load_state_dict(self.Q_net.state_dict())

    def load_weights(self, filepath):
        model = torch.load(filepath + 'DQN.pkl')
        self.Q_net.load_state_dict(model["Q_net"].state_dict())
        self.target_Q_net.load_state_dict(model["target_Q_net"].state_dict())
        # self.optim.load_state_dict(model["optim"])

    def save_weights(self, filepath, overwrite=True):
        torch.save(
            {
                "Q_net": self.Q_net,
                "target_Q_net": self.target_Q_net,
                "optim": self.optim
            }, filepath + "DQN.pkl")

    def cuda(self):
        self.Q_net = gpu_foward(self.Q_net)
        self.target_Q_net = deepcopy(self.Q_net)
        self.gpu = True
Example #5
0
class TD3_Agent(Agent_value_based):
    def __init__(
            self,
            env,
            actor_model,
            critic_model,
            actor_lr=1e-4,
            critic_lr=3e-4,
            actor_target_network_update_freq=0.1,
            critic_target_network_update_freq=0.1,
            actor_training_freq=2,
            critic_training_freq=1,
            ## hyper-parameter
            gamma=0.99,
            batch_size=32,
            buffer_size=50000,
            learning_starts=1000,
            ## decay
            decay=False,
            decay_rate=0.9,
            l2_regulization=0.01,
            ##
            path=None):

        self.gpu = False
        self.env = env
        self.gamma = gamma
        self.batch_size = batch_size
        self.learning_starts = learning_starts
        self.actor_training_freq, self.critic_training_freq = actor_training_freq, critic_training_freq
        self.actor_target_network_update_freq = actor_target_network_update_freq
        self.critic_target_network_update_freq = critic_target_network_update_freq

        self.replay_buffer = ReplayMemory(buffer_size)
        self.actor = actor_model
        self.critic = critic_build(critic_model)

        self.actor_critic = actor_critic(self.actor, self.critic)

        self.target_actor = deepcopy(self.actor)
        self.target_critic = deepcopy(self.critic)

        actor_optim = Adam(self.actor.parameters(), lr=actor_lr)
        critic_optim = Adam(self.critic.parameters(),
                            lr=critic_lr,
                            weight_decay=l2_regulization)
        if decay:
            self.actor_optim = torch.optim.lr_scheduler.ExponentialLR(
                actor_optim, decay_rate, last_epoch=-1)
            self.critic_optim = torch.optim.lr_scheduler.ExponentialLR(
                critic_optim, decay_rate, last_epoch=-1)
        else:
            self.actor_optim = actor_optim
            self.critic_optim = critic_optim

        torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 1, norm_type=2)
        torch.nn.utils.clip_grad_norm_(self.critic.parameters(),
                                       1,
                                       norm_type=2)

        super(TD3_Agent, self).__init__(path)
        example_input = Variable(
            torch.rand(100, self.env.observation_space.shape[0]))
        self.writer.add_graph(self.actor_critic, input_to_model=example_input)
        self.forward_step_show_list = []
        self.backward_step_show_list = []
        self.forward_ep_show_list = []
        self.backward_ep_show_list = []

    def forward(self, observation):
        observation = observation.astype(np.float32)
        observation = torch.from_numpy(observation)
        action = self.actor.forward(observation)
        csv_record(action.detach().numpy(), "./")
        action = torch.normal(action, torch.ones_like(action))
        Q, _ = self.critic(torch.cat((observation, action), axis=0))
        action = action.data.numpy()
        return action, Q.detach().numpy(), {}

    def backward(self, sample_):
        self.replay_buffer.push(sample_)
        if self.step > self.learning_starts and self.learning:
            sample = self.replay_buffer.sample(self.batch_size)
            if self.gpu:
                for key in sample.keys():
                    sample[key] = sample[key].cuda()
            assert len(sample["s"]) == self.batch_size
            "update the critic "
            if self.step % self.critic_training_freq == 0:
                target_a = self.target_actor(sample["s_"])
                target_input = torch.cat((sample["s_"], target_a), -1)
                Q1, Q2 = self.target_critic(target_input)
                target_Q = torch.min(Q1, Q2)
                expected_q_values = sample["r"] + self.gamma * target_Q * (
                    1.0 - sample["tr"])

                input = torch.cat((sample["s"], sample["a"]), -1)
                Q1, Q2 = self.critic(input)
                loss = torch.mean(
                    huber_loss(expected_q_values - Q1)) + torch.mean(
                        huber_loss(expected_q_values - Q2))
                self.critic.zero_grad()
                loss.backward()
                self.critic_optim.step()
            "training the actor"
            if self.step % self.actor_training_freq == 0:
                Q = self.actor_critic(sample["s"])
                Q = -torch.mean(Q)
                self.actor.zero_grad()
                Q.backward()
                self.actor_optim.step()
            self.target_net_update()
            loss = loss.data.numpy()
            return loss, {}
        return 0, {}

    def target_net_update(self):
        if self.actor_target_network_update_freq > 1:
            if self.step % self.actor_target_network_update_freq == 0:
                self.target_actor.load_state_dict(self.actor.state_dict())
        else:
            for param, target_param in zip(self.actor.parameters(),
                                           self.target_actor.parameters()):
                target_param.data.copy_(
                    self.actor_target_network_update_freq * param.data +
                    (1 - self.actor_target_network_update_freq) *
                    target_param.data)
        if self.critic_target_network_update_freq > 1:
            if self.step % self.critic_target_network_update_freq == 0:
                self.target_critic.load_state_dict(self.critic.state_dict())
        else:
            for param, target_param in zip(self.critic.parameters(),
                                           self.target_critic.parameters()):
                target_param.data.copy_(
                    self.critic_target_network_update_freq * param.data +
                    (1 - self.critic_target_network_update_freq) *
                    target_param.data)

    def load_weights(self, filepath):
        model = torch.load(filepath + "TD3.pkl")
        self.actor.load_state_dict(model["actor"])
        self.critic.load_state_dict(model["critic"])
        self.target_actor.load_state_dict(model["target_actor"])
        self.target_critic.load_state_dict(model["target_critic"])
        self.actor_optim.load_state_dict(model["actor_optim"])
        self.critic_optim.load_state_dict(model["critic_optim"])

    def save_weights(self, filepath, overwrite=False):
        torch.save(
            {
                "actor": self.actor,
                "critic": self.critic,
                "target_actor": self.target_actor,
                "target_critic": self.target_critic,
                "actor_optim": self.actor_optim,
                "critic_optim": self.critic_optim
            }, filepath + "TD3.pkl")

    def cuda(self):
        self.actor = gpu_foward(self.actor)
        self.critic = gpu_foward(self.critic)
        self.target_actor = deepcopy(self.actor)
        self.target_critic = deepcopy(self.critic)

        self.gpu = True