Ejemplo n.º 1
0
 def forward(self, x, y):
     x = self.modelA(x)
     y = self.modelB(y)
     x = extractDict(x)
     y = extractDict(y)
     if self.fusion_function == 'conv':
         out = torch.cat((x, y), dim=1)
         out = self.fusion(out)
     elif self.fusion_function == 'max':
         out = torch.max(x, y)
     elif self.fusion_function == 'sum':
         out = torch.add(x, y)
     return out
Ejemplo n.º 2
0
    def forward(self, x, y):
        input_shape = x.shape[-1]
        # contract: features is a dict of tensors
        features = self.backbone(x, y)
        if self.if_extract_dict:
            fine_grained = features['fine_grained']
            features = extractDict(features)
        if isinstance(features, tuple):
            x = self.classifier(*features)
        else:
            x = self.classifier(features)

        if self.if_point_rend_upsample:
            if not self.training:
                result = self.rend_upsample(512, fine_grained, x)
                return result['out']
            else:
                result = self.rend_upsample(input_shape, fine_grained, x)
                # when training, calculating seg_loss from result['coarse'] after interpolating and point_loss from result['point'] and rend
                result['out'] = F.interpolate(x,
                                              size=input_shape,
                                              mode='bilinear',
                                              align_corners=False)
                return result
        else:
            x = F.interpolate(x,
                              size=input_shape,
                              mode='bilinear',
                              align_corners=False)
            return x
Ejemplo n.º 3
0
def train(epoch):
    # 设置数据集
    train_dataloader = get_dataloader('train')

    logging.info("Epoch " + str(epoch))
    train_loss_list = []
    model.train()
    # 开始一个epoch的迭代

    for i, (ID, img, seg_label, US_data, label4, label2) in enumerate(train_dataloader):
        # 数据分为两类, 算法的输入:img 算法的输出 seg_label , (其他还没用到)
        img = img.to(DEVICE)
        seg_label = seg_label.to(DEVICE)
        if args.criterion.strip() == 'BCELoss':
            seg_label = seg_label.float()
        US_data = US_data.to(DEVICE)
        US_label = US_data[:, :args.length_aux]
        label4 = label4.to(DEVICE)

        # mixup
        img, seg_label_a, seg_label_b, US_label_a, US_label_b, label4_a, label4_b, lam = mixup_data(img, seg_label,
                                                                                                    US_label, label4,
                                                                                                    device=DEVICE, alpha=args.alpha)
        # 执行模型,得到输出
        out = model(img)
        out = extractDict(out)


        # 取损失函数
        train_loss = mixup_criterion_type(criterion, out, seg_label_a, seg_label_b, lam)
        # train_loss = criterion(out, seg_label)

        train_loss_list.append(train_loss.item())

        # 使用优化器执行反向传播
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

    train_loss_mean = log_mean(epoch, train_loss_list, "train loss", isLog= True)
    all_quality['train_loss'].append(train_loss_mean)
    return train_loss_mean
Ejemplo n.º 4
0
def testLoss(Loss):
    model = torchvision.models.segmentation.fcn_resnet50(pretrained=False,
                                                         progress=False,
                                                         num_classes=2,
                                                         aux_loss=None)
    input = torch.rand((8, 3, 224, 224))
    label = torch.rand((8, 1, 224, 224))
    label[label >= 0.5] = 1
    label[label < 0.5] = 0
    # label = label.long()
    testEpoch = 100
    for epoch in range(testEpoch):
        output = model(input)
        output = extractDict(output)
        # output = torch.nn.Sigmoid()(output)
        print(output.shape)
        optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
        loss = Loss(output, label)
        print(loss)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
Ejemplo n.º 5
0
def train(epoch):

    # 设置数据集
    train_dataset = MTLDataset.FluidSegDataset(
        str(data_root) + 'TRAIN/', args.seg_root, args.fluid_root, binary_fluid=args.binary_fluid, us_path=us_path, num_classes=NUM_CLASSES, train_or_test='Train',
        screener=rf_sort_list, screen_num=10, seg_channel=NUM_CLASSES)
    train_dataloader = data.DataLoader(
        train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    logging.info("Epoch " + str(epoch))
    train_loss_list = []
    model.train()
    # 开始一个epoch的迭代
    for i, (ID, img, fluid_img, seg_label, US_data, label4, label2) in enumerate(train_dataloader):
        # 数据分为两类, 算法的输入:img 算法的输出 seg_label , (其他还没用到)
        img = img.to(DEVICE)
        seg_label = seg_label.to(DEVICE)
        if args.n_class == 2:
            seg_label = torch.cat((seg_label, 1-seg_label), dim=1)
        fluid_img = fluid_img.to(DEVICE)
        if args.criterion.strip() == 'BCELoss':
            seg_label = seg_label.float()
        US_data = US_data.to(DEVICE)
        US_label = US_data[:, :args.length_aux]
        label4 = label4.to(DEVICE)

        # mixup
        img, fluid_img, seg_label_a, seg_label_b, US_label_a, US_label_b, label4_a, label4_b, lam = mixup_data2(img, fluid_img, seg_label,
                                                                                                    US_label, label4,
                                                                                                    device=DEVICE, alpha=args.alpha)

        # 执行模型,得到输出
        if args.net_input_num == 1:
            input_ = torch.cat((img, fluid_img), dim=1)
            out = model(input_)
        elif args.net_input_num == 2:
            out = model(img, fluid_img)
        # out = model(img, fluid_img)

        if args.ifPointRend:
            rend = out["rend"]
            points = out["points"]
        out = extractDict(out, True)
        # if out.shape[1] == 1:
        #     out = nn.Sigmoid()(out)
        # else:
        #     out = torch.softmax(out, dim=0)
        # 取损失函数

        train_loss = mixup_criterion_type(criterion, out, seg_label_a, seg_label_b, lam)
        train_loss_list.append(train_loss.item())
        if args.ifPointRend:
            gt_points = point_sample(
                seg_label.float(),
                points,
                mode="nearest",
                align_corners=False
            ).long()
            point_loss = F.cross_entropy(rend, gt_points[:,1])
            train_loss = point_loss + train_loss
        # 使用优化器执行反向传播
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

    train_loss_mean = log_mean(epoch, train_loss_list, "train loss", isLog= True)
    all_quality['train_loss'].append(train_loss_mean)
Ejemplo n.º 6
0
def test(epoch):
    test_loss_list = []
    dice2_list = []
    # 设置数据集
    test_dataset = MTLDataset.FluidSegDataset(
        str(data_root) + 'TEST/', args.seg_root, args.fluid_root, binary_fluid=args.binary_fluid, us_path=us_path,  num_classes=NUM_CLASSES, train_or_test='Test',
        screener=rf_sort_list,
        screen_num=10, seg_channel=NUM_CLASSES)
    test_dataloader = data.DataLoader(
        test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    model.eval()
    epoch_quality = {}
    batch_num = 0
    # 开始这轮的迭代
    with torch.no_grad():
        for i, (ID, img, fluid_img, seg_label, US_data, label4, label2) in enumerate(test_dataloader):
            batch_num += 1
            # 数据分为两类, 算法的输入:img 算法的输出 seg_label , (其他还没用到)
            actual_batch_size = len(ID)
            img = img.to(DEVICE)
            seg_label = seg_label.to(DEVICE)
            if args.n_class == 2:
                seg_label = torch.cat((seg_label, 1-seg_label), dim=1)
            fluid_img = fluid_img.to(DEVICE)
            if args.criterion.strip() == 'BCELoss':
                seg_label = seg_label.float()
            US_data = US_data.to(DEVICE)
            # US_label = US_data[:, :args.length_aux]
            # label4 = label4.to(DEVICE)

            # 输出
            if args.net_input_num == 1:
                input_ = torch.cat((img, fluid_img), dim=1)
                output = model(input_)
            elif args.net_input_num == 2:
                output = model(img, fluid_img)
            # output = model(img, fluid_img)
            # seg_label 标签, 注意此时的loss含义已然不同了,未来考虑把这个值去掉
            output = extractDict(output, True)
            if output.shape[1] == 1:
                output = nn.Sigmoid()(output)
            else:
                output = torch.softmax(output, dim=0)

            if args.criterion in ['GDL', 'GWDL']:
                output[output >= 0.5] = 0
                output[output < 0.5] = 1
            else:
                output[output >= 0.5] = 1
                output[output < 0.5] = 0
            if output.shape[3] != 512:
                output = F.interpolate(output, size=512, mode='bilinear', align_corners=True)


            # 记录Loss,计算性能指标
            loss = criterion(output, seg_label)
            test_loss_list.append(loss.item())

            seg_test = seg_label[:,0:1].long()
            output = output[:,0:1]
            dice2 = metrics.dice_index(output, seg_test)
            dice2_list.append(dice2)
            quality, dices = metrics.get_sum_metrics(output, seg_test, count_metrics, printDice=True)
            epoch_quality = dict_sum(epoch_quality, quality)
            # 可视化第一个BATCH
            if i == 0:
                output = output.cpu()
                for j in range(BATCH_SIZE):
                    save_img = transforms.ToPILImage()(output[j][0]).convert('L')
                    path = '../Log/' + args.logdir + '/V' + str(j) + '/'
                    if not os.path.exists(path):
                        os.makedirs(path)
                    save_img.save(path + 'E' + str(epoch) + '_' + ID[j] + '.jpg')
    #       record dice of every img
            for j in range(actual_batch_size):
                all_img_dice[ID[j]].append(dices[j])




    test_loss_mean = log_mean(epoch, test_loss_list, "test loss", isLog= True)
    dice2_mean = log_mean(epoch, dice2_list, "dice2", isLog= True)
    for k in epoch_quality:
        epoch_quality[k] /= len(test_dataset)
    logging.info("Epoch {0} MEAN {1}".format(epoch, epoch_quality))

    for k in epoch_quality:
        all_quality[k].append(epoch_quality[k])
    all_quality['test_loss'].append(test_loss_mean)
    all_quality['dice2'].append(dice2_mean)
Ejemplo n.º 7
0
def test(epoch):
    test_loss_list = []
    dice2_list = []
    # 设置数据集
    test_dataloader = get_dataloader('test')
    model.eval()
    epoch_quality = {}
    batch_num = 0
    # 开始这轮的迭代
    with torch.no_grad():
        for i, (ID, img, seg_label, US_data, label4, label2) in enumerate(test_dataloader):
            batch_num += 1
            # 数据分为两类, 算法的输入:img 算法的输出 seg_label , (其他还没用到)
            img = img.to(DEVICE)
            seg_label = seg_label.long().to(DEVICE)
            if args.criterion.strip() == 'BCELoss':
                seg_label = seg_label.float()
            US_data = US_data.to(DEVICE)
            US_label = US_data[:, :args.length_aux]
            label4 = label4.to(DEVICE)

            # 输出
            output = model(img)
            output = extractDict(output)
            output = nn.Sigmoid()(output)

            # seg_label 标签
            if args.criterion in ['GDL', 'GWDL']:
                output[output >= 0.5] = 0
                output[output < 0.5] = 1
            else:
                output[output >= 0.5] = 1
                output[output < 0.5] = 0
            if output.shape[3] != 512:
                output = F.interpolate(output, size=512, mode='bilinear', align_corners=True)
            seg_test = seg_label.long()



            # 记录Loss,计算性能指标
            # logging.info("Epoch {0} TestLoss {1}".format(epoch, loss.item()))
            loss = criterion(output, seg_label)
            test_loss_list.append(loss.item())
            dice2 = metrics.dice_index(output, seg_test)
            dice2_list.append(dice2)
            quality = metrics.get_sum_metrics(output, seg_test, count_metrics)
            epoch_quality = dict_sum(epoch_quality, quality)
            if i == 0:
                output = output.cpu()
                for j in range(BATCH_SIZE):
                    save_img = transforms.ToPILImage()(output[j][0]).convert('L')
                    path = '../Log/' + args.logdir + '/V' + str(j) + '/'
                    if not os.path.exists(path):
                        os.makedirs(path)
                    save_img.save(path + 'E' + str(epoch) + '_' + ID[j] + '.jpg')

    test_loss_mean = log_mean(epoch, test_loss_list, "test loss", isLog= True)
    dice2_mean = log_mean(epoch, dice2_list, "dice2", isLog= True)
    for k in epoch_quality:
        epoch_quality[k] /= len(test_dataloader.dataset)
    logging.info("Epoch {0} MEAN {1}".format(epoch, epoch_quality))

    for k in epoch_quality:
        all_quality[k].append(epoch_quality[k])
    all_quality['test_loss'].append(test_loss_mean)
    all_quality['dice2'].append(dice2_mean)