def build(self, learning_rate, natoms, model_dict, label_dict, suffix): coord = model_dict['coord'] energy = model_dict['energy'] atom_ener = model_dict['atom_ener'] nframes = tf.shape(atom_ener)[0] natoms = tf.shape(atom_ener)[1] # build energy dipole atom_ener0 = atom_ener - tf.reshape( tf.tile( tf.reshape(energy / global_cvt_2_ener_float(natoms), [-1, 1]), [1, natoms]), [nframes, natoms]) coord = tf.reshape(coord, [nframes, natoms, 3]) atom_ener0 = tf.reshape(atom_ener0, [nframes, 1, natoms]) ener_dipole = tf.matmul(atom_ener0, coord) ener_dipole = tf.reshape(ener_dipole, [nframes, 3]) energy_hat = label_dict['energy'] ener_dipole_hat = label_dict['energy_dipole'] find_energy = label_dict['find_energy'] find_ener_dipole = label_dict['find_energy_dipole'] l2_ener_loss = tf.reduce_mean(tf.square(energy - energy_hat), name='l2_' + suffix) ener_dipole_reshape = tf.reshape(ener_dipole, [-1]) ener_dipole_hat_reshape = tf.reshape(ener_dipole_hat, [-1]) l2_ener_dipole_loss = tf.reduce_mean( tf.square(ener_dipole_reshape - ener_dipole_hat_reshape), name='l2_' + suffix) # atom_norm_ener = 1./ global_cvt_2_ener_float(natoms[0]) atom_norm_ener = 1. / global_cvt_2_ener_float(natoms) pref_e = global_cvt_2_ener_float( find_energy * (self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * learning_rate / self.starter_learning_rate)) pref_ed = global_cvt_2_tf_float( find_ener_dipole * (self.limit_pref_ed + (self.start_pref_ed - self.limit_pref_ed) * learning_rate / self.starter_learning_rate)) l2_loss = 0 more_loss = {} l2_loss += atom_norm_ener * (pref_e * l2_ener_loss) l2_loss += global_cvt_2_ener_float(pref_ed * l2_ener_dipole_loss) more_loss['l2_ener_loss'] = l2_ener_loss more_loss['l2_ener_dipole_loss'] = l2_ener_dipole_loss self.l2_l = l2_loss self.l2_more = more_loss return l2_loss, more_loss
def _filter_r(self, inputs, type_input, natoms, activation_fn=tf.nn.tanh, stddev=1.0, bavg=0.0, name='linear', reuse=None, trainable=True): # natom x nei outputs_size = [1] + self.filter_neuron with tf.variable_scope(name, reuse=reuse): start_index = 0 xyz_scatter_total = [] for type_i in range(self.ntypes): # cut-out inputs # with natom x nei_type_i inputs_i = tf.slice(inputs, [0, start_index], [-1, self.sel_r[type_i]]) start_index += self.sel_r[type_i] shape_i = inputs_i.get_shape().as_list() # with (natom x nei_type_i) x 1 xyz_scatter = tf.reshape(inputs_i, [-1, 1]) if (type_input, type_i) not in self.exclude_types: xyz_scatter = embedding_net( xyz_scatter, self.filter_neuron, self.filter_precision, activation_fn=activation_fn, resnet_dt=self.filter_resnet_dt, name_suffix="_" + str(type_i), stddev=stddev, bavg=bavg, seed=self.seed, trainable=trainable, uniform_seed=self.uniform_seed, initial_variables=self.embedding_net_variables, ) if (not self.uniform_seed) and (self.seed is not None): self.seed += self.seed_shift # natom x nei_type_i x out_size xyz_scatter = tf.reshape( xyz_scatter, (-1, shape_i[1], outputs_size[-1])) else: natom = tf.shape(inputs)[0] xyz_scatter = tf.cast( tf.fill((natom, shape_i[1], outputs_size[-1]), 0.), GLOBAL_TF_FLOAT_PRECISION) xyz_scatter_total.append(xyz_scatter) # natom x nei x outputs_size xyz_scatter = tf.concat(xyz_scatter_total, axis=1) # natom x outputs_size # res_rescale = 1. / 5. result = tf.reduce_mean(xyz_scatter, axis=1) * res_rescale return result
def build(self, learning_rate, natoms, model_dict, label_dict, suffix): wfc_hat = label_dict['wfc'] wfc = model_dict['wfc'] l2_loss = tf.reduce_mean(tf.square(wfc - wfc_hat), name='l2_' + suffix) self.l2_l = l2_loss more_loss = {} return l2_loss, more_loss
def build(self, learning_rate, natoms, model_dict, label_dict, suffix): polar_hat = label_dict[self.label_name] polar = model_dict[self.tensor_name] l2_loss = tf.reduce_mean(tf.square(polar - polar_hat), name='l2_' + suffix) self.l2_l = l2_loss more_loss = {} return l2_loss, more_loss
def variable_summaries(var: tf.Variable, name: str): """Attach a lot of summaries to a Tensor (for TensorBoard visualization). Parameters ---------- var : tf.Variable [description] name : str variable name """ with tf.name_scope(name): mean = tf.reduce_mean(var) tf.summary.scalar('mean', mean) with tf.name_scope('stddev'): stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) tf.summary.scalar('stddev', stddev) tf.summary.scalar('max', tf.reduce_max(var)) tf.summary.scalar('min', tf.reduce_min(var)) tf.summary.histogram('histogram', var)
def _type_embedding_net_one_side(self, mat_g, atype, natoms, name='', reuse=None, seed=None, trainable=True): outputs_size = self.filter_neuron[-1] nframes = tf.shape(mat_g)[0] # (nf x natom x nei) x (outputs_size x chnl x chnl) mat_g = tf.reshape(mat_g, [nframes * natoms[0] * self.nnei, outputs_size]) mat_g = one_layer(mat_g, outputs_size * self.type_nchanl, activation_fn=None, precision=self.filter_precision, name=name + '_amplify', reuse=reuse, seed=self.seed, trainable=trainable) # nf x natom x nei x outputs_size x chnl mat_g = tf.reshape( mat_g, [nframes, natoms[0], self.nnei, outputs_size, self.type_nchanl]) # nf x natom x outputs_size x nei x chnl mat_g = tf.transpose(mat_g, perm=[0, 1, 3, 2, 4]) # nf x natom x outputs_size x (nei x chnl) mat_g = tf.reshape( mat_g, [nframes, natoms[0], outputs_size, self.nnei * self.type_nchanl]) # nei x nchnl ebd_nei_type = self._type_embed(self.nei_type, reuse=reuse, trainable=True, suffix='') # (nei x nchnl) ebd_nei_type = tf.reshape(ebd_nei_type, [self.nnei * self.type_nchanl]) # nf x natom x outputs_size x (nei x chnl) mat_g = tf.multiply(mat_g, ebd_nei_type) # nf x natom x outputs_size x nei x chnl mat_g = tf.reshape( mat_g, [nframes, natoms[0], outputs_size, self.nnei, self.type_nchanl]) # nf x natom x outputs_size x nei mat_g = tf.reduce_mean(mat_g, axis=4) # nf x natom x nei x outputs_size mat_g = tf.transpose(mat_g, perm=[0, 1, 3, 2]) # (nf x natom) x nei x outputs_size mat_g = tf.reshape(mat_g, [nframes * natoms[0], self.nnei, outputs_size]) return mat_g
def build(self, learning_rate, natoms, model_dict, label_dict, suffix): polar_hat = label_dict[self.label_name] polar = model_dict[self.tensor_name] l2_loss = tf.reduce_mean(tf.square(self.scale * (polar - polar_hat)), name='l2_' + suffix) if not self.atomic: atom_norm = 1. / global_cvt_2_tf_float(natoms[0]) l2_loss = l2_loss * atom_norm self.l2_l = l2_loss more_loss = {} return l2_loss, more_loss
def _filter_r(self, inputs, type_input, natoms, activation_fn=tf.nn.tanh, stddev=1.0, bavg=0.0, name='linear', reuse=None, seed=None, trainable=True): # natom x nei outputs_size = [1] + self.filter_neuron with tf.variable_scope(name, reuse=reuse): start_index = 0 xyz_scatter_total = [] for type_i in range(self.ntypes): # cut-out inputs # with natom x nei_type_i inputs_i = tf.slice(inputs, [0, start_index], [-1, self.sel_r[type_i]]) start_index += self.sel_r[type_i] shape_i = inputs_i.get_shape().as_list() # with (natom x nei_type_i) x 1 xyz_scatter = tf.reshape(inputs_i, [-1, 1]) if (type_input, type_i) not in self.exclude_types: for ii in range(1, len(outputs_size)): w = tf.get_variable( 'matrix_' + str(ii) + '_' + str(type_i), [outputs_size[ii - 1], outputs_size[ii]], self.filter_precision, tf.random_normal_initializer( stddev=stddev / np.sqrt(outputs_size[ii] + outputs_size[ii - 1]), seed=seed), trainable=trainable) b = tf.get_variable( 'bias_' + str(ii) + '_' + str(type_i), [1, outputs_size[ii]], self.filter_precision, tf.random_normal_initializer(stddev=stddev, mean=bavg, seed=seed), trainable=trainable) if self.filter_resnet_dt: idt = tf.get_variable( 'idt_' + str(ii) + '_' + str(type_i), [1, outputs_size[ii]], self.filter_precision, tf.random_normal_initializer(stddev=0.001, mean=1.0, seed=seed), trainable=trainable) if outputs_size[ii] == outputs_size[ii - 1]: if self.filter_resnet_dt: xyz_scatter += activation_fn( tf.matmul(xyz_scatter, w) + b) * idt else: xyz_scatter += activation_fn( tf.matmul(xyz_scatter, w) + b) elif outputs_size[ii] == outputs_size[ii - 1] * 2: if self.filter_resnet_dt: xyz_scatter = tf.concat( [xyz_scatter, xyz_scatter], 1) + activation_fn( tf.matmul(xyz_scatter, w) + b) * idt else: xyz_scatter = tf.concat( [xyz_scatter, xyz_scatter], 1) + activation_fn( tf.matmul(xyz_scatter, w) + b) else: xyz_scatter = activation_fn( tf.matmul(xyz_scatter, w) + b) else: w = tf.zeros((outputs_size[0], outputs_size[-1]), dtype=global_tf_float_precision) xyz_scatter = tf.matmul(xyz_scatter, w) # natom x nei_type_i x out_size xyz_scatter = tf.reshape(xyz_scatter, (-1, shape_i[1], outputs_size[-1])) xyz_scatter_total.append(xyz_scatter) # natom x nei x outputs_size xyz_scatter = tf.concat(xyz_scatter_total, axis=1) # natom x outputs_size # res_rescale = 1. / 5. result = tf.reduce_mean(xyz_scatter, axis=1) * res_rescale return result
def build( self, inputs: tf.Tensor, natoms: tf.Tensor, input_dict: dict = None, reuse: bool = None, suffix: str = '', ) -> tf.Tensor: """ Build the computational graph for fitting net Parameters ---------- inputs The input descriptor input_dict Additional dict for inputs. if numb_fparam > 0, should have input_dict['fparam'] if numb_aparam > 0, should have input_dict['aparam'] natoms The number of atoms. This tensor has the length of Ntypes + 2 natoms[0]: number of local atoms natoms[1]: total number of atoms held by this processor natoms[i]: 2 <= i < Ntypes+2, number of type i atoms reuse The weights in the networks should be reused when get the variable. suffix Name suffix to identify this descriptor Returns ------- ener The system energy """ if input_dict is None: input_dict = {} bias_atom_e = self.bias_atom_e if self.numb_fparam > 0 and (self.fparam_avg is None or self.fparam_inv_std is None): raise RuntimeError( 'No data stat result. one should do data statisitic, before build' ) if self.numb_aparam > 0 and (self.aparam_avg is None or self.aparam_inv_std is None): raise RuntimeError( 'No data stat result. one should do data statisitic, before build' ) with tf.variable_scope('fitting_attr' + suffix, reuse=reuse): t_dfparam = tf.constant(self.numb_fparam, name='dfparam', dtype=tf.int32) t_daparam = tf.constant(self.numb_aparam, name='daparam', dtype=tf.int32) if self.numb_fparam > 0: t_fparam_avg = tf.get_variable( 't_fparam_avg', self.numb_fparam, dtype=GLOBAL_TF_FLOAT_PRECISION, trainable=False, initializer=tf.constant_initializer(self.fparam_avg)) t_fparam_istd = tf.get_variable( 't_fparam_istd', self.numb_fparam, dtype=GLOBAL_TF_FLOAT_PRECISION, trainable=False, initializer=tf.constant_initializer(self.fparam_inv_std)) if self.numb_aparam > 0: t_aparam_avg = tf.get_variable( 't_aparam_avg', self.numb_aparam, dtype=GLOBAL_TF_FLOAT_PRECISION, trainable=False, initializer=tf.constant_initializer(self.aparam_avg)) t_aparam_istd = tf.get_variable( 't_aparam_istd', self.numb_aparam, dtype=GLOBAL_TF_FLOAT_PRECISION, trainable=False, initializer=tf.constant_initializer(self.aparam_inv_std)) inputs = tf.reshape(inputs, [-1, self.dim_descrpt * natoms[0]]) if len(self.atom_ener): # only for atom_ener nframes = input_dict.get('nframes') if nframes is not None: # like inputs, but we don't want to add a dependency on inputs inputs_zero = tf.zeros((nframes, self.dim_descrpt * natoms[0]), dtype=self.fitting_precision) else: inputs_zero = tf.zeros_like(inputs, dtype=self.fitting_precision) if bias_atom_e is not None: assert (len(bias_atom_e) == self.ntypes) fparam = None aparam = None if self.numb_fparam > 0: fparam = input_dict['fparam'] fparam = tf.reshape(fparam, [-1, self.numb_fparam]) fparam = (fparam - t_fparam_avg) * t_fparam_istd if self.numb_aparam > 0: aparam = input_dict['aparam'] aparam = tf.reshape(aparam, [-1, self.numb_aparam]) aparam = (aparam - t_aparam_avg) * t_aparam_istd aparam = tf.reshape(aparam, [-1, self.numb_aparam * natoms[0]]) type_embedding = input_dict.get('type_embedding', None) if type_embedding is not None: atype_embed = embed_atom_type(self.ntypes, natoms, type_embedding) atype_embed = tf.tile(atype_embed, [tf.shape(inputs)[0], 1]) else: atype_embed = None if atype_embed is None: start_index = 0 outs_list = [] for type_i in range(self.ntypes): if bias_atom_e is None: type_bias_ae = 0.0 else: type_bias_ae = bias_atom_e[type_i] final_layer = self._build_lower(start_index, natoms[2 + type_i], inputs, fparam, aparam, bias_atom_e=type_bias_ae, suffix='_type_' + str(type_i) + suffix, reuse=reuse) # concat the results if type_i < len( self.atom_ener) and self.atom_ener[type_i] is not None: zero_layer = self._build_lower(start_index, natoms[2 + type_i], inputs_zero, fparam, aparam, bias_atom_e=type_bias_ae, suffix='_type_' + str(type_i) + suffix, reuse=True) final_layer += self.atom_ener[type_i] - zero_layer final_layer = tf.reshape( final_layer, [tf.shape(inputs)[0], natoms[2 + type_i]]) outs_list.append(final_layer) start_index += natoms[2 + type_i] # concat the results # concat once may be faster than multiple concat outs = tf.concat(outs_list, axis=1) # with type embedding else: if len(self.atom_ener) > 0: raise RuntimeError( "setting atom_ener is not supported by type embedding") atype_embed = tf.cast(atype_embed, self.fitting_precision) type_shape = atype_embed.get_shape().as_list() inputs = tf.concat( [tf.reshape(inputs, [-1, self.dim_descrpt]), atype_embed], axis=1) self.dim_descrpt = self.dim_descrpt + type_shape[1] inputs = tf.reshape(inputs, [-1, self.dim_descrpt * natoms[0]]) final_layer = self._build_lower(0, natoms[0], inputs, fparam, aparam, bias_atom_e=0.0, suffix=suffix, reuse=reuse) outs = tf.reshape(final_layer, [tf.shape(inputs)[0], natoms[0]]) # add atom energy bias; TF will broadcast to all batches # tf.repeat is avaiable in TF>=2.1 or TF 1.15 _TF_VERSION = Version(TF_VERSION) if (Version('1.15') <= _TF_VERSION < Version('2') or _TF_VERSION >= Version('2.1')) and self.bias_atom_e is not None: outs += tf.repeat( tf.Variable(self.bias_atom_e, dtype=self.fitting_precision, trainable=False, name="bias_atom_ei"), natoms[2:]) if self.tot_ener_zero: force_tot_ener = 0.0 outs = tf.reshape(outs, [-1, natoms[0]]) outs_mean = tf.reshape(tf.reduce_mean(outs, axis=1), [-1, 1]) outs_mean = outs_mean - tf.ones_like( outs_mean, dtype=GLOBAL_TF_FLOAT_PRECISION) * ( force_tot_ener / global_cvt_2_tf_float(natoms[0])) outs = outs - outs_mean outs = tf.reshape(outs, [-1]) tf.summary.histogram('fitting_net_output', outs) return tf.reshape(outs, [-1])
def build(self, learning_rate, natoms, model_dict, label_dict, suffix): polar_hat = label_dict[self.label_name] atomic_polar_hat = label_dict["atomic_" + self.label_name] polar = tf.reshape(model_dict[self.tensor_name], [-1]) find_global = label_dict['find_' + self.label_name] find_atomic = label_dict['find_atomic_' + self.label_name] # YHT: added for global / local dipole combination l2_loss = global_cvt_2_tf_float(0.0) more_loss = { "local_loss": global_cvt_2_tf_float(0.0), "global_loss": global_cvt_2_tf_float(0.0) } if self.local_weight > 0.0: local_loss = global_cvt_2_tf_float(find_atomic) * tf.reduce_mean( tf.square(self.scale * (polar - atomic_polar_hat)), name='l2_' + suffix) more_loss['local_loss'] = local_loss l2_loss += self.local_weight * local_loss self.l2_loss_local_summary = tf.summary.scalar( 'l2_local_loss', tf.sqrt(more_loss['local_loss'])) if self.global_weight > 0.0: # Need global loss atoms = 0 if self.type_sel is not None: for w in self.type_sel: atoms += natoms[2 + w] else: atoms = natoms[0] nframes = tf.shape(polar)[0] // self.tensor_size // atoms # get global results global_polar = tf.reshape( tf.reduce_sum(tf.reshape(polar, [nframes, -1, self.tensor_size]), axis=1), [-1]) #if self.atomic: # If label is local, however # global_polar_hat = tf.reshape(tf.reduce_sum(tf.reshape( # polar_hat, [nframes, -1, self.tensor_size]), axis=1),[-1]) #else: # global_polar_hat = polar_hat global_loss = global_cvt_2_tf_float(find_global) * tf.reduce_mean( tf.square(self.scale * (global_polar - polar_hat)), name='l2_' + suffix) more_loss['global_loss'] = global_loss self.l2_loss_global_summary = tf.summary.scalar( 'l2_global_loss', tf.sqrt(more_loss['global_loss']) / global_cvt_2_tf_float(atoms)) # YWolfeee: should only consider atoms with dipole, i.e. atoms # atom_norm = 1./ global_cvt_2_tf_float(natoms[0]) atom_norm = 1. / global_cvt_2_tf_float(atoms) global_loss *= atom_norm l2_loss += self.global_weight * global_loss self.l2_more = more_loss self.l2_l = l2_loss self.l2_loss_summary = tf.summary.scalar('l2_loss', tf.sqrt(l2_loss)) return l2_loss, more_loss
def build(self, learning_rate, natoms, model_dict, label_dict, suffix): energy = model_dict['energy'] force = model_dict['force'] virial = model_dict['virial'] atom_ener = model_dict['atom_ener'] energy_hat = label_dict['energy'] force_hat = label_dict['force'] virial_hat = label_dict['virial'] atom_ener_hat = label_dict['atom_ener'] atom_pref = label_dict['atom_pref'] find_energy = label_dict['find_energy'] find_force = label_dict['find_force'] find_virial = label_dict['find_virial'] find_atom_ener = label_dict['find_atom_ener'] find_atom_pref = label_dict['find_atom_pref'] l2_ener_loss = tf.reduce_mean(tf.square(energy - energy_hat), name='l2_' + suffix) force_reshape = tf.reshape(force, [-1]) force_hat_reshape = tf.reshape(force_hat, [-1]) atom_pref_reshape = tf.reshape(atom_pref, [-1]) diff_f = force_hat_reshape - force_reshape if self.relative_f is not None: force_hat_3 = tf.reshape(force_hat, [-1, 3]) norm_f = tf.reshape(tf.norm(force_hat_3, axis=1), [-1, 1]) + self.relative_f diff_f_3 = tf.reshape(diff_f, [-1, 3]) diff_f_3 = diff_f_3 / norm_f diff_f = tf.reshape(diff_f_3, [-1]) l2_force_loss = tf.reduce_mean(tf.square(diff_f), name="l2_force_" + suffix) l2_pref_force_loss = tf.reduce_mean(tf.multiply( tf.square(diff_f), atom_pref_reshape), name="l2_pref_force_" + suffix) virial_reshape = tf.reshape(virial, [-1]) virial_hat_reshape = tf.reshape(virial_hat, [-1]) l2_virial_loss = tf.reduce_mean(tf.square(virial_hat_reshape - virial_reshape), name="l2_virial_" + suffix) atom_ener_reshape = tf.reshape(atom_ener, [-1]) atom_ener_hat_reshape = tf.reshape(atom_ener_hat, [-1]) l2_atom_ener_loss = tf.reduce_mean(tf.square(atom_ener_hat_reshape - atom_ener_reshape), name="l2_atom_ener_" + suffix) atom_norm = 1. / global_cvt_2_tf_float(natoms[0]) atom_norm_ener = 1. / global_cvt_2_ener_float(natoms[0]) pref_e = global_cvt_2_ener_float( find_energy * (self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * learning_rate / self.starter_learning_rate)) pref_f = global_cvt_2_tf_float( find_force * (self.limit_pref_f + (self.start_pref_f - self.limit_pref_f) * learning_rate / self.starter_learning_rate)) pref_v = global_cvt_2_tf_float( find_virial * (self.limit_pref_v + (self.start_pref_v - self.limit_pref_v) * learning_rate / self.starter_learning_rate)) pref_ae = global_cvt_2_tf_float( find_atom_ener * (self.limit_pref_ae + (self.start_pref_ae - self.limit_pref_ae) * learning_rate / self.starter_learning_rate)) pref_pf = global_cvt_2_tf_float( find_atom_pref * (self.limit_pref_pf + (self.start_pref_pf - self.limit_pref_pf) * learning_rate / self.starter_learning_rate)) l2_loss = 0 more_loss = {} if self.has_e: l2_loss += atom_norm_ener * (pref_e * l2_ener_loss) more_loss['l2_ener_loss'] = l2_ener_loss if self.has_f: l2_loss += global_cvt_2_ener_float(pref_f * l2_force_loss) more_loss['l2_force_loss'] = l2_force_loss if self.has_v: l2_loss += global_cvt_2_ener_float(atom_norm * (pref_v * l2_virial_loss)) more_loss['l2_virial_loss'] = l2_virial_loss if self.has_ae: l2_loss += global_cvt_2_ener_float(pref_ae * l2_atom_ener_loss) more_loss['l2_atom_ener_loss'] = l2_atom_ener_loss if self.has_pf: l2_loss += global_cvt_2_ener_float(pref_pf * l2_pref_force_loss) more_loss['l2_pref_force_loss'] = l2_pref_force_loss self.l2_l = l2_loss self.l2_more = more_loss return l2_loss, more_loss
def _type_embedding_net_one_side_aparam(self, mat_g, atype, natoms, aparam, name='', reuse=None, seed=None, trainable=True): outputs_size = self.filter_neuron[-1] nframes = tf.shape(mat_g)[0] # (nf x natom x nei) x (outputs_size x chnl x chnl) mat_g = tf.reshape(mat_g, [nframes * natoms[0] * self.nnei, outputs_size]) mat_g = one_layer(mat_g, outputs_size * self.type_nchanl, activation_fn=None, precision=self.filter_precision, name=name + '_amplify', reuse=reuse, seed=self.seed, trainable=trainable) # nf x natom x nei x outputs_size x chnl mat_g = tf.reshape( mat_g, [nframes, natoms[0], self.nnei, outputs_size, self.type_nchanl]) # outputs_size x nf x natom x nei x chnl mat_g = tf.transpose(mat_g, perm=[3, 0, 1, 2, 4]) # outputs_size x (nf x natom x nei x chnl) mat_g = tf.reshape( mat_g, [outputs_size, nframes * natoms[0] * self.nnei * self.type_nchanl]) # nf x natom x nnei embed_type = tf.tile(tf.reshape(self.nei_type, [1, self.nnei]), [nframes * natoms[0], 1]) # (nf x natom x nnei) x 1 embed_type = tf.reshape(embed_type, [nframes * natoms[0] * self.nnei, 1]) # nf x (natom x naparam) aparam = tf.reshape(aparam, [nframes, -1]) # nf x natom x nnei x naparam embed_aparam = op_module.map_aparam(aparam, self.nlist, natoms, n_a_sel=self.nnei_a, n_r_sel=self.nnei_r) # (nf x natom x nnei) x naparam embed_aparam = tf.reshape( embed_aparam, [nframes * natoms[0] * self.nnei, self.numb_aparam]) # (nf x natom x nnei) x (naparam+1) embed_input = tf.concat((embed_type, embed_aparam), axis=1) # (nf x natom x nnei) x nchnl ebd_nei_type = self._type_embed(embed_input, ndim=self.numb_aparam + 1, reuse=reuse, trainable=True, suffix='') # (nf x natom x nei x nchnl) ebd_nei_type = tf.reshape( ebd_nei_type, [nframes * natoms[0] * self.nnei * self.type_nchanl]) # outputs_size x (nf x natom x nei x chnl) mat_g = tf.multiply(mat_g, ebd_nei_type) # outputs_size x nf x natom x nei x chnl mat_g = tf.reshape( mat_g, [outputs_size, nframes, natoms[0], self.nnei, self.type_nchanl]) # outputs_size x nf x natom x nei mat_g = tf.reduce_mean(mat_g, axis=4) # nf x natom x nei x outputs_size mat_g = tf.transpose(mat_g, perm=[1, 2, 3, 0]) # (nf x natom) x nei x outputs_size mat_g = tf.reshape(mat_g, [nframes * natoms[0], self.nnei, outputs_size]) return mat_g