示例#1
0
from data_loading.data_module import DataModule
from models.nn_unet import NNUnet
from utils.gpu_affinity import set_affinity
from utils.logger import LoggingCallback
from utils.utils import get_main_args, is_main_process, log, make_empty_dir, set_cuda_devices, verify_ckpt_path

if __name__ == "__main__":
    args = get_main_args()

    if args.affinity != "disabled":
        affinity = set_affinity(os.getenv("LOCAL_RANK", "0"), args.affinity)

    set_cuda_devices(args)
    seed_everything(args.seed)
    data_module = DataModule(args)
    data_module.prepare_data()
    data_module.setup()
    ckpt_path = verify_ckpt_path(args)

    callbacks = None
    model_ckpt = None
    if args.benchmark:
        model = NNUnet(args)
        batch_size = args.batch_size if args.exec_mode == "train" else args.val_batch_size
        log_dir = os.path.join(args.results, args.logname if args.logname is not None else "perf.json")
        callbacks = [
            LoggingCallback(
                log_dir=log_dir,
                global_batch_size=batch_size * args.gpus,
                mode=args.exec_mode,
示例#2
0
        choices=["pre", "post"],
        help="Type of task to run; pre - localization, post - damage assesment",
    )
    arg("--seed", type=int, default=1)

    parser = Model.add_model_specific_args(parser)
    args = parser.parse_args()
    if args.interpolate:
        args.deep_supervision = False
        args.dec_interp = False

    set_cuda_devices(args.gpus)
    affinity = set_affinity(os.getenv("LOCAL_RANK", "0"),
                            "socket_unique_interleaved")
    seed_everything(args.seed)
    data_module = DataModule(args)

    callbacks = None
    checkpoint = args.ckpt if args.ckpt is not None and os.path.exists(
        args.ckpt) else None
    if args.exec_mode == "train":
        model = Model(args)
        model_ckpt = ModelCheckpoint(monitor="f1_score",
                                     mode="max",
                                     save_last=True)
        callbacks = [
            EarlyStopping(monitor="f1_score",
                          patience=args.patience,
                          verbose=True,
                          mode="max")
        ]