Пример #1
0
        def _update_center():
            with tf.variable_scope('_update_center'):
                index = tf.to_int32(tf.argmax(activation, axis=0))

                index = _ops.print_debug(index, [index],
                                         'update_center',
                                         is_debug=is_debug)
                active_update = active

                activation_used_update = tf.reshape(activation_used[index],
                                                    [1])
                activation_count_update = tf.reshape(activation_count[index],
                                                     [1])

                activation_max = activation[index]
                activation_max = _ops.print_debug(
                    activation_max,
                    [value[index], update_beta, activation_max],
                    'update_center_value_beta_activation_max',
                    is_debug=is_debug)
                value_update = tf.reshape(
                    tf.add(
                        value[index],
                        tf.multiply(update_beta,
                                    tf.multiply(sign_f, activation_max))), [1])
                value_update = _ops.print_debug(value_update,
                                                [value_update, sign_f],
                                                'update_center_value_updata',
                                                is_debug=is_debug)
                center_update = tf.reshape(center[index], grad_data_shape)
                sigma_update = tf.reshape(sigma[index], [1])
                return (index, active_update, activation_used_update,
                        activation_count_update, value_update, center_update,
                        sigma_update)
Пример #2
0
 def _eta():
     active_ = _ops.print_debug(active, [active],
                                'compute_eta_active',
                                is_debug=is_debug)
     value_sliced = tf.slice(value, begin=[0], size=[active_])
     activation_ = _ops.print_debug(
         activation, [activation[active_ - 1], value_sliced],
         'compute_eta_activation_value',
         is_debug=is_debug)
     eta = tf.multiply(value_sliced, activation_)
     activation_total = tf.reduce_sum(activation)
     return tf.cond(
         tf.equal(activation_total, 0), _empty,
         lambda: _ops.safe_div(tf.reduce_sum(eta), activation_total))
Пример #3
0
 def _add_or_update():
     with tf.variable_scope('_add_or_update'):
         activation_max = tf.reduce_max(activation, axis=0)
         activation_max = _ops.print_debug(
             activation_max, [activation_max],
             'add_or_update_activation_max',
             is_debug=is_debug)
         return tf.cond(tf.less(activation_max, threshold_activation),
                        _add_center_free_or_full, _update_center)
Пример #4
0
        def _add_center_full():
            with tf.variable_scope('_add_center_full'):

                used_freq = _ops.safe_div(activation_used, activation_count)

                index = tf.to_int32(tf.argmin(used_freq, axis=0))
                index = _ops.print_debug(index, [index],
                                         'add_center_full',
                                         is_debug=is_debug)

                active_update = active

                activation_used_update = tf.constant([1], dtype=tf.int32)
                activation_count_update = tf.constant([1], dtype=tf.int32)

                value_update = tf.reshape(update_beta, [1])
                center_update = grad_data
                sigma_update = tf.reshape(sigma_init, [1])
                return (index, active_update, activation_used_update,
                        activation_count_update, value_update, center_update,
                        sigma_update)
Пример #5
0
 def _do_print():
     sigma_ = sigma_scaled
     sigma_ = _ops.print_debug(sigma_, [active],
                               "mth_rbf",
                               is_debug=is_debug)
     sigma_ = _ops.print_debug(sigma_, [sigma_[active - 1]],
                               "rbf_sigma",
                               is_debug=is_debug)
     sigma_ = _ops.print_debug(sigma_, [data_tiled[active - 1]],
                               "rbf_x",
                               is_debug=is_debug)
     sigma_ = _ops.print_debug(sigma_, [center_sliced[active - 1]],
                               "rbf_c",
                               is_debug=is_debug)
     sigma_ = _ops.print_debug(sigma_, [center_diff[active - 1]],
                               "rbf_x_minus_c",
                               is_debug=is_debug)
     sigma_ = _ops.print_debug(sigma_, [center_diff_square[active - 1]],
                               "rbf_centerdiff",
                               is_debug=is_debug)
     return sigma_
Пример #6
0
def update_memory(activation,
                  activation_used,
                  activation_count,
                  grad_cur,
                  grad_prev,
                  grad_data,
                  value,
                  center,
                  sigma,
                  active,
                  threshold_activation,
                  update_beta,
                  sigma_init,
                  name_or_scope=None,
                  is_debug=False):
    """Update the memory by either creating a new center or updating a one.
    """
    with tf.variable_scope(name_or_scope,
                           default_name='update_memory',
                           values=[
                               activation, grad_cur, grad_prev, value, center,
                               sigma, active, threshold_activation,
                               update_beta, sigma_init
                           ]):
        activation.get_shape().assert_has_rank(1)
        activation_used.get_shape().assert_has_rank(1)
        activation_count.get_shape().assert_has_rank(1)
        grad_prev = _ops.print_debug(grad_prev,
                                     [update_beta, grad_cur, grad_prev],
                                     "update_memory_beta_grad_cur_grad_prev",
                                     is_debug=is_debug)
        tf.summary.scalar('update_memory_active', active)
        # tf.summary.scalar('update_memory_grad_cur', grad_cur)
        # Franzi: apparently the gradients have rank 0
        # grad_cur.get_shape().assert_has_rank(0)
        # grad_prev.get_shape().assert_has_rank(0)
        # center.get_shape().assert_has_rank(1)
        sigma.get_shape().assert_has_rank(1)
        active.get_shape().assert_has_rank(0)
        barrier_max = tf.constant(1, dtype=tf.float32)
        barrier_min = -barrier_max

        grad_dotproduct = tf.reduce_sum(tf.multiply(grad_prev, grad_cur))
        grad_data_shape = grad_data.get_shape().as_list()
        sign_f = tf.minimum(barrier_max,
                            tf.maximum(barrier_min, grad_dotproduct))
        sign_f = _ops.print_debug(sign_f, [sign_f, grad_dotproduct],
                                  'sign_f, grad_dotproduct',
                                  is_debug=is_debug)
        tf.summary.scalar('sign_f', sign_f)

        def _add_center_empty():
            with tf.variable_scope('_add_center_free'):
                index = active
                active_update = tf.add(active, tf.constant(1, dtype=tf.int32))
                activation_used_update = tf.constant([1], dtype=tf.int32)
                activation_count_update = tf.constant([1], dtype=tf.int32)
                # we should use the value of the closest neighbor
                value_update = tf.reshape(update_beta, [1])
                center_update = grad_data
                sigma_update = tf.reshape(sigma_init, [1])
                return (index, active_update, activation_used_update,
                        activation_count_update, value_update, center_update,
                        sigma_update)

        def _add_center_free():
            with tf.variable_scope('_add_center_free'):
                index = active
                # pick the index with largest activation
                index_closest = tf.to_int32(tf.argmax(activation, axis=0))

                active_update = tf.add(active, tf.constant(1, dtype=tf.int32))
                activation_used_update = tf.constant([1], dtype=tf.int32)
                activation_count_update = tf.constant([1], dtype=tf.int32)
                # we should use the value of the closest neighbor
                value_update = tf.reshape(value[index_closest], [1])
                center_update = grad_data
                sigma_update = tf.reshape(sigma_init, [1])
                return (index, active_update, activation_used_update,
                        activation_count_update, value_update, center_update,
                        sigma_update)

        def _add_center_full():
            with tf.variable_scope('_add_center_full'):

                used_freq = _ops.safe_div(activation_used, activation_count)

                index = tf.to_int32(tf.argmin(used_freq, axis=0))
                index = _ops.print_debug(index, [index],
                                         'add_center_full',
                                         is_debug=is_debug)

                active_update = active

                activation_used_update = tf.constant([1], dtype=tf.int32)
                activation_count_update = tf.constant([1], dtype=tf.int32)

                value_update = tf.reshape(update_beta, [1])
                center_update = grad_data
                sigma_update = tf.reshape(sigma_init, [1])
                return (index, active_update, activation_used_update,
                        activation_count_update, value_update, center_update,
                        sigma_update)

        def _update_center():
            with tf.variable_scope('_update_center'):
                index = tf.to_int32(tf.argmax(activation, axis=0))

                index = _ops.print_debug(index, [index],
                                         'update_center',
                                         is_debug=is_debug)
                active_update = active

                activation_used_update = tf.reshape(activation_used[index],
                                                    [1])
                activation_count_update = tf.reshape(activation_count[index],
                                                     [1])

                activation_max = activation[index]
                activation_max = _ops.print_debug(
                    activation_max,
                    [value[index], update_beta, activation_max],
                    'update_center_value_beta_activation_max',
                    is_debug=is_debug)
                value_update = tf.reshape(
                    tf.add(
                        value[index],
                        tf.multiply(update_beta,
                                    tf.multiply(sign_f, activation_max))), [1])
                value_update = _ops.print_debug(value_update,
                                                [value_update, sign_f],
                                                'update_center_value_updata',
                                                is_debug=is_debug)
                center_update = tf.reshape(center[index], grad_data_shape)
                sigma_update = tf.reshape(sigma[index], [1])
                return (index, active_update, activation_used_update,
                        activation_count_update, value_update, center_update,
                        sigma_update)

        def _add_center_free_or_full():
            with tf.variable_scope('_add_center_free_or_full'):
                # If all centers are filled we have to update one
                # otherwise just add a center.
                return tf.cond(tf.less(active,
                                       tf.shape(center)[0]), _add_center_free,
                               _add_center_full)

        def _add_or_update():
            with tf.variable_scope('_add_or_update'):
                activation_max = tf.reduce_max(activation, axis=0)
                activation_max = _ops.print_debug(
                    activation_max, [activation_max],
                    'add_or_update_activation_max',
                    is_debug=is_debug)
                return tf.cond(tf.less(activation_max, threshold_activation),
                               _add_center_free_or_full, _update_center)

        (index, active_update, activation_used_update, activation_count_update,
         value_update, center_update,
         sigma_update) = tf.cond(tf.equal(active, 0), _add_center_empty,
                                 _add_or_update)

        index = _ops.print_debug(index, [index],
                                 'update_memory_index',
                                 is_debug=is_debug)
        activation_used_new = tf.scatter_update(activation_used, [index],
                                                activation_used_update,
                                                name='activation_used')
        activation_count_new = tf.scatter_update(activation_count, [index],
                                                 activation_count_update,
                                                 name='activation_count')
        # value_new = tf.scatter_update(
        #     value, [index], tf.maximum(value_update, 0.0001), name='value')
        value_new = tf.scatter_update(value, [index],
                                      value_update,
                                      name='value')
        center_new = tf.scatter_update(center, [index],
                                       center_update,
                                       name='center')
        sigma_new = tf.scatter_update(sigma, [index],
                                      sigma_update,
                                      name='sigma')
        with tf.control_dependencies([
                activation_used_new, activation_count_new, value_new,
                center_new, sigma_new
        ]):
            active_new = tf.assign(active, active_update, name='active')
            return (active_new, center_new, value_new, sigma_new,
                    activation_used_new, activation_count_new)