def build_L2H_or_WS_or_H2L(self, which_module, prev_effect):
        '''
        L2H - Lower to higher hierarchy propagation (L2A in the paper) 
        WG - Within group (WS in the paper)
        H2L - Higher hierarchy to lower propagation (A2D in the paper)
        '''
        if which_module == 'L2H':
            func_get_receiver_sender = self._get_L2H_receiver_sender
            func_get_states = self._get_L2H_states
            which_one_hot = 0
        elif which_module == 'WG':
            func_get_receiver_sender = self._get_WS_receiver_sender
            func_get_states = self._get_WS_states
            which_one_hot = 1
        elif which_module == 'H2L':
            func_get_receiver_sender = self._get_H2L_receiver_sender
            func_get_states = self._get_H2L_states
            which_one_hot = 2

        Rr, Rs = func_get_receiver_sender()
        state_Rr, state_Rs = func_get_states(Rr, Rs)

        effect_Rs = tf.reshape(tf.gather_nd(prev_effect, Rs), [-1, self.de])
        extra_attribute = tf.reshape(
            one_hot_column(effect_Rs, ndims=3, axis=which_one_hot), [-1, 3])

        # Add stiffness to Ra if needed
        if self.vary_stiff == 1:
            curr_stiff = tf.reshape(tf.gather_nd(self.inp_stiff, Rs),
                                    [-1, self.sl])
            extra_attribute = tf.concat([extra_attribute, curr_stiff], axis=-1)

        # Add ground truth distance to Ra if needed
        if self.add_dist == 1:
            # Get the new index tensor
            indx_dist = tf.concat([Rr, Rs[:, 1:]], axis=-1)
            curr_dist = tf.reshape(tf.gather_nd(self.all_dist_ts, indx_dist),
                                   [-1, 1])
            extra_attribute = tf.concat([extra_attribute, curr_dist], axis=-1)

        mlp_input = tf.concat([state_Rr, state_Rs, effect_Rs, extra_attribute],
                              axis=1)

        raw_effects = hidden_mlp(mlp_input,
                                 self.model_builder,
                                 self.cfg,
                                 self.phiR_list[which_one_hot],
                                 reuse_weights=self.reuse_list[which_one_hot],
                                 train=not self.my_test,
                                 debug=self.debug)

        summed_effects = sum_effects(raw_effects, Rr, self.bs,
                                     self.no)  # (BS, NO, DE)

        return summed_effects
    def _get_phiF(self):
        # Process external forces and gravity
        self.__prepare_for_phiF()

        if self.use_actions == 1:
            self.__add_actions()

        F = tf.concat([self.F_Rr, self.F_E, self.F_Ra], axis=1)
        self.PhiF = hidden_mlp(F,
                               self.model_builder,
                               self.cfg,
                               'phiF',
                               reuse_weights=False,
                               train=not self.my_test,
                               debug=self.debug)
 def _get_phiH(self):
     # Process single particle history information
     # For legacy issue, it's also named as phiS
     assert 'phiS' in self.cfg, "Please set the network shape in cfg!"
     Rr_S = tf.cast(tf.where(self.leaf_flag), tf.int32)
     S_S = self.__get_states(Rr_S, local_delta_pos=True)
     S_Ra = tf.reshape(one_hot_column(S_S, 4, 0), [-1, 4])
     S_S = tf.concat([S_S, S_Ra], axis=1)
     self.PhiH = hidden_mlp(S_S,
                            self.model_builder,
                            self.cfg,
                            'phiS',
                            reuse_weights=False,
                            train=not self.my_test,
                            debug=self.debug)
     self.Rr_S = Rr_S
    def get_predictions(self, E):
        # O stands for all states
        state_pos_mass = tf.concat([self.state_pos, self.state_mass], axis=-1)
        state_pos_mass = tf.reshape(state_pos_mass, [self.bs, self.no, -1])
        state_delta_pos = tf.reshape(self.state_local_delta_pos,
                                     [self.bs, self.no, -1])
        O = tf.concat([state_pos_mass, state_delta_pos], axis=-1)

        C = tf.concat([O, self.state_gravity, E], axis=2)
        C = tf.reshape(C, [self.bs * self.no, self.ds + self.de + self.dx])

        PhiO = hidden_mlp(C,
                          self.model_builder,
                          self.cfg,
                          'phiO',
                          reuse_weights=False,
                          train=not self.my_test,
                          debug=self.debug)
        P = tf.reshape(PhiO, [self.bs, self.no, self.dp])
        return P
    def _get_phiC(self):
        self.__prepare_for_collisions()

        if self.use_collisions == 1:
            self.__add_collisions()

        if self.use_self == 1:
            self.__add_self_collisions()

        if self.use_static == 1:
            self.__add_static_collisions()

        if self.use_collisions==1 \
                or self.use_static==1 \
                or self.use_self==1:
            C = tf.concat([self.C_Rr, self.C_Rs, self.C_Ra], axis=1)

            self.PhiC = hidden_mlp(C,
                                   self.model_builder,
                                   self.cfg,
                                   'phiC',
                                   reuse_weights=False,
                                   train=not self.my_test,
                                   debug=self.debug)