Пример #1
0
from DeepJetCore.DataCollection import DataCollection
import sys

djdc_path = sys.argv[1]

train_data = DataCollection(djdc_path)

# splits off 10% of the training dataset for validation. Can be used in the same way as train_data
val_data = train_data.split(0.9)

# Set the batch size.
# If the data is ragged in dimension 1 (see convert options),
# then this is the maximum number of elements per batch, which could be distributed differently
# to individual examples. E.g., if the first example has 50 elements, the second 48, and the third 30,
# and the batch size is set to 100, it would return the first two examples (in total 99 elements) in
# the first batch etc. This is helpful to avoid out-of-memory errors during training

train_data.setBatchSize(100)

print("batch size: 100")
# prepare the generator

train_data.invokeGenerator()

# loop over epochs here ...

train_data.generator.shuffleFilelist()
train_data.generator.prepareNextEpoch()

# this number can differ from epoch to epoch for ragged data!
nbatches = train_data.generator.getNBatches()
Пример #2
0
    def train(self):

        placeholder_input, placeholder_output = self.model.get_placeholders()
        graph_output = self.model.get_compute_graphs()
        graph_loss = self.model.get_losses()
        graph_optmiser = self.model.get_optimizer()
        graph_summary = self.model.get_summary()

        if self.from_scratch:
            self.clean_summary_dir()

        init = tf.global_variables_initializer()
        with tf.Session() as sess:
            sess.run(init)
            if self.use_tf_records:
                coord = tf.train.Coordinator()
                threads = tf.train.start_queue_runners(coord=coord)
                record_batch_input, record_batch_target = self.get_record_placeholders(
                )
            else:
                input_data = self.config['train_data_path']
                train_data = DataCollection()
                train_data.readFromFile(input_data)

                val_data = train_data.split(0.1)
                train_data = train_data.split(0.9)
                train_data.setBatchSize(self.batch_size)
                val_data.setBatchSize(self.batch_size)
                val_data_generator = train_data.generator()
                train_data_generator = train_data.generator()

            summary_writer = tf.summary.FileWriter(self.summary_path,
                                                   sess.graph)

            if not self.from_scratch:
                self.saver_all.restore(sess, self.model_path)
                print("\n\nINFO: Loading model\n\n")
                with open(self.model_path + '.txt', 'r') as f:
                    iteration_number = int(f.read())
            else:
                iteration_number = 0

            print("Starting iterations")
            while iteration_number < self.train_for_iterations:
                if self.use_tf_records:
                    input, output = sess.run(
                        [record_batch_input, record_batch_target])
                    input = [
                        np.fromstring(''.join(i)).reshape(
                            13, 13, int(self.config['num_layers']),
                            int(self.config['num_channels'])) for i in input
                    ]
                    output = [
                        np.fromstring(''.join(i)).reshape(
                            13, 13, int(self.config['num_layers']))
                        for i in output
                    ]
                else:
                    input, output, _ = train_data_generator.next()
                    input = np.squeeze(input, axis=0)
                    output = np.squeeze(output, axis=0)

                _, eval_loss, _, eval_summary = sess.run(
                    [graph_output, graph_loss, graph_optmiser, graph_summary],
                    feed_dict={
                        placeholder_input: input,
                        placeholder_output: output
                    })
                print("Iteration %4d: loss %0.5f" %
                      (iteration_number, eval_loss))
                iteration_number += 1
                summary_writer.add_summary(eval_summary, iteration_number)
                if iteration_number % self.save_after_iterations == 0:
                    print("\n\nINFO: Saving model\n\n")
                    self.saver_all.save(sess, self.model_path)
                    with open(self.model_path + '.txt', 'w') as f:
                        f.write(str(iteration_number))
            if self.use_tf_records:
                # Stop the threads
                coord.request_stop()

                # Wait for threads to stop
                coord.join(threads)