def build_L2H_or_WS_or_H2L(self, which_module, prev_effect): ''' L2H - Lower to higher hierarchy propagation (L2A in the paper) WG - Within group (WS in the paper) H2L - Higher hierarchy to lower propagation (A2D in the paper) ''' if which_module == 'L2H': func_get_receiver_sender = self._get_L2H_receiver_sender func_get_states = self._get_L2H_states which_one_hot = 0 elif which_module == 'WG': func_get_receiver_sender = self._get_WS_receiver_sender func_get_states = self._get_WS_states which_one_hot = 1 elif which_module == 'H2L': func_get_receiver_sender = self._get_H2L_receiver_sender func_get_states = self._get_H2L_states which_one_hot = 2 Rr, Rs = func_get_receiver_sender() state_Rr, state_Rs = func_get_states(Rr, Rs) effect_Rs = tf.reshape(tf.gather_nd(prev_effect, Rs), [-1, self.de]) extra_attribute = tf.reshape( one_hot_column(effect_Rs, ndims=3, axis=which_one_hot), [-1, 3]) # Add stiffness to Ra if needed if self.vary_stiff == 1: curr_stiff = tf.reshape(tf.gather_nd(self.inp_stiff, Rs), [-1, self.sl]) extra_attribute = tf.concat([extra_attribute, curr_stiff], axis=-1) # Add ground truth distance to Ra if needed if self.add_dist == 1: # Get the new index tensor indx_dist = tf.concat([Rr, Rs[:, 1:]], axis=-1) curr_dist = tf.reshape(tf.gather_nd(self.all_dist_ts, indx_dist), [-1, 1]) extra_attribute = tf.concat([extra_attribute, curr_dist], axis=-1) mlp_input = tf.concat([state_Rr, state_Rs, effect_Rs, extra_attribute], axis=1) raw_effects = hidden_mlp(mlp_input, self.model_builder, self.cfg, self.phiR_list[which_one_hot], reuse_weights=self.reuse_list[which_one_hot], train=not self.my_test, debug=self.debug) summed_effects = sum_effects(raw_effects, Rr, self.bs, self.no) # (BS, NO, DE) return summed_effects
def _get_phiF(self): # Process external forces and gravity self.__prepare_for_phiF() if self.use_actions == 1: self.__add_actions() F = tf.concat([self.F_Rr, self.F_E, self.F_Ra], axis=1) self.PhiF = hidden_mlp(F, self.model_builder, self.cfg, 'phiF', reuse_weights=False, train=not self.my_test, debug=self.debug)
def _get_phiH(self): # Process single particle history information # For legacy issue, it's also named as phiS assert 'phiS' in self.cfg, "Please set the network shape in cfg!" Rr_S = tf.cast(tf.where(self.leaf_flag), tf.int32) S_S = self.__get_states(Rr_S, local_delta_pos=True) S_Ra = tf.reshape(one_hot_column(S_S, 4, 0), [-1, 4]) S_S = tf.concat([S_S, S_Ra], axis=1) self.PhiH = hidden_mlp(S_S, self.model_builder, self.cfg, 'phiS', reuse_weights=False, train=not self.my_test, debug=self.debug) self.Rr_S = Rr_S
def get_predictions(self, E): # O stands for all states state_pos_mass = tf.concat([self.state_pos, self.state_mass], axis=-1) state_pos_mass = tf.reshape(state_pos_mass, [self.bs, self.no, -1]) state_delta_pos = tf.reshape(self.state_local_delta_pos, [self.bs, self.no, -1]) O = tf.concat([state_pos_mass, state_delta_pos], axis=-1) C = tf.concat([O, self.state_gravity, E], axis=2) C = tf.reshape(C, [self.bs * self.no, self.ds + self.de + self.dx]) PhiO = hidden_mlp(C, self.model_builder, self.cfg, 'phiO', reuse_weights=False, train=not self.my_test, debug=self.debug) P = tf.reshape(PhiO, [self.bs, self.no, self.dp]) return P
def _get_phiC(self): self.__prepare_for_collisions() if self.use_collisions == 1: self.__add_collisions() if self.use_self == 1: self.__add_self_collisions() if self.use_static == 1: self.__add_static_collisions() if self.use_collisions==1 \ or self.use_static==1 \ or self.use_self==1: C = tf.concat([self.C_Rr, self.C_Rs, self.C_Ra], axis=1) self.PhiC = hidden_mlp(C, self.model_builder, self.cfg, 'phiC', reuse_weights=False, train=not self.my_test, debug=self.debug)