def test_smoke(self, input_shape, depth): filters_root = 4 classes = 2 inputs = np.ones(shape=input_shape)[np.newaxis, ...] model = ResUNet(input_shape, classes, filters_root, depth) output = model(inputs).numpy().squeeze() self.assertEqual(output.shape[0], input_shape[0]) self.assertEqual(output.shape[1], input_shape[1]) self.assertEqual(output.shape[2], classes)
def test_smoke(self, input_shape, depth, filters_root, error=False): classes = 2 inputs = np.ones(shape=input_shape)[np.newaxis, ...] try: model = ResUNet(input_shape, classes, filters_root, depth) except Exception: if not error: # if it's not supposed to end as error => raise the error raise # otherwise end the test return output = model(inputs).numpy().squeeze() self.assertEqual(output.shape[0], input_shape[0]) self.assertEqual(output.shape[1], input_shape[1]) self.assertEqual(output.shape[2], classes)
"Path to the validation dataset. Expects 'images' and 'masks' directory inside with images and masks named the same." ) parser.add_argument("--logs_root", default="logs", type=str) parser.add_argument("--epochs", default=100, type=int) parser.add_argument("--batch_size", default=32, type=int) parser.add_argument("--plot_model", action="store_true", default=False) parser.add_argument("--seed", default=42, type=int) args = parser.parse_args() np.random.seed(args.seed) tf.random.set_seed(args.seed) random.seed(args.seed) model = ResUNet(input_shape=(128, 128, 1), classes=2, filters_root=16, depth=3) model.summary() if args.plot_model: from tensorflow.python.keras.utils.vis_utils import plot_model plot_model(model, show_shapes=True) model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["categorical_accuracy"]) train_dataset = list( zip(*list(SimpleDataset(args.train_dataset_dir_path)()))) train_dataset = (np.array(train_dataset[0]), np.array(train_dataset[1])) x = np.array(train_dataset[0])
help="Path to the validation dataset. Expects 'images' and 'masks' directory inside with images and masks named the same.") parser.add_argument("--logs_root", default="logs", type=str) parser.add_argument("--epochs", default=10, type=int) parser.add_argument("--batch_size", default=32, type=int) parser.add_argument("--plot_model", action="store_true", default=False) parser.add_argument("--seed", default=42, type=int) args = parser.parse_args() args.plot_model = True np.random.seed(args.seed) tf.random.set_seed(args.seed) random.seed(args.seed) model = ResUNet(input_shape=(256, 256, 2), classes=2, filters_root=16, depth=4,final_activation='sigmoid') model.summary() if args.plot_model: from tensorflow.python.keras.utils.vis_utils import plot_model plot_model(model, show_shapes=True) dataset = np.load('dataset/scaled_transformer_256.npz') trainer = Trainer(name='ResUNet-transformer-original', checkpoint_callback=True) x_train = dataset['x_train'] y_train = dataset['y_train'] x_valid = dataset['x_valid'] y_valid = dataset['y_valid']