class A2CAgent:
    def __init__(self,
                 replay_size,
                 memory_size=10000,
                 prioritized=False,
                 load_models=False,
                 actor_model_file='',
                 critic_model_file='',
                 is_eval=False):
        self.state_size = 2
        self.action_size = 3
        self.step = 0
        self.replay_size = replay_size
        self.replay_queue = deque(maxlen=self.replay_size)
        self.memory_size = memory_size
        self.prioritized = prioritized
        if self.prioritized:
            self.memory = Memory(capacity=memory_size)

        # Hyper parameters for learning
        self.value_size = 1
        self.layer_size = 16
        self.discount_factor = 0.99
        self.actor_learning_rate = 0.0005
        self.critic_learning_rate = 0.005
        self.is_eval = is_eval

        # Create actor and critic neural networks
        self.actor = self.build_actor()
        self.critic = self.build_critic()
        #self.actor.summary()

        if load_models:
            if actor_model_file:
                self.actor.load_weights(actor_model_file)
            if critic_model_file:
                self.critic.load_weights(critic_model_file)

    # The actor takes a state and outputs probabilities of each possible action
    def build_actor(self):

        layer1 = Dense(self.layer_size,
                       input_dim=self.state_size,
                       activation='relu',
                       kernel_initializer='he_uniform')
        layer2 = Dense(self.layer_size,
                       input_dim=self.layer_size,
                       activation='relu',
                       kernel_initializer='he_uniform')
        # Use softmax activation so that the sum of probabilities of the actions becomes 1
        layer3 = Dense(self.action_size,
                       activation='softmax',
                       kernel_initializer='he_uniform')  # self.action_size = 2

        actor = Sequential(layers=[layer1, layer2, layer3])

        # Print a summary of the network
        actor.summary()

        # We use categorical crossentropy loss since we have a probability distribution
        actor.compile(loss='categorical_crossentropy',
                      optimizer=Adam(lr=self.actor_learning_rate))
        return actor

    # The critic takes a state and outputs the predicted value of the state
    def build_critic(self):

        layer1 = Dense(self.layer_size,
                       input_dim=self.state_size,
                       activation='relu',
                       kernel_initializer='he_uniform')
        layer2 = Dense(self.layer_size,
                       input_dim=self.layer_size,
                       activation='relu',
                       kernel_initializer='he_uniform')
        layer3 = Dense(self.value_size,
                       activation='linear',
                       kernel_initializer='he_uniform')  # self.value_size = 1

        critic = Sequential(layers=[layer1, layer2, layer3])

        # Print a summary of the network
        critic.summary()

        critic.compile(loss='mean_squared_error',
                       optimizer=Adam(lr=self.critic_learning_rate))
        return critic

    def act(self, state):
        # Get probabilities for each action
        policy = self.actor.predict(np.array([state]), batch_size=1).flatten()

        # Randomly choose an action
        if not self.is_eval:
            return np.random.choice(self.action_size, 1, p=policy).take(0)
        else:
            return np.argmax(policy)  # 20191117- for evaluation

    def store_transition(self, s, a, r, s_, dd):
        if self.prioritized:  # prioritized replay
            transition = np.hstack((s, [a, r], s_, dd))
            self.memory.store(
                transition)  # have high priority for newly arrived transition
        else:
            #self.replay_queue.append((s, [a, r], s_, dd))
            transition = np.hstack((s, [a, r], s_, dd))
            self.replay_queue.append(transition)

    def expReplay(self, batch_size=64, lr=1, factor=0.95):
        if self.prioritized:
            tree_idx, batch_memory, ISWeights = self.memory.sample(batch_size)
        else:
            batch_memory = random.sample(self.replay_queue, batch_size)

        s_prevBatch = np.array([replay[[0, 1]] for replay in batch_memory])
        a = np.array([replay[[2]] for replay in batch_memory])
        r = np.array([replay[[3]] for replay in batch_memory])
        s_currBatch = np.array([replay[[4, 5]] for replay in batch_memory])
        d = np.array([replay[[6]] for replay in batch_memory])

        td_error = np.zeros((d.shape[0], ), dtype=float)
        for i in range(d.shape[0]):
            q_prev = self.critic.predict(np.array([s_prevBatch[i, :]]))
            q_curr = self.critic.predict(np.array([s_currBatch[i, :]]))
            if int(d[i]) == 1:
                q_curr = r[i]
            q_realP = r[i] + factor * q_curr
            advantages = np.zeros((1, self.action_size))
            advantages[0, int(a[i])] = q_realP - q_prev

            if self.prioritized:
                td_error[i] = abs(advantages[0, int(a[i])])

            self.actor.fit(np.array([s_prevBatch[i, :]]),
                           advantages,
                           epochs=1,
                           verbose=0)
            self.critic.fit(np.array([s_prevBatch[i, :]]),
                            reshape(q_realP),
                            epochs=1,
                            verbose=0)

        if self.prioritized:
            self.memory.batch_update(tree_idx, td_error)
예제 #2
0
파일: ddqn.py 프로젝트: shivamsaboo17/DQN
if __name__ == '__main__':
    memory_size = 100000
    pretrain_length = 100000
    memory = Memory(memory_size)
    env = gym.make('CartPole-v1')
    state = env.reset()
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n
    print('Building randomized priority tree', end='')
    for i in range(pretrain_length):
        action = env.action_space.sample()
        next_state, reward, done, _ = env.step(action)
        state = np.reshape(state, [1, state_size])
        next_state = np.reshape(next_state, [1, state_size])
        memory.store((state, action, reward, next_state, done))
        state = next_state
        if done:
            env.reset()
    agent = DQN(state_size, action_size)
    done = False
    batch_size = 32
    EPISODES = 5000
    with tqdm(total=EPISODES) as pbar:
        for e in range(EPISODES):
            state = env.reset()
            state = np.reshape(state, [1, state_size])
            for time in range(500):
                env.render()
                action = agent.select_action(state)
                next_state, reward, done, _ = env.step(action)
class DoubleDQN(object):
    def __init__(self, replay_size, memory_size=10000, prioritized=False):
        self.step = 0
        self.replay_size = replay_size
        self.replay_queue = deque(maxlen=self.replay_size)
        self.memory_size = memory_size
        self.tau = 1e-2  #MountainCar-v0
        self.model = self.create_model()
        self.prioritized = prioritized
        self.target_model = self.create_model()
        self.target_model.set_weights(self.model.get_weights())
        if self.prioritized:
            self.memory = Memory(capacity=memory_size)

    def create_model(self):

        STATE_DIM, ACTION_DIM = 2, 3
        model = models.Sequential([
            layers.Dense(100, input_dim=STATE_DIM, activation='relu'),
            layers.Dense(ACTION_DIM, activation="linear")
        ])
        model.compile(loss='mean_squared_error',
                      optimizer=optimizers.Adam(0.001))
        return model

    def act(self, s, epsilon=0.1):

        #
        if np.random.uniform() < epsilon - self.step * 0.0002:
            return np.random.choice([0, 1, 2])
        return np.argmax(self.model.predict(np.array([s]))[0])

    def save_model(self, file_path='MountainCar-v0-Ddqn.h5'):
        print('model saved')
        self.model.save(file_path)

    def store_transition(self, s, a, r, s_, dd):
        if self.prioritized:  # prioritized replay
            transition = np.hstack((s, [a, r], s_, dd))  # transition -> 7x1
            self.memory.store(
                transition)  # have high priority for newly arrived transition
        else:
            #self.replay_queue.append((s, [a, r], s_, dd))
            transition = np.hstack((s, [a, r], s_, dd))  # transition -> 7x1
            self.replay_queue.append(transition)

    def expReplay(self, batch_size=64, lr=1, factor=0.95):

        if self.prioritized:
            tree_idx, batch_memory, ISWeights = self.memory.sample(batch_size)
        else:
            batch_memory = random.sample(self.replay_queue, batch_size)

        s_batch = np.array([replay[[0, 1]] for replay in batch_memory])
        a = np.array([replay[[2]] for replay in batch_memory])
        r = np.array([replay[[3]] for replay in batch_memory])
        next_s_batch = np.array([replay[[4, 5]] for replay in batch_memory])
        d = np.array([replay[[6]] for replay in batch_memory])

        Q = self.model.predict(s_batch)
        Q_next = self.model.predict(next_s_batch)
        Q_targ = self.target_model.predict(next_s_batch)

        #update Q value
        td_error = np.zeros((d.shape[0], ), dtype=float)
        for i in range(d.shape[0]):
            old_q = Q[i, int(a[i])]
            if int(d[i]) == 1:
                Q[i, int(a[i])] = r[i]
            else:
                next_best_action = np.argmax(Q_next[i, :])
                Q[i, int(a[i])] = r[i] + factor * Q_targ[i, next_best_action]

            if self.prioritized:
                td_error[i] = abs(old_q - Q[i, int(a[i])])

        if self.prioritized:
            self.memory.batch_update(tree_idx, td_error)

        self.model.fit(s_batch, Q, verbose=0)

    def transfer_weights(self):
        """ Transfer Weights from Model to Target at rate Tau
        """
        W = self.model.get_weights()
        tgt_W = self.target_model.get_weights()
        for i in range(len(W)):
            tgt_W[i] = self.tau * W[i] + (1 - self.tau) * tgt_W[i]
        self.target_model.set_weights(tgt_W)
예제 #4
0
파일: QLearning.py 프로젝트: czxttkl/X-AI
class QLearning:
    def __init__(
        self,
        k,
        d,
        env_name,
        env_dir,
        env_fixed_xo,
        n_hidden,
        save_and_load_path,
        load,
        tensorboard_path,
        logger_path,
        learn_wall_time_limit,
        prioritized,
        trial_size,
        learning_rate=0.005,
        # we have finite horizon, so we don't worry about reward explosion
        # see: https://goo.gl/Ew4629 (Other Prediction Problems and Update Rules)
        reward_decay=1.0,
        e_greedy=0.8,
        save_model_iter=5000,
        memory_capacity=300000,
        memory_capacity_start_learning=10000,
        batch_size=64,
        e_greedy_increment=0.0005,
        replace_target_iter=500,
        planning=False,
        random_seed=None,
    ):
        self.env_name = env_name
        self.env, self.n_features, self.n_actions = self.get_env(
            env_name, env_dir, env_fixed_xo, k, d)
        self.save_and_load_path = save_and_load_path
        self.load = load

        self.path_check(load)

        # create a graph for model variables and session
        self.graph = tf.Graph()
        self.sess = tf.Session(graph=self.graph)

        if not load:
            self.random_seed = random_seed
            numpy.random.seed(self.random_seed)
            tf.set_random_seed(self.random_seed)

            self.tensorboard_path = tensorboard_path
            self.logger_path = logger_path
            self.tb_writer = TensorboardWriter(
                folder_name=self.tensorboard_path, session=self.sess)
            self.logger = Logger(self.logger_path)

            self.n_hidden = n_hidden
            self.lr = learning_rate
            self.gamma = reward_decay
            self.epsilon_max = e_greedy
            self.save_model_iter = save_model_iter
            self.memory_capacity = memory_capacity
            self.memory_capacity_start_learning = memory_capacity_start_learning
            self.learn_wall_time_limit = learn_wall_time_limit
            self.batch_size = batch_size
            self.epsilon_increment = e_greedy_increment
            self.epsilon = 0 if e_greedy_increment is not None else self.epsilon_max
            self.prioritized = prioritized  # decide to use prioritized experience replay or not
            self.trial_size = trial_size
            self.replace_target_iter = replace_target_iter
            self.planning = planning  # decide to use planning for additional learning

            with self.graph.as_default():
                self._build_net()
                self.sess.run(tf.global_variables_initializer())
            self.memory = Memory(prioritized=self.prioritized,
                                 capacity=self.memory_capacity,
                                 n_features=self.n_features,
                                 n_actions=self.n_actions,
                                 batch_size=self.batch_size,
                                 planning=self.planning,
                                 qsa_feature_extractor=self.env.step_state,
                                 qsa_feature_extractor_for_all_acts=self.env.
                                 all_possible_next_states)
            self.learn_iterations = 0
            self.learn_wall_time = 0.
            self.sample_iterations = 0
            self.sample_wall_time = 0.
            self.last_cpu_time = 0.
            self.last_wall_time = 0.
            self.last_save_time = time.time()
            self.last_test_learn_iterations = 0
        else:
            self.load_model()

        self.memory_lock = multiprocessing.Lock(
        )  # lock for memory modification

    def get_env(self, env_name, env_dir, env_fixed_xo, k, d):
        # n_actions: # of one-card modification + 1 for not changing any card
        # n_features: input dimension to qlearning network (x_o and x_p plus time step as a feature)
        if env_name == 'env_nn':
            from environment.env_nn import Environment
            if env_dir:
                env = Environment.load(env_dir)
            else:
                raise NotImplementedError(
                    'we enforce environment has been created')
            n_features, n_actions = 2 * env.k + 1, env.d * (env.k - env.d) + 1
        elif env_name == 'env_nn_noisy':
            from environment.env_nn_noisy import Environment
            if env_dir:
                env = Environment.load(env_dir)
            else:
                raise NotImplementedError(
                    'we enforce environment has been created')
            n_features, n_actions = 2 * env.k + 1, env.d * (env.k - env.d) + 1
        elif env_name == 'env_greedymove':
            from environment.env_greedymove import Environment
            if env_dir:
                env = Environment.load(env_dir)
            else:
                raise NotImplementedError(
                    'we enforce environment has been created')
            n_features, n_actions = 2 * env.k + 1, env.d * (env.k - env.d) + 1
        elif env_name == 'env_gamestate':
            from environment.env_gamestate import Environment
            if env_dir:
                env = Environment.load(env_dir)
            else:
                raise NotImplementedError(
                    'we enforce environment has been created')
            n_features, n_actions = 2 * env.k + 1, env.d * (env.k - env.d) + 1

        return env, n_features, n_actions

    def path_check(self, load):
        save_and_load_path_dir = os.path.dirname(self.save_and_load_path)
        if load:
            assert os.path.exists(
                save_and_load_path_dir
            ), "model path not exist:" + save_and_load_path_dir
        else:
            os.makedirs(save_and_load_path_dir, exist_ok=True)
            # remove old existing models if any
            files = glob.glob(save_and_load_path_dir + '/*')
            for file in files:
                os.remove(file)

    def save_model(self):
        # save tensorflow
        with self.graph.as_default():
            saver = tf.train.Saver()
            path = saver.save(self.sess, self.save_and_load_path)
        # save memory
        self.memory_lock.acquire()
        with open(self.save_and_load_path + '_memory.pickle', 'wb') as f:
            pickle.dump(self.memory, f, protocol=-1)  # -1: highest protocol
        self.memory_lock.release()
        # save variables
        with open(self.save_and_load_path + '_variables.pickle', 'wb') as f:
            pickle.dump(
                (self.random_seed, self.tensorboard_path, self.logger_path,
                 self.n_hidden, self.lr, self.gamma, self.epsilon_max,
                 self.save_model_iter, self.memory_capacity,
                 self.memory_capacity_start_learning,
                 self.learn_wall_time_limit, self.batch_size,
                 self.epsilon_increment, self.epsilon, self.prioritized,
                 self.trial_size, self.replace_target_iter, self.planning,
                 self.learn_iterations, self.sample_iterations,
                 self.learn_wall_time, self.sample_wall_time, self.cpu_time,
                 self.wall_time, self.last_test_learn_iterations),
                f,
                protocol=-1)
        self.last_save_time = time.time()
        print('save model to', path)

    def load_model(self):
        # load tensorflow
        with self.graph.as_default():
            saver = tf.train.import_meta_graph(self.save_and_load_path +
                                               '.meta')
            saver.restore(
                self.sess,
                tf.train.latest_checkpoint(
                    os.path.dirname(self.save_and_load_path)))
            # placeholders
            self.s = self.graph.get_tensor_by_name('s:0')  # Q(s,a) feature
            self.s_ = self.graph.get_tensor_by_name('s_:0')  # Q(s',a') feature
            self.rewards = self.graph.get_tensor_by_name('reward:0')  # reward
            self.terminal_weights = self.graph.get_tensor_by_name(
                'terminal:0')  # terminal
            # variables
            self.q_eval = self.graph.get_tensor_by_name('eval_net/q_eval:0')
            self.eval_w1 = self.graph.get_tensor_by_name('eval_net/l1/w1:0')
            self.eval_b1 = self.graph.get_tensor_by_name('eval_net/l1/b1:0')
            self.eval_w2 = self.graph.get_tensor_by_name('eval_net/l2/w2:0')
            self.eval_b2 = self.graph.get_tensor_by_name('eval_net/l2/b2:0')
            self.q_next = self.graph.get_tensor_by_name('eval_net/q_next:0')
            self.q_target = self.graph.get_tensor_by_name("q_target:0")
            self.is_weights = self.graph.get_tensor_by_name("is_weights:0")
            self.loss = self.graph.get_tensor_by_name("loss:0")
            self.abs_errors = self.graph.get_tensor_by_name("abs_errors:0")
            # operations
            self.train_op = self.graph.get_operation_by_name('train_op')
        # load memory
        with open(self.save_and_load_path + '_memory.pickle', 'rb') as f:
            self.memory = pickle.load(f)  # -1: highest protocol
        # load variables
        with open(self.save_and_load_path + '_variables.pickle', 'rb') as f:
            self.random_seed, \
            self.tensorboard_path, self.logger_path, \
            self.n_hidden, \
            self.lr, self.gamma, \
            self.epsilon_max, self.save_model_iter, \
            self.memory_capacity, self.memory_capacity_start_learning, \
            self.learn_wall_time_limit, self.batch_size, \
            self.epsilon_increment, self.epsilon, \
            self.prioritized, self.trial_size, \
            self.replace_target_iter, self.planning, \
            self.learn_iterations, \
            self.sample_iterations, \
            self.learn_wall_time, \
            self.sample_wall_time, \
            self.last_cpu_time, \
            self.last_wall_time, \
            self.last_test_learn_iterations = pickle.load(f)

        numpy.random.seed(self.random_seed)
        tf.set_random_seed(self.random_seed)

        self.tb_writer = TensorboardWriter(folder_name=self.tensorboard_path,
                                           session=self.sess)
        self.logger = Logger(self.logger_path)
        self.last_save_time = time.time()

    def _build_net(self):
        self.s = tf.placeholder(tf.float32, [None, self.n_features],
                                name='s')  # Q(s,a) feature
        self.s_ = tf.placeholder(tf.float32,
                                 [None, self.n_actions, self.n_features],
                                 name='s_')  # Q(s',a') feature
        self.rewards = tf.placeholder(tf.float32, [None],
                                      name='reward')  # reward
        self.terminal_weights = tf.placeholder(tf.float32, [None],
                                               name='terminal')  # terminal

        w_initializer, b_initializer = \
            tf.random_normal_initializer(0., 0.3), tf.constant_initializer(0.1)  # config of layers

        # ------------------ build evaluate_net ------------------
        with tf.variable_scope('eval_net'):
            # s is Q(s,a) feature, shape: (n_sample, n_features)
            # s_ Q(s',a') for all a' feature, shape: (n_sample, n_actions, n_features)
            with tf.variable_scope('l1'):
                self.eval_w1 = tf.get_variable(
                    'w1', [self.n_features, self.n_hidden],
                    initializer=w_initializer)
                self.eval_b1 = tf.get_variable('b1', [self.n_hidden],
                                               initializer=b_initializer)
                # l1 shape: (n_sample, n_hidden)
                l1 = tf.nn.relu(tf.matmul(self.s, self.eval_w1) + self.eval_b1)
                # l1_ shape: shape: (n_sample, n_actions, n_hidden)
                l1_ = tf.nn.relu(
                    tf.einsum('ijk,kh->ijh', self.s_, self.eval_w1) +
                    self.eval_b1)
            with tf.variable_scope('l2'):
                self.eval_w2 = tf.get_variable('w2', [self.n_hidden, 1],
                                               initializer=w_initializer)
                self.eval_b2 = tf.get_variable('b2', [1],
                                               initializer=b_initializer)
                # out shape: (n_sample, 1)
                out = tf.matmul(l1, self.eval_w2) + self.eval_b2
                # out_ shape: (n_sample, n_actions, 1), Q(s',a') for all a' feature
                out_ = tf.einsum('ijh,ho->ijo', l1_,
                                 self.eval_w2) + self.eval_b2
            self.q_eval = tf.squeeze(out, name='q_eval')
            self.q_next = tf.squeeze(out_, name='q_next')

        # ------------------ loss function ----------------------
        self.q_target = tf.add(
            self.rewards,
            self.terminal_weights *
            (self.gamma * tf.reduce_max(self.q_next, axis=1)),
            name='q_target')
        # We do not want the target to be used for computing the gradient
        self.q_target = tf.stop_gradient(self.q_target)
        # importance sampling weight
        self.is_weights = tf.placeholder(tf.float32, [None], name='is_weights')
        self.loss = tf.reduce_mean(
            self.is_weights *
            tf.squared_difference(self.q_target, self.q_eval),
            name='loss')
        self.abs_errors = tf.abs(self.q_target - self.q_eval,
                                 name='abs_errors')
        self.train_op = tf.train.RMSPropOptimizer(self.lr).minimize(
            self.loss, name='train_op')

    def store_transition(self, s, a, r, s_, terminal):
        self.memory_lock.acquire()
        # transition is a tuple (current_state, action, reward, next_state, whether_terminal)
        self.memory.store((s, a, r, s_, terminal))
        self.memory_lock.release()

    def update_memory_priority(self, exp_ids, abs_errors):
        """ update memory priority """
        self.memory_lock.acquire()
        self.memory.update_priority(exp_ids, abs_errors)
        self.memory_lock.release()

    def choose_action(self,
                      state,
                      next_possible_states,
                      next_possbile_actions,
                      epsilon_greedy=True):
        pred_q_values = self.sess.run(self.q_eval,
                                      feed_dict={
                                          self.s: next_possible_states
                                      }).flatten()
        if not epsilon_greedy or np.random.uniform() < self.epsilon:
            action_idx = np.argmax(pred_q_values)
        else:
            action_idx = np.random.choice(
                numpy.arange(len(next_possbile_actions)))
        action = next_possbile_actions[action_idx]
        pred_q_value = pred_q_values[action_idx]
        return action, pred_q_value

    # def _replace_target_params(self):
    #     with self.graph.as_default():
    #         t_params = tf.get_collection('target_net_params')
    #         e_params = tf.get_collection('eval_net_params')
    #         self.sess.run([tf.assign(t, e) for t, e in zip(t_params, e_params)])
    #         print('target_params_replaced')

    def planning_learn(self, qsa_next_features, qsa_features):
        """ additional learning from planning """
        raise NotImplementedError

    @property
    def cpu_time(self):
        cpu_time = psutil.Process().cpu_times()
        return cpu_time.user + cpu_time.system + cpu_time.children_system + cpu_time.children_user + self.last_cpu_time

    @property
    def wall_time(self):
        return time.time() - psutil.Process().create_time(
        ) + self.last_wall_time

    def learn(self):
        while True:
            if self.wall_time > self.learn_wall_time_limit:
                break

            if self.memory_size() < self.memory_capacity_start_learning:
                print('LEARN:{}:wait for more samples:wall time:{}'.format(
                    self.learn_iterations, self.wall_time))
                time.sleep(2)
                continue

            # don't learn too fast
            if self.learn_iterations > self.sample_iterations > 0:
                time.sleep(0.2)
                continue

            learn_time = time.time()
            qsa_feature, qsa_next_features, rewards, terminal_weights, is_weights, exp_ids \
                = self.memory.sample()

            _, loss, abs_errors = self.sess.run(
                [self.train_op, self.loss, self.abs_errors],
                feed_dict={
                    self.s: qsa_feature,
                    self.s_: qsa_next_features,
                    self.rewards: rewards,
                    self.terminal_weights: terminal_weights,
                    self.is_weights: is_weights
                })

            if self.prioritized:
                self.update_memory_priority(exp_ids, abs_errors)
                mem_total_p = self.memory.memory.tree.total_p
            else:
                mem_total_p = -1

            if self.planning:
                self.planning_learn()

            self.epsilon = self.cur_epsilon()

            learn_time = time.time() - learn_time
            self.learn_iterations += 1
            self.learn_wall_time += learn_time

            print(
                'LEARN:{}:mem_size:{}:virtual:{}:wall_t:{:.2f}:total:{:.2f}:cpu_time:{:.2f}:pid:{}:wall_t:{:.2f}:mem_p:{:.2f}'
                .format(self.learn_iterations, self.memory_size(),
                        self.memory_virtual_size(),
                        learn_time, self.learn_wall_time, self.cpu_time,
                        os.getpid(), self.wall_time, mem_total_p))

    def cur_epsilon(self):
        return self.epsilon + self.epsilon_increment if self.epsilon < self.epsilon_max else self.epsilon_max

    def tb_write(self, tags, values, step):
        """ write to tensorboard """
        if self.tb_writer:
            self.tb_writer.write(tags, values, step)

    def get_logger(self):
        return self.logger

    def memory_size(self):
        return self.memory.size

    def memory_virtual_size(self):
        return self.memory.virtual_size

    def function_call_counts_training(self):
        """ number of function calls during training, which equals to memory virtual size """
        return self.memory.virtual_size

    def collect_samples(self, EPISODE_SIZE, TEST_PERIOD):
        """ collect samples in a process """
        for i_episode in range(self.sample_iterations, EPISODE_SIZE):
            if self.wall_time > self.learn_wall_time_limit:
                self.save_model()
                break

            # don't sample too fast
            while 0 < self.learn_iterations < self.sample_iterations - 3:
                time.sleep(0.2)

            sample_wall_time = time.time()
            cur_state = self.env.reset()

            for i_episode_step in range(self.trial_size):
                # prevent wall time over limit during sampling
                if self.wall_time > self.learn_wall_time_limit:
                    self.save_model()
                    break

                # save every 6 min
                if time.time() - self.last_save_time > 6 * 60:
                    self.save_model()

                next_possible_states, next_possible_actions = self.env.all_possible_next_state_action(
                    cur_state)
                action, _ = self.choose_action(cur_state,
                                               next_possible_states,
                                               next_possible_actions,
                                               epsilon_greedy=True)
                cur_state_, reward = self.env.step(action)
                terminal = True if i_episode_step == self.trial_size - 1 else False
                self.store_transition(cur_state, action, reward, cur_state_,
                                      terminal)
                cur_state = cur_state_

            sample_wall_time = time.time() - sample_wall_time
            self.sample_iterations += 1
            self.sample_wall_time += sample_wall_time

            # end_state distilled output = reward (might be noisy)
            end_output = self.env.still(reward)
            mem_total_p = -1 if not self.prioritized else self.memory.memory.tree.total_p
            print(
                'SAMPLE:{}:finished output:{:.5f}:cur_epsilon:{:.5f}:mem_size:{}:virtual:{}:wall_t:{:.2f}:total:{:.2f}:pid:{}:wall_t:{:.2f}:mem_p:{:.2f}'
                .format(self.sample_iterations, end_output, self.cur_epsilon(),
                        self.memory_size(), self.memory_virtual_size(),
                        sample_wall_time, self.sample_wall_time, os.getpid(),
                        self.wall_time, mem_total_p))

            # test every once a while
            if self.memory_virtual_size() >= self.memory_capacity_start_learning \
                    and self.learn_iterations % TEST_PERIOD == 0 \
                    and self.learn_iterations > self.last_test_learn_iterations \
                    and self.wall_time < self.learn_wall_time_limit:
                #self.env.test(TRIAL_SIZE, RANDOM_SEED, self.learn_step_counter, self.wall_time, self.env_name,
                #               rl_model=self)
                max_val_rl, max_state_rl, end_val_rl, end_state_rl, duration_rl, _, _ = self.exp_test(
                )

                max_val_mc, max_state_mc, _, _, duration_mc, _ = self.env.monte_carlo(
                )
                self.logger.log_test(output_mc=max_val_mc,
                                     state_mc=max_state_mc,
                                     duration_mc=duration_mc,
                                     output_rl=max_val_rl,
                                     state_rl=max_state_rl,
                                     duration_rl=duration_rl,
                                     learn_step_counter=self.learn_iterations,
                                     wall_time=self.wall_time)

                self.tb_write(
                    tags=[
                        'Prioritized={0}, gamma={1}, seed={2}, env={3}, fixed_xo={4}/(Max_RL-MC)'
                        .format(self.prioritized, self.gamma, self.random_seed,
                                self.env_name, self.env.if_set_fixed_xo()),
                        'Prioritized={0}, gamma={1}, seed={2}, env={3}, fixed_xo={4}/Ending Output (RL)'
                        .format(self.prioritized, self.gamma, self.random_seed,
                                self.env_name, self.env.if_set_fixed_xo()),
                    ],
                    values=[max_val_rl - max_val_mc,
                            end_val_rl],  # note we record end value for RL
                    step=self.learn_iterations)

                self.last_test_learn_iterations = self.learn_iterations

    def exp_test(self, debug=True):
        """
        If debug is true, find the max output along the search.
        If debug is false, only return the end output
        """
        cur_state = self.env.reset()
        duration = time.time()
        start_state = cur_state.copy()
        end_output = max_output = -99999.
        max_state = None

        for i in range(self.trial_size):
            next_possible_states, next_possible_actions = self.env.all_possible_next_state_action(
                cur_state)
            action, q_val = self.choose_action(cur_state,
                                               next_possible_states,
                                               next_possible_actions,
                                               epsilon_greedy=False)
            if debug:
                # reward is noisy output
                cur_state, reward = self.env.step(action)
                # noiseless, stilled end output
                end_output = self.env.still(
                    self.env.output_noiseless(cur_state))
                print(
                    'TEST  :{}:output: {:.5f}, qval: {:.5f}, reward {:.5f}, at {}'
                    .format(i, end_output, q_val, reward, cur_state))
                if end_output > max_output:
                    max_output = end_output
                    max_state = cur_state.copy()
            else:
                cur_state = self.env.step_without_reward(action)
                print('TEST  :{}:qval: {:.5f}, at {}'.format(
                    i, q_val, cur_state))

        duration = time.time() - duration
        end_state = cur_state
        if not debug:
            end_output = self.env.still(self.env.output_noiseless(cur_state))

        if_set_fixed_xo = self.env.if_set_fixed_xo()

        return max_output, max_state, end_output, end_state, duration, if_set_fixed_xo, start_state

    # very adhoc methods to query environment's information
    def set_env_fixed_xo(self, x_o):
        self.env.set_fixed_xo(x_o)

    def get_env_if_set_fixed_xo(self):
        return self.env.if_set_fixed_xo()

    def get_learn_iteration(self):
        return self.learn_iterations

    def get_wall_time(self):
        return self.wall_time