def test_display(mock_show):
    a = TFRecordsReader2D(tf.data.TFRecordDataset('train.tfrecords'), 1, 0, 1, 512, 512, 128, 128, 1)
    img1 = np.zeros(shape=(512,512,1))
    img2 = np.zeros(shape=(512,512,1))
    img3 = np.zeros(shape=(512,512,1))
    display_list = [img1, img2, img3]
    assert a.display(display_list) == None
def test_parse_image_function():
    a = TFRecordsReader2D(tf.data.TFRecordDataset('train.tfrecords'), 1, 0, 1, 512, 512, 128, 128, 1)
    features = {
            'data/slice': tf.io.FixedLenFeature([], tf.string),
            'data/seg': tf.io.FixedLenFeature([], tf.string),
        }
    dataset = a.dataset.map(lambda x: a._parse_function(x, features))
    tf.debugging.assert_proper_iterable(dataset)
def test_load_image_train():
    a = TFRecordsReader2D(tf.data.TFRecordDataset('train.tfrecords'), 1, 0, 1, 512, 512, 128, 128, 1)
    features = {
            'data/slice': tf.io.FixedLenFeature([], tf.string),
            'data/seg': tf.io.FixedLenFeature([], tf.string),
        }
    parsed_dataset = a.dataset.map(lambda x: a._parse_function(x, features))
    dataset = parsed_dataset.map(a.load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    tf.debugging.assert_proper_iterable(dataset)
def test_init():
    a = TFRecordsReader2D(tf.data.TFRecordDataset('train.tfrecords'), 1, 0, 1, 512, 512, 128, 128, 1)
    assert isinstance(a, TFRecordsReader2D)
def test_read_valueerror():
    a = TFRecordsReader2D(tf.data.TFRecordDataset('train.tfrecords'), 1, -1, 1, 512, 512, 128, 128, 1)
    with pytest.raises(ValueError):
        assert a.read()
def test_read():
    a = TFRecordsReader2D(tf.data.TFRecordDataset('train.tfrecords'), 1, 0, 1, 512, 512, 128, 128, 1)
    dataset = a.read()
    tf.debugging.assert_proper_iterable(dataset)
Exemplo n.º 7
0
    def launch_unet2d(self):
        '''Launch the training process.
        '''
        # Read tfrecords
        train_filenames = get_tfrecord_filenames(self.path_to_tfrecords_train)
        val_filenames = get_tfrecord_filenames(self.path_to_tfrecords_val)

        # Load tfrecords into a TFRecordDataset
        train_tfrecorddataset = tf.data.TFRecordDataset(
            filenames=train_filenames)
        val_tfrecorddataset = tf.data.TFRecordDataset(filenames=val_filenames)

        # Decode the data and prepare for training
        log.info('YELLOW', 'Loading Datasets')
        print("Training Datasets: ", train_filenames)
        print("Validation Datasets: ", val_filenames)

        train_data_provider = TFRecordsReader2D(train_tfrecorddataset,
                                                np.ceil(self.num_scans / 10),
                                                0, self.batch_size, 512, 512,
                                                self.image_shape_resize[0],
                                                self.image_shape_resize[1],
                                                self.num_classes)
        val_data_provider = TFRecordsReader2D(val_tfrecorddataset,
                                              np.ceil(self.num_scans / 10), 0,
                                              self.batch_size, 512, 512,
                                              self.image_shape_resize[0],
                                              self.image_shape_resize[1],
                                              self.num_classes)

        train_dataset = train_data_provider.read()
        val_dataset = val_data_provider.read()

        for image, mask in train_dataset.take(1):
            sample_image, sample_mask = image, mask
            print(sample_image.shape)
            print(sample_mask.shape)

        # Load the Unet Model
        Module = Unet2D(learning_rate=self.learning_rate,
                        num_classes=self.num_classes,
                        input_size=[
                            self.image_shape_resize[0],
                            self.image_shape_resize[1], self.channels
                        ])
        model = Module.unet()

        # Create a callback that saves the model's weights
        ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
            filepath=self.ckpt_dir, save_weights_only=True, verbose=1)

        # Clear any logs from previous runs
        shutil.rmtree(self.output_dir + "logs/")

        # Create a callback to tensorboard
        tensorboard_callback = tf.keras.callbacks.TensorBoard(
            log_dir=self.log_dir, histogram_freq=1, profile_batch=0)

        # Creates a file writer for the image prediction log directory.
        file_writer = tf.summary.create_file_writer(self.log_dir + '/img')

        # Start the training and evaluation
        model.fit(x=train_dataset,
                  epochs=self.epochs,
                  steps_per_epoch=self.steps_per_epoch,
                  callbacks=[
                      DisplayCallback(model, val_dataset, file_writer),
                      ckpt_callback, tensorboard_callback
                  ],
                  validation_data=val_dataset,
                  validation_steps=self.val_steps)

        # Save the trained model
        #tf.saved_model.save(model, self.savedmodel_dir)
        model.save(self.savedmodel_dir, save_format='tf')