Esempio n. 1
0
    def getTable(self, ids, loc, features):
        '''
        ids:[N]
        loc:[N]
        '''
        #indices: [N, 2]
        loc = tf.clip_by_value(loc, 0, self.max_volume)
        indices = tf.stack(
            [tf.range(0, limit=tf.shape(ids)[0], dtype='int32'), loc], axis=1)
        gallery_binit = tf.gather(self.gallery_binit, ids)
        exclude_gallery_binit = tf.tensor_scatter_nd_update(
            gallery_binit, indices, tf.zeros_like(ids, dtype='float32'))
        #[N, V, dims]
        gallery_table = tf.gather(self.gallery_table, ids)
        exclude_feature_table = gallery_table[:, :self.max_volume, :] * tf.expand_dims(
            exclude_gallery_binit[:, :self.max_volume], axis=2)
        #[N, dims]
        feature_table = tf.reduce_sum(
            exclude_feature_table, axis=1) / tf.clip_by_value(
                tf.reduce_sum(exclude_gallery_binit, axis=1, keepdims=True), 1,
                1e10)
        self.table = tf.scatter_update(self.table, ids, feature_table)
        self.binit = tf.scatter_update(
            self.binit, ids,
            tf.reduce_max(exclude_gallery_binit[:, :self.max_volume], axis=1))

        indices = tf.stack([ids, loc], axis=1)
        self.gallery_table = tf.scatter_nd_update(self.gallery_table, indices,
                                                  features)
        self.gallery_binit = tf.scatter_nd_update(
            self.gallery_binit, indices, tf.ones_like(ids, dtype='float32'))
Esempio n. 2
0
        def _fn(i):
            bbmin_i = tf.gather(bbmin, i)
            bbmax_i = tf.gather(bbmax, i)
            verts_i = [tf.gather(verts[0], i),
                       tf.gather(verts[1], i),
                       tf.gather(verts[2], i)]

            x, y = tf.meshgrid(tf.range(bbmin_i[0], bbmax_i[0]),
                               tf.range(bbmin_i[1], bbmax_i[1]))

            num_frags = tf.reduce_prod(tf.shape(x))
            p = tf.stack([tf.reshape(x, [-1]),
                          tf.reshape(y, [-1]),
                          tf.zeros([num_frags], dtype=tf.float32)], axis=1)

            bc, valid = barycentric(verts_i, p)

            p = tf.boolean_mask(p, valid)
            bc = [tf.boolean_mask(bc[k], valid) for k in range(3)]
            z = utils.tri_dot([verts_i[k][2] for k in range(3)], bc)

            inds = tf.to_int32(tf.stack([p[:, 1], p[:, 0]], axis=1))
            cur_z = tf.gather_nd(self.depth, inds)
            visible = tf.less_equal(cur_z, z)

            inds = tf.boolean_mask(inds, visible)
            bc = [tf.boolean_mask(bc[k], visible) for k in range(3)]
            z = tf.boolean_mask(z, visible)

            c = utils.pack_colors(shader.fragment(bc, i), 1)

            updates = [
                tf.scatter_nd_update(self.color, inds, c, use_locking=False),
                tf.scatter_nd_update(self.depth, inds, z, use_locking=False)]
            return updates
Esempio n. 3
0
def unpool(prev_layer, pooling_layer, switches):

    print(switches[0, 0, 0, 0])

    switches = unravel_argmax(switches, prev_layer.get_shape().as_list())

    print(switches)
    print(switches.shape)
    print(switches[:, 0, 0, 0, 0])

    unpool = tf.Variable(initial_value=tf.zeros_like(prev_layer))
    pooling_layer_shape = pooling_layer.shape
    for instance in range(pooling_layer_shape[0]):
        for height in range(pooling_layer_shape[1]):
            for width in range(pooling_layer_shape[2]):
                for channel in range(pooling_layer_shape[3]):
                    # print(instance * height * width * channel)
                    index = switches[:, instance, height, width, channel]
                    max_value = pooling_layer[instance, height, width, channel]
                    tf.scatter_nd_update(
                        unpool,
                        tf.reverse(index, [-1]),
                        updates=tf.convert_to_tensor(max_value))

    return unpool
Esempio n. 4
0
 def update_centroids(self, centroid):
     """compute updated values for centroids as mean of assigned samples
         :param centroid - centroid to be updated
         :returns updated centroids Tensor"""
     sample = self.data_queue_2.dequeue()
     # update per centroid count
     per_centroid_count = tf.scatter_nd_add(self.samples_per_centroid,
                                            indices=[[centroid]],
                                            updates=[1],
                                            name="incrementPerCenterCount")
     # update per center learning rate
     with tf.control_dependencies([per_centroid_count]):
         learning_rate = tf.squeeze(
             tf.cast(1 / tf.slice(per_centroid_count, [centroid], [1]),
                     tf.float64))
         # learning_rate = tf.Print(learning_rate, [learning_rate], message="learning rate: ")
     tf.scatter_nd_update(self.learning_rate, [[centroid]], [learning_rate],
                          name="updateLearningRate")
     # compute new centroids
     updated_centroids = tf.scatter_nd_update(
         self.centroids,
         indices=[centroid],
         updates=tf.add(
             tf.scalar_mul(scalar=(1 - learning_rate),
                           x=tf.slice(input_=self.centroids,
                                      begin=[centroid, 0],
                                      size=[1, self.n_features])),
             tf.scalar_mul(scalar=learning_rate, x=sample)))
     with tf.control_dependencies([updated_centroids]):
         return centroid
Esempio n. 5
0
  def update_edge(self, edge: np.ndarray, change: float) -> None:
    """ The callback to receive notifications about edge changes in the graph.

     This method is called from the Graph when an addition or deletion is
     produced on the edge set. So probably is necessary to recompute the
     transition matrix.

     Args:
       edge (:obj:`np.ndarray`): A 1-D `np.ndarray` that represents the edge that
         changes in the graph, where `edge[0]` is the source vertex, and
         `edge[1]` the destination vertex.
       change (float): The variation of the edge weight. If the final value is
         0.0 then the edge is removed.

     Returns:
       This method returns nothing.

     """

    if change > 0.0:
      self.run_tf(tf.scatter_nd_update(
        self.transition, [[edge[0]]],
        tf.div(self.G.A_tf_vertex(edge[0]),
               self.G.out_degrees_tf_vertex(edge[0]))))
    else:
      self.run_tf(tf.scatter_nd_update(
        self.transition, [[edge[0]]],
        tf.where(self.G.is_not_sink_tf_vertex(edge[0]),
                 tf.div(self.G.A_tf_vertex(edge[0]),
                        self.G.out_degrees_tf_vertex(edge[0])),
                 tf.fill([1, self.G.n], tf.pow(self.G.n_tf, -1)))))
    self._notify(edge, change)
Esempio n. 6
0
 def push(self, x):
     count = tf.shape(x)[0]
     indices = tf.random.uniform(minval=0,
                                 maxval=self._count,
                                 shape=[count],
                                 dtype=tf.int32)
     indices = tf.expand_dims(indices, 1)
     tf.scatter_nd_update(self._pool, indices, x)
Esempio n. 7
0
 def get_A(batch_item, R):
     I = tf.eye(int(R.shape[0]))
     R = tf.scatter_nd_update(R, list(zip(*pairs)),
                              tf.squeeze(batch_item))
     R = tf.scatter_nd_update(R, list(zip(*pairs[::-1])),
                              tf.squeeze(batch_item))
     D = tf.diag(tf.reduce_sum(R, axis=1))
     return I + D - R
  def testBooleanScatterUpdate(self):
    with self.test_session(use_gpu=False) as session:
      var = tf.Variable([True, False])
      update0 = tf.scatter_nd_update(var, [[1]], [True])
      update1 = tf.scatter_nd_update(
          var, tf.constant(
              [[0]], dtype=tf.int64), [False])
      var.initializer.run()

      session.run([update0, update1])

      self.assertAllEqual([False, True], var.eval())
Esempio n. 9
0
    def testBooleanScatterUpdate(self):
        with self.test_session(use_gpu=False) as session:
            var = tf.Variable([True, False])
            update0 = tf.scatter_nd_update(var, [[1]], [True])
            update1 = tf.scatter_nd_update(var,
                                           tf.constant([[0]], dtype=tf.int64),
                                           [False])
            var.initializer.run()

            session.run([update0, update1])

            self.assertAllEqual([False, True], var.eval())
  def testRank3InvalidShape2(self):
    indices = tf.zeros([2, 2, 1], tf.int32)
    updates = tf.zeros([2, 2], tf.int32)
    shape = np.array([2, 2, 2])
    with self.assertRaisesWithPredicateMatch(
        ValueError, "The inner \\d+ dimensions of output\\.shape="):
      tf.scatter_nd(indices, updates, shape)

    ref = tf.Variable(tf.zeros(shape, tf.int32))
    with self.assertRaisesWithPredicateMatch(
        ValueError, "The inner \\d+ dimensions of ref\\.shape="):
      tf.scatter_nd_update(ref, indices, updates)
Esempio n. 11
0
    def testRank3InvalidShape2(self):
        indices = tf.zeros([2, 2, 1], tf.int32)
        updates = tf.zeros([2, 2], tf.int32)
        shape = np.array([2, 2, 2])
        with self.assertRaisesWithPredicateMatch(
                ValueError, "The inner \\d+ dimensions of output\\.shape="):
            tf.scatter_nd(indices, updates, shape)

        ref = tf.Variable(tf.zeros(shape, tf.int32))
        with self.assertRaisesWithPredicateMatch(
                ValueError, "The inner \\d+ dimensions of ref\\.shape="):
            tf.scatter_nd_update(ref, indices, updates)
 def update_history_ops(self, batch_variables, batch_gradients_sign,
                        batch_variables_history, batch_grad_sign_history,
                        history_ptr):
     history_ops = []
     shape = batch_variables.shape[0].value
     indices = [[i, history_ptr] for i in range(shape)]
     history_ops.append(
         tf.scatter_nd_update(batch_variables_history, indices,
                              tf.reshape(batch_variables, [shape])))
     history_ops.append(
         tf.scatter_nd_update(batch_grad_sign_history, indices,
                              tf.reshape(batch_gradients_sign, [shape])))
     return history_ops
Esempio n. 13
0
    def initialize(self, layer_dims, optimizer, learner_target_inputs=None):
        def _make(flow):
            for i, size in enumerate(layer_dims):
                flow = fullyConnected(
                    "layer%i" % i, flow, size, tf.nn.relu)

            return fullyConnected(
                "output_layer", flow, self.action_dim)

        learner_target_inputs = learner_target_inputs or [
            self.state, self.state]
        with tf.variable_scope('learner'):
            self.action_value = _make(learner_target_inputs[0])
        with tf.variable_scope('target'):
            self.target_action_value = _make(learner_target_inputs[1])

        row = tf.range(tf.shape(self.action_value)[0])
        indexes = tf.stack([row, self.action_ph], axis=1)

        updated = tf.Variable([], trainable=False, validate_shape=False)
        updated = tf.assign(updated, self.action_value, validate_shape=False)
        action_value = tf.scatter_nd_update(
            updated, indexes, self.action_value_ph)

        self._loss = tf.losses.huber_loss(
            self.action_value, action_value)

        self.policy_action = tf.argmax(self.action_value, axis=1)
        self.update_op = copyScopeVars('learner', 'target')

        self.train_op = optimizer.minimize(
            self._loss, var_list=getScopeParameters('learner'))
Esempio n. 14
0
def update_windows(x, centers, updates, mask, system_shape, window_shape):
    """
    Update windows around centers with updates at rows where mask is True.

    Parameters
    ----------
    x : Variable tensor of shape (N,) + system_shape
    centers : tensor of shape (N, N_DIMS)
    updates : Tensor of shape (N,) + window_shape
    mask : boolean tensor of shape (N,)
    system_shape : tuple
    window_shape : tuple

    Returns
    -------
    x : Ref to updated variable
    """
    window_size = np.prod(window_shape)
    batch_size = tf.shape(x)[0]
    index_matrix = tf.constant(create_index_matrix(system_shape, window_shape))
    window_range = tf.range(batch_size, dtype=tf.int32)[:, None] * \
        tf.ones(window_size, dtype=tf.int32)[None, :]
    indices = tf.stack((window_range, tf.gather(index_matrix, centers)), 2)
    return tf.scatter_nd_update(x, tf.boolean_mask(indices, mask),
                                tf.boolean_mask(updates, mask))
Esempio n. 15
0
    def Add(self, tokens, state_vectors, attention_vectors):
        '''
        Function adds the vectors to cachce
        state_vectors: tensor of size (batch_size x ... x hidden_size)
        attention_vectors: tensor of size (batch_size x ... x hidden_size)
        tokens: tensor of size (batch_size x ...)

        return: tf.float32(0)
        '''

        indices, alphas = tf.py_func(self._AddPy, [tokens],
                                     (tf.int32, tf.float32))

        indices.set_shape((state_vectors.shape[0] * state_vectors.shape[1], 2))
        alphas.set_shape((state_vectors.shape[0] * state_vectors.shape[1], ))

        updates = tf.multiply(alphas[:, None], tf.gather_nd(self.state_tensor_, indices)) + \
            tf.multiply(1 - alphas[:, None], tf.reshape(state_vectors, (-1, tf.shape(state_vectors)[-1])))

        self.state_tensor_ = tf.scatter_nd_update(self.state_tensor_, indices,
                                                  updates)

        #self.state_tensor_[tf.range(self.batch_size_)[:, None], indeces] = \
        #    state_vectors * alphas[:, :, None] + \
        #    self.state_tensor_[tf.range(self.batch_size_)[
        #       :, None], indeces] * (1 - alphas[:, :, None])

        #self.attention_tensor_[tf.range(self.batch_size_)[:, None], indeces] = \
        #    attention_vectors * alphas[:, :, None] + \
        #    self.attention_tensor_[tf.range(self.batch_size_)[
        #        :, None], indeces] * (1 - alphas[:, :, None])
        return tf.cast(0 * tf.reduce_sum(self.state_tensor_), tf.float32)
Esempio n. 16
0
    def _if_not_empty_lexicon_state(self, i, j, char_inputs, state_inputs,
                                    char_inputs_indices_for_lexicon, state_inputs_indices_for_lexicon, new_c_in):
        new_c_with_lexicon = self._new_c_with_lexicon(i=i, j=j, char_inputs=char_inputs, state_inputs=state_inputs,
                                                      indices_tensor=state_inputs_indices_for_lexicon)
        new_c_out = tf.scatter_nd_update(new_c_in, indices=char_inputs_indices_for_lexicon, updates=new_c_with_lexicon)

        return new_c_out
Esempio n. 17
0
    def call(self, x):

        # mean
        one = tf.constant([1.0])
        sample = tf.shape(x)[0]
        sample_float = tf.cast(sample, 'float')

        partition = tf.divide(one, sample_float)
        xbar = K.transpose(x) - partition * tf.matmul(
            K.transpose(x), tf.ones([sample, sample]))
        R = tf.matmul(xbar, tf.transpose(xbar))
        Rs = tf.Variable(tf.zeros_like(R))
        indices = []
        values = []
        for i in range(self.N):
            for j in range(self.f):
                for k in range(self.f):
                    indices.append([j + i * self.f, k + i * self.f])
                    values.append(R[j + i * self.f, k + i * self.f])

        S = tf.scatter_nd_update(Rs, indices, values)
        T = tf.matmul(
            tf.linalg.inv(S + tf.constant([1e-6]) * tf.eye(self.f * self.N)),
            R - S)

        U, V = tf.linalg.eigh(T)
        U_sort, _ = tf.nn.top_k(U, 1)
        corr = K.sum(K.sqrt(U_sort))

        return -corr
Esempio n. 18
0
 def inner_loop(self, i, j, _):
     body = tf.cond(
         tf.greater(self.array[j - 1], self.array[j]),
         lambda: tf.scatter_nd_update(self.array, [[j - 1], [j]],
                                      [self.array[j], self.array[j - 1]]),
         lambda: self.array)
     return i, tf.subtract(j, 1), body
Esempio n. 19
0
def Gabor_conv(input, Theta, Lambda, name, kernel_size, in_channel,
               out_channel):
    input_shape = input.get_shape()
    batch_size = input_shape[0]

    res = tf.get_variable(
        name=name,
        shape=[input_shape[0], input_shape[1], input_shape[2], out_channel],
        initializer=tf.constant_initializer(0.0),
        trainable=False)

    # different Gabor conv for each image of batch
    for i in range(batch_size):
        img = Expand_dim_up(input=input[i], num=1)
        Theta_ = Theta[i]
        Lambda_ = Lambda[i]

        Gabor = Gabor_filter(Theta=Theta_,
                             Lambda=Lambda_,
                             size=kernel_size,
                             in_channel=in_channel,
                             out_channel=out_channel)

        img_Gabor_conv = tf.nn.conv2d(img,
                                      Gabor,
                                      strides=[1, 1, 1, 1],
                                      padding='SAME')
        tmp = tf.identity(img_Gabor_conv)

        indices = tf.constant([[i]])
        res = tf.scatter_nd_update(ref=res, indices=indices, updates=tmp)

    return res
Esempio n. 20
0
def rep_loss_func(
        inputs,
        output,
        **kwargs
        ):
    data_indx = output['data_indx']
    new_data_memory = output['new_data_memory']

    memory_bank_list = output['memory_bank']
    all_labels_list = output['all_labels']
    semi_psd_labels = output.get('semi_psd_labels', None)
    confidence = output.get('confidence', None)
    if isinstance(memory_bank_list, tf.Variable):
        memory_bank_list = [memory_bank_list]
        all_labels_list = [all_labels_list]
        semi_psd_labels = [semi_psd_labels]
        confidence = [confidence]

    new_semi_psd_labels = output.get('new_semi_psd_labels', None)
    new_conf = output.get('new_conf', None)

    devices = ['/gpu:%i' % idx for idx in range(len(memory_bank_list))]
    update_ops = []
    for device, memory_bank, all_labels \
            in zip(devices, memory_bank_list, all_labels_list):
        with tf.device(device):
            mb_update_op = tf.scatter_update(
                    memory_bank, data_indx, new_data_memory)
            update_ops.append(mb_update_op)
            lb_update_op = tf.scatter_update(
                    all_labels, data_indx,
                    inputs['label'])
            update_ops.append(lb_update_op)

            # Update the first label
            bs = data_indx.get_shape().as_list()[0]
            new_data_indx = tf.concat(
                    [tf.zeros([bs, 1], dtype=data_indx.dtype),
                     tf.expand_dims(data_indx, axis=1)],
                    axis=1)
            curr_idx = devices.index(device)
            update_ops.append(tf.scatter_nd_update(
                    semi_psd_labels[curr_idx],
                    new_data_indx, new_semi_psd_labels))
            if new_conf is not None:
                conf_update_op = tf.scatter_update(
                        confidence[curr_idx],
                        data_indx, new_conf)
                update_ops.append(conf_update_op)

    loss_pure = output['loss']
    with tf.control_dependencies(update_ops):
        # Force the updates to happen before the next batch.
        loss_pure = tf.identity(loss_pure)

    ret_dict = {'loss_pure': loss_pure}
    for key, item in output.items():
        if key.startswith('loss_'):
            ret_dict[key] = item
    return ret_dict
Esempio n. 21
0
    def attention_vocab(attention_weight, sentence_index):
        ''' return indices and updates for tf.scatter_nd_update

        Args:
            attention_weight : [batch, length]
            sentence_index : [batch, length]
        '''
        batch_size = attention_weight.get_shape()[0]
        sentencen_length = attention_weight.get_shape()[-1]

        batch_index = tf.range(batch_size)
        batch_index = tf.expand_dims(batch_index, [1])
        batch_index = tf.tile(batch_index, [1, sentence_length])
        batch_index = tf.reshape(
            batch_index,
            [-1, 1])  # looks like [0,0,0,0,0,1,1,1,1,1,2,2,2,2,2,....]

        zeros = tf.zeros([batch_size, self._output_size])

        flat_index = tf.reshape(sentence_index, [-1, 1])
        indices = tf.concat([batch_index, flat_index], 1)

        updates = tf.reshape(attention_weight, [-1])

        p_attn = tf.scatter_nd_update(zeros, indices, updates)

        return p_attn
Esempio n. 22
0
    def margin_loss(self, features, label, centers, beta):
        #features need to do L2 norm before process since beta make sense

        batchSize = tf.shape(features)[0]
        val = centers - tf.reshape(
            features, [tf.shape(features)[0], 1,
                       tf.shape(features)[1]])
        distance = tf.reduce_sum(tf.square(val), 2)
        var_distance = tf.Variable(0, name='temp', dtype=distance.dtype)
        var_distance = tf.assign(var_distance, distance, validate_shape=False)

        seq = tf.range(batchSize)
        zipper = tf.stack([seq, label], 1)
        c_distance = tf.gather_nd(
            distance,
            zipper)  #change the value of batch's own center to MAX_FLOAT
        var_distance = tf.scatter_nd_update(
            var_distance, zipper,
            tf.ones(batchSize, dtype=tf.float32) * np.finfo(np.float32).max)

        minIndexs = tf.cast(tf.argmin(var_distance, 1), tf.int32)
        minIndexs = tf.stack([seq, minIndexs], 1)
        minValue = tf.gather_nd(
            var_distance, minIndexs
        )  #calc minDistance between feature of whole centers(except its own center)

        basic_loss = tf.add(tf.subtract(c_distance, minValue), beta)
        loss = tf.reduce_mean(tf.maximum(basic_loss, 0.0), 0)

        return loss
def rewiring(theta, target_nb_connection, epsilon=1e-12):
    '''
    The rewiring operation to use after each iteration.
    :param theta:
    :param target_nb_connection:
    :return:
    '''

    with tf.name_scope('rewiring'):
        th = theta.read_value()
        is_con = tf.greater(th, 0)

        n_connected = tf.reduce_sum(tf.cast(is_con, tf.int32))
        nb_reconnect = target_nb_connection - n_connected
        nb_reconnect = tf.maximum(nb_reconnect,0)

        reconnect_candidate_coord = tf.where(tf.logical_not(is_con), name='CandidateCoord')

        n_candidates = tf.shape(reconnect_candidate_coord)[0]
        reconnect_sample_id = tf.random_shuffle(tf.range(n_candidates))[:nb_reconnect]
        reconnect_sample_coord = tf.gather(reconnect_candidate_coord, reconnect_sample_id, name='SelectedCoord')

        # Apply the rewiring
        reconnect_vals = tf.fill(dims=[nb_reconnect], value=epsilon, name='InitValues')
        reconnect_op = tf.scatter_nd_update(theta, reconnect_sample_coord, reconnect_vals, name='Reconnect')

        with tf.control_dependencies([reconnect_op]):
            connection_check = assert_connection_number(theta=theta, targeted_number=target_nb_connection)
            with tf.control_dependencies([connection_check]):
                return tf.no_op('Rewiring')
Esempio n. 24
0
def unit_pruning(w: tf.Variable, k: float) -> tf.Variable:
    """Performs pruning on a weight matrix w in the following way:

    - The euclidean norm of each column is computed.
    - The indices of smallest k% columns based on their euclidean norms are
    selected.
    - All elements in the columns that have the matching indices are set to 0.

    Args:
        w: The weight matrix.
        k: The percentage of columns that should be pruned from the matrix.

    Returns:
        The weight pruned weight matrix.

    """
    k = tf.cast(
        tf.round(tf.cast(tf.shape(w)[1], tf.float32) * tf.constant(k)), dtype=tf.int32
    )
    norm = tf.norm(w, axis=0)
    row_indices = tf.tile(tf.range(tf.shape(w)[0]), [k])
    _, col_indices = tf.nn.top_k(tf.negative(norm), k, sorted=True, name=None)
    col_indices = tf.reshape(
        tf.tile(tf.reshape(col_indices, [-1, 1]), [1, tf.shape(w)[0]]), [-1]
    )
    indices = tf.stack([row_indices, col_indices], axis=1)

    return w.assign(
        tf.scatter_nd_update(w, indices, tf.zeros(tf.shape(w)[0] * k, tf.float32))
    )
Esempio n. 25
0
    def call(self, inputs, state):
        char_inputs = inputs[0]
        state_inputs = inputs[1]

        check_state_0 = tf.reduce_sum(state_inputs, axis=-1)
        check_state_1 = tf.reduce_sum(check_state_0, axis=-1)
        state_inputs_indices_for_lexicon = tf.where(
            tf.not_equal(check_state_0, 0))
        state_inputs_indices_for_not_lexicon = tf.squeeze(
            tf.where(tf.equal(check_state_1, 0)))

        state_inputs_indices_for_not_lexicon = tf.cond(
            pred=tf.equal(tf.rank(state_inputs_indices_for_not_lexicon), 0),
            true_fn=lambda: tf.expand_dims(
                state_inputs_indices_for_not_lexicon, axis=0),
            false_fn=lambda: state_inputs_indices_for_not_lexicon)

        char_inputs_indices_for_lexicon = tf.where(
            tf.not_equal(tf.reduce_sum(check_state_0, axis=-1), 0))
        char_inputs_indices_for_not_lexicon = tf.where(
            tf.equal(tf.reduce_sum(check_state_0, axis=-1), 0))

        if self._state_is_tuple:
            c, h = state
        else:
            c, h = tf.split(value=state, num_or_size_splits=2, axis=1)

        gate_inputs = tf.matmul(tf.concat([char_inputs, h], 1), self._kernel)
        gate_inputs = tf.nn.bias_add(gate_inputs, self._bias)

        i, j, f, o = tf.split(value=gate_inputs, num_or_size_splits=4, axis=1)

        new_c_without_lexicon = self._new_c_without_lexicon(
            i=i,
            f=f,
            j=j,
            c=c,
            indices_tensor=state_inputs_indices_for_not_lexicon)
        new_c = tf.scatter_nd_update(
            self._char_state_tensor,
            indices=char_inputs_indices_for_not_lexicon,
            updates=new_c_without_lexicon)

        new_c = tf.cond(tf.not_equal(
            tf.shape(state_inputs_indices_for_not_lexicon)[-1],
            tf.shape(state_inputs)[0]),
                        true_fn=lambda: self._if_not_empty_lexicon_state(
                            i, j, char_inputs, state_inputs,
                            char_inputs_indices_for_lexicon,
                            state_inputs_indices_for_lexicon, new_c),
                        false_fn=lambda: new_c)

        new_h = tf.multiply(self._activation(new_c), tf.nn.sigmoid(o))

        if self._state_is_tuple:
            new_state = LSTMStateTuple(new_c, new_h)
        else:
            new_state = tf.concat([new_c, new_h], 1)

        return new_h, new_state
Esempio n. 26
0
def weight_pruning(w: tf.Variable, k: float) -> tf.Variable:
    """Performs pruning on a weight matrix w in the following way:

    - The absolute value of all elements in the weight matrix are computed.
    - The indices of the smallest k% elements based on their absolute values are
    selected.
    - All elements with the matching indices are set to 0.

    Args:
        w: The weight matrix.
        k: The percentage of values (units) that should be pruned from the matrix.

    Returns:
        The unit pruned weight matrix.

    """
    k = tf.cast(
        tf.round(tf.size(w, out_type=tf.float32) * tf.constant(k)), dtype=tf.int32
    )
    w_reshaped = tf.reshape(w, [-1])
    _, indices = tf.nn.top_k(tf.negative(tf.abs(w_reshaped)), k, sorted=True, name=None)
    mask = tf.scatter_nd_update(
        tf.Variable(
            tf.ones_like(w_reshaped, dtype=tf.float32), name="mask", trainable=False
        ),
        tf.reshape(indices, [-1, 1]),
        tf.zeros([k], tf.float32),
    )

    return w.assign(tf.reshape(w_reshaped * mask, tf.shape(w)))
Esempio n. 27
0
  def append(self, transitions, rows=None):
    """Append a batch of transitions to rows of the memory.

    Args:
      transitions: Tuple of transition quantities with batch dimension.
      rows: Episodes to append to, defaults to all.

    Returns:
      Operation.
    """
    rows = tf.range(self._capacity) if rows is None else rows
    assert rows.shape.ndims == 1
    assert_capacity = tf.assert_less(
        rows, self._capacity,
        message='capacity exceeded')
    with tf.control_dependencies([assert_capacity]):
      assert_max_length = tf.assert_less(
          tf.gather(self._length, rows), self._max_length,
          message='max length exceeded')
    with tf.control_dependencies([assert_max_length]):
      timestep = tf.gather(self._length, rows)
      indices = tf.stack([rows, timestep], 1)
      append_ops = tools.nested.map(
          lambda var, val: tf.scatter_nd_update(var, indices, val),
          self._buffers, transitions, flatten=True)
    with tf.control_dependencies(append_ops):
      episode_mask = tf.reduce_sum(tf.one_hot(
          rows, self._capacity, dtype=tf.int32), 0)
      return self._length.assign_add(episode_mask)
Esempio n. 28
0
  def append(self, transitions, rows=None):
    """Append a batch of transitions to rows of the memory.

    Args:
      transitions: Tuple of transition quantities with batch dimension.
      rows: Episodes to append to, defaults to all.

    Returns:
      Operation.
    """
    rows = tf.range(self._capacity) if rows is None else rows
    assert rows.shape.ndims == 1
    assert_capacity = tf.assert_less(
        rows, self._capacity,
        message='capacity exceeded')
    with tf.control_dependencies([assert_capacity]):
      assert_max_length = tf.assert_less(
          tf.gather(self._length, rows), self._max_length,
          message='max length exceeded')
    append_ops = []
    with tf.control_dependencies([assert_max_length]):
      for buffer_, elements in zip(self._buffers, transitions):
        timestep = tf.gather(self._length, rows)
        indices = tf.stack([rows, timestep], 1)
        append_ops.append(tf.scatter_nd_update(buffer_, indices, elements))
    with tf.control_dependencies(append_ops):
      episode_mask = tf.reduce_sum(tf.one_hot(
          rows, self._capacity, dtype=tf.int32), 0)
      return self._length.assign_add(episode_mask)
Esempio n. 29
0
def tf_inverse_flow(flow_input, b, h, w):

    # x = vertical (channel 0 in flow), y = horizontal (channel 1 in flow)
    flow_list = tf.unstack(flow_input)

    x, y = tf.meshgrid(tf.range(h), tf.range(w), indexing='ij')

    x = tf.expand_dims(x, -1)
    y = tf.expand_dims(y, -1)

    grid = tf.cast(tf.concat([x, y], axis=-1), tf.float32)

    for r in range(b):

        flow = flow_list[r]
        grid1 = grid + flow

        x1, y1 = tf.split(grid1, [1, 1], axis=-1)
        x1 = tf.clip_by_value(x1, 0, h - 1)
        y1 = tf.clip_by_value(y1, 0, w - 1)
        grid1 = tf.concat([x1, y1], axis=-1)
        grid1 = tf.cast(grid1, tf.int32)

        tf_zeros = tf.zeros([h, w, 1, 1], np.int32)
        indices = tf.expand_dims(grid1, 2)
        indices = tf.concat([indices, tf_zeros], axis=-1)

        flow_x, flow_y = tf.split(flow, [1, 1], axis=-1)

        ref_x = tf.Variable(np.zeros([h, w, 1], np.float32),
                            trainable=False,
                            dtype=tf.float32)
        ref_y = tf.Variable(np.zeros([h, w, 1], np.float32),
                            trainable=False,
                            dtype=tf.float32)
        inv_flow_x = tf.scatter_nd_update(ref_x, indices, -flow_x)
        inv_flow_y = tf.scatter_nd_update(ref_y, indices, -flow_y)
        inv_flow_batch = tf.expand_dims(tf.concat([inv_flow_x, inv_flow_y],
                                                  axis=-1),
                                        axis=0)

        if r == 0:
            inv_flow = inv_flow_batch
        else:
            inv_flow = tf.concat([inv_flow, inv_flow_batch], axis=0)

    return inv_flow
Esempio n. 30
0
def _to_affine_transform_matrix(origin=(0.0, 0.0),
                                trans=(0.0, 0.0),
                                rot=0.0,
                                scale=(1.0, 1.0),
                                shear=(0.0, 0.0)):
    """Create a 3x3 affine transformation matrix from transformation parameters.
    The transformation is applied in the following order: Shear - Scale - Rotate - Translate 
    origin: (x, y). Transformation will take place centered around this pixel location.
    trans: (tx, ty). Translation vector
    rot: theta. Rotation angle
    scale: (sx, sy). Scale (zoom) in x and y directions.
    Shear: (hx, hy). Shear in x and y directions.
    
    Returns:
        M: [3, 3] tensor.
    """
    # Rotation matrix
    #     R = [[ cos(theta)  -sin(theta)   0 ]
    #          [ sin(theta)   cos(theta)   0 ]
    #          [     0            0        1 ]]
    R = tf.Variable(lambda: tf.zeros((3, 3)), tf.float32)
    cos = tf.cast(tf.cos(rot), tf.float32)
    sin = tf.cast(tf.sin(rot), tf.float32)
    tf.scatter_nd_update(R, [[2, 2], [0, 0], [1, 1], [0, 1], [1, 0]],
                         [1, cos, cos, -sin, sin])

    # Scale and shear
    #         [[ sx  0   0 ]    [[ 1   hx  0 ]    [[  sx   sx*hx   0 ]
    #     S =  [ 0   sy  0 ]  *  [ hy  1   0 ]  =  [ sy*hy   sy    0 ]
    #          [ 0   0   1 ]]    [ 0   0   1 ]]    [   0      0    1 ]]
    S = tf.Variable(lambda: tf.zeros((3, 3)), tf.float32)
    tf.scatter_nd_update(
        S, [[2, 2], [0, 0], [1, 1], [0, 1], [1, 0]],
        [1, scale[0], scale[1], scale[0] * shear[0], scale[1] * shear[1]])

    # Coordinate transform: shifting the origin from (0,0) to (x, y)
    #     T = [[ 1   0  -x ]
    #          [ 0   1  -y ]
    #          [ 0   0   1 ]]
    M = tf.Variable(lambda: tf.zeros((3, 3)), tf.float32)
    tf.scatter_nd_update(M, [[0, 0], [1, 1], [2, 2], [0, 2], [1, 2]],
                         [1, 1, 1, -origin[0], -origin[1]])

    # Translation matrix + shift the origin back to (0,0)
    #     T = [[ 1   0   tx + x ]
    #          [ 0   1   ty + y ]
    #          [ 0   0      1   ]]
    T = tf.Variable(lambda: tf.zeros((3, 3)), tf.float32)
    tf.scatter_nd_update(T, [[0, 0], [1, 1], [2, 2], [0, 2], [1, 2]],
                         [1, 1, 1, trans[0] + origin[0], trans[1] + origin[1]])

    # Combine transformations
    M = tf.matmul(S, M)
    M = tf.matmul(R, M)
    M = tf.matmul(T, M)

    return M
 def add_dense(self,
               kernel,
               filters,
               strides,
               activation):
     input_shape = numpy.prod(kernel)
     output_shape = numpy.prod(filters)
     w_shape = input_shape * (input_shape + 1) / 2
     with tensorflow.variable_scope('layer_{0}'.format(len(self.outputs))):
         W = tensorflow.get_variable('weights', (w_shape,),
                             initializer=tensorflow.constant_initializer(1.0))
         s = 2 * W / (W ** 2 + 1)
         c = (1 - W ** 2) / (1 + W ** 2)
         j = 1
         i = 0
         final_shape = max(input_shape, output_shape)
         final_rotation = tensorflow.eye(final_shape)
         for c0, s0 in zip(tensorflow.unstack(c), tensorflow.unstack(s)):
             if i == j or i >= output_shape:
                 j += 1
                 i = 0
             if j >= output_shape:
                 break
             givens_matrix = tensorflow.SparseTensor(indices=[[i, i], [j, j], [i, j], [j, i]],
                             values=tensorflow.stack([c0, c0, -s0, s0]),
                                                     dense_shape=[final_shape, final_shape])
             final_rotation = tensorflow.sparse_tensor_dense_matmul(givens_matrix,
                                                                    final_rotation)
             i += 1
         final_rotation = tensorflow.slice(final_rotation, [0,] * 2, [output_shape, input_shape])
         bias = tensorflow.get_variable('bias_{0}'.format(len(self.outputs)),
                             output_shape,
                             initializer=tensorflow.constant_initializer(0.0))
         
         # convolution
         padded = tensorflow.pad(self.outputs[-1], [[k // 2, k // 2] for k in kernel])
         output_shape = [dim // stride for dim, stride in zip(self.outputs[-1].shape, strides)] + filters
         output = tensorflow.Variable(tensorflow.zeros(shape=output_shape, dtype=tensorflow.float32))
         for i in itertools.product(*[range(0, dim, stride) for dim, stride in zip(self.outputs[-1].shape, strides)]):
             k = [i0 // s for i0, s in zip(i, strides)]
             f = tensorflow.matmul(final_rotation,
                                   tensorflow.reshape(tensorflow.slice(padded,
                                                                       i,
                                                                       kernel),
                                                      [input_shape, 1]))
             tensorflow.scatter_nd_update(output, k, tensorflow.squeeze(f))
         self.outputs.append(activation(output + bias))
Esempio n. 32
0
def nmap(xs,
         s,
         coords,
         scope,
         nplayers=2,
         n=15,
         feat=32,
         init_scale=1.0,
         act=tf.nn.relu,
         reuse=False):
    nbatch, nin = [v.value for v in xs[0].get_shape()]
    # s should be of shape [nbatch, n, n]
    with tf.variable_scope(scope, reuse=reuse):
        # global read op
        r1 = conv(s, "conv-r1", nf=32, rf=3, stride=1, init_scale=np.sqrt(2))
        r2 = conv(r1, "conv-r2", nf=32, rf=3, stride=1, init_scale=np.sqrt(2))
        r3 = conv(r2, "conv-r3", nf=32, rf=3, stride=1, init_scale=np.sqrt(2))
        r3 = conv_to_fc(r3)
        r4 = fc(r3, "fc-r4", nh=256, init_scale=init_scale)
        r5 = fc(r4, "glob-r", nh=feat, init_scale=init_scale)

    _reuse = reuse
    ws = []

    for i, (x, coord) in enumerate(zip(xs, coords)):
        """
        q = fc(tf.concat([x, r5], axis=1), "c-query", nh=feat, init_scale=init_scale, reuse=_reuse)
        q = tf.reshape(q, [-1, 1, 1, feat])
        alpha = tf.reshape(tf.nn.softmax(tf.reduce_sum(s * q, axis=3)), [-1, n, n, 1])
        c = tf.reduce_sum(s * alpha, axis=[1, 2])
        """
        # key-val retrieval
        q = fc(tf.concat([x, r5], axis=1),
               "c-query",
               nh=feat / 2,
               init_scale=np.sqrt(2),
               reuse=_reuse)
        k, v = tf.split(axis=3, num_or_size_splits=2, value=s)
        q = tf.reshape(q, [-1, 1, 1, feat // 2])
        alpha = tf.reshape(tf.nn.softmax(tf.reduce_sum(k * q, axis=3)),
                           [-1, n, n, 1])
        c = tf.reduce_sum(v * alpha, axis=[1, 2])

        # write to position a, b
        batch_idx = tf.range(0, nbatch)
        batch_idx = tf.reshape(batch_idx, [-1, 1])
        coord = tf.concat([batch_idx, coord], axis=1)
        s_coord = tf.gather_nd(s, indices=coord)
        w = fc(tf.concat([x, r5, c, s_coord], axis=1),
               "write",
               nh=feat,
               init_scale=np.sqrt(2),
               reuse=_reuse)
        mem_new = tf.scatter_nd_update(s, indices=coord, updates=w)
        _reuse = True
        ws.append(w)
        xs[i] = c

    return xs, r5, ws, mem_new
Esempio n. 33
0
def unpool_with_with_argmax(pooled,
                            ind,
                            input_shape,
                            ksize=[1, 2, 2, 1],
                            name='unpool'):
    """
  https://github.com/sangeet259/tensorflow_unpooling
    To unpool the tensor after  max_pool_with_argmax.
    Argumnets:
        pooled:    the max pooled output tensor
        ind:       argmax indices , the second output of max_pool_with_argmax
        ksize:     ksize should be the same as what you have used to pool
    Returns:
        unpooled:      the tensor after unpooling
    Some points to keep in mind ::
        1. In tensorflow the indices in argmax are flattened, so that a maximum value at position [b, y, x, c] 
           becomes flattened index ((b * height + y) * width + x) * channels + c
        2. Due to point 1, use broadcasting to appropriately place the values at their right locations ! 
  """
    with tf.name_scope(name) as scope:
        # Get the the shape of the tensor in th form of a list
        #input_shape = pooled.get_shape().as_list()

        # Determine the output shape
        output_shape = (input_shape[0], input_shape[1] * ksize[1],
                        input_shape[2] * ksize[2], input_shape[3])
        # Ceshape into one giant tensor for better workability
        pooled_ = tf.reshape(pooled, [
            input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3]
        ])
        # The indices in argmax are flattened, so that a maximum value at position [b, y, x, c] becomes flattened index ((b * height + y) * width + x) * channels + c
        # Create a single unit extended cuboid of length bath_size populating it with continous natural number from zero to batch_size
        tmp_shape = np.array([input_shape[0], 1, 1, 1], dtype=np.int64)
        batch_range = tf.reshape(tf.range(tf.cast(output_shape[0], tf.int64),
                                          dtype=ind.dtype),
                                 shape=tmp_shape)
        b = tf.ones_like(ind) * batch_range
        b_ = tf.reshape(b, [
            input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3],
            1
        ])
        ind_ = tf.reshape(ind, [
            input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3],
            1
        ])
        ind_ = tf.concat([b_, ind_], 1)
        ref = tf.Variable(
            tf.zeros([
                output_shape[0],
                output_shape[1] * output_shape[2] * output_shape[3]
            ]))
        # Update the sparse matrix with the pooled values , it is a batch wise operation
        unpooled_ = tf.scatter_nd_update(ref, ind_, pooled_)
        # Reshape the vector to get the final result
        unpooled = tf.reshape(unpooled_, [
            output_shape[0], output_shape[1], output_shape[2], output_shape[3]
        ])
        return (unpooled)
  def testRank3ValidShape(self):
    indices = tf.zeros([2, 2, 2], tf.int32)
    updates = tf.zeros([2, 2, 2], tf.int32)
    shape = np.array([2, 2, 2])
    self.assertAllEqual(
        tf.scatter_nd(indices, updates, shape).get_shape().as_list(), shape)

    ref = tf.Variable(tf.zeros(shape, tf.int32))
    self.assertAllEqual(
        tf.scatter_nd_update(ref, indices, updates).get_shape().as_list(),
        shape)
Esempio n. 35
0
  def testExtraIndicesDimensions(self):
    indices = tf.zeros([1, 1, 2], tf.int32)
    updates = tf.zeros([1, 1], tf.int32)
    shape = np.array([2, 2])
    scatter = tf.scatter_nd(indices, updates, shape)
    self.assertAllEqual(scatter.get_shape().as_list(), shape)
    expected_result = np.zeros([2, 2], dtype=np.int32)
    with self.test_session():
      self.assertAllEqual(expected_result, scatter.eval())

    ref = tf.Variable(tf.zeros(shape, tf.int32))
    scatter_update = tf.scatter_nd_update(ref, indices, updates)
    self.assertAllEqual(scatter_update.get_shape().as_list(), shape)

    with self.test_session():
      ref.initializer.run()
      self.assertAllEqual(expected_result, scatter_update.eval())
Esempio n. 36
0
  def call(self, x, mask=None):
    """Execute this layer on input tensors.

    x = [atom_features, parents, calculation_orders, membership]
    
    Parameters
    ----------
    x: list
      list of Tensors of form described above.
    mask: bool, optional
      Ignored. Present only to shadow superclass call() method.

    Returns
    -------
    outputs: tf.Tensor
      Tensor of atom features, of shape (n_atoms, n_graph_feat)
    """
    # Add trainable weights
    self.build()

    # Extract atom_features
    # Basic features of every atom: (batch_size*max_atoms) * n_atom_features
    atom_features = x[0]

    # calculation orders of graph: (batch_size*max_atoms) * max_atoms * max_atoms
    # each atom corresponds to a graph, which is represented by the `max_atoms*max_atoms` int32 matrix of index
    # each gragh include `max_atoms` of steps(corresponding to rows) of calculating graph features
    # step i calculates the graph features for atoms of index `parents[:,i,0]`
    parents = x[1]

    # target atoms for each step: (batch_size*max_atoms) * max_atoms
    # represent the same atoms of `parents[:, :, 0]`, 
    # different in that these index are positions in `atom_features`
    # paded with max_atoms*batch_size
    calculation_orders = x[2]
    # flags: (batch_size*max_atoms)
    # 0 for paddings, 1 for real atoms
    membership = x[3]
    # number of atoms in total, should equal `batch_size*max_atoms`
    n_atoms = atom_features.get_shape()[0]

    # initialize graph features for each graph
    # another row of zeros is generated for padded dummy atoms
    graph_features = tf.Variable(
        tf.constant(0., shape=(n_atoms, self.max_atoms + 1, self.n_graph_feat)),
        trainable=False)
    # add dummy
    atom_features = tf.concat(
        axis=0,
        values=[
            atom_features, tf.constant(0., shape=(1, self.n_atom_features))
        ])
    for count in range(self.max_atoms):
      # `count`-th step
      # extracting atom features of target atoms: (batch_size*max_atoms) * n_atom_features
      batch_atom_features = tf.gather(atom_features,
                                      calculation_orders[:, count])

      # generating index for graph features used in the inputs
      index = tf.stack(
          [
              tf.reshape(
                  tf.stack([tf.range(n_atoms)] * (self.max_atoms - 1), axis=1),
                  [-1]), tf.reshape(parents[:, count, 1:], [-1])
          ],
          axis=1)
      # extracting graph features for parents of the target atoms, then flatten
      # shape: (batch_size*max_atoms) * [(max_atoms-1)*n_graph_features]
      batch_graph_features = tf.reshape(
          tf.gather_nd(graph_features, index),
          [-1, (self.max_atoms - 1) * self.n_graph_feat])

      # concat into the input tensor: (batch_size*max_atoms) * n_inputs
      batch_inputs = tf.concat(
          axis=1, values=[batch_atom_features, batch_graph_features])
      # DAGgraph_step maps from batch_inputs to a batch of graph_features
      # of shape: (batch_size*max_atoms) * n_graph_features
      # representing the graph features of target atoms in each graph
      batch_outputs = self.DAGgraph_step(batch_inputs, self.W_list, self.b_list)

      # index for targe atoms
      target_index = tf.stack([tf.range(n_atoms), parents[:, count, 0]], axis=1)
      # index for dummies
      target_index2 = tf.stack(
          [tf.range(n_atoms), tf.constant(self.max_atoms, shape=(n_atoms,))],
          axis=1)
      # update the graph features for target atoms
      graph_features = tf.scatter_nd_update(graph_features, target_index,
                                            batch_outputs)
      # recover dummies to zeros if being updated
      graph_features = tf.scatter_nd_update(graph_features, target_index2,
                                            tf.zeros(
                                                (n_atoms, self.n_graph_feat)))

    # last step generates graph features for all target atoms
    # masking the outputs
    outputs = tf.multiply(batch_outputs,
                          tf.expand_dims(tf.to_float(membership), axis=1))
    return outputs
Esempio n. 37
0
  def create_tensor(self, in_layers=None, set_tensors=True, **kwargs):
    """
    parent layers: atom_features, parents, calculation_orders, calculation_masks, n_atoms
    """
    if in_layers is None:
      in_layers = self.in_layers
    in_layers = convert_to_layers(in_layers)

    # Add trainable weights
    self.build()

    atom_features = in_layers[0].out_tensor
    # each atom corresponds to a graph, which is represented by the `max_atoms*max_atoms` int32 matrix of index
    # each gragh include `max_atoms` of steps(corresponding to rows) of calculating graph features
    parents = in_layers[1].out_tensor
    # target atoms for each step: (batch_size*max_atoms) * max_atoms
    calculation_orders = in_layers[2].out_tensor
    calculation_masks = in_layers[3].out_tensor

    n_atoms = in_layers[4].out_tensor
    # initialize graph features for each graph
    graph_features_initial = tf.zeros((self.max_atoms * self.batch_size,
                                       self.max_atoms + 1, self.n_graph_feat))
    # initialize graph features for each graph
    # another row of zeros is generated for padded dummy atoms
    graph_features = tf.Variable(graph_features_initial, trainable=False)

    for count in range(self.max_atoms):
      # `count`-th step
      # extracting atom features of target atoms: (batch_size*max_atoms) * n_atom_features
      mask = calculation_masks[:, count]
      current_round = tf.boolean_mask(calculation_orders[:, count], mask)
      batch_atom_features = tf.gather(atom_features, current_round)

      # generating index for graph features used in the inputs
      index = tf.stack(
          [
              tf.reshape(
                  tf.stack(
                      [tf.boolean_mask(tf.range(n_atoms), mask)] *
                      (self.max_atoms - 1),
                      axis=1), [-1]),
              tf.reshape(tf.boolean_mask(parents[:, count, 1:], mask), [-1])
          ],
          axis=1)
      # extracting graph features for parents of the target atoms, then flatten
      # shape: (batch_size*max_atoms) * [(max_atoms-1)*n_graph_features]
      batch_graph_features = tf.reshape(
          tf.gather_nd(graph_features, index),
          [-1, (self.max_atoms - 1) * self.n_graph_feat])

      # concat into the input tensor: (batch_size*max_atoms) * n_inputs
      batch_inputs = tf.concat(
          axis=1, values=[batch_atom_features, batch_graph_features])
      # DAGgraph_step maps from batch_inputs to a batch of graph_features
      # of shape: (batch_size*max_atoms) * n_graph_features
      # representing the graph features of target atoms in each graph
      batch_outputs = self.DAGgraph_step(batch_inputs, self.W_list, self.b_list,
                                         **kwargs)

      # index for targe atoms
      target_index = tf.stack([tf.range(n_atoms), parents[:, count, 0]], axis=1)
      target_index = tf.boolean_mask(target_index, mask)
      # update the graph features for target atoms
      graph_features = tf.scatter_nd_update(graph_features, target_index,
                                            batch_outputs)

    out_tensor = batch_outputs
    if set_tensors:
      self.variables = self.trainable_weights
      self.out_tensor = out_tensor
    return out_tensor