예제 #1
0
def Train(train_root, train_csv, test_root, test_csv):

    # parameters
    args = parse_args()
    besttraindice = 0.0

    # record
    record_params(args)

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_order
    torch.manual_seed(args.torch_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.torch_seed)
    np.random.seed(args.torch_seed)
    random.seed(args.torch_seed)

    if args.cudnn == 0:
        cudnn.benchmark = False
    else:
        cudnn.benchmark = True
        cudnn.deterministic = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    num_classes = 2

    net = build_model(args.model_name, num_classes)

    # resume
    params_name = '{}_r{}.pkl'.format(args.model_name, args.repetition)
    start_epoch = 0
    history = {
        'train_loss': [],
        'test_loss': [],
        'train_dice': [],
        'test_dice': []
    }
    end_epoch = start_epoch + args.num_epoch

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        net = nn.DataParallel(net)
    net.to(device)

    # data
    img_size = args.img_size
    ## train3_multidomainl_normalcl
    train_aug = Compose([
        Resize(size=(img_size, img_size)),
        ToTensor(),
        Normalize(mean=args.data_mean, std=args.data_std)
    ])
    ## test
    test_aug = train_aug

    train_dataset = kidney_seg(root=train_root,
                               csv_file=train_csv,
                               maskidentity=args.maskidentity,
                               train=True,
                               transform=train_aug)
    test_dataset = kidney_seg(root=test_root,
                              csv_file=test_csv,
                              maskidentity=args.maskidentity,
                              train=False,
                              transform=test_aug)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=4,
                              shuffle=True,
                              drop_last=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             num_workers=4,
                             shuffle=False)

    # loss function, optimizer and scheduler
    cedice_weight = torch.tensor(args.cedice_weight)
    ceclass_weight = torch.tensor(args.ceclass_weight)
    diceclass_weight = torch.tensor(args.diceclass_weight)

    if args.loss == 'ce':
        criterion = CrossEntropyLoss2d(weight=ceclass_weight).to(device)
    elif args.loss == 'dice':
        criterion = MulticlassDiceLoss(weight=diceclass_weight).to(device)
    elif args.loss == 'cedice':
        criterion = CEMDiceLoss(cediceweight=cedice_weight,
                                ceclassweight=ceclass_weight,
                                diceclassweight=diceclass_weight).to(device)
    else:
        print('Do not have this loss')

    optimizer = Adam(net.parameters(), lr=args.lr, amsgrad=True)

    ## scheduler
    if args.lr_policy == 'StepLR':
        scheduler = StepLR(optimizer, step_size=30, gamma=0.5)
    if args.lr_policy == 'PolyLR':
        scheduler = PolyLR(optimizer, max_epoch=end_epoch, power=0.9)

    # training process
    logging.info('Start Training For Kidney Seg')

    for epoch in range(start_epoch, end_epoch):
        ts = time.time()

        # train3_multidomainl_normalcl
        net.train()
        train_loss = 0.
        train_dice = 0.
        train_count = 0
        for batch_idx, (inputs, _, targets) in \
                tqdm(enumerate(train_loader),total=int(len(train_loader.dataset) / args.batch_size)):
            inputs = inputs.to(device)
            targets = targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            train_count += inputs.shape[0]
            train_loss += loss.item() * inputs.shape[0]
            train_dice += Dice_fn(outputs, targets).item()

        train_loss_epoch = train_loss / float(train_count)
        train_dice_epoch = train_dice / float(train_count)
        history['train_loss'].append(train_loss_epoch)
        history['train_dice'].append(train_dice_epoch)

        # test
        net.eval()
        test_loss = 0.
        test_dice = 0.
        test_count = 0

        for batch_idx, (inputs, _, targets) in tqdm(
                enumerate(test_loader),
                total=int(len(test_loader.dataset) / args.batch_size)):
            with torch.no_grad():
                inputs = inputs.to(device)
                targets = targets.to(device)
                outputs = net(inputs)
                loss = criterion(outputs, targets)
            test_count += inputs.shape[0]
            test_loss += loss.item() * inputs.shape[0]
            test_dice += Dice_fn(outputs, targets).item()

        test_loss_epoch = test_loss / float(test_count)
        test_dice_epoch = test_dice / float(test_count)
        history['test_loss'].append(test_loss_epoch)
        history['test_dice'].append(test_dice_epoch)

        traineval_loss = 0.
        traineval_dice = 0.
        traineval_count = 0
        for batch_idx, (inputs, _, targets) in tqdm(
                enumerate(train_loader),
                total=int(len(train_loader.dataset) / args.batch_size)):
            with torch.no_grad():
                inputs = inputs.to(device)
                targets = targets.to(device)
                outputs = net(inputs)
                loss = criterion(outputs, targets)
            traineval_count += inputs.shape[0]
            traineval_loss += loss.item() * inputs.shape[0]
            traineval_dice += Dice_fn(outputs, targets).item()

        traineval_loss_epoch = traineval_loss / float(traineval_count)
        traineval_dice_epoch = traineval_dice / float(traineval_count)

        time_cost = time.time() - ts
        logging.info(
            'epoch[%d/%d]: train_loss: %.3f | test_loss: %.3f | train_dice: %.3f | test_dice: %.3f '
            '| traineval_dice: %.3f || time: %.1f' %
            (epoch + 1, end_epoch, train_loss_epoch, test_loss_epoch,
             train_dice_epoch, test_dice_epoch, traineval_dice_epoch,
             time_cost))

        if args.lr_policy != 'None':
            scheduler.step()

        if traineval_dice_epoch > besttraindice:
            besttraindice = traineval_dice_epoch
            logging.info('Best Checkpoint {} Saving...'.format(epoch + 1))

            save_model = net
            if torch.cuda.device_count() > 1:
                save_model = list(net.children())[0]
            state = {
                'net': save_model.state_dict(),
                'loss': test_loss_epoch,
                'dice': test_dice_epoch,
                'epoch': epoch + 1,
                'history': history
            }
            savecheckname = os.path.join(
                args.checkpoint,
                params_name.split('.pkl')[0] + '_besttraindice.' +
                params_name.split('.')[-1])
            torch.save(state, savecheckname)
예제 #2
0
def Train(train_root, train_csv, test_csv, tempmaskfolder):
    makefolder(os.path.join(train_root, tempmaskfolder))

    besttraindice = 0.0
    changepointdice = 0.0
    ascending = False

    # parameters
    args = parse_args()
    record_params(args)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_order
    torch.manual_seed(args.torch_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.torch_seed)
    np.random.seed(args.torch_seed)
    random.seed(args.torch_seed)

    if args.cudnn == 0:
        cudnn.benchmark = False
    else:
        cudnn.benchmark = True
        cudnn.deterministic = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_classes = 2
    net1 = build_model(args.model1_name, num_classes)
    net2 = build_model(args.model2_name, num_classes)

    # resume
    params1_name = '{}_warmup{}_temp{}_r{}_net1.pkl'.format(
        args.model1_name, args.warmup_epoch, args.temperature, args.repetition)
    params2_name = '{}_warmup{}_temp{}_r{}_net2.pkl'.format(
        args.model2_name, args.warmup_epoch, args.temperature, args.repetition)

    checkpoint1_path = os.path.join(args.checkpoint, params1_name)
    checkpoint2_path = os.path.join(args.checkpoint, params2_name)
    initializecheckpoint = torch.load(args.resumefile)['net']
    net1.load_state_dict(initializecheckpoint)
    net2.load_state_dict(initializecheckpoint)

    start_epoch = 0
    end_epoch = args.num_epoch
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        net1 = nn.DataParallel(net1)
        net2 = nn.DataParallel(net2)
    net1.to(device)
    net2.to(device)

    # data
    train_aug = Compose([
        Resize(size=(args.img_size, args.img_size)),
        RandomRotate(args.rotation),
        RandomHorizontallyFlip(),
        ToTensor(),
        Normalize(mean=args.data_mean, std=args.data_std)
    ])
    test_aug = Compose([
        Resize(size=(args.img_size, args.img_size)),
        ToTensor(),
        Normalize(mean=args.data_mean, std=args.data_std)
    ])

    train_dataset = kidney_seg(root=train_root,
                               csv_file=train_csv,
                               tempmaskfolder=tempmaskfolder,
                               maskidentity=args.maskidentity,
                               train=True,
                               transform=train_aug)
    test_dataset = kidney_seg(
        root=train_root,
        csv_file=test_csv,
        tempmaskfolder=tempmaskfolder,
        maskidentity=args.maskidentity,
        train=False,
        transform=test_aug)  # tempmaskfolder=tempmaskfolder,

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=4,
                              shuffle=True,
                              drop_last=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             num_workers=4,
                             shuffle=False)

    # loss function, optimizer and scheduler
    cedice_weight = torch.tensor(args.cedice_weight)
    ceclass_weight = torch.tensor(args.ceclass_weight)
    diceclass_weight = torch.tensor(args.diceclass_weight)
    if args.loss == 'ce':
        criterion = CrossEntropyLoss2d(weight=ceclass_weight).to(device)
    elif args.loss == 'dice':
        criterion = MulticlassDiceLoss(weight=diceclass_weight).to(device)
    elif args.loss == 'cedice':
        criterion = CEMDiceLossImage(
            cediceweight=cedice_weight,
            ceclassweight=ceclass_weight,
            diceclassweight=diceclass_weight).to(device)
    else:
        print('Do not have this loss')
    corrlosscriterion = MulticlassMSELoss(reduction='none').to(device)

    # define augmentation loss effect schedule
    rate_schedule = np.ones(args.num_epoch)
    optimizer1 = Adam(net1.parameters(), lr=args.lr, amsgrad=True)
    optimizer2 = Adam(net2.parameters(), lr=args.lr, amsgrad=True)

    ## scheduler
    if args.lr_policy == 'StepLR':
        scheduler1 = StepLR(optimizer1, step_size=30, gamma=0.5)
        scheduler2 = StepLR(optimizer2, step_size=30, gamma=0.5)
    if args.lr_policy == 'PolyLR':
        scheduler1 = PolyLR(optimizer1, max_epoch=end_epoch, power=0.9)
        scheduler2 = PolyLR(optimizer2, max_epoch=end_epoch, power=0.9)

    # training process
    logging.info('Start Training For Kidney Seg')

    for epoch in range(start_epoch, end_epoch):

        ts = time.time()
        if args.warmup_epoch == 0:
            rate_schedule[epoch] = 1.0
        else:
            rate_schedule[epoch] = min(
                (float(epoch) / float(args.warmup_epoch))**2, 1.0)
        net1.train()
        net2.train()
        train_loss1 = 0.
        train_dice1 = 0.
        train_count = 0
        train_loss2 = 0.
        train_dice2 = 0.
        for batch_idx, (inputs, augset, targets, targets1, targets2) in \
                tqdm(enumerate(train_loader), total=int(
                    len(train_loader.dataset) / args.batch_size)):  # (inputs, augset, targets, targets1, targets2)

            net1.eval()
            net2.eval()
            augoutput1 = []
            augoutput2 = []
            for aug_idx in range(augset['augno'][0]):
                augimg = augset['img{}'.format(aug_idx + 1)].to(device)
                augoutput1.append(net1(augimg).detach())
                augoutput2.append(net2(augimg).detach())
            #
            augoutput1 = reverseaugbatch(augset,
                                         augoutput1,
                                         classno=num_classes)
            augoutput2 = reverseaugbatch(augset,
                                         augoutput2,
                                         classno=num_classes)

            for aug_idx in range(augset['augno'][0]):
                augmask1 = torch.nn.functional.softmax(augoutput1[aug_idx],
                                                       dim=1)
                augmask2 = torch.nn.functional.softmax(augoutput2[aug_idx],
                                                       dim=1)

                if aug_idx == 0:
                    pseudo_label1 = augmask1
                    pseudo_label2 = augmask2
                else:
                    pseudo_label1 += augmask1
                    pseudo_label2 += augmask2
            pseudo_label1 = pseudo_label1 / float(augset['augno'][0])
            pseudo_label2 = pseudo_label2 / float(augset['augno'][0])
            pseudo_label1 = sharpen(pseudo_label1, args.temperature)
            pseudo_label2 = sharpen(pseudo_label2, args.temperature)
            weightmap1 = 1.0 - 4.0 * pseudo_label1[:,
                                                   0, :, :] * pseudo_label1[:,
                                                                            1, :, :]
            weightmap1 = weightmap1.unsqueeze(dim=1)
            weightmap2 = 1.0 - 4.0 * pseudo_label2[:,
                                                   0, :, :] * pseudo_label2[:,
                                                                            1, :, :]
            weightmap2 = weightmap2.unsqueeze(dim=1)
            net1.train()
            net2.train()
            inputs = inputs.to(device)
            targets = targets.to(device)
            targets1 = targets1.to(device)
            targets2 = targets2.to(device)
            outputs1 = net1(inputs)
            outputs2 = net2(inputs)

            loss1_segpre = criterion(outputs1, targets2)
            loss2_segpre = criterion(outputs2, targets1)

            _, indx1 = loss1_segpre.sort()
            _, indx2 = loss2_segpre.sort()
            loss1_seg1 = criterion(outputs1[indx2[0:2], :, :, :],
                                   targets2[indx2[0:2], :, :]).mean()
            loss2_seg1 = criterion(outputs2[indx1[0:2], :, :, :],
                                   targets1[indx1[0:2], :, :]).mean()
            loss1_seg2 = criterion(outputs1[indx2[2:], :, :, :],
                                   targets2[indx2[2:], :, :]).mean()
            loss2_seg2 = criterion(outputs2[indx1[2:], :, :, :],
                                   targets1[indx1[2:], :, :]).mean()
            loss1_cor = weightmap2[indx2[2:], :, :, :] * corrlosscriterion(
                outputs1[indx2[2:], :, :, :],
                pseudo_label2[indx2[2:], :, :, :])
            loss1_cor = loss1_cor.mean()
            loss1 = args.segcor_weight[0] * (loss1_seg1 + (1.0 - rate_schedule[epoch]) * loss1_seg2) + \
                    args.segcor_weight[1] * rate_schedule[epoch] * loss1_cor

            loss2_cor = weightmap1[indx1[2:], :, :, :] * corrlosscriterion(
                outputs2[indx1[2:], :, :, :],
                pseudo_label1[indx1[2:], :, :, :])
            loss2_cor = loss2_cor.mean()
            loss2 = args.segcor_weight[0] * (loss2_seg1 + (1.0 - rate_schedule[epoch]) * loss2_seg2) + \
                    args.segcor_weight[1] * rate_schedule[epoch] * loss2_cor

            optimizer1.zero_grad()
            optimizer2.zero_grad()

            loss1.backward(retain_graph=True)
            optimizer1.step()
            loss2.backward()
            optimizer2.step()
            train_count += inputs.shape[0]
            train_loss1 += loss1.item() * inputs.shape[0]
            train_dice1 += Dice_fn(outputs1, targets2).item()
            train_loss2 += loss2.item() * inputs.shape[0]
            train_dice2 += Dice_fn(outputs2, targets1).item()
        train_loss1_epoch = train_loss1 / float(train_count)
        train_dice1_epoch = train_dice1 / float(train_count)
        train_loss2_epoch = train_loss2 / float(train_count)
        train_dice2_epoch = train_dice2 / float(train_count)

        # test
        net1.eval()
        net2.eval()
        test_loss1 = 0.
        test_dice1 = 0.
        test_loss2 = 0.
        test_dice2 = 0.
        test_count = 0
        for batch_idx, (inputs, _, targets, targets1, targets2) in \
                tqdm(enumerate(test_loader), total=int(len(test_loader.dataset) / args.batch_size)):
            with torch.no_grad():
                inputs = inputs.to(device)
                targets = targets.to(device)
                targets1 = targets1.to(device)
                targets2 = targets2.to(device)
                outputs1 = net1(inputs)
                outputs2 = net2(inputs)
                loss1 = criterion(outputs1, targets2).mean()
                loss2 = criterion(outputs2, targets1).mean()
            test_count += inputs.shape[0]
            test_loss1 += loss1.item() * inputs.shape[0]
            test_dice1 += Dice_fn(outputs1, targets2).item()
            test_loss2 += loss2.item() * inputs.shape[0]
            test_dice2 += Dice_fn(outputs2, targets1).item()
        test_loss1_epoch = test_loss1 / float(test_count)
        test_dice1_epoch = test_dice1 / float(test_count)
        test_loss2_epoch = test_loss2 / float(test_count)
        test_dice2_epoch = test_dice2 / float(test_count)

        traindices1 = torch.zeros(len(train_dataset))
        traindices2 = torch.zeros(len(train_dataset))
        generatedmask1 = []
        generatedmask2 = []
        for casecount in tqdm(range(len(train_dataset)),
                              total=len(train_dataset)):
            sample = train_dataset.__getitem__(casecount)
            img = sample[0]
            mask1 = sample[4]
            mask2 = sample[3]
            with torch.no_grad():
                img = torch.unsqueeze(img.to(device), 0)
                output1 = net1(img)
                output1 = F.softmax(output1, dim=1)
                output2 = net2(img)
                output2 = F.softmax(output2, dim=1)
            output1 = torch.argmax(output1, dim=1)
            output2 = torch.argmax(output2, dim=1)
            output1 = output1.squeeze().cpu()
            generatedoutput1 = output1.unsqueeze(dim=0).numpy()
            output2 = output2.squeeze().cpu()
            generatedoutput2 = output2.unsqueeze(dim=0).numpy()
            traindices1[casecount] = Dice2d(generatedoutput1, mask1.numpy())
            traindices2[casecount] = Dice2d(generatedoutput2, mask2.numpy())
            generatedmask1.append(generatedoutput1)
            generatedmask2.append(generatedoutput2)
        evaltrainavgdice1 = traindices1.sum() / float(len(train_dataset))
        evaltrainavgdice2 = traindices2.sum() / float(len(train_dataset))
        evaltrainavgdicetemp = (evaltrainavgdice1 + evaltrainavgdice2) / 2.0
        maskannotations = {
            '1': train_dataset.mask1,
            '2': train_dataset.mask2,
            '3': train_dataset.mask3
        }

        # update pseudolabel
        if (epoch + 1) <= args.warmup_epoch or (epoch + 1) % 10 == 0:
            avgdice = evaltrainavgdicetemp
            selected_samples = int(args.update_percent * len(train_dataset))
            save_root = os.path.join(train_root, tempmaskfolder)
            _, sortidx1 = traindices1.sort()
            selectedidxs = sortidx1[:selected_samples]
            for selectedidx in selectedidxs:
                maskname = maskannotations['{}'.format(int(
                    args.maskidentity))][selectedidx]
                savefolder = os.path.join(save_root, maskname.split('/')[-2])
                makefolder(savefolder)
                save_name = os.path.join(
                    savefolder,
                    maskname.split('/')[-1].split('.')[0] + '_net1.nii.gz')
                save_data = generatedmask1[selectedidx]
                if save_data.sum() > 0:
                    soutput = sitk.GetImageFromArray(save_data)
                    sitk.WriteImage(soutput, save_name)
            logging.info('{} masks modified for net1'.format(
                len(selectedidxs)))

            _, sortidx2 = traindices2.sort()
            selectedidxs = sortidx2[:selected_samples]
            for selectedidx in selectedidxs:
                maskname = maskannotations['{}'.format(int(
                    args.maskidentity))][selectedidx]
                savefolder = os.path.join(save_root, maskname.split('/')[-2])
                makefolder(savefolder)
                save_name = os.path.join(
                    savefolder,
                    maskname.split('/')[-1].split('.')[0] + '_net2.nii.gz')
                save_data = generatedmask2[selectedidx]
                if save_data.sum() > 0:
                    soutput = sitk.GetImageFromArray(save_data)
                    sitk.WriteImage(soutput, save_name)
            logging.info('{} masks modify for net2'.format(len(selectedidxs)))

        if epoch > 0 and changepointdice < evaltrainavgdicetemp and ascending == False:
            ascending = True
            besttraindice = changepointdice

        if evaltrainavgdicetemp > besttraindice and ascending:
            besttraindice = evaltrainavgdicetemp
            logging.info('Best Checkpoint {} Saving...'.format(epoch + 1))
            save_model = net1
            if torch.cuda.device_count() > 1:
                save_model = list(net1.children())[0]
            state = {
                'net': save_model.state_dict(),
                'loss': test_loss1_epoch,
                'epoch': epoch + 1,
            }
            torch.save(
                state, '{}_besttraindice.pkl'.format(
                    checkpoint1_path.split('.pkl')[0]))

            save_model = net2
            if torch.cuda.device_count() > 1:
                save_model = list(net2.children())[0]
            state = {
                'net': save_model.state_dict(),
                'loss': test_loss2_epoch,
                'epoch': epoch + 1,
            }
            torch.save(
                state, '{}_besttraindice.pkl'.format(
                    checkpoint2_path.split('.pkl')[0]))

        if not ascending:
            changepointdice = evaltrainavgdicetemp

        time_cost = time.time() - ts
        logging.info(
            'epoch[%d/%d]: train_loss1: %.3f | test_loss1: %.3f | '
            'train_dice1: %.3f | test_dice1: %.3f || time: %.1f' %
            (epoch + 1, end_epoch, train_loss1_epoch, test_loss1_epoch,
             train_dice1_epoch, test_dice1_epoch, time_cost))
        logging.info(
            'epoch[%d/%d]: train_loss2: %.3f | test_loss2: %.3f | '
            'train_dice2: %.3f | test_dice2: %.3f || time: %.1f' %
            (epoch + 1, end_epoch, train_loss2_epoch, test_loss2_epoch,
             train_dice2_epoch, test_dice2_epoch, time_cost))
        logging.info(
            'epoch[%d/%d]: evaltrain_dice1: %.3f | evaltrain_dice2: %.3f || time: %.1f'
            % (epoch + 1, end_epoch, evaltrainavgdice1, evaltrainavgdice2,
               time_cost))

        net1.train()
        net2.train()
        if args.lr_policy != 'None':
            scheduler1.step()
            scheduler2.step()
def Train(train_root, train_csv, test_root, test_csv, traincase_csv,
          testcase_csv):

    # parameters
    args = parse_args()
    besttraincasedice = 0.0

    train_cases = pd.read_csv(traincase_csv)['Image'].tolist()
    test_cases = pd.read_csv(testcase_csv)['Image'].tolist()
    # record
    record_params(args)

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_order
    torch.manual_seed(args.torch_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.torch_seed)
    np.random.seed(args.torch_seed)
    random.seed(args.torch_seed)

    if args.cudnn == 0:
        cudnn.benchmark = False
    else:
        cudnn.benchmark = True
        cudnn.deterministic = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    num_classes = 2

    net = build_model(args.model_name, num_classes)

    params_name = '{}_r{}.pkl'.format(args.model_name, args.repetition)
    start_epoch = 0
    history = {
        'train_loss': [],
        'test_loss': [],
        'train_dice': [],
        'test_dice': []
    }
    end_epoch = start_epoch + args.num_epoch

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        net = nn.DataParallel(net)
    net.to(device)

    # data
    img_size = args.img_size
    ## train
    train_aug = Compose([
        Resize(size=(img_size, img_size)),
        ToTensor(),
        Normalize(mean=args.data_mean, std=args.data_std)
    ])
    ## test
    test_aug = train_aug

    train_dataset = prostate_seg(root=train_root,
                                 csv_file=train_csv,
                                 transform=train_aug)
    test_dataset = prostate_seg(root=test_root,
                                csv_file=test_csv,
                                transform=test_aug)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=4,
                              shuffle=True,
                              drop_last=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             num_workers=4,
                             shuffle=False)

    # loss function, optimizer and scheduler
    cedice_weight = torch.tensor(args.cedice_weight)
    ceclass_weight = torch.tensor(args.ceclass_weight)
    diceclass_weight = torch.tensor(args.diceclass_weight)

    if args.loss == 'ce':
        criterion = CrossEntropyLoss2d(weight=ceclass_weight).to(device)
    elif args.loss == 'dice':
        criterion = MulticlassDiceLoss(weight=diceclass_weight).to(device)
    elif args.loss == 'cedice':
        criterion = CEMDiceLoss(cediceweight=cedice_weight,
                                ceclassweight=ceclass_weight,
                                diceclassweight=diceclass_weight).to(device)
    else:
        print('Do not have this loss')

    optimizer = Adam(net.parameters(), lr=args.lr, amsgrad=True)

    ## scheduler
    if args.lr_policy == 'StepLR':
        scheduler = StepLR(optimizer, step_size=30, gamma=0.5)
    if args.lr_policy == 'PolyLR':
        scheduler = PolyLR(optimizer, max_epoch=end_epoch, power=0.9)

    # training process
    logging.info('Start Training For Prostate Seg')

    for epoch in range(start_epoch, end_epoch):
        ts = time.time()

        # train
        net.train()
        train_loss = 0.
        train_dice = 0.
        train_count = 0
        for batch_idx, (inputs, _, targets) in \
                tqdm(enumerate(train_loader),total=int(len(train_loader.dataset) / args.batch_size)):
            inputs = inputs.to(device)
            targets = targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            train_count += inputs.shape[0]
            train_loss += loss.item() * inputs.shape[0]
            train_dice += Dice_fn(outputs, targets).item()

        train_loss_epoch = train_loss / float(train_count)
        train_dice_epoch = train_dice / float(train_count)
        history['train_loss'].append(train_loss_epoch)
        history['train_dice'].append(train_dice_epoch)

        # test
        net.eval()
        test_loss = 0.
        test_dice = 0.
        test_count = 0

        for batch_idx, (inputs, _, targets) in tqdm(
                enumerate(test_loader),
                total=int(len(test_loader.dataset) / args.batch_size)):
            with torch.no_grad():
                inputs = inputs.to(device)
                targets = targets.to(device)
                outputs = net(inputs)
                loss = criterion(outputs, targets)
            test_count += inputs.shape[0]
            test_loss += loss.item() * inputs.shape[0]
            test_dice += Dice_fn(outputs, targets).item()

        test_loss_epoch = test_loss / float(test_count)
        test_dice_epoch = test_dice / float(test_count)
        history['test_loss'].append(test_loss_epoch)
        history['test_dice'].append(test_dice_epoch)

        testcasedices = torch.zeros(len(test_cases))
        startimgslices = torch.zeros(len(test_cases))
        for casecount in tqdm(range(len(test_cases)), total=len(test_cases)):
            caseidx = test_cases[casecount].split('.')[0]
            caseimg = [file for file in test_dataset.imgs if caseidx in file]
            caseimg.sort()
            casemask = [file for file in test_dataset.masks if caseidx in file]
            casemask.sort()
            generatedtarget = []
            target = []
            startcaseimg = int(torch.sum(startimgslices[:casecount + 1]))
            for imgidx in range(len(caseimg)):
                sample = test_dataset.__getitem__(imgidx + startcaseimg)
                input = sample[0]
                mask = sample[2]
                target.append(mask)
                with torch.no_grad():
                    input = torch.unsqueeze(input.to(device), 0)
                    output = net(input)
                    output = F.softmax(output, dim=1)
                    output = torch.argmax(output, dim=1)
                    output = output.squeeze().cpu().numpy()
                    generatedtarget.append(output)
            target = np.stack(target, axis=-1)
            generatedtarget = np.stack(generatedtarget, axis=-1)
            generatedtarget_keeplargest = keep_largest_connected_components(
                generatedtarget)
            testcasedices[casecount] = Dice3d_fn(generatedtarget_keeplargest,
                                                 target)
            if casecount + 1 < len(test_cases):
                startimgslices[casecount + 1] = len(caseimg)
        testcasedice = testcasedices.sum() / float(len(test_cases))

        traincasedices = torch.zeros(len(train_cases))
        startimgslices = torch.zeros(len(train_cases))
        generatedmask = []
        for casecount in tqdm(range(len(train_cases)), total=len(train_cases)):
            caseidx = train_cases[casecount]
            caseimg = [file for file in train_dataset.imgs if caseidx in file]
            caseimg.sort()
            casemask = [
                file for file in train_dataset.masks if caseidx in file
            ]
            casemask.sort()
            generatedtarget = []
            target = []
            startcaseimg = int(torch.sum(startimgslices[:casecount + 1]))
            for imgidx in range(len(caseimg)):
                sample = train_dataset.__getitem__(imgidx + startcaseimg)
                input = sample[0]
                mask = sample[2]
                target.append(mask)
                with torch.no_grad():
                    input = torch.unsqueeze(input.to(device), 0)
                    output = net(input)
                    output = F.softmax(output, dim=1)
                    output = torch.argmax(output, dim=1)
                    output = output.squeeze().cpu().numpy()
                    generatedtarget.append(output)
            target = np.stack(target, axis=-1)
            generatedtarget = np.stack(generatedtarget, axis=-1)
            generatedtarget_keeplargest = keep_largest_connected_components(
                generatedtarget)
            traincasedices[casecount] = Dice3d_fn(generatedtarget_keeplargest,
                                                  target)
            generatedmask.append(generatedtarget_keeplargest)
            if casecount + 1 < len(train_cases):
                startimgslices[casecount + 1] = len(caseimg)
        traincasedice = traincasedices.sum() / float(len(train_cases))

        time_cost = time.time() - ts
        logging.info(
            'epoch[%d/%d]: train_loss: %.3f | test_loss: %.3f | train_dice: %.3f | test_dice: %.3f || time: %.1f'
            % (epoch + 1, end_epoch, train_loss_epoch, test_loss_epoch,
               train_dice_epoch, test_dice_epoch, time_cost))
        logging.info(
            'epoch[%d/%d]: traincase_dice: %.3f | testcase_dice: %.3f || time: %.1f'
            % (epoch + 1, end_epoch, traincasedice, testcasedice, time_cost))

        if args.lr_policy != 'None':
            scheduler.step()

        if traincasedice > besttraincasedice:
            besttraincasedice = traincasedice
            logging.info('Best Checkpoint {} Saving...'.format(epoch + 1))

            save_model = net
            if torch.cuda.device_count() > 1:
                save_model = list(net.children())[0]
            state = {
                'net': save_model.state_dict(),
                'loss': test_loss_epoch,
                'dice': test_dice_epoch,
                'epoch': epoch + 1,
                'history': history
            }
            savecheckname = os.path.join(
                args.checkpoint,
                params_name.split('.pkl')[0] + '_besttraindice.' +
                params_name.split('.')[-1])
            torch.save(state, savecheckname)
def Train(train_root, train_csv, test_csv, traincase_csv, testcase_csv,
          labelcase_csv, tempmaskfolder):
    makefolder(os.path.join(train_root, tempmaskfolder))

    # parameters
    args = parse_args()

    # record
    record_params(args)

    train_cases = pd.read_csv(traincase_csv)['patient_case'].tolist()
    test_cases = pd.read_csv(testcase_csv)['patient_case'].tolist()
    label_cases = pd.read_csv(labelcase_csv)['patient_case'].tolist()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_order
    torch.manual_seed(args.torch_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.torch_seed)
    np.random.seed(args.torch_seed)
    random.seed(args.torch_seed)

    if args.cudnn == 0:
        cudnn.benchmark = False
    else:
        cudnn.benchmark = True
        cudnn.deterministic = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    num_classes = 2

    net1 = build_model(args.model_name, num_classes)
    net2 = build_model(args.model_name, num_classes)

    params1_name = '{}_temp{}_r{}_net1.pkl'.format(args.model_name,
                                                   args.temperature,
                                                   args.repetition)
    params2_name = '{}_temp{}_r{}_net2.pkl'.format(args.model_name,
                                                   args.temperature,
                                                   args.repetition)
    start_epoch = 0
    end_epoch = args.num_epoch

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        net1 = nn.DataParallel(net1)
        net2 = nn.DataParallel(net2)
    net1.to(device)
    net2.to(device)

    # data
    train_aug = Compose([
        Resize(size=(args.img_size, args.img_size)),
        RandomRotate(args.rotation),
        RandomHorizontallyFlip(),
        ToTensor(),
        Normalize(mean=args.data_mean, std=args.data_std)
    ])
    test_aug = Compose([
        Resize(size=(args.img_size, args.img_size)),
        ToTensor(),
        Normalize(mean=args.data_mean, std=args.data_std)
    ])

    train_dataset = chaos_seg(root=train_root,
                              csv_file=train_csv,
                              tempmaskfolder=tempmaskfolder,
                              transform=train_aug)
    test_dataset = chaos_seg(root=train_root,
                             csv_file=test_csv,
                             tempmaskfolder=tempmaskfolder,
                             transform=test_aug)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=4,
                              shuffle=True,
                              drop_last=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             num_workers=4,
                             shuffle=False)

    # loss function, optimizer and scheduler
    cedice_weight = torch.tensor(args.cedice_weight)
    ceclass_weight = torch.tensor(args.ceclass_weight)
    diceclass_weight = torch.tensor(args.diceclass_weight)

    if args.loss == 'ce':
        criterion = CrossEntropyLoss2d(weight=ceclass_weight).to(device)
    elif args.loss == 'dice':
        criterion = MulticlassDiceLoss(weight=diceclass_weight).to(device)
    elif args.loss == 'cedice':
        criterion = CEMDiceLossImage(
            cediceweight=cedice_weight,
            ceclassweight=ceclass_weight,
            diceclassweight=diceclass_weight).to(device)
    else:
        print('Do not have this loss')
    corrlosscriterion = MulticlassMSELoss(reduction='none').to(device)

    # define augmentation loss effect schedule
    rate_schedule = np.ones(args.num_epoch)

    optimizer1 = Adam(net1.parameters(), lr=args.lr, amsgrad=True)
    optimizer2 = Adam(net2.parameters(), lr=args.lr, amsgrad=True)

    ## scheduler
    if args.lr_policy == 'StepLR':
        scheduler1 = StepLR(optimizer1, step_size=30, gamma=0.5)
        scheduler2 = StepLR(optimizer2, step_size=30, gamma=0.5)
    if args.lr_policy == 'PolyLR':
        scheduler1 = PolyLR(optimizer1, max_epoch=end_epoch, power=0.9)
        scheduler2 = PolyLR(optimizer2, max_epoch=end_epoch, power=0.9)

    # training process
    logging.info('Start Training For CHAOS Seg')
    besttraincasedice = 0.0
    for epoch in range(start_epoch, end_epoch):

        ts = time.time()
        rate_schedule[epoch] = min(
            (float(epoch) / float(args.warmup_epoch))**2, 1.0)

        # train
        net1.train()
        net2.train()

        train_loss1 = 0.
        train_dice1 = 0.
        train_count = 0
        train_loss2 = 0.
        train_dice2 = 0.

        for batch_idx, (inphase, outphase, augset, targets, targets1, targets2) in \
                tqdm(enumerate(train_loader), total=int(len(train_loader.dataset) / args.batch_size)):

            augoutput1 = []
            augoutput2 = []
            for aug_idx in range(augset['augno'][0]):
                augimgin = augset['imgmodal1{}'.format(aug_idx + 1)].to(device)
                augimgout = augset['imgmodal2{}'.format(aug_idx +
                                                        1)].to(device)
                augoutput1.append(net1(augimgin, augimgout).detach())
                augoutput2.append(net2(augimgin, augimgout).detach())

            augoutput1 = reverseaug(augset, augoutput1, classno=num_classes)
            augoutput2 = reverseaug(augset, augoutput2, classno=num_classes)

            for aug_idx in range(augset['augno'][0]):
                augmask1 = torch.nn.functional.softmax(augoutput1[aug_idx],
                                                       dim=1)
                augmask2 = torch.nn.functional.softmax(augoutput2[aug_idx],
                                                       dim=1)

                if aug_idx == 0:
                    pseudo_label1 = augmask1
                    pseudo_label2 = augmask2
                else:
                    pseudo_label1 += augmask1
                    pseudo_label2 += augmask2

            pseudo_label1 = pseudo_label1 / float(augset['augno'][0])
            pseudo_label2 = pseudo_label2 / float(augset['augno'][0])
            pseudo_label1 = sharpen(pseudo_label1, args.temperature)
            pseudo_label2 = sharpen(pseudo_label2, args.temperature)
            weightmap1 = 1.0 - 4.0 * pseudo_label1[:,
                                                   0, :, :] * pseudo_label1[:,
                                                                            1, :, :]
            weightmap1 = weightmap1.unsqueeze(dim=1)
            weightmap2 = 1.0 - 4.0 * pseudo_label2[:,
                                                   0, :, :] * pseudo_label2[:,
                                                                            1, :, :]
            weightmap2 = weightmap2.unsqueeze(dim=1)

            inphase = inphase.to(device)
            outphase = outphase.to(device)
            targets1 = targets1[:, 1, :, :].to(device)
            targets2 = targets2[:, 1, :, :].to(device)
            optimizer1.zero_grad()
            optimizer2.zero_grad()

            outputs1 = net1(inphase, outphase)
            outputs2 = net2(inphase, outphase)
            loss1_segpre = criterion(outputs1, targets2)
            loss2_segpre = criterion(outputs2, targets1)
            _, indx1 = loss1_segpre.sort()
            _, indx2 = loss2_segpre.sort()
            loss1_seg1 = criterion(outputs1[indx2[0:2], :, :, :],
                                   targets2[indx2[0:2], :, :]).mean()
            loss2_seg1 = criterion(outputs2[indx1[0:2], :, :, :],
                                   targets1[indx1[0:2], :, :]).mean()
            loss1_seg2 = criterion(outputs1[indx2[2:], :, :, :],
                                   targets2[indx2[2:], :, :]).mean()
            loss2_seg2 = criterion(outputs2[indx1[2:], :, :, :],
                                   targets1[indx1[2:], :, :]).mean()
            loss1_cor = weightmap2[indx2[2:], :, :, :] * corrlosscriterion(
                outputs1[indx2[2:], :, :, :],
                pseudo_label2[indx2[2:], :, :, :])
            loss1_cor = loss1_cor.mean()
            loss1 = args.segcor_weight[0] * (loss1_seg1 + (1.0 - rate_schedule[epoch]) * loss1_seg2) + \
                    args.segcor_weight[1] * rate_schedule[epoch] * loss1_cor

            loss2_cor = weightmap1[indx1[2:], :, :, :] * corrlosscriterion(
                outputs2[indx1[2:], :, :, :],
                pseudo_label1[indx1[2:], :, :, :])
            loss2_cor = loss2_cor.mean()
            loss2 = args.segcor_weight[0] * (loss2_seg1 + (1.0 - rate_schedule[epoch]) * loss2_seg2) + \
                    args.segcor_weight[1] * rate_schedule[epoch] * loss2_cor
            loss1.backward(retain_graph=True)
            optimizer1.step()
            loss2.backward()
            optimizer2.step()
            train_count += inphase.shape[0]
            train_loss1 += loss1.item() * inphase.shape[0]
            train_dice1 += Dice_fn(outputs1, targets2).item()
            train_loss2 += loss2.item() * inphase.shape[0]
            train_dice2 += Dice_fn(outputs2, targets1).item()
        train_loss1_epoch = train_loss1 / float(train_count)
        train_dice1_epoch = train_dice1 / float(train_count)
        train_loss2_epoch = train_loss2 / float(train_count)
        train_dice2_epoch = train_dice2 / float(train_count)

        print(rate_schedule[epoch])
        print(args.segcor_weight[0] *
              (loss1_seg1 + (1.0 - rate_schedule[epoch]) * loss1_seg2))
        print(args.segcor_weight[1] * rate_schedule[epoch] * loss1_cor)

        print(args.segcor_weight[0] *
              (loss2_seg1 + (1.0 - rate_schedule[epoch]) * loss2_seg2))
        print(args.segcor_weight[1] * rate_schedule[epoch] * loss2_cor)

        # test
        net1.eval()
        net2.eval()
        test_loss1 = 0.
        test_dice1 = 0.
        test_loss2 = 0.
        test_dice2 = 0.
        test_count = 0
        for batch_idx, (inphase, outphase, augset, targets, targets1, targets2) in \
                tqdm(enumerate(test_loader), total=int(len(test_loader.dataset) / args.batch_size)):
            with torch.no_grad():
                inphase = inphase.to(device)
                outphase = outphase.to(device)
                targets1 = targets1[:, 1, :, :].to(device)
                targets2 = targets2[:, 1, :, :].to(device)
                outputs1 = net1(inphase, outphase)
                outputs2 = net2(inphase, outphase)
                loss1 = criterion(outputs1, targets2).mean()
                loss2 = criterion(outputs2, targets1).mean()
            test_count += inphase.shape[0]
            test_loss1 += loss1.item() * inphase.shape[0]
            test_dice1 += Dice_fn(outputs1, targets2).item()
            test_loss2 += loss2.item() * inphase.shape[0]
            test_dice2 += Dice_fn(outputs2, targets1).item()

        test_loss1_epoch = test_loss1 / float(test_count)
        test_dice1_epoch = test_dice1 / float(test_count)
        test_loss2_epoch = test_loss2 / float(test_count)
        test_dice2_epoch = test_dice2 / float(test_count)

        testcasedices1 = torch.zeros(len(test_cases))
        testcasedices2 = torch.zeros(len(test_cases))
        startimgslices = torch.zeros(len(test_cases))
        for casecount in tqdm(range(len(test_cases)), total=len(test_cases)):
            caseidx = test_cases[casecount]
            caseinphaseimg = [
                file for file in test_dataset.t1inphase
                if int(file.split('/')[0]) == caseidx
            ]
            caseinphaseimg.sort()
            caseoutphaseimg = [
                file for file in test_dataset.t1outphase
                if int(file.split('/')[0]) == caseidx
            ]
            caseoutphaseimg.sort()
            casemask = [
                file for file in test_dataset.masks
                if int(file.split('/')[0]) == caseidx
            ]
            casemask.sort()
            generatedtarget1 = []
            generatedtarget2 = []
            target1 = []
            target2 = []
            startcaseimg = int(torch.sum(startimgslices[:casecount + 1]))
            for imgidx in range(len(caseinphaseimg)):
                assert caseinphaseimg[imgidx].split('/')[-1].split('.')[0] == \
                       casemask[imgidx].split('/')[-1].split('.')[0]
                assert caseinphaseimg[imgidx].split('/')[-1].split('-')[1] == \
                       caseoutphaseimg[imgidx].split('/')[-1].split('-')[1]
                assert int(caseinphaseimg[imgidx].split('/')[-1].split('-')[-1].split('.')[0]) == \
                       int(caseoutphaseimg[imgidx].split('/')[-1].split('-')[-1].split('.')[0]) + 1
                sample = test_dataset.__getitem__(imgidx + startcaseimg)
                inphase = sample[0]
                outphase = sample[1]
                mask1 = sample[3]
                mask2 = sample[4]
                target1.append(mask1[1, :, :])
                target2.append(mask2[1, :, :])
                with torch.no_grad():
                    inphase = torch.unsqueeze(inphase.to(device), 0)
                    outphase = torch.unsqueeze(outphase.to(device), 0)
                    output1 = net1(inphase, outphase)
                    output1 = F.softmax(output1, dim=1)
                    output1 = torch.argmax(output1, dim=1)
                    output1 = output1.squeeze().cpu().numpy()
                    generatedtarget1.append(output1)
                    output2 = net2(inphase, outphase)
                    output2 = F.softmax(output2, dim=1)
                    output2 = torch.argmax(output2, dim=1)
                    output2 = output2.squeeze().cpu().numpy()
                    generatedtarget2.append(output2)
            target1 = np.stack(target1, axis=-1)
            target2 = np.stack(target2, axis=-1)
            generatedtarget1 = np.stack(generatedtarget1, axis=-1)
            generatedtarget2 = np.stack(generatedtarget2, axis=-1)
            generatedtarget1_keeplargest = keep_largest_connected_components(
                generatedtarget1)
            generatedtarget2_keeplargest = keep_largest_connected_components(
                generatedtarget2)
            testcasedices1[casecount] = Dice3d_fn(generatedtarget1_keeplargest,
                                                  target1)
            testcasedices2[casecount] = Dice3d_fn(generatedtarget2_keeplargest,
                                                  target2)
            if casecount + 1 < len(test_cases):
                startimgslices[casecount + 1] = len(caseinphaseimg)
        testcasedice1 = testcasedices1.sum() / float(len(test_cases))
        testcasedice2 = testcasedices2.sum() / float(len(test_cases))

        traincasedices1 = torch.zeros(len(train_cases))
        traincasedices2 = torch.zeros(len(train_cases))
        # update pseudolabel
        startimgslices = torch.zeros(len(train_cases))
        generatedmask1 = []
        generatedmask2 = []
        for casecount in tqdm(range(len(train_cases)), total=len(train_cases)):
            caseidx = train_cases[casecount]
            caseinphaseimg = [
                file for file in train_dataset.t1inphase
                if int(file.split('/')[0]) == caseidx
            ]
            caseinphaseimg.sort()
            caseoutphaseimg = [
                file for file in train_dataset.t1outphase
                if int(file.split('/')[0]) == caseidx
            ]
            caseoutphaseimg.sort()
            if caseidx in label_cases:
                casemask = [
                    file for file in train_dataset.masks
                    if file.split('/')[0].isdigit()
                ]
                casemask = [
                    file for file in casemask
                    if int(file.split('/')[0]) == caseidx
                ]
            else:
                casemask = [
                    file for file in train_dataset.masks
                    if file.split('/')[-2].isdigit()
                ]
                casemask = [
                    file for file in casemask
                    if int(file.split('/')[-2]) == caseidx
                ]
            casemask.sort()
            generatedtarget1 = []
            generatedtarget2 = []
            target1 = []
            target2 = []
            startcaseimg = int(torch.sum(startimgslices[:casecount + 1]))
            for imgidx in range(len(caseinphaseimg)):
                assert caseinphaseimg[imgidx].split('/')[-1].split('.')[0] == \
                       casemask[imgidx].split('/')[-1].split('.')[0]
                assert caseinphaseimg[imgidx].split('/')[-1].split('-')[1] == \
                       caseoutphaseimg[imgidx].split('/')[-1].split('-')[1]
                assert int(caseinphaseimg[imgidx].split('/')[-1].split('-')[-1].split('.')[0]) == \
                       int(caseoutphaseimg[imgidx].split('/')[-1].split('-')[-1].split('.')[0]) + 1
                sample = train_dataset.__getitem__(imgidx + startcaseimg)
                inphase = sample[0]
                outphase = sample[1]
                mask1 = sample[3]
                mask2 = sample[4]
                target1.append(mask1[1, :, :])
                target2.append(mask2[1, :, :])
                with torch.no_grad():
                    inphase = torch.unsqueeze(inphase.to(device), 0)
                    outphase = torch.unsqueeze(outphase.to(device), 0)
                    output1 = net1(inphase, outphase)
                    output1 = F.softmax(output1, dim=1)
                    output1 = torch.argmax(output1, dim=1)
                    output1 = output1.squeeze().cpu().numpy()
                    generatedtarget1.append(output1)

                    output2 = net2(inphase, outphase)
                    output2 = F.softmax(output2, dim=1)
                    output2 = torch.argmax(output2, dim=1)
                    output2 = output2.squeeze().cpu().numpy()
                    generatedtarget2.append(output2)

            target1 = np.stack(target1, axis=-1)
            target2 = np.stack(target2, axis=-1)
            generatedtarget1 = np.stack(generatedtarget1, axis=-1)
            generatedtarget2 = np.stack(generatedtarget2, axis=-1)
            generatedtarget1_keeplargest = keep_largest_connected_components(
                generatedtarget1)
            generatedtarget2_keeplargest = keep_largest_connected_components(
                generatedtarget2)
            traincasedices1[casecount] = Dice3d_fn(
                generatedtarget1_keeplargest, target1)
            traincasedices2[casecount] = Dice3d_fn(
                generatedtarget2_keeplargest, target2)
            generatedmask1.append(generatedtarget1_keeplargest)
            generatedmask2.append(generatedtarget2_keeplargest)
            if casecount + 1 < len(train_cases):
                startimgslices[casecount + 1] = len(caseinphaseimg)

        traincasedice1 = traincasedices1.sum() / float(len(train_cases))
        traincasedice2 = traincasedices2.sum() / float(len(train_cases))

        traincasediceavgtemp = (traincasedice1 + traincasedice2) / 2.0

        if traincasediceavgtemp > besttraincasedice:
            besttraincasedice = traincasediceavgtemp
            logging.info('Best Checkpoint {} Saving...'.format(epoch + 1))

            save_model = net1
            if torch.cuda.device_count() > 1:
                save_model = list(net1.children())[0]
            state = {
                'net': save_model.state_dict(),
                'loss': test_loss1_epoch,
                'epoch': epoch + 1,
            }
            savecheckname = os.path.join(
                args.checkpoint,
                params1_name.split('.pkl')[0] + '_besttraincasedice.' +
                params1_name.split('.')[-1])
            torch.save(state, savecheckname)

            save_model = net2
            if torch.cuda.device_count() > 1:
                save_model = list(net2.children())[0]
            state = {
                'net': save_model.state_dict(),
                'loss': test_loss2_epoch,
                'epoch': epoch + 1,
            }
            savecheckname = os.path.join(
                args.checkpoint,
                params2_name.split('.pkl')[0] + '_besttraincasedicde.' +
                params2_name.split('.')[-1])
            torch.save(state, savecheckname)

        if (epoch + 1) <= args.warmup_epoch or (epoch + 1) % 10 == 0:
            selected_samples = int(0.25 * len(train_cases))
            save_root = os.path.join(train_root, tempmaskfolder)
            _, sortidx1 = traincasedices1.sort()
            selectedidxs = sortidx1[:selected_samples]
            for selectedidx in selectedidxs:
                caseidx = train_cases[selectedidx]
                if caseidx not in label_cases:
                    caseinphaseimg = [
                        file for file in train_dataset.t1inphase
                        if int(file.split('/')[0]) == caseidx
                    ]
                    caseinphaseimg.sort()
                    caseoutphaseimg = [
                        file for file in train_dataset.t1outphase
                        if int(file.split('/')[0]) == caseidx
                    ]
                    caseoutphaseimg.sort()
                    casemask = [
                        file for file in train_dataset.masks
                        if file.split('/')[-2].isdigit()
                    ]
                    casemask = [
                        file for file in casemask
                        if int(file.split('/')[-2]) == caseidx
                    ]
                    casemask.sort()
                    for imgidx in range(len(caseinphaseimg)):
                        save_folder = os.path.join(save_root, str(caseidx))
                        makefolder(save_folder)
                        save_name = os.path.join(
                            save_folder,
                            casemask[imgidx].split('/')[-1].split('.')[0] +
                            '_net1.png')
                        save_data = generatedmask1[selectedidx][:, :, imgidx]
                        output_pil = save_data * 63
                        output_pil = Image.fromarray(
                            output_pil.astype(np.uint8), 'L')
                        output_pil.save(save_name)
            logging.info('Mask {} modify for net1'.format(
                [train_cases[i] for i in selectedidxs]))

            _, sortidx2 = traincasedices2.sort()
            selectedidxs = sortidx2[:selected_samples]
            for selectedidx in selectedidxs:
                caseidx = train_cases[selectedidx]
                if caseidx not in label_cases:
                    caseinphaseimg = [
                        file for file in train_dataset.t1inphase
                        if int(file.split('/')[0]) == caseidx
                    ]
                    caseinphaseimg.sort()
                    caseoutphaseimg = [
                        file for file in train_dataset.t1outphase
                        if int(file.split('/')[0]) == caseidx
                    ]
                    caseoutphaseimg.sort()
                    casemask = [
                        file for file in train_dataset.masks
                        if file.split('/')[-2].isdigit()
                    ]
                    casemask = [
                        file for file in casemask
                        if int(file.split('/')[-2]) == caseidx
                    ]
                    casemask.sort()
                    for imgidx in range(len(caseinphaseimg)):
                        save_folder = os.path.join(save_root, str(caseidx))
                        makefolder(save_folder)
                        save_name = os.path.join(
                            save_folder,
                            casemask[imgidx].split('/')[-1].split('.')[0] +
                            '_net2.png')
                        save_data = generatedmask2[selectedidx][:, :, imgidx]
                        output_pil = save_data * 63
                        output_pil = Image.fromarray(
                            output_pil.astype(np.uint8), 'L')
                        output_pil.save(save_name)
            logging.info('Mask {} modify for net2'.format(
                [train_cases[i] for i in selectedidxs]))

        time_cost = time.time() - ts
        logging.info(
            'epoch[%d/%d]: train_loss1: %.3f | test_loss1: %.3f | train_dice1: %.3f | test_dice1: %.3f || '
            'traincase_dice1: %.3f || testcase_dice1: %.3f || time: %.1f' %
            (epoch + 1, end_epoch, train_loss1_epoch, test_loss1_epoch,
             train_dice1_epoch, test_dice1_epoch, traincasedice1,
             testcasedice1, time_cost))
        logging.info(
            'epoch[%d/%d]: train_loss2: %.3f | test_loss2: %.3f | train_dice2: %.3f | test_dice2: %.3f || '
            'traincase_dice2: %.3f || testcase_dice2: %.3f || time: %.1f' %
            (epoch + 1, end_epoch, train_loss2_epoch, test_loss2_epoch,
             train_dice2_epoch, test_dice2_epoch, traincasedice2,
             testcasedice2, time_cost))
        if args.lr_policy != 'None':
            scheduler1.step()
            scheduler2.step()