예제 #1
0
class Agent():
    def __init__(self, Env_dim, Nb_action):
        self.memory = Buffer(Memory_size)
        self.eval_nn = Network(Env_dim, Nb_action)
        self.target_nn = Network(Env_dim, Nb_action)
        self.optimizer = torch.optim.Adam(self.eval_nn.parameters(),
                                          lr=Learning_rate)
        self.criterion = nn.MSELoss(reduction='sum')
        self.counter = 0
        self.target_nn.fc1 = self.eval_nn.fc1
        self.target_nn.fc2 = self.eval_nn.fc2
        self.target_nn.out = self.eval_nn.out

    def choose_action(self, s):
        s = torch.unsqueeze(torch.FloatTensor(s), 0)
        return self.eval_nn(s)[0].detach()  # ae(s)

    def getSample(self):
        return self.memory.sample(Batch_size)

    def optimize_model(self, file):
        if self.memory.get_nb_elements() >= Batch_size:
            batch = self.memory.sample(Batch_size)
            for s, a, s_, r, done in batch:
                qValues = (self.eval_nn(torch.tensor(s).float()))[a]
                qValues_ = self.target_nn(torch.tensor(s_).float())
                qValues_target = Gamma * torch.max(qValues_)
                JO = pow(qValues - (r + (qValues_target * (1 - done))), 2)
                loss = self.criterion(qValues, JO)
                self.optimizer.zero_grad()
                # if i != Batch_size - 1:
                #     loss.backward(retain_graph=True)
                # else:
                #     loss.backward()
                loss.backward()
                self.optimizer.step()
            self.counter += 1
            if self.counter % Refresh_gap == 0:
                torch.save(self.eval_nn, file)
                self.target_nn.fc1 = self.eval_nn.fc1
                self.target_nn.fc2 = self.eval_nn.fc2
                self.target_nn.out = self.eval_nn.out

    def store_transition(self, value):
        self.memory.insert(value)
예제 #2
0
파일: HCA.py 프로젝트: Pechckin/HCA
class HCA:
    def __init__(self, episodes, trajectory, alpha_actor, alpha_credit, gamma):

        states = env.observation_space.shape[0]
        actions = env.action_space.shape[0]
        self.low = env.action_space.low[0]
        self.high = env.action_space.high[0]

        self.gamma = gamma
        self.alpha_actor = alpha_actor
        self.alpha_credit = alpha_credit
        self.episodes = episodes
        self.memory = Buffer(trajectory)

        self.policy = Policy(states, actions).apply(self.weights)
        self.credit = Credit(states, actions).apply(self.weights)

        self.policy_optim = optim.Adam(self.policy.parameters(),
                                       lr=self.alpha_actor)
        self.credit_optim = optim.Adam(self.credit.parameters(),
                                       lr=self.alpha_credit)
        self.credit_loss = nn.CrossEntropyLoss()

    def discount(self, rewards):
        R = 0.0
        returns = []
        for r in rewards.numpy()[::-1]:
            R = r + self.gamma * R
            returns.insert(0, R)
        returns = torch.tensor(returns)
        returns = (returns - returns.mean()) / (returns.std() + 1e-10)
        return returns

    @staticmethod
    def weights(layer):
        if isinstance(layer, nn.Linear):
            nn.init.xavier_normal_(layer.weight)

    def update(self):
        if not self.memory.full():
            return
        batch = self.memory.sample()

        Zs = self.discount(batch.r)

        # Policy
        mu_policy, sigma_policy = self.policy(batch.s)
        log_prob_policy = Normal(mu_policy, sigma_policy).log_prob(
            batch.a).mean(dim=1, keepdims=True)

        # Credit
        mu_credit, sigma_credit = self.credit(batch.s, Zs)
        log_prob_credit = Normal(mu_credit, sigma_credit).log_prob(
            batch.a).mean(dim=1, keepdims=True)

        ratio = torch.exp(log_prob_policy - log_prob_credit.detach())
        A = (1 - ratio) * Zs.unsqueeze(1)
        policy_loss = -(A.T @ log_prob_policy) / batch.r.size(0)
        self.policy_optim.zero_grad()
        policy_loss.backward()

        credit_loss = -torch.mean(log_prob_policy.detach() * log_prob_credit)

        self.credit_optim.zero_grad()
        credit_loss.backward()

        torch.nn.utils.clip_grad_norm_(self.policy.parameters(), 0.7)
        torch.nn.utils.clip_grad_norm_(self.credit.parameters(), 0.7)

        self.policy_optim.step()
        self.credit_optim.step()

    def act(self, state):
        state = torch.FloatTensor(state)
        with torch.no_grad():
            mu, sigma = self.policy(state)
        return torch.clamp(Normal(mu, sigma).sample(),
                           min=self.low,
                           max=self.high)
예제 #3
0
class RDQN(object):

    # Initialize Buffer, Networks, global variables
    # and configure CUDA
    def __init__(self,
                 env,
                 buffer_size=10000,
                 active_cuda=True,
                 nb_episodes=2000,
                 max_steps=3500,
                 discount_factor=0.995,
                 epsilon_greedy_end=0.01,
                 epsilon_greedy_start=0.1,
                 batch_size=128,
                 update_target=10,
                 env_type="Unity",
                 train=True,
                 save_episode=800,
                 skip_frame=4,
                 stack_size=4,
                 nb_episodes_decay=100,
                 save_path="gym_cartpole",
                 nb_action=2,
                 lr=0.002,
                 weight_decay=1e-6,
                 update_plot=10,
                 rgb=False,
                 seq_len=8,
                 nb_samples_episodes=4):

        # Global parameters
        self.env = env
        self.nb_episodes = nb_episodes
        self.max_steps = max_steps
        self.discount_factor = discount_factor
        self.batch_size = batch_size
        self.update_target = update_target
        self.env_type = env_type
        self.save_episode = save_episode
        self.skip_frame = skip_frame
        self.save_path = save_path
        self.nb_episodes_decay = nb_episodes_decay
        self.nb_action = nb_action
        self.lr = lr
        self.weight_decay = weight_decay
        self.buffer_size = buffer_size
        self.update_plot = update_plot
        self.nb_channel = 3 if rgb else 1
        self.epsilon_greedy_start = epsilon_greedy_start
        self.epsilon_greedy_end = epsilon_greedy_end
        self.seq_len = seq_len
        self.episode_iterator = 0
        self.nb_samples_episodes = nb_samples_episodes
        # Log to see improvment
        self.log_cumulative_reward = []
        self.log_loss = []

        #################### PSEUDO CODE STEPS ############################

        # Initialize replay memory D
        self.buffer = Episode_Buffer(self.buffer_size, self.seq_len)

        # Initialize Q policy network and Q target network
        self.Q_policy_net = RDQN_net(self.nb_action)
        self.Q_target_net = RDQN_net(self.nb_action)

        # Copy policy weight to target weight
        self.Q_target_net.load_state_dict(self.Q_policy_net.state_dict())

        ############### PYTORCH SPECIFIC INITIALIZATION ###################

        # Adapt to cuda
        self.active_cuda = active_cuda
        if active_cuda:
            self.Q_policy_net.cuda()
            self.Q_target_net.cuda()

        self.FloatTensor = torch.cuda.FloatTensor if active_cuda else torch.FloatTensor
        self.LongTensor = torch.cuda.LongTensor if active_cuda else torch.LongTensor
        self.ByteTensor = torch.cuda.ByteTensor if active_cuda else torch.ByteTensor
        self.Tensor = self.FloatTensor

        # Use RMSProp DeepMind's parameters
        self.optimizer = torch.optim.RMSprop(self.Q_policy_net.parameters(),
                                             lr=self.lr,
                                             weight_decay=self.weight_decay)
        # Init class to process each fram (just call gym_screen_processing.get_screen() to have the processed screen)
        self.gym_screen_processing = GymScreenProcessing(self.env, active_cuda)

    def train_loop(self, retrain=False):

        self.update_epislon_greedy()

        print("Train")
        if (self.episode_iterator >= self.nb_episodes):
            if (not retrain):
                "Please pass retrain parameter if you want to retrain the model. Warning: You will loose everything if " \
                "you choose to retrain your network."
                return
        for current_episode in range(self.episode_iterator, self.nb_episodes):

            self.buffer.new_episode(self.episode_iterator)

            cumulative_reward = 0

            self.env.reset()
            state = self.get_screen()

            print("Episode " + str(self.episode_iterator))

            hx = None
            cx = None

            # Initialize sequence s1 and preprocess (We take difference between two next frame)
            for t in range(0, self.max_steps):

                if (t % self.skip_frame == 0):
                    # Select epsilon greedy action
                    action, hx, cx = self.select_action(
                        Variable(state, volatile=True), hx, cx)
                    # Process the action to the environment
                    env_action = self.get_env_action(action)

                    _, reward, done, _ = self.env.step(env_action)

                    cumulative_reward += reward

                    reward = self.Tensor([reward])

                    next_state = self.get_screen()

                    if not done:
                        not_done_mask = self.ByteTensor(1).fill_(1)
                    else:
                        next_state = None
                        not_done_mask = self.ByteTensor(1).fill_(0)
                        #reward = self.Tensor([-1])

                    self.buffer.push(state, action, next_state, reward,
                                     not_done_mask, self.episode_iterator)

                    self.learn()

                    state = next_state

                    if done:
                        print("Done")
                        break
                else:
                    self.env.step(env_action)

            print("Episode cumulative reward: ")
            print(cumulative_reward)

            if self.episode_iterator % self.save_episode == 0 and self.episode_iterator != 0:
                print("Save parameters checkpoint:")
                self.save()
                print("End saving")

            if self.episode_iterator % self.update_plot == 0:
                self.save_plot()

            self.episode_iterator += 1
            self.update_epislon_greedy()

            if current_episode % self.update_target == 0:
                self.Q_target_net.load_state_dict(
                    self.Q_policy_net.state_dict())

            self.log_cumulative_reward.append(cumulative_reward)

    ################################################ LEARNING FUNCTIONS ################################################

    # Gradient descent on (yi - Q_target(state))^2
    def learn(self):
        if (self.buffer.hasAtLeast(self.nb_samples_episodes)):

            samples, nb_episodes = self.buffer.sample(self.nb_samples_episodes)

            # At least 1 sampled episode
            if (nb_episodes > 0):

                # Here batches and sequence are mixed like that:
                #  episode 1 t_1
                #  episode 2 t_1
                #  episode.. t_1
                #  episode n t_1
                #  episode 1 t_m
                #  episode 2 t_m
                #  episode.. t_m
                #  episode n t_m
                [
                    batch_state, batch_action, batch_reward, batch_next_state,
                    not_done_batch
                ] = Transition(*zip(*samples))
                batch_state = Variable(torch.cat(batch_state, dim=0))
                batch_action = Variable(torch.cat(batch_action))
                batch_reward = Variable(torch.cat(batch_reward))
                #batch_next_state = Variable(torch.cat(batch_next_state, dim = 0))
                not_done_batch = self.ByteTensor(torch.cat(not_done_batch))

                non_final_next_states = Variable(torch.cat([
                    s if s is not None else torch.zeros(1, 1, 84, 84).type(
                        self.Tensor) for s in batch_next_state
                ]),
                                                 volatile=True)
                Q_s_t_a, (_, _) = self.Q_policy_net(batch_state,
                                                    batch_size=nb_episodes,
                                                    seq_length=self.seq_len)
                Q_s_t_a = Q_s_t_a.gather(1, batch_action)

                Q_s_next_t_a_result, (_, _) = self.Q_target_net(
                    non_final_next_states,
                    batch_size=nb_episodes,
                    seq_length=self.seq_len)
                Q_s_next_t_a = Q_s_next_t_a_result.max(1)[0]
                Q_s_next_t_a[1 - not_done_batch] = 0

                # Target Q_s_t_a value (like supervised learning )
                target_state_value = (Q_s_next_t_a *
                                      self.discount_factor) + batch_reward
                target_state_value.detach_()

                target_state_value = Variable(
                    target_state_value.data).unsqueeze_(1)

                assert Q_s_t_a.shape == target_state_value.shape

                loss = F.smooth_l1_loss(Q_s_t_a, target_state_value)

                # Optimize the model
                self.optimizer.zero_grad()
                loss.backward()

                self.log_loss.append(loss.data[0])

                for param in self.Q_policy_net.parameters():
                    param.grad.data.clamp_(-1, 1)
                self.optimizer.step()

    def select_action(self, state, hx, cx):
        # Greedy action
        if (np.random.uniform() > self.epsilon_greedy):
            Q_policy_values, (hx, cx) = self.Q_policy_net.forward(state,
                                                                  hx=hx,
                                                                  cx=cx)
            action = Q_policy_values.data.max(1)[1].view(1, 1)
            return action, hx, cx
        # Random
        else:
            return self.LongTensor([[random.randrange(self.nb_action)]
                                    ]), hx, cx

    # Every episodes
    def update_epislon_greedy(self):
        self.epsilon_greedy = self.epsilon_greedy_end + (
            self.epsilon_greedy_start - self.epsilon_greedy_end) * math.exp(
                -1. * self.episode_iterator / self.nb_episodes_decay)

    ##################################################### SAVE/LOAD FUNCTIONS ##########################################

    def save(self):
        temp_env = self.env
        temp_gym_screen_proc = self.gym_screen_processing
        temp_buffer = self.buffer
        self.env = None
        self.gym_screen_processing = None
        self.buffer = None

        with open(self.save_path, 'wb') as output:
            cPickle.dump(self, output)
        self.env = temp_env
        self.gym_screen_processing = temp_gym_screen_proc
        self.buffer = temp_buffer

    def load_env(self, env):
        self.env = env

    def init_buffer(self):
        self.buffer = Buffer(self.buffer_size)

    ##################################################### ENVIRONMENT TYPE SPECIFIC ####################################

    def get_env_action(self, action):
        if (self.env_type == "Unity"):
            return action.cpu().numpy()
        else:
            return action.cpu().numpy()[0, 0]

    def get_screen(self):
        if (self.env_type == "Unity"):
            return img_to_tensor(self.env.get_screen())
        elif (self.env_type == "Gridworld"):
            return img_to_tensor(np.expand_dims(self.env.renderEnv(), axis=3))
            #return self.env.renderEnv()
        else:
            # Gym
            return self.gym_screen_processing.get_screen()

    #################################################### PLOT SPECIFIC FUNCTIONS #######################################
    def save_plot(self):
        plt.plot(self.log_cumulative_reward)
        plt.title("DRQN on " + self.save_path)
        plt.xlabel("Episodes")
        plt.ylabel("Cumulative reward")
        plt.savefig("../save/" + self.save_path + "_cumulative_rewards.png")
        plt.clf()
        plt.plot(self.log_loss[100:])
        plt.title("DRQN on " + self.save_path)
        plt.xlabel("Episodes")
        plt.ylabel("Loss")
        plt.savefig("../save/" + self.save_path + "_loss.png")
        plt.clf()
예제 #4
0
class DQN(object):

    # Initialize Buffer, Networks, global variables
    # and configure CUDA
    def __init__(self,
                 env,
                 buffer_size=10000,
                 active_cuda=True,
                 nb_episodes=2000,
                 max_steps=3500,
                 discount_factor=0.995,
                 epsilon_greedy_end=0.01,
                 epsilon_greedy_start=0.1,
                 batch_size=128,
                 update_target=10,
                 env_type="Unity",
                 train=True,
                 save_episode=800,
                 skip_frame=4,
                 stack_size=4,
                 nb_episodes_decay=100,
                 save_path="gym_cartpole",
                 nb_action=2,
                 lr=0.002,
                 weight_decay=1e-6,
                 update_plot=10):

        # Global parameters
        self.env = env
        self.nb_episodes = nb_episodes
        self.max_steps = max_steps
        self.discount_factor = discount_factor
        self.batch_size = batch_size
        self.update_target = update_target
        self.env_type = env_type
        self.save_episode = save_episode
        self.skip_frame = skip_frame
        self.stack_size = stack_size
        self.stack_frame = StackFrame(self.stack_size)
        self.save_path = save_path
        self.nb_episodes_decay = nb_episodes_decay
        self.nb_action = nb_action
        self.lr = lr
        self.weight_decay = weight_decay
        self.buffer_size = buffer_size
        self.update_plot = update_plot

        self.epsilon_greedy_start = epsilon_greedy_start
        self.epsilon_greedy_end = epsilon_greedy_end

        self.episode_iterator = 0

        # Log to see improvment
        self.log_cumulative_reward = []

        #################### PSEUDO CODE STEPS ############################

        # Initialize replay memory D
        self.buffer = Buffer(self.buffer_size)

        # Initialize Q policy network and Q target network
        self.Q_policy_net = DQN_net(self.nb_action)
        self.Q_target_net = DQN_net(self.nb_action)

        # Copy policy weight to target weight
        self.Q_target_net.load_state_dict(self.Q_policy_net.state_dict())

        ############### PYTORCH SPECIFIC INITIALIZATION ###################

        # Adapt to cuda
        self.active_cuda = active_cuda
        if active_cuda:
            self.Q_policy_net.cuda()
            self.Q_target_net.cuda()

        self.FloatTensor = torch.cuda.FloatTensor if active_cuda else torch.FloatTensor
        self.LongTensor = torch.cuda.LongTensor if active_cuda else torch.LongTensor
        self.ByteTensor = torch.cuda.ByteTensor if active_cuda else torch.ByteTensor
        self.Tensor = self.FloatTensor

        # Use RMSProp DeepMind's parameters
        self.optimizer = torch.optim.RMSprop(self.Q_policy_net.parameters(),
                                             lr=self.lr,
                                             weight_decay=self.weight_decay)
        # Init class to process each fram (just call gym_screen_processing.get_screen() to have the processed screen)
        self.gym_screen_processing = GymScreenProcessing(self.env, active_cuda)

    def train_loop(self, retrain=False):

        self.update_epislon_greedy()

        print("Train")
        if (self.episode_iterator >= self.nb_episodes):
            if (not retrain):
                "Please pass retrain parameter if you want to retrain the model. Warning: You will loose everything if " \
                "you choose to retrain your network."
                return
        for current_episode in range(self.episode_iterator, self.nb_episodes):

            cumulative_reward = 0

            self.env.reset()
            state = self.get_screen()

            # Init first stack frame
            self.stack_frame.reset_stack()
            # Init 4 first frames
            for i in range(0, self.stack_frame.max_frames):
                self.stack_frame.add_frame(state)

            old_stack = torch.cat(self.stack_frame.get_frames(), dim=1)

            print("Episode " + str(self.episode_iterator))

            # Initialize sequence s1 and preprocess (We take difference between two next frame)
            for t in range(0, self.max_steps):
                if (t % self.skip_frame == 0):

                    # Select epsilon greedy action
                    action = self.select_action(
                        Variable(old_stack, volatile=True))
                    # Process the action to the environment
                    env_action = self.get_env_action(action)

                    _, reward, done, _ = self.env.step(env_action)

                    cumulative_reward += reward

                    reward = self.Tensor([reward])

                    next_state = self.get_screen()

                    self.stack_frame.add_frame(next_state)

                    if not done:
                        next_stack = torch.cat(self.stack_frame.get_frames(),
                                               dim=1)
                        not_done_mask = self.ByteTensor(1).fill_(1)
                    else:
                        next_stack = None
                        not_done_mask = self.ByteTensor(1).fill_(0)
                        reward = self.Tensor([-1])

                    self.buffer.push(old_stack, action, next_stack, reward,
                                     not_done_mask)

                    self.learn()

                    old_stack = next_stack

                    if done:
                        print("Done")
                        break
                else:
                    self.env.step(env_action)

            print("Episode cumulative reward: ")
            print(cumulative_reward)

            if self.episode_iterator % self.save_episode == 0 and self.episode_iterator != 0:
                print("Save parameters checkpoint:")
                self.save()
                print("End saving")

            if self.episode_iterator % self.update_plot == 0:
                self.save_plot()

            self.episode_iterator += 1
            self.update_epislon_greedy()

            if current_episode % self.update_target == 0:
                self.Q_target_net.load_state_dict(
                    self.Q_policy_net.state_dict())

            self.log_cumulative_reward.append(cumulative_reward)

    ################################################ LEARNING FUNCTIONS ################################################

    # Gradient descent on (yi - Q_target(state))^2
    def learn(self):
        if (self.buffer.hasAtLeast(self.batch_size)):

            [
                batch_state, batch_action, batch_reward, batch_next_state,
                not_done_batch
            ] = Transition(*zip(*self.buffer.sample(self.batch_size)))
            batch_state = Variable(torch.cat(batch_state, dim=0))
            batch_action = Variable(torch.cat(batch_action))
            batch_reward = Variable(torch.cat(batch_reward))
            not_done_batch = self.ByteTensor(torch.cat(not_done_batch))
            non_final_next_states = Variable(torch.cat(
                [s for s in batch_next_state if s is not None]),
                                             volatile=True)

            Q_s_t_a = self.Q_policy_net(batch_state).gather(1, batch_action)

            Q_s_next_t_a = Variable(
                torch.zeros(self.batch_size).type(self.Tensor))
            Q_s_next_t_a[not_done_batch] = self.Q_target_net(
                non_final_next_states).max(1)[0]

            # Target Q_s_t_a value (like supervised learning )
            target_state_value = (Q_s_next_t_a *
                                  self.discount_factor) + batch_reward
            target_state_value = Variable(target_state_value.data)

            loss = F.smooth_l1_loss(Q_s_t_a, target_state_value)

            # Optimize the model
            self.optimizer.zero_grad()
            loss.backward()

            for param in self.Q_policy_net.parameters():
                param.grad.data.clamp_(-1, 1)
            self.optimizer.step()

    def select_action(self, state):
        # Greedy action
        if (np.random.uniform() > self.epsilon_greedy):
            return self.Q_policy_net.forward(state).data.max(1)[1].view(1, 1)
        # Random
        else:
            return self.LongTensor([[random.randrange(self.nb_action)]])

    # Every episodes
    def update_epislon_greedy(self):
        self.epsilon_greedy = self.epsilon_greedy_end + (
            self.epsilon_greedy_start - self.epsilon_greedy_end) * math.exp(
                -1. * self.episode_iterator / self.nb_episodes_decay)

    ##################################################### SAVE/LOAD FUNCTIONS ##########################################

    def save(self):
        temp_env = self.env
        temp_gym_screen_proc = self.gym_screen_processing
        temp_buffer = self.buffer
        self.env = None
        self.gym_screen_processing = None
        self.buffer = None

        with open(self.save_path, 'wb') as output:
            cPickle.dump(self, output)
        self.env = temp_env
        self.gym_screen_processing = temp_gym_screen_proc
        self.buffer = temp_buffer

    def load_env(self, env):
        self.env = env

    def init_buffer(self):
        self.buffer = Buffer(self.buffer_size)

    ##################################################### ENVIRONMENT TYPE SPECIFIC ####################################

    def get_env_action(self, action):
        if (self.env_type == "Unity"):
            return action.cpu().numpy()
        else:
            return action[0, 0]

    def get_screen(self):
        if (self.env_type == "Unity"):
            return img_to_tensor(self.env.get_screen())
        else:
            # Gym
            return self.gym_screen_processing.get_screen()

    #################################################### PLOT SPECIFIC FUNCTIONS #######################################
    def save_plot(self):
        plt.plot(self.log_cumulative_reward)
        plt.title("DQN on " + self.save_path)
        plt.xlabel("Episodes")
        plt.ylabel("Cumulative reward")
        plt.savefig("save/" + self.save_path + "_cumulative_rewards.png")


#with open('gym_cartpole.pkl', 'rb') as input:
#dqn_train = cPickle.load(input)
#env = UnityEnvironment(file_name="C:/Users/Bureau/Desktop/RL_DQN_FinalProject/POMDP/pomdp")
#dqn_train.load_env(env)
#dqn_train.init_buffer(10000)
#dqn_train.max_steps  = 10000

#dqn_train.train_mode  = False
# dqn_train.nb_episodes = 200000
#    print(dqn_train.episode_iterator)
#dqn_train.train_loop()

#print(dqn_train.log_cumulative_reward)

#reward_every_50 = np.mean(np.array(dqn_train.log_cumulative_reward).reshape(-1, 1), axis=1)
#plt.plot(reward_every_50)
#plt.title("DQN")
#plt.xlabel("Episodes (multiply by 177)")
#plt.ylabel("Cumulative reward")
#plt.show()
예제 #5
0
class Brain:

    def __init__(self, env_dict, params):
        """
        option_num, state_dim, action_dim, action_bound, gamma, learning_rate, replacement,
                 buffer_capacity, epsilon
        gamma: (u_gamma, l_gamma)
        learning_rate: (lr_u_policy, lr_u_critic, lr_option, lr_termin, lr_l_critic)
        """

        # session
        self.sess = tf.Session()

        # environment parameters
        self.sd = env_dict['state_dim']
        self.ad = env_dict['action_dim']
        a_bound = env_dict['action_scale']
        assert a_bound.shape == (self.ad,), 'Action bound does not match action dimension!'

        # hyper parameters
        self.on = params['option_num']
        epsilon = params['epsilon']
        u_gamma = params['upper_gamma']
        l_gamma = params['lower_gamma']
        u_capac = params['upper_capacity']
        l_capac = params['lower_capacity']
        u_lrcri = params['upper_learning_rate_critic']
        l_lrcri = params['lower_learning_rate_critic']
        l_lrpol = params['lower_learning_rate_policy']
        l_lrter = params['lower_learning_rate_termin']

        # the frequency of training termination function
        if params['delay'] == 'inf':
            self.delay = -1
        else:
            self.delay = params['delay']

        # Upper critic and buffer
        self.u_critic = UCritic(session=self.sess, state_dim=self.sd, option_num=self.on,
                                gamma=u_gamma, epsilon=epsilon, learning_rate=u_lrcri)
        self.u_buffer = Buffer(state_dim=self.sd, action_dim=1, capacity=u_capac)

        # Lower critic, options and buffer HER
        self.l_critic = LCritic(session=self.sess, state_dim=self.sd, action_dim=self.ad,
                                gamma=l_gamma, learning_rate=l_lrcri)
        self.l_options = [Option(session=self.sess, state_dim=self.sd, action_dim=self.ad,
                                 ordinal=i, learning_rate=[l_lrpol, l_lrter])
                          for i in range(self.on)]
        self.l_buffers = [Buffer(state_dim=self.sd, action_dim=self.ad, capacity=l_capac)
                          for i in range(self.on)]

        # Initialize all coefficients and saver
        self.sess.run(tf.global_variables_initializer())
        self.saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=100)

        self.tc = 0         # counter for training termination
        self.mc = 0         # counter for model

    def train_policy(self, batch_size):
        """Train upper critic(policy)"""

        if self.u_buffer.isFilled:

            # sample batches
            state_batch, option_batch, reward_batch, next_state_batch = self.u_buffer.sample(batch_size)
            # training
            self.u_critic.train(state_batch, option_batch, reward_batch, next_state_batch)

    def train_option(self, batch_size, option):
        """Train option"""

        # only train after the buffer is filled up
        if self.l_buffers[option].isFilled:

            # sample batches
            state_batch, action_batch, reward_batch, next_state_batch = \
                self.l_buffers[option].sample(batch_size)
            next_action_batch = self.l_options[option].get_target_actions(next_state_batch)

            # train lower critic
            self.l_critic.train(state_batch, action_batch, reward_batch, next_state_batch,
                                next_action_batch)

            # get affiliated batch
            q_gradients_batch = self.l_critic.q_gradients(state_batch, action_batch)
            self.tc += 1
            if self.tc == self.delay:
                advantage_batch = self.l_critic.q_batch(state_batch, action_batch) - \
                                  self._value_batch(state_batch)
                self.tc = 0
            else:
                advantage_batch = None


            # train lower options
            self.l_options[option].train(state_batch, q_gradients_batch, advantage_batch)

            return True

        return False

    def pretrain(self):
        """Pretrain ucritic and termination function"""

        self.u_critic.pretrain()
        for i in range(self.on):
            self.l_options[i].pretrain()
            self.l_options[i].render()
        self.save_model()

    def render(self):
        """Render ucritic and termination function"""

        self.u_critic.render()
        for i in range(self.on):
            self.l_options[i].render()

    def save_model(self):
        """Save current model"""

        if self.mc == 0:
            self.saver.save(self.sess, './model/model.ckpt', global_step=0, write_meta_graph=True)
        else:
            self.saver.save(self.sess, './model/model.ckpt', global_step=self.mc, write_meta_graph=False)
        self.mc += 1

    def restore_model(self, order=None):
        """Restore trained model"""

        ckpt = tf.train.get_checkpoint_state('./model/')
        saver = tf.train.import_meta_graph(ckpt.all_model_checkpoint_paths[0] + '.meta')
        if order is None:
            saver.restore(self.sess, ckpt.model_checkpoint_path)
        else:
            saver.restore(self.sess, ckpt.all_model_checkpoint_paths[order])
            self.mc = order + 1

    def _value_batch(self, state_batch):
        """The upper policy average of Q value for each option
        :return: the value
        """

        batch_size = state_batch.shape[0]
        value_batch = np.zeros((batch_size, 1))
        action_batch = [self.l_options[i].get_actions(state_batch) for i in range(self.on)]
        q_batch = [self.l_critic.q_batch(state_batch, action_batch[i]) for i in range(self.on)]
        distribution_batch = self.u_critic.get_distribution(state_batch)

        # calculate the value function
        for i in range(batch_size):
            for j in range(self.on):
                value_batch[i] += q_batch[j][i] * distribution_batch[i, j]

        return value_batch