def main(): args = parse_args() if args.conf_path is None: train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH)) conf_path = TRAIN_CONF_PATH else: train_conf = ConfigFactory.parse_file(str(Path(args.conf_path))) conf_path = Path(args.conf_path) if get_conf(train_conf, group="experiment", key="experiment_dir") is not None: experiment_dir = Path(get_conf(train_conf, group="experiment", key="experiment_dir")) else: experiment_dir = None config = ExperimentConfig( experiment_dir=experiment_dir, batch_size=get_conf(train_conf, group="experiment", key="batch_size"), num_epochs=get_conf(train_conf, group="experiment", key="num_epochs"), gpu=get_conf(train_conf, group="experiment", key="gpu"), device=get_conf(train_conf, group="experiment", key="device"), num_workers=get_conf(train_conf, group="experiment", key="num_workers"), pin_memory=get_conf(train_conf, group="experiment", key="pin_memory"), n_inference=get_conf(train_conf, group="experiment", key="n_inference"), seed=get_conf(train_conf, group="experiment", key="seed"), ) augmentation_config = AugmentationConfig.from_conf(conf_path) shutil.copy(str(conf_path), str(config.experiment_dir.joinpath("train.conf"))) network = DefSegNet( in_channels=get_conf(train_conf, group="network", key="in_channels"), n_classes=get_conf(train_conf, group="network", key="n_classes"), n_filters=get_conf(train_conf, group="network", key="n_filters"), feature_size=get_conf(train_conf, group="network", key="feature_size"), n_slices=get_conf(train_conf, group="network", key="n_slices"), int_downsize=get_conf(train_conf, group="network", key="integrate_downsize"), bidir=True, ) checkpoint = torch.load( str(get_conf(train_conf, group="network", key="checkpoint_path")), map_location=torch.device(get_conf(train_conf, group="experiment", key="device")) ) network.load_state_dict(checkpoint) network.cuda(device=get_conf(train_conf, group="experiment", key="device")) network.freeze_vxm() training_sets, validation_sets, extra_validation_sets = construct_training_validation_dataset( DataConfig.from_conf(conf_path), feature_size=get_conf(train_conf, group="network", key="feature_size"), n_slices=get_conf(train_conf, group="network", key="n_slices"), is_3d=True, seed=config.seed, augmentation_config=augmentation_config, output_dir=config.experiment_dir, template_path=Path(get_conf(train_conf, group="network", key="template_path")), template_image_path=Path(get_conf(train_conf, group="network", key="template_image_path")), ) for train in training_sets: train.export(config.experiment_dir.joinpath("training_set_{}.csv".format(train.name))) for val in validation_sets: val.export(config.experiment_dir.joinpath("validation_set_{}.csv".format(val.name))) for val in extra_validation_sets: val.export(config.experiment_dir.joinpath("extra_validation_set_{}.csv".format(val.name))) if get_conf(train_conf, group="optimizer", key="type") == "SGD": optimizer = torch.optim.SGD( network.parameters(), lr=get_conf(train_conf, group="optimizer", key="learning_rate"), momentum=get_conf(train_conf, group="optimizer", key="momentum"), ) else: optimizer = torch.optim.Adam( network.parameters(), lr=get_conf(train_conf, group="optimizer", key="learning_rate") ) loss = UDADefSegLoss( penalty="l2", loss_mult=get_conf(train_conf, group="network", key="integrate_downsize"), weights=get_conf(train_conf, group="loss", key="weights"), ) experiment = DefSegExperiment( config=config, network=network, training_sets=training_sets, validation_sets=validation_sets, extra_validation_sets=extra_validation_sets, optimizer=optimizer, loss=loss, other_validation_metrics=[ DefSegWarpedTemplateDice(), DefSegPredDice(), DefSegWarpedMapsDice() ], inference_func=inference ) experiment.train()
def main(): args = parse_args() if args.conf_path is None: train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH)) conf_path = TRAIN_CONF_PATH else: train_conf = ConfigFactory.parse_file(str(Path(args.conf_path))) conf_path = Path(args.conf_path) if get_conf(train_conf, group="experiment", key="experiment_dir") is not None: experiment_dir = Path( get_conf(train_conf, group="experiment", key="experiment_dir")) else: experiment_dir = None config = ExperimentConfig( experiment_dir=experiment_dir, batch_size=get_conf(train_conf, group="experiment", key="batch_size"), num_epochs=get_conf(train_conf, group="experiment", key="num_epochs"), gpu=get_conf(train_conf, group="experiment", key="gpu"), device=get_conf(train_conf, group="experiment", key="device"), num_workers=get_conf(train_conf, group="experiment", key="num_workers"), pin_memory=get_conf(train_conf, group="experiment", key="pin_memory"), n_inference=get_conf(train_conf, group="experiment", key="n_inference"), seed=get_conf(train_conf, group="experiment", key="seed"), ) augmentation_config = AugmentationConfig.from_conf(conf_path) shutil.copy(str(conf_path), str(config.experiment_dir.joinpath("train.conf"))) network = UNet( in_channels=get_conf(train_conf, group="network", key="in_channels"), n_classes=get_conf(train_conf, group="network", key="n_classes"), n_filters=get_conf(train_conf, group="network", key="n_filters"), ) training_sets, validation_sets, extra_validation_sets = construct_training_validation_dataset( DataConfig.from_conf(conf_path), feature_size=get_conf(train_conf, group="network", key="feature_size"), n_slices=get_conf(train_conf, group="network", key="n_slices"), is_3d=True, seed=config.seed, augmentation_config=augmentation_config, output_dir=config.experiment_dir, ) for train in training_sets: train.export( config.experiment_dir.joinpath("training_set_{}.csv".format( train.name))) for val in validation_sets: val.export( config.experiment_dir.joinpath("validation_set_{}.csv".format( val.name))) for val in extra_validation_sets: val.export( config.experiment_dir.joinpath( "extra_validation_set_{}.csv".format(val.name))) if get_conf(train_conf, group="optimizer", key="type") == "SGD": optimizer = torch.optim.SGD( network.parameters(), lr=get_conf(train_conf, group="optimizer", key="learning_rate"), momentum=get_conf(train_conf, group="optimizer", key="momentum"), ) else: optimizer = torch.optim.Adam(network.parameters(), lr=get_conf(train_conf, group="optimizer", key="learning_rate")) if get_conf(train_conf, group="loss", key="type") == "FocalLoss": loss = FocalLoss( alpha=get_conf(train_conf, group="loss", key="alpha"), gamma=get_conf(train_conf, group="loss", key="gamma"), logits=True, ) else: loss = BCELoss() experiment = FCN3DExperiment( config=config, network=network, training_sets=training_sets, validation_sets=validation_sets, extra_validation_sets=extra_validation_sets, optimizer=optimizer, loss=loss, other_validation_metrics=[DiceCoeffWithLogits()], inference_func=inference) experiment.train()
def main(): args = parse_args() if args.starting_epoch > 0: assert args.exp_dir is not None exp_dir = Path(args.exp_dir) if args.conf_path is None: train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH)) conf_path = TRAIN_CONF_PATH else: train_conf = ConfigFactory.parse_file(str(Path(args.conf_path))) conf_path = Path(args.conf_path) if get_conf(train_conf, group="experiment", key="experiment_dir") is not None: experiment_dir = Path( get_conf(train_conf, group="experiment", key="experiment_dir")) else: experiment_dir = None config = ExperimentConfig( experiment_dir=experiment_dir, batch_size=get_conf(train_conf, group="experiment", key="batch_size"), num_epochs=get_conf(train_conf, group="experiment", key="num_epochs"), gpu=get_conf(train_conf, group="experiment", key="gpu"), device=get_conf(train_conf, group="experiment", key="device"), num_workers=get_conf(train_conf, group="experiment", key="num_workers"), pin_memory=get_conf(train_conf, group="experiment", key="pin_memory"), n_inference=get_conf(train_conf, group="experiment", key="n_inference"), seed=get_conf(train_conf, group="experiment", key="seed"), ) augmentation_config = AugmentationConfig.from_conf(conf_path) shutil.copy(str(conf_path), str(config.experiment_dir.joinpath("train.conf"))) if get_conf(train_conf, group="network", key="name") == "encodedecode": network = ImgTemplateEncoderNet( feature_size=get_conf(train_conf, group="network", key="feature_size"), n_slices=get_conf(train_conf, group="network", key="n_slices"), n_filters=get_conf(train_conf, group="network", key="n_filters"), batch_norm=get_conf(train_conf, group="network", key="batch_norm"), group_norm=get_conf(train_conf, group="network", key="group_norm"), int_downsize=get_conf(train_conf, group="network", key="integrate_downsize"), bidir=True, in_channels=get_conf(train_conf, group="network", key="in_channels"), n_classes=get_conf(train_conf, group="network", key="n_classes"), batch_size=get_conf(train_conf, group="experiment", key="batch_size"), gpu=get_conf(train_conf, group="experiment", key="gpu"), device=get_conf(train_conf, group="experiment", key="device"), ) elif get_conf(train_conf, group="network", key="name") == "istn": network = ISTNNet( feature_size=get_conf(train_conf, group="network", key="feature_size"), n_slices=get_conf(train_conf, group="network", key="n_slices"), n_filters=get_conf(train_conf, group="network", key="n_filters"), batch_norm=get_conf(train_conf, group="network", key="batch_norm"), group_norm=get_conf(train_conf, group="network", key="group_norm"), int_downsize=get_conf(train_conf, group="network", key="integrate_downsize"), bidir=True, in_channels=get_conf(train_conf, group="network", key="in_channels"), n_classes=get_conf(train_conf, group="network", key="n_classes"), batch_size=get_conf(train_conf, group="experiment", key="batch_size"), gpu=get_conf(train_conf, group="experiment", key="gpu"), device=get_conf(train_conf, group="experiment", key="device"), ) else: raise ValueError("network name not supported. istn or encodedecode") if args.starting_epoch > 0: print( exp_dir.joinpath("checkpoints", "CP_{}.pth".format(args.starting_epoch - 1))) checkpoint = torch.load( str( exp_dir.joinpath("checkpoints", "CP_{}.pth".format(args.starting_epoch - 1))), map_location=torch.device( get_conf(train_conf, group="experiment", key="device")), ) network.load_state_dict(checkpoint) network.cuda( device=get_conf(train_conf, group="experiment", key="device")) image = DefSegDataset.read_image( image_path=exp_dir.joinpath( "atlas", "epoch_{}".format(args.starting_epoch - 1), "image.nii.gz"), feature_size=None, n_slices=None, ) label = DefSegDataset.read_label( label_path=exp_dir.joinpath( "atlas", "epoch_{}".format(args.starting_epoch - 1), "label.nii.gz"), feature_size=None, n_slices=None, ) atlas = Atlas(image=image, label=label) else: atlas = None training_sets, validation_sets, extra_validation_sets = construct_training_validation_dataset( DataConfig.from_conf(conf_path), feature_size=get_conf(train_conf, group="network", key="feature_size"), n_slices=get_conf(train_conf, group="network", key="n_slices"), is_3d=True, seed=config.seed, augmentation_config=augmentation_config, output_dir=config.experiment_dir, ) for train in training_sets: train.export( config.experiment_dir.joinpath("training_set_{}.csv".format( train.name))) for val in validation_sets: val.export( config.experiment_dir.joinpath("validation_set_{}.csv".format( val.name))) for val in extra_validation_sets: val.export( config.experiment_dir.joinpath( "extra_validation_set_{}.csv".format(val.name))) if get_conf(train_conf, group="optimizer", key="type") == "SGD": optimizer = torch.optim.SGD( network.parameters(), lr=get_conf(train_conf, group="optimizer", key="learning_rate"), momentum=get_conf(train_conf, group="optimizer", key="momentum"), ) else: optimizer = torch.optim.Adam(network.parameters(), lr=get_conf(train_conf, group="optimizer", key="learning_rate")) loss = DefLoss( penalty="l2", loss_mult=get_conf(train_conf, group="network", key="integrate_downsize"), weights=get_conf(train_conf, group="loss", key="weights"), ) experiment = DefSegExperiment( config=config, network=network, training_sets=training_sets, validation_sets=validation_sets, extra_validation_sets=extra_validation_sets, optimizer=optimizer, loss=loss, other_validation_metrics=[ DefAffineWarpedTemplateDice(), DefWarpedTemplateDice(), DefPredDice(), ], inference_func=inference, ) experiment.train( starting_epoch=args.starting_epoch, atlas=atlas, atlas_eta=get_conf(train_conf, group="network", key="atlas_eta"), )
def main(): args = parse_args() if args.conf_path is None: train_conf = ConfigFactory.parse_file(str(TRAIN_CONF_PATH)) conf_path = TRAIN_CONF_PATH else: train_conf = ConfigFactory.parse_file(str(Path(args.conf_path))) conf_path = Path(args.conf_path) if get_conf(train_conf, group="experiment", key="experiment_dir") is not None: experiment_dir = Path( get_conf(train_conf, group="experiment", key="experiment_dir")) else: experiment_dir = None config = ExperimentConfig( experiment_dir=experiment_dir, batch_size=get_conf(train_conf, group="experiment", key="batch_size"), num_epochs=get_conf(train_conf, group="experiment", key="num_epochs"), gpu=get_conf(train_conf, group="experiment", key="gpu"), device=get_conf(train_conf, group="experiment", key="device"), num_workers=get_conf(train_conf, group="experiment", key="num_workers"), pin_memory=get_conf(train_conf, group="experiment", key="pin_memory"), n_inference=get_conf(train_conf, group="experiment", key="n_inference"), seed=get_conf(train_conf, group="experiment", key="seed"), ) augmentation_config = AugmentationConfig.from_conf(conf_path) shutil.copy(str(conf_path), str(config.experiment_dir.joinpath("train.conf"))) atlas_path = "D:\surui\\rbh\cardiac\DL_segmentation\RBH_3D_atlases\IHD586_10RZ04876_RBH_IHD_201410241040_MRC25598\seg_lvsa_SR_ED.nii.gz" atlas = Torch2DSegmentationDataset.read_label( label_path=Path(atlas_path), feature_size=get_conf(train_conf, group="network", key="feature_size"), n_slices=get_conf(train_conf, group="network", key="n_slices"), ) network = DefSegNet( in_channels=get_conf(train_conf, group="network", key="in_channels"), n_classes=get_conf(train_conf, group="network", key="n_classes"), n_filters=get_conf(train_conf, group="network", key="n_filters"), feature_size=get_conf(train_conf, group="network", key="feature_size"), n_slices=get_conf(train_conf, group="network", key="n_slices"), template=atlas, int_downsize=2, ) training_sets, validation_sets, extra_validation_sets = construct_training_validation_dataset( DataConfig.from_conf(conf_path), feature_size=get_conf(train_conf, group="network", key="feature_size"), n_slices=get_conf(train_conf, group="network", key="n_slices"), is_3d=True, seed=config.seed, augmentation_config=augmentation_config, output_dir=config.experiment_dir, ) for train in training_sets: train.export( config.experiment_dir.joinpath("training_set_{}.csv".format( train.name))) for val in validation_sets: val.export( config.experiment_dir.joinpath("validation_set_{}.csv".format( val.name))) for val in extra_validation_sets: val.export( config.experiment_dir.joinpath( "extra_validation_set_{}.csv".format(val.name))) if get_conf(train_conf, group="optimizer", key="type") == "SGD": optimizer = torch.optim.SGD( network.parameters(), lr=get_conf(train_conf, group="optimizer", key="learning_rate"), momentum=get_conf(train_conf, group="optimizer", key="momentum"), ) else: optimizer = torch.optim.Adam(network.parameters(), lr=get_conf(train_conf, group="optimizer", key="learning_rate")) if get_conf(train_conf, group="loss", key="type") == "FocalLoss": loss = FocalLoss( alpha=get_conf(train_conf, group="loss", key="alpha"), gamma=get_conf(train_conf, group="loss", key="gamma"), logits=True, ) else: loss = DefSegLoss(penalty="l2", loss_mult=2, template=atlas) experiment = Experiment(config=config, network=network, training_sets=training_sets, validation_sets=validation_sets, extra_validation_sets=extra_validation_sets, optimizer=optimizer, loss=loss, other_validation_metrics=[ DefSegWarpedDice(template=atlas), DefSegPredDice() ], inference_func=inference) experiment.train()