def test_lovasz_loss(self): lovaszl = LovaszLoss() loss = lovaszl(self.y_true, self.y_pred) self.assertAllEqual(loss.shape, ())
base_model_path = '/opt/segelectri/train/train_unet/normal_function_model/saved_model/1' clustered_model_path = '/opt/segelectri/train/train_unet/normal_function_model/clustered_model/1' pruned_model_path = '/opt/segelectri/train/train_unet/normal_function_model/pruned_pb_model/1' clustered_model = tf.keras.models.load_model(filepath=clustered_model_path, custom_objects={ 'MeanIou': MeanIou, 'FocalLoss': FocalLoss, 'LovaszLoss': LovaszLoss, 'DiceLoss': DiceLoss, 'BoundaryLoss': BoundaryLoss }) clustered_model.compile(optimizer=tf.keras.optimizers.Adam(), loss=LovaszLoss(), metrics=[MeanIou(num_classes=4)]) get_test_loss(clustered_model, 10, False) pruned_model = tf.keras.models.load_model(filepath=pruned_model_path, custom_objects={ 'MeanIou': MeanIou, 'FocalLoss': FocalLoss, 'LovaszLoss': LovaszLoss, 'DiceLoss': DiceLoss, 'BoundaryLoss': BoundaryLoss }) pruned_model.compile(optimizer=tf.keras.optimizers.Adam(), loss=LovaszLoss(), metrics=[MeanIou(num_classes=4)]) get_test_loss(pruned_model, 10, False) base_model = tf.keras.models.load_model(filepath=base_model_path,
routine = TrainRoutine(ds=ds, model=model) routine.run(exp_dir=freeze_expdir_list[i], epochs=freeze_epochs) model.layers[0].trainable = True model.compile(optimizer=tf.keras.optimizers.Adam(), loss=loss_list[i], metrics=[MeanIou(num_classes=4)]) routine = TrainRoutine(ds=ds, model=model) routine.run(exp_dir=unfreeze_expdir_list[i], epochs=unfreeze_epochs) if __name__ == '__main__': loss_list = [ tf.keras.losses.BinaryCrossentropy(), FocalLoss(), LovaszLoss(), DiceLoss(), BoundaryLoss() ] freeze_expdir_list = [ 'exp/46_freeze', 'exp/47_freeze', 'exp/48_freeze', 'exp/49_freeze', 'exp/50_freeze' ] unfreeze_expdir_list = [ 'exp/46_unfreeze', 'exp/47_unfreeze', 'exp/48_unfreeze', 'exp/49_unfreeze', 'exp/50_unfreeze' ] model_optimize(xla=True, mix_prec=True) train(loss_list=loss_list, freeze_expdir_list=freeze_expdir_list, unfreeze_expdir_list=unfreeze_expdir_list,