def train_step(warper, weights, optimizer, mov, fix) -> tuple: """ Train step function for backprop using gradient tape :param warper: warping function returned from layer.Warping :param weights: trainable ddf [1, f_dim1, f_dim2, f_dim3, 3] :param optimizer: tf.optimizers :param mov: moving image [1, m_dim1, m_dim2, m_dim3] :param fix: fixed image [1, f_dim1, f_dim2, f_dim3] :return: a tuple: - loss: overall loss to optimise - loss_image: image dissimilarity - loss_deform: deformation regularisation """ with tf.GradientTape() as tape: pred = warper(inputs=[weights, mov]) loss_image = REGISTRY.build_loss(config=image_loss_config)( y_true=fix, y_pred=pred, ) loss_deform = REGISTRY.build_loss(config=deform_loss_config)( inputs=weights, ) loss = loss_image + weight_deform_loss * loss_deform gradients = tape.gradient(loss, [weights]) optimizer.apply_gradients(zip(gradients, [weights])) return loss, loss_image, loss_deform
def get_data_loader(data_config: dict, split: str) -> Optional[DataLoader]: """ Return the corresponding data loader. Can't be placed in the same file of loader interfaces as it causes import cycle. :param data_config: a dictionary containing configuration for data :param split: must be train/valid/test :return: DataLoader or None, returns None if the split or dir is empty. """ if split not in KNOWN_DATA_SPLITS: raise ValueError( f"split must be one of {KNOWN_DATA_SPLITS}, got {split}") if split not in data_config: return None data_dir_paths = data_config[split].get("dir", None) if data_dir_paths is None or data_dir_paths == "": return None if isinstance(data_dir_paths, str): data_dir_paths = [data_dir_paths] # replace ~ with user home path data_dir_paths = list(map(os.path.expanduser, data_dir_paths)) for data_dir_path in data_dir_paths: if not os.path.isdir(data_dir_path): raise ValueError( f"Data directory path {data_dir_path} for split {split}" f" is not a directory or does not exist") # prepare data loader config data_loader_config = deepcopy(data_config) data_loader_config = { k: v for k, v in data_loader_config.items() if k not in KNOWN_DATA_SPLITS } data_loader_config["name"] = data_loader_config.pop("type") default_args = dict( data_dir_paths=data_dir_paths, file_loader=REGISTRY.get(category=FILE_LOADER_CLASS, key=data_config[split]["format"]), labeled=data_config[split]["labeled"], sample_label="sample" if split == "train" else "all", seed=None if split == "train" else 0, ) data_loader: DataLoader = REGISTRY.build_data_loader( config=data_loader_config, default_args=default_args) return data_loader
def get_dataset_and_preprocess( self, training: bool, batch_size: int, repeat: bool, shuffle_buffer_num_batch: int, data_augmentation: Optional[Union[List, Dict]] = None, ) -> tf.data.Dataset: """ :param training: bool, indicating if it's training or not :param batch_size: int, size of mini batch :param repeat: bool, indicating if we need to repeat the dataset :param shuffle_buffer_num_batch: int, when shuffling, the shuffle_buffer_size = batch_size * shuffle_buffer_num_batch :param repeat: bool, indicating if we need to repeat the dataset :param data_augmentation: augmentation config, can be a list of dict or dict. :returns dataset: """ dataset = self.get_dataset() # resize dataset = dataset.map( lambda x: resize_inputs( inputs=x, moving_image_size=self.moving_image_shape, fixed_image_size=self.fixed_image_shape, ), num_parallel_calls=tf.data.experimental.AUTOTUNE, ) # shuffle / repeat / batch / preprocess if training and shuffle_buffer_num_batch > 0: dataset = dataset.shuffle( buffer_size=batch_size * shuffle_buffer_num_batch, reshuffle_each_iteration=True, ) if repeat: dataset = dataset.repeat() dataset = dataset.batch(batch_size=batch_size, drop_remainder=training) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) if training and data_augmentation is not None: if isinstance(data_augmentation, dict): data_augmentation = [data_augmentation] for config in data_augmentation: da_fn = REGISTRY.build_data_augmentation( config=config, default_args={ "moving_image_size": self.moving_image_shape, "fixed_image_size": self.fixed_image_shape, "batch_size": batch_size, }, ) dataset = dataset.map( da_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE ) return dataset
def test_image_loss(self, config: dict, option: int, expected: int): method = "ddf" backbone = "local" labeled = True copied = deepcopy(config) copied["method"] = method copied["backbone"]["name"] = backbone copied["backbone"] = { **backbone_args[backbone], # type: ignore **copied["backbone"], } if option == 0: # remove image loss config, so loss is not used copied["loss"].pop("image") elif option == 1: # set image loss weight to zero, so loss is not used copied["loss"]["image"]["weight"] = 0.0 elif option == 2: # remove image loss weight, so loss is used with default weight 1 copied["loss"]["image"].pop("weight") ddf_model = REGISTRY.build_model(config=dict( name=method, # TODO we store method twice moving_image_size=moving_image_size, fixed_image_size=fixed_image_size, index_size=index_size, labeled=labeled, batch_size=batch_size, config=copied, )) assert len(ddf_model._model.losses) == expected # type: ignore
def model(method: str, labeled: bool, backbone: str) -> RegistrationModel: """ A specific registration model object. :param method: name of method :param labeled: whether the data is labeled :param backbone: name of backbone :return: the built object """ copied = deepcopy(config) copied["method"] = method copied["backbone"]["name"] = backbone # type: ignore if method == "conditional": copied["backbone"].pop("control_points", None) # type: ignore copied["backbone"].update(backbone_args[backbone]) # type: ignore return REGISTRY.build_model( # type: ignore config=dict( name=method, # TODO we store method twice moving_image_size=moving_image_size, fixed_image_size=fixed_image_size, index_size=index_size, labeled=labeled, batch_size=batch_size, config=copied, ))
def build_model(self): """Build the model to be saved as self._model.""" assert self.labeled # build inputs self._inputs = self.build_inputs() moving_image = self._inputs["moving_image"] fixed_image = self._inputs["fixed_image"] moving_label = self._inputs["moving_label"] # build ddf backbone_inputs = self.concat_images(moving_image, fixed_image, moving_label) backbone = REGISTRY.build_backbone( config=self.config["backbone"], default_args=dict( image_size=self.fixed_image_size, out_channels=1, out_kernel_initializer="glorot_uniform", out_activation="sigmoid", ), ) # (batch, f_dim1, f_dim2, f_dim3) pred_fixed_label = backbone(inputs=backbone_inputs) pred_fixed_label = tf.squeeze(pred_fixed_label, axis=4) self._outputs = dict(pred_fixed_label=pred_fixed_label) return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
def _build_loss(self, name: str, inputs_dict: dict): """ Build and add one weighted loss together with the metrics. :param name: name of loss, image / label / regularization. :param inputs_dict: inputs for loss function """ if name not in self.config["loss"]: # loss config is not defined logger.warning( f"The configuration for loss {name} is not defined. " f"Therefore it is not used.") return loss_configs = self.config["loss"][name] if not isinstance(loss_configs, list): loss_configs = [loss_configs] for loss_config in loss_configs: if "weight" not in loss_config: # default loss weight 1 logger.warning(f"The weight for loss {name} is not defined." f"Default weight = 1.0 is used.") loss_config["weight"] = 1.0 # build loss weight = loss_config["weight"] if weight == 0: logger.warning(f"The weight for loss {name} is zero." f"Loss is not used.") return # do not perform reduction over batch axis for supporting multi-device # training, model.fit() will average over global batch size automatically loss_layer: tf.keras.layers.Layer = REGISTRY.build_loss( config=dict_without(d=loss_config, key="weight"), default_args={"reduction": tf.keras.losses.Reduction.NONE}, ) loss_value = loss_layer(**inputs_dict) weighted_loss = loss_value * weight # add loss self._model.add_loss(weighted_loss) # add metric self._model.add_metric(loss_value, name=f"loss/{name}_{loss_layer.name}", aggregation="mean") self._model.add_metric( weighted_loss, name=f"loss/{name}_{loss_layer.name}_weighted", aggregation="mean", )
def get_data_loader(data_config: dict, mode: str) -> Optional[DataLoader]: """ Return the corresponding data loader. Can't be placed in the same file of loader interfaces as it causes import cycle. :param data_config: a dictionary containing configuration for data :param mode: string, must be train/valid/test :return: DataLoader or None, returns None if the data_dir_paths is empty """ assert mode in ["train", "valid", "test"], "mode must be one of train/valid/test" data_dir_paths = data_config["dir"].get(mode, None) if data_dir_paths is None or data_dir_paths == "": return None if isinstance(data_dir_paths, str): data_dir_paths = [data_dir_paths] # replace ~ with user home path data_dir_paths = list(map(os.path.expanduser, data_dir_paths)) for data_dir_path in data_dir_paths: if not os.path.isdir(data_dir_path): raise ValueError( f"Data directory path {data_dir_path} for mode {mode}" f" is not a directory or does not exist") # prepare data loader config data_loader_config = deepcopy(data_config) data_loader_config.pop("dir") data_loader_config.pop("format") data_loader_config["name"] = data_loader_config.pop("type") default_args = dict( data_dir_paths=data_dir_paths, file_loader=REGISTRY.get(category=FILE_LOADER_CLASS, key=data_config["format"]), labeled=data_config["labeled"], sample_label="sample" if mode == "train" else "all", seed=None if mode == "train" else 0, ) data_loader = REGISTRY.build_data_loader(config=data_loader_config, default_args=default_args) return data_loader
def test_data_loader(self, data_type: str, format: str): """ Test the data loader can be successfully built. :param data_type: name of data loader for registry :param format: name of file loader for registry """ # single paired data loader config = load_yaml(f"config/test/{data_type}_{format}.yaml") got = load.get_data_loader(data_config=config["dataset"], mode="train") expected = REGISTRY.get(category=DATA_LOADER_CLASS, key=data_type) assert isinstance(got, expected) # type: ignore
def _build_loss(self, name: str, inputs_dict: dict): """ Build and add one weighted loss together with the metrics. :param name: name of loss :param inputs_dict: inputs for loss function """ if name not in self.config["loss"]: # loss config is not defined logging.warning( f"The configuration for loss {name} is not defined. " f"Therefore it is not used.") return loss_configs = self.config["loss"][name] if not isinstance(loss_configs, list): loss_configs = [loss_configs] for loss_config in loss_configs: if "weight" not in loss_config: # default loss weight 1 logging.warning(f"The weight for loss {name} is not defined." f"Default weight = 1.0 is used.") loss_config["weight"] = 1.0 # build loss weight = loss_config["weight"] if weight == 0: logging.warning(f"The weight for loss {name} is zero." f"Loss is not used.") return loss_layer: tf.keras.layers.Layer = REGISTRY.build_loss( config=dict_without(d=loss_config, key="weight")) loss_value = loss_layer(**inputs_dict) / self.global_batch_size weighted_loss = loss_value * weight # add loss self._model.add_loss(weighted_loss) # add metric self._model.add_metric(loss_value, name=f"loss/{name}_{loss_layer.name}", aggregation="mean") self._model.add_metric( weighted_loss, name=f"loss/{name}_{loss_layer.name}_weighted", aggregation="mean", )
def build_model(self): """Build the model to be saved as self._model.""" # build inputs self._inputs = self.build_inputs() moving_image = self._inputs[ "moving_image"] # (batch, m_dim1, m_dim2, m_dim3) fixed_image = self._inputs[ "fixed_image"] # (batch, f_dim1, f_dim2, f_dim3) # build ddf control_points = self.config["backbone"].pop("control_points", False) backbone_inputs = self.concat_images(moving_image, fixed_image) backbone = REGISTRY.build_backbone( config=self.config["backbone"], default_args=dict( image_size=self.fixed_image_size, out_channels=3, out_kernel_initializer="zeros", out_activation=None, ), ) if isinstance(backbone, GlobalNet): # (f_dim1, f_dim2, f_dim3, 3), (4, 3) ddf, theta = backbone(inputs=backbone_inputs) self._outputs = dict(ddf=ddf, theta=theta) else: # (f_dim1, f_dim2, f_dim3, 3) ddf = backbone(inputs=backbone_inputs) ddf = (self._resize_interpolate(ddf, control_points) if control_points else ddf) self._outputs = dict(ddf=ddf) # build outputs warping = layer.Warping(fixed_image_size=self.fixed_image_size) # (f_dim1, f_dim2, f_dim3) pred_fixed_image = warping(inputs=[ddf, moving_image]) self._outputs["pred_fixed_image"] = pred_fixed_image if not self.labeled: return tf.keras.Model(inputs=self._inputs, outputs=self._outputs) # (f_dim1, f_dim2, f_dim3) moving_label = self._inputs["moving_label"] pred_fixed_label = warping(inputs=[ddf, moving_label]) self._outputs["pred_fixed_label"] = pred_fixed_label return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
def build_model(self): """Build the model to be saved as self._model.""" # build inputs self._inputs = self.build_inputs() moving_image = self._inputs["moving_image"] fixed_image = self._inputs["fixed_image"] control_points = self.config["backbone"].pop("control_points", False) # build ddf backbone_inputs = self.concat_images(moving_image, fixed_image) backbone = REGISTRY.build_backbone( config=self.config["backbone"], default_args=dict( image_size=self.fixed_image_size, out_channels=3, out_kernel_initializer="zeros", out_activation=None, ), ) dvf = backbone(inputs=backbone_inputs) dvf = self._resize_interpolate( dvf, control_points) if control_points else dvf ddf = layer.IntDVF(fixed_image_size=self.fixed_image_size)(dvf) # build outputs self._warping = layer.Warping(fixed_image_size=self.fixed_image_size) # (f_dim1, f_dim2, f_dim3, 3) pred_fixed_image = self._warping(inputs=[ddf, moving_image]) self._outputs = dict(dvf=dvf, ddf=ddf, pred_fixed_image=pred_fixed_image) if not self.labeled: return tf.keras.Model(inputs=self._inputs, outputs=self._outputs) # (f_dim1, f_dim2, f_dim3, 3) moving_label = self._inputs["moving_label"] pred_fixed_label = self._warping(inputs=[ddf, moving_label]) self._outputs["pred_fixed_label"] = pred_fixed_label return tf.keras.Model(inputs=self._inputs, outputs=self._outputs)
def train_step(grid, weights, optimizer, mov, fix) -> object: """ Train step function for backprop using gradient tape :param grid: reference grid return from layer_util.get_reference_grid :param weights: trainable affine parameters [1, 4, 3] :param optimizer: tf.optimizers :param mov: moving image [1, m_dim1, m_dim2, m_dim3] :param fix: fixed image [1, f_dim1, f_dim2, f_dim3] :return loss: image dissimilarity to minimise """ with tf.GradientTape() as tape: pred = layer_util.resample(vol=mov, loc=layer_util.warp_grid(grid, weights)) loss = REGISTRY.build_loss(config=image_loss_config)( y_true=fix, y_pred=pred, ) gradients = tape.gradient(loss, [weights]) optimizer.apply_gradients(zip(gradients, [weights])) return loss
def predict( gpu: str, ckpt_path: str, split: str, batch_size: int, exp_name: str, config_path: Union[str, List[str]], num_workers: int = 1, gpu_allow_growth: bool = True, save_nifti: bool = True, save_png: bool = True, log_dir: str = "logs", ): """ Function to predict some metrics from the saved model and logging results. :param gpu: which env gpu to use. :param ckpt_path: where model is stored, should be like log_folder/save/ckpt-x. :param split: train / valid / test, to define the split to be evaluated. :param batch_size: int, batch size to perform predictions. :param exp_name: name of the experiment. :param config_path: to overwrite the default config. :param num_workers: number of cpu cores to be used, <=0 means not limited. :param gpu_allow_growth: whether to allocate whole GPU memory for training. :param save_nifti: if true, outputs will be saved in nifti format. :param save_png: if true, outputs will be saved in png format. :param log_dir: path of the log directory. """ # env vars os.environ["CUDA_VISIBLE_DEVICES"] = gpu os.environ[ "TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true" if num_workers <= 0: # pragma: no cover logger.info( "Limiting CPU usage by setting environment variables " "OMP_NUM_THREADS, TF_NUM_INTRAOP_THREADS, TF_NUM_INTEROP_THREADS to %d. " "This may slow down the prediction. " "Please use --num_workers flag to modify the behavior. " "Setting to 0 or negative values will remove the limitation.", num_workers, ) # limit CPU usage # https://github.com/tensorflow/tensorflow/issues/29968#issuecomment-789604232 os.environ["OMP_NUM_THREADS"] = str(num_workers) os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers) os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers) # load config config, log_dir, ckpt_path = build_config(config_path=config_path, log_dir=log_dir, exp_name=exp_name, ckpt_path=ckpt_path) config["train"]["preprocess"]["batch_size"] = batch_size # data data_loader, dataset, _ = build_dataset( dataset_config=config["dataset"], preprocess_config=config["train"]["preprocess"], split=split, training=False, repeat=False, ) assert data_loader is not None # use strategy to support multiple GPUs # the network is mirrored in each GPU so that we can use larger batch size # https://www.tensorflow.org/guide/distributed_training # only model, optimizer and metrics need to be defined inside the strategy num_devices = max(len(tf.config.list_physical_devices("GPU")), 1) if num_devices > 1: # pragma: no cover strategy = tf.distribute.MirroredStrategy() if batch_size % num_devices != 0: raise ValueError( f"batch size {batch_size} can not be divided evenly " f"by the number of devices.") else: strategy = tf.distribute.get_strategy() with strategy.scope(): model: tf.keras.Model = REGISTRY.build_model(config=dict( name=config["train"]["method"], moving_image_size=data_loader.moving_image_shape, fixed_image_size=data_loader.fixed_image_shape, index_size=data_loader.num_indices, labeled=config["dataset"][split]["labeled"], batch_size=batch_size, config=config["train"], )) optimizer = opt.build_optimizer( optimizer_config=config["train"]["optimizer"]) model.compile(optimizer=optimizer) model.plot_model(output_dir=log_dir) # load weights if ckpt_path.endswith(".ckpt"): # for ckpt from tf.keras.callbacks.ModelCheckpoint # skip warnings because of optimizers # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-object model.load_weights(ckpt_path).expect_partial() # pragma: no cover else: # for ckpts from ckpt manager callback _, _ = build_checkpoint_callback( model=model, dataset=dataset, log_dir=log_dir, save_period=config["train"]["save_period"], ckpt_path=ckpt_path, ) # predict fixed_grid_ref = tf.expand_dims( layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape), axis=0) # shape = (1, f_dim1, f_dim2, f_dim3, 3) predict_on_dataset( dataset=dataset, fixed_grid_ref=fixed_grid_ref, model=model, save_dir=os.path.join(log_dir, "test"), save_nifti=save_nifti, save_png=save_png, ) # close the opened files in data loaders data_loader.close()
def test_register_err(self, category, key, err_msg): with pytest.raises(ValueError) as err_info: REGISTRY.register(category=category, name=key, cls=0) assert err_msg in str(err_info.value)
def test_register(self): category, key, value = "backbone_class", "test_key", 0 REGISTRY.register(category=category, name=key, cls=value) assert REGISTRY._dict[(category, key)] == value assert REGISTRY.get(category, key) == value
def test_get_err(self): with pytest.raises(ValueError) as err_info: REGISTRY.get("backbone_class", "wrong_key") assert "has not been registered" in str(err_info.value)
def test_get_backbone(self): # no error means the unet has been registered _ = REGISTRY.get("backbone_class", "unet")
def reg(self): return REGISTRY.copy()
def test_register_backbone(self): key = "new_backbone" value = 0 REGISTRY.register_backbone(name=key, cls=value) got = REGISTRY.get_backbone(key) assert got == value
def predict( gpu: str, gpu_allow_growth: bool, ckpt_path: str, mode: str, batch_size: int, exp_name: str, config_path: Union[str, List[str]], save_nifti: bool = True, save_png: bool = True, log_dir: str = "logs", ): """ Function to predict some metrics from the saved model and logging results. :param gpu: which env gpu to use. :param gpu_allow_growth: whether to allow gpu growth or not :param ckpt_path: where model is stored, should be like log_folder/save/ckpt-x :param mode: train / valid / test, to define which split of dataset to be evaluated :param batch_size: int, batch size to perform predictions in :param exp_name: name of the experiment :param log_dir: path of the log directory :param save_nifti: if true, outputs will be saved in nifti format :param save_png: if true, outputs will be saved in png format :param config_path: to overwrite the default config """ # TODO support custom sample_label logging.warning("sample_label is not used in predict. " "It is True if and only if mode == 'train'.") # env vars os.environ["CUDA_VISIBLE_DEVICES"] = gpu os.environ[ "TF_FORCE_GPU_ALLOW_GROWTH"] = "false" if gpu_allow_growth else "true" # load config config, log_dir, ckpt_path = build_config(config_path=config_path, log_dir=log_dir, exp_name=exp_name, ckpt_path=ckpt_path) preprocess_config = config["train"]["preprocess"] # batch_size corresponds to batch_size per GPU gpus = tf.config.experimental.list_physical_devices("GPU") preprocess_config["batch_size"] = batch_size * max(len(gpus), 1) # data data_loader, dataset, _ = build_dataset( dataset_config=config["dataset"], preprocess_config=preprocess_config, mode=mode, training=False, repeat=False, ) assert data_loader is not None # optimizer optimizer = opt.build_optimizer( optimizer_config=config["train"]["optimizer"]) # model model: tf.keras.Model = REGISTRY.build_model(config=dict( name=config["train"]["method"], moving_image_size=data_loader.moving_image_shape, fixed_image_size=data_loader.fixed_image_shape, index_size=data_loader.num_indices, labeled=config["dataset"]["labeled"], batch_size=config["train"]["preprocess"]["batch_size"], config=config["train"], )) # metrics model.compile(optimizer=optimizer) # load weights if ckpt_path.endswith(".ckpt"): # for ckpt from tf.keras.callbacks.ModelCheckpoint # skip warnings because of optimizers # https://stackoverflow.com/questions/58289342/tf2-0-translation-model-error-when-restoring-the-saved-model-unresolved-object model.load_weights(ckpt_path).expect_partial() # pragma: no cover else: # for ckpts from ckpt manager callback _, _ = build_checkpoint_callback( model=model, dataset=dataset, log_dir=log_dir, save_period=config["train"]["save_period"], ckpt_path=ckpt_path, ) # predict fixed_grid_ref = tf.expand_dims( layer_util.get_reference_grid(grid_size=data_loader.fixed_image_shape), axis=0) # shape = (1, f_dim1, f_dim2, f_dim3, 3) predict_on_dataset( dataset=dataset, fixed_grid_ref=fixed_grid_ref, model=model, model_method=config["train"]["method"], save_dir=os.path.join(log_dir, "test"), save_nifti=save_nifti, save_png=save_png, ) # close the opened files in data loaders data_loader.close()
def train( gpu: str, config_path: Union[str, List[str]], gpu_allow_growth: bool, ckpt_path: str, exp_name: str = "", log_dir: str = "logs", max_epochs: int = -1, ): """ Function to train a model. :param gpu: which local gpu to use to train. :param config_path: path to configuration set up. :param gpu_allow_growth: whether to allocate whole GPU memory for training. :param ckpt_path: where to store training checkpoints. :param log_dir: path of the log directory. :param exp_name: experiment name. :param max_epochs: if max_epochs > 0, will use it to overwrite the configuration. """ # set env variables os.environ["CUDA_VISIBLE_DEVICES"] = gpu os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true" if gpu_allow_growth else "false" # load config config, log_dir, ckpt_path = build_config( config_path=config_path, log_dir=log_dir, exp_name=exp_name, ckpt_path=ckpt_path, max_epochs=max_epochs, ) # build dataset data_loader_train, dataset_train, steps_per_epoch_train = build_dataset( dataset_config=config["dataset"], preprocess_config=config["train"]["preprocess"], mode="train", training=True, repeat=True, ) assert data_loader_train is not None # train data should not be None data_loader_val, dataset_val, steps_per_epoch_val = build_dataset( dataset_config=config["dataset"], preprocess_config=config["train"]["preprocess"], mode="valid", training=False, repeat=True, ) # use strategy to support multiple GPUs # the network is mirrored in each GPU so that we can use larger batch size # https://www.tensorflow.org/guide/distributed_training # only model, optimizer and metrics need to be defined inside the strategy num_devices = max(len(tf.config.list_physical_devices("GPU")), 1) if num_devices > 1: strategy = tf.distribute.MirroredStrategy() # pragma: no cover else: strategy = tf.distribute.get_strategy() with strategy.scope(): model: tf.keras.Model = REGISTRY.build_model( config=dict( name=config["train"]["method"], moving_image_size=data_loader_train.moving_image_shape, fixed_image_size=data_loader_train.fixed_image_shape, index_size=data_loader_train.num_indices, labeled=config["dataset"]["labeled"], batch_size=config["train"]["preprocess"]["batch_size"], config=config["train"], num_devices=num_devices, ) ) optimizer = opt.build_optimizer(optimizer_config=config["train"]["optimizer"]) # compile model.compile(optimizer=optimizer) model.plot_model(output_dir=log_dir) # build callbacks tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=log_dir, histogram_freq=config["train"]["save_period"] ) ckpt_callback, initial_epoch = build_checkpoint_callback( model=model, dataset=dataset_train, log_dir=log_dir, save_period=config["train"]["save_period"], ckpt_path=ckpt_path, ) callbacks = [tensorboard_callback, ckpt_callback] # train # it's necessary to define the steps_per_epoch # and validation_steps to prevent errors like # BaseCollectiveExecutor::StartAbort Out of range: End of sequence model.fit( x=dataset_train, steps_per_epoch=steps_per_epoch_train, initial_epoch=initial_epoch, epochs=config["train"]["epochs"], validation_data=dataset_val, validation_steps=steps_per_epoch_val, callbacks=callbacks, ) # close file loaders in data loaders after training data_loader_train.close() if data_loader_val is not None: data_loader_val.close()
def train( gpu: str, config_path: Union[str, List[str]], ckpt_path: str, num_workers: int = 1, gpu_allow_growth: bool = True, exp_name: str = "", log_dir: str = "logs", max_epochs: int = -1, ): """ Function to train a model. :param gpu: which local gpu to use to train. :param config_path: path to configuration set up. :param ckpt_path: where to store training checkpoints. :param num_workers: number of cpu cores to be used, <=0 means not limited. :param gpu_allow_growth: whether to allocate whole GPU memory for training. :param log_dir: path of the log directory. :param exp_name: experiment name. :param max_epochs: if max_epochs > 0, will use it to overwrite the configuration. """ # set env variables os.environ["CUDA_VISIBLE_DEVICES"] = gpu os.environ[ "TF_FORCE_GPU_ALLOW_GROWTH"] = "true" if gpu_allow_growth else "false" if num_workers <= 0: # pragma: no cover logger.info( "Limiting CPU usage by setting environment variables " "OMP_NUM_THREADS, TF_NUM_INTRAOP_THREADS, TF_NUM_INTEROP_THREADS to %d. " "This may slow down the training. " "Please use --num_workers flag to modify the behavior. " "Setting to 0 or negative values will remove the limitation.", num_workers, ) # limit CPU usage # https://github.com/tensorflow/tensorflow/issues/29968#issuecomment-789604232 os.environ["OMP_NUM_THREADS"] = str(num_workers) os.environ["TF_NUM_INTRAOP_THREADS"] = str(num_workers) os.environ["TF_NUM_INTEROP_THREADS"] = str(num_workers) # load config config, log_dir, ckpt_path = build_config( config_path=config_path, log_dir=log_dir, exp_name=exp_name, ckpt_path=ckpt_path, max_epochs=max_epochs, ) # build dataset data_loader_train, dataset_train, steps_per_epoch_train = build_dataset( dataset_config=config["dataset"], preprocess_config=config["train"]["preprocess"], split="train", training=True, repeat=True, ) assert data_loader_train is not None # train data should not be None data_loader_val, dataset_val, steps_per_epoch_val = build_dataset( dataset_config=config["dataset"], preprocess_config=config["train"]["preprocess"], split="valid", training=False, repeat=True, ) # use strategy to support multiple GPUs # the network is mirrored in each GPU so that we can use larger batch size # https://www.tensorflow.org/guide/distributed_training # only model, optimizer and metrics need to be defined inside the strategy num_devices = max(len(tf.config.list_physical_devices("GPU")), 1) batch_size = config["train"]["preprocess"]["batch_size"] if num_devices > 1: # pragma: no cover strategy = tf.distribute.MirroredStrategy() if batch_size % num_devices != 0: raise ValueError( f"batch size {batch_size} can not be divided evenly " f"by the number of devices.") else: strategy = tf.distribute.get_strategy() with strategy.scope(): model: tf.keras.Model = REGISTRY.build_model(config=dict( name=config["train"]["method"], moving_image_size=data_loader_train.moving_image_shape, fixed_image_size=data_loader_train.fixed_image_shape, index_size=data_loader_train.num_indices, labeled=config["dataset"]["train"]["labeled"], batch_size=batch_size, config=config["train"], )) optimizer = opt.build_optimizer( optimizer_config=config["train"]["optimizer"]) model.compile(optimizer=optimizer) model.plot_model(output_dir=log_dir) # build callbacks tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=log_dir, histogram_freq=config["train"]["save_period"], update_freq=config["train"].get("update_freq", "epoch"), ) ckpt_callback, initial_epoch = build_checkpoint_callback( model=model, dataset=dataset_train, log_dir=log_dir, save_period=config["train"]["save_period"], ckpt_path=ckpt_path, ) callbacks = [tensorboard_callback, ckpt_callback] # train # it's necessary to define the steps_per_epoch # and validation_steps to prevent errors like # BaseCollectiveExecutor::StartAbort Out of range: End of sequence model.fit( x=dataset_train, steps_per_epoch=steps_per_epoch_train, initial_epoch=initial_epoch, epochs=config["train"]["epochs"], validation_data=dataset_val, validation_steps=steps_per_epoch_val, callbacks=callbacks, ) # close file loaders in data loaders after training data_loader_train.close() if data_loader_val is not None: data_loader_val.close()
def get_dataset_and_preprocess( self, training: bool, batch_size: int, repeat: bool, shuffle_buffer_num_batch: int, data_augmentation: Optional[Union[List, Dict]] = None, num_parallel_calls: int = tf.data.experimental.AUTOTUNE, ) -> tf.data.Dataset: """ Generate tf.data.dataset. Reference: - https://www.tensorflow.org/guide/data_performance#parallelizing_data_transformation - https://www.tensorflow.org/api_docs/python/tf/data/Dataset :param training: indicating if it's training or not :param batch_size: size of mini batch :param repeat: indicating if we need to repeat the dataset :param shuffle_buffer_num_batch: when shuffling, the shuffle_buffer_size = batch_size * shuffle_buffer_num_batch :param repeat: indicating if we need to repeat the dataset :param data_augmentation: augmentation config, can be a list of dict or dict. :param num_parallel_calls: number elements to process asynchronously in parallel during preprocessing, -1 means unlimited, heuristically it should be set to the number of CPU cores available. AUTOTUNE=-1 means not limited. :returns dataset: """ dataset = self.get_dataset() # resize dataset = dataset.map( lambda x: resize_inputs( inputs=x, moving_image_size=self.moving_image_shape, fixed_image_size=self.fixed_image_shape, ), num_parallel_calls=num_parallel_calls, ) # shuffle / repeat / batch / preprocess if training and shuffle_buffer_num_batch > 0: dataset = dataset.shuffle( buffer_size=batch_size * shuffle_buffer_num_batch, reshuffle_each_iteration=True, ) if repeat: dataset = dataset.repeat() dataset = dataset.batch(batch_size=batch_size, drop_remainder=training) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) if training and data_augmentation is not None: if isinstance(data_augmentation, dict): data_augmentation = [data_augmentation] for config in data_augmentation: da_fn = REGISTRY.build_data_augmentation( config=config, default_args={ "moving_image_size": self.moving_image_shape, "fixed_image_size": self.fixed_image_shape, "batch_size": batch_size, }, ) dataset = dataset.map(da_fn, num_parallel_calls=num_parallel_calls) return dataset