示例#1
0
def Train(train_root, train_csv, test_root, test_csv):

    # parameters
    args = parse_args()
    besttraindice = 0.0

    # record
    record_params(args)

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_order
    torch.manual_seed(args.torch_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.torch_seed)
    np.random.seed(args.torch_seed)
    random.seed(args.torch_seed)

    if args.cudnn == 0:
        cudnn.benchmark = False
    else:
        cudnn.benchmark = True
        cudnn.deterministic = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    num_classes = 2

    net = build_model(args.model_name, num_classes)

    # resume
    params_name = '{}_r{}.pkl'.format(args.model_name, args.repetition)
    start_epoch = 0
    history = {
        'train_loss': [],
        'test_loss': [],
        'train_dice': [],
        'test_dice': []
    }
    end_epoch = start_epoch + args.num_epoch

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        net = nn.DataParallel(net)
    net.to(device)

    # data
    img_size = args.img_size
    ## train3_multidomainl_normalcl
    train_aug = Compose([
        Resize(size=(img_size, img_size)),
        ToTensor(),
        Normalize(mean=args.data_mean, std=args.data_std)
    ])
    ## test
    test_aug = train_aug

    train_dataset = kidney_seg(root=train_root,
                               csv_file=train_csv,
                               maskidentity=args.maskidentity,
                               train=True,
                               transform=train_aug)
    test_dataset = kidney_seg(root=test_root,
                              csv_file=test_csv,
                              maskidentity=args.maskidentity,
                              train=False,
                              transform=test_aug)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=4,
                              shuffle=True,
                              drop_last=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             num_workers=4,
                             shuffle=False)

    # loss function, optimizer and scheduler
    cedice_weight = torch.tensor(args.cedice_weight)
    ceclass_weight = torch.tensor(args.ceclass_weight)
    diceclass_weight = torch.tensor(args.diceclass_weight)

    if args.loss == 'ce':
        criterion = CrossEntropyLoss2d(weight=ceclass_weight).to(device)
    elif args.loss == 'dice':
        criterion = MulticlassDiceLoss(weight=diceclass_weight).to(device)
    elif args.loss == 'cedice':
        criterion = CEMDiceLoss(cediceweight=cedice_weight,
                                ceclassweight=ceclass_weight,
                                diceclassweight=diceclass_weight).to(device)
    else:
        print('Do not have this loss')

    optimizer = Adam(net.parameters(), lr=args.lr, amsgrad=True)

    ## scheduler
    if args.lr_policy == 'StepLR':
        scheduler = StepLR(optimizer, step_size=30, gamma=0.5)
    if args.lr_policy == 'PolyLR':
        scheduler = PolyLR(optimizer, max_epoch=end_epoch, power=0.9)

    # training process
    logging.info('Start Training For Kidney Seg')

    for epoch in range(start_epoch, end_epoch):
        ts = time.time()

        # train3_multidomainl_normalcl
        net.train()
        train_loss = 0.
        train_dice = 0.
        train_count = 0
        for batch_idx, (inputs, _, targets) in \
                tqdm(enumerate(train_loader),total=int(len(train_loader.dataset) / args.batch_size)):
            inputs = inputs.to(device)
            targets = targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            train_count += inputs.shape[0]
            train_loss += loss.item() * inputs.shape[0]
            train_dice += Dice_fn(outputs, targets).item()

        train_loss_epoch = train_loss / float(train_count)
        train_dice_epoch = train_dice / float(train_count)
        history['train_loss'].append(train_loss_epoch)
        history['train_dice'].append(train_dice_epoch)

        # test
        net.eval()
        test_loss = 0.
        test_dice = 0.
        test_count = 0

        for batch_idx, (inputs, _, targets) in tqdm(
                enumerate(test_loader),
                total=int(len(test_loader.dataset) / args.batch_size)):
            with torch.no_grad():
                inputs = inputs.to(device)
                targets = targets.to(device)
                outputs = net(inputs)
                loss = criterion(outputs, targets)
            test_count += inputs.shape[0]
            test_loss += loss.item() * inputs.shape[0]
            test_dice += Dice_fn(outputs, targets).item()

        test_loss_epoch = test_loss / float(test_count)
        test_dice_epoch = test_dice / float(test_count)
        history['test_loss'].append(test_loss_epoch)
        history['test_dice'].append(test_dice_epoch)

        traineval_loss = 0.
        traineval_dice = 0.
        traineval_count = 0
        for batch_idx, (inputs, _, targets) in tqdm(
                enumerate(train_loader),
                total=int(len(train_loader.dataset) / args.batch_size)):
            with torch.no_grad():
                inputs = inputs.to(device)
                targets = targets.to(device)
                outputs = net(inputs)
                loss = criterion(outputs, targets)
            traineval_count += inputs.shape[0]
            traineval_loss += loss.item() * inputs.shape[0]
            traineval_dice += Dice_fn(outputs, targets).item()

        traineval_loss_epoch = traineval_loss / float(traineval_count)
        traineval_dice_epoch = traineval_dice / float(traineval_count)

        time_cost = time.time() - ts
        logging.info(
            'epoch[%d/%d]: train_loss: %.3f | test_loss: %.3f | train_dice: %.3f | test_dice: %.3f '
            '| traineval_dice: %.3f || time: %.1f' %
            (epoch + 1, end_epoch, train_loss_epoch, test_loss_epoch,
             train_dice_epoch, test_dice_epoch, traineval_dice_epoch,
             time_cost))

        if args.lr_policy != 'None':
            scheduler.step()

        if traineval_dice_epoch > besttraindice:
            besttraindice = traineval_dice_epoch
            logging.info('Best Checkpoint {} Saving...'.format(epoch + 1))

            save_model = net
            if torch.cuda.device_count() > 1:
                save_model = list(net.children())[0]
            state = {
                'net': save_model.state_dict(),
                'loss': test_loss_epoch,
                'dice': test_dice_epoch,
                'epoch': epoch + 1,
                'history': history
            }
            savecheckname = os.path.join(
                args.checkpoint,
                params_name.split('.pkl')[0] + '_besttraindice.' +
                params_name.split('.')[-1])
            torch.save(state, savecheckname)
def Train(train_root, train_csv, test_root, test_csv, traincase_csv,
          testcase_csv):

    # parameters
    args = parse_args()
    besttraincasedice = 0.0

    train_cases = pd.read_csv(traincase_csv)['Image'].tolist()
    test_cases = pd.read_csv(testcase_csv)['Image'].tolist()
    # record
    record_params(args)

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_order
    torch.manual_seed(args.torch_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.torch_seed)
    np.random.seed(args.torch_seed)
    random.seed(args.torch_seed)

    if args.cudnn == 0:
        cudnn.benchmark = False
    else:
        cudnn.benchmark = True
        cudnn.deterministic = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    num_classes = 2

    net = build_model(args.model_name, num_classes)

    params_name = '{}_r{}.pkl'.format(args.model_name, args.repetition)
    start_epoch = 0
    history = {
        'train_loss': [],
        'test_loss': [],
        'train_dice': [],
        'test_dice': []
    }
    end_epoch = start_epoch + args.num_epoch

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        net = nn.DataParallel(net)
    net.to(device)

    # data
    img_size = args.img_size
    ## train
    train_aug = Compose([
        Resize(size=(img_size, img_size)),
        ToTensor(),
        Normalize(mean=args.data_mean, std=args.data_std)
    ])
    ## test
    test_aug = train_aug

    train_dataset = prostate_seg(root=train_root,
                                 csv_file=train_csv,
                                 transform=train_aug)
    test_dataset = prostate_seg(root=test_root,
                                csv_file=test_csv,
                                transform=test_aug)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=4,
                              shuffle=True,
                              drop_last=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             num_workers=4,
                             shuffle=False)

    # loss function, optimizer and scheduler
    cedice_weight = torch.tensor(args.cedice_weight)
    ceclass_weight = torch.tensor(args.ceclass_weight)
    diceclass_weight = torch.tensor(args.diceclass_weight)

    if args.loss == 'ce':
        criterion = CrossEntropyLoss2d(weight=ceclass_weight).to(device)
    elif args.loss == 'dice':
        criterion = MulticlassDiceLoss(weight=diceclass_weight).to(device)
    elif args.loss == 'cedice':
        criterion = CEMDiceLoss(cediceweight=cedice_weight,
                                ceclassweight=ceclass_weight,
                                diceclassweight=diceclass_weight).to(device)
    else:
        print('Do not have this loss')

    optimizer = Adam(net.parameters(), lr=args.lr, amsgrad=True)

    ## scheduler
    if args.lr_policy == 'StepLR':
        scheduler = StepLR(optimizer, step_size=30, gamma=0.5)
    if args.lr_policy == 'PolyLR':
        scheduler = PolyLR(optimizer, max_epoch=end_epoch, power=0.9)

    # training process
    logging.info('Start Training For Prostate Seg')

    for epoch in range(start_epoch, end_epoch):
        ts = time.time()

        # train
        net.train()
        train_loss = 0.
        train_dice = 0.
        train_count = 0
        for batch_idx, (inputs, _, targets) in \
                tqdm(enumerate(train_loader),total=int(len(train_loader.dataset) / args.batch_size)):
            inputs = inputs.to(device)
            targets = targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            train_count += inputs.shape[0]
            train_loss += loss.item() * inputs.shape[0]
            train_dice += Dice_fn(outputs, targets).item()

        train_loss_epoch = train_loss / float(train_count)
        train_dice_epoch = train_dice / float(train_count)
        history['train_loss'].append(train_loss_epoch)
        history['train_dice'].append(train_dice_epoch)

        # test
        net.eval()
        test_loss = 0.
        test_dice = 0.
        test_count = 0

        for batch_idx, (inputs, _, targets) in tqdm(
                enumerate(test_loader),
                total=int(len(test_loader.dataset) / args.batch_size)):
            with torch.no_grad():
                inputs = inputs.to(device)
                targets = targets.to(device)
                outputs = net(inputs)
                loss = criterion(outputs, targets)
            test_count += inputs.shape[0]
            test_loss += loss.item() * inputs.shape[0]
            test_dice += Dice_fn(outputs, targets).item()

        test_loss_epoch = test_loss / float(test_count)
        test_dice_epoch = test_dice / float(test_count)
        history['test_loss'].append(test_loss_epoch)
        history['test_dice'].append(test_dice_epoch)

        testcasedices = torch.zeros(len(test_cases))
        startimgslices = torch.zeros(len(test_cases))
        for casecount in tqdm(range(len(test_cases)), total=len(test_cases)):
            caseidx = test_cases[casecount].split('.')[0]
            caseimg = [file for file in test_dataset.imgs if caseidx in file]
            caseimg.sort()
            casemask = [file for file in test_dataset.masks if caseidx in file]
            casemask.sort()
            generatedtarget = []
            target = []
            startcaseimg = int(torch.sum(startimgslices[:casecount + 1]))
            for imgidx in range(len(caseimg)):
                sample = test_dataset.__getitem__(imgidx + startcaseimg)
                input = sample[0]
                mask = sample[2]
                target.append(mask)
                with torch.no_grad():
                    input = torch.unsqueeze(input.to(device), 0)
                    output = net(input)
                    output = F.softmax(output, dim=1)
                    output = torch.argmax(output, dim=1)
                    output = output.squeeze().cpu().numpy()
                    generatedtarget.append(output)
            target = np.stack(target, axis=-1)
            generatedtarget = np.stack(generatedtarget, axis=-1)
            generatedtarget_keeplargest = keep_largest_connected_components(
                generatedtarget)
            testcasedices[casecount] = Dice3d_fn(generatedtarget_keeplargest,
                                                 target)
            if casecount + 1 < len(test_cases):
                startimgslices[casecount + 1] = len(caseimg)
        testcasedice = testcasedices.sum() / float(len(test_cases))

        traincasedices = torch.zeros(len(train_cases))
        startimgslices = torch.zeros(len(train_cases))
        generatedmask = []
        for casecount in tqdm(range(len(train_cases)), total=len(train_cases)):
            caseidx = train_cases[casecount]
            caseimg = [file for file in train_dataset.imgs if caseidx in file]
            caseimg.sort()
            casemask = [
                file for file in train_dataset.masks if caseidx in file
            ]
            casemask.sort()
            generatedtarget = []
            target = []
            startcaseimg = int(torch.sum(startimgslices[:casecount + 1]))
            for imgidx in range(len(caseimg)):
                sample = train_dataset.__getitem__(imgidx + startcaseimg)
                input = sample[0]
                mask = sample[2]
                target.append(mask)
                with torch.no_grad():
                    input = torch.unsqueeze(input.to(device), 0)
                    output = net(input)
                    output = F.softmax(output, dim=1)
                    output = torch.argmax(output, dim=1)
                    output = output.squeeze().cpu().numpy()
                    generatedtarget.append(output)
            target = np.stack(target, axis=-1)
            generatedtarget = np.stack(generatedtarget, axis=-1)
            generatedtarget_keeplargest = keep_largest_connected_components(
                generatedtarget)
            traincasedices[casecount] = Dice3d_fn(generatedtarget_keeplargest,
                                                  target)
            generatedmask.append(generatedtarget_keeplargest)
            if casecount + 1 < len(train_cases):
                startimgslices[casecount + 1] = len(caseimg)
        traincasedice = traincasedices.sum() / float(len(train_cases))

        time_cost = time.time() - ts
        logging.info(
            'epoch[%d/%d]: train_loss: %.3f | test_loss: %.3f | train_dice: %.3f | test_dice: %.3f || time: %.1f'
            % (epoch + 1, end_epoch, train_loss_epoch, test_loss_epoch,
               train_dice_epoch, test_dice_epoch, time_cost))
        logging.info(
            'epoch[%d/%d]: traincase_dice: %.3f | testcase_dice: %.3f || time: %.1f'
            % (epoch + 1, end_epoch, traincasedice, testcasedice, time_cost))

        if args.lr_policy != 'None':
            scheduler.step()

        if traincasedice > besttraincasedice:
            besttraincasedice = traincasedice
            logging.info('Best Checkpoint {} Saving...'.format(epoch + 1))

            save_model = net
            if torch.cuda.device_count() > 1:
                save_model = list(net.children())[0]
            state = {
                'net': save_model.state_dict(),
                'loss': test_loss_epoch,
                'dice': test_dice_epoch,
                'epoch': epoch + 1,
                'history': history
            }
            savecheckname = os.path.join(
                args.checkpoint,
                params_name.split('.pkl')[0] + '_besttraindice.' +
                params_name.split('.')[-1])
            torch.save(state, savecheckname)