def _update_center(): with tf.variable_scope('_update_center'): index = tf.to_int32(tf.argmax(activation, axis=0)) index = _ops.print_debug(index, [index], 'update_center', is_debug=is_debug) active_update = active activation_used_update = tf.reshape(activation_used[index], [1]) activation_count_update = tf.reshape(activation_count[index], [1]) activation_max = activation[index] activation_max = _ops.print_debug( activation_max, [value[index], update_beta, activation_max], 'update_center_value_beta_activation_max', is_debug=is_debug) value_update = tf.reshape( tf.add( value[index], tf.multiply(update_beta, tf.multiply(sign_f, activation_max))), [1]) value_update = _ops.print_debug(value_update, [value_update, sign_f], 'update_center_value_updata', is_debug=is_debug) center_update = tf.reshape(center[index], grad_data_shape) sigma_update = tf.reshape(sigma[index], [1]) return (index, active_update, activation_used_update, activation_count_update, value_update, center_update, sigma_update)
def _eta(): active_ = _ops.print_debug(active, [active], 'compute_eta_active', is_debug=is_debug) value_sliced = tf.slice(value, begin=[0], size=[active_]) activation_ = _ops.print_debug( activation, [activation[active_ - 1], value_sliced], 'compute_eta_activation_value', is_debug=is_debug) eta = tf.multiply(value_sliced, activation_) activation_total = tf.reduce_sum(activation) return tf.cond( tf.equal(activation_total, 0), _empty, lambda: _ops.safe_div(tf.reduce_sum(eta), activation_total))
def _add_or_update(): with tf.variable_scope('_add_or_update'): activation_max = tf.reduce_max(activation, axis=0) activation_max = _ops.print_debug( activation_max, [activation_max], 'add_or_update_activation_max', is_debug=is_debug) return tf.cond(tf.less(activation_max, threshold_activation), _add_center_free_or_full, _update_center)
def _add_center_full(): with tf.variable_scope('_add_center_full'): used_freq = _ops.safe_div(activation_used, activation_count) index = tf.to_int32(tf.argmin(used_freq, axis=0)) index = _ops.print_debug(index, [index], 'add_center_full', is_debug=is_debug) active_update = active activation_used_update = tf.constant([1], dtype=tf.int32) activation_count_update = tf.constant([1], dtype=tf.int32) value_update = tf.reshape(update_beta, [1]) center_update = grad_data sigma_update = tf.reshape(sigma_init, [1]) return (index, active_update, activation_used_update, activation_count_update, value_update, center_update, sigma_update)
def _do_print(): sigma_ = sigma_scaled sigma_ = _ops.print_debug(sigma_, [active], "mth_rbf", is_debug=is_debug) sigma_ = _ops.print_debug(sigma_, [sigma_[active - 1]], "rbf_sigma", is_debug=is_debug) sigma_ = _ops.print_debug(sigma_, [data_tiled[active - 1]], "rbf_x", is_debug=is_debug) sigma_ = _ops.print_debug(sigma_, [center_sliced[active - 1]], "rbf_c", is_debug=is_debug) sigma_ = _ops.print_debug(sigma_, [center_diff[active - 1]], "rbf_x_minus_c", is_debug=is_debug) sigma_ = _ops.print_debug(sigma_, [center_diff_square[active - 1]], "rbf_centerdiff", is_debug=is_debug) return sigma_
def update_memory(activation, activation_used, activation_count, grad_cur, grad_prev, grad_data, value, center, sigma, active, threshold_activation, update_beta, sigma_init, name_or_scope=None, is_debug=False): """Update the memory by either creating a new center or updating a one. """ with tf.variable_scope(name_or_scope, default_name='update_memory', values=[ activation, grad_cur, grad_prev, value, center, sigma, active, threshold_activation, update_beta, sigma_init ]): activation.get_shape().assert_has_rank(1) activation_used.get_shape().assert_has_rank(1) activation_count.get_shape().assert_has_rank(1) grad_prev = _ops.print_debug(grad_prev, [update_beta, grad_cur, grad_prev], "update_memory_beta_grad_cur_grad_prev", is_debug=is_debug) tf.summary.scalar('update_memory_active', active) # tf.summary.scalar('update_memory_grad_cur', grad_cur) # Franzi: apparently the gradients have rank 0 # grad_cur.get_shape().assert_has_rank(0) # grad_prev.get_shape().assert_has_rank(0) # center.get_shape().assert_has_rank(1) sigma.get_shape().assert_has_rank(1) active.get_shape().assert_has_rank(0) barrier_max = tf.constant(1, dtype=tf.float32) barrier_min = -barrier_max grad_dotproduct = tf.reduce_sum(tf.multiply(grad_prev, grad_cur)) grad_data_shape = grad_data.get_shape().as_list() sign_f = tf.minimum(barrier_max, tf.maximum(barrier_min, grad_dotproduct)) sign_f = _ops.print_debug(sign_f, [sign_f, grad_dotproduct], 'sign_f, grad_dotproduct', is_debug=is_debug) tf.summary.scalar('sign_f', sign_f) def _add_center_empty(): with tf.variable_scope('_add_center_free'): index = active active_update = tf.add(active, tf.constant(1, dtype=tf.int32)) activation_used_update = tf.constant([1], dtype=tf.int32) activation_count_update = tf.constant([1], dtype=tf.int32) # we should use the value of the closest neighbor value_update = tf.reshape(update_beta, [1]) center_update = grad_data sigma_update = tf.reshape(sigma_init, [1]) return (index, active_update, activation_used_update, activation_count_update, value_update, center_update, sigma_update) def _add_center_free(): with tf.variable_scope('_add_center_free'): index = active # pick the index with largest activation index_closest = tf.to_int32(tf.argmax(activation, axis=0)) active_update = tf.add(active, tf.constant(1, dtype=tf.int32)) activation_used_update = tf.constant([1], dtype=tf.int32) activation_count_update = tf.constant([1], dtype=tf.int32) # we should use the value of the closest neighbor value_update = tf.reshape(value[index_closest], [1]) center_update = grad_data sigma_update = tf.reshape(sigma_init, [1]) return (index, active_update, activation_used_update, activation_count_update, value_update, center_update, sigma_update) def _add_center_full(): with tf.variable_scope('_add_center_full'): used_freq = _ops.safe_div(activation_used, activation_count) index = tf.to_int32(tf.argmin(used_freq, axis=0)) index = _ops.print_debug(index, [index], 'add_center_full', is_debug=is_debug) active_update = active activation_used_update = tf.constant([1], dtype=tf.int32) activation_count_update = tf.constant([1], dtype=tf.int32) value_update = tf.reshape(update_beta, [1]) center_update = grad_data sigma_update = tf.reshape(sigma_init, [1]) return (index, active_update, activation_used_update, activation_count_update, value_update, center_update, sigma_update) def _update_center(): with tf.variable_scope('_update_center'): index = tf.to_int32(tf.argmax(activation, axis=0)) index = _ops.print_debug(index, [index], 'update_center', is_debug=is_debug) active_update = active activation_used_update = tf.reshape(activation_used[index], [1]) activation_count_update = tf.reshape(activation_count[index], [1]) activation_max = activation[index] activation_max = _ops.print_debug( activation_max, [value[index], update_beta, activation_max], 'update_center_value_beta_activation_max', is_debug=is_debug) value_update = tf.reshape( tf.add( value[index], tf.multiply(update_beta, tf.multiply(sign_f, activation_max))), [1]) value_update = _ops.print_debug(value_update, [value_update, sign_f], 'update_center_value_updata', is_debug=is_debug) center_update = tf.reshape(center[index], grad_data_shape) sigma_update = tf.reshape(sigma[index], [1]) return (index, active_update, activation_used_update, activation_count_update, value_update, center_update, sigma_update) def _add_center_free_or_full(): with tf.variable_scope('_add_center_free_or_full'): # If all centers are filled we have to update one # otherwise just add a center. return tf.cond(tf.less(active, tf.shape(center)[0]), _add_center_free, _add_center_full) def _add_or_update(): with tf.variable_scope('_add_or_update'): activation_max = tf.reduce_max(activation, axis=0) activation_max = _ops.print_debug( activation_max, [activation_max], 'add_or_update_activation_max', is_debug=is_debug) return tf.cond(tf.less(activation_max, threshold_activation), _add_center_free_or_full, _update_center) (index, active_update, activation_used_update, activation_count_update, value_update, center_update, sigma_update) = tf.cond(tf.equal(active, 0), _add_center_empty, _add_or_update) index = _ops.print_debug(index, [index], 'update_memory_index', is_debug=is_debug) activation_used_new = tf.scatter_update(activation_used, [index], activation_used_update, name='activation_used') activation_count_new = tf.scatter_update(activation_count, [index], activation_count_update, name='activation_count') # value_new = tf.scatter_update( # value, [index], tf.maximum(value_update, 0.0001), name='value') value_new = tf.scatter_update(value, [index], value_update, name='value') center_new = tf.scatter_update(center, [index], center_update, name='center') sigma_new = tf.scatter_update(sigma, [index], sigma_update, name='sigma') with tf.control_dependencies([ activation_used_new, activation_count_new, value_new, center_new, sigma_new ]): active_new = tf.assign(active, active_update, name='active') return (active_new, center_new, value_new, sigma_new, activation_used_new, activation_count_new)