コード例 #1
0
def main(save_path=cfg.save, n_epochs=cfg.n_epochs, seed=cfg.seed):
    # set seed
    if seed is not None:
        set_seed(cfg.seed)
    cudnn.benchmark = True
    # back up your code
    os.makedirs(save_path)
    copy_file_backup(save_path)
    redirect_stdout(save_path)

    # Datasets
    train_set = LIDCSegDataset(crop_size=48,
                               move=5,
                               data_path=env.data,
                               train=True)
    valid_set = None
    test_set = LIDCSegDataset(crop_size=48,
                              move=5,
                              data_path=env.data,
                              train=False)

    # Define model
    model_dict = {
        'resnet18': FCNResNet,
        'vgg16': FCNVGG,
        'densenet121': FCNDenseNet
    }
    model = model_dict[cfg.backbone](pretrained=cfg.pretrained,
                                     num_classes=2,
                                     backbone=cfg.backbone)

    # convert to counterparts and load pretrained weights according to various convolution
    if cfg.conv == 'ACSConv':
        model = model_to_syncbn(ACSConverter(model))
    if cfg.conv == 'Conv2_5d':
        model = model_to_syncbn(Conv2_5dConverter(model))
    if cfg.conv == 'Conv3d':
        if cfg.pretrained_3d == 'i3d':
            model = model_to_syncbn(Conv3dConverter(model, i3d_repeat_axis=-3))
        else:
            model = model_to_syncbn(
                Conv3dConverter(model, i3d_repeat_axis=None))
            if cfg.pretrained_3d == 'video':
                model = load_video_pretrained_weights(
                    model, env.video_resnet18_pretrain_path)
            elif cfg.pretrained_3d == 'mednet':
                model = load_mednet_pretrained_weights(
                    model, env.mednet_resnet18_pretrain_path)
    print(model)
    torch.save(model.state_dict(), os.path.join(save_path, 'model.dat'))
    # train and test the model
    train(model=model,
          train_set=train_set,
          valid_set=valid_set,
          test_set=test_set,
          save=save_path,
          n_epochs=n_epochs)

    print('Done!')
コード例 #2
0
def main(
        save_path=cfg.save,  # configuration file
        n_epochs=cfg.n_epochs,
        seed=cfg.seed):
    # set seed
    if seed is not None:
        set_seed(cfg.seed)
    cudnn.benchmark = True  # improve efficiency
    # back up your code
    os.makedirs(save_path)
    copy_file_backup(save_path)
    redirect_stdout(save_path)

    # Datasets
    valid_set = None
    test_set = CACTwoClassDataset(crop_size=[48, 48, 48],
                                  data_path=env.data,
                                  datatype=2,
                                  fill_with=-1)

    # Define model
    model_dict = {
        'resnet18': ClsResNet,
        'vgg16': ClsVGG,
        'densenet121': ClsDenseNet
    }
    model = model_dict[cfg.backbone](pretrained=cfg.pretrained,
                                     num_classes=2,
                                     backbone=cfg.backbone)

    # convert to counterparts and load pretrained weights according to various convolution
    if cfg.conv == 'ACSConv':
        model = model_to_syncbn(ACSConverter(model))
    if cfg.conv == 'Conv2_5d':
        model = model_to_syncbn(Conv2_5dConverter(model))
    if cfg.conv == 'Conv3d':
        if cfg.pretrained_3d == 'i3d':
            model = model_to_syncbn(Conv3dConverter(model, i3d_repeat_axis=-3))
        else:
            model = model_to_syncbn(
                Conv3dConverter(model, i3d_repeat_axis=None))
            if cfg.pretrained_3d == 'video':
                model = load_video_pretrained_weights(
                    model, env.video_resnet18_pretrain_path)
            elif cfg.pretrained_3d == 'mednet':
                model = load_mednet_pretrained_weights(
                    model, env.mednet_resnet18_pretrain_path)
    # print(model)
    torch.save(model.state_dict(), os.path.join(save_path, 'model.dat'))
    model_path = '/cluster/home/it_stu167/wwj/classification_after_crop/result/CACClass/resnet18/ACSConv/48_1-2_m0/epoch_107/model.dat'
    model.load_state_dict(torch.load(model_path))
    # train and test the model
    train(model=model,
          valid_set=valid_set,
          test_set=test_set,
          save=save_path,
          n_epochs=n_epochs)

    print('Done!')
コード例 #3
0
ファイル: train_3d.py プロジェクト: wanchen1026/ACSConv
def main(save_path=cfg.save, n_epochs=cfg.n_epochs, seed=cfg.seed):
    if seed is not None:
        set_seed(cfg.seed)
    cudnn.benchmark = True

    os.makedirs(save_path)
    copy_file_backup(save_path)
    redirect_stdout(save_path)
    # Datasets
    train_data = env.data_train
    test_data = env.data_test
    shape_cp = env.shape_checkpoint

    train_set = BaseDatasetVoxel(train_data, cfg.train_samples)
    valid_set = None
    test_set = BaseDatasetVoxel(test_data, 200)

    # # Models

    model = UNet(6)
    if cfg.conv == 'Conv3D':
        model = Conv3dConverter(model)
        initialize(model.modules())
    elif cfg.conv == 'Conv2_5D':
        if cfg.pretrained:
            shape_cp = torch.load(shape_cp)
            shape_cp.popitem()
            shape_cp.popitem()
            incompatible_keys = model.load_state_dict(shape_cp, strict=False)
            print('load shape pretrained weights\n', incompatible_keys)
        model = Conv2_5dConverter(model)
    elif cfg.conv == 'ACSConv':
        # You can use either the naive ``ACSUNet`` or the ``ACSConverter(model)``
        model = ACSConverter(model)
        # model = ACSUNet(6)
        if cfg.pretrained:
            shape_cp = torch.load(shape_cp)
            shape_cp.popitem()
            shape_cp.popitem()
            incompatible_keys = model.load_state_dict(shape_cp, strict=False)
            print('load shape pretrained weights\n', incompatible_keys)
    else:
        raise ValueError('not valid conv')

    print(model)
    torch.save(model.state_dict(), os.path.join(save_path, 'model.dat'))
    # Train the model
    train(model=model,
          train_set=train_set,
          valid_set=valid_set,
          test_set=test_set,
          save=save_path,
          n_epochs=n_epochs)

    print('Done!')
コード例 #4
0
def main(save_path=cfg.save, n_epochs=cfg.n_epochs, seed=cfg.seed):
    # set seed
    if seed is not None:
        set_seed(cfg.seed)
    cudnn.benchmark = True
    # back up your code
    os.makedirs(save_path)
    copy_file_backup(save_path)
    redirect_stdout(save_path)

    # Datasets
    train_set = CACSegDataset(crop_size=[48, 48, 48],
                              data_path=env.data,
                              random=cfg.random,
                              datatype=0)
    valid_set = CACSegDataset(crop_size=[48, 48, 48],
                              data_path=env.data,
                              random=cfg.random,
                              datatype=1)
    test_set = CACSegDataset(crop_size=[48, 48, 48],
                             data_path=env.data,
                             random=cfg.random,
                             datatype=2)

    # Define model
    model_dict = {
        'resnet18': FCNResNet,
        'resnet34': FCNResNet,
        'resnet50': FCNResNet,
        'resnet101': FCNResNet,
        'vgg16': FCNVGG,
        'densenet121': FCNDenseNet,
        'unet': UNet
    }
    model = model_dict[cfg.backbone](pretrained=cfg.pretrained,
                                     num_classes=3,
                                     backbone=cfg.backbone,
                                     checkpoint=cfg.checkpoint)  # modified
    # model.load_state_dict(torch.load('/cluster/home/it_stu167/wwj/classification_after_crop/result/CACSeg/resnet18/ACSConv/200911_104150_pretrained/model.dat'))

    # convert to counterparts and load pretrained weights according to various convolution
    if cfg.conv == 'ACSConv':
        model = model_to_syncbn(ACSConverter(model))
    if cfg.conv == 'Conv2_5d':
        model = model_to_syncbn(Conv2_5dConverter(model))
    if cfg.conv == 'Conv3d':
        if cfg.pretrained_3d == 'i3d':
            model = model_to_syncbn(Conv3dConverter(model, i3d_repeat_axis=-3))
        else:
            model = model_to_syncbn(
                Conv3dConverter(model, i3d_repeat_axis=None))
            if cfg.pretrained_3d == 'video':
                model = load_video_pretrained_weights(
                    model, env.video_resnet18_pretrain_path)
            elif cfg.pretrained_3d == 'mednet':
                model = load_mednet_pretrained_weights(
                    model, env.mednet_resnet18_pretrain_path)
    # print(model)
    torch.save(model.state_dict(), os.path.join(save_path, 'model.dat'))
    # train and test the model
    train(model=model,
          train_set=train_set,
          valid_set=valid_set,
          test_set=test_set,
          save=save_path,
          n_epochs=n_epochs)

    print('Done!')
コード例 #5
0
def main(
        save_path=cfg.save,  # configuration file
        n_epochs=cfg.n_epochs,
        seed=cfg.seed):
    # set seed
    if seed is not None:
        set_seed(cfg.seed)
    cudnn.benchmark = True  # improve efficiency
    # back up your code
    os.makedirs(save_path)
    copy_file_backup(save_path)
    redirect_stdout(save_path)

    # Datasets
    train_set = CACTwoClassDataset(crop_size=[48, 48, 48],
                                   data_path=env.data,
                                   datatype=0,
                                   fill_with=-1)
    test_set = CACTwoClassDataset(crop_size=[48, 48, 48],
                                  data_path=env.data,
                                  datatype=1,
                                  fill_with=-1)

    # Define model
    model_dict = {
        'resnet18': ClsResNet,
        'resnet34': ClsResNet,
        'resnet50': ClsResNet,
        'resnet101': ClsResNet,
        'resnet152': ClsResNet,
        'vgg16': ClsVGG,
        'densenet121': ClsDenseNet
    }
    model = model_dict[cfg.backbone](pretrained=cfg.pretrained,
                                     num_classes=2,
                                     backbone=cfg.backbone,
                                     checkpoint=cfg.checkpoint,
                                     pooling=cfg.pooling)

    # convert to counterparts and load pretrained weights according to various convolution
    if cfg.conv == 'ACSConv':
        model = model_to_syncbn(ACSConverter(model))
    if cfg.conv == 'Conv2_5d':
        model = model_to_syncbn(Conv2_5dConverter(model))
    if cfg.conv == 'Conv3d':
        if cfg.pretrained_3d == 'i3d':
            model = model_to_syncbn(Conv3dConverter(model, i3d_repeat_axis=-3))
        else:
            model = model_to_syncbn(
                Conv3dConverter(model, i3d_repeat_axis=None))
            if cfg.pretrained_3d == 'video':
                model = load_video_pretrained_weights(
                    model, env.video_resnet18_pretrain_path)
            elif cfg.pretrained_3d == 'mednet':
                model = load_mednet_pretrained_weights(
                    model, env.mednet_resnet18_pretrain_path)
    # print(model)
    torch.save(model.state_dict(), os.path.join(save_path, 'model.dat'))
    # torch.save(model.state_dict(), os.path.join(save_path, 'model.pth'))
    # train and test the model
    train(model=model,
          train_set=train_set,
          test_set=test_set,
          save=save_path,
          n_epochs=n_epochs)

    print('Done!')