Example #1
0
    def test(self, args):
        print("[*] Testing....")

        if not os.path.exists(args.output_path):
            os.makedirs(args.output_path)

        self.saver = tf.train.Saver()

        self.sess.run(tf.global_variables_initializer())
        self.sess.run(tf.local_variables_initializer())

        if args.checkpoint_path != '':
            self.saver.restore(self.sess, args.checkpoint_path)
            print(" [*] Load model: SUCCESS")
        else:
            print(" [*] Load failed...neglected")
            print(" [*] End Testing...")
            raise ValueError('self.checkpoint_path == ')

        dataloader = Dataloader(file=args.dataset_testing,
                                isTraining=self.isTraining)
        disp_batch = dataloader.disp
        line = dataloader.disp_filename
        num_samples = dataloader.count_text_lines(args.dataset_testing)

        prediction = tf.pad(
            tf.nn.sigmoid(self.prediction),
            tf.constant([[0, 0], [
                self.radius,
                self.radius,
            ], [self.radius, self.radius], [0, 0]]), "CONSTANT")
        png = tf.image.encode_png(
            tf.cast(tf.scalar_mul(65535.0, tf.squeeze(prediction, axis=0)),
                    dtype=tf.uint16))

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        print(" [*] Start Testing...")
        for step in range(num_samples):
            batch, filename = self.sess.run([disp_batch, line])
            print(" [*] Test image:" + filename)
            start = time.time()
            confidence = self.sess.run(png, feed_dict={self.disp: batch})
            current = time.time()
            output_file = args.output_path + filename.strip().split('/')[-1]
            self.sess.run(tf.write_file(output_file, confidence))
            print(" [*] CCNN confidence prediction saved in:" + output_file)
            print(" [*] CCNN running time:" + str(current - start) + "s")

        coord.request_stop()
        coord.join(threads)
Example #2
0
    def train(self, args):
        print("\n [*] Training....")

        if not os.path.exists(args.log_directory):
            os.makedirs(args.log_directory)

        self.optimizer = tf.train.GradientDescentOptimizer(
            self.learning_rate).minimize(self.loss, var_list=self.vars)
        self.saver = tf.train.Saver()
        self.summary_op = tf.summary.merge_all(self.model_collection[0])
        self.writer = tf.summary.FileWriter(args.log_directory + "/summary/",
                                            graph=self.sess.graph)

        total_num_parameters = 0
        for variable in tf.trainable_variables():
            total_num_parameters += np.array(
                variable.get_shape().as_list()).prod()
        print(" [*] Number of trainable parameters: {}".format(
            total_num_parameters))

        self.sess.run(tf.global_variables_initializer())
        self.sess.run(tf.local_variables_initializer())

        print(' [*] Loading training set...')
        dataloader = Dataloader(file=args.dataset_training,
                                isTraining=self.isTraining)
        patches_disp, patches_gt = dataloader.get_training_patches(
            self.patch_size, args.threshold)
        line = dataloader.disp_filename
        num_samples = dataloader.count_text_lines(args.dataset_training)

        print(' [*] Training data loaded successfully')
        epoch = 0
        iteration = 0
        lr = self.initial_learning_rate

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        print(" [*] Start Training...")
        while epoch < self.epoch:
            for i in range(num_samples):
                batch_disp, batch_gt, filename = self.sess.run(
                    [patches_disp, patches_gt, line])
                print(" [*] Training image: " + filename)

                step_image = 0
                while step_image < len(batch_disp):
                    offset = (step_image * self.batch_size) % (
                        batch_disp.shape[0] - self.batch_size)
                    batch_data = batch_disp[offset:(offset +
                                                    self.batch_size), :, :, :]
                    batch_labels = batch_gt[offset:(offset + self.batch_size),
                                            self.radius:self.radius + 1,
                                            self.radius:self.radius + 1, :]

                    _, loss, summary_str = self.sess.run(
                        [self.optimizer, self.loss, self.summary_op],
                        feed_dict={
                            self.disp: batch_data,
                            self.gt: batch_labels,
                            self.learning_rate: lr
                        })

                    print("Epoch: [%2d]" % epoch + ", Image: [%2d]" % i +
                          ", Iter: [%2d]" % iteration + ", Loss: [%2f]" % loss)
                    self.writer.add_summary(summary_str, global_step=iteration)
                    iteration = iteration + 1
                    step_image = step_image + self.batch_size

            epoch = epoch + 1

            if np.mod(epoch, args.save_epoch_freq) == 0:
                self.saver.save(self.sess,
                                args.log_directory + '/' + self.model_name,
                                global_step=iteration)

            if epoch == 10:
                lr = lr / 10

        coord.request_stop()
        coord.join(threads)