def fit(self, training_set, validation_set): """ Fit the model Args: training_set: validation_set: Returns: """ outputs = self.model outputs[-1] = tf.Print(outputs[-1], [self.label], summarize=self.batch_size * self.n_output, message="Label: ") outputs[-1] = tf.Print(outputs[-1], [outputs[-1]], summarize=self.batch_size * self.n_output, message="Prediction: ") metrics = SequenceModel.compute_metrics( outputs[-1], self.label, self.time_steps if self.is_sequence_output else 1) loss = SequenceModel.compute_loss( outputs[-1], self.label, self.loss_name, self.time_steps if self.is_sequence_output else 1) train_op = SequenceModel.compute_gradient(self, loss, self.global_step) # Merge summaries summaries = tf.summary.merge_all() # Initialize variables init_g = tf.global_variables_initializer() init_l = tf.local_variables_initializer() self.global_step = tf.add(self.global_step, tf.constant(1)) with tf.Session() as sess: run_opts = tf.RunOptions(report_tensor_allocations_upon_oom=True) sess.run(init_g) sess.run(init_l) self.train_writer.add_graph(sess.graph) # Load existing model SequenceModel.load(self, sess) if self.from_pretrained else None for epoch in range(self.epochs): for i in range(self.batch_size, len(training_set), self.batch_size): time0 = time() batch_input, batch_label = self.load_batch( training_set[i - self.batch_size:i]) initial_state = np.zeros(shape=(self.batch_size, self.units_per_cell), dtype=np.float32) _, loss_value, summaries_value, step = sess.run( [train_op, loss, summaries, self.global_step], feed_dict={ self.input: batch_input, self.label: batch_label, self.initial_state: initial_state }, options=run_opts, ) self.train_writer.add_summary(summaries_value, step) time1 = time() self.logger.info( "Cost = {0} for batch {1} in {2:.2f} seconds".format( loss_value, i / self.batch_size, time1 - time0)) if self.logger else None if i % self.validation_step == 0: SequenceModel.memory() self.validation_eval(sess, summaries, validation_set, metrics, step) gc.collect() SequenceModel.memory() if i % self.checkpoint_step == 0: SequenceModel.save(self, sess, step=self.global_step)
def fit(self, training_set, validation_set, stop_at_step=None): """ Fit the model. Args: training_set: set of data for training validation_set: set of data for evaluation stop_at_step: step from which to stop """ outputs = self.model outputs[-1] = tf.Print(outputs[-1], [self.label], summarize=self.batch_size * self.n_output, message="Label: ") outputs[-1] = tf.Print(outputs[-1], [outputs[-1]], summarize=self.batch_size * self.n_output, message="Prediction: ") metrics = SequenceModel.compute_metrics( outputs[-1], self.label, self.time_steps if self.is_sequence_output else 1) loss = SequenceModel.compute_loss( outputs[-1], self.label, self.loss_name, self.time_steps if self.is_sequence_output else 1) train_op = SequenceModel.compute_gradient(self, loss, self.global_step) self.global_step = tf.add(self.global_step, tf.constant(1)) # Merge summaries summaries = tf.summary.merge_all() # Initialize variables init_g = tf.global_variables_initializer() init_l = tf.local_variables_initializer() # Initialize states initial_state = np.zeros(shape=(self.batch_size, self.units_per_cell), dtype=np.float32) initial_state_val = np.zeros(shape=(len(validation_set), self.units_per_cell), dtype=np.float32) # load validation set input_val, label_val = self.load_batch(validation_set) with tf.Session() as sess: run_opts = tf.RunOptions(report_tensor_allocations_upon_oom=True) sess.run(init_g) sess.run(init_l) self.train_writer.add_graph(sess.graph) # Load existing model SequenceModel.load(self, sess) if self.from_pretrained else None for epoch in range(self.epochs): for i in range(self.batch_size, len(training_set), self.batch_size): time0 = time() batch_input, batch_label = self.load_batch( training_set[i - self.batch_size:i]) _, loss_value, summaries_value, step = sess.run( [train_op, loss, summaries, self.global_step], feed_dict={ self.input: batch_input, self.label: batch_label, self.initial_state: initial_state }, options=run_opts, ) self.train_writer.add_summary(summaries_value, step) time1 = time() self.logger.info( "Cost = {0} for batch {1} in {2:.2f} seconds".format( loss_value, i / self.batch_size, time1 - time0)) if self.logger else None if i % self.validation_step == 0: if i % self.validation_step == 0: self.validation_eval(sess, summaries, input_val, label_val, initial_state_val, metrics, step) if i % self.checkpoint_step == 0: # SequenceModel.save(self, sess, step=self.global_step) import os checkpoint_path = os.path.join(self.checkpoint_path, self.name) self.saver.save(sess, checkpoint_path, global_step=step) if stop_at_step and step >= stop_at_step: break predictions = sess.run( [outputs[-1]], feed_dict={ self.input: input_val, self.label: label_val, self.initial_state: initial_state_val }) np.save("predictions.npy", predictions[0])