コード例 #1
0
def simple_replay_train(DQN, train_batch):
    x_stack = np.empty(0).reshape(0, DQN.input_size)
    y_stack = np.empty(0).reshape(0, DQN.output_size)

    for state, action, reward, next_state, done in train_batch:
        Q = DQN.predict(state)

        if done:
            Q[0, action] = reward
        else:
            Q[0, action] = reward + dis * np.max(DQN.predict(next_state))

        y_stack = np.vstack([y_stack, Q])
        x_stack = np.vstack([x_stack, state])

    return DQN.update(x_stack, y_stack)
コード例 #2
0
def main():
    max_episodes = 5000
    replay_buffer = deque()

    with tf.Session() as sess:
        mainDQN = DQN(sess, input_size, output_size, name='main')
        targetDQN = DQN(sess, input_size, output_size, name='target')
        tf.global_variables_initializer().run()

        copy_ops = get_copy_var_ops(dest_scope_name='target', src_scope_name='main')

        sess.run(copy_ops)

        for episode in range(max_episodes):
            e = 1. / ((episode / 10) + 1)
            done = False
            step_count = 0

            state = env.reset()

            while not done:
                if np.random.rand(1) < e:
                    action = env.action_space.sample()

                else:
                    action = np.argmax(mainDQN.predict(state))

                next_state, reward, done, _ = env.step(action)

                if done:
                    reward = -100

                replay_buffer.append((state, action, reward, next_state, done))

                if len(replay_buffer) > REPLAY_MEMORY:
                    replay_buffer.popleft()

                state = next_state
                step_count += 1
                if step_count > 10000:
                    break

            print("Episode: {} Steps: {}".format(episode, step_count))

            if step_count > 10000:
                pass

            if episode % 10 == 1:
                for _ in range(50):
                    minibatch = random.sample(replay_buffer, 10)
                    loss, _ = replay_train(mainDQN, targetDQN, minibatch)

                print("Loss: ", loss)

                sess.run(copy_ops)

        bot_play(mainDQN)
コード例 #3
0
ファイル: agent.py プロジェクト: fy-meng/XRL_experiments
class DQNAgent(Agent):
    def __init__(self,
                 state_size,
                 num_actions,
                 batch_size=64,
                 gamma=0.999,
                 epsilon=0.9,
                 epsilon_decay=0.99995,
                 **kwargs):
        super(DQNAgent, self).__init__(state_size, num_actions, **kwargs)

        self.batch_size = batch_size
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay

        self.net = DQN(state_size, num_actions, **kwargs)

    def get_action(self, state: np.ndarray):
        if self.mode == 'train' and np.random.random() < self.epsilon:
            action = np.random.randint(self.num_actions)
        else:
            action = np.argmax(self.get_q_values(state), axis=-1)
        self.epsilon *= self.epsilon_decay
        return action

    def get_q_values(self, state: np.ndarray) -> np.ndarray:
        return self.net.predict(
            state).detach().cpu().numpy()  # shape = (b, m, c)

    def optimize(self):
        batch: List[Transition] = self.buffer.sample(self.batch_size)
        if batch is None:
            return

        self.net.optimize(batch, self.gamma)

    def save_model(self, model_save_path: str):
        self.net.save_model(model_save_path)
コード例 #4
0
ファイル: main.py プロジェクト: goldkim92/alphachu
def main():
	max_episodes = 20000
	REPLAY_MEMORY = 20000

	env = Env(max_episodes)
	replay_buffer = deque()
	with tf.Session() as sess:
		c, r = env.pixel_size
		mainDQN = DQN(sess, (r, c, env.history_num), env.action_space, name="main")
		targetDQN = DQN(sess, (r, c, env.history_num), env.action_space, name="target")
		tf.global_variables_initializer().run()
		saver = tf.train.Saver()
		saver.restore(sess, "./ckpt/model.ckpt")
		writer = tf.summary.FileWriter("summary/dqn", sess.graph)
		copy_ops = m.get_copy_var_ops(dest_scope_name="target", src_scope_name="main")
		sess.run(copy_ops)						

		set_reward = 0
		total_reward = 0		
		reward = 0
		ready = False
		print("Set standard")
		while not ready:
			mask, img = env.preprocess_img(False)			
			ready = env.get_standard(mask)
		print("Ready")
		for episode in range(4607, max_episodes + 1):	 
			start1 = False
			start2 = False
			restart = False
			end = False
			history = env.history_reset()

			print("Start!")
			time.sleep(env.frame_time * 5)
			starttime = time.time()
			frame = 0
			try:
				while not end:
					predicts = mainDQN.predict(history)
					env.avg_q_max += np.max(predicts)
					if np.random.rand(1) < env.epsilon:
						action = np.random.choice(range(env.action_space))
					else:
						action = np.argmax(predicts)
					env.key_dict[action](env.frame_time)
					mask, img = env.preprocess_img()
					end = env.check_end(mask)
					if end:
						reward = env.reward
					else:		
						reward = 1
					next_history = env.history_update(history, img)
				
					replay_buffer.append((history, action, reward, next_history, end))
					if len(replay_buffer) > REPLAY_MEMORY:
					    replay_buffer.popleft()
					history = next_history
					total_reward += reward		
					time.sleep(env.frame_time - ((time.time() - starttime) % env.frame_time))	
					frame += 1
					if env.epsilon > env.epsilon_end:
						env.epsilon = 1 - episode * env.epsilon_decay_step
					# endtime = time.time()
					# print(endtime-starttime)
					# starttime = endtime
			except KeyboardInterrupt:
				control.release()
				break					

			stats = [total_reward, env.avg_q_max / float(frame), frame, env.avg_loss / float(frame)]
			for i in range(len(stats)):
				sess.run(env.update_ops[i], feed_dict={env.summary_placeholders[i]: float(stats[i])})
			summary_str = sess.run(env.summary_op)
			writer.add_summary(summary_str, episode + 1)
			env.avg_q_max, env.avg_loss = 0, 0

			if reward > 0:
				print("Episode: {}, result: Win, reward:{}, frame:{}".format(episode, total_reward, frame))
			else:
				print("Episode: {}, result: Lose, reward:{}, frame:{}".format(episode, total_reward, frame))
			set_reward += total_reward
			total_reward = 0
			print("Not started yet")
			wait = 0
			while not start1 or start2 or not restart:
				start1 = start2
				mask, img = env.preprocess_img()
				start2 = env.check_start(mask)	
				restart = env.get_standard(mask, set_ = False)
				wait += 1
				if wait == 500:
					print("Game set: total reward: {}".format(set_reward))
					set_reward = 0
					for _ in range(50):
					    minibatch = random.sample(replay_buffer, 100)
					    loss, _ = m.replay_train(mainDQN, targetDQN, minibatch)
					    env.avg_loss += loss
					saver.save(sess, "./ckpt/model.ckpt")			
					print("Loss: ", loss)
					print("Wait to restart to game")						
					env.key_dict[6](env.frame_time)	
					time.sleep(env.frame_time)	

				if wait > 750:					
					print("Let's start!")
					env.key_dict[6](env.frame_time)	
					time.sleep(env.frame_time)		
				control.release()

			if episode % mainDQN.update_target_rate == 1:
				sess.run(copy_ops)
コード例 #5
0
     )
     # Update episode number
     n_episode += 1
     # Reset environment
     state = env.reset()
     episode_reward = 0
     episode_steps = 0
     total_max_q = 0
     done = False
 # Evaluate epsilon for epsilon-greedy policy
 if step < EPSILON_FALL:
     epsilon = EPSILON_MAX - (EPSILON_MAX - EPSILON_MIN) * step / EPSILON_FALL
 else:
     epsilon = EPSILON_MIN
 # Choose action epsilon-greedily using dqn
 q_values = dqn.predict([state], sess)[0]
 if np.random.random() < epsilon:
     # Choose random action
     a = env.action_space.sample()
 else:
     # Choose best action
     a = np.argmax(q_values)
 # Perform choosen action
 next_state, reward, done, _ = env.step(a)
 episode_reward += reward
 episode_steps += 1
 # Insert into replay buffer
 repbuf.add_sample((state, a, reward, next_state, done))
 state = next_state
 # Stats
 total_max_q += q_values.max()
コード例 #6
0
class Agent:
    """
    Class representing a learning agent acting in an environment.
    """
    def __init__(self,
                 buffer_size,
                 batch_size,
                 alpha,
                 gamma,
                 epsilon,
                 epsilon_min,
                 epsilon_decay,
                 lr,
                 game="CartPole-v1",
                 mean_bound=5,
                 reward_bound=495.0,
                 sync_model=1000,
                 save_model=10):
        """
        Constructor of the agent class.
            - game="CartPole-v1" : Name of the game environment
            - mean_bound=5 : Number of last acquired rewards considered for mean reward
            - reward_bound=495.0 : Reward acquired for completing an episode properly
            - sync_model=1000 : Interval for synchronizing model and target model
            - save_model=10 : Interval for saving model

            - buffer_size : Replay buffer size of the DQN model
            - batch_size : Batch size of the DQN model
            - alpha : Learning rate for Q-Learning
            - gamma : Discount factor for Q-Learning
            - epsilon : Threshold for taking a random action
            - epsilon_min : Minimal value allowed for epsilon
            - epsilon_decay : Decay rate for epsilon
            - lr : Learning rate for the DQN model
        """

        # Environment variables
        self.game = game
        self.env = gym.make(self.game)
        self.num_states = self.env.observation_space.shape[0]
        self.num_actions = self.env.action_space.n

        # Agent variables
        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.buffer = ReplayBuffer(self.buffer_size, self.batch_size)
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.mean_bound = mean_bound
        self.reward_bound = reward_bound

        # DQN variables
        self.lr = lr
        self.model = DQN(self.num_states, self.num_actions, self.lr)
        self.target_model = DQN(self.num_states, self.num_actions, self.lr)
        self.target_model.update(self.model)
        self.sync_model = sync_model
        self.save_model = save_model

        # File paths
        dirname = os.path.dirname(__file__)
        self.path_model = os.path.join(dirname, "../models/dqn.h5")
        self.path_plot = os.path.join(dirname, "../plots/dqn.png")

        # Load model, if it already exists
        try:
            self.model.load(self.path_model)
            self.target_model.update(self.model)
        except:
            print("Model does not exist! Create new model...")

    def reduce_epsilon(self):
        """
        Reduces the parameter epsilon up to a given minimal value where the speed of decay is controlled by some given parameter.
        """

        epsilon = self.epsilon * self.epsilon_decay

        if epsilon >= self.epsilon_min:
            self.epsilon = epsilon
        else:
            self.epsilon = self.epsilon_min

    def get_action(self, state):
        """
        Returns an action for a given state, based on the current policy.
            - state : Current state of the agent
        """

        if np.random.random() < self.epsilon:
            action = self.env.action_space.sample()
        else:
            action = np.argmax(self.model.predict(state))

        return action

    def train(self, num_episodes, report_interval):
        """
        Trains the DQN model for a given number of episodes. Outputting report information is controlled by a given time interval.
            - num_episodes : Number of episodes to train
            - report_interval : Interval for outputting report information of training
        """

        step = 0
        total_rewards = []

        for episode in range(1, num_episodes + 1):
            if episode % self.save_model == 0:
                self.model.save(self.path_model)

            state = self.env.reset()
            state = state.reshape((1, self.num_states))
            total_reward = 0.0

            while True:
                step += 1

                action = self.get_action(state)
                next_state, reward, done, _ = self.env.step(action)
                next_state = next_state.reshape((1, self.num_states))

                # Penalize agent if pole could not be balanced until end of episode
                if done and reward < 499.0:
                    reward = -100.0

                self.buffer.remember(state, action, reward, next_state, done)
                self.replay()
                self.reduce_epsilon()

                state = next_state
                total_reward += reward

                if step % self.sync_model == 0:
                    self.target_model.update(self.model)

                if done:
                    total_reward += 100.0
                    total_rewards.append(total_reward)
                    mean_reward = np.mean(total_rewards[-self.mean_bound:])

                    if episode % report_interval == 0:
                        print(f"Episode: {episode}/{num_episodes}"
                              f"\tStep: {step}"
                              f"\tMemory Size: {len(self.memory)}"
                              f"\tEpsilon: {self.epsilon : .3f}"
                              f"\tReward: {total_reward}"
                              f"\tLast 5 Mean: {mean_reward : .2f}")

                        self.plot_rewards(total_rewards)

                    if mean_reward > self.reward_bound:
                        self.model.save(self.path_model)
                        return

                    break

        self.model.save(self.path_model)

    def replay(self):
        """
        Samples training data from the replay buffer and fits the DQN model.
        """

        sample_size, states, actions, rewards, next_states, dones = self.memory.sample(
        )

        q_values = self.model.predict(states)
        next_q_values = self.target_model.predict(next_states)

        for i in range(sample_size):
            action = actions[i]
            done = dones[i]

            if done:
                q_target = rewards[i]
            else:
                q_target = rewards[i] + self.gamma * np.max(next_q_values[i])

            q_values[i][action] = (1 - self.alpha) * \
                q_values[i][action] + self.alpha * q_target

        self.model.fit(states, q_values)

    def play(self, num_episodes):
        """
        Renders the trained agent for a given number of episodes.
            - num_episodes : Number of episodes to render
        """

        self.epsilon = self.epsilon_min

        for episode in range(1, num_episodes + 1):
            state = self.env.reset()
            state = state.reshape((1, self.num_states))
            total_reward = 0.0

            while True:
                self.env.render()
                action = self.get_action(state)
                next_state, reward, done, _ = self.env.step(action)
                next_state = next_state.reshape((1, self.num_states))
                state = next_state
                total_reward += reward

                if done:
                    print(f"Episode: {episode}/{num_episodes}"
                          f"\tTotal Reward: {total_reward : .2f}")

                    break

    def plot_rewards(self, total_rewards):
        """
        Plots the rewards the agent has acquired during training.
            - total_rewards : Rewards the agent has gained per episode
        """

        x = range(len(total_rewards))
        y = total_rewards

        slope, intercept, _, _, _ = linregress(x, y)

        plt.plot(x, y, linewidth=0.8)
        plt.plot(x, slope * x + intercept, color="red", linestyle="-.")
        plt.xlabel("Episode")
        plt.ylabel("Reward")
        plt.title("DQN-Learning")
        plt.savefig(self.path_plot)
コード例 #7
0
ファイル: agent.py プロジェクト: fy-meng/XRL_experiments
class DQNCropAgent(CropAgent):
    def __init__(self,
                 state_size,
                 _num_actions,
                 batch_size=64,
                 gamma=0.999,
                 epsilon=0.9,
                 epsilon_decay=0.99995,
                 **kwargs):
        num_actions = len(self.WATER_VALUES) * len(self.NITROGEN_VALUES) \
                      * len(self.PHOSPHORUS_VALUES)
        super(DQNCropAgent, self).__init__(state_size, num_actions, **kwargs)

        self.batch_size = batch_size
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay

        self.net = DQN(state_size, self.num_actions, **kwargs)

    def get_action(self, state: np.ndarray):
        if self.mode == 'train' and np.random.random() < self.epsilon:
            action_idx = np.random.randint(self.num_actions)
        else:
            action_idx = self.get_q_values(state).argmax(axis=-1)

        self.epsilon *= self.epsilon_decay

        # convert action index to actual action values
        action = self.idx_to_action(action_idx)

        return action

    def get_q_values(self, state: np.ndarray) -> np.ndarray:
        return self.net.predict(state).detach().cpu().numpy()

    def get_saliency(self, state: np.ndarray,
                     q_values: np.ndarray) -> np.ndarray:
        assert state.size == self.state_size, "saliency cannot be computed during training"

        self.update_state_value_range(state)

        saliency = np.zeros_like(state)
        action: int = q_values.argmax()
        q_values_dict = {i: q / 100 for i, q in enumerate(q_values.squeeze())}

        for _ in range(self.SALIENCY_TRIALS):
            for i in range(self.state_size):
                perturbed_state = self.perturb(state, i)
                perturbed_q_values = self.get_q_values(perturbed_state)
                perturbed_q_values_dict = {
                    j: q / 100
                    for j, q in enumerate(perturbed_q_values.squeeze())
                }

                saliency[i] += computeSaliencyUsingSarfa(
                    action, q_values_dict,
                    perturbed_q_values_dict)[0] / self.SALIENCY_TRIALS

        return saliency

    def optimize(self):
        batch: List[Transition] = self.buffer.sample(self.batch_size)
        if batch is None:
            return

        self.net.optimize(batch, self.gamma)

    def save_model(self, model_save_path: str):
        self.net.save_model(model_save_path)
コード例 #8
0
ファイル: agent.py プロジェクト: KokoMind/DQN-TF
class Agent:
    """Our Wasted Agent :P """
    def __init__(self, sess, config, environment, evaluation_enviroment):
        # Get the session, config, environment, and create a replaymemory
        self.sess = sess
        self.config = config
        self.environment = environment
        self.evaluation_enviroment = evaluation_enviroment

        if config.prm:
            self.memory = PrioritizedExperienceReplay(sess, config)
        else:
            self.memory = ReplayMemory(config.state_shape, config.rep_max_size)

        self.init_dirs()

        self.init_cur_epsiode()
        self.init_global_step()
        self.init_epsilon()
        self.init_summaries()

        # Intialize the DQN graph which contain 2 Networks Target and Q
        self.estimator = DQN(sess, config, self.environment.n_actions)

        # To initialize all variables
        self.init = tf.group(tf.global_variables_initializer(),
                             tf.local_variables_initializer())
        self.sess.run(self.init)

        self.saver = tf.train.Saver(max_to_keep=10)
        self.summary_writer = tf.summary.FileWriter(self.summary_dir,
                                                    self.sess.graph)

        if config.is_train and not config.cont_training:
            pass
        elif config.is_train and config.cont_training:
            self.load()
        elif config.is_play:
            self.load()
        else:
            raise Exception("Please Set proper mode for training or playing")

    def load(self):
        latest_checkpoint = tf.train.latest_checkpoint(self.checkpoint_dir)
        if latest_checkpoint:
            print("Loading model checkpoint {}...\n".format(latest_checkpoint))
            self.saver.restore(self.sess, latest_checkpoint)

    def save(self):
        self.saver.save(self.sess, self.checkpoint_dir,
                        self.global_step_tensor)

    def init_dirs(self):
        # Create directories for checkpoints and summaries
        self.checkpoint_dir = os.path.join(self.config.experiment_dir,
                                           "checkpoints/")
        self.summary_dir = os.path.join(self.config.experiment_dir,
                                        "summaries/")

    def init_cur_epsiode(self):
        """Create cur episode tensor to totally save the process of the training"""
        with tf.variable_scope('cur_episode'):
            self.cur_episode_tensor = tf.Variable(-1,
                                                  trainable=False,
                                                  name='cur_epsiode')
            self.cur_epsiode_input = tf.placeholder('int32',
                                                    None,
                                                    name='cur_episode_input')
            self.cur_episode_assign_op = self.cur_episode_tensor.assign(
                self.cur_epsiode_input)

    def init_global_step(self):
        """Create a global step variable to be a reference to the number of iterations"""
        with tf.variable_scope('step'):
            self.global_step_tensor = tf.Variable(0,
                                                  trainable=False,
                                                  name='global_step')
            self.global_step_input = tf.placeholder('int32',
                                                    None,
                                                    name='global_step_input')
            self.global_step_assign_op = self.global_step_tensor.assign(
                self.global_step_input)

    def init_epsilon(self):
        """Create an epsilon variable"""
        with tf.variable_scope('epsilon'):
            self.epsilon_tensor = tf.Variable(self.config.initial_epsilon,
                                              trainable=False,
                                              name='epsilon')
            self.epsilon_input = tf.placeholder('float32',
                                                None,
                                                name='epsilon_input')
            self.epsilon_assign_op = self.epsilon_tensor.assign(
                self.epsilon_input)

    def init_summaries(self):
        """Create the summary part of the graph"""
        with tf.variable_scope('summary'):
            self.summary_placeholders = {}
            self.summary_ops = {}
            self.scalar_summary_tags = [
                'episode.total_reward', 'episode.length',
                'evaluation.total_reward', 'evaluation.length', 'epsilon'
            ]
            for tag in self.scalar_summary_tags:
                self.summary_placeholders[tag] = tf.placeholder('float32',
                                                                None,
                                                                name=tag)
                self.summary_ops[tag] = tf.summary.scalar(
                    tag, self.summary_placeholders[tag])

    def init_replay_memory(self):
        # Populate the replay memory with initial experience
        print("initializing replay memory...")

        state = self.environment.reset()
        for i in itertools.count():
            action = self.take_action(state)
            next_state, reward, done = self.observe_and_save(
                state, self.environment.valid_actions[action])
            if done:
                if self.config.prm:
                    if i >= self.config.prm_init_size:
                        break
                else:
                    if i >= self.config.replay_memory_init_size:
                        break
                state = self.environment.reset()
            else:
                state = next_state
        print("finished initializing replay memory")

    def policy_fn(self, fn_type, estimator, n_actions):
        """Function that contain definitions to various number of policy functions and choose between them"""
        def epsilon_greedy(sess, observation, epsilon):
            actions = np.ones(n_actions, dtype=float) * epsilon / n_actions
            q_values = estimator.predict(np.expand_dims(observation, 0))[0]
            best_action = np.argmax(q_values)
            actions[best_action] += (1.0 - epsilon)
            return actions

        def greedy(sess, observation):
            q_values = estimator.predict(np.expand_dims(observation, 0),
                                         type="target")[0]
            best_action = np.argmax(q_values)
            return best_action

        if fn_type == 'epsilon_greedy':
            return epsilon_greedy
        elif fn_type == 'greedy':
            return greedy
        else:
            raise Exception("Please Select a proper policy function")

    def take_action(self, state):
        """Take the action based on the policy function"""
        action_probs = self.policy(self.sess, state,
                                   self.epsilon_tensor.eval(self.sess))
        action = np.random.choice(np.arange(len(action_probs)), p=action_probs)
        return action

    def observe_and_save(self, state, action):
        """Function that observe the new state , reward and save it in the memory"""
        next_state, reward, done = self.environment.step(action)
        self.memory.push(state, next_state, action, reward, done)
        return next_state, reward, done

    def update_target_network(self):
        """Update Target network By copying paramter between the two networks in DQN"""
        self.estimator.update_target_network()

    def add_summary(self, summaries_dict, step):
        """Add the summaries to tensorboard"""
        summary_list = self.sess.run(
            [self.summary_ops[tag] for tag in summaries_dict.keys()], {
                self.summary_placeholders[tag]: value
                for tag, value in summaries_dict.items()
            })
        for summary in summary_list:
            self.summary_writer.add_summary(summary, step)
        self.summary_writer.flush()

    def train_episodic(self):
        """Train the agent in episodic techniques"""

        # Initialize the epsilon step, it's step, the policy function, the replay memory
        self.epsilon_step = (
            self.config.initial_epsilon -
            self.config.final_epsilon) / self.config.exploration_steps
        self.policy = self.policy_fn(self.config.policy_fn, self.estimator,
                                     self.environment.n_actions)
        self.init_replay_memory()

        for cur_episode in range(
                self.cur_episode_tensor.eval(self.sess) + 1,
                self.config.num_episodes, 1):

            # Save the current checkpoint
            self.save()

            # Update the Cur Episode tensor
            self.cur_episode_assign_op.eval(
                session=self.sess,
                feed_dict={
                    self.cur_epsiode_input:
                    self.cur_episode_tensor.eval(self.sess) + 1
                })

            # Evaluate Now to see how it behave
            if cur_episode % self.config.evaluate_every == 0:
                self.evaluate(cur_episode / self.config.evaluate_every)

            state = self.environment.reset()
            total_reward = 0

            # Take steps in the environment untill terminal state of epsiode
            for t in itertools.count():

                # Update the Global step
                self.global_step_assign_op.eval(
                    session=self.sess,
                    feed_dict={
                        self.global_step_input:
                        self.global_step_tensor.eval(self.sess) + 1
                    })

                # time to update the target estimator
                if self.global_step_tensor.eval(
                        self.sess
                ) % self.config.update_target_estimator_every == 0:
                    self.update_target_network()

                # Calculate the Epsilon for this time step
                # Take an action ..Then observe and save
                self.epsilon_assign_op.eval(
                    {
                        self.epsilon_input:
                        max(
                            self.config.final_epsilon,
                            self.epsilon_tensor.eval(self.sess) -
                            self.epsilon_step)
                    }, self.sess)
                action = self.take_action(state)
                next_state, reward, done = self.observe_and_save(
                    state, self.environment.valid_actions[action])

                # Sample a minibatch from the replay memory
                if self.config.prm:
                    indices_batch, weights_batch, state_batch, next_state_batch, action_batch, reward_batch, done_batch = self.memory.sample(
                    )
                else:
                    state_batch, next_state_batch, action_batch, reward_batch, done_batch = self.memory.get_batch(
                        self.config.batch_size)

                # Calculate targets Then Compute the loss
                q_values_next = self.estimator.predict(next_state_batch,
                                                       type="target")
                targets_batch = reward_batch + np.invert(done_batch).astype(
                    np.float32) * self.config.discount_factor * np.amax(
                        q_values_next, axis=1)

                if self.config.prm:
                    _ = self.estimator.update(state_batch, action_batch,
                                              targets_batch, weights_batch)
                else:
                    _ = self.estimator.update(state_batch, action_batch,
                                              targets_batch)

                total_reward += reward

                if done:  # IF terminal state so exit the episode
                    # Add summaries to tensorboard
                    summaries_dict = {
                        'episode.total_reward': total_reward,
                        'episode.length': t,
                        'epsilon': self.epsilon_tensor.eval(self.sess)
                    }
                    self.add_summary(summaries_dict,
                                     self.global_step_tensor.eval(self.sess))
                    break

                state = next_state

        print("Training Finished")

    def train_continous(self):
        # TODO implement on global step only
        pass

    def play(self, n_episode=10):
        """Function that play greedily on the policy learnt"""
        # Play Greedily
        self.policy = self.policy_fn('greedy', self.estimator,
                                     self.environment.n_actions)

        for cur_episode in range(n_episode):

            state = self.environment.reset()
            total_reward = 0

            for t in itertools.count():

                best_action = self.policy(self.sess, state)
                next_state, reward, done = self.environment.step(
                    self.environment.valid_actions[best_action])

                total_reward += reward

                if done:
                    print("Total Reward in Epsiode " + str(cur_episode) +
                          " = " + str(total_reward))
                    print("Total Length in Epsiode " + str(cur_episode) +
                          " = " + str(t))
                    break

                state = next_state

    def evaluate(self, local_step):

        print('evaluation #{0}'.format(local_step))

        policy = self.policy_fn('greedy', self.estimator,
                                self.evaluation_enviroment.n_actions)

        for cur_episode in range(self.config.evaluation_episodes):

            state = self.evaluation_enviroment.reset()
            total_reward = 0

            for t in itertools.count():

                best_action = policy(self.sess, state)
                next_state, reward, done = self.evaluation_enviroment.step(
                    self.evaluation_enviroment.valid_actions[best_action])

                total_reward += reward

                if done:
                    # Add summaries to tensorboard
                    summaries_dict = {
                        'evaluation.total_reward': total_reward,
                        'evaluation.length': t
                    }
                    self.add_summary(summaries_dict,
                                     local_step * 5 + cur_episode)
                    break

                state = next_state

        print('Finished evaluation #{0}'.format(local_step))