示例#1
0
    def fit(self, training_set, validation_set):
        """
        Fit the model

        Args:
            training_set:
            validation_set:

        Returns:

        """
        outputs = self.model

        outputs[-1] = tf.Print(outputs[-1], [self.label],
                               summarize=self.batch_size * self.n_output,
                               message="Label: ")
        outputs[-1] = tf.Print(outputs[-1], [outputs[-1]],
                               summarize=self.batch_size * self.n_output,
                               message="Prediction: ")

        metrics = SequenceModel.compute_metrics(
            outputs[-1], self.label,
            self.time_steps if self.is_sequence_output else 1)
        loss = SequenceModel.compute_loss(
            outputs[-1], self.label, self.loss_name,
            self.time_steps if self.is_sequence_output else 1)

        train_op = SequenceModel.compute_gradient(self, loss, self.global_step)

        # Merge summaries
        summaries = tf.summary.merge_all()

        # Initialize variables
        init_g = tf.global_variables_initializer()
        init_l = tf.local_variables_initializer()

        self.global_step = tf.add(self.global_step, tf.constant(1))

        with tf.Session() as sess:

            run_opts = tf.RunOptions(report_tensor_allocations_upon_oom=True)

            sess.run(init_g)
            sess.run(init_l)

            self.train_writer.add_graph(sess.graph)

            # Load existing model
            SequenceModel.load(self, sess) if self.from_pretrained else None

            for epoch in range(self.epochs):

                for i in range(self.batch_size, len(training_set),
                               self.batch_size):

                    time0 = time()
                    batch_input, batch_label = self.load_batch(
                        training_set[i - self.batch_size:i])
                    initial_state = np.zeros(shape=(self.batch_size,
                                                    self.units_per_cell),
                                             dtype=np.float32)

                    _, loss_value, summaries_value, step = sess.run(
                        [train_op, loss, summaries, self.global_step],
                        feed_dict={
                            self.input: batch_input,
                            self.label: batch_label,
                            self.initial_state: initial_state
                        },
                        options=run_opts,
                    )

                    self.train_writer.add_summary(summaries_value, step)

                    time1 = time()
                    self.logger.info(
                        "Cost = {0} for batch {1} in {2:.2f} seconds".format(
                            loss_value, i / self.batch_size, time1 -
                            time0)) if self.logger else None

                    if i % self.validation_step == 0:
                        SequenceModel.memory()
                        self.validation_eval(sess, summaries, validation_set,
                                             metrics, step)
                        gc.collect()
                        SequenceModel.memory()

                    if i % self.checkpoint_step == 0:
                        SequenceModel.save(self, sess, step=self.global_step)
示例#2
0
    def fit(self, training_set, validation_set, stop_at_step=None):
        """
        Fit the model.

        Args:
            training_set: set of data for training
            validation_set: set of data for evaluation
            stop_at_step: step from which to stop

        """
        outputs = self.model

        outputs[-1] = tf.Print(outputs[-1], [self.label],
                               summarize=self.batch_size * self.n_output,
                               message="Label: ")
        outputs[-1] = tf.Print(outputs[-1], [outputs[-1]],
                               summarize=self.batch_size * self.n_output,
                               message="Prediction: ")

        metrics = SequenceModel.compute_metrics(
            outputs[-1], self.label,
            self.time_steps if self.is_sequence_output else 1)
        loss = SequenceModel.compute_loss(
            outputs[-1], self.label, self.loss_name,
            self.time_steps if self.is_sequence_output else 1)

        train_op = SequenceModel.compute_gradient(self, loss, self.global_step)

        self.global_step = tf.add(self.global_step, tf.constant(1))

        # Merge summaries
        summaries = tf.summary.merge_all()

        # Initialize variables
        init_g = tf.global_variables_initializer()
        init_l = tf.local_variables_initializer()

        # Initialize states
        initial_state = np.zeros(shape=(self.batch_size, self.units_per_cell),
                                 dtype=np.float32)
        initial_state_val = np.zeros(shape=(len(validation_set),
                                            self.units_per_cell),
                                     dtype=np.float32)

        # load validation set
        input_val, label_val = self.load_batch(validation_set)

        with tf.Session() as sess:

            run_opts = tf.RunOptions(report_tensor_allocations_upon_oom=True)

            sess.run(init_g)
            sess.run(init_l)

            self.train_writer.add_graph(sess.graph)

            # Load existing model
            SequenceModel.load(self, sess) if self.from_pretrained else None

            for epoch in range(self.epochs):

                for i in range(self.batch_size, len(training_set),
                               self.batch_size):

                    time0 = time()
                    batch_input, batch_label = self.load_batch(
                        training_set[i - self.batch_size:i])

                    _, loss_value, summaries_value, step = sess.run(
                        [train_op, loss, summaries, self.global_step],
                        feed_dict={
                            self.input: batch_input,
                            self.label: batch_label,
                            self.initial_state: initial_state
                        },
                        options=run_opts,
                    )

                    self.train_writer.add_summary(summaries_value, step)

                    time1 = time()
                    self.logger.info(
                        "Cost = {0} for batch {1} in {2:.2f} seconds".format(
                            loss_value, i / self.batch_size, time1 -
                            time0)) if self.logger else None

                    if i % self.validation_step == 0:
                        if i % self.validation_step == 0:
                            self.validation_eval(sess, summaries, input_val,
                                                 label_val, initial_state_val,
                                                 metrics, step)

                    if i % self.checkpoint_step == 0:
                        # SequenceModel.save(self, sess, step=self.global_step)
                        import os
                        checkpoint_path = os.path.join(self.checkpoint_path,
                                                       self.name)
                        self.saver.save(sess,
                                        checkpoint_path,
                                        global_step=step)

                    if stop_at_step and step >= stop_at_step:
                        break

            predictions = sess.run(
                [outputs[-1]],
                feed_dict={
                    self.input: input_val,
                    self.label: label_val,
                    self.initial_state: initial_state_val
                })
            np.save("predictions.npy", predictions[0])