Exemple #1
0
 def test_summarizer(self):
     # Bulk Tests
     with tf.Graph().as_default():
         x = tf.placeholder("float", [None, 4])
         W = tf.Variable(tf.random_normal([4, 4]))
         x = tf.nn.tanh(tf.matmul(x, W))
         tf.add_to_collection(tf.GraphKeys.ACTIVATIONS, x)
         import tflearn.helpers.summarizer as s
         s.summarize_variables([W])
         s.summarize_activations(tf.get_collection(tf.GraphKeys.ACTIVATIONS))
         s.summarize(x, 'histogram', "test_summary")
Exemple #2
0
 def test_summarizer(self):
     # Bulk Tests
     with tf.Graph().as_default():
         x = tf.placeholder("float", [None, 4])
         W = tf.Variable(tf.random_normal([4, 4]))
         x = tf.nn.tanh(tf.matmul(x, W))
         tf.add_to_collection(tf.GraphKeys.ACTIVATIONS, x)
         import tflearn.helpers.summarizer as s
         s.summarize_variables([W])
         s.summarize_activations(tf.get_collection(
             tf.GraphKeys.ACTIVATIONS))
         s.summarize(x, 'histogram', "test_summary")
Exemple #3
0
    def train(self, tf_record_dir, save_dir, num_epochs, batch_size, num_save_every,
              model_file=None, early_stop=False, full_eval_every=0, learning_rate=1e-3, lr_steps=0, lr_decay=0.96, avg = False):
        """Trains the Model defined in module on the records in tf_record_dir"""
        train_loss_iterations = {'iteration': [], 'epoch': [], 'train_loss': [], 'train_dice': [],
                                 'train_mse': [], 'val_loss': [], 'val_dice': [], 'val_mse': []}
        meta_data_filepath = os.path.join(tf_record_dir, 'meta_data.txt')
        with open(meta_data_filepath, 'r') as f:
            meta_data = json.load(f)
        num_examples = sum([x[1] for x in meta_data['train_examples'].items()])
        num_batches_per_epoch = num_examples//batch_size
        num_batches = math.ceil(num_epochs*num_batches_per_epoch) if num_epochs else 0
        num_full_validation_every = full_eval_every if full_eval_every else num_batches_per_epoch
        validation_tups = meta_data['validation_tups']
        full_validation_metrics = {k[0]:[] for  k in validation_tups}
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
            write_csv(os.path.join(save_dir, 'log.csv'), train_loss_iterations)
            write_csv(os.path.join(save_dir, 'full_validation_log.csv'), full_validation_metrics)
        with tf.Graph().as_default():
            config = tf.compat.v1.ConfigProto()
            config.gpu_options.allow_growth = True
            tflearn.config.init_training_mode()
            with tf.compat.v1.name_scope('training'):
                model = self.module.Model(batch_size, False, tf_record_dir, num_epochs)
            with tf.compat.v1.name_scope('eval'):
                model_eval = self.module.Model(batch_size, True, tf_record_dir,
                                               num_epochs) if validation_tups else None
            model_test = self.module.Model(batch_size, True)
            inferer = model_test.build_full_inferer()
            avg_time = 0
            global_step = model.global_step #tf.Variable(0, name='global_step', trainable=False)
            lr = model.lr #tf.train.exponential_decay(learning_rate, global_step, lr_steps, lr_decay, staircase=True) if lr_steps else learning_rate
            optimizer = model.optimizer #tf.train.AdamOptimizer(lr  of decay_steps)
            tf.compat.v1.summary.scalar('learning_rate', lr)
            update_op = self._avg(model.loss_op, optimizer, global_step) if avg else optimizer.minimize(model.loss_op,
                                                                                                        global_step=global_step)
            #update_op = optimizer.minimize(model.loss_op,
            #                               global_step=global_step)

            #config = tf.ConfigProto()
            #config.gpu_options.allow_growth = True
            with tf.compat.v1.Session(config=config) as sess:
                merged = s.summarize_variables()
                merged = tf.compat.v1.summary.merge_all()
                summary_writer = tf.compat.v1.summary.FileWriter(save_dir, sess.graph)
                sess.run(tf.compat.v1.local_variables_initializer())
                sess.run(tf.compat.v1.global_variables_initializer())
                saver = tf.compat.v1.train.Saver()
                saver_epoch = tf.compat.v1.train.Saver(max_to_keep=None)
                coord = tf.train.Coordinator()
                if model_file:
                    if avg:
                        variable_averages = tf.train.ExponentialMovingAverage(0.999)
                        vars_to_restore = variable_averages.variables_to_restore()
                        vars_to_restore = model.filter_vars(vars_to_restore)
                        restore_saver = (tf.compat.v1.train.Saver(vars_to_restore)
                                         if vars_to_restore else tf.compat.v1.train.Saver())
                        restore_saver.restore(sess, model_file)
                    else:
                        vars_to_restore = model.filter_vars(tf.compat.v1.global_variables())
                        restore_saver = (tf.compat.v1.train.Saver(vars_to_restore)
                                         if vars_to_restore else tf.compat.v1.train.Saver())
                        restore_saver.restore(sess, model_file)
                try:
                    continue_training = True
                    epoch = previous_epoch = 0
                    while continue_training and not coord.should_stop():
                        tflearn.is_training(True)
                        cur_step = sess.run(global_step)
                        previous_epoch = epoch
                        epoch = cur_step//num_batches_per_epoch
                        start = time.time()
                        if epoch != previous_epoch: saver = tf.compat.v1.train.Saver()
                        #print(tf.get_collection('summaries'))
                        #if num_batches_val and cur_step % num_save_every == 0:
                        message_string = ''
                        if num_save_every and cur_step % num_save_every == 0:
                            tflearn.is_training(False)
                            start = time.time()
                            (train_loss, train_dice, val_loss, val_dice, summaries,
                             train_mse, val_mse) = sess.run([model.loss_op,
                                                             model.dice_op,
                                                             model_eval.loss_op,
                                                             model_eval.dice_op,
                                                             merged, model.mse,
                                                             model_eval.mse,
                                                             ] +
                                                            model_eval.metric_update_ops +
                                                            model.metric_update_ops)[0:7]
                            summary_writer.add_summary(summaries, cur_step)
                            message_string = (' val_loss: {:.3f}, val_dice: {:.3f}, mse: {:.5f}'
                                              .format(val_loss, val_dice, val_mse))
                            non_zero_losses = [x for x in  train_loss_iterations['val_loss']
                                               if x is not None]
                            non_zero_dices = [x for x in  train_loss_iterations['val_dice']
                                              if x is not None]
                            non_zero_val_mses = [x for x in  train_loss_iterations['val_mse']
                                                 if x is not None and x > 0]
                            should_save = (val_loss < np.percentile(non_zero_losses, 25)
                                           if non_zero_losses else True)
                            should_save = (should_save and val_dice > np.percentile(non_zero_dices, 75)
                                           if non_zero_dices else True)
                            min_mse = 0 if not non_zero_val_mses else np.min(non_zero_val_mses)
                            should_save = should_save and val_mse < min_mse and min_mse
                            train_loss_iterations['val_loss'].append(val_loss)
                            train_loss_iterations['val_dice'].append(val_dice)
                            train_loss_iterations['val_mse'].append(val_mse)
                            if should_save:
                                checkpoint_path = os.path.join(save_dir, 'epoch_'
                                                               + str(epoch) + '_model.ckpt')
                                saver.save(sess, checkpoint_path, global_step=global_step)
                                print("model saved to {}".format(checkpoint_path))
                            non_zero_val_mses = [x for x in  train_loss_iterations['val_mse']
                                                 if x is not None and x > 0]
                            if early_stop:
                                continue_training = (np.median(non_zero_val_mses[-10:])
                                                     <= np.min(non_zero_val_mses[:-10])
                                                     if non_zero_val_mses[:-10] and epoch > 0
                                                     else True)
                        else:
                            train_loss_iterations['val_loss'].append(None)
                            train_loss_iterations['val_dice'].append(None)
                            train_loss_iterations['val_mse'].append(None)
                            train_loss, train_dice, train_mse = sess.run([model.loss_op,
                                                                          model.dice_op,
                                                                          model.mse
                                                                          ])[0:3]
                        train_loss_iterations['iteration'].append(cur_step)
                        train_loss_iterations['epoch'].append(epoch)
                        train_loss_iterations['train_loss'].append(train_loss)
                        train_loss_iterations['train_dice'].append(train_dice)
                        train_loss_iterations['train_mse'].append(train_mse)
                        write_csv(os.path.join(save_dir, 'log.csv'),
                                  {k:([v[-1]] if v else [])
                                   for k, v in train_loss_iterations.items()},
                                  mode='a', header=False)
                        end = time.time()
                        avg_time = avg_time + ((end-start)-avg_time)/cur_step if cur_step else end - start
                        time_remaining_s = (num_batches - cur_step)*avg_time if num_epochs else 0
                        t = timedelta(seconds=time_remaining_s)
                        time_remaining_string = ("time left: {}m-{}d {} (h:mm:ss)"
                                                 .format(t.days/30, t.days%30,
                                                         timedelta(seconds=t.seconds)))
                        message_string = ("{}/{} (epoch {}), train_loss = {:.3f}, dice = {:.3f}, accuracy = {:.3f}, time/batch = {:.3f}, "
                                          .format(cur_step, num_batches, epoch, train_loss, train_dice,
                                                  0, end - start) + time_remaining_string +
                                          message_string)
                        print(message_string)
                        if all([num_full_validation_every,
                                cur_step % num_full_validation_every == 0,
                                cur_step]):
                            tflearn.is_training(False)
                            start = time.time()
                            test_save_dir = os.path.join(save_dir, 'val_preds_' + str(cur_step))
                            if not os.path.exists(test_save_dir): os.makedirs(test_save_dir)
                            for tup in validation_tups:
                                dice = inferer(sess, tup, test_save_dir, model_test)
                                full_validation_metrics[tup[0]].append(dice)
                                print(dice)
                            end = time.time()
                            avg_time = avg_time + ((end-start)/num_full_validation_every
                                                   - avg_time)/cur_step if cur_step else avg_time
                            checkpoint_path = os.path.join(save_dir, 'epoch_model.ckpt')
                            saver_epoch.save(sess, checkpoint_path, global_step=global_step)
                            print("epoch model saved to {}".format(checkpoint_path))
                            write_csv(os.path.join(save_dir, 'full_validation_log.csv'),
                                      {k:([v[-1]] if v else [])
                                       for k, v in full_validation_metrics.items()},
                                      mode='a',
                                      header=False)
                        #Run optimiser after saving/evaluating the model 
                        sess.run(update_op)
                except tf.errors.OutOfRangeError:
                    print('Done training')
                finally:
                    checkpoint_path = os.path.join(save_dir, 'model.ckpt')
                    saver.save(sess, checkpoint_path, global_step=global_step)
                    coord.request_stop()
                print('Finished')
    def create_network_graph(self):
        input_shape = self._input_shape
        output_num = self._output_num
        # Input placeholders
        with tf.name_scope('input'):
            # we need to fix the input shape from (batch, filter, height, width) to
            # tensorflow which is (batch, height, width, filter)
            self._t_x_input_channel_firstdim = tf.placeholder(tf.uint8, [None] + input_shape, name='x-input')
            # transpose because tf wants channels on last dim and channels are passed in on 2nd dim
            self._t_x_input = tf.cast(tf.transpose(self._t_x_input_channel_firstdim, perm=[0, 2, 3, 1]), tf.float32) / 255.0
            self._t_x_input_tp1_channel_firstdim = tf.placeholder(tf.uint8, [None] + input_shape, name='x-input-tp1')
            # transpose because tf wants channels on last dim and channels are passed in on 2nd dim
            self._t_x_input_tp1 = tf.cast(tf.transpose(self._t_x_input_tp1_channel_firstdim, perm=[0, 2, 3, 1]), tf.float32) / 255.0
            self._t_x_actions = tf.placeholder(tf.int32, shape=[None], name='x-actions')
            self._t_x_rewards = tf.placeholder(tf.float32, shape=[None], name='x-rewards')
            self._t_x_terminals = tf.placeholder(tf.bool, shape=[None], name='x-terminals')
            self._t_x_discount = self._q_discount

        # Target network does not reuse variables
        with tf.variable_scope('network') as var_scope:
            self._t_network_output = self._network_generator(self._t_x_input, output_num)

            # get the trainable variables for this network, later used to overwrite target network vars
            self._tf_network_trainables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network')

            # summarize activations
            summarizer.summarize_activations(tf.get_collection(tf.GraphKeys.ACTIVATIONS, scope='network'))

            # if double DQN then we need to create network output for s_tp1
            if self.algorithm_type == 'double' or self.algorithm_type == 'doublenstep':
                var_scope.reuse_variables()
                self._t_network_output_tp1 = self._network_generator(self._t_x_input_tp1, output_num)

            # summarize a histogram of each action output
            for output_ind in range(output_num):
                summarizer.summarize(self._t_network_output[:, output_ind], 'histogram', 'network-output/{0}'.format(output_ind))

            # add network summaries
            summarizer.summarize_variables(train_vars=self._tf_network_trainables)

        with tf.variable_scope('target-network'):
            self._t_target_network_output = self._network_generator(self._t_x_input_tp1, output_num)

            # get trainables for target network, used in assign op for the update target network step
            target_network_trainables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='target-network')

        # update target network with network variables
        with tf.name_scope('update-target-network'):
            self._tf_update_target_network_ops = [target_v.assign(v) for v, target_v in zip(self._tf_network_trainables, target_network_trainables)]

        # if double convience function to get target values for online action
        if self.algorithm_type == 'double' or self.algorithm_type == 'doublenstep':
            with tf.name_scope('double_target'):
                # Target = target_Q(s_tp1, argmax(online_Q(s_tp1)))
                argmax_tp1 = tf.argmax(self._t_network_output_tp1, axis=1)
                self._t_target_value_online_action = tf_util.one_hot(self._t_target_network_output, argmax_tp1, output_num)

        # caclulate QLoss
        with tf.name_scope('loss'):
            # nstep rewards are calculated outside the gpu/graph because it requires a loop
            if self.algorithm_type != 'nstep' and self.algorithm_type != 'doublenstep':
                with tf.name_scope('estimated-reward-tp1'):
                    if self.algorithm_type == 'double':
                        # Target = target_Q(s_tp1, argmax(online_Q(s_tp1)))
                        target = self._t_target_value_online_action
                    elif self.algorithm_type == 'dqn':
                        # Target = max(target_Q(s_tp1))
                        target = tf.reduce_max(self._t_target_network_output, axis=1)

                    # compute a mask that returns gamma (discount factor) or 0 if terminal
                    terminal_discount_mask = tf.multiply(1.0 - tf.cast(self._t_x_terminals, tf.float32), self._t_x_discount)
                    est_rew_tp1 = tf.multiply(terminal_discount_mask, target)

                y = self._t_x_rewards + tf.stop_gradient(est_rew_tp1)
            # else nstep
            else:
                y = self._t_x_rewards

            with tf.name_scope('estimated-reward'):
                est_rew = tf_util.one_hot(self._t_network_output, self._t_x_actions, output_num)

            with tf.name_scope('qloss'):
                # clip loss but keep linear past clip bounds (huber loss with customizable linear part)
                # REFS: https://github.com/spragunr/deep_q_rl/blob/master/deep_q_rl/q_network.py#L108
                # https://github.com/Jabberwockyll/deep_rl_ale/blob/master/q_network.py#L241
                diff = y - est_rew

                if self._loss_clipping > 0.0:
                    abs_diff = tf.abs(diff)
                    # same as min(diff, loss_clipping) because diff can never be negative (definition of abs value)
                    quadratic_part = tf.clip_by_value(abs_diff, 0.0, self._loss_clipping)
                    linear_part = abs_diff - quadratic_part
                    loss = (0.5 * tf.square(quadratic_part)) + (self._loss_clipping * linear_part)
                else:
                    # But why multiply the loss by 0.5 when not clipping? https://groups.google.com/forum/#!topic/deep-q-learning/hKK0ZM_OWd4
                    loss = 0.5 * tf.square(diff)
                # NOTICE: we are summing gradients
                error = tf.reduce_sum(loss)
            summarizer.summarize(error, 'scalar', 'loss')

        # optimizer
        with tf.name_scope('shared-optimizer'):
            self._tf_learning_rate = tf.placeholder(tf.float32)
            optimizer = self._optimizer_fn(learning_rate=self._tf_learning_rate)
            # only train the network vars not the target network
            gradients = optimizer.compute_gradients(error, var_list=self._tf_network_trainables)
            # gradients are stored as a tuple, (gradient, tensor the gradient corresponds to)
            # kinda lame that clip by global norm doesn't accept the list of tuples returned from compute_gradients
            # so we unzip then zip
            tensors = [tensor for gradient, tensor in gradients]
            grads = [gradient for gradient, tensor in gradients]
            clipped_gradients, _ = tf.clip_by_global_norm(grads, self.global_norm_clipping)  # returns list[tensors], norm
            clipped_grads_tensors = zip(clipped_gradients, tensors)
            self._tf_train_step = optimizer.apply_gradients(clipped_grads_tensors)
            # tflearn smartly knows how gradients are stored so we just pass in the list of tuples
            summarizer.summarize_gradients(clipped_grads_tensors)

            # tf learn auto merges all summaries so we just have to grab the last output
            self._tf_summaries = summarizer.summarize(self._tf_learning_rate, 'scalar', 'learning-rate')
    def create_network_graph(self, input_shape, output_num, network_generator, q_discount, optimizer, loss_clipping):
        # Input placeholders
        with tf.name_scope('input'):
            # we need to fix the input shape from (batch, filter, height, width) to
            # tensorflow which is (batch, height, width, filter)
            x_input_channel_firstdim = tf.placeholder(tf.uint8, [None] + input_shape, name='x-input')
            # transpose because tf wants channels on last dim and channels are passed in on 2nd dim
            x_input = tf.cast(tf.transpose(x_input_channel_firstdim, perm=[0, 2, 3, 1]), tf.float32) / 255.0
            x_input_tp1_channel_firstdim = tf.placeholder(tf.uint8, [None] + input_shape, name='x-input-tp1')
            # transpose because tf wants channels on last dim and channels are passed in on 2nd dim
            x_input_tp1 = tf.cast(tf.transpose(x_input_tp1_channel_firstdim, perm=[0, 2, 3, 1]), tf.float32) / 255.0
            x_actions = tf.placeholder(tf.int32, shape=[None], name='x-actions')
            x_rewards = tf.placeholder(tf.float32, shape=[None], name='x-rewards')
            x_terminals = tf.placeholder(tf.bool, shape=[None], name='x-terminals')
            x_discount = q_discount

        # Target network does not reuse variables. so we use two different variable scopes
        with tf.variable_scope('network'):
            network_output = network_generator(x_input, output_num)

            # summarize a histogram of each action output
            for output_ind in range(output_num):
                summarizer.summarize(network_output[:, output_ind], 'histogram', 'network-output/{0}'.format(output_ind))

            # get the trainable variables for this network, later used to overwrite target network vars
            network_trainables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network')

            # summarize activations
            summarizer.summarize_activations(tf.get_collection(tf.GraphKeys.ACTIVATIONS, scope='network'))

            # add network summaries
            summarizer.summarize_variables(train_vars=network_trainables)

        with tf.variable_scope('target-network'):
            target_network_output = network_generator(x_input_tp1, output_num)

            # get trainables for target network, used in assign op for the update target network step
            target_network_trainables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='target-network')

            # summarize activations
            summarizer.summarize_activations(tf.get_collection(tf.GraphKeys.ACTIVATIONS, scope='target-network'))

            # add network summaries
            summarizer.summarize_variables(train_vars=target_network_trainables)

        # update target network with network variables
        with tf.name_scope('update-target-network'):
            update_target_network_ops = [target_v.assign(v) for v, target_v in zip(network_trainables, target_network_trainables)]

        # caclulate QLoss
        with tf.name_scope('loss'):
            with tf.name_scope('estimated-reward-tp1'):
                one_minus_term = tf.mul(1.0 - tf.cast(x_terminals, tf.float32), x_discount)
                est_rew_tp1 = tf.mul(one_minus_term, tf.reduce_max(target_network_output, reduction_indices=1))

            y = x_rewards + tf.stop_gradient(est_rew_tp1)

            with tf.name_scope('estimated-reward'):
                # Because of https://github.com/tensorflow/tensorflow/issues/206
                # we cannot use numpy like indexing so we convert to a one hot
                # multiply then take the max over last dim
                # NumPy/Theano est_rew = network_output[:, x_actions]
                x_actions_one_hot = tf.one_hot(x_actions, depth=output_num, name='one-hot',
                                               on_value=1.0, off_value=0.0, dtype=tf.float32)
                # we reduce sum here because the output could be negative we can't take the max
                # the other indecies will be 0
                est_rew = tf.reduce_sum(tf.mul(network_output, x_actions_one_hot), reduction_indices=1)

            with tf.name_scope('qloss'):
                # clip loss but keep linear past clip bounds
                # REFS: https://github.com/spragunr/deep_q_rl/blob/master/deep_q_rl/q_network.py#L108
                # https://github.com/Jabberwockyll/deep_rl_ale/blob/master/q_network.py#L241
                diff = y - est_rew

                if loss_clipping > 0.0:
                    abs_diff = tf.abs(diff)
                    # same as min(diff, loss_clipping) because diff can never be negative (definition of abs value)
                    quadratic_part = tf.clip_by_value(abs_diff, 0.0, loss_clipping)
                    linear_part = abs_diff - quadratic_part
                    loss = (0.5 * tf.square(quadratic_part)) + (loss_clipping * linear_part)
                else:
                    # But why multiply the loss by 0.5 when not clipping? https://groups.google.com/forum/#!topic/deep-q-learning/hKK0ZM_OWd4
                    loss = 0.5 * tf.square(diff)
                # NOTICE: we are summing gradients
                error = tf.reduce_sum(loss)
            summarizer.summarize(error, 'scalar', 'loss')

        # optimizer
        with tf.name_scope('shared-optimizer'):
            tf_learning_rate = tf.placeholder(tf.float32)
            optimizer = optimizer(learning_rate=tf_learning_rate)
            # only train the network vars not the target network
            tf_train_step = optimizer.minimize(error, var_list=network_trainables)

            # tf learn auto merges all summaries so we just have to grab the last output
            tf_summaries = summarizer.summarize(tf_learning_rate, 'scalar', 'learning-rate')

        # function to get network output
        def get_output(sess, state):
            feed_dict = {x_input_channel_firstdim: state}
            return sess.run([network_output], feed_dict=feed_dict)

        # function to get mse feed dict
        def train_step(sess, current_learning_rate, state, action, reward, state_tp1, terminal, summaries=False):
            feed_dict = {x_input_channel_firstdim: state, x_input_tp1_channel_firstdim: state_tp1,
                         x_actions: action, x_rewards: reward, x_terminals: terminal,
                         tf_learning_rate: current_learning_rate}
            if summaries:
                return sess.run([tf_summaries, tf_train_step], feed_dict=feed_dict)[0]
            else:
                return sess.run([tf_train_step], feed_dict=feed_dict)

        def update_target_net(sess):
            return sess.run([update_target_network_ops])

        self._get_output = get_output
        self._train_step = train_step
        self._update_target_network = update_target_net
        self.saver = tf.train.Saver(var_list=network_trainables)
Exemple #6
0
    def create_network_graph(self):
        input_shape = self._input_shape
        output_num = self._output_num
        # Input placeholders
        with tf.name_scope('input'):
            # we need to fix the input shape from (batch, filter, height, width) to
            # tensorflow which is (batch, height, width, filter)
            self._t_x_input_channel_firstdim = tf.placeholder(
                tf.uint8, [None] + input_shape, name='x-input')
            # transpose because tf wants channels on last dim and channels are passed in on 2nd dim
            self._t_x_input = tf.cast(
                tf.transpose(self._t_x_input_channel_firstdim,
                             perm=[0, 2, 3, 1]), tf.float32) / 255.0
            self._t_x_input_tp1_channel_firstdim = tf.placeholder(
                tf.uint8, [None] + input_shape, name='x-input-tp1')
            # transpose because tf wants channels on last dim and channels are passed in on 2nd dim
            self._t_x_input_tp1 = tf.cast(
                tf.transpose(self._t_x_input_tp1_channel_firstdim,
                             perm=[0, 2, 3, 1]), tf.float32) / 255.0
            self._t_x_actions = tf.placeholder(tf.int32,
                                               shape=[None],
                                               name='x-actions')
            self._t_x_rewards = tf.placeholder(tf.float32,
                                               shape=[None],
                                               name='x-rewards')
            self._t_x_terminals = tf.placeholder(tf.bool,
                                                 shape=[None],
                                                 name='x-terminals')
            self._t_x_discount = self._q_discount

        # Target network does not reuse variables
        with tf.variable_scope('network') as var_scope:
            self._t_network_output = self._network_generator(
                self._t_x_input, output_num)

            # get the trainable variables for this network, later used to overwrite target network vars
            self._tf_network_trainables = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES, scope='network')

            # summarize activations
            summarizer.summarize_activations(
                tf.get_collection(tf.GraphKeys.ACTIVATIONS, scope='network'))

            # if double DQN then we need to create network output for s_tp1
            if self.algorithm_type == 'double' or self.algorithm_type == 'doublenstep':
                var_scope.reuse_variables()
                self._t_network_output_tp1 = self._network_generator(
                    self._t_x_input_tp1, output_num)

            # summarize a histogram of each action output
            for output_ind in range(output_num):
                summarizer.summarize(self._t_network_output[:, output_ind],
                                     'histogram',
                                     'network-output/{0}'.format(output_ind))

            # add network summaries
            summarizer.summarize_variables(
                train_vars=self._tf_network_trainables)

        with tf.variable_scope('target-network'):
            self._t_target_network_output = self._network_generator(
                self._t_x_input_tp1, output_num)

            # get trainables for target network, used in assign op for the update target network step
            target_network_trainables = tf.get_collection(
                tf.GraphKeys.TRAINABLE_VARIABLES, scope='target-network')

        # update target network with network variables
        with tf.name_scope('update-target-network'):
            self._tf_update_target_network_ops = [
                target_v.assign(v) for v, target_v in zip(
                    self._tf_network_trainables, target_network_trainables)
            ]

        # if double convience function to get target values for online action
        if self.algorithm_type == 'double' or self.algorithm_type == 'doublenstep':
            with tf.name_scope('double_target'):
                # Target = target_Q(s_tp1, argmax(online_Q(s_tp1)))
                argmax_tp1 = tf.argmax(self._t_network_output_tp1, axis=1)
                self._t_target_value_online_action = tf_util.one_hot(
                    self._t_target_network_output, argmax_tp1, output_num)

        # caclulate QLoss
        with tf.name_scope('loss'):
            # nstep rewards are calculated outside the gpu/graph because it requires a loop
            if self.algorithm_type != 'nstep' and self.algorithm_type != 'doublenstep':
                with tf.name_scope('estimated-reward-tp1'):
                    if self.algorithm_type == 'double':
                        # Target = target_Q(s_tp1, argmax(online_Q(s_tp1)))
                        target = self._t_target_value_online_action
                    elif self.algorithm_type == 'dqn':
                        # Target = max(target_Q(s_tp1))
                        target = tf.reduce_max(self._t_target_network_output,
                                               axis=1)

                    # compute a mask that returns gamma (discount factor) or 0 if terminal
                    terminal_discount_mask = tf.multiply(
                        1.0 - tf.cast(self._t_x_terminals, tf.float32),
                        self._t_x_discount)
                    est_rew_tp1 = tf.multiply(terminal_discount_mask, target)

                y = self._t_x_rewards + tf.stop_gradient(est_rew_tp1)
            # else nstep
            else:
                y = self._t_x_rewards

            with tf.name_scope('estimated-reward'):
                est_rew = tf_util.one_hot(self._t_network_output,
                                          self._t_x_actions, output_num)

            with tf.name_scope('qloss'):
                # clip loss but keep linear past clip bounds (huber loss with customizable linear part)
                # REFS: https://github.com/spragunr/deep_q_rl/blob/master/deep_q_rl/q_network.py#L108
                # https://github.com/Jabberwockyll/deep_rl_ale/blob/master/q_network.py#L241
                diff = y - est_rew

                if self._loss_clipping > 0.0:
                    abs_diff = tf.abs(diff)
                    # same as min(diff, loss_clipping) because diff can never be negative (definition of abs value)
                    quadratic_part = tf.clip_by_value(abs_diff, 0.0,
                                                      self._loss_clipping)
                    linear_part = abs_diff - quadratic_part
                    loss = (0.5 * tf.square(quadratic_part)) + (
                        self._loss_clipping * linear_part)
                else:
                    # But why multiply the loss by 0.5 when not clipping? https://groups.google.com/forum/#!topic/deep-q-learning/hKK0ZM_OWd4
                    loss = 0.5 * tf.square(diff)
                # NOTICE: we are summing gradients
                error = tf.reduce_sum(loss)
            summarizer.summarize(error, 'scalar', 'loss')

        # optimizer
        with tf.name_scope('shared-optimizer'):
            self._tf_learning_rate = tf.placeholder(tf.float32)
            optimizer = self._optimizer_fn(
                learning_rate=self._tf_learning_rate)
            # only train the network vars not the target network
            gradients = optimizer.compute_gradients(
                error, var_list=self._tf_network_trainables)
            # gradients are stored as a tuple, (gradient, tensor the gradient corresponds to)
            # kinda lame that clip by global norm doesn't accept the list of tuples returned from compute_gradients
            # so we unzip then zip
            tensors = [tensor for gradient, tensor in gradients]
            grads = [gradient for gradient, tensor in gradients]
            clipped_gradients, _ = tf.clip_by_global_norm(
                grads,
                self.global_norm_clipping)  # returns list[tensors], norm
            clipped_grads_tensors = zip(clipped_gradients, tensors)
            self._tf_train_step = optimizer.apply_gradients(
                clipped_grads_tensors)
            # tflearn smartly knows how gradients are stored so we just pass in the list of tuples
            summarizer.summarize_gradients(clipped_grads_tensors)

            # tf learn auto merges all summaries so we just have to grab the last output
            self._tf_summaries = summarizer.summarize(self._tf_learning_rate,
                                                      'scalar',
                                                      'learning-rate')
    def create_network_graph(self, input_shape, output_num, network_generator, q_discount, optimizer, loss_clipping):
        # Input placeholders
        with tf.name_scope('input'):
            # we need to fix the input shape from (batch, filter, height, width) to
            # tensorflow which is (batch, height, width, filter)
            x_input_channel_firstdim = tf.placeholder(tf.uint8, [None] + input_shape, name='x-input')
            # transpose because tf wants channels on last dim and channels are passed in on 2nd dim
            x_input = tf.cast(tf.transpose(x_input_channel_firstdim, perm=[0, 2, 3, 1]), tf.float32) / 255.0
            x_input_tp1_channel_firstdim = tf.placeholder(tf.uint8, [None] + input_shape, name='x-input-tp1')
            # transpose because tf wants channels on last dim and channels are passed in on 2nd dim
            x_input_tp1 = tf.cast(tf.transpose(x_input_tp1_channel_firstdim, perm=[0, 2, 3, 1]), tf.float32) / 255.0
            x_actions = tf.placeholder(tf.int32, shape=[None], name='x-actions')
            # TODO: SARSA only change
            x_actions_tp1 = tf.placeholder(tf.int32, shape=[None], name='x-actions-tp1')
            x_rewards = tf.placeholder(tf.float32, shape=[None], name='x-rewards')
            x_terminals = tf.placeholder(tf.bool, shape=[None], name='x-terminals')
            x_discount = q_discount

        # Target network does not reuse variables. so we use two different variable scopes
        with tf.variable_scope('network'):
            network_output = network_generator(x_input, output_num)

            # summarize a histogram of each action output
            for output_ind in range(output_num):
                summarizer.summarize(network_output[:, output_ind], 'histogram', 'network-output/{0}'.format(output_ind))

            # get the trainable variables for this network, later used to overwrite target network vars
            network_trainables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='network')

            # summarize activations
            summarizer.summarize_activations(tf.get_collection(tf.GraphKeys.ACTIVATIONS, scope='network'))

            # add network summaries
            summarizer.summarize_variables(train_vars=network_trainables)

        with tf.variable_scope('target-network'):
            target_network_output = network_generator(x_input_tp1, output_num)

            # get trainables for target network, used in assign op for the update target network step
            target_network_trainables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='target-network')

            # summarize activations
            summarizer.summarize_activations(tf.get_collection(tf.GraphKeys.ACTIVATIONS, scope='target-network'))

            # add network summaries
            summarizer.summarize_variables(train_vars=target_network_trainables)

        # update target network with network variables
        with tf.name_scope('update-target-network'):
            update_target_network_ops = [target_v.assign(v) for v, target_v in zip(network_trainables, target_network_trainables)]

        # caclulate QLoss
        with tf.name_scope('loss'):
            with tf.name_scope('estimated-reward-tp1'):
                one_minus_term = tf.multiply(1.0 - tf.cast(x_terminals, tf.float32), x_discount)
                # TODO: SARSA only change
                # Sarsa uses the q estimate of the next state given action_tp1. Not the max
                # We must convert to one hot same as below
                # NumPy/Theano est_rew_tp1 = network_output[:, x_actions_tp1]
                x_actions_tp1_one_hot = tf.one_hot(x_actions_tp1, depth=output_num, name='one-hot-tp1',
                                                   on_value=1.0, off_value=0.0, dtype=tf.float32)
                # we reduce sum here because the output could be negative we can't take the max
                # the other indecies will be 0
                network_est_rew_tp1 = tf.reduce_sum(tf.multiply(target_network_output, x_actions_tp1_one_hot), axis=1)
                est_rew_tp1 = tf.multiply(one_minus_term, network_est_rew_tp1)

            y = x_rewards + tf.stop_gradient(est_rew_tp1)

            with tf.name_scope('estimated-reward'):
                # Because of https://github.com/tensorflow/tensorflow/issues/206
                # we cannot use numpy like indexing so we convert to a one hot
                # multiply then take the max over last dim
                # NumPy/Theano est_rew = network_output[:, x_actions]
                x_actions_one_hot = tf.one_hot(x_actions, depth=output_num, name='one-hot',
                                               on_value=1.0, off_value=0.0, dtype=tf.float32)
                # we reduce sum here because the output could be negative we can't take the max
                # the other indecies will be 0
                est_rew = tf.reduce_sum(tf.multiply(network_output, x_actions_one_hot), axis=1)

            with tf.name_scope('qloss'):
                # clip loss but keep linear past clip bounds
                # REFS: https://github.com/spragunr/deep_q_rl/blob/master/deep_q_rl/q_network.py#L108
                # https://github.com/Jabberwockyll/deep_rl_ale/blob/master/q_network.py#L241
                diff = y - est_rew

                if loss_clipping > 0.0:
                    abs_diff = tf.abs(diff)
                    # same as min(diff, loss_clipping) because diff can never be negative (definition of abs value)
                    quadratic_part = tf.clip_by_value(abs_diff, 0.0, loss_clipping)
                    linear_part = abs_diff - quadratic_part
                    loss = (0.5 * tf.square(quadratic_part)) + (loss_clipping * linear_part)
                else:
                    # But why multiply the loss by 0.5 when not clipping? https://groups.google.com/forum/#!topic/deep-q-learning/hKK0ZM_OWd4
                    loss = 0.5 * tf.square(diff)
                # NOTICE: we are summing gradients
                error = tf.reduce_sum(loss)
            summarizer.summarize(error, 'scalar', 'loss')

        # optimizer
        with tf.name_scope('shared-optimizer'):
            tf_learning_rate = tf.placeholder(tf.float32)
            optimizer = optimizer(learning_rate=tf_learning_rate)
            # only train the network vars not the target network
            tf_train_step = optimizer.minimize(error, var_list=network_trainables)

            # tf learn auto merges all summaries so we just have to grab the last output
            tf_summaries = summarizer.summarize(tf_learning_rate, 'scalar', 'learning-rate')

        # function to get network output
        def get_output(sess, state):
            feed_dict = {x_input_channel_firstdim: state}
            return sess.run([network_output], feed_dict=feed_dict)

        # function to get mse feed dict
        def train_step(sess, current_learning_rate, state, action, reward, state_tp1, action_tp1, terminal, summaries=False):
            feed_dict = {x_input_channel_firstdim: state, x_input_tp1_channel_firstdim: state_tp1,
                         x_actions: action, x_actions_tp1: action_tp1, x_rewards: reward, x_terminals: terminal,
                         # TODO: SARSA only change action_tp1
                         tf_learning_rate: current_learning_rate}
            if summaries:
                return sess.run([tf_summaries, tf_train_step], feed_dict=feed_dict)[0]
            else:
                return sess.run([tf_train_step], feed_dict=feed_dict)

        def update_target_net(sess):
            return sess.run([update_target_network_ops])

        self._get_output = get_output
        self._train_step = train_step
        self._update_target_network = update_target_net
        self.saver = tf.train.Saver(var_list=network_trainables)