class DeepStatisticalSolver: def __init__(self, sess, latent_dimension=10, hidden_layers=3, correction_updates=5, alpha=1e-3, non_lin='leaky_relu', minibatch_size=10, name='deep_statistical_solver', directory='./', model_to_restore=None, default_data_directory='datasets/spring/default', proxy=False): self.sess = sess self.latent_dimension = latent_dimension self.hidden_layers = hidden_layers self.correction_updates = correction_updates self.alpha = alpha self.non_lin = non_lin self.minibatch_size = minibatch_size self.name = name self.directory = directory self.current_train_iter = 0 self.default_data_directory = default_data_directory self.proxy = proxy # Initialize list of trainable variables self.trainable_variables = [] try: # Try importing the dimensions associated to the problem sys.path.append(self.default_data_directory) from problem import Problem self.problem = Problem() except ImportError: print('You should provide a compatible "problem.py" file in your data folder!') # Reload config if there is a model to restore if (model_to_restore is not None) and os.path.exists(model_to_restore): logging.info(' Restoring model from '+model_to_restore) # Reload the parameters from a json file path_to_config = os.path.join(model_to_restore, 'config.json') with open(path_to_config, 'r') as f: config = json.load(f) self.set_config(config) else: self.d_in_A = self.problem.d_in_A self.d_in_B = self.problem.d_in_B self.d_out = self.problem.d_out self.d_F = self.problem.d_F self.initial_U = self.problem.initial_U # Normalization constants self.B_mean = self.problem.B_mean self.B_std = self.problem.B_std self.A_mean = self.problem.A_mean self.A_std = self.problem.A_std # Build weight tensors self.build_weights() # Build computational graph self.build_graph(self.default_data_directory) # Restore trained weights if there is a model to restore if (model_to_restore is not None) and os.path.exists(model_to_restore): # Reload the weights from a ckpt file saver = tf.compat.v1.train.Saver(self.trainable_variables) path_to_weights = os.path.join(model_to_restore, 'model.ckpt') saver.restore(self.sess, path_to_weights) # Else, randomly initialize weights else: self.sess.run(tf.compat.v1.variables_initializer(self.trainable_variables)) # Log config infos self.log_config() def build_weights(self): """ Builds all the trainable variables """ # Build weights of each correction update block, and store them self.psi = {} self.phi_from = {} self.phi_to = {} self.phi_loop = {} self.xi = {} for update in range(self.correction_updates): self.psi[str(update)] = FullyConnected( non_lin=self.non_lin, latent_dimension=self.latent_dimension, hidden_layers=self.hidden_layers, name=self.name+'_correction_block_{}'.format(update), input_dim=4*(self.latent_dimension)+self.d_in_B ) self.phi_from[str(update)] = FullyConnected( non_lin=self.non_lin, latent_dimension=self.latent_dimension, hidden_layers=self.hidden_layers, name=self.name+'_phi_from_{}'.format(update), input_dim=2*(self.latent_dimension)+self.d_in_A ) self.phi_to[str(update)] = FullyConnected( non_lin=self.non_lin, latent_dimension=self.latent_dimension, hidden_layers=self.hidden_layers, name=self.name+'_phi_to_{}'.format(update), input_dim=2*(self.latent_dimension)+self.d_in_A ) self.phi_loop[str(update)] = FullyConnected( non_lin=self.non_lin, latent_dimension=self.latent_dimension, hidden_layers=self.hidden_layers, name=self.name+'_phi_loop_{}'.format(update), input_dim=2*(self.latent_dimension)+self.d_in_A ) for update in range(self.correction_updates+1): self.xi[str(update)] = FullyConnected( non_lin=self.non_lin, latent_dimension=self.latent_dimension, hidden_layers=self.hidden_layers,#1, name=self.name+'_D_{}'.format(update), input_dim=self.latent_dimension, output_dim=self.d_out ) # self.D = FullyConnected( # non_lin=self.non_lin, # latent_dimension=self.latent_dimension, # hidden_layers=self.hidden_layers,#1, # name=self.name+'_D_{}'.format(update), # input_dim=self.latent_dimension, # output_dim=self.d_out # ) #self.loss_function = self.cost_function EquilibriumViolation(self.default_data_directory) def build_graph(self, default_data_directory): """ Builds the computation graph. Assumes that all graphs have been merged into one supergraph """ def extract_fn(tfrecord): # Extract features using the keys set during creation features = { 'A': tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True), 'B': tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True), 'U': tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True) } # Extract the data record sample = tf.parse_single_example(tfrecord, features) return [sample['A'], sample['B'], sample['U']] self.train_dataset = tf.data.TFRecordDataset([os.path.join(default_data_directory,'train.tfrecords')]) self.train_dataset = self.train_dataset.map(extract_fn).shuffle(100).batch(self.minibatch_size).repeat() self.valid_dataset = tf.data.TFRecordDataset([os.path.join(default_data_directory,'val.tfrecords')]) self.valid_dataset = self.valid_dataset.map(extract_fn).shuffle(100).batch(self.minibatch_size).repeat() # Build iterator self.iterator = tf.compat.v1.data.Iterator.from_structure( tf.compat.v1.data.get_output_types(self.train_dataset), None) self.next_element = self.iterator.get_next() # Build operations to initialize training and validation self.training_init_op = self.iterator.make_initializer(self.train_dataset) self.validation_init_op = self.iterator.make_initializer(self.valid_dataset) # Get the output of the data handler self.A_flat, self.B_flat, self.U_flat = self.next_element # Reshape the iterator self.minibatch_size_ = tf.shape(self.A_flat)[0] self.A = tf.reshape(self.A_flat, [self.minibatch_size_, -1, self.d_in_A+2]) self.B = tf.reshape(self.B_flat, [self.minibatch_size_, -1, self.d_in_B]) self.U_gt = tf.reshape(self.U_flat, [self.minibatch_size_, -1, self.d_out]) # Get relevant tensor dimensions self.minibatch_size_tf = tf.shape(self.A)[0] self.num_nodes = tf.shape(self.B)[1] self.num_edges = tf.shape(self.A)[1] self.A_dim = tf.shape(self.A)[2] # Getting normalization constants self.A_mean_tf = tf.ones([self.minibatch_size_tf, self.num_edges, 1]) * \ tf.reshape(tf.constant(self.A_mean, dtype=tf.float32), [1, 1, 2+self.d_in_A]) self.A_std_tf = tf.ones([self.minibatch_size_tf, self.num_edges, 1]) * \ tf.reshape(tf.constant(self.A_std, dtype=tf.float32), [1, 1, 2+self.d_in_A]) self.B_mean_tf = tf.ones([self.minibatch_size_tf, self.num_nodes, 1]) * \ tf.reshape(tf.constant(self.B_mean, dtype=tf.float32), [1, 1, self.d_in_B]) self.B_std_tf = tf.ones([self.minibatch_size_tf, self.num_nodes, 1]) * \ tf.reshape(tf.constant(self.B_std, dtype=tf.float32), [1, 1, self.d_in_B]) # Normalizing inputs A and B. Lower case means normalized self.a = (self.A - self.A_mean_tf) / self.A_std_tf self.b = (self.B - self.B_mean_tf) / self.B_std_tf # Extract indices from matrix A (indices are indeed not normalized) self.indices_from = tf.cast(self.A[:,:,0], tf.int32) self.indices_to = tf.cast(self.A[:,:,1], tf.int32) # Build mask to detect loops self.mask_loop = tf.cast(tf.math.equal(self.indices_from, self.indices_to), tf.float32) self.mask_loop = tf.expand_dims(self.mask_loop, -1) # Extract normalized edge characteristics from matrix A self.a_ij = self.a[:,:,2:] # Initialize the discount factor that will later be updated self.discount = tf.Variable(0., trainable=False) self.sess.run(tf.compat.v1.variables_initializer([self.discount])) # Initialize messages, predictions and losses dict self.H = {} self.U = {} self.loss = {} self.loss_proxy = {} self.log_loss = {} self.cost_per_sample = {} self.total_loss = None self.total_loss_proxy = None # Get the natural offset that will be added to every output U at every node self.initial_U_tf = tf.ones([self.minibatch_size_tf, self.num_nodes, 1]) * \ tf.reshape(tf.constant(self.initial_U, dtype=tf.float32), [1, 1, self.d_out]) # Initialize latent message and prediction to 0 self.H['0'] = tf.zeros([self.minibatch_size_tf, self.num_nodes, self.latent_dimension]) # Decode the first message. Although this step useless, it is still there for compatibility issues self.U['0'] = self.xi['0'](self.H['0']) + self.initial_U_tf # Iterate over every correction update (k_bar) for update in range(self.correction_updates): # Gather messages from both extremities of each edges self.H_from = custom_gather(self.H[str(update)], self.indices_from) self.H_to = custom_gather(self.H[str(update)], self.indices_to) # Concatenate all the inputs of the phi neural network self.Phi_input = tf.concat([self.H_from, self.H_to, self.a_ij], axis=2) # Compute the phi using the dedicated neural network blocks self.Phi_from = self.phi_from[str(update)](self.Phi_input) * (1.-self.mask_loop) self.Phi_to = self.phi_to[str(update)](self.Phi_input) * (1.-self.mask_loop) self.Phi_loop = self.phi_loop[str(update)](self.Phi_input) * self.mask_loop # Get the sum of each transformed messages at each node self.Phi_from_sum = custom_scatter( self.indices_from, self.Phi_from, [self.minibatch_size_tf, self.num_nodes, self.latent_dimension]) self.Phi_to_sum = custom_scatter( self.indices_to, self.Phi_to, [self.minibatch_size_tf, self.num_nodes, self.latent_dimension]) self.Phi_loop_sum = custom_scatter( self.indices_to, self.Phi_loop, [self.minibatch_size_tf, self.num_nodes, self.latent_dimension]) # Concatenate all the inputs of the correction neural network self.correction_input = tf.concat([ self.H[str(update)], self.Phi_from_sum, self.Phi_to_sum, self.Phi_loop_sum, self.b], axis=2) # Compute the correction using the dedicated neural network block self.correction = self.psi[str(update)](self.correction_input) # Apply correction, and extract the predictions from the latent message self.H[str(update+1)] = self.H[str(update)] + self.correction * self.alpha # Decode H self.U[str(update+1)] = self.xi[str(update+1)](self.H[str(update+1)]) + self.initial_U_tf # Compute the loss for each sample if self.proxy: self.loss_proxy[str(update + 1)] = \ tf.reduce_mean((self.U[str(update+1)] - self.U_gt)**2) self.cost_per_sample[str(update+1)] = \ self.problem.cost_function(self.U[str(update+1)], self.A, self.B) # Take the mean, and register its value for tensorboard self.loss[str(update+1)] = tf.reduce_mean(self.cost_per_sample[str(update+1)]) tf.compat.v1.summary.scalar("loss_{}".format(update+1), self.loss[str(update+1)]) # Compute the discounted loss if self.total_loss is None: self.total_loss = self.loss[str(update+1)] * self.discount**(self.correction_updates-1-update) if self.proxy: self.total_loss_proxy = self.loss_proxy[str(update+1)] * \ self.discount**(self.correction_updates-1-update) else: self.total_loss += self.loss[str(update+1)] * self.discount**(self.correction_updates-1-update) if self.proxy: self.total_loss_proxy += self.loss_proxy[str(update+1)] * \ self.discount**(self.correction_updates-1-update) # Get the final prediction and the final loss self.U_final = self.U[str(self.correction_updates)] self.loss_final = self.loss[str(self.correction_updates)] # Initialize the optimizer self.learning_rate = tf.Variable(0., trainable=False) self.sess.run(tf.compat.v1.variables_initializer([self.learning_rate])) self.optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=self.learning_rate) # Gradient clipping to avoid exploding gradients if self.proxy: self.gradients, self.variables = zip(*self.optimizer.compute_gradients(self.total_loss_proxy)) else: self.gradients, self.variables = zip(*self.optimizer.compute_gradients(self.total_loss)) self.gradients, _ = tf.clip_by_global_norm(self.gradients, 1e-2) # Define optimization operator self.opt_op = self.optimizer.apply_gradients(zip(self.gradients, self.variables)) # Initialize training variables self.sess.run(tf.compat.v1.variables_initializer(self.optimizer.variables())) # Build summary to visualize the final loss in Tensorboard tf.compat.v1.summary.scalar("loss_final", self.loss_final) self.merged_summary_op = tf.compat.v1.summary.merge_all() # Gather trainable variables for update in range(self.correction_updates): self.trainable_variables.extend(self.phi_from[str(update)].trainable_variables) self.trainable_variables.extend(self.phi_to[str(update)].trainable_variables) self.trainable_variables.extend(self.phi_loop[str(update)].trainable_variables) self.trainable_variables.extend(self.psi[str(update)].trainable_variables) for update in range(self.correction_updates+1): self.trainable_variables.extend(self.xi[str(update)].trainable_variables) def log_config(self): """ Logs the config of the whole model """ logging.info(' Configuration :') logging.info(' Storing model in : '+self.directory) logging.info(' Latent dimensions : {}'.format(self.latent_dimension)) logging.info(' Number of hidden layers per block : {}'.format(self.hidden_layers)) logging.info(' Number of correction updates : {}'.format(self.correction_updates)) logging.info(' Alpha : {}'.format(self.alpha)) logging.info(' Non linearity : {}'.format(self.non_lin)) logging.info(' d_in_A : {}'.format(self.d_in_A)) logging.info(' d_in_B : {}'.format(self.d_in_B)) logging.info(' d_out : {}'.format(self.d_out)) #logging.info(' d_F : {}'.format(self.d_F)) logging.info(' Minibatch size : {}'.format(self.minibatch_size)) logging.info(' Current training iteration : {}'.format(self.current_train_iter)) logging.info(' Model name : ' + self.name) logging.info(' Initial U : {}'.format(self.initial_U)) logging.info(' A mean : {}'.format(self.A_mean)) logging.info(' A std : {}'.format(self.A_std)) logging.info(' B mean : {}'.format(self.B_mean)) logging.info(' B std : {}'.format(self.B_std)) logging.info(' Proxy : {}'.format(self.proxy)) def set_config(self, config): """ Sets the config according to an inputed dict """ self.latent_dimension = config['latent_dimension'] self.hidden_layers = config['hidden_layers'] self.correction_updates = config['correction_updates'] self.alpha = config['alpha'] self.non_lin = config['non_lin'] self.d_in_A = config['d_in_A'] self.d_in_B = config['d_in_B'] self.d_out = config['d_out'] #self.d_F = config['d_F'] self.minibatch_size = config['minibatch_size'] self.name = config['name'] self.directory = config['directory'] self.current_train_iter = config['current_train_iter'] self.initial_U = np.array(config['initial_U']) self.A_mean = np.array(config['A_mean']) self.A_std = np.array(config['A_std']) self.B_mean = np.array(config['B_mean']) self.B_std = np.array(config['B_std']) if 'proxy' in config: self.proxy = config['proxy'] else: self.proxy = False def get_config(self): """ Gets the config dict """ config = { 'latent_dimension': self.latent_dimension, 'hidden_layers': self.hidden_layers, 'correction_updates': self.correction_updates, 'alpha': self.alpha, 'non_lin': self.non_lin, 'd_in_A': self.d_in_A, 'd_in_B': self.d_in_B, 'd_out': self.d_out, #'d_F': self.d_F, 'minibatch_size': self.minibatch_size, 'name': self.name, 'directory': self.directory, 'current_train_iter': self.current_train_iter, 'initial_U': list(self.initial_U), 'A_mean': list(self.A_mean), 'A_std': list(self.A_std), 'B_mean': list(self.B_mean), 'B_std': list(self.B_std), 'proxy': self.proxy } return config def save(self): """ Saves the configuration of the model and the trained weights """ # Save config config = self.get_config() path_to_config = os.path.join(self.directory, 'config.json') with open(path_to_config, 'w') as f: json.dump(config, f) # Save weights saver = tf.compat.v1.train.Saver(self.trainable_variables) path_to_weights = os.path.join(self.directory, 'model.ckpt') saver.save(self.sess, path_to_weights) def train(self, max_iter=10, learning_rate=3e-4, discount=0.9, data_directory='datasets/spring/default', save_step=None, profile=False): """ Performs a training process while keeping track of the validation score """ # Log infos about training process logging.info(' Starting a training process :') logging.info(' Max iteration : {}'.format(max_iter)) logging.info(' Learning rate : {}'.format(learning_rate)) logging.info(' Discount : {}'.format(discount)) logging.info(' Training data : {}'.format(data_directory)) logging.info(' Saving model every {} iterations'.format(save_step)) if profile: logging.info(' Profiling...') # Load dataset self.sess.run(self.training_init_op) # Build writer dedicated to training for Tensorboard self.training_writer = tf.compat.v1.summary.FileWriter( os.path.join(self.directory, 'train')) # Build writer dedicated to validation for Tensorboard self.validation_writer = tf.compat.v1.summary.FileWriter( os.path.join(self.directory, 'val')) # Set discount factor and learning rate self.sess.run(self.discount.assign(discount)) self.sess.run(self.learning_rate.assign(learning_rate)) # Copy the latest training iteration of the model starting_point = copy.copy(self.current_train_iter) # If profiling, then initialize useful variables if profile: options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() profile_path = os.path.join(self.directory, 'profile') # Training loop for i in tqdm(range(starting_point, starting_point+max_iter)): # Store current training step, so that it's always up to date self.current_train_iter = i # Perform SGD step if profile and i>starting_point: self.sess.run(self.opt_op, options=options, run_metadata=run_metadata) fetched_timeline = timeline.Timeline(run_metadata.step_stats) chrome_trace = fetched_timeline.generate_chrome_trace_format() with open(profile_path+'_{}.json'.format(i), 'w') as f: f.write(chrome_trace) else: self.sess.run(self.opt_op) # Store final loss in a summary self.summary = self.sess.run(self.merged_summary_op) self.training_writer.add_summary(self.summary, self.current_train_iter) # Periodically log metrics and save model if ((save_step is not None) & (i % save_step == 0)) or (i == starting_point+max_iter-1): # Get minibatch train loss loss_final_train = self.sess.run(self.loss_final) # Change source data to validation self.sess.run(self.validation_init_op) # Get minibatch val loss loss_final_val = self.sess.run(self.loss_final) # Store final loss in validation self.summary = self.sess.run(self.merged_summary_op) self.validation_writer.add_summary(self.summary, self.current_train_iter) # Change source data to validation self.sess.run(self.training_init_op) # Log metrics logging.info(' Learning iteration {}'.format(i)) logging.info(' Training loss (minibatch) : {}'.format(loss_final_train)) logging.info(' Validation loss (minibatch): {}'.format(loss_final_val)) # Save model self.save() # Save model at the end of training self.save() def evaluate(self, mode='val', data_directory='data'): """ Evaluate loss on the desired dataset and stores predictions """ # Import numpy dataset data_plot = 'datasets/spring/large' data_file = '_test.npy' A = np.load(os.path.join(data_directory, 'A_'+mode+'.npy')) B = np.load(os.path.join(data_directory, 'B_'+mode+'.npy')) # Compute final loss loss = self.sess.run(self.loss_final, feed_dict={self.A:A, self.B:B}) # Compute and save the final prediction X_final = self.sess.run(self.X_final, feed_dict={self.A:A, self.B:B}) np.save(os.path.join(self.directory, 'X_final_pred_'+mode+'.npy'), X_final) return loss