Ejemplo n.º 1
0
    def __init__(
            self,

            # Model Params
            input_format,  # NCHW or NHWC
            compute_format,  # NCHW or NHWC
            n_channels,
            activation_fn,
            weight_init_method,
            model_variant,
            input_shape,
            mask_shape,
            input_normalization_method,

            # Training HParams
            augment_data,
            loss_fn_name,

            #  Runtime HParams
            amp,
            xla,

            # Directory Params
            model_dir=None,
            log_dir=None,
            sample_dir=None,
            data_dir=None,
            dataset_name=None,
            dataset_hparams=None,

            # Debug Params
            log_every_n_steps=1,
            debug_verbosity=0,
            seed=None):

        if dataset_hparams is None:
            dataset_hparams = dict()

        if compute_format not in ["NHWC", 'NCHW']:
            raise ValueError(
                "Unknown `compute_format` received: %s (allowed: ['NHWC', 'NCHW'])"
                % compute_format)

        if input_format not in ["NHWC", 'NCHW']:
            raise ValueError(
                "Unknown `input_format` received: %s (allowed: ['NHWC', 'NCHW'])"
                % input_format)

        if n_channels not in [1, 3]:
            raise ValueError(
                "Unsupported number of channels: %d (allowed: 1 (grayscale) and 3 (color))"
                % n_channels)

        if data_dir is not None and not os.path.exists(data_dir):
            raise ValueError("The `data_dir` received does not exists: %s" %
                             data_dir)

        if hvd_utils.is_using_hvd():
            hvd.init()

            if hvd.rank() == 0:
                print("Horovod successfully initialized ...")

            tf_seed = 2 * (seed + hvd.rank()) if seed is not None else None

        else:
            tf_seed = 2 * seed if seed is not None else None

        # ============================================
        # Optimisation Flags - Do not remove
        # ============================================

        os.environ['CUDA_CACHE_DISABLE'] = '0'

        os.environ['HOROVOD_GPU_ALLREDUCE'] = 'NCCL'

        os.environ['TF_GPU_THREAD_MODE'] = 'gpu_private'
        os.environ['TF_GPU_THREAD_COUNT'] = '1' if not hvd_utils.is_using_hvd(
        ) else str(hvd.size())
        print("WORLD_SIZE", hvd.size())

        os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'

        os.environ['TF_ADJUST_HUE_FUSED'] = '1'
        os.environ['TF_ADJUST_SATURATION_FUSED'] = '1'
        os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1'

        os.environ['TF_SYNC_ON_FINISH'] = '0'
        os.environ['TF_AUTOTUNE_THRESHOLD'] = '2'

        # =================================================

        self.xla = xla

        if amp:

            if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
                print("TF AMP is activated - Experimental Feature")

            os.environ["TF_ENABLE_AUTO_MIXED_PRECISION_GRAPH_REWRITE"] = "1"

        # =================================================

        model_hparams = tf.contrib.training.HParams(
            # Model Params
            input_format=input_format,
            compute_format=compute_format,
            input_shape=input_shape,
            mask_shape=mask_shape,
            n_channels=n_channels,
            activation_fn=activation_fn,
            weight_init_method=weight_init_method,
            model_variant=model_variant,
            input_normalization_method=input_normalization_method,

            # Training HParams
            augment_data=augment_data,
            loss_fn_name=loss_fn_name,

            # Runtime Params
            amp=amp,

            # Debug Params
            log_every_n_steps=log_every_n_steps,
            debug_verbosity=debug_verbosity,
            seed=tf_seed)

        run_config_additional = tf.contrib.training.HParams(
            dataset_hparams=dataset_hparams,
            model_dir=model_dir
            if not hvd_utils.is_using_hvd() or hvd.rank() == 0 else None,
            log_dir=log_dir
            if not hvd_utils.is_using_hvd() or hvd.rank() == 0 else None,
            sample_dir=sample_dir
            if not hvd_utils.is_using_hvd() or hvd.rank() == 0 else None,
            data_dir=data_dir,
            num_preprocessing_threads=32,
        )

        if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
            try:
                os.makedirs(sample_dir)
            except FileExistsError:
                pass

        self.run_hparams = Runner._build_hparams(model_hparams,
                                                 run_config_additional)

        if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
            print('Defining Model Estimator ...\n')

        self._model = UNet_v1(
            model_name="UNet_v1",
            input_format=self.run_hparams.input_format,
            compute_format=self.run_hparams.compute_format,
            n_output_channels=1,
            unet_variant=self.run_hparams.model_variant,
            weight_init_method=self.run_hparams.weight_init_method,
            activation_fn=self.run_hparams.activation_fn)

        if self.run_hparams.seed is not None:

            if not hvd_utils.is_using_hvd() or hvd.rank() == 0:
                print("Deterministic Run - Seed: %d\n" % seed)

            tf.set_random_seed(self.run_hparams.seed)
            np.random.seed(self.run_hparams.seed)
            random.seed(self.run_hparams.seed)

        if dataset_name not in known_datasets.keys():
            raise RuntimeError(
                "The dataset `%s` is unknown, allowed values: %s ..." %
                (dataset_name, list(known_datasets.keys())))

        self.dataset = known_datasets[dataset_name](
            data_dir=data_dir, **self.run_hparams.dataset_hparams)

        self.num_gpus = 1 if not hvd_utils.is_using_hvd() else hvd.size()
def parse_cmdline():

    p = argparse.ArgumentParser(description="JoC-UNet_v1-TF")

    p.add_argument(
        '--unet_variant',
        default="tinyUNet",
        choices=UNet_v1.authorized_models_variants,
        type=str,
        required=False,
        help=
        """Which model size is used. This parameter control directly the size and the number of parameters"""
    )

    p.add_argument(
        '--activation_fn',
        choices=authorized_activation_fn,
        type=str,
        default="relu",
        required=False,
        help=
        """Which activation function is used after the convolution layers""")

    p.add_argument('--exec_mode',
                   choices=[
                       'train', 'train_and_evaluate', 'evaluate',
                       'training_benchmark', 'inference_benchmark'
                   ],
                   type=str,
                   required=True,
                   help="""Which execution mode to run the model into""")

    p.add_argument(
        '--iter_unit',
        choices=['epoch', 'batch'],
        type=str,
        required=True,
        help="""Will the model be run for X batches or X epochs ?""")

    p.add_argument('--num_iter',
                   type=int,
                   required=True,
                   help="""Number of iterations to run.""")

    p.add_argument('--batch_size',
                   type=int,
                   required=True,
                   help="""Size of each minibatch per GPU.""")

    p.add_argument(
        '--warmup_step',
        default=200,
        type=int,
        required=False,
        help=
        """Number of steps considered as warmup and not taken into account for performance measurements."""
    )

    p.add_argument(
        '--results_dir',
        type=str,
        required=True,
        help=
        """Directory in which to write training logs, summaries and checkpoints."""
    )

    _add_bool_argument(
        parser=p,
        name="save_eval_results_to_json",
        default=False,
        required=False,
        help="Whether to save evaluation results in JSON format.")

    p.add_argument('--data_dir',
                   required=False,
                   default=None,
                   type=str,
                   help="Path to dataset directory")

    p.add_argument(
        '--dataset_name',
        choices=list(known_datasets.keys()),
        type=str,
        required=True,
        help=
        """Name of the dataset used in this run (only DAGM2007 is supported atm.)"""
    )

    p.add_argument(
        '--dataset_classID',
        default=None,
        type=int,
        required=False,
        help=
        """ClassID to consider to train or evaluate the network (used for DAGM)."""
    )

    p.add_argument(
        '--data_format',
        choices=['NHWC', 'NCHW'],
        type=str,
        default="NCHW",
        required=False,
        help="""Which Tensor format is used for computation inside the mode""")

    _add_bool_argument(
        parser=p,
        name="use_tf_amp",
        default=False,
        required=False,
        help=
        "Enable Automatic Mixed Precision to speedup FP32 computation using tensor cores"
    )

    _add_bool_argument(parser=p,
                       name="use_xla",
                       default=False,
                       required=False,
                       help="Enable Tensorflow XLA to maximise performance.")

    p.add_argument(
        '--weight_init_method',
        choices=UNet_v1.authorized_weight_init_methods,
        default="he_normal",
        type=str,
        required=False,
        help=
        """Which initialisation method is used to randomly intialize the model during training"""
    )

    p.add_argument('--learning_rate',
                   default=1e-5,
                   type=float,
                   required=False,
                   help="""Learning rate value.""")

    p.add_argument('--learning_rate_decay_factor',
                   default=0.75,
                   type=float,
                   required=False,
                   help="""Decay factor to decrease the learning rate.""")

    p.add_argument('--learning_rate_decay_steps',
                   default=500,
                   type=int,
                   required=False,
                   help="""Decay factor to decrease the learning rate.""")

    p.add_argument('--rmsprop_decay',
                   default=0.9,
                   type=float,
                   required=False,
                   help="""RMSProp - Decay value.""")

    p.add_argument('--rmsprop_momentum',
                   default=0.8,
                   type=float,
                   required=False,
                   help="""RMSProp - Momentum value.""")

    p.add_argument('--weight_decay',
                   default=1e-4,
                   type=float,
                   required=False,
                   help="""Weight Decay scale factor""")

    _add_bool_argument(parser=p,
                       name="use_auto_loss_scaling",
                       default=False,
                       required=False,
                       help="Use AutoLossScaling with TF-AMP")

    p.add_argument('--loss_fn_name',
                   type=str,
                   default="adaptive_loss",
                   required=False,
                   help="""Loss function Name to use to train the network""")

    _add_bool_argument(parser=p,
                       name="augment_data",
                       default=True,
                       required=False,
                       help="Choose whether to use data augmentation")

    p.add_argument(
        '--display_every',
        type=int,
        default=50,
        required=False,
        help="""How often (in batches) to print out debug information.""")

    p.add_argument(
        '--debug_verbosity',
        choices=[0, 1, 2],
        default=0,
        type=int,
        required=False,
        help=
        """Verbosity Level: 0 minimum, 1 with layer creation debug info, 2 with layer + var creation debug info."""
    )

    p.add_argument('--seed', type=int, default=None, help="""Random seed.""")

    FLAGS, unknown_args = p.parse_known_args()

    if len(unknown_args) > 0:

        for bad_arg in unknown_args:
            print("ERROR: Unknown command line arg: %s" % bad_arg)

        raise ValueError("Invalid command line arg(s)")

    return FLAGS