Ejemplo n.º 1
0
def export_model(RUNNING_CONFIG):

    if RUNNING_CONFIG.use_tf_amp:
        os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"

    model = UNet_v1(model_name="UNet_v1",
                    input_format="NHWC",
                    compute_format=RUNNING_CONFIG.data_format,
                    n_output_channels=1,
                    unet_variant=RUNNING_CONFIG.unet_variant,
                    weight_init_method="he_normal",
                    activation_fn=RUNNING_CONFIG.activation_fn)

    config_proto = tf.ConfigProto()

    config_proto.allow_soft_placement = True
    config_proto.log_device_placement = False

    config_proto.gpu_options.allow_growth = True

    if RUNNING_CONFIG.use_xla:  # Only working on single GPU
        LOGGER.log("XLA is activated - Experimental Feature")
        config_proto.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1

    config_proto.gpu_options.force_gpu_compatible = True  # Force pinned memory

    run_config = tf.estimator.RunConfig(
        model_dir=None,
        tf_random_seed=None,
        save_summary_steps=1e9,  # disabled
        save_checkpoints_steps=None,
        save_checkpoints_secs=None,
        session_config=config_proto,
        keep_checkpoint_max=None,
        keep_checkpoint_every_n_hours=1e9,  # disabled
        log_step_count_steps=1e9,
        train_distribute=None,
        device_fn=None,
        protocol=None,
        eval_distribute=None,
        experimental_distribute=None)

    estimator = tf.estimator.Estimator(
        model_fn=model,
        model_dir=RUNNING_CONFIG.model_checkpoint_path,
        config=run_config,
        params={'debug_verbosity': 0})

    LOGGER.log('[*] Exporting the model ...')

    input_type = tf.float32 if RUNNING_CONFIG.input_dtype else tf.float16

    def get_serving_input_receiver_fn():

        input_shape = [RUNNING_CONFIG.batch_size, 512, 512, 1]

        def serving_input_receiver_fn():
            features = tf.placeholder(dtype=input_type,
                                      shape=input_shape,
                                      name='input_tensor')

            return tf.estimator.export.TensorServingInputReceiver(
                features=features, receiver_tensors=features)

        return serving_input_receiver_fn

    export_path = estimator.export_saved_model(
        export_dir_base=RUNNING_CONFIG.export_dir,
        serving_input_receiver_fn=get_serving_input_receiver_fn(),
        checkpoint_path=RUNNING_CONFIG.model_checkpoint_path)

    LOGGER.log('[*] Done! path: `%s`' % export_path.decode())
Ejemplo n.º 2
0
    def __init__(
            self,

            # Model Params
            input_format,  # NCHW or NHWC
            compute_format,  # NCHW or NHWC
            n_channels,
            activation_fn,
            weight_init_method,
            model_variant,
            input_shape,
            mask_shape,
            input_normalization_method,

            # Training HParams
            augment_data,
            loss_fn_name,

            #  Runtime HParams
            amp,
            xla,

            # Directory Params
            model_dir=None,
            log_dir=None,
            sample_dir=None,
            data_dir=None,
            dataset_name=None,
            dataset_hparams=None,

            # Debug Params
            log_every_n_steps=1,
            debug_verbosity=0,
            seed=None):

        if dataset_hparams is None:
            dataset_hparams = dict()

        if compute_format not in ["NHWC", 'NCHW']:
            raise ValueError(
                "Unknown `compute_format` received: %s (allowed: ['NHWC', 'NCHW'])"
                % compute_format)

        if input_format not in ["NHWC", 'NCHW']:
            raise ValueError(
                "Unknown `input_format` received: %s (allowed: ['NHWC', 'NCHW'])"
                % input_format)

        if n_channels not in [1, 3]:
            raise ValueError(
                "Unsupported number of channels: %d (allowed: 1 (grayscale) and 3 (color))"
                % n_channels)

        if data_dir is not None and not os.path.exists(data_dir):
            raise ValueError("The `data_dir` received does not exists: %s" %
                             data_dir)

        if hvd_utils.is_using_hvd():
            hvd.init()

            if hvd.rank() == 0:
                print("Horovod successfully initialized ...")

            tf_seed = 2 * (seed + hvd.rank()) if seed is not None else None

        else:
            tf_seed = 2 * seed if seed is not None else None

        # ============================================
        # Optimisation Flags - Do not remove
        # ============================================

        os.environ['CUDA_CACHE_DISABLE'] = '0'

        os.environ['HOROVOD_GPU_ALLREDUCE'] = 'NCCL'

        os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private'
        os.environ['TF_GPU_THREAD_COUNT'] = '1' if not hvd_utils.is_using_hvd(
        ) else str(hvd.size())
        print("WORLD_SIZE", hvd.size())

        os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'

        os.environ['TF_ADJUST_HUE_FUSED'] = '1'
        os.environ['TF_ADJUST_SATURATION_FUSED'] = '1'
        os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

        os.environ['TF_SYNC_ON_FINISH'] = '0'
        os.environ['TF_AUTOTUNE_THRESHOLD'] = '2'

        # =================================================

        self.xla = xla

        if amp:

            if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
                print("TF AMP is activated - Experimental Feature")

            os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"

        # =================================================

        model_hparams = tf.contrib.training.HParams(
            # Model Params
            input_format=input_format,
            compute_format=compute_format,
            input_shape=input_shape,
            mask_shape=mask_shape,
            n_channels=n_channels,
            activation_fn=activation_fn,
            weight_init_method=weight_init_method,
            model_variant=model_variant,
            input_normalization_method=input_normalization_method,

            # Training HParams
            augment_data=augment_data,
            loss_fn_name=loss_fn_name,

            # Runtime Params
            amp=amp,

            # Debug Params
            log_every_n_steps=log_every_n_steps,
            debug_verbosity=debug_verbosity,
            seed=tf_seed)

        run_config_additional = tf.contrib.training.HParams(
            dataset_hparams=dataset_hparams,
            model_dir=model_dir
            if not hvd_utils.is_using_hvd() or hvd.rank() == 0 else None,
            log_dir=log_dir
            if not hvd_utils.is_using_hvd() or hvd.rank() == 0 else None,
            sample_dir=sample_dir
            if not hvd_utils.is_using_hvd() or hvd.rank() == 0 else None,
            data_dir=data_dir,
            num_preprocessing_threads=32,
        )

        if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
            try:
                os.makedirs(sample_dir)
            except FileExistsError:
                pass

        self.run_hparams = Runner._build_hparams(model_hparams,
                                                 run_config_additional)

        if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
            print('Defining Model Estimator ...\n')

        self._model = UNet_v1(
            model_name="UNet_v1",
            input_format=self.run_hparams.input_format,
            compute_format=self.run_hparams.compute_format,
            n_output_channels=1,
            unet_variant=self.run_hparams.model_variant,
            weight_init_method=self.run_hparams.weight_init_method,
            activation_fn=self.run_hparams.activation_fn)

        if self.run_hparams.seed is not None:

            if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
                print("Deterministic Run - Seed: %d\n" % seed)

            tf.set_random_seed(self.run_hparams.seed)
            np.random.seed(self.run_hparams.seed)
            random.seed(self.run_hparams.seed)

        if dataset_name not in known_datasets.keys():
            raise RuntimeError(
                "The dataset `%s` is unknown, allowed values: %s ..." %
                (dataset_name, list(known_datasets.keys())))

        self.dataset = known_datasets[dataset_name](
            data_dir=data_dir, **self.run_hparams.dataset_hparams)

        self.num_gpus = 1 if not hvd_utils.is_using_hvd() else hvd.size()