def _concat_type_embedding( self, xyz_scatter, nframes, natoms, type_embedding, ): '''Concatenate `type_embedding` of neighbors and `xyz_scatter`. If not self.type_one_side, concatenate `type_embedding` of center atoms as well. Parameters ---------- xyz_scatter: shape is [nframes*natoms[0]*self.nnei, 1] nframes: shape is [] natoms: shape is [1+1+self.ntypes] type_embedding: shape is [self.ntypes, Y] where Y=jdata['type_embedding']['neuron'][-1] Returns ------- embedding: environment of each atom represented by embedding. ''' te_out_dim = type_embedding.get_shape().as_list()[-1] nei_embed = tf.nn.embedding_lookup( type_embedding, tf.cast(self.nei_type, dtype=tf.int32)) # shape is [self.nnei, 1+te_out_dim] nei_embed = tf.tile( nei_embed, (nframes * natoms[0], 1)) # shape is [nframes*natoms[0]*self.nnei, te_out_dim] nei_embed = tf.reshape(nei_embed, [-1, te_out_dim]) embedding_input = tf.concat( [xyz_scatter, nei_embed], 1) # shape is [nframes*natoms[0]*self.nnei, 1+te_out_dim] if not self.type_one_side: atm_embed = embed_atom_type( self.ntypes, natoms, type_embedding) # shape is [natoms[0], te_out_dim] atm_embed = tf.tile( atm_embed, (nframes, self.nnei )) # shape is [nframes*natoms[0], self.nnei*te_out_dim] atm_embed = tf.reshape( atm_embed, [-1, te_out_dim ]) # shape is [nframes*natoms[0]*self.nnei, te_out_dim] embedding_input = tf.concat( [embedding_input, atm_embed], 1 ) # shape is [nframes*natoms[0]*self.nnei, 1+te_out_dim+te_out_dim] return embedding_input
def build(self, learning_rate, natoms, model_dict, label_dict, suffix): coord = model_dict['coord'] energy = model_dict['energy'] atom_ener = model_dict['atom_ener'] nframes = tf.shape(atom_ener)[0] natoms = tf.shape(atom_ener)[1] # build energy dipole atom_ener0 = atom_ener - tf.reshape( tf.tile( tf.reshape(energy / global_cvt_2_ener_float(natoms), [-1, 1]), [1, natoms]), [nframes, natoms]) coord = tf.reshape(coord, [nframes, natoms, 3]) atom_ener0 = tf.reshape(atom_ener0, [nframes, 1, natoms]) ener_dipole = tf.matmul(atom_ener0, coord) ener_dipole = tf.reshape(ener_dipole, [nframes, 3]) energy_hat = label_dict['energy'] ener_dipole_hat = label_dict['energy_dipole'] find_energy = label_dict['find_energy'] find_ener_dipole = label_dict['find_energy_dipole'] l2_ener_loss = tf.reduce_mean(tf.square(energy - energy_hat), name='l2_' + suffix) ener_dipole_reshape = tf.reshape(ener_dipole, [-1]) ener_dipole_hat_reshape = tf.reshape(ener_dipole_hat, [-1]) l2_ener_dipole_loss = tf.reduce_mean( tf.square(ener_dipole_reshape - ener_dipole_hat_reshape), name='l2_' + suffix) # atom_norm_ener = 1./ global_cvt_2_ener_float(natoms[0]) atom_norm_ener = 1. / global_cvt_2_ener_float(natoms) pref_e = global_cvt_2_ener_float( find_energy * (self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * learning_rate / self.starter_learning_rate)) pref_ed = global_cvt_2_tf_float( find_ener_dipole * (self.limit_pref_ed + (self.start_pref_ed - self.limit_pref_ed) * learning_rate / self.starter_learning_rate)) l2_loss = 0 more_loss = {} l2_loss += atom_norm_ener * (pref_e * l2_ener_loss) l2_loss += global_cvt_2_ener_float(pref_ed * l2_ener_dipole_loss) more_loss['l2_ener_loss'] = l2_ener_loss more_loss['l2_ener_dipole_loss'] = l2_ener_dipole_loss self.l2_l = l2_loss self.l2_more = more_loss return l2_loss, more_loss
def embed_atom_type( ntypes: int, natoms: tf.Tensor, type_embedding: tf.Tensor, ): """ Make the embedded type for the atoms in system. The atoms are assumed to be sorted according to the type, thus their types are described by a `tf.Tensor` natoms, see explanation below. Parameters ---------- ntypes: Number of types. natoms: The number of atoms. This tensor has the length of Ntypes + 2 natoms[0]: number of local atoms natoms[1]: total number of atoms held by this processor natoms[i]: 2 <= i < Ntypes+2, number of type i atoms type_embedding: The type embedding. It has the shape of [ntypes, embedding_dim] Returns ------- atom_embedding The embedded type of each atom. It has the shape of [numb_atoms, embedding_dim] """ te_out_dim = type_embedding.get_shape().as_list()[-1] atype = [] for ii in range(ntypes): atype.append(tf.tile([ii], [natoms[2 + ii]])) atype = tf.concat(atype, axis=0) atm_embed = tf.nn.embedding_lookup( type_embedding, tf.cast(atype, dtype=tf.int32)) #(nf*natom)*nchnl atm_embed = tf.reshape(atm_embed, [-1, te_out_dim]) return atm_embed
def build( self, inputs: tf.Tensor, natoms: tf.Tensor, input_dict: dict = None, reuse: bool = None, suffix: str = '', ) -> tf.Tensor: """ Build the computational graph for fitting net Parameters ---------- inputs The input descriptor input_dict Additional dict for inputs. if numb_fparam > 0, should have input_dict['fparam'] if numb_aparam > 0, should have input_dict['aparam'] natoms The number of atoms. This tensor has the length of Ntypes + 2 natoms[0]: number of local atoms natoms[1]: total number of atoms held by this processor natoms[i]: 2 <= i < Ntypes+2, number of type i atoms reuse The weights in the networks should be reused when get the variable. suffix Name suffix to identify this descriptor Returns ------- ener The system energy """ if input_dict is None: input_dict = {} bias_atom_e = self.bias_atom_e if self.numb_fparam > 0 and (self.fparam_avg is None or self.fparam_inv_std is None): raise RuntimeError( 'No data stat result. one should do data statisitic, before build' ) if self.numb_aparam > 0 and (self.aparam_avg is None or self.aparam_inv_std is None): raise RuntimeError( 'No data stat result. one should do data statisitic, before build' ) with tf.variable_scope('fitting_attr' + suffix, reuse=reuse): t_dfparam = tf.constant(self.numb_fparam, name='dfparam', dtype=tf.int32) t_daparam = tf.constant(self.numb_aparam, name='daparam', dtype=tf.int32) if self.numb_fparam > 0: t_fparam_avg = tf.get_variable( 't_fparam_avg', self.numb_fparam, dtype=GLOBAL_TF_FLOAT_PRECISION, trainable=False, initializer=tf.constant_initializer(self.fparam_avg)) t_fparam_istd = tf.get_variable( 't_fparam_istd', self.numb_fparam, dtype=GLOBAL_TF_FLOAT_PRECISION, trainable=False, initializer=tf.constant_initializer(self.fparam_inv_std)) if self.numb_aparam > 0: t_aparam_avg = tf.get_variable( 't_aparam_avg', self.numb_aparam, dtype=GLOBAL_TF_FLOAT_PRECISION, trainable=False, initializer=tf.constant_initializer(self.aparam_avg)) t_aparam_istd = tf.get_variable( 't_aparam_istd', self.numb_aparam, dtype=GLOBAL_TF_FLOAT_PRECISION, trainable=False, initializer=tf.constant_initializer(self.aparam_inv_std)) inputs = tf.reshape(inputs, [-1, self.dim_descrpt * natoms[0]]) if len(self.atom_ener): # only for atom_ener nframes = input_dict.get('nframes') if nframes is not None: # like inputs, but we don't want to add a dependency on inputs inputs_zero = tf.zeros((nframes, self.dim_descrpt * natoms[0]), dtype=self.fitting_precision) else: inputs_zero = tf.zeros_like(inputs, dtype=self.fitting_precision) if bias_atom_e is not None: assert (len(bias_atom_e) == self.ntypes) fparam = None aparam = None if self.numb_fparam > 0: fparam = input_dict['fparam'] fparam = tf.reshape(fparam, [-1, self.numb_fparam]) fparam = (fparam - t_fparam_avg) * t_fparam_istd if self.numb_aparam > 0: aparam = input_dict['aparam'] aparam = tf.reshape(aparam, [-1, self.numb_aparam]) aparam = (aparam - t_aparam_avg) * t_aparam_istd aparam = tf.reshape(aparam, [-1, self.numb_aparam * natoms[0]]) type_embedding = input_dict.get('type_embedding', None) if type_embedding is not None: atype_embed = embed_atom_type(self.ntypes, natoms, type_embedding) atype_embed = tf.tile(atype_embed, [tf.shape(inputs)[0], 1]) else: atype_embed = None if atype_embed is None: start_index = 0 outs_list = [] for type_i in range(self.ntypes): if bias_atom_e is None: type_bias_ae = 0.0 else: type_bias_ae = bias_atom_e[type_i] final_layer = self._build_lower(start_index, natoms[2 + type_i], inputs, fparam, aparam, bias_atom_e=type_bias_ae, suffix='_type_' + str(type_i) + suffix, reuse=reuse) # concat the results if type_i < len( self.atom_ener) and self.atom_ener[type_i] is not None: zero_layer = self._build_lower(start_index, natoms[2 + type_i], inputs_zero, fparam, aparam, bias_atom_e=type_bias_ae, suffix='_type_' + str(type_i) + suffix, reuse=True) final_layer += self.atom_ener[type_i] - zero_layer final_layer = tf.reshape( final_layer, [tf.shape(inputs)[0], natoms[2 + type_i]]) outs_list.append(final_layer) start_index += natoms[2 + type_i] # concat the results # concat once may be faster than multiple concat outs = tf.concat(outs_list, axis=1) # with type embedding else: if len(self.atom_ener) > 0: raise RuntimeError( "setting atom_ener is not supported by type embedding") atype_embed = tf.cast(atype_embed, self.fitting_precision) type_shape = atype_embed.get_shape().as_list() inputs = tf.concat( [tf.reshape(inputs, [-1, self.dim_descrpt]), atype_embed], axis=1) self.dim_descrpt = self.dim_descrpt + type_shape[1] inputs = tf.reshape(inputs, [-1, self.dim_descrpt * natoms[0]]) final_layer = self._build_lower(0, natoms[0], inputs, fparam, aparam, bias_atom_e=0.0, suffix=suffix, reuse=reuse) outs = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[0]]) # add atom energy bias; TF will broadcast to all batches # tf.repeat is avaiable in TF>=2.1 or TF 1.15 _TF_VERSION = Version(TF_VERSION) if (Version('1.15') <= _TF_VERSION < Version('2') or _TF_VERSION >= Version('2.1')) and self.bias_atom_e is not None: outs += tf.repeat( tf.Variable(self.bias_atom_e, dtype=self.fitting_precision, trainable=False, name="bias_atom_ei"), natoms[2:]) if self.tot_ener_zero: force_tot_ener = 0.0 outs = tf.reshape(outs, [-1, natoms[0]]) outs_mean = tf.reshape(tf.reduce_mean(outs, axis=1), [-1, 1]) outs_mean = outs_mean - tf.ones_like( outs_mean, dtype=GLOBAL_TF_FLOAT_PRECISION) * ( force_tot_ener / global_cvt_2_tf_float(natoms[0])) outs = outs - outs_mean outs = tf.reshape(outs, [-1]) tf.summary.histogram('fitting_net_output', outs) return tf.reshape(outs, [-1])
def _build_lower(self, start_index, natoms, inputs, fparam=None, aparam=None, bias_atom_e=0.0, suffix='', reuse=None): # cut-out inputs inputs_i = tf.slice(inputs, [0, start_index * self.dim_descrpt], [-1, natoms * self.dim_descrpt]) inputs_i = tf.reshape(inputs_i, [-1, self.dim_descrpt]) layer = inputs_i if fparam is not None: ext_fparam = tf.tile(fparam, [1, natoms]) ext_fparam = tf.reshape(ext_fparam, [-1, self.numb_fparam]) ext_fparam = tf.cast(ext_fparam, self.fitting_precision) layer = tf.concat([layer, ext_fparam], axis=1) if aparam is not None: ext_aparam = tf.slice(aparam, [0, start_index * self.numb_aparam], [-1, natoms * self.numb_aparam]) ext_aparam = tf.reshape(ext_aparam, [-1, self.numb_aparam]) ext_aparam = tf.cast(ext_aparam, self.fitting_precision) layer = tf.concat([layer, ext_aparam], axis=1) for ii in range(0, len(self.n_neuron)): if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii - 1]: layer += one_layer( layer, self.n_neuron[ii], name='layer_' + str(ii) + suffix, reuse=reuse, seed=self.seed, use_timestep=self.resnet_dt, activation_fn=self.fitting_activation_fn, precision=self.fitting_precision, trainable=self.trainable[ii], uniform_seed=self.uniform_seed, initial_variables=self.fitting_net_variables, mixed_prec=self.mixed_prec) else: layer = one_layer(layer, self.n_neuron[ii], name='layer_' + str(ii) + suffix, reuse=reuse, seed=self.seed, activation_fn=self.fitting_activation_fn, precision=self.fitting_precision, trainable=self.trainable[ii], uniform_seed=self.uniform_seed, initial_variables=self.fitting_net_variables, mixed_prec=self.mixed_prec) if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift final_layer = one_layer(layer, 1, activation_fn=None, bavg=bias_atom_e, name='final_layer' + suffix, reuse=reuse, seed=self.seed, precision=self.fitting_precision, trainable=self.trainable[-1], uniform_seed=self.uniform_seed, initial_variables=self.fitting_net_variables, mixed_prec=self.mixed_prec, final_layer=True) if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift return final_layer
def _normalize_3d(self, a): na = tf.norm(a, axis=1) na = tf.tile(tf.reshape(na, [-1, 1]), tf.constant([1, 3])) b = tf.divide(a, na) return b
def build(self, inputs, input_dict, natoms, reuse=None, suffix=''): bias_atom_e = self.bias_atom_e if self.numb_fparam > 0 and (self.fparam_avg is None or self.fparam_inv_std is None): raise RuntimeError( 'No data stat result. one should do data statisitic, before build' ) if self.numb_aparam > 0 and (self.aparam_avg is None or self.aparam_inv_std is None): raise RuntimeError( 'No data stat result. one should do data statisitic, before build' ) with tf.variable_scope('fitting_attr' + suffix, reuse=reuse): t_dfparam = tf.constant(self.numb_fparam, name='dfparam', dtype=tf.int32) t_daparam = tf.constant(self.numb_aparam, name='daparam', dtype=tf.int32) if self.numb_fparam > 0: t_fparam_avg = tf.get_variable( 't_fparam_avg', self.numb_fparam, dtype=global_tf_float_precision, trainable=False, initializer=tf.constant_initializer(self.fparam_avg)) t_fparam_istd = tf.get_variable( 't_fparam_istd', self.numb_fparam, dtype=global_tf_float_precision, trainable=False, initializer=tf.constant_initializer(self.fparam_inv_std)) if self.numb_aparam > 0: t_aparam_avg = tf.get_variable( 't_aparam_avg', self.numb_aparam, dtype=global_tf_float_precision, trainable=False, initializer=tf.constant_initializer(self.aparam_avg)) t_aparam_istd = tf.get_variable( 't_aparam_istd', self.numb_aparam, dtype=global_tf_float_precision, trainable=False, initializer=tf.constant_initializer(self.aparam_inv_std)) start_index = 0 inputs = tf.cast( tf.reshape(inputs, [-1, self.dim_descrpt * natoms[0]]), self.fitting_precision) if bias_atom_e is not None: assert (len(bias_atom_e) == self.ntypes) if self.numb_fparam > 0: fparam = input_dict['fparam'] fparam = tf.reshape(fparam, [-1, self.numb_fparam]) fparam = (fparam - t_fparam_avg) * t_fparam_istd if self.numb_aparam > 0: aparam = input_dict['aparam'] aparam = tf.reshape(aparam, [-1, self.numb_aparam]) aparam = (aparam - t_aparam_avg) * t_aparam_istd aparam = tf.reshape(aparam, [-1, self.numb_aparam * natoms[0]]) for type_i in range(self.ntypes): # cut-out inputs inputs_i = tf.slice(inputs, [0, start_index * self.dim_descrpt], [-1, natoms[2 + type_i] * self.dim_descrpt]) inputs_i = tf.reshape(inputs_i, [-1, self.dim_descrpt]) layer = inputs_i if self.numb_fparam > 0: ext_fparam = tf.tile(fparam, [1, natoms[2 + type_i]]) ext_fparam = tf.reshape(ext_fparam, [-1, self.numb_fparam]) layer = tf.concat([layer, ext_fparam], axis=1) if self.numb_aparam > 0: ext_aparam = tf.slice( aparam, [0, start_index * self.numb_aparam], [-1, natoms[2 + type_i] * self.numb_aparam]) ext_aparam = tf.reshape(ext_aparam, [-1, self.numb_aparam]) layer = tf.concat([layer, ext_aparam], axis=1) start_index += natoms[2 + type_i] if bias_atom_e is None: type_bias_ae = 0.0 else: type_bias_ae = bias_atom_e[type_i] for ii in range(0, len(self.n_neuron)): if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii - 1]: layer += one_layer( layer, self.n_neuron[ii], name='layer_' + str(ii) + '_type_' + str(type_i) + suffix, reuse=reuse, seed=self.seed, use_timestep=self.resnet_dt, activation_fn=self.fitting_activation_fn, precision=self.fitting_precision, trainable=self.trainable[ii]) else: layer = one_layer(layer, self.n_neuron[ii], name='layer_' + str(ii) + '_type_' + str(type_i) + suffix, reuse=reuse, seed=self.seed, activation_fn=self.fitting_activation_fn, precision=self.fitting_precision, trainable=self.trainable[ii]) final_layer = one_layer(layer, 1, activation_fn=None, bavg=type_bias_ae, name='final_layer_type_' + str(type_i) + suffix, reuse=reuse, seed=self.seed, precision=self.fitting_precision, trainable=self.trainable[-1]) if type_i < len( self.atom_ener) and self.atom_ener[type_i] is not None: inputs_zero = tf.zeros_like(inputs_i, dtype=global_tf_float_precision) layer = inputs_zero if self.numb_fparam > 0: layer = tf.concat([layer, ext_fparam], axis=1) if self.numb_aparam > 0: layer = tf.concat([layer, ext_aparam], axis=1) for ii in range(0, len(self.n_neuron)): if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii - 1]: layer += one_layer( layer, self.n_neuron[ii], name='layer_' + str(ii) + '_type_' + str(type_i) + suffix, reuse=True, seed=self.seed, use_timestep=self.resnet_dt, activation_fn=self.fitting_activation_fn, precision=self.fitting_precision, trainable=self.trainable[ii]) else: layer = one_layer( layer, self.n_neuron[ii], name='layer_' + str(ii) + '_type_' + str(type_i) + suffix, reuse=True, seed=self.seed, activation_fn=self.fitting_activation_fn, precision=self.fitting_precision, trainable=self.trainable[ii]) zero_layer = one_layer(layer, 1, activation_fn=None, bavg=type_bias_ae, name='final_layer_type_' + str(type_i) + suffix, reuse=True, seed=self.seed, precision=self.fitting_precision, trainable=self.trainable[-1]) final_layer += self.atom_ener[type_i] - zero_layer final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[2 + type_i]]) # concat the results if type_i == 0: outs = final_layer else: outs = tf.concat([outs, final_layer], axis=1) return tf.cast(tf.reshape(outs, [-1]), global_tf_float_precision)
def _type_embedding_net_one_side_aparam(self, mat_g, atype, natoms, aparam, name='', reuse=None, seed=None, trainable=True): outputs_size = self.filter_neuron[-1] nframes = tf.shape(mat_g)[0] # (nf x natom x nei) x (outputs_size x chnl x chnl) mat_g = tf.reshape(mat_g, [nframes * natoms[0] * self.nnei, outputs_size]) mat_g = one_layer(mat_g, outputs_size * self.type_nchanl, activation_fn=None, precision=self.filter_precision, name=name + '_amplify', reuse=reuse, seed=self.seed, trainable=trainable) # nf x natom x nei x outputs_size x chnl mat_g = tf.reshape( mat_g, [nframes, natoms[0], self.nnei, outputs_size, self.type_nchanl]) # outputs_size x nf x natom x nei x chnl mat_g = tf.transpose(mat_g, perm=[3, 0, 1, 2, 4]) # outputs_size x (nf x natom x nei x chnl) mat_g = tf.reshape( mat_g, [outputs_size, nframes * natoms[0] * self.nnei * self.type_nchanl]) # nf x natom x nnei embed_type = tf.tile(tf.reshape(self.nei_type, [1, self.nnei]), [nframes * natoms[0], 1]) # (nf x natom x nnei) x 1 embed_type = tf.reshape(embed_type, [nframes * natoms[0] * self.nnei, 1]) # nf x (natom x naparam) aparam = tf.reshape(aparam, [nframes, -1]) # nf x natom x nnei x naparam embed_aparam = op_module.map_aparam(aparam, self.nlist, natoms, n_a_sel=self.nnei_a, n_r_sel=self.nnei_r) # (nf x natom x nnei) x naparam embed_aparam = tf.reshape( embed_aparam, [nframes * natoms[0] * self.nnei, self.numb_aparam]) # (nf x natom x nnei) x (naparam+1) embed_input = tf.concat((embed_type, embed_aparam), axis=1) # (nf x natom x nnei) x nchnl ebd_nei_type = self._type_embed(embed_input, ndim=self.numb_aparam + 1, reuse=reuse, trainable=True, suffix='') # (nf x natom x nei x nchnl) ebd_nei_type = tf.reshape( ebd_nei_type, [nframes * natoms[0] * self.nnei * self.type_nchanl]) # outputs_size x (nf x natom x nei x chnl) mat_g = tf.multiply(mat_g, ebd_nei_type) # outputs_size x nf x natom x nei x chnl mat_g = tf.reshape( mat_g, [outputs_size, nframes, natoms[0], self.nnei, self.type_nchanl]) # outputs_size x nf x natom x nei mat_g = tf.reduce_mean(mat_g, axis=4) # nf x natom x nei x outputs_size mat_g = tf.transpose(mat_g, perm=[1, 2, 3, 0]) # (nf x natom) x nei x outputs_size mat_g = tf.reshape(mat_g, [nframes * natoms[0], self.nnei, outputs_size]) return mat_g
def build(self, inputs, input_dict, natoms, bias_atom_e=None, reuse=None, suffix=''): with tf.variable_scope('fitting_attr' + suffix, reuse=reuse): t_dfparam = tf.constant(self.numb_fparam, name='dfparam', dtype=tf.int32) start_index = 0 inputs = tf.reshape(inputs, [-1, self.dim_descrpt * natoms[0]]) shape = inputs.get_shape().as_list() if bias_atom_e is not None: assert (len(bias_atom_e) == self.ntypes) for type_i in range(self.ntypes): # cut-out inputs inputs_i = tf.slice(inputs, [0, start_index * self.dim_descrpt], [-1, natoms[2 + type_i] * self.dim_descrpt]) inputs_i = tf.reshape(inputs_i, [-1, self.dim_descrpt]) start_index += natoms[2 + type_i] if bias_atom_e is None: type_bias_ae = 0.0 else: type_bias_ae = bias_atom_e[type_i] layer = inputs_i if self.numb_fparam > 0: fparam = input_dict['fparam'] ext_fparam = tf.reshape(fparam, [-1, self.numb_fparam]) ext_fparam = tf.tile(ext_fparam, [1, natoms[2 + type_i]]) ext_fparam = tf.reshape(ext_fparam, [-1, self.numb_fparam]) layer = tf.concat([layer, ext_fparam], axis=1) for ii in range(0, len(self.n_neuron)): if ii >= 1 and self.n_neuron[ii] == self.n_neuron[ii - 1]: layer += one_layer(layer, self.n_neuron[ii], name='layer_' + str(ii) + '_type_' + str(type_i) + suffix, reuse=reuse, seed=self.seed, use_timestep=self.resnet_dt) else: layer = one_layer(layer, self.n_neuron[ii], name='layer_' + str(ii) + '_type_' + str(type_i) + suffix, reuse=reuse, seed=self.seed) final_layer = one_layer(layer, 1, activation_fn=None, bavg=type_bias_ae, name='final_layer_type_' + str(type_i) + suffix, reuse=reuse, seed=self.seed) final_layer = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[2 + type_i]]) # concat the results if type_i == 0: outs = final_layer else: outs = tf.concat([outs, final_layer], axis=1) return tf.reshape(outs, [-1])