def main(save_path=cfg.save, n_epochs=cfg.n_epochs, seed=cfg.seed): # set seed if seed is not None: set_seed(cfg.seed) cudnn.benchmark = True # back up your code os.makedirs(save_path) copy_file_backup(save_path) redirect_stdout(save_path) # Datasets train_set = LIDCSegDataset(crop_size=48, move=5, data_path=env.data, train=True) valid_set = None test_set = LIDCSegDataset(crop_size=48, move=5, data_path=env.data, train=False) # Define model model_dict = { 'resnet18': FCNResNet, 'vgg16': FCNVGG, 'densenet121': FCNDenseNet } model = model_dict[cfg.backbone](pretrained=cfg.pretrained, num_classes=2, backbone=cfg.backbone) # convert to counterparts and load pretrained weights according to various convolution if cfg.conv == 'ACSConv': model = model_to_syncbn(ACSConverter(model)) if cfg.conv == 'Conv2_5d': model = model_to_syncbn(Conv2_5dConverter(model)) if cfg.conv == 'Conv3d': if cfg.pretrained_3d == 'i3d': model = model_to_syncbn(Conv3dConverter(model, i3d_repeat_axis=-3)) else: model = model_to_syncbn( Conv3dConverter(model, i3d_repeat_axis=None)) if cfg.pretrained_3d == 'video': model = load_video_pretrained_weights( model, env.video_resnet18_pretrain_path) elif cfg.pretrained_3d == 'mednet': model = load_mednet_pretrained_weights( model, env.mednet_resnet18_pretrain_path) print(model) torch.save(model.state_dict(), os.path.join(save_path, 'model.dat')) # train and test the model train(model=model, train_set=train_set, valid_set=valid_set, test_set=test_set, save=save_path, n_epochs=n_epochs) print('Done!')
def main( save_path=cfg.save, # configuration file n_epochs=cfg.n_epochs, seed=cfg.seed): # set seed if seed is not None: set_seed(cfg.seed) cudnn.benchmark = True # improve efficiency # back up your code os.makedirs(save_path) copy_file_backup(save_path) redirect_stdout(save_path) # Datasets valid_set = None test_set = CACTwoClassDataset(crop_size=[48, 48, 48], data_path=env.data, datatype=2, fill_with=-1) # Define model model_dict = { 'resnet18': ClsResNet, 'vgg16': ClsVGG, 'densenet121': ClsDenseNet } model = model_dict[cfg.backbone](pretrained=cfg.pretrained, num_classes=2, backbone=cfg.backbone) # convert to counterparts and load pretrained weights according to various convolution if cfg.conv == 'ACSConv': model = model_to_syncbn(ACSConverter(model)) if cfg.conv == 'Conv2_5d': model = model_to_syncbn(Conv2_5dConverter(model)) if cfg.conv == 'Conv3d': if cfg.pretrained_3d == 'i3d': model = model_to_syncbn(Conv3dConverter(model, i3d_repeat_axis=-3)) else: model = model_to_syncbn( Conv3dConverter(model, i3d_repeat_axis=None)) if cfg.pretrained_3d == 'video': model = load_video_pretrained_weights( model, env.video_resnet18_pretrain_path) elif cfg.pretrained_3d == 'mednet': model = load_mednet_pretrained_weights( model, env.mednet_resnet18_pretrain_path) # print(model) torch.save(model.state_dict(), os.path.join(save_path, 'model.dat')) model_path = '/cluster/home/it_stu167/wwj/classification_after_crop/result/CACClass/resnet18/ACSConv/48_1-2_m0/epoch_107/model.dat' model.load_state_dict(torch.load(model_path)) # train and test the model train(model=model, valid_set=valid_set, test_set=test_set, save=save_path, n_epochs=n_epochs) print('Done!')
def main(save_path=cfg.save, n_epochs=cfg.n_epochs, seed=cfg.seed): if seed is not None: set_seed(cfg.seed) cudnn.benchmark = True os.makedirs(save_path) copy_file_backup(save_path) redirect_stdout(save_path) # Datasets train_data = env.data_train test_data = env.data_test shape_cp = env.shape_checkpoint train_set = BaseDatasetVoxel(train_data, cfg.train_samples) valid_set = None test_set = BaseDatasetVoxel(test_data, 200) # # Models model = UNet(6) if cfg.conv == 'Conv3D': model = Conv3dConverter(model) initialize(model.modules()) elif cfg.conv == 'Conv2_5D': if cfg.pretrained: shape_cp = torch.load(shape_cp) shape_cp.popitem() shape_cp.popitem() incompatible_keys = model.load_state_dict(shape_cp, strict=False) print('load shape pretrained weights\n', incompatible_keys) model = Conv2_5dConverter(model) elif cfg.conv == 'ACSConv': # You can use either the naive ``ACSUNet`` or the ``ACSConverter(model)`` model = ACSConverter(model) # model = ACSUNet(6) if cfg.pretrained: shape_cp = torch.load(shape_cp) shape_cp.popitem() shape_cp.popitem() incompatible_keys = model.load_state_dict(shape_cp, strict=False) print('load shape pretrained weights\n', incompatible_keys) else: raise ValueError('not valid conv') print(model) torch.save(model.state_dict(), os.path.join(save_path, 'model.dat')) # Train the model train(model=model, train_set=train_set, valid_set=valid_set, test_set=test_set, save=save_path, n_epochs=n_epochs) print('Done!')
def main(save_path=cfg.save, n_epochs=cfg.n_epochs, seed=cfg.seed): # set seed if seed is not None: set_seed(cfg.seed) cudnn.benchmark = True # back up your code os.makedirs(save_path) copy_file_backup(save_path) redirect_stdout(save_path) # Datasets train_set = CACSegDataset(crop_size=[48, 48, 48], data_path=env.data, random=cfg.random, datatype=0) valid_set = CACSegDataset(crop_size=[48, 48, 48], data_path=env.data, random=cfg.random, datatype=1) test_set = CACSegDataset(crop_size=[48, 48, 48], data_path=env.data, random=cfg.random, datatype=2) # Define model model_dict = { 'resnet18': FCNResNet, 'resnet34': FCNResNet, 'resnet50': FCNResNet, 'resnet101': FCNResNet, 'vgg16': FCNVGG, 'densenet121': FCNDenseNet, 'unet': UNet } model = model_dict[cfg.backbone](pretrained=cfg.pretrained, num_classes=3, backbone=cfg.backbone, checkpoint=cfg.checkpoint) # modified # model.load_state_dict(torch.load('/cluster/home/it_stu167/wwj/classification_after_crop/result/CACSeg/resnet18/ACSConv/200911_104150_pretrained/model.dat')) # convert to counterparts and load pretrained weights according to various convolution if cfg.conv == 'ACSConv': model = model_to_syncbn(ACSConverter(model)) if cfg.conv == 'Conv2_5d': model = model_to_syncbn(Conv2_5dConverter(model)) if cfg.conv == 'Conv3d': if cfg.pretrained_3d == 'i3d': model = model_to_syncbn(Conv3dConverter(model, i3d_repeat_axis=-3)) else: model = model_to_syncbn( Conv3dConverter(model, i3d_repeat_axis=None)) if cfg.pretrained_3d == 'video': model = load_video_pretrained_weights( model, env.video_resnet18_pretrain_path) elif cfg.pretrained_3d == 'mednet': model = load_mednet_pretrained_weights( model, env.mednet_resnet18_pretrain_path) # print(model) torch.save(model.state_dict(), os.path.join(save_path, 'model.dat')) # train and test the model train(model=model, train_set=train_set, valid_set=valid_set, test_set=test_set, save=save_path, n_epochs=n_epochs) print('Done!')
def main( save_path=cfg.save, # configuration file n_epochs=cfg.n_epochs, seed=cfg.seed): # set seed if seed is not None: set_seed(cfg.seed) cudnn.benchmark = True # improve efficiency # back up your code os.makedirs(save_path) copy_file_backup(save_path) redirect_stdout(save_path) # Datasets train_set = CACTwoClassDataset(crop_size=[48, 48, 48], data_path=env.data, datatype=0, fill_with=-1) test_set = CACTwoClassDataset(crop_size=[48, 48, 48], data_path=env.data, datatype=1, fill_with=-1) # Define model model_dict = { 'resnet18': ClsResNet, 'resnet34': ClsResNet, 'resnet50': ClsResNet, 'resnet101': ClsResNet, 'resnet152': ClsResNet, 'vgg16': ClsVGG, 'densenet121': ClsDenseNet } model = model_dict[cfg.backbone](pretrained=cfg.pretrained, num_classes=2, backbone=cfg.backbone, checkpoint=cfg.checkpoint, pooling=cfg.pooling) # convert to counterparts and load pretrained weights according to various convolution if cfg.conv == 'ACSConv': model = model_to_syncbn(ACSConverter(model)) if cfg.conv == 'Conv2_5d': model = model_to_syncbn(Conv2_5dConverter(model)) if cfg.conv == 'Conv3d': if cfg.pretrained_3d == 'i3d': model = model_to_syncbn(Conv3dConverter(model, i3d_repeat_axis=-3)) else: model = model_to_syncbn( Conv3dConverter(model, i3d_repeat_axis=None)) if cfg.pretrained_3d == 'video': model = load_video_pretrained_weights( model, env.video_resnet18_pretrain_path) elif cfg.pretrained_3d == 'mednet': model = load_mednet_pretrained_weights( model, env.mednet_resnet18_pretrain_path) # print(model) torch.save(model.state_dict(), os.path.join(save_path, 'model.dat')) # torch.save(model.state_dict(), os.path.join(save_path, 'model.pth')) # train and test the model train(model=model, train_set=train_set, test_set=test_set, save=save_path, n_epochs=n_epochs) print('Done!')