示例#1
0
    def store_prediction(self, sess, batch_x, batch_y, batch_z, name):
        prediction, blurpredict = sess.run(
            [self.net.predicter, self.net.blurpredict],
            feed_dict={
                self.net.x: batch_x,
                self.net.y: batch_y,
                self.net.z: batch_z,
                self.net.keep_prob: 1.
            })
        pred_shape = prediction.shape
        blurpred_shape = blurpredict.shape

        loss = sess.run(self.net.cost,
                        feed_dict={
                            self.net.x: batch_x,
                            self.net.y:
                            util.crop_to_shape(batch_y, pred_shape),
                            self.net.z: batch_z,
                            self.net.keep_prob: 1.
                        })

        logging.info("Verification error= {:.1f}%, loss= {:.4f}".format(
            error_rate(prediction,
                       util.crop_to_shape(batch_y, prediction.shape)), loss))

        img = util.combine_img_prediction(batch_x, batch_y, prediction)
        util.save_image(img, "%s/%s.jpg" % (self.prediction_path, name))

        return pred_shape, blurpred_shape
示例#2
0
    def store_prediction(self, sess, batch_x, batch_y, name):
        prediction = sess.run(self.net.predicter,
                              feed_dict={
                                  self.net.x: batch_x,
                                  self.net.y: batch_y,
                                  self.net.keep_prob: 1.
                              })

        pred_shape = prediction.shape
        loss = sess.run(self.net.cost,
                        feed_dict={
                            self.net.x: batch_x,
                            self.net.y:
                            util.crop_to_shape(batch_y, pred_shape),
                            self.net.keep_prob: 1.
                        })

        logging.info("Verification error= {:.1f}%, loss= {:.4f}".format(
            error_rate(prediction,
                       util.crop_to_shape(batch_y, prediction.shape)), loss))

        batch_y_img = util.crop_to_shape(batch_y, prediction.shape)
        for idx in range(batch_y.shape[0]):
            img_true = util.to_rgb(batch_y_img[idx])
            img_pred = util.to_rgb(np.asarray(prediction[idx], 'float'))
            util.save_image(
                img_true,
                "%s/%s_true_%s0.jpg" % (self.prediction_path, name, idx))
            util.save_image(
                img_pred,
                "%s/%s_pred_%s0.jpg" % (self.prediction_path, name, idx))
        return pred_shape
示例#3
0
    def train(self,
              data_provider,
              output_path,
              training_iters=10,
              epochs=100,
              dropout=0.75,
              display_step=1,
              restore=False,
              write_graph=False,
              prediction_path='prediction'):
        """
        Lauches the training process
        :param data_provider: callable returning training and verification data
        :param output_path: path where to store checkpoints
        :param training_iters: number of training mini batch iteration
        :param epochs: number of epochs
        :param dropout: dropout probability
        :param display_step: number of steps till outputting stats
        :param restore: Flag if previous model should be restored
        :param write_graph: Flag if the computation graph should be written as protobuf file to the output path
        :param prediction_path: path where to save predictions on each epoch
        """
        save_path = os.path.join(output_path, "model.ckpt")
        if epochs == 0:
            return save_path

        init = self._initialize(training_iters, output_path, restore,
                                prediction_path)

        with tf.Session() as sess:
            if write_graph:
                tf.train.write_graph(sess.graph_def, output_path, "graph.pb",
                                     False)

            sess.run(init)

            if restore:
                ckpt = tf.train.get_checkpoint_state(output_path)
                if ckpt and ckpt.model_checkpoint_path:
                    self.net.restore(sess, ckpt.model_checkpoint_path)

            test_x, test_y, test_z = data_provider(
                self.verification_batch_size)
            pred_shape, blurpred_shape = self.store_prediction(
                sess, test_x, test_y, test_z, "_init")

            summary_writer = tf.summary.FileWriter(output_path,
                                                   graph=sess.graph)
            logging.info("Start optimization")

            avg_gradients = None
            for epoch in range(epochs):
                total_loss = 0
                for step in range((epoch * training_iters),
                                  ((epoch + 1) * training_iters)):
                    batch_x, batch_y, batch_z = data_provider(self.batch_size)

                    # Run optimization op (backprop)
                    _, loss, lr, gradients = sess.run(
                        (self.optimizer, self.net.cost,
                         self.learning_rate_node, self.net.gradients_node),
                        feed_dict={
                            self.net.x: batch_x,
                            self.net.y:
                            util.crop_to_shape(batch_y, pred_shape),
                            self.net.z: batch_z,
                            self.net.keep_prob: dropout
                        })

                    if self.net.summaries and self.norm_grads:
                        avg_gradients = _update_avg_gradients(
                            avg_gradients, gradients, step)
                        norm_gradients = [
                            np.linalg.norm(gradient)
                            for gradient in avg_gradients
                        ]
                        self.norm_gradients_node.assign(norm_gradients).eval()

                    if step % display_step == 0:
                        self.output_minibatch_stats(
                            sess, summary_writer, step, batch_x,
                            util.crop_to_shape(batch_y, pred_shape), batch_z)

                    total_loss += loss

                self.output_epoch_stats(epoch, total_loss, training_iters, lr)
                self.store_prediction(sess, test_x, test_y, test_z,
                                      "epoch_%s" % epoch)

                save_path = self.net.save(sess, save_path)
            logging.info("Optimization Finished!")

            return save_path
示例#4
0
    display_step = 2
    restore = True

    generator = image_gen.RgbDataProvider(nx, ny, cnt=20, rectangles=False)

    net = unet.Unet(channels=generator.channels,
                    n_class=generator.n_class,
                    layers=3,
                    features_root=4,
                    cost="IoU")

    trainer = unet.Trainer(net,
                           optimizer="momentum",
                           opt_kwargs=dict(momentum=0.2,
                                           learning_rate=0.1,
                                           decay_rate=0.9))
    path = trainer.train(generator,
                         "./unet_trained",
                         training_iters=training_iters,
                         epochs=epochs,
                         dropout=dropout,
                         display_step=display_step,
                         restore=restore)

    x_test, y_test = generator(4)
    prediction = net.predict(path, x_test)

    print("Testing error rate: {:.2f}%".format(
        unet.error_rate(prediction, util.crop_to_shape(y_test,
                                                       prediction.shape))))