Пример #1
0
def run():
    """Builds model, loads data, trains and evaluates"""
    model = UNet(CFG)
    model.load_data()
    model.build()
    model.train()
    model.evaluate()
Пример #2
0
class UnetTest(tf.test.TestCase):

    def setUp(self):
        super(UnetTest, self).setUp()
        self.unet = UNet(CFG)

    def tearDown(self):
        pass

    def test_normalize(self):
        input_image = np.array([[1., 1.], [1., 1.]])
        input_mask = 1
        expected_image = np.array([[0.00392157, 0.00392157], [0.00392157, 0.00392157]])

        result = self.unet._normalize(input_image, input_mask)
        self.assertAllClose(expected_image, result[0])

    def test_ouput_size(self):
        shape = (1, self.unet.image_size, self.unet.image_size, 3)
        image = tf.ones(shape)
        self.unet.build()
        self.assertEqual(self.unet.model.predict(image).shape, shape)

    @patch('model.unet.DataLoader.load_data')
    def test_load_data(self, mock_data_loader):
        mock_data_loader.side_effect = dummy_load_data
        shape = tf.TensorShape([None, self.unet.image_size, self.unet.image_size, 3])

        self.unet.load_data()
        mock_data_loader.assert_called()

        self.assertItemsEqual(self.unet.train_dataset.element_spec[0].shape, shape)
        self.assertItemsEqual(self.unet.test_dataset.element_spec[0].shape, shape)
Пример #3
0
test_path = 'data/test'
save_path = 'data/results'
model_weights_name = 'unet_bones_weights.hdf5'

if __name__ == "__main__":
    """ Prediction Script
    Run this Python script with a command line
    argument that defines number of test samples
    e.g. python predict.py 6
    Note that test samples names should be:
    1.jpg, 2.jpg, 3.jpg ...
    """

    # get number of samples from command line
    samples_number = int(sys.argv[1])

    # build model
    unet = UNet(
        input_size = (img_width,img_height,1),
        n_filters = 64,
        pretrained_weights = model_weights_name
    )
    unet.build()

    # generated testing set
    test_gen = test_generator(test_path, samples_number, img_size)

    # display results
    results = unet.predict_generator(test_gen, samples_number ,verbose=1)
    save_results(save_path, results)
Пример #4
0
def main():
    with tf.Session() as sess:
        unet = UNet(num_classes=NUM_CLASSES, input_shape=INPUT_SHAPE)
        unet.build()
        init = tf.global_variables_initializer()
        sess.run(init)
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=5)

        best_val_loss = 9999999
        num_consec_worse_earlystop = 0
        num_consec_worse_lr = 0
        # learning_rate = 1e-4

        for s, lrs in enumerate(LEARNING_RATE_SETTINGS):
            for epoch in range(lrs['max_epoch']):
                ##############
                #   Train
                ##############
                start_time = time.time()
                train_data = train_img_reader.read()
                for batch, (X_batch, y_batch) in enumerate(train_data):
                    _, loss, pred = sess.run(
                        [unet.train_op, unet.loss, unet.pred],
                        feed_dict={
                            unet.is_training: True,
                            unet.X_train: X_batch,
                            unet.y_train: y_batch,
                            unet.learning_rate: lrs['lr']
                        })
                    logger.info(
                        '[set {}, epoch {}, batch {}] training loss: {}'.
                        format(s, epoch, batch, loss))

                logger.info(
                    '==== set {}, epoch {} took {:.0f} seconds to train. ===='.
                    format(s, epoch,
                           time.time() - start_time))

                ##########################
                #   Eval Validation set
                ##########################
                start_time = time.time()
                val_data = val_img_reader.read()
                losses = []
                for batch, (X_batch, y_batch) in enumerate(val_data):
                    loss, pred = sess.run(
                        [unet.loss, unet.pred],
                        feed_dict={
                            unet.is_training: False,
                            unet.X_train: X_batch,
                            unet.y_train: y_batch
                        })
                    losses.append(loss)

                avg_val_loss = np.average(losses)
                logger.info('==== average validation loss: {} ===='.format(
                    avg_val_loss))
                logger.info(
                    '==== set {}, epoch {} took {:.0f} seconds to evaluate the validation set. ===='
                    .format(s, epoch,
                            time.time() - start_time))

                def save_checkpoint(sess):
                    saver.save(sess,
                               os.path.join(cur_checkpoint_path,
                                            'unet-{}'.format(INPUT_SHAPE)),
                               global_step=s * len(LEARNING_RATE_SETTINGS) +
                               epoch)

                if lrs.get('reduce_factor'):
                    if avg_val_loss < best_val_loss:
                        best_val_loss = avg_val_loss
                        # num_consec_worse_earlystop = 0
                        num_consec_worse_lr = 0
                    else:
                        # num_consec_worse_earlystop += 1
                        num_consec_worse_lr += 1

                    if num_consec_worse_lr >= lrs.get('reduce_patience'):
                        lrs['lr'] *= lrs.get('reduce_factor')
                        logger.info(
                            '==== val loss did not improve for {} epochs, learning rate reduced to {}. ===='
                            .format(num_consec_worse_lr, lrs['lr']))
                        num_consec_worse_lr = 0

                # if num_consec_worse_earlystop >= EARLY_STOPPING_PATIENCE:
                #     logger.info('==== Training early stopped because worse val loss lasts for {} epochs. ===='.format(num_consec_worse_earlystop))
                #     save_checkpoint(sess)
                #     break

                if (epoch > 0 and epoch % SAVING_INTERVAL
                        == 0) or epoch == lrs['max_epoch'] - 1:
                    save_checkpoint(sess)