示例#1
0
def Train(train_root, train_csv, test_root, test_csv):

    # parameters
    args = parse_args()
    besttraindice = 0.0

    # record
    record_params(args)

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_order
    torch.manual_seed(args.torch_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.torch_seed)
    np.random.seed(args.torch_seed)
    random.seed(args.torch_seed)

    if args.cudnn == 0:
        cudnn.benchmark = False
    else:
        cudnn.benchmark = True
        cudnn.deterministic = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    num_classes = 2

    net = build_model(args.model_name, num_classes)

    # resume
    params_name = '{}_r{}.pkl'.format(args.model_name, args.repetition)
    start_epoch = 0
    history = {
        'train_loss': [],
        'test_loss': [],
        'train_dice': [],
        'test_dice': []
    }
    end_epoch = start_epoch + args.num_epoch

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        net = nn.DataParallel(net)
    net.to(device)

    # data
    img_size = args.img_size
    ## train3_multidomainl_normalcl
    train_aug = Compose([
        Resize(size=(img_size, img_size)),
        ToTensor(),
        Normalize(mean=args.data_mean, std=args.data_std)
    ])
    ## test
    test_aug = train_aug

    train_dataset = kidney_seg(root=train_root,
                               csv_file=train_csv,
                               maskidentity=args.maskidentity,
                               train=True,
                               transform=train_aug)
    test_dataset = kidney_seg(root=test_root,
                              csv_file=test_csv,
                              maskidentity=args.maskidentity,
                              train=False,
                              transform=test_aug)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=4,
                              shuffle=True,
                              drop_last=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             num_workers=4,
                             shuffle=False)

    # loss function, optimizer and scheduler
    cedice_weight = torch.tensor(args.cedice_weight)
    ceclass_weight = torch.tensor(args.ceclass_weight)
    diceclass_weight = torch.tensor(args.diceclass_weight)

    if args.loss == 'ce':
        criterion = CrossEntropyLoss2d(weight=ceclass_weight).to(device)
    elif args.loss == 'dice':
        criterion = MulticlassDiceLoss(weight=diceclass_weight).to(device)
    elif args.loss == 'cedice':
        criterion = CEMDiceLoss(cediceweight=cedice_weight,
                                ceclassweight=ceclass_weight,
                                diceclassweight=diceclass_weight).to(device)
    else:
        print('Do not have this loss')

    optimizer = Adam(net.parameters(), lr=args.lr, amsgrad=True)

    ## scheduler
    if args.lr_policy == 'StepLR':
        scheduler = StepLR(optimizer, step_size=30, gamma=0.5)
    if args.lr_policy == 'PolyLR':
        scheduler = PolyLR(optimizer, max_epoch=end_epoch, power=0.9)

    # training process
    logging.info('Start Training For Kidney Seg')

    for epoch in range(start_epoch, end_epoch):
        ts = time.time()

        # train3_multidomainl_normalcl
        net.train()
        train_loss = 0.
        train_dice = 0.
        train_count = 0
        for batch_idx, (inputs, _, targets) in \
                tqdm(enumerate(train_loader),total=int(len(train_loader.dataset) / args.batch_size)):
            inputs = inputs.to(device)
            targets = targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            train_count += inputs.shape[0]
            train_loss += loss.item() * inputs.shape[0]
            train_dice += Dice_fn(outputs, targets).item()

        train_loss_epoch = train_loss / float(train_count)
        train_dice_epoch = train_dice / float(train_count)
        history['train_loss'].append(train_loss_epoch)
        history['train_dice'].append(train_dice_epoch)

        # test
        net.eval()
        test_loss = 0.
        test_dice = 0.
        test_count = 0

        for batch_idx, (inputs, _, targets) in tqdm(
                enumerate(test_loader),
                total=int(len(test_loader.dataset) / args.batch_size)):
            with torch.no_grad():
                inputs = inputs.to(device)
                targets = targets.to(device)
                outputs = net(inputs)
                loss = criterion(outputs, targets)
            test_count += inputs.shape[0]
            test_loss += loss.item() * inputs.shape[0]
            test_dice += Dice_fn(outputs, targets).item()

        test_loss_epoch = test_loss / float(test_count)
        test_dice_epoch = test_dice / float(test_count)
        history['test_loss'].append(test_loss_epoch)
        history['test_dice'].append(test_dice_epoch)

        traineval_loss = 0.
        traineval_dice = 0.
        traineval_count = 0
        for batch_idx, (inputs, _, targets) in tqdm(
                enumerate(train_loader),
                total=int(len(train_loader.dataset) / args.batch_size)):
            with torch.no_grad():
                inputs = inputs.to(device)
                targets = targets.to(device)
                outputs = net(inputs)
                loss = criterion(outputs, targets)
            traineval_count += inputs.shape[0]
            traineval_loss += loss.item() * inputs.shape[0]
            traineval_dice += Dice_fn(outputs, targets).item()

        traineval_loss_epoch = traineval_loss / float(traineval_count)
        traineval_dice_epoch = traineval_dice / float(traineval_count)

        time_cost = time.time() - ts
        logging.info(
            'epoch[%d/%d]: train_loss: %.3f | test_loss: %.3f | train_dice: %.3f | test_dice: %.3f '
            '| traineval_dice: %.3f || time: %.1f' %
            (epoch + 1, end_epoch, train_loss_epoch, test_loss_epoch,
             train_dice_epoch, test_dice_epoch, traineval_dice_epoch,
             time_cost))

        if args.lr_policy != 'None':
            scheduler.step()

        if traineval_dice_epoch > besttraindice:
            besttraindice = traineval_dice_epoch
            logging.info('Best Checkpoint {} Saving...'.format(epoch + 1))

            save_model = net
            if torch.cuda.device_count() > 1:
                save_model = list(net.children())[0]
            state = {
                'net': save_model.state_dict(),
                'loss': test_loss_epoch,
                'dice': test_dice_epoch,
                'epoch': epoch + 1,
                'history': history
            }
            savecheckname = os.path.join(
                args.checkpoint,
                params_name.split('.pkl')[0] + '_besttraindice.' +
                params_name.split('.')[-1])
            torch.save(state, savecheckname)
示例#2
0
def Train(train_root, train_csv, test_csv, tempmaskfolder):
    makefolder(os.path.join(train_root, tempmaskfolder))

    besttraindice = 0.0
    changepointdice = 0.0
    ascending = False

    # parameters
    args = parse_args()
    record_params(args)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_order
    torch.manual_seed(args.torch_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.torch_seed)
    np.random.seed(args.torch_seed)
    random.seed(args.torch_seed)

    if args.cudnn == 0:
        cudnn.benchmark = False
    else:
        cudnn.benchmark = True
        cudnn.deterministic = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_classes = 2
    net1 = build_model(args.model1_name, num_classes)
    net2 = build_model(args.model2_name, num_classes)

    # resume
    params1_name = '{}_warmup{}_temp{}_r{}_net1.pkl'.format(
        args.model1_name, args.warmup_epoch, args.temperature, args.repetition)
    params2_name = '{}_warmup{}_temp{}_r{}_net2.pkl'.format(
        args.model2_name, args.warmup_epoch, args.temperature, args.repetition)

    checkpoint1_path = os.path.join(args.checkpoint, params1_name)
    checkpoint2_path = os.path.join(args.checkpoint, params2_name)
    initializecheckpoint = torch.load(args.resumefile)['net']
    net1.load_state_dict(initializecheckpoint)
    net2.load_state_dict(initializecheckpoint)

    start_epoch = 0
    end_epoch = args.num_epoch
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        net1 = nn.DataParallel(net1)
        net2 = nn.DataParallel(net2)
    net1.to(device)
    net2.to(device)

    # data
    train_aug = Compose([
        Resize(size=(args.img_size, args.img_size)),
        RandomRotate(args.rotation),
        RandomHorizontallyFlip(),
        ToTensor(),
        Normalize(mean=args.data_mean, std=args.data_std)
    ])
    test_aug = Compose([
        Resize(size=(args.img_size, args.img_size)),
        ToTensor(),
        Normalize(mean=args.data_mean, std=args.data_std)
    ])

    train_dataset = kidney_seg(root=train_root,
                               csv_file=train_csv,
                               tempmaskfolder=tempmaskfolder,
                               maskidentity=args.maskidentity,
                               train=True,
                               transform=train_aug)
    test_dataset = kidney_seg(
        root=train_root,
        csv_file=test_csv,
        tempmaskfolder=tempmaskfolder,
        maskidentity=args.maskidentity,
        train=False,
        transform=test_aug)  # tempmaskfolder=tempmaskfolder,

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=4,
                              shuffle=True,
                              drop_last=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             num_workers=4,
                             shuffle=False)

    # loss function, optimizer and scheduler
    cedice_weight = torch.tensor(args.cedice_weight)
    ceclass_weight = torch.tensor(args.ceclass_weight)
    diceclass_weight = torch.tensor(args.diceclass_weight)
    if args.loss == 'ce':
        criterion = CrossEntropyLoss2d(weight=ceclass_weight).to(device)
    elif args.loss == 'dice':
        criterion = MulticlassDiceLoss(weight=diceclass_weight).to(device)
    elif args.loss == 'cedice':
        criterion = CEMDiceLossImage(
            cediceweight=cedice_weight,
            ceclassweight=ceclass_weight,
            diceclassweight=diceclass_weight).to(device)
    else:
        print('Do not have this loss')
    corrlosscriterion = MulticlassMSELoss(reduction='none').to(device)

    # define augmentation loss effect schedule
    rate_schedule = np.ones(args.num_epoch)
    optimizer1 = Adam(net1.parameters(), lr=args.lr, amsgrad=True)
    optimizer2 = Adam(net2.parameters(), lr=args.lr, amsgrad=True)

    ## scheduler
    if args.lr_policy == 'StepLR':
        scheduler1 = StepLR(optimizer1, step_size=30, gamma=0.5)
        scheduler2 = StepLR(optimizer2, step_size=30, gamma=0.5)
    if args.lr_policy == 'PolyLR':
        scheduler1 = PolyLR(optimizer1, max_epoch=end_epoch, power=0.9)
        scheduler2 = PolyLR(optimizer2, max_epoch=end_epoch, power=0.9)

    # training process
    logging.info('Start Training For Kidney Seg')

    for epoch in range(start_epoch, end_epoch):

        ts = time.time()
        if args.warmup_epoch == 0:
            rate_schedule[epoch] = 1.0
        else:
            rate_schedule[epoch] = min(
                (float(epoch) / float(args.warmup_epoch))**2, 1.0)
        net1.train()
        net2.train()
        train_loss1 = 0.
        train_dice1 = 0.
        train_count = 0
        train_loss2 = 0.
        train_dice2 = 0.
        for batch_idx, (inputs, augset, targets, targets1, targets2) in \
                tqdm(enumerate(train_loader), total=int(
                    len(train_loader.dataset) / args.batch_size)):  # (inputs, augset, targets, targets1, targets2)

            net1.eval()
            net2.eval()
            augoutput1 = []
            augoutput2 = []
            for aug_idx in range(augset['augno'][0]):
                augimg = augset['img{}'.format(aug_idx + 1)].to(device)
                augoutput1.append(net1(augimg).detach())
                augoutput2.append(net2(augimg).detach())
            #
            augoutput1 = reverseaugbatch(augset,
                                         augoutput1,
                                         classno=num_classes)
            augoutput2 = reverseaugbatch(augset,
                                         augoutput2,
                                         classno=num_classes)

            for aug_idx in range(augset['augno'][0]):
                augmask1 = torch.nn.functional.softmax(augoutput1[aug_idx],
                                                       dim=1)
                augmask2 = torch.nn.functional.softmax(augoutput2[aug_idx],
                                                       dim=1)

                if aug_idx == 0:
                    pseudo_label1 = augmask1
                    pseudo_label2 = augmask2
                else:
                    pseudo_label1 += augmask1
                    pseudo_label2 += augmask2
            pseudo_label1 = pseudo_label1 / float(augset['augno'][0])
            pseudo_label2 = pseudo_label2 / float(augset['augno'][0])
            pseudo_label1 = sharpen(pseudo_label1, args.temperature)
            pseudo_label2 = sharpen(pseudo_label2, args.temperature)
            weightmap1 = 1.0 - 4.0 * pseudo_label1[:,
                                                   0, :, :] * pseudo_label1[:,
                                                                            1, :, :]
            weightmap1 = weightmap1.unsqueeze(dim=1)
            weightmap2 = 1.0 - 4.0 * pseudo_label2[:,
                                                   0, :, :] * pseudo_label2[:,
                                                                            1, :, :]
            weightmap2 = weightmap2.unsqueeze(dim=1)
            net1.train()
            net2.train()
            inputs = inputs.to(device)
            targets = targets.to(device)
            targets1 = targets1.to(device)
            targets2 = targets2.to(device)
            outputs1 = net1(inputs)
            outputs2 = net2(inputs)

            loss1_segpre = criterion(outputs1, targets2)
            loss2_segpre = criterion(outputs2, targets1)

            _, indx1 = loss1_segpre.sort()
            _, indx2 = loss2_segpre.sort()
            loss1_seg1 = criterion(outputs1[indx2[0:2], :, :, :],
                                   targets2[indx2[0:2], :, :]).mean()
            loss2_seg1 = criterion(outputs2[indx1[0:2], :, :, :],
                                   targets1[indx1[0:2], :, :]).mean()
            loss1_seg2 = criterion(outputs1[indx2[2:], :, :, :],
                                   targets2[indx2[2:], :, :]).mean()
            loss2_seg2 = criterion(outputs2[indx1[2:], :, :, :],
                                   targets1[indx1[2:], :, :]).mean()
            loss1_cor = weightmap2[indx2[2:], :, :, :] * corrlosscriterion(
                outputs1[indx2[2:], :, :, :],
                pseudo_label2[indx2[2:], :, :, :])
            loss1_cor = loss1_cor.mean()
            loss1 = args.segcor_weight[0] * (loss1_seg1 + (1.0 - rate_schedule[epoch]) * loss1_seg2) + \
                    args.segcor_weight[1] * rate_schedule[epoch] * loss1_cor

            loss2_cor = weightmap1[indx1[2:], :, :, :] * corrlosscriterion(
                outputs2[indx1[2:], :, :, :],
                pseudo_label1[indx1[2:], :, :, :])
            loss2_cor = loss2_cor.mean()
            loss2 = args.segcor_weight[0] * (loss2_seg1 + (1.0 - rate_schedule[epoch]) * loss2_seg2) + \
                    args.segcor_weight[1] * rate_schedule[epoch] * loss2_cor

            optimizer1.zero_grad()
            optimizer2.zero_grad()

            loss1.backward(retain_graph=True)
            optimizer1.step()
            loss2.backward()
            optimizer2.step()
            train_count += inputs.shape[0]
            train_loss1 += loss1.item() * inputs.shape[0]
            train_dice1 += Dice_fn(outputs1, targets2).item()
            train_loss2 += loss2.item() * inputs.shape[0]
            train_dice2 += Dice_fn(outputs2, targets1).item()
        train_loss1_epoch = train_loss1 / float(train_count)
        train_dice1_epoch = train_dice1 / float(train_count)
        train_loss2_epoch = train_loss2 / float(train_count)
        train_dice2_epoch = train_dice2 / float(train_count)

        # test
        net1.eval()
        net2.eval()
        test_loss1 = 0.
        test_dice1 = 0.
        test_loss2 = 0.
        test_dice2 = 0.
        test_count = 0
        for batch_idx, (inputs, _, targets, targets1, targets2) in \
                tqdm(enumerate(test_loader), total=int(len(test_loader.dataset) / args.batch_size)):
            with torch.no_grad():
                inputs = inputs.to(device)
                targets = targets.to(device)
                targets1 = targets1.to(device)
                targets2 = targets2.to(device)
                outputs1 = net1(inputs)
                outputs2 = net2(inputs)
                loss1 = criterion(outputs1, targets2).mean()
                loss2 = criterion(outputs2, targets1).mean()
            test_count += inputs.shape[0]
            test_loss1 += loss1.item() * inputs.shape[0]
            test_dice1 += Dice_fn(outputs1, targets2).item()
            test_loss2 += loss2.item() * inputs.shape[0]
            test_dice2 += Dice_fn(outputs2, targets1).item()
        test_loss1_epoch = test_loss1 / float(test_count)
        test_dice1_epoch = test_dice1 / float(test_count)
        test_loss2_epoch = test_loss2 / float(test_count)
        test_dice2_epoch = test_dice2 / float(test_count)

        traindices1 = torch.zeros(len(train_dataset))
        traindices2 = torch.zeros(len(train_dataset))
        generatedmask1 = []
        generatedmask2 = []
        for casecount in tqdm(range(len(train_dataset)),
                              total=len(train_dataset)):
            sample = train_dataset.__getitem__(casecount)
            img = sample[0]
            mask1 = sample[4]
            mask2 = sample[3]
            with torch.no_grad():
                img = torch.unsqueeze(img.to(device), 0)
                output1 = net1(img)
                output1 = F.softmax(output1, dim=1)
                output2 = net2(img)
                output2 = F.softmax(output2, dim=1)
            output1 = torch.argmax(output1, dim=1)
            output2 = torch.argmax(output2, dim=1)
            output1 = output1.squeeze().cpu()
            generatedoutput1 = output1.unsqueeze(dim=0).numpy()
            output2 = output2.squeeze().cpu()
            generatedoutput2 = output2.unsqueeze(dim=0).numpy()
            traindices1[casecount] = Dice2d(generatedoutput1, mask1.numpy())
            traindices2[casecount] = Dice2d(generatedoutput2, mask2.numpy())
            generatedmask1.append(generatedoutput1)
            generatedmask2.append(generatedoutput2)
        evaltrainavgdice1 = traindices1.sum() / float(len(train_dataset))
        evaltrainavgdice2 = traindices2.sum() / float(len(train_dataset))
        evaltrainavgdicetemp = (evaltrainavgdice1 + evaltrainavgdice2) / 2.0
        maskannotations = {
            '1': train_dataset.mask1,
            '2': train_dataset.mask2,
            '3': train_dataset.mask3
        }

        # update pseudolabel
        if (epoch + 1) <= args.warmup_epoch or (epoch + 1) % 10 == 0:
            avgdice = evaltrainavgdicetemp
            selected_samples = int(args.update_percent * len(train_dataset))
            save_root = os.path.join(train_root, tempmaskfolder)
            _, sortidx1 = traindices1.sort()
            selectedidxs = sortidx1[:selected_samples]
            for selectedidx in selectedidxs:
                maskname = maskannotations['{}'.format(int(
                    args.maskidentity))][selectedidx]
                savefolder = os.path.join(save_root, maskname.split('/')[-2])
                makefolder(savefolder)
                save_name = os.path.join(
                    savefolder,
                    maskname.split('/')[-1].split('.')[0] + '_net1.nii.gz')
                save_data = generatedmask1[selectedidx]
                if save_data.sum() > 0:
                    soutput = sitk.GetImageFromArray(save_data)
                    sitk.WriteImage(soutput, save_name)
            logging.info('{} masks modified for net1'.format(
                len(selectedidxs)))

            _, sortidx2 = traindices2.sort()
            selectedidxs = sortidx2[:selected_samples]
            for selectedidx in selectedidxs:
                maskname = maskannotations['{}'.format(int(
                    args.maskidentity))][selectedidx]
                savefolder = os.path.join(save_root, maskname.split('/')[-2])
                makefolder(savefolder)
                save_name = os.path.join(
                    savefolder,
                    maskname.split('/')[-1].split('.')[0] + '_net2.nii.gz')
                save_data = generatedmask2[selectedidx]
                if save_data.sum() > 0:
                    soutput = sitk.GetImageFromArray(save_data)
                    sitk.WriteImage(soutput, save_name)
            logging.info('{} masks modify for net2'.format(len(selectedidxs)))

        if epoch > 0 and changepointdice < evaltrainavgdicetemp and ascending == False:
            ascending = True
            besttraindice = changepointdice

        if evaltrainavgdicetemp > besttraindice and ascending:
            besttraindice = evaltrainavgdicetemp
            logging.info('Best Checkpoint {} Saving...'.format(epoch + 1))
            save_model = net1
            if torch.cuda.device_count() > 1:
                save_model = list(net1.children())[0]
            state = {
                'net': save_model.state_dict(),
                'loss': test_loss1_epoch,
                'epoch': epoch + 1,
            }
            torch.save(
                state, '{}_besttraindice.pkl'.format(
                    checkpoint1_path.split('.pkl')[0]))

            save_model = net2
            if torch.cuda.device_count() > 1:
                save_model = list(net2.children())[0]
            state = {
                'net': save_model.state_dict(),
                'loss': test_loss2_epoch,
                'epoch': epoch + 1,
            }
            torch.save(
                state, '{}_besttraindice.pkl'.format(
                    checkpoint2_path.split('.pkl')[0]))

        if not ascending:
            changepointdice = evaltrainavgdicetemp

        time_cost = time.time() - ts
        logging.info(
            'epoch[%d/%d]: train_loss1: %.3f | test_loss1: %.3f | '
            'train_dice1: %.3f | test_dice1: %.3f || time: %.1f' %
            (epoch + 1, end_epoch, train_loss1_epoch, test_loss1_epoch,
             train_dice1_epoch, test_dice1_epoch, time_cost))
        logging.info(
            'epoch[%d/%d]: train_loss2: %.3f | test_loss2: %.3f | '
            'train_dice2: %.3f | test_dice2: %.3f || time: %.1f' %
            (epoch + 1, end_epoch, train_loss2_epoch, test_loss2_epoch,
             train_dice2_epoch, test_dice2_epoch, time_cost))
        logging.info(
            'epoch[%d/%d]: evaltrain_dice1: %.3f | evaltrain_dice2: %.3f || time: %.1f'
            % (epoch + 1, end_epoch, evaltrainavgdice1, evaltrainavgdice2,
               time_cost))

        net1.train()
        net2.train()
        if args.lr_policy != 'None':
            scheduler1.step()
            scheduler2.step()
def Train(train_root, train_csv, test_root, test_csv, traincase_csv,
          testcase_csv):

    # parameters
    args = parse_args()
    besttraincasedice = 0.0

    train_cases = pd.read_csv(traincase_csv)['Image'].tolist()
    test_cases = pd.read_csv(testcase_csv)['Image'].tolist()
    # record
    record_params(args)

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_order
    torch.manual_seed(args.torch_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.torch_seed)
    np.random.seed(args.torch_seed)
    random.seed(args.torch_seed)

    if args.cudnn == 0:
        cudnn.benchmark = False
    else:
        cudnn.benchmark = True
        cudnn.deterministic = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    num_classes = 2

    net = build_model(args.model_name, num_classes)

    params_name = '{}_r{}.pkl'.format(args.model_name, args.repetition)
    start_epoch = 0
    history = {
        'train_loss': [],
        'test_loss': [],
        'train_dice': [],
        'test_dice': []
    }
    end_epoch = start_epoch + args.num_epoch

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        net = nn.DataParallel(net)
    net.to(device)

    # data
    img_size = args.img_size
    ## train
    train_aug = Compose([
        Resize(size=(img_size, img_size)),
        ToTensor(),
        Normalize(mean=args.data_mean, std=args.data_std)
    ])
    ## test
    test_aug = train_aug

    train_dataset = prostate_seg(root=train_root,
                                 csv_file=train_csv,
                                 transform=train_aug)
    test_dataset = prostate_seg(root=test_root,
                                csv_file=test_csv,
                                transform=test_aug)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=4,
                              shuffle=True,
                              drop_last=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             num_workers=4,
                             shuffle=False)

    # loss function, optimizer and scheduler
    cedice_weight = torch.tensor(args.cedice_weight)
    ceclass_weight = torch.tensor(args.ceclass_weight)
    diceclass_weight = torch.tensor(args.diceclass_weight)

    if args.loss == 'ce':
        criterion = CrossEntropyLoss2d(weight=ceclass_weight).to(device)
    elif args.loss == 'dice':
        criterion = MulticlassDiceLoss(weight=diceclass_weight).to(device)
    elif args.loss == 'cedice':
        criterion = CEMDiceLoss(cediceweight=cedice_weight,
                                ceclassweight=ceclass_weight,
                                diceclassweight=diceclass_weight).to(device)
    else:
        print('Do not have this loss')

    optimizer = Adam(net.parameters(), lr=args.lr, amsgrad=True)

    ## scheduler
    if args.lr_policy == 'StepLR':
        scheduler = StepLR(optimizer, step_size=30, gamma=0.5)
    if args.lr_policy == 'PolyLR':
        scheduler = PolyLR(optimizer, max_epoch=end_epoch, power=0.9)

    # training process
    logging.info('Start Training For Prostate Seg')

    for epoch in range(start_epoch, end_epoch):
        ts = time.time()

        # train
        net.train()
        train_loss = 0.
        train_dice = 0.
        train_count = 0
        for batch_idx, (inputs, _, targets) in \
                tqdm(enumerate(train_loader),total=int(len(train_loader.dataset) / args.batch_size)):
            inputs = inputs.to(device)
            targets = targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            train_count += inputs.shape[0]
            train_loss += loss.item() * inputs.shape[0]
            train_dice += Dice_fn(outputs, targets).item()

        train_loss_epoch = train_loss / float(train_count)
        train_dice_epoch = train_dice / float(train_count)
        history['train_loss'].append(train_loss_epoch)
        history['train_dice'].append(train_dice_epoch)

        # test
        net.eval()
        test_loss = 0.
        test_dice = 0.
        test_count = 0

        for batch_idx, (inputs, _, targets) in tqdm(
                enumerate(test_loader),
                total=int(len(test_loader.dataset) / args.batch_size)):
            with torch.no_grad():
                inputs = inputs.to(device)
                targets = targets.to(device)
                outputs = net(inputs)
                loss = criterion(outputs, targets)
            test_count += inputs.shape[0]
            test_loss += loss.item() * inputs.shape[0]
            test_dice += Dice_fn(outputs, targets).item()

        test_loss_epoch = test_loss / float(test_count)
        test_dice_epoch = test_dice / float(test_count)
        history['test_loss'].append(test_loss_epoch)
        history['test_dice'].append(test_dice_epoch)

        testcasedices = torch.zeros(len(test_cases))
        startimgslices = torch.zeros(len(test_cases))
        for casecount in tqdm(range(len(test_cases)), total=len(test_cases)):
            caseidx = test_cases[casecount].split('.')[0]
            caseimg = [file for file in test_dataset.imgs if caseidx in file]
            caseimg.sort()
            casemask = [file for file in test_dataset.masks if caseidx in file]
            casemask.sort()
            generatedtarget = []
            target = []
            startcaseimg = int(torch.sum(startimgslices[:casecount + 1]))
            for imgidx in range(len(caseimg)):
                sample = test_dataset.__getitem__(imgidx + startcaseimg)
                input = sample[0]
                mask = sample[2]
                target.append(mask)
                with torch.no_grad():
                    input = torch.unsqueeze(input.to(device), 0)
                    output = net(input)
                    output = F.softmax(output, dim=1)
                    output = torch.argmax(output, dim=1)
                    output = output.squeeze().cpu().numpy()
                    generatedtarget.append(output)
            target = np.stack(target, axis=-1)
            generatedtarget = np.stack(generatedtarget, axis=-1)
            generatedtarget_keeplargest = keep_largest_connected_components(
                generatedtarget)
            testcasedices[casecount] = Dice3d_fn(generatedtarget_keeplargest,
                                                 target)
            if casecount + 1 < len(test_cases):
                startimgslices[casecount + 1] = len(caseimg)
        testcasedice = testcasedices.sum() / float(len(test_cases))

        traincasedices = torch.zeros(len(train_cases))
        startimgslices = torch.zeros(len(train_cases))
        generatedmask = []
        for casecount in tqdm(range(len(train_cases)), total=len(train_cases)):
            caseidx = train_cases[casecount]
            caseimg = [file for file in train_dataset.imgs if caseidx in file]
            caseimg.sort()
            casemask = [
                file for file in train_dataset.masks if caseidx in file
            ]
            casemask.sort()
            generatedtarget = []
            target = []
            startcaseimg = int(torch.sum(startimgslices[:casecount + 1]))
            for imgidx in range(len(caseimg)):
                sample = train_dataset.__getitem__(imgidx + startcaseimg)
                input = sample[0]
                mask = sample[2]
                target.append(mask)
                with torch.no_grad():
                    input = torch.unsqueeze(input.to(device), 0)
                    output = net(input)
                    output = F.softmax(output, dim=1)
                    output = torch.argmax(output, dim=1)
                    output = output.squeeze().cpu().numpy()
                    generatedtarget.append(output)
            target = np.stack(target, axis=-1)
            generatedtarget = np.stack(generatedtarget, axis=-1)
            generatedtarget_keeplargest = keep_largest_connected_components(
                generatedtarget)
            traincasedices[casecount] = Dice3d_fn(generatedtarget_keeplargest,
                                                  target)
            generatedmask.append(generatedtarget_keeplargest)
            if casecount + 1 < len(train_cases):
                startimgslices[casecount + 1] = len(caseimg)
        traincasedice = traincasedices.sum() / float(len(train_cases))

        time_cost = time.time() - ts
        logging.info(
            'epoch[%d/%d]: train_loss: %.3f | test_loss: %.3f | train_dice: %.3f | test_dice: %.3f || time: %.1f'
            % (epoch + 1, end_epoch, train_loss_epoch, test_loss_epoch,
               train_dice_epoch, test_dice_epoch, time_cost))
        logging.info(
            'epoch[%d/%d]: traincase_dice: %.3f | testcase_dice: %.3f || time: %.1f'
            % (epoch + 1, end_epoch, traincasedice, testcasedice, time_cost))

        if args.lr_policy != 'None':
            scheduler.step()

        if traincasedice > besttraincasedice:
            besttraincasedice = traincasedice
            logging.info('Best Checkpoint {} Saving...'.format(epoch + 1))

            save_model = net
            if torch.cuda.device_count() > 1:
                save_model = list(net.children())[0]
            state = {
                'net': save_model.state_dict(),
                'loss': test_loss_epoch,
                'dice': test_dice_epoch,
                'epoch': epoch + 1,
                'history': history
            }
            savecheckname = os.path.join(
                args.checkpoint,
                params_name.split('.pkl')[0] + '_besttraindice.' +
                params_name.split('.')[-1])
            torch.save(state, savecheckname)
model_name = ['convnet', 'classifier', 'pan', 'mask_classifier']
optimizer = {
    'convnet':
    optim.SGD(convnet.parameters(), lr=args.lr, weight_decay=1e-4),
    'classifier':
    optim.SGD(classifier.parameters(), lr=args.lr, weight_decay=1e-4),
    'pan':
    optim.SGD(pan.parameters(), lr=args.lr, weight_decay=1e-4),
    'mask_classifier':
    optim.SGD(mask_classifier.parameters(), lr=args.lr, weight_decay=1e-4)
}

optimizer_lr_scheduler = {
    'convnet':
    PolyLR(optimizer['convnet'], max_iter=args.epochs, power=0.9),
    'classifier':
    PolyLR(optimizer['classifier'], max_iter=args.epochs, power=0.9),
    'pan':
    PolyLR(optimizer['pan'], max_iter=args.epochs, power=0.9),
    'mask_classifier':
    PolyLR(optimizer['mask_classifier'], max_iter=args.epochs, power=0.9)
}

best_acc = 0
for epoch in range(args.epochs):
    for m in model_name:
        optimizer_lr_scheduler[m].step(epoch)
    logging.info('Epoch:{:}'.format(epoch))
    train(epoch, optimizer, training_loader)
    if epoch % 1 == 0:
    #classifier = Classifier(in_features=2048, num_class=NUM_CLASS)
    pan = PAN(convnet.blocks[::-1])
    mask_classifier = Mask_Classifier(in_features=256, num_class=(NUM_CLASS+1))

    convnet.to(device)
    #classifier.to(device)
    pan.to(device)
    mask_classifier.to(device)
    
    #model_name = ['convnet', 'classifier', 'pan', 'mask_classifier']
    model_name = ['convnet', 'pan', 'mask_classifier']
    optimizer = {'convnet': optim.SGD(convnet.parameters(), lr=args.lr, weight_decay=1e-4),
             #'classifier': optim.SGD(classifier.parameters(), lr=args.lr, weight_decay=1e-4),
             'pan': optim.SGD(pan.parameters(), lr=args.lr, weight_decay=1e-4),
             'mask_classifier': optim.SGD(mask_classifier.parameters(), lr=args.lr, weight_decay=1e-4)}

    optimizer_lr_scheduler = {'convnet': PolyLR(optimizer['convnet'], max_iter=args.epochs, power=0.9),
                          #'classifier': PolyLR(optimizer['classifier'], max_iter=args.epochs, power=0.9),
                          'pan': PolyLR(optimizer['pan'], max_iter=args.epochs, power=0.9),
                          'mask_classifier': PolyLR(optimizer['mask_classifier'], max_iter=args.epochs, power=0.9)}
    
    best_acc = 0
    for epoch in range(args.epochs):
        for m in model_name:
            optimizer_lr_scheduler[m].step(epoch)
        logging.info('Epoch:{:}'.format(epoch))
        train(epoch, optimizer, training_loader)
        if epoch % 5 == 0:
            test(test_loader)
        
def main():
    global args, best_prec1
    global cur_itrs
    args = parser.parse_args()
    print(args.mode)

    # STEP1: model
    if args.mode=='baseline_train':
        model = initialize_model(use_resnet=True, pretrained=False, nclasses=200)
    elif args.mode=='pretrain':
        model = deeplab_network.deeplabv3_resnet50(num_classes=args.num_classes, output_stride=args.output_stride, pretrained_backbone=False)
        set_bn_momentum(model.backbone, momentum=0.01)
    elif args.mode=='finetune':
        model = initialize_model(use_resnet=True, pretrained=False, nclasses=3)
        # load the pretrained model
        if args.pretrained_model:
            if os.path.isfile(args.pretrained_model):
                print("=> loading pretrained model '{}'".format(args.pretrained_model))
                checkpoint = torch.load(args.pretrained_model)
                args.start_epoch = checkpoint['epoch']
                best_prec1 = checkpoint['best_prec1']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                print("=> loaded pretrained model '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
    if torch.cuda.is_available:
        model = model.cuda()
    
    # STEP2: criterion and optimizer
    if args.mode in ['baseline_train', 'finetune']:
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
        # train_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) 
    elif args.mode=='pretrain':
        criterion = nn.MSELoss()
        optimizer = torch.optim.SGD(params=[
        {'params': model.backbone.parameters(), 'lr': 0.1*args.lr},
        {'params': model.classifier.parameters(), 'lr': args.lr},
    ], lr=args.lr, momentum=0.9, weight_decay=args.weight_decay)
        scheduler = PolyLR(optimizer, args.total_itrs, power=0.9)

    # STEP3: loss/prec record
    if args.mode in ['baseline_train', 'finetune']:
        train_losses = []
        train_top1s = []
        train_top5s = []

        test_losses = []
        test_top1s = []
        test_top5s = []
    elif args.mode == 'pretrain':
        train_losses = []
        test_losses = []

    # STEP4: optionlly resume from a checkpoint
    if args.resume:
        print('resume')
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            if args.mode in ['baseline_train', 'finetune']:
                checkpoint = torch.load(args.resume)
                args.start_epoch = checkpoint['epoch']
                best_prec1 = checkpoint['best_prec1']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                datafile = args.resume.split('.pth')[0] + '.npz'
                load_data = np.load(datafile)
                train_losses = list(load_data['train_losses'])
                train_top1s = list(load_data['train_top1s'])
                train_top5s = list(load_data['train_top5s'])
                test_losses = list(load_data['test_losses'])
                test_top1s = list(load_data['test_top1s'])
                test_top5s = list(load_data['test_top5s'])
            elif args.mode=='pretrain':
                checkpoint = torch.load(args.resume)
                args.start_epoch = checkpoint['epoch']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                scheduler.load_state_dict(checkpoint['scheduler'])
                cur_itrs = checkpoint['cur_itrs']
                datafile = args.resume.split('.pth')[0] + '.npz'
                load_data = np.load(datafile)
                train_losses = list(load_data['train_losses'])
                # test_losses = list(load_data['test_losses'])
            print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    # STEP5: train!
    if args.mode in ['baseline_train', 'finetune']:
        # data
        from utils import TinyImageNet_data_loader
        print('color_distortion:', color_distortion)
        train_loader, val_loader = TinyImageNet_data_loader(args.dataset, args.batch_size,color_distortion=args.color_distortion)
        
        # if evaluate the model
        if args.evaluate:
            print('evaluate this model on validation dataset')
            validate(val_loader, model, criterion, args.print_freq)
            return
        
        for epoch in range(args.start_epoch, args.epochs):
            adjust_learning_rate(optimizer, epoch, args.lr)
            time1 = time.time() #timekeeping

            # train for one epoch
            model.train()
            loss, top1, top5 = train(train_loader, model, criterion, optimizer, epoch, args.print_freq)
            train_losses.append(loss)
            train_top1s.append(top1)
            train_top5s.append(top5)

            # evaluate on validation set
            model.eval()
            loss, prec1, prec5 = validate(val_loader, model, criterion, args.print_freq)
            test_losses.append(loss)
            test_top1s.append(prec1)
            test_top5s.append(prec5)

            # remember the best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)

            save_checkpoint({
                'epoch': epoch + 1,
                'mode': args.mode,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict()
            }, is_best, args.mode + '_' + args.dataset +'.pth')

            np.savez(args.mode + '_' + args.dataset +'.npz', train_losses=train_losses,train_top1s=train_top1s,train_top5s=train_top5s, test_losses=test_losses,test_top1s=test_top1s, test_top5s=test_top5s)
           # np.savez(args.mode + '_' + args.dataset +'.npz', train_losses=train_losses)
            time2 = time.time() #timekeeping
            print('Elapsed time for epoch:',time2 - time1,'s')
            print('ETA of completion:',(time2 - time1)*(args.epochs - epoch - 1)/60,'minutes')
            print()
    elif args.mode=='pretrain':
        #data
        from utils import TinyImageNet_data_loader
        # args.dataset = 'tiny-imagenet-200'
        args.batch_size = 16
        train_loader, val_loader = TinyImageNet_data_loader(args.dataset, args.batch_size, col=True)
        
        # if evaluate the model, show some results
        if args.evaluate:
            print('evaluate this model on validation dataset')
            visulization(val_loader, model, args.start_epoch)
            return

        # for epoch in range(args.start_epoch, args.epochs):
        epoch = 0
        while True:
            if cur_itrs >=  args.total_itrs:
                return
            # adjust_learning_rate(optimizer, epoch, args.lr)
            time1 = time.time() #timekeeping

            model.train()
            # train for one epoch
            # loss, _, _ = train(train_loader, model, criterion, optimizer, epoch, args.print_freq, colorization=True,scheduler=scheduler)
            # train_losses.append(loss)
            

            # model.eval()
            # # evaluate on validation set
            # loss, _, _ = validate(val_loader, model, criterion, args.print_freq, colorization=True)
            # test_losses.append(loss)

            save_checkpoint({
                'epoch': epoch + 1,
                'mode': args.mode,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler':scheduler.state_dict(),
                "cur_itrs": cur_itrs
            }, True, args.mode + '_' + args.dataset +'.pth')

            np.savez(args.mode + '_' + args.dataset +'.npz', train_losses=train_losses)
            # scheduler.step()
            time2 = time.time() #timekeeping
            print('Elapsed time for epoch:',time2 - time1,'s')
            print('ETA of completion:',(time2 - time1)*(args.total_itrs - cur_itrs - 1)/60,'minutes')
            print()
            epoch += 1
def Train(train_root, train_csv, test_csv, traincase_csv, testcase_csv,
          labelcase_csv, tempmaskfolder):
    makefolder(os.path.join(train_root, tempmaskfolder))

    # parameters
    args = parse_args()

    # record
    record_params(args)

    train_cases = pd.read_csv(traincase_csv)['patient_case'].tolist()
    test_cases = pd.read_csv(testcase_csv)['patient_case'].tolist()
    label_cases = pd.read_csv(labelcase_csv)['patient_case'].tolist()

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_order
    torch.manual_seed(args.torch_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.torch_seed)
    np.random.seed(args.torch_seed)
    random.seed(args.torch_seed)

    if args.cudnn == 0:
        cudnn.benchmark = False
    else:
        cudnn.benchmark = True
        cudnn.deterministic = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    num_classes = 2

    net1 = build_model(args.model_name, num_classes)
    net2 = build_model(args.model_name, num_classes)

    params1_name = '{}_temp{}_r{}_net1.pkl'.format(args.model_name,
                                                   args.temperature,
                                                   args.repetition)
    params2_name = '{}_temp{}_r{}_net2.pkl'.format(args.model_name,
                                                   args.temperature,
                                                   args.repetition)
    start_epoch = 0
    end_epoch = args.num_epoch

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        net1 = nn.DataParallel(net1)
        net2 = nn.DataParallel(net2)
    net1.to(device)
    net2.to(device)

    # data
    train_aug = Compose([
        Resize(size=(args.img_size, args.img_size)),
        RandomRotate(args.rotation),
        RandomHorizontallyFlip(),
        ToTensor(),
        Normalize(mean=args.data_mean, std=args.data_std)
    ])
    test_aug = Compose([
        Resize(size=(args.img_size, args.img_size)),
        ToTensor(),
        Normalize(mean=args.data_mean, std=args.data_std)
    ])

    train_dataset = chaos_seg(root=train_root,
                              csv_file=train_csv,
                              tempmaskfolder=tempmaskfolder,
                              transform=train_aug)
    test_dataset = chaos_seg(root=train_root,
                             csv_file=test_csv,
                             tempmaskfolder=tempmaskfolder,
                             transform=test_aug)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batch_size,
                              num_workers=4,
                              shuffle=True,
                              drop_last=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             num_workers=4,
                             shuffle=False)

    # loss function, optimizer and scheduler
    cedice_weight = torch.tensor(args.cedice_weight)
    ceclass_weight = torch.tensor(args.ceclass_weight)
    diceclass_weight = torch.tensor(args.diceclass_weight)

    if args.loss == 'ce':
        criterion = CrossEntropyLoss2d(weight=ceclass_weight).to(device)
    elif args.loss == 'dice':
        criterion = MulticlassDiceLoss(weight=diceclass_weight).to(device)
    elif args.loss == 'cedice':
        criterion = CEMDiceLossImage(
            cediceweight=cedice_weight,
            ceclassweight=ceclass_weight,
            diceclassweight=diceclass_weight).to(device)
    else:
        print('Do not have this loss')
    corrlosscriterion = MulticlassMSELoss(reduction='none').to(device)

    # define augmentation loss effect schedule
    rate_schedule = np.ones(args.num_epoch)

    optimizer1 = Adam(net1.parameters(), lr=args.lr, amsgrad=True)
    optimizer2 = Adam(net2.parameters(), lr=args.lr, amsgrad=True)

    ## scheduler
    if args.lr_policy == 'StepLR':
        scheduler1 = StepLR(optimizer1, step_size=30, gamma=0.5)
        scheduler2 = StepLR(optimizer2, step_size=30, gamma=0.5)
    if args.lr_policy == 'PolyLR':
        scheduler1 = PolyLR(optimizer1, max_epoch=end_epoch, power=0.9)
        scheduler2 = PolyLR(optimizer2, max_epoch=end_epoch, power=0.9)

    # training process
    logging.info('Start Training For CHAOS Seg')
    besttraincasedice = 0.0
    for epoch in range(start_epoch, end_epoch):

        ts = time.time()
        rate_schedule[epoch] = min(
            (float(epoch) / float(args.warmup_epoch))**2, 1.0)

        # train
        net1.train()
        net2.train()

        train_loss1 = 0.
        train_dice1 = 0.
        train_count = 0
        train_loss2 = 0.
        train_dice2 = 0.

        for batch_idx, (inphase, outphase, augset, targets, targets1, targets2) in \
                tqdm(enumerate(train_loader), total=int(len(train_loader.dataset) / args.batch_size)):

            augoutput1 = []
            augoutput2 = []
            for aug_idx in range(augset['augno'][0]):
                augimgin = augset['imgmodal1{}'.format(aug_idx + 1)].to(device)
                augimgout = augset['imgmodal2{}'.format(aug_idx +
                                                        1)].to(device)
                augoutput1.append(net1(augimgin, augimgout).detach())
                augoutput2.append(net2(augimgin, augimgout).detach())

            augoutput1 = reverseaug(augset, augoutput1, classno=num_classes)
            augoutput2 = reverseaug(augset, augoutput2, classno=num_classes)

            for aug_idx in range(augset['augno'][0]):
                augmask1 = torch.nn.functional.softmax(augoutput1[aug_idx],
                                                       dim=1)
                augmask2 = torch.nn.functional.softmax(augoutput2[aug_idx],
                                                       dim=1)

                if aug_idx == 0:
                    pseudo_label1 = augmask1
                    pseudo_label2 = augmask2
                else:
                    pseudo_label1 += augmask1
                    pseudo_label2 += augmask2

            pseudo_label1 = pseudo_label1 / float(augset['augno'][0])
            pseudo_label2 = pseudo_label2 / float(augset['augno'][0])
            pseudo_label1 = sharpen(pseudo_label1, args.temperature)
            pseudo_label2 = sharpen(pseudo_label2, args.temperature)
            weightmap1 = 1.0 - 4.0 * pseudo_label1[:,
                                                   0, :, :] * pseudo_label1[:,
                                                                            1, :, :]
            weightmap1 = weightmap1.unsqueeze(dim=1)
            weightmap2 = 1.0 - 4.0 * pseudo_label2[:,
                                                   0, :, :] * pseudo_label2[:,
                                                                            1, :, :]
            weightmap2 = weightmap2.unsqueeze(dim=1)

            inphase = inphase.to(device)
            outphase = outphase.to(device)
            targets1 = targets1[:, 1, :, :].to(device)
            targets2 = targets2[:, 1, :, :].to(device)
            optimizer1.zero_grad()
            optimizer2.zero_grad()

            outputs1 = net1(inphase, outphase)
            outputs2 = net2(inphase, outphase)
            loss1_segpre = criterion(outputs1, targets2)
            loss2_segpre = criterion(outputs2, targets1)
            _, indx1 = loss1_segpre.sort()
            _, indx2 = loss2_segpre.sort()
            loss1_seg1 = criterion(outputs1[indx2[0:2], :, :, :],
                                   targets2[indx2[0:2], :, :]).mean()
            loss2_seg1 = criterion(outputs2[indx1[0:2], :, :, :],
                                   targets1[indx1[0:2], :, :]).mean()
            loss1_seg2 = criterion(outputs1[indx2[2:], :, :, :],
                                   targets2[indx2[2:], :, :]).mean()
            loss2_seg2 = criterion(outputs2[indx1[2:], :, :, :],
                                   targets1[indx1[2:], :, :]).mean()
            loss1_cor = weightmap2[indx2[2:], :, :, :] * corrlosscriterion(
                outputs1[indx2[2:], :, :, :],
                pseudo_label2[indx2[2:], :, :, :])
            loss1_cor = loss1_cor.mean()
            loss1 = args.segcor_weight[0] * (loss1_seg1 + (1.0 - rate_schedule[epoch]) * loss1_seg2) + \
                    args.segcor_weight[1] * rate_schedule[epoch] * loss1_cor

            loss2_cor = weightmap1[indx1[2:], :, :, :] * corrlosscriterion(
                outputs2[indx1[2:], :, :, :],
                pseudo_label1[indx1[2:], :, :, :])
            loss2_cor = loss2_cor.mean()
            loss2 = args.segcor_weight[0] * (loss2_seg1 + (1.0 - rate_schedule[epoch]) * loss2_seg2) + \
                    args.segcor_weight[1] * rate_schedule[epoch] * loss2_cor
            loss1.backward(retain_graph=True)
            optimizer1.step()
            loss2.backward()
            optimizer2.step()
            train_count += inphase.shape[0]
            train_loss1 += loss1.item() * inphase.shape[0]
            train_dice1 += Dice_fn(outputs1, targets2).item()
            train_loss2 += loss2.item() * inphase.shape[0]
            train_dice2 += Dice_fn(outputs2, targets1).item()
        train_loss1_epoch = train_loss1 / float(train_count)
        train_dice1_epoch = train_dice1 / float(train_count)
        train_loss2_epoch = train_loss2 / float(train_count)
        train_dice2_epoch = train_dice2 / float(train_count)

        print(rate_schedule[epoch])
        print(args.segcor_weight[0] *
              (loss1_seg1 + (1.0 - rate_schedule[epoch]) * loss1_seg2))
        print(args.segcor_weight[1] * rate_schedule[epoch] * loss1_cor)

        print(args.segcor_weight[0] *
              (loss2_seg1 + (1.0 - rate_schedule[epoch]) * loss2_seg2))
        print(args.segcor_weight[1] * rate_schedule[epoch] * loss2_cor)

        # test
        net1.eval()
        net2.eval()
        test_loss1 = 0.
        test_dice1 = 0.
        test_loss2 = 0.
        test_dice2 = 0.
        test_count = 0
        for batch_idx, (inphase, outphase, augset, targets, targets1, targets2) in \
                tqdm(enumerate(test_loader), total=int(len(test_loader.dataset) / args.batch_size)):
            with torch.no_grad():
                inphase = inphase.to(device)
                outphase = outphase.to(device)
                targets1 = targets1[:, 1, :, :].to(device)
                targets2 = targets2[:, 1, :, :].to(device)
                outputs1 = net1(inphase, outphase)
                outputs2 = net2(inphase, outphase)
                loss1 = criterion(outputs1, targets2).mean()
                loss2 = criterion(outputs2, targets1).mean()
            test_count += inphase.shape[0]
            test_loss1 += loss1.item() * inphase.shape[0]
            test_dice1 += Dice_fn(outputs1, targets2).item()
            test_loss2 += loss2.item() * inphase.shape[0]
            test_dice2 += Dice_fn(outputs2, targets1).item()

        test_loss1_epoch = test_loss1 / float(test_count)
        test_dice1_epoch = test_dice1 / float(test_count)
        test_loss2_epoch = test_loss2 / float(test_count)
        test_dice2_epoch = test_dice2 / float(test_count)

        testcasedices1 = torch.zeros(len(test_cases))
        testcasedices2 = torch.zeros(len(test_cases))
        startimgslices = torch.zeros(len(test_cases))
        for casecount in tqdm(range(len(test_cases)), total=len(test_cases)):
            caseidx = test_cases[casecount]
            caseinphaseimg = [
                file for file in test_dataset.t1inphase
                if int(file.split('/')[0]) == caseidx
            ]
            caseinphaseimg.sort()
            caseoutphaseimg = [
                file for file in test_dataset.t1outphase
                if int(file.split('/')[0]) == caseidx
            ]
            caseoutphaseimg.sort()
            casemask = [
                file for file in test_dataset.masks
                if int(file.split('/')[0]) == caseidx
            ]
            casemask.sort()
            generatedtarget1 = []
            generatedtarget2 = []
            target1 = []
            target2 = []
            startcaseimg = int(torch.sum(startimgslices[:casecount + 1]))
            for imgidx in range(len(caseinphaseimg)):
                assert caseinphaseimg[imgidx].split('/')[-1].split('.')[0] == \
                       casemask[imgidx].split('/')[-1].split('.')[0]
                assert caseinphaseimg[imgidx].split('/')[-1].split('-')[1] == \
                       caseoutphaseimg[imgidx].split('/')[-1].split('-')[1]
                assert int(caseinphaseimg[imgidx].split('/')[-1].split('-')[-1].split('.')[0]) == \
                       int(caseoutphaseimg[imgidx].split('/')[-1].split('-')[-1].split('.')[0]) + 1
                sample = test_dataset.__getitem__(imgidx + startcaseimg)
                inphase = sample[0]
                outphase = sample[1]
                mask1 = sample[3]
                mask2 = sample[4]
                target1.append(mask1[1, :, :])
                target2.append(mask2[1, :, :])
                with torch.no_grad():
                    inphase = torch.unsqueeze(inphase.to(device), 0)
                    outphase = torch.unsqueeze(outphase.to(device), 0)
                    output1 = net1(inphase, outphase)
                    output1 = F.softmax(output1, dim=1)
                    output1 = torch.argmax(output1, dim=1)
                    output1 = output1.squeeze().cpu().numpy()
                    generatedtarget1.append(output1)
                    output2 = net2(inphase, outphase)
                    output2 = F.softmax(output2, dim=1)
                    output2 = torch.argmax(output2, dim=1)
                    output2 = output2.squeeze().cpu().numpy()
                    generatedtarget2.append(output2)
            target1 = np.stack(target1, axis=-1)
            target2 = np.stack(target2, axis=-1)
            generatedtarget1 = np.stack(generatedtarget1, axis=-1)
            generatedtarget2 = np.stack(generatedtarget2, axis=-1)
            generatedtarget1_keeplargest = keep_largest_connected_components(
                generatedtarget1)
            generatedtarget2_keeplargest = keep_largest_connected_components(
                generatedtarget2)
            testcasedices1[casecount] = Dice3d_fn(generatedtarget1_keeplargest,
                                                  target1)
            testcasedices2[casecount] = Dice3d_fn(generatedtarget2_keeplargest,
                                                  target2)
            if casecount + 1 < len(test_cases):
                startimgslices[casecount + 1] = len(caseinphaseimg)
        testcasedice1 = testcasedices1.sum() / float(len(test_cases))
        testcasedice2 = testcasedices2.sum() / float(len(test_cases))

        traincasedices1 = torch.zeros(len(train_cases))
        traincasedices2 = torch.zeros(len(train_cases))
        # update pseudolabel
        startimgslices = torch.zeros(len(train_cases))
        generatedmask1 = []
        generatedmask2 = []
        for casecount in tqdm(range(len(train_cases)), total=len(train_cases)):
            caseidx = train_cases[casecount]
            caseinphaseimg = [
                file for file in train_dataset.t1inphase
                if int(file.split('/')[0]) == caseidx
            ]
            caseinphaseimg.sort()
            caseoutphaseimg = [
                file for file in train_dataset.t1outphase
                if int(file.split('/')[0]) == caseidx
            ]
            caseoutphaseimg.sort()
            if caseidx in label_cases:
                casemask = [
                    file for file in train_dataset.masks
                    if file.split('/')[0].isdigit()
                ]
                casemask = [
                    file for file in casemask
                    if int(file.split('/')[0]) == caseidx
                ]
            else:
                casemask = [
                    file for file in train_dataset.masks
                    if file.split('/')[-2].isdigit()
                ]
                casemask = [
                    file for file in casemask
                    if int(file.split('/')[-2]) == caseidx
                ]
            casemask.sort()
            generatedtarget1 = []
            generatedtarget2 = []
            target1 = []
            target2 = []
            startcaseimg = int(torch.sum(startimgslices[:casecount + 1]))
            for imgidx in range(len(caseinphaseimg)):
                assert caseinphaseimg[imgidx].split('/')[-1].split('.')[0] == \
                       casemask[imgidx].split('/')[-1].split('.')[0]
                assert caseinphaseimg[imgidx].split('/')[-1].split('-')[1] == \
                       caseoutphaseimg[imgidx].split('/')[-1].split('-')[1]
                assert int(caseinphaseimg[imgidx].split('/')[-1].split('-')[-1].split('.')[0]) == \
                       int(caseoutphaseimg[imgidx].split('/')[-1].split('-')[-1].split('.')[0]) + 1
                sample = train_dataset.__getitem__(imgidx + startcaseimg)
                inphase = sample[0]
                outphase = sample[1]
                mask1 = sample[3]
                mask2 = sample[4]
                target1.append(mask1[1, :, :])
                target2.append(mask2[1, :, :])
                with torch.no_grad():
                    inphase = torch.unsqueeze(inphase.to(device), 0)
                    outphase = torch.unsqueeze(outphase.to(device), 0)
                    output1 = net1(inphase, outphase)
                    output1 = F.softmax(output1, dim=1)
                    output1 = torch.argmax(output1, dim=1)
                    output1 = output1.squeeze().cpu().numpy()
                    generatedtarget1.append(output1)

                    output2 = net2(inphase, outphase)
                    output2 = F.softmax(output2, dim=1)
                    output2 = torch.argmax(output2, dim=1)
                    output2 = output2.squeeze().cpu().numpy()
                    generatedtarget2.append(output2)

            target1 = np.stack(target1, axis=-1)
            target2 = np.stack(target2, axis=-1)
            generatedtarget1 = np.stack(generatedtarget1, axis=-1)
            generatedtarget2 = np.stack(generatedtarget2, axis=-1)
            generatedtarget1_keeplargest = keep_largest_connected_components(
                generatedtarget1)
            generatedtarget2_keeplargest = keep_largest_connected_components(
                generatedtarget2)
            traincasedices1[casecount] = Dice3d_fn(
                generatedtarget1_keeplargest, target1)
            traincasedices2[casecount] = Dice3d_fn(
                generatedtarget2_keeplargest, target2)
            generatedmask1.append(generatedtarget1_keeplargest)
            generatedmask2.append(generatedtarget2_keeplargest)
            if casecount + 1 < len(train_cases):
                startimgslices[casecount + 1] = len(caseinphaseimg)

        traincasedice1 = traincasedices1.sum() / float(len(train_cases))
        traincasedice2 = traincasedices2.sum() / float(len(train_cases))

        traincasediceavgtemp = (traincasedice1 + traincasedice2) / 2.0

        if traincasediceavgtemp > besttraincasedice:
            besttraincasedice = traincasediceavgtemp
            logging.info('Best Checkpoint {} Saving...'.format(epoch + 1))

            save_model = net1
            if torch.cuda.device_count() > 1:
                save_model = list(net1.children())[0]
            state = {
                'net': save_model.state_dict(),
                'loss': test_loss1_epoch,
                'epoch': epoch + 1,
            }
            savecheckname = os.path.join(
                args.checkpoint,
                params1_name.split('.pkl')[0] + '_besttraincasedice.' +
                params1_name.split('.')[-1])
            torch.save(state, savecheckname)

            save_model = net2
            if torch.cuda.device_count() > 1:
                save_model = list(net2.children())[0]
            state = {
                'net': save_model.state_dict(),
                'loss': test_loss2_epoch,
                'epoch': epoch + 1,
            }
            savecheckname = os.path.join(
                args.checkpoint,
                params2_name.split('.pkl')[0] + '_besttraincasedicde.' +
                params2_name.split('.')[-1])
            torch.save(state, savecheckname)

        if (epoch + 1) <= args.warmup_epoch or (epoch + 1) % 10 == 0:
            selected_samples = int(0.25 * len(train_cases))
            save_root = os.path.join(train_root, tempmaskfolder)
            _, sortidx1 = traincasedices1.sort()
            selectedidxs = sortidx1[:selected_samples]
            for selectedidx in selectedidxs:
                caseidx = train_cases[selectedidx]
                if caseidx not in label_cases:
                    caseinphaseimg = [
                        file for file in train_dataset.t1inphase
                        if int(file.split('/')[0]) == caseidx
                    ]
                    caseinphaseimg.sort()
                    caseoutphaseimg = [
                        file for file in train_dataset.t1outphase
                        if int(file.split('/')[0]) == caseidx
                    ]
                    caseoutphaseimg.sort()
                    casemask = [
                        file for file in train_dataset.masks
                        if file.split('/')[-2].isdigit()
                    ]
                    casemask = [
                        file for file in casemask
                        if int(file.split('/')[-2]) == caseidx
                    ]
                    casemask.sort()
                    for imgidx in range(len(caseinphaseimg)):
                        save_folder = os.path.join(save_root, str(caseidx))
                        makefolder(save_folder)
                        save_name = os.path.join(
                            save_folder,
                            casemask[imgidx].split('/')[-1].split('.')[0] +
                            '_net1.png')
                        save_data = generatedmask1[selectedidx][:, :, imgidx]
                        output_pil = save_data * 63
                        output_pil = Image.fromarray(
                            output_pil.astype(np.uint8), 'L')
                        output_pil.save(save_name)
            logging.info('Mask {} modify for net1'.format(
                [train_cases[i] for i in selectedidxs]))

            _, sortidx2 = traincasedices2.sort()
            selectedidxs = sortidx2[:selected_samples]
            for selectedidx in selectedidxs:
                caseidx = train_cases[selectedidx]
                if caseidx not in label_cases:
                    caseinphaseimg = [
                        file for file in train_dataset.t1inphase
                        if int(file.split('/')[0]) == caseidx
                    ]
                    caseinphaseimg.sort()
                    caseoutphaseimg = [
                        file for file in train_dataset.t1outphase
                        if int(file.split('/')[0]) == caseidx
                    ]
                    caseoutphaseimg.sort()
                    casemask = [
                        file for file in train_dataset.masks
                        if file.split('/')[-2].isdigit()
                    ]
                    casemask = [
                        file for file in casemask
                        if int(file.split('/')[-2]) == caseidx
                    ]
                    casemask.sort()
                    for imgidx in range(len(caseinphaseimg)):
                        save_folder = os.path.join(save_root, str(caseidx))
                        makefolder(save_folder)
                        save_name = os.path.join(
                            save_folder,
                            casemask[imgidx].split('/')[-1].split('.')[0] +
                            '_net2.png')
                        save_data = generatedmask2[selectedidx][:, :, imgidx]
                        output_pil = save_data * 63
                        output_pil = Image.fromarray(
                            output_pil.astype(np.uint8), 'L')
                        output_pil.save(save_name)
            logging.info('Mask {} modify for net2'.format(
                [train_cases[i] for i in selectedidxs]))

        time_cost = time.time() - ts
        logging.info(
            'epoch[%d/%d]: train_loss1: %.3f | test_loss1: %.3f | train_dice1: %.3f | test_dice1: %.3f || '
            'traincase_dice1: %.3f || testcase_dice1: %.3f || time: %.1f' %
            (epoch + 1, end_epoch, train_loss1_epoch, test_loss1_epoch,
             train_dice1_epoch, test_dice1_epoch, traincasedice1,
             testcasedice1, time_cost))
        logging.info(
            'epoch[%d/%d]: train_loss2: %.3f | test_loss2: %.3f | train_dice2: %.3f | test_dice2: %.3f || '
            'traincase_dice2: %.3f || testcase_dice2: %.3f || time: %.1f' %
            (epoch + 1, end_epoch, train_loss2_epoch, test_loss2_epoch,
             train_dice2_epoch, test_dice2_epoch, traincasedice2,
             testcasedice2, time_cost))
        if args.lr_policy != 'None':
            scheduler1.step()
            scheduler2.step()
def Train(train_root, train_csv, test_csv):

    # parameters
    args = parse_args()
    record_params(args)

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_order
    torch.manual_seed(args.torch_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.torch_seed)
    np.random.seed(args.torch_seed)
    random.seed(args.torch_seed)

    if args.cudnn == 0:
        cudnn.benchmark = False
    else:
        cudnn.benchmark = True
        cudnn.deterministic = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    num_classes = 4
    net = build_model(args.model_name, num_classes, args.pretrain)

    # resume
    checkpoint_name_loss = os.path.join(
        args.checkpoint,
        args.params_name.split('.')[0] + '_loss.' +
        args.params_name.split('.')[-1])
    checkpoint_name_acc = os.path.join(
        args.checkpoint,
        args.params_name.split('.')[0] + '_acc.' +
        args.params_name.split('.')[-1])
    if args.resume != 0:
        logging.info('Resuming from checkpoint...')
        checkpoint = torch.load(checkpoint_name_loss)
        best_loss = checkpoint['loss']
        best_acc = checkpoint['acc']
        start_epoch = checkpoint['epoch']
        history = checkpoint['history']
        net.load_state_dict(checkpoint['net'])
    else:
        best_loss = float('inf')
        best_acc = 0.0
        start_epoch = 0
        history = {
            'train_loss': [],
            'train_acc': [],
            'test_loss': [],
            'test_acc': []
        }
    end_epoch = start_epoch + args.num_epoch

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        net = nn.DataParallel(net)
    net.to(device)

    # data
    img_size = args.img_size
    ## train
    train_aug = Compose([
        Resize(size=(img_size, img_size)),
        RandomHorizontallyFlip(),
        RandomVerticallyFlip(),
        RandomRotate(90),
        ToTensor(),
        Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    ## test
    # test_aug = train_aug
    test_aug = Compose([
        Resize(size=(img_size, img_size)),
        ToTensor(),
        Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])

    train_dataset = breast_classify_inbreast(root=train_root,
                                             csv_file=train_csv,
                                             transform=train_aug)
    test_dataset = breast_classify_inbreast(root=train_root,
                                            csv_file=test_csv,
                                            transform=test_aug)

    if args.weighted_sampling == 1:
        weights = torch.FloatTensor([1.0, 1.0, 1.5, 5.0]).to(device)
        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=4,
                                  shuffle=True)
    else:
        weights = None
        train_loader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=4,
                                  shuffle=True)

    # train_loader = DataLoader(train_dataset, batch_size=args.batch_size,
    #                           num_workers=4, shuffle=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.batch_size,
                             num_workers=4,
                             shuffle=True)

    # loss function, optimizer and scheduler

    criterion = nn.NLLLoss(size_average=True, weight=weights).to(device)

    optimizer = Adam(net.parameters(), lr=args.lr, amsgrad=True)

    ## scheduler
    if args.lr_policy == 'StepLR':
        scheduler = StepLR(optimizer, step_size=30, gamma=0.5)
    if args.lr_policy == 'PolyLR':
        scheduler = PolyLR(optimizer, max_epoch=end_epoch, power=0.9)

    # training process
    logging.info('Start Training For Breast Density Classification')
    for epoch in range(start_epoch, end_epoch):
        ts = time.time()
        if args.lr_policy != 'None':
            scheduler.step()

        # train
        net.train()
        train_loss = 0.
        train_acc = 0.

        for batch_idx, (inputs, targets) in tqdm(enumerate(train_loader),
                                                 total=int(len(train_loader))):
            inputs = inputs.to(device)
            targets = targets.to(device)
            targets = targets.long()
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(F.log_softmax(outputs, dim=1), targets)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            accuracy = float(sum(outputs.argmax(dim=1) == targets))
            train_acc += accuracy

        train_acc_epoch = train_acc / (len(train_loader.dataset))

        train_loss_epoch = train_loss / (batch_idx + 1)
        history['train_loss'].append(train_loss_epoch)
        history['train_acc'].append(train_acc_epoch)

        # test
        net.eval()
        test_loss = 0.
        test_acc = 0.

        for batch_idx, (inputs, targets) in tqdm(
                enumerate(test_loader),
                total=int(len(test_loader.dataset) / args.batch_size) + 1):
            with torch.no_grad():
                inputs = inputs.to(device)
                targets = targets.to(device)
                targets = targets.long()
                outputs = net(inputs)
                loss = criterion(F.log_softmax(outputs, dim=1), targets)
                accuracy = float(sum(outputs.argmax(dim=1) == targets))

            test_acc += accuracy
            test_loss += loss.item()

        test_loss_epoch = test_loss / (batch_idx + 1)
        test_acc_epoch = test_acc / (len(test_loader.dataset))
        history['test_loss'].append(test_loss_epoch)
        history['test_acc'].append(test_acc_epoch)

        time_cost = time.time() - ts
        logging.info(
            'epoch[%d/%d]: train_loss: %.3f | train_acc: %.3f | test_loss: %.3f | test_acc: %.3f || time: %.1f'
            % (epoch + 1, end_epoch, train_loss_epoch, train_acc_epoch,
               test_loss_epoch, test_acc_epoch, time_cost))

        # save checkpoint
        if test_loss_epoch < best_loss:
            logging.info('Loss checkpoint Saving...')

            save_model = net
            if torch.cuda.device_count() > 1:
                save_model = list(net.children())[0]
            state = {
                'net': save_model.state_dict(),
                'loss': test_loss_epoch,
                'acc': test_acc_epoch,
                'epoch': epoch + 1,
                'history': history
            }
            torch.save(state, checkpoint_name_loss)
            best_loss = test_loss_epoch

        if test_acc_epoch > best_acc:
            logging.info('Acc checkpoint Saving...')

            save_model = net
            if torch.cuda.device_count() > 1:
                save_model = list(net.children())[0]
            state = {
                'net': save_model.state_dict(),
                'loss': test_loss_epoch,
                'acc': test_acc_epoch,
                'epoch': epoch + 1,
                'history': history
            }
            torch.save(state, checkpoint_name_acc)
            best_acc = test_acc_epoch