Beispiel #1
0
    def _step(self) -> Dict[str, tf.Tensor]:
        """Do a step of SGD and update the priorities."""

        # Pull out the data needed for updates/priorities.
        inputs = next(self._iterator)
        transitions: types.Transition = inputs.data
        keys, probs = inputs.info[:2]

        with tf.GradientTape() as tape:
            # Evaluate our networks.
            q_tm1 = self._network(transitions.observation)
            q_t_value = self._target_network(transitions.next_observation)
            q_t_selector = self._network(transitions.next_observation)

            # The rewards and discounts have to have the same type as network values.
            r_t = tf.cast(transitions.reward, q_tm1.dtype)
            if self._max_abs_reward:
                r_t = tf.clip_by_value(r_t, -self._max_abs_reward,
                                       self._max_abs_reward)
            d_t = tf.cast(transitions.discount, q_tm1.dtype) * tf.cast(
                self._discount, q_tm1.dtype)

            # Compute the loss.
            _, extra = trfl.double_qlearning(q_tm1, transitions.action, r_t,
                                             d_t, q_t_value, q_t_selector)
            loss = losses.huber(extra.td_error, self._huber_loss_parameter)

            # Get the importance weights.
            importance_weights = 1. / probs  # [B]
            importance_weights **= self._importance_sampling_exponent
            importance_weights /= tf.reduce_max(importance_weights)

            # Reweight.
            loss *= tf.cast(importance_weights, loss.dtype)  # [B]
            loss = tf.reduce_mean(loss, axis=[0])  # []

        # Do a step of SGD.
        gradients = tape.gradient(loss, self._network.trainable_variables)
        gradients, _ = tf.clip_by_global_norm(gradients,
                                              self._max_gradient_norm)
        self._optimizer.apply(gradients, self._network.trainable_variables)

        # Get the priorities that we'll use to update.
        priorities = tf.abs(extra.td_error)

        # Periodically update the target network.
        if tf.math.mod(self._num_steps, self._target_update_period) == 0:
            for src, dest in zip(self._network.variables,
                                 self._target_network.variables):
                dest.assign(src)
        self._num_steps.assign_add(1)

        # Report loss & statistics for logging.
        fetches = {
            'loss': loss,
            'keys': keys,
            'priorities': priorities,
        }

        return fetches
Beispiel #2
0
    def _build_model_for_training(self):
        inputs = self.create_inputs("main", **self.model_kwargs)
        model = self.create_model(inputs, **self.model_kwargs)
        model_vars = model.trainable_weights
        q = model.output

        with tf.name_scope('training'):
            # Input placeholders
            actions = tf.placeholder(tf.int32, (None, ), name="action")
            rewards = tf.placeholder(tf.float32, (None, ), name="reward")
            inputs_next = self.create_inputs("next", **self.model_kwargs)
            terminates = tf.placeholder(tf.bool, (None, ), name="terminate")

            # Target network
            target_model = self.create_model(inputs_next, **self.model_kwargs)
            target_vars = target_model.trainable_weights

            q_next = tf.stop_gradient(target_model.output)
            q_next_online_net = tf.stop_gradient(model(inputs_next))

            # Loss
            pcontinues = (1.0 - tf.to_float(terminates)) * self.gamma
            errors, _info = double_qlearning(q, actions, rewards, pcontinues,
                                             q_next, q_next_online_net)

            td_error = _info.td_error

            loss = K.mean(errors)
            optimizer = tf.train.AdamOptimizer(
                learning_rate=self.learning_rate)
            optimize_expr = optimizer.minimize(loss, var_list=model_vars)

            with tf.control_dependencies([optimize_expr]):
                optimize_expr = tf.group(
                    *[tf.assign(*a) for a in model.updates])

            # update_target_fn will be called periodically to copy Q network to target Q network
            update_target_expr = tf.group(*[
                var_target.assign(var)
                for var, var_target in zip(model_vars, target_vars)
            ])

        # Create callable functions
        train_fn = K.function(inputs + [
            actions,
            rewards,
            terminates,
        ] + inputs_next,
                              outputs=[td_error],
                              updates=[optimize_expr])

        act_fn = K.function(inputs=inputs, outputs=[K.argmax(q, axis=1)])
        q_fn = K.function(inputs=inputs, outputs=[q])
        update_fn = K.function([], [], updates=[update_target_expr])

        self._update_parameters = lambda: update_fn([])
        self._train = train_fn
        self._act = lambda x: act_fn([x])[0]
        self._q = lambda x: q_fn([x])[0]
        return model
Beispiel #3
0
    def _step(self) -> Dict[str, tf.Tensor]:
        """Do a step of SGD and update the priorities."""

        # Pull out the data needed for updates/priorities.
        inputs = next(self._iterator)
        o_tm1, a_tm1, r_t, d_t, o_t = inputs.data
        keys, probs = inputs.info[:2]

        with tf.GradientTape() as tape:
            # Evaluate our networks.
            q_tm1 = self._network(o_tm1)
            q_t_value = self._target_network(o_t)
            q_t_selector = self._network(o_t)

            # The rewards and discounts have to have the same type as network values.
            r_t = tf.cast(r_t, q_tm1.dtype)
            r_t = tf.clip_by_value(r_t, -1., 1.)
            d_t = tf.cast(d_t, q_tm1.dtype) * tf.cast(self._discount,
                                                      q_tm1.dtype)

            # Compute the loss.
            _, extra = trfl.double_qlearning(q_tm1, a_tm1, r_t, d_t, q_t_value,
                                             q_t_selector)
            loss = losses.huber(extra.td_error, self._huber_loss_parameter)

            # Get the importance weights.
            importance_weights = 1. / probs  # [B]
            importance_weights **= self._importance_sampling_exponent
            importance_weights /= tf.reduce_max(importance_weights)

            # Reweight.
            loss *= tf.cast(importance_weights, loss.dtype)  # [B]
            loss = tf.reduce_mean(loss, axis=[0])  # []

        # Do a step of SGD.
        gradients = tape.gradient(loss, self._network.trainable_variables)
        self._optimizer.apply(gradients, self._network.trainable_variables)

        # Update the priorities in the replay buffer.
        if self._replay_client:
            priorities = tf.cast(tf.abs(extra.td_error), tf.float64)
            self._replay_client.update_priorities(
                table=adders.DEFAULT_PRIORITY_TABLE,
                keys=keys,
                priorities=priorities)

        # Periodically update the target network.
        if tf.math.mod(self._num_steps, self._target_update_period) == 0:
            for src, dest in zip(self._network.variables,
                                 self._target_network.variables):
                dest.assign(src)
        self._num_steps.assign_add(1)

        # Report loss & statistics for logging.
        fetches = {
            'loss': loss,
        }

        return fetches
Beispiel #4
0
    def __init__(self,
                 name,
                 learning_rate=0.01,
                 state_size=4,
                 action_size=2,
                 hidden_size=10,
                 batch_size=20):

        with tf.variable_scope(name):
            self._inputs = tf.placeholder(tf.float32, [None, state_size],
                                          name='inputs')

            self._actions = tf.placeholder(tf.int32, [batch_size],
                                           name='actions')

            self.fc1 = tf.contrib.layers.fully_connected(
                self._inputs, hidden_size)
            self.fc2 = tf.contrib.layers.fully_connected(self.fc1, hidden_size)
            self.fc3 = tf.contrib.layers.fully_connected(self.fc2, hidden_size)
            self.fc4 = tf.contrib.layers.fully_connected(self.fc3, hidden_size)
            self.output = tf.contrib.layers.fully_connected(self.fc4,
                                                            action_size,
                                                            activation_fn=None)

            self.name = name

            self._targetQs = tf.placeholder(tf.float32,
                                            [batch_size, action_size],
                                            name='target')
            self.reward = tf.placeholder(tf.float32, [batch_size],
                                         name='reward')
            self.discount = tf.constant(0.99,
                                        shape=[batch_size],
                                        dtype=tf.float32,
                                        name='discount')

            q_loss, q_learning = trfl.double_qlearning(
                self.output, self._actions, self.reward, self.discount,
                self._targetQs, self.output)
            self.loss = tf.reduce_mean(q_loss)
            self.opt = tf.train.AdamOptimizer(learning_rate).minimize(
                self.loss)
Beispiel #5
0
def q_learning(vision_model_dict, agent_model_dict, target_agent_model_dict,
               inputs, batch_size, kp_type, agent_size, mask_threshold,
               patch_sizes, kpt_encoder_type, mp_steps, img_size, lsp_layers,
               window_size, gamma, double_q, n_step_q):
    """
	:param vision_model_dict:
	:param agent_model_dict:
	:param target_agent_model_dict:
	:param inputs: bottom_up_kpt inputs [batch, T, dims]
	:param batch_size: (int)
	:param kp_type: (str) "transporter" or "permakey" type of keypoint used for bottom-up processing
	:param agent_size: (int) size of agent lstm
	:param mask_threshold: (float)
	:param patch_sizes: (int) size of patch size for "permakey" keypoints
	:param kpt_encoder_type: (str) "cnn" for conv-net "gnn" for graph-net
	:param mp_steps: (int) number of message-passing steps in GNNs
	:param img_size: (int) size of input image (H for H x H img)
	:param lsp_layers: (tuple) of layers for "permakey" keypoints
	:param window_size: (int) size of window used for recurrent q-learning
	:param gamma: (float) discount factor
	:param double_q: (bool) True if using double q-learning
	:param n_step_q: (int) 'n' value used for n-step q-learning
	:return:
	bottom_up_maps: keypoint gaussian masks
	bottom_up_features: bottom-up keypoint features
	"""

    # unpacking elements from sampled trajectories from buffer
    obses_tm1, a_tm1, r_t, dones = inputs[0][0], inputs[0][1], inputs[0][
        2], inputs[0][3]

    obses_tm1 = tf.cast(obses_tm1,
                        dtype=tf.float32) / 255.0  # (batch, T, H, W)

    # reshaping obs tensor (batch, T, H, W, C) -> (batch*T, H, W, C)
    obses_tm1_shape = obses_tm1.shape
    obses_tm1 = tf.reshape(obses_tm1, [
        obses_tm1_shape[0] * obses_tm1_shape[1], obses_tm1_shape[2],
        obses_tm1_shape[3], obses_tm1_shape[4]
    ])

    # 1 single forward pass of kpt-module for T-steps of frames
    vis_forward_start = time.time()
    bottom_up_maps, encoder_features, kpt_centers = vision_forward_pass(
        obses_tm1, vision_model_dict, lsp_layers, kp_type, patch_sizes,
        img_size)

    # reshaping tensors from (b*T, ...) -> (b, T, ...)
    bup_map_shape = bottom_up_maps.shape
    bottom_up_maps = tf.reshape(bottom_up_maps, [
        obses_tm1_shape[0], obses_tm1_shape[1], bup_map_shape[1],
        bup_map_shape[2], bup_map_shape[3]
    ])
    enc_feat_shape = encoder_features.shape
    encoder_features = tf.reshape(encoder_features, [
        obses_tm1_shape[0], obses_tm1_shape[1], enc_feat_shape[1],
        enc_feat_shape[2], enc_feat_shape[3]
    ])
    kpt_c_shape = kpt_centers.shape
    kpt_centers = tf.reshape(kpt_centers, [
        obses_tm1_shape[0], obses_tm1_shape[1], kpt_c_shape[1], kpt_c_shape[2]
    ])

    # splitting outputs into 2 parts  targets = (1:T) and qs = (0:T-1)
    bottom_up_maps_tm1, bottom_up_maps_t = bottom_up_maps[:, n_step_q:
                                                          -1, :, :, :], bottom_up_maps[:,
                                                                                       n_step_q
                                                                                       +
                                                                                       1:, :, :, :]
    encoder_features_tm1, encoder_features_t = encoder_features[:, n_step_q:
                                                                -1, :, :, :], encoder_features[:,
                                                                                               n_step_q
                                                                                               +
                                                                                               1:, :, :, :]
    kpt_centers_tm1, kpt_centers_t = kpt_centers[:, n_step_q:
                                                 -1, :, :], kpt_centers[:,
                                                                        n_step_q
                                                                        +
                                                                        1:, :, :]

    # collecting a_tm1, r_t and dones for n'th step bootstrapping
    a_tm1, r_t = tf.cast(a_tm1, dtype=tf.int32), tf.cast(r_t, dtype=tf.float32)
    a_tm1, r_t = a_tm1[:, n_step_q:-1, :], r_t[:, 0:-1, :]
    dones = tf.cast(dones, dtype=tf.float32)
    dones = dones[:, n_step_q + 1:, 1]  # dones for q_t's
    # switching batch and time axis to align all inputs i.e. (T, b, ..) -> (b, T, ..)
    a_tm1 = tf.transpose(a_tm1, perm=[1, 0, 2])
    dones = tf.transpose(dones, perm=[1, 0])

    # reshaping tensors again (ugh!) (b, T-1, ...) -> (b*(T-1), ...)
    bup_tm1_shape = bottom_up_maps_tm1.shape
    bottom_up_maps_tm1 = tf.reshape(
        bottom_up_maps_tm1,
        [-1, bup_tm1_shape[2], bup_tm1_shape[3], bup_tm1_shape[4]])
    bottom_up_maps_t = tf.reshape(bottom_up_maps_t, bottom_up_maps_tm1.shape)

    enc_tm1_shape = encoder_features_tm1.shape
    encoder_features_tm1 = tf.reshape(
        encoder_features_tm1,
        [-1, enc_tm1_shape[2], enc_tm1_shape[3], enc_tm1_shape[4]])
    encoder_features_t = tf.reshape(encoder_features_t,
                                    encoder_features_tm1.shape)

    kptc_tm1_shape = kpt_centers_tm1.shape
    kpt_centers_tm1 = tf.reshape(kpt_centers_tm1,
                                 [-1, kptc_tm1_shape[2], kptc_tm1_shape[3]])
    kpt_centers_t = tf.reshape(kpt_centers_t, kpt_centers_tm1.shape)

    # compute keypoint encodings
    kpts_features_tm1 = encode_keypoints(
        bottom_up_maps_tm1,
        encoder_features_tm1,
        kpt_centers_tm1,
        mask_threshold,
        kp_type,
        kpt_encoder_type,
        mp_steps,
        True,
        pos_net=agent_model_dict.get("pos_net"),
        kpt_encoder=agent_model_dict.get("kpt_encoder"),
        node_encoder=agent_model_dict.get(
            "node_enc"))  # passes none if not available

    kpts_features_t = encode_keypoints(
        bottom_up_maps_t,
        encoder_features_t,
        kpt_centers_t,
        mask_threshold,
        kp_type,
        kpt_encoder_type,
        mp_steps,
        True,
        pos_net=target_agent_model_dict.get("pos_net"),
        kpt_encoder=target_agent_model_dict.get("kpt_encoder"),
        node_encoder=target_agent_model_dict.get(
            "node_enc"))  # passes none if not available

    # reshaping back the time axis (b*T, dims) -> (b, T, dims)
    kpts_features_tm1 = tf.expand_dims(kpts_features_tm1, axis=1)
    kpts_tm1_shape = kpts_features_tm1.shape
    kpts_features_tm1 = tf.reshape(
        kpts_features_tm1, [batch_size, window_size, kpts_tm1_shape[-1]])

    kpts_features_t = tf.expand_dims(kpts_features_t, axis=1)
    kpts_t_shape = kpts_features_t.shape
    kpts_features_t = tf.reshape(kpts_features_t,
                                 [batch_size, window_size, kpts_t_shape[-1]])

    # RNN computation
    q_tm1_seq = []
    q_t_seq = []
    q_t_selector_seq = []

    # reset lstm state at start of update as in R-DQN random updates
    c_tm1 = tf.Variable(tf.zeros((batch_size, agent_size)), trainable=True)
    h_tm1 = tf.Variable(tf.zeros((batch_size, agent_size)), trainable=True)
    h_t_sel = tf.Variable(tf.zeros((batch_size, agent_size)), trainable=True)
    c_t_sel = tf.Variable(tf.zeros((batch_size, agent_size)), trainable=True)
    h_t = tf.Variable(tf.zeros((batch_size, agent_size)),
                      trainable=False)  # td_targets
    c_t = tf.Variable(tf.zeros((batch_size, agent_size)),
                      trainable=False)  # td_targets
    rnn_unroll_start = time.time()

    # RNN unrolling
    for seq_idx in tf.range(window_size):
        s_tm1 = kpts_features_tm1[:, seq_idx, :]
        s_t = kpts_features_t[:, seq_idx, :]
        # double_q action selection step
        if double_q:
            q_t_selector, h_t_sel, c_t_sel = agent_model_dict["agent_net"](
                s_t, [h_t_sel, c_t_sel], training=True)
            q_t_selector_seq.append(q_t_selector)

        q_tm1, h_tm1, c_tm1 = agent_model_dict["agent_net"](s_tm1,
                                                            [h_tm1, c_tm1],
                                                            training=True)
        q_tm1_seq.append(q_tm1)
        q_t, h_t, c_t = target_agent_model_dict["agent_net"](s_t, [h_t, c_t],
                                                             training=False)
        q_t_seq.append(q_t)
    # print("RNN for loop unrolling took %s" % (time.time() - rnn_unroll_start))

    q_tm1 = tf.convert_to_tensor(q_tm1_seq, dtype=tf.float32)
    q_t = tf.convert_to_tensor(q_t_seq, dtype=tf.float32)

    # compute cumm. rew for 'n' steps
    if n_step_q > 1:
        l = tf.constant(np.array(list(range(n_step_q))), dtype=tf.float32)
        discounts = tf.math.pow(gamma, l)
        # slice r_t [b, T] into moving windows of [b, t-k, k]  # cumsum over k steps
        r_t = tf.transpose(r_t, perm=[1, 0, 2])
        r_t_sliced = tf.convert_to_tensor(
            [r_t[t:t + n_step_q, :, :] for t in range(window_size)],
            dtype=tf.float32)
        r_t_sliced = tf.squeeze(tf.transpose(r_t_sliced, perm=[0, 2, 1, 3]))
        r_t_sl_shape = r_t_sliced.shape
        # reshape (batch, T, n) -> (batch*T, n)
        r_t_sliced = tf.reshape(
            r_t_sliced, [r_t_sl_shape[0] * r_t_sl_shape[1], r_t_sl_shape[2]])
        # r_t_slices [T*batch, n_steps] x  discounts [n_steps, 1]
        r_t = tf.linalg.matvec(r_t_sliced, discounts)
        r_t = tf.reshape(r_t, [r_t_sl_shape[0], r_t_sl_shape[1]])

    # reshape again to make tensors compatible with trfl API
    q_tm1_shape = q_tm1.shape
    q_tm1 = tf.reshape(q_tm1,
                       [q_tm1_shape[0] * q_tm1_shape[1], q_tm1_shape[2]])
    q_t = tf.reshape(q_t, [q_tm1_shape[0] * q_tm1_shape[1], q_tm1_shape[2]])
    a_tm1_shape = a_tm1.shape
    a_tm1 = tf.squeeze(
        tf.reshape(a_tm1, [a_tm1_shape[0] * a_tm1_shape[1], a_tm1_shape[2]]))
    r_t_shape = r_t.shape
    r_t = tf.reshape(r_t, [r_t_shape[0] * r_t_shape[1]])
    dones_shape = dones.shape
    dones = tf.reshape(dones, [dones_shape[0] * dones_shape[1]])

    p_cont = 0.0
    if n_step_q == 1:
        # discount factor (at t=1) for bootstrapped value
        p_cont = tf.math.multiply(tf.ones((dones.shape)) - dones, gamma)
    elif n_step_q > 1:
        # discount factor (at t=n+1) accordingly for bootstrapped value
        p_cont = tf.math.multiply(
            tf.ones((dones.shape)) - dones, tf.math.pow(gamma, n_step_q))

    loss, extra = 0.0, None
    if not double_q:
        loss, extra = trfl.qlearning(q_tm1, a_tm1, r_t, p_cont, q_t)
    elif double_q:
        q_t_selector = tf.convert_to_tensor(q_t_selector_seq, dtype=tf.float32)
        q_t_selector = tf.reshape(
            q_t_selector, [q_tm1_shape[0] * q_tm1_shape[1], q_tm1_shape[2]])
        loss, extra = trfl.double_qlearning(q_tm1, a_tm1, r_t, p_cont, q_t,
                                            q_t_selector)

    # average over batch_dim = (batch*time)
    loss = tf.reduce_mean(loss, axis=0)
    # print("Inside q_learning bellman updates took %4.5f" % (time.time() - q_backup_start))
    return loss, extra
Beispiel #6
0
def main(unused_argv):
    '''
    check path
    '''
    if FLAGS.data_dir == '' or not os.path.exists(FLAGS.data_dir):
        raise ValueError('invalid data directory {}'.format(FLAGS.data_dir))

    if FLAGS.output_dir == '':
        raise ValueError('invalid output directory {}'.format(
            FLAGS.output_dir))
    elif not os.path.exists(FLAGS.output_dir):
        os.makedirs(FLAGS.output_dir)

    event_log_dir = os.path.join(FLAGS.output_dir, '')
    checkpoint_path = os.path.join(FLAGS.output_dir, 'model.ckpt')
    '''
    setup summaries
    '''
    summ = Summaries()
    '''
    setup the game environment
    '''

    filenames_train = glob.glob(
        os.path.join(FLAGS.data_dir, 'train-{}'.format(FLAGS.sampling_rate),
                     '*.mat'))
    filenames_val = glob.glob(
        os.path.join(FLAGS.data_dir, 'val-{}'.format(FLAGS.sampling_rate),
                     '*.mat'))

    game_env_train = Env(decay=FLAGS.decay)
    game_env_val = Env(decay=FLAGS.decay)

    game_actions = list(game_env_train.actions.keys())
    '''
    setup the transition table for experience replay
    '''

    stateDim = [FLAGS.num_chans, FLAGS.num_points]

    transition_args = {
        'batchSize': FLAGS.batch_size,
        'stateDim': stateDim,
        'numActions': len(game_actions),
        'maxSize': FLAGS.replay_memory,
    }

    transitions = TransitionMemory(transition_args)
    '''
    setup agent
    '''
    s_placeholder = tf.placeholder(tf.float32, [FLAGS.batch_size] + stateDim,
                                   's_placeholder')
    s2_placeholder = tf.placeholder(tf.float32, [FLAGS.batch_size] + stateDim,
                                    's2_placeholder')
    a_placeholder = tf.placeholder(tf.int32, [FLAGS.batch_size],
                                   'a_placeholder')
    r_placeholder = tf.placeholder(tf.float32, [FLAGS.batch_size],
                                   'r_placeholder')

    pcont_t = tf.constant(FLAGS.discount, tf.float32, [FLAGS.batch_size])

    network = Model(FLAGS.batch_size, len(game_actions), FLAGS.num_chans, FLAGS.sampling_rate, \
                    FLAGS.num_filters, FLAGS.num_recurs, FLAGS.pooling_stride, name = "network")

    target_network = Model(FLAGS.batch_size, len(game_actions), FLAGS.num_chans, FLAGS.sampling_rate,\
                           FLAGS.num_filters, FLAGS.num_recurs, FLAGS.pooling_stride, name = "target_n")

    q = network(s_placeholder)
    q2 = target_network(s2_placeholder)
    q_selector = network(s2_placeholder)

    loss, q_learning = trfl.double_qlearning(q, a_placeholder, r_placeholder,
                                             pcont_t, q2, q_selector)
    synchronizer = Synchronizer(network, target_network)
    sychronize_ops = synchronizer()

    training_variables = network.variables

    opt = Adam(FLAGS.learning_rate,
               lr_decay=FLAGS.lr_decay,
               lr_decay_steps=FLAGS.lr_decay_steps,
               lr_decay_factor=FLAGS.lr_decay_factor,
               clip=True)

    reduced_loss = tf.reduce_mean(loss)

    graph_regularizers = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    total_regularization_loss = tf.reduce_sum(graph_regularizers)

    total_loss = reduced_loss + total_regularization_loss

    update_op = opt(total_loss, var_list=training_variables)

    summ_loss_op = tf.summary.scalar('loss', total_loss)

    state_placeholder = tf.placeholder(tf.float32, [1] + stateDim,
                                       'state_placeholder')
    decayed_ep_placeholder = tf.placeholder(tf.float32, [],
                                            'decayed_ep_placeholder')

    action_tensor_egreedy = eGreedy(state_placeholder, network,
                                    len(game_actions), decayed_ep_placeholder,
                                    FLAGS.debug)

    action_tensor_greedy = greedy(state_placeholder, network)
    '''
    setup the training process
    '''
    episode_reward_placeholder = tf.placeholder(tf.float32, [],
                                                "episode_reward_placeholder")
    average_reward_placeholder = tf.placeholder(tf.float32, [],
                                                "average_reward_placeholder")

    summ.register('train', 'episode_reward_train', episode_reward_placeholder)
    summ.register('train', 'average_reward_train', average_reward_placeholder)

    summ.register('val', 'episode_reward_val', episode_reward_placeholder)
    summ.register('val', 'average_reward_val', average_reward_placeholder)

    total_reward_train = 0
    average_reward_train = 0

    total_reward_val = 0
    average_reward_val = 0
    '''
    gathering summary operators
    '''
    train_summ_op = summ('train')
    val_summ_op = summ('val')
    '''
    setup the training process
    '''
    transitions.empty()
    # print("game_actions -> {}".format(game_actions))

    writer = tf.summary.FileWriter(event_log_dir, tf.get_default_graph())

    saver = tf.train.Saver(training_variables)

    config = tf.ConfigProto(allow_soft_placement=True,
                            log_device_placement=False)

    assert (FLAGS.gpus != ''), 'invalid GPU specification'
    config.gpu_options.visible_device_list = FLAGS.gpus

    with tf.Session(config=config) as sess:
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ])

        val_step = 0

        for step in range(FLAGS.steps):
            print("Iteration: {}".format(step))

            game_env_train.reset(filenames_train[np.random.randint(
                0, len(filenames_train))])

            last_state = None
            last_state_assigned = False
            episode_reward = 0
            action_index = (len(game_actions) >> 2)

            for estep in range(FLAGS.eval_steps):
                # print("Evaluation step: {}".format(estep))

                # print("{} - measured RT: {}".format(estep, game_env_train.measured_rt))
                # print("{} - predicted RT: {}".format(estep, game_env_train.predicted_rt))
                # print("{} - action -> {}".format(estep, game_actions[action]))

                state, reward, terminal = game_env_train.step(
                    game_actions[action_index])

                # game over?
                if terminal:
                    break

                episode_reward += reward

                # Store transition s, a, r, t
                # if last_state_assigned and reward:
                if last_state_assigned:
                    # print("reward -> {}".format(reward))
                    # print("action -> {}".format(game_actions[last_action]))
                    transitions.add(last_state, last_action, reward,
                                    last_terminal)

                # Select action
                # decayed_ep = FLAGS.testing_ep

                decayed_ep = max(0.1,
                                 (FLAGS.steps - step) / FLAGS.steps * FLAGS.ep)

                if not terminal:
                    action_index = sess.run(action_tensor_egreedy,
                                            feed_dict={
                                                state_placeholder:
                                                np.expand_dims(state, axis=0),
                                                decayed_ep_placeholder:
                                                decayed_ep
                                            })
                else:
                    action_index = 0

                # Do some Q-learning updates
                if estep > FLAGS.learn_start and estep % FLAGS.update_freq == 0:
                    summ_str = None
                    for _ in range(FLAGS.n_replay):
                        if transitions.size > FLAGS.batch_size:
                            s, a, r, s2 = transitions.sample()

                            summ_str, _ = sess.run(
                                [summ_loss_op, update_op],
                                feed_dict={
                                    s_placeholder: s,
                                    a_placeholder: a,
                                    r_placeholder: r,
                                    s2_placeholder: s2
                                })

                    if summ_str:
                        writer.add_summary(summ_str,
                                           step * FLAGS.eval_steps + estep)

                last_state = state
                last_state_assigned = True

                last_action = action_index
                last_terminal = terminal

                if estep > FLAGS.learn_start and estep % FLAGS.target_q == 0:
                    # print("duplicate model parameters")
                    sess.run(sychronize_ops)

            total_reward_train += episode_reward
            average_reward_train = total_reward_train / (step + 1)

            train_summ_str = sess.run(train_summ_op,
                                      feed_dict={
                                          episode_reward_placeholder:
                                          episode_reward,
                                          average_reward_placeholder:
                                          average_reward_train
                                      })
            writer.add_summary(train_summ_str, step)

            if FLAGS.validation and step % FLAGS.validation_interval == 0:
                game_env_val.reset(filenames_val[0])

                episode_reward = 0
                count = 0
                action_index = (len(game_actions) >> 2)

                while True:
                    # print("Evaluation step: {}".format(count))
                    # print("action -> {}".format(game_actions[action_index]))

                    state, reward, terminal = game_env_val.step(
                        game_actions[action_index])

                    # game over?
                    if terminal:
                        break

                    episode_reward += reward

                    if not terminal:
                        action_index = sess.run(action_tensor_greedy,
                                                feed_dict={
                                                    state_placeholder:
                                                    np.expand_dims(state,
                                                                   axis=0)
                                                })
                        action_index = np.squeeze(action_index)

                    # print('state -> {}'.format(state))
                    # print('action_index -> {}'.format(action_index))

                    else:
                        action_index = 0

                    count += 1

                total_reward_val += episode_reward
                average_reward_val = total_reward_val / (val_step + 1)
                val_step += 1

                val_summ_str = sess.run(val_summ_op,
                                        feed_dict={
                                            episode_reward_placeholder:
                                            episode_reward,
                                            average_reward_placeholder:
                                            average_reward_val
                                        })
                writer.add_summary(val_summ_str, step)

        tf.logging.info('Saving model.')
        saver.save(sess, checkpoint_path)
        tf.logging.info('Training complete')

    writer.close()
Beispiel #7
0
    def _step(self) -> Dict[str, tf.Tensor]:
        """Do a step of SGD and update the priorities."""

        # Pull out the data needed for updates/priorities.
        inputs = next(self._iterator)
        o_tm1, a_tm1, r_t, d_t, o_t = inputs.data
        keys, probs = inputs.info[:2]

        with tf.GradientTape() as tape:
            # Evaluate our networks.
            q_tm1 = self._network(o_tm1)
            q_t_value = self._target_network(o_t)
            q_t_selector = self._network(o_t)

            # The rewards and discounts have to have the same type as network values.
            r_t = tf.cast(r_t, q_tm1.dtype)
            r_t = tf.clip_by_value(r_t, -1., 1.)
            d_t = tf.cast(d_t, q_tm1.dtype) * tf.cast(self._discount,
                                                      q_tm1.dtype)

            # Compute the loss.
            _, extra = trfl.double_qlearning(q_tm1, a_tm1, r_t, d_t, q_t_value,
                                             q_t_selector)
            loss = losses.huber(extra.td_error, self._huber_loss_parameter)

            if self._alpha:
                policy_probs = self._emp_policy.lookup([str(o) for o in o_tm1])

                push_down = tf.reduce_logsumexp(
                    q_tm1 * self._tr,
                    axis=1) / self._tr  # soft-maximum of the q func
                push_up = tf.reduce_sum(
                    policy_probs * q_tm1,
                    axis=1)  # expected q value under behavioural policy

                cql_loss = loss + self._alpha * (push_down - push_up)
            else:
                cql_loss = loss

            cql_loss = tf.reduce_mean(cql_loss, axis=0)

        # Do a step of SGD.
        gradients = tape.gradient(cql_loss, self._network.trainable_variables)
        self._optimizer.apply(gradients, self._network.trainable_variables)

        # Update the priorities in the replay buffer.
        if self._replay_client:
            priorities = tf.cast(tf.abs(extra.td_error), tf.float64)
            self._replay_client.update_priorities(
                table=adders.DEFAULT_PRIORITY_TABLE,
                keys=keys,
                priorities=priorities)

        # Periodically update the target network.
        if tf.math.mod(self._counter.get_counts()['learner_steps'],
                       self._target_update_period) == 0:
            for src, dest in zip(self._network.variables,
                                 self._target_network.variables):
                dest.assign(src)

        # Report loss & statistics for logging.
        fetches = {
            'critic_loss':
            tf.reduce_mean(loss, axis=0),
            'q_variance':
            tf.reduce_mean(tf.math.reduce_variance(q_tm1, axis=1), axis=0),
            'q_average':
            tf.reduce_mean(q_tm1)
        }
        if self._alpha:
            fetches.update({
                'push_up':
                tf.reduce_mean(push_up, axis=0),
                'push_down':
                tf.reduce_mean(push_down, axis=0),
                'regularizer':
                tf.reduce_mean(push_down - push_up, axis=0),
                'cql_loss':
                cql_loss,
            })
        return fetches