コード例 #1
0
ファイル: main.py プロジェクト: Vergangenheit/DL_Production
def run():
    """Builds model, loads data, trains and evaluates"""
    model = UNet(CFG)
    model.load_data()
    model.build()
    model.train()
    model.evaluate()
コード例 #2
0
ファイル: unet_test.py プロジェクト: Nornostra/ai-summer
class UnetTest(tf.test.TestCase):

    def setUp(self):
        super(UnetTest, self).setUp()
        self.unet = UNet(CFG)

    def tearDown(self):
        pass

    def test_normalize(self):
        input_image = np.array([[1., 1.], [1., 1.]])
        input_mask = 1
        expected_image = np.array([[0.00392157, 0.00392157], [0.00392157, 0.00392157]])

        result = self.unet._normalize(input_image, input_mask)
        self.assertAllClose(expected_image, result[0])

    def test_ouput_size(self):
        shape = (1, self.unet.image_size, self.unet.image_size, 3)
        image = tf.ones(shape)
        self.unet.build()
        self.assertEqual(self.unet.model.predict(image).shape, shape)

    @patch('model.unet.DataLoader.load_data')
    def test_load_data(self, mock_data_loader):
        mock_data_loader.side_effect = dummy_load_data
        shape = tf.TensorShape([None, self.unet.image_size, self.unet.image_size, 3])

        self.unet.load_data()
        mock_data_loader.assert_called()

        self.assertItemsEqual(self.unet.train_dataset.element_spec[0].shape, shape)
        self.assertItemsEqual(self.unet.test_dataset.element_spec[0].shape, shape)