コード例 #1
0
class DeepStatisticalSolver:

    def __init__(self,
        sess,
        latent_dimension=10,
        hidden_layers=3,
        correction_updates=5,
        alpha=1e-3,
        non_lin='leaky_relu',
        minibatch_size=10,
        name='deep_statistical_solver',
        directory='./',
        model_to_restore=None,
        default_data_directory='datasets/spring/default',
        proxy=False):

        self.sess = sess
        self.latent_dimension = latent_dimension
        self.hidden_layers = hidden_layers
        self.correction_updates = correction_updates
        self.alpha = alpha
        self.non_lin = non_lin
        self.minibatch_size = minibatch_size
        self.name = name
        self.directory = directory
        self.current_train_iter = 0
        self.default_data_directory = default_data_directory
        self.proxy = proxy

        # Initialize list of trainable variables
        self.trainable_variables = []

        try:
            # Try importing the dimensions associated to the problem
            sys.path.append(self.default_data_directory)
            from problem import Problem

            self.problem = Problem()

        except ImportError:
            print('You should provide a compatible "problem.py" file in your data folder!')

        # Reload config if there is a model to restore
        if (model_to_restore is not None) and os.path.exists(model_to_restore):

            logging.info('    Restoring model from '+model_to_restore)

            # Reload the parameters from a json file
            path_to_config = os.path.join(model_to_restore, 'config.json')
            with open(path_to_config, 'r') as f:
                config = json.load(f)
            self.set_config(config)

        else:
        
            self.d_in_A = self.problem.d_in_A
            self.d_in_B = self.problem.d_in_B
            self.d_out = self.problem.d_out
            self.d_F = self.problem.d_F
            self.initial_U = self.problem.initial_U

            # Normalization constants
            self.B_mean = self.problem.B_mean
            self.B_std = self.problem.B_std
            self.A_mean = self.problem.A_mean
            self.A_std = self.problem.A_std

        # Build weight tensors
        self.build_weights()

        # Build computational graph
        self.build_graph(self.default_data_directory)

        # Restore trained weights if there is a model to restore
        if (model_to_restore is not None) and os.path.exists(model_to_restore):

            # Reload the weights from a ckpt file
            saver = tf.compat.v1.train.Saver(self.trainable_variables)
            path_to_weights = os.path.join(model_to_restore, 'model.ckpt')
            saver.restore(self.sess, path_to_weights)

        # Else, randomly initialize weights
        else:
            self.sess.run(tf.compat.v1.variables_initializer(self.trainable_variables))

        # Log config infos
        self.log_config()


    def build_weights(self):
        """
        Builds all the trainable variables
        """

        # Build weights of each correction update block, and store them
        self.psi = {}

        self.phi_from = {}
        self.phi_to = {}
        self.phi_loop = {}

        self.xi = {}

        for update in range(self.correction_updates):

            self.psi[str(update)] = FullyConnected(
                non_lin=self.non_lin,
                latent_dimension=self.latent_dimension,
                hidden_layers=self.hidden_layers,
                name=self.name+'_correction_block_{}'.format(update),
                input_dim=4*(self.latent_dimension)+self.d_in_B
            )
            self.phi_from[str(update)] = FullyConnected(
                non_lin=self.non_lin,
                latent_dimension=self.latent_dimension,
                hidden_layers=self.hidden_layers,
                name=self.name+'_phi_from_{}'.format(update),
                input_dim=2*(self.latent_dimension)+self.d_in_A
            )
            self.phi_to[str(update)] = FullyConnected(
                non_lin=self.non_lin,
                latent_dimension=self.latent_dimension,
                hidden_layers=self.hidden_layers,
                name=self.name+'_phi_to_{}'.format(update),
                input_dim=2*(self.latent_dimension)+self.d_in_A
            )
            self.phi_loop[str(update)] = FullyConnected(
                non_lin=self.non_lin,
                latent_dimension=self.latent_dimension,
                hidden_layers=self.hidden_layers,
                name=self.name+'_phi_loop_{}'.format(update),
                input_dim=2*(self.latent_dimension)+self.d_in_A
            )

        for update in range(self.correction_updates+1):
            self.xi[str(update)] = FullyConnected(
                non_lin=self.non_lin,
                latent_dimension=self.latent_dimension,
                hidden_layers=self.hidden_layers,#1,
                name=self.name+'_D_{}'.format(update),
                input_dim=self.latent_dimension,
                output_dim=self.d_out
            )
        # self.D = FullyConnected(
        #     non_lin=self.non_lin,
        #     latent_dimension=self.latent_dimension,
        #     hidden_layers=self.hidden_layers,#1,
        #     name=self.name+'_D_{}'.format(update),
        #     input_dim=self.latent_dimension,
        #     output_dim=self.d_out
        # )

        
        #self.loss_function = self.cost_function EquilibriumViolation(self.default_data_directory)

    def build_graph(self, default_data_directory):
        """
        Builds the computation graph.
        Assumes that all graphs have been merged into one supergraph
        """

        def extract_fn(tfrecord):

            # Extract features using the keys set during creation
            features = {
                'A': tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True),
                'B': tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True),
                'U': tf.io.FixedLenSequenceFeature([], tf.float32, allow_missing=True)
            }

            # Extract the data record
            sample = tf.parse_single_example(tfrecord, features)

            return [sample['A'], sample['B'], sample['U']]

        self.train_dataset = tf.data.TFRecordDataset([os.path.join(default_data_directory,'train.tfrecords')])
        self.train_dataset = self.train_dataset.map(extract_fn).shuffle(100).batch(self.minibatch_size).repeat()

        self.valid_dataset = tf.data.TFRecordDataset([os.path.join(default_data_directory,'val.tfrecords')])
        self.valid_dataset = self.valid_dataset.map(extract_fn).shuffle(100).batch(self.minibatch_size).repeat()

        # Build iterator
        self.iterator = tf.compat.v1.data.Iterator.from_structure(
            tf.compat.v1.data.get_output_types(self.train_dataset),
            None)
        self.next_element = self.iterator.get_next()

        # Build operations to initialize training and validation
        self.training_init_op = self.iterator.make_initializer(self.train_dataset)
        self.validation_init_op = self.iterator.make_initializer(self.valid_dataset)

        # Get the output of the data handler
        self.A_flat, self.B_flat, self.U_flat = self.next_element

        # Reshape the iterator
        self.minibatch_size_ = tf.shape(self.A_flat)[0]
        self.A = tf.reshape(self.A_flat, [self.minibatch_size_, -1, self.d_in_A+2])
        self.B = tf.reshape(self.B_flat, [self.minibatch_size_, -1, self.d_in_B])
        self.U_gt = tf.reshape(self.U_flat, [self.minibatch_size_, -1, self.d_out])

        # Get relevant tensor dimensions
        self.minibatch_size_tf = tf.shape(self.A)[0]
        self.num_nodes = tf.shape(self.B)[1]
        self.num_edges = tf.shape(self.A)[1]
        self.A_dim = tf.shape(self.A)[2]

        # Getting normalization constants
        self.A_mean_tf = tf.ones([self.minibatch_size_tf, self.num_edges, 1]) * \
            tf.reshape(tf.constant(self.A_mean, dtype=tf.float32), [1, 1, 2+self.d_in_A])
        self.A_std_tf = tf.ones([self.minibatch_size_tf, self.num_edges, 1]) * \
            tf.reshape(tf.constant(self.A_std, dtype=tf.float32), [1, 1, 2+self.d_in_A])
        self.B_mean_tf = tf.ones([self.minibatch_size_tf, self.num_nodes, 1]) * \
            tf.reshape(tf.constant(self.B_mean, dtype=tf.float32), [1, 1, self.d_in_B])
        self.B_std_tf = tf.ones([self.minibatch_size_tf, self.num_nodes, 1]) * \
            tf.reshape(tf.constant(self.B_std, dtype=tf.float32), [1, 1, self.d_in_B])

        # Normalizing inputs A and B. Lower case means normalized
        self.a = (self.A - self.A_mean_tf) / self.A_std_tf
        self.b = (self.B - self.B_mean_tf) / self.B_std_tf

        # Extract indices from matrix A (indices are indeed not normalized)
        self.indices_from = tf.cast(self.A[:,:,0], tf.int32)
        self.indices_to = tf.cast(self.A[:,:,1], tf.int32)

        # Build mask to detect loops
        self.mask_loop = tf.cast(tf.math.equal(self.indices_from, self.indices_to), tf.float32)
        self.mask_loop = tf.expand_dims(self.mask_loop, -1)

        # Extract normalized edge characteristics from matrix A
        self.a_ij = self.a[:,:,2:]

        # Initialize the discount factor that will later be updated
        self.discount = tf.Variable(0., trainable=False)
        self.sess.run(tf.compat.v1.variables_initializer([self.discount]))

        # Initialize messages, predictions and losses dict
        self.H = {}
        self.U = {}
        self.loss = {}
        self.loss_proxy = {}
        self.log_loss = {}
        self.cost_per_sample = {}
        self.total_loss = None
        self.total_loss_proxy = None

        # Get the natural offset that will be added to every output U at every node
        self.initial_U_tf = tf.ones([self.minibatch_size_tf, self.num_nodes, 1]) * \
            tf.reshape(tf.constant(self.initial_U, dtype=tf.float32), [1, 1, self.d_out])

        # Initialize latent message and prediction to 0
        self.H['0'] = tf.zeros([self.minibatch_size_tf, self.num_nodes, self.latent_dimension])
        
        # Decode the first message. Although this step useless, it is still there for compatibility issues
        self.U['0'] = self.xi['0'](self.H['0']) + self.initial_U_tf

        # Iterate over every correction update (k_bar)
        for update in range(self.correction_updates):

            # Gather messages from both extremities of each edges
            self.H_from = custom_gather(self.H[str(update)], self.indices_from)
            self.H_to = custom_gather(self.H[str(update)], self.indices_to)

            # Concatenate all the inputs of the phi neural network
            self.Phi_input = tf.concat([self.H_from, self.H_to, self.a_ij], axis=2)

            # Compute the phi using the dedicated neural network blocks
            self.Phi_from = self.phi_from[str(update)](self.Phi_input) * (1.-self.mask_loop)
            self.Phi_to = self.phi_to[str(update)](self.Phi_input) * (1.-self.mask_loop)
            self.Phi_loop = self.phi_loop[str(update)](self.Phi_input) * self.mask_loop

            # Get the sum of each transformed messages at each node
            self.Phi_from_sum = custom_scatter(
                self.indices_from, 
                self.Phi_from, 
                [self.minibatch_size_tf, self.num_nodes, self.latent_dimension])
            self.Phi_to_sum = custom_scatter(
                self.indices_to, 
                self.Phi_to, 
                [self.minibatch_size_tf, self.num_nodes, self.latent_dimension])
            self.Phi_loop_sum = custom_scatter(
                self.indices_to, 
                self.Phi_loop, 
                [self.minibatch_size_tf, self.num_nodes, self.latent_dimension])

            # Concatenate all the inputs of the correction neural network
            self.correction_input = tf.concat([
                self.H[str(update)],
                self.Phi_from_sum,
                self.Phi_to_sum,
                self.Phi_loop_sum,
                self.b], axis=2)

            # Compute the correction using the dedicated neural network block
            self.correction = self.psi[str(update)](self.correction_input)

            # Apply correction, and extract the predictions from the latent message
            self.H[str(update+1)] = self.H[str(update)] + self.correction * self.alpha

            # Decode H
            self.U[str(update+1)] = self.xi[str(update+1)](self.H[str(update+1)]) + self.initial_U_tf

            # Compute the loss for each sample
            if self.proxy:
                self.loss_proxy[str(update + 1)] = \
                    tf.reduce_mean((self.U[str(update+1)] - self.U_gt)**2)

            self.cost_per_sample[str(update+1)] = \
                self.problem.cost_function(self.U[str(update+1)], self.A, self.B)

            # Take the mean, and register its value for tensorboard
            self.loss[str(update+1)] = tf.reduce_mean(self.cost_per_sample[str(update+1)])
            tf.compat.v1.summary.scalar("loss_{}".format(update+1), self.loss[str(update+1)])

            # Compute the discounted loss
            if self.total_loss is None:
                self.total_loss = self.loss[str(update+1)] * self.discount**(self.correction_updates-1-update)
                if self.proxy:
                    self.total_loss_proxy = self.loss_proxy[str(update+1)] * \
                                            self.discount**(self.correction_updates-1-update)
            else:
                self.total_loss += self.loss[str(update+1)] * self.discount**(self.correction_updates-1-update)
                if self.proxy:
                    self.total_loss_proxy += self.loss_proxy[str(update+1)] * \
                                             self.discount**(self.correction_updates-1-update)

        # Get the final prediction and the final loss
        self.U_final = self.U[str(self.correction_updates)]
        self.loss_final = self.loss[str(self.correction_updates)]

        # Initialize the optimizer
        self.learning_rate = tf.Variable(0., trainable=False)
        self.sess.run(tf.compat.v1.variables_initializer([self.learning_rate]))
        self.optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=self.learning_rate)

        # Gradient clipping to avoid exploding gradients
        if self.proxy:
            self.gradients, self.variables = zip(*self.optimizer.compute_gradients(self.total_loss_proxy))
        else:
            self.gradients, self.variables = zip(*self.optimizer.compute_gradients(self.total_loss))
        self.gradients, _ = tf.clip_by_global_norm(self.gradients, 1e-2)

        # Define optimization operator
        self.opt_op = self.optimizer.apply_gradients(zip(self.gradients, self.variables))

        # Initialize training variables
        self.sess.run(tf.compat.v1.variables_initializer(self.optimizer.variables()))

        # Build summary to visualize the final loss in Tensorboard
        tf.compat.v1.summary.scalar("loss_final", self.loss_final)
        self.merged_summary_op = tf.compat.v1.summary.merge_all()

        # Gather trainable variables
        for update in range(self.correction_updates):
            self.trainable_variables.extend(self.phi_from[str(update)].trainable_variables)
            self.trainable_variables.extend(self.phi_to[str(update)].trainable_variables)
            self.trainable_variables.extend(self.phi_loop[str(update)].trainable_variables)
            self.trainable_variables.extend(self.psi[str(update)].trainable_variables)
        for update in range(self.correction_updates+1):
            self.trainable_variables.extend(self.xi[str(update)].trainable_variables)

        

    def log_config(self):
        """
        Logs the config of the whole model
        """

        logging.info('    Configuration :')
        logging.info('        Storing model in  : '+self.directory)
        logging.info('        Latent dimensions : {}'.format(self.latent_dimension))
        logging.info('        Number of hidden layers per block : {}'.format(self.hidden_layers))
        logging.info('        Number of correction updates : {}'.format(self.correction_updates))
        logging.info('        Alpha : {}'.format(self.alpha))
        logging.info('        Non linearity : {}'.format(self.non_lin))
        logging.info('        d_in_A : {}'.format(self.d_in_A))
        logging.info('        d_in_B : {}'.format(self.d_in_B))
        logging.info('        d_out : {}'.format(self.d_out))
        #logging.info('        d_F : {}'.format(self.d_F))
        logging.info('        Minibatch size : {}'.format(self.minibatch_size))
        logging.info('        Current training iteration : {}'.format(self.current_train_iter))
        logging.info('        Model name : ' + self.name)
        logging.info('        Initial U : {}'.format(self.initial_U))
        logging.info('        A mean : {}'.format(self.A_mean))
        logging.info('        A std : {}'.format(self.A_std))
        logging.info('        B mean : {}'.format(self.B_mean))
        logging.info('        B std : {}'.format(self.B_std))
        logging.info('        Proxy : {}'.format(self.proxy))

    def set_config(self, config):
        """
        Sets the config according to an inputed dict
        """

        self.latent_dimension = config['latent_dimension']
        self.hidden_layers = config['hidden_layers']
        self.correction_updates = config['correction_updates']
        self.alpha = config['alpha']
        self.non_lin = config['non_lin']
        self.d_in_A = config['d_in_A']
        self.d_in_B = config['d_in_B']
        self.d_out = config['d_out']
        #self.d_F = config['d_F']
        self.minibatch_size = config['minibatch_size']
        self.name = config['name']
        self.directory = config['directory']
        self.current_train_iter = config['current_train_iter']
        self.initial_U = np.array(config['initial_U'])
        self.A_mean = np.array(config['A_mean'])
        self.A_std = np.array(config['A_std'])
        self.B_mean = np.array(config['B_mean'])
        self.B_std = np.array(config['B_std'])
        if 'proxy' in config:
            self.proxy = config['proxy']
        else:
            self.proxy = False

    def get_config(self):
        """
        Gets the config dict
        """

        config = {
            'latent_dimension': self.latent_dimension,
            'hidden_layers': self.hidden_layers,
            'correction_updates': self.correction_updates,
            'alpha': self.alpha,
            'non_lin': self.non_lin,
            'd_in_A': self.d_in_A,
            'd_in_B': self.d_in_B,
            'd_out': self.d_out,
            #'d_F': self.d_F,
            'minibatch_size': self.minibatch_size,
            'name': self.name,
            'directory': self.directory,
            'current_train_iter': self.current_train_iter,
            'initial_U': list(self.initial_U),
            'A_mean': list(self.A_mean),
            'A_std': list(self.A_std),
            'B_mean': list(self.B_mean),
            'B_std': list(self.B_std),
            'proxy': self.proxy
        } 
        return config

    def save(self):
        """
        Saves the configuration of the model and the trained weights
        """

        # Save config
        config = self.get_config()
        path_to_config = os.path.join(self.directory, 'config.json')
        with open(path_to_config, 'w') as f:
            json.dump(config, f)

        # Save weights
        saver = tf.compat.v1.train.Saver(self.trainable_variables)
        path_to_weights = os.path.join(self.directory, 'model.ckpt')
        saver.save(self.sess, path_to_weights)


    def train(self, 
        max_iter=10,
        learning_rate=3e-4, 
        discount=0.9,
        data_directory='datasets/spring/default',
        save_step=None,
        profile=False):
        """
        Performs a training process while keeping track of the validation score
        """

        # Log infos about training process
        logging.info('    Starting a training process :')
        logging.info('        Max iteration : {}'.format(max_iter))
        logging.info('        Learning rate : {}'.format(learning_rate))
        logging.info('        Discount : {}'.format(discount))
        logging.info('        Training data : {}'.format(data_directory))
        logging.info('        Saving model every {} iterations'.format(save_step))
        if profile:
            logging.info('        Profiling...')

        # Load dataset
        self.sess.run(self.training_init_op)

        # Build writer dedicated to training for Tensorboard
        self.training_writer = tf.compat.v1.summary.FileWriter(
            os.path.join(self.directory, 'train'))

        # Build writer dedicated to validation for Tensorboard
        self.validation_writer = tf.compat.v1.summary.FileWriter(
            os.path.join(self.directory, 'val'))

        # Set discount factor and learning rate
        self.sess.run(self.discount.assign(discount))
        self.sess.run(self.learning_rate.assign(learning_rate))

        # Copy the latest training iteration of the model
        starting_point = copy.copy(self.current_train_iter)

        # If profiling, then initialize useful variables
        if profile:
            options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()
            profile_path = os.path.join(self.directory, 'profile')

        # Training loop
        for i in tqdm(range(starting_point, starting_point+max_iter)):

            # Store current training step, so that it's always up to date
            self.current_train_iter = i

            # Perform SGD step
            if profile and i>starting_point:
                self.sess.run(self.opt_op, options=options, run_metadata=run_metadata)
                fetched_timeline = timeline.Timeline(run_metadata.step_stats)
                chrome_trace = fetched_timeline.generate_chrome_trace_format()
                with open(profile_path+'_{}.json'.format(i), 'w') as f:
                    f.write(chrome_trace)
            else:
                self.sess.run(self.opt_op)

            # Store final loss in a summary
            self.summary = self.sess.run(self.merged_summary_op)
            self.training_writer.add_summary(self.summary, self.current_train_iter)

            # Periodically log metrics and save model
            if ((save_step is not None) & (i % save_step == 0)) or (i == starting_point+max_iter-1):

                # Get minibatch train loss
                loss_final_train = self.sess.run(self.loss_final)

                # Change source data to validation
                self.sess.run(self.validation_init_op)

                # Get minibatch val loss
                loss_final_val = self.sess.run(self.loss_final)

                # Store final loss in validation
                self.summary = self.sess.run(self.merged_summary_op)
                self.validation_writer.add_summary(self.summary, self.current_train_iter)

                # Change source data to validation
                self.sess.run(self.training_init_op)

                # Log metrics
                logging.info('    Learning iteration {}'.format(i))
                logging.info('        Training loss (minibatch) : {}'.format(loss_final_train))
                logging.info('        Validation loss (minibatch): {}'.format(loss_final_val))

                # Save model
                self.save()

        # Save model at the end of training
        self.save()

    def evaluate(self,
        mode='val',
        data_directory='data'):
        """
        Evaluate loss on the desired dataset and stores predictions
        """

        # Import numpy dataset
        data_plot = 'datasets/spring/large'
        data_file = '_test.npy'
        A = np.load(os.path.join(data_directory, 'A_'+mode+'.npy'))
        B = np.load(os.path.join(data_directory, 'B_'+mode+'.npy'))

        # Compute final loss
        loss = self.sess.run(self.loss_final, feed_dict={self.A:A, self.B:B})

        # Compute and save the final prediction
        X_final = self.sess.run(self.X_final, feed_dict={self.A:A, self.B:B})
        np.save(os.path.join(self.directory, 'X_final_pred_'+mode+'.npy'), X_final)

        return loss