示例#1
0
def Train():
    train_loader, validation_loader = LoadTVData()
    model = ResNet(Bottleneck, [3, 4, 6, 3]).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    train_loss = 0.0
    valid_loss = 0.0
    cla_criterion = nn.BCEWithLogitsLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=20, factor=0.5, verbose=True)
    early_stopping = EarlyStopping(patience=100, verbose=True)
    writer = SummaryWriter(log_dir=graph_path, comment='Net')

    for epoch in range(1000):
        train_loss_list, valid_loss_list = [], []
        class_list, class_pred_list = [], []

        model.train()
        for i, (inputs, outputs) in enumerate(train_loader):
            t2, dwi, adc, roi, prostate = inputs[0], inputs[1], inputs[2], inputs[3], inputs[4]
            ece = np.squeeze(outputs, axis=1)

            inputs = torch.cat([t2, dwi, adc, roi, prostate], axis=1)
            inputs = inputs.type(torch.FloatTensor).to(device)

            ece = ece.type(torch.FloatTensor).to(device)

            optimizer.zero_grad()

            class_out, _ = model(inputs)
            class_out = torch.squeeze(class_out, dim=1)
            class_out_sigmoid = class_out.sigmoid()

            loss = cla_criterion(class_out, ece)

            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_loss_list.append(loss.item())

            # compute auc
            class_list.extend(list(ece.cpu().numpy()))
            class_pred_list.extend(list(class_out_sigmoid.cpu().detach().numpy()))

            if (i + 1) % 10 == 0:
                print('Epoch [%d / %d], Iter [%d], Train Loss: %.4f' %(epoch + 1, 1000, i + 1, train_loss / 10))
                print(list(class_out_sigmoid.cpu().detach().numpy()))
                train_loss = 0.0

        _, _, train_auc = get_auc(class_pred_list, class_list)
        class_list, class_pred_list = [], []

        model.eval()
        with torch.no_grad():
            for i, (inputs, outputs) in enumerate(validation_loader):
                t2, dwi, adc, roi, prostate = inputs[0], inputs[1], inputs[2], inputs[3], inputs[4]
                ece = np.squeeze(outputs, axis=1)

                inputs = torch.cat([t2, dwi, adc, roi, prostate], axis=1)
                inputs = inputs.type(torch.FloatTensor).to(device)

                ece = ece.type(torch.FloatTensor).to(device)

                class_out, _ = model(inputs)
                class_out = torch.squeeze(class_out, dim=1)
                class_out_sigmoid = class_out.sigmoid()

                loss = cla_criterion(class_out, ece)

                valid_loss += loss.item()
                valid_loss_list.append(loss.item())

                # compute auc
                class_list.extend(list(ece.cpu().numpy()))
                class_pred_list.extend(list(class_out_sigmoid.cpu().detach().numpy()))

                if (i + 1) % 10 == 0:
                    print('Epoch [%d / %d], Iter [%d],  Valid Loss: %.4f' %(epoch + 1, 1000, i + 1, valid_loss / 10))
                    print(list(class_out_sigmoid.cpu().detach().numpy()))
                    valid_loss = 0.0
            _, _, valid_auc = get_auc(class_pred_list, class_list)

        for index, (name, param) in enumerate(model.named_parameters()):
            if 'bn' not in name:
                writer.add_histogram(name + '_data', param.cpu().data.numpy(), epoch + 1)

        writer.add_scalars('Train_Val_Loss',
                           {'train_loss': np.mean(train_loss_list), 'val_loss': np.mean(valid_loss_list)}, epoch + 1)
        writer.add_scalars('Train_Val_auc',
                           {'train_auc': train_auc, 'val_auc': valid_auc}, epoch + 1)
        writer.close()

        print('Epoch:', epoch + 1, 'Training Loss:', np.mean(train_loss_list), 'Valid Loss:',
              np.mean(valid_loss_list), 'Train auc:', train_auc, 'Valid auc:', valid_auc)

        scheduler.step(np.mean(valid_loss_list))
        early_stopping(sum(valid_loss_list)/len(valid_loss_list), model, save_path=model_folder, evaluation=min)

        if early_stopping.early_stop:
            print("Early stopping")
            break
示例#2
0
def Train():
    sub_train = []
    sub_val = []
    param_config = {
        RotateTransform.name: {'theta': ['uniform', -10, 10]},
        ShiftTransform.name: {'horizontal_shift': ['uniform', -0.05, 0.05],
                              'vertical_shift': ['uniform', -0.05, 0.05]},
        ZoomTransform.name: {'horizontal_zoom': ['uniform', 0.95, 1.05],
                             'vertical_zoom': ['uniform', 0.95, 1.05]},
        FlipTransform.name: {'horizontal_flip': ['choice', True, False]},
        BiasTransform.name: {'center': ['uniform', -1., 1., 2],
                             'drop_ratio': ['uniform', 0., 1.]},
        NoiseTransform.name: {'noise_sigma': ['uniform', 0., 0.03]},
        ContrastTransform.name: {'factor': ['uniform', 0.8, 1.2]},
        GammaTransform.name: {'gamma': ['uniform', 0.8, 1.2]},
        ElasticTransform.name: ['elastic', 1, 0.1, 256]
    }
    input_shape = []
    batch_size = []

    train_loader, train_batches = _GetLoader(sub_train, param_config, input_shape, batch_size, True)
    val_loader, val_batches = _GetLoader(sub_val, param_config, input_shape, batch_size, True)

    torch.autograd.set_detect_anomaly(True)
    model = ResNet(Bottleneck, [3, 4, 6, 3]).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    train_loss = 0.0
    valid_loss = 0.0
    cla_criterion = nn.BCEWithLogitsLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=20, factor=0.5, verbose=True)
    early_stopping = EarlyStopping(patience=100, verbose=True)
    writer = SummaryWriter(log_dir=graph_path, comment='Net')
    model.apply(HeWeightInit)

    for epoch in range(1000):
        train_loss_list, valid_loss_list = [], []
        class_list, class_pred_list = [], []

        model.train()
        for i, (inputs, outputs) in enumerate(train_loader):
            t2, dwi, adc, roi, prostate = inputs[0], inputs[1], inputs[2], inputs[3], inputs[4]
            ece = np.squeeze(outputs, axis=1)

            inputs = torch.cat([t2, dwi, adc, roi, prostate], dim=1)
            inputs = inputs.float().to(device)

            ece = ece.float().to(device)

            optimizer.zero_grad()

            class_out, _ = model(inputs)
            class_out = torch.squeeze(class_out, dim=1)
            class_out_sigmoid = class_out.sigmoid()

            loss = cla_criterion(class_out, ece)

            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_loss_list.append(loss.item())

            # compute auc
            class_list.extend(list(ece.cpu().numpy()))
            class_pred_list.extend(list(class_out_sigmoid.cpu().detach().numpy()))

            if (i + 1) % 10 == 0:
                print('Epoch [%d / %d], Iter [%d], Train Loss: %.4f' % (epoch + 1, 1000, i + 1, train_loss / 10))
                print(list(class_out_sigmoid.cpu().detach().numpy()))
                train_loss = 0.0


        model.eval()
        with torch.no_grad():
            for i, (inputs, outputs) in enumerate(val_loader):
                t2, dwi, adc, roi, prostate = inputs[0], inputs[1], inputs[2], inputs[3], inputs[4]
                ece = np.squeeze(outputs, axis=1)

                inputs = torch.cat([t2, dwi, adc, roi, prostate], dim=1)
                inputs = inputs.float().to(device)

                ece = ece.type(torch.FloatTensor).to(device)

                class_out, _ = model(inputs)
                class_out = torch.squeeze(class_out, dim=1)
                class_out_sigmoid = class_out.sigmoid()

                loss = cla_criterion(class_out, ece)

                valid_loss += loss.item()
                valid_loss_list.append(loss.item())

                # compute auc
                class_list.extend(list(ece.cpu().numpy()))
                class_pred_list.extend(list(class_out_sigmoid.cpu().detach().numpy()))

                if (i + 1) % 10 == 0:
                    print('Epoch [%d / %d], Iter [%d],  Valid Loss: %.4f' % (epoch + 1, 1000, i + 1, valid_loss / 10))
                    print(list(class_out_sigmoid.cpu().detach().numpy()))
                    valid_loss = 0.0

        for index, (name, param) in enumerate(model.named_parameters()):
            if 'bn' not in name:
                writer.add_histogram(name + '_data', param.cpu().data.numpy(), epoch + 1)

        writer.add_scalars('Train_Val_Loss',
                           {'train_loss': np.mean(train_loss_list), 'val_loss': np.mean(valid_loss_list)}, epoch + 1)
        writer.close()

        print('Epoch:', epoch + 1, 'Training Loss:', np.mean(train_loss_list), 'Valid Loss:')

        scheduler.step(np.mean(valid_loss_list))
        early_stopping(sum(valid_loss_list)/len(valid_loss_list), model, save_path=model_folder, evaluation=min)

        if early_stopping.early_stop:
            print("Early stopping")
            break
示例#3
0
def Train():
    ClearGraphPath()
    train_loader, validation_loader = LoadTVData(is_test=False, folder=data_folder, setname=['PreTrain', 'PreValid'])
    model = MultiSegModel(in_channels=3, out_channels=1).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    seg_criterion = nn.BCELoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=20, factor=0.5, verbose=True)
    early_stopping = EarlyStopping(patience=100, verbose=True)
    writer = SummaryWriter(log_dir=graph_path, comment='Net')

    for epoch in range(1000):
        train_loss1_list, valid_loss1_list = [], []
        train_loss2_list, valid_loss2_list = [], []
        train_loss_list, valid_loss_list = [], []
        train_loss1 = 0.0
        train_loss2 = 0.0
        train_loss = 0.0
        valid_loss1 = 0.0
        valid_loss2 = 0.0
        valid_loss = 0.0

        model.train()
        for i, (inputs, outputs) in enumerate(train_loader):
            t2, dwi, adc = inputs[0], inputs[1], inputs[2],
            roi, prostate = outputs[0].to(device), outputs[1].to(device)

            inputs = torch.cat([t2, dwi, adc], dim=1)
            inputs = inputs.type(torch.FloatTensor).to(device)

            roi_out, prostate_out = model(inputs)

            loss1 = seg_criterion(roi_out, roi)
            loss2 = seg_criterion(prostate_out, prostate)
            loss = loss1 + loss2

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss1 += loss1.item()
            train_loss1_list.append(loss1.item())

            train_loss2 += loss2.item()
            train_loss2_list.append(loss2.item())

            train_loss += loss.item()
            train_loss_list.append(loss.item())

            if (i + 1) % 10 == 0:
                print('Epoch [%d / %d], Iter [%d], Cancer train Loss: %.4f, Prostate train Loss: %.4f, Loss: %.4f' %
                      (epoch + 1, 1000, i + 1, train_loss1 / 10, train_loss2 / 10,  train_loss / 10))
                train_loss = 0.0
                train_loss1 = 0.0
                train_loss2 = 0.0

        model.eval()
        with torch.no_grad():
            for i, (inputs, outputs) in enumerate(validation_loader):
                t2, dwi, adc = inputs[0], inputs[1], inputs[2],
                roi, prostate = outputs[0].to(device), outputs[1].to(device)

                inputs = torch.cat([t2, dwi, adc], dim=1)
                inputs = inputs.type(torch.FloatTensor).to(device)


                roi_out, prostate_out = model(inputs)

                loss1 = seg_criterion(roi_out, roi)
                loss2 = seg_criterion(prostate_out, prostate)
                loss = loss1 + loss2

                valid_loss1 += loss1.item()
                valid_loss1_list.append(loss1.item())

                valid_loss2 += loss2.item()
                valid_loss2_list.append(loss2.item())

                valid_loss += loss.item()
                valid_loss_list.append(loss.item())

                if (i + 1) % 10 == 0:
                    print('Epoch [%d / %d], Iter [%d], Cancer validation Loss: %.4f, Prostate validation Loss: %.4f, Loss: %.4f' %
                          (epoch + 1, 1000, i + 1, valid_loss1 / 10, valid_loss2 / 10, valid_loss / 10))
                    valid_loss1 = 0.0
                    valid_loss2 = 0.0
                    valid_loss = 0.0

        for index, (name, param) in enumerate(model.named_parameters()):
            if 'bn' not in name:
                writer.add_histogram(name + '_data', param.cpu().data.numpy(), epoch + 1)

        writer.add_scalars('Train_Val_Loss1',
                           {'train_cancer_dice_loss': np.mean(train_loss1_list), 'val_cancer_dice_loss': np.mean(valid_loss1_list)}, epoch + 1)
        writer.add_scalars('Train_Val_Loss2',
                           {'train_prostate_dice_loss': np.mean(train_loss2_list), 'val_prostate_dice_loss': np.mean(valid_loss2_list)}, epoch + 1)
        writer.add_scalars('Train_Val_Loss',
                           {'train_loss': np.mean(train_loss_list), 'val_loss': np.mean(valid_loss_list)}, epoch + 1)
        writer.close()

        # print('Epoch:', epoch + 1, 'Training Loss:', np.mean(train_loss_list), 'Valid Loss:', np.mean(valid_loss_list))

        scheduler.step(np.mean(valid_loss_list))
        early_stopping(sum(valid_loss_list)/len(valid_loss_list), model, save_path=model_folder, evaluation=min)

        if early_stopping.early_stop:
            print("Early stopping")
            break
示例#4
0
def Train():
    ClearGraphPath()
    train_loader, validation_loader = LoadTVData()
    model = MultiTaskModel(in_channels=3, out_channels=1).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    train_loss1 = 0.0
    train_loss2 = 0.0
    train_loss3 = 0.0
    train_loss = 0.0
    valid_loss1 = 0.0
    valid_loss2 = 0.0
    valid_loss3 = 0.0
    valid_loss = 0.0
    seg_criterion1 = DiceLoss()
    seg_criterion2 = DiceLoss()
    cla_criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='min',
                                                           patience=20,
                                                           factor=0.5,
                                                           verbose=True)
    early_stopping = EarlyStopping(patience=100, verbose=True)
    writer = SummaryWriter(log_dir=graph_path, comment='Net')

    for epoch in range(1000):
        train_loss1_list, valid_loss1_list = [], []
        train_loss2_list, valid_loss2_list = [], []
        train_loss3_list, valid_loss3_list = [], []
        train_loss_list, valid_loss_list = [], []
        class_list, class_pred_list = [], []

        model.train()
        for i, (inputs, outputs) in enumerate(train_loader):
            t2, dwi, adc = inputs[0], inputs[1], inputs[2],
            ece, roi, prostate = np.squeeze(
                outputs[0],
                axis=1), outputs[1].to(device), outputs[2].to(device)

            inputs = torch.cat([t2, dwi, adc], axis=1)
            inputs = inputs.type(torch.FloatTensor).to(device)

            ece = np.argmax(ece, axis=1)
            ece = ece.type(torch.LongTensor).to(device)

            optimizer.zero_grad()

            roi_out, prostate_out, class_out, _ = model(inputs)
            class_out_softmax = nn.functional.softmax(class_out, dim=1)

            loss1 = seg_criterion1(roi_out, roi)
            loss2 = seg_criterion2(prostate_out, prostate)
            loss3 = cla_criterion(class_out, ece)
            loss = loss1 + loss2 + loss3

            loss.backward()
            optimizer.step()

            train_loss1 += loss1.item()
            train_loss1_list.append(loss1.item())

            train_loss2 += loss2.item()
            train_loss2_list.append(loss2.item())

            train_loss3 += loss3.item()
            train_loss3_list.append(loss3.item())

            train_loss += loss.item()
            train_loss_list.append(loss.item())

            # compute auc
            class_list.extend(list(ece.cpu().numpy()))
            class_pred_list.extend(
                list(class_out_softmax.cpu().detach().numpy()[..., 1]))

            if (i + 1) % 10 == 0:
                print(
                    'Epoch [%d / %d], Iter [%d], Cancer train Loss: %.4f, Prostate train Loss: %.4f, ECE train Loss: %.4f, Loss: %.4f'
                    % (epoch + 1, 1000, i + 1, train_loss1 / 10,
                       train_loss2 / 10, train_loss3 / 10, train_loss / 10))
                train_loss = 0.0
                train_loss1 = 0.0
                train_loss2 = 0.0
                train_loss3 = 0.0

        _, _, train_auc = get_auc(class_pred_list, class_list)
        class_list, class_pred_list = [], []

        model.eval()
        with torch.no_grad():
            for i, (inputs, outputs) in enumerate(validation_loader):
                t2, dwi, adc = inputs[0], inputs[1], inputs[2],
                ece, roi, prostate = np.squeeze(
                    outputs[0],
                    axis=1), outputs[1].to(device), outputs[2].to(device)

                inputs = torch.cat([t2, dwi, adc], axis=1)
                inputs = inputs.type(torch.FloatTensor).to(device)

                ece = np.argmax(ece, axis=1)
                ece = ece.type(torch.LongTensor).to(device)

                roi_out, prostate_out, class_out, _ = model(inputs)
                class_out_softmax = nn.functional.softmax(class_out, dim=1)

                loss1 = seg_criterion1(roi_out, roi)
                loss2 = seg_criterion2(prostate_out, prostate)
                loss3 = cla_criterion(class_out, ece)
                loss = loss1 + loss2 + loss3

                valid_loss1 += loss1.item()
                valid_loss1_list.append(loss1.item())

                valid_loss2 += loss2.item()
                valid_loss2_list.append(loss2.item())

                valid_loss3 += loss3.item()
                valid_loss3_list.append(loss3.item())

                valid_loss += loss.item()
                valid_loss_list.append(loss.item())

                # compute auc
                class_list.extend(list(ece.cpu().numpy()))
                class_pred_list.extend(
                    list(class_out_softmax.cpu().detach().numpy()[..., 1]))

                if (i + 1) % 10 == 0:
                    print(
                        'Epoch [%d / %d], Iter [%d], Cancer validation Loss: %.4f, Prostate validation Loss: %.4f, ECE validation Loss: %.4f, Loss: %.4f'
                        %
                        (epoch + 1, 1000, i + 1, valid_loss1 / 10,
                         valid_loss2 / 10, valid_loss3 / 10, valid_loss / 10))
                    valid_loss1 = 0.0
                    valid_loss2 = 0.0
                    valid_loss3 = 0.0
                    valid_loss = 0.0
            _, _, valid_auc = get_auc(class_pred_list, class_list)

        for index, (name, param) in enumerate(model.named_parameters()):
            if 'bn' not in name:
                # writer.add_histogram(name+'_grad', param.grad.cpu().data.numpy(), epoch+1)
                writer.add_histogram(name + '_data',
                                     param.cpu().data.numpy(), epoch + 1)

        writer.add_scalars(
            'Train_Val_Loss1', {
                'train_cancer_dice_loss': np.mean(train_loss1_list),
                'val_cancer_dice_loss': np.mean(valid_loss1_list)
            }, epoch + 1)
        writer.add_scalars(
            'Train_Val_Loss2', {
                'train_prostate_dice_loss': np.mean(train_loss2_list),
                'val_prostate_dice_loss': np.mean(valid_loss2_list)
            }, epoch + 1)
        writer.add_scalars(
            'Train_Val_Loss3', {
                'train_bce_loss': np.mean(train_loss3_list),
                'val_bce_loss': np.mean(valid_loss3_list)
            }, epoch + 1)
        writer.add_scalars(
            'Train_Val_Loss', {
                'train_loss': np.mean(train_loss_list),
                'val_loss': np.mean(valid_loss_list)
            }, epoch + 1)
        writer.add_scalars('Train_Val_auc', {
            'train_auc': train_auc,
            'val_auc': valid_auc
        }, epoch + 1)
        writer.close()

        print('Epoch:', epoch + 1,
              'Training Loss:', np.mean(train_loss_list), 'Valid Loss:',
              np.mean(valid_loss_list), 'Train auc:', train_auc, 'Valid auc:',
              valid_auc)

        scheduler.step(np.mean(valid_loss_list))
        early_stopping(sum(valid_loss_list) / len(valid_loss_list),
                       model,
                       save_path=model_folder,
                       evaluation=min)

        if early_stopping.early_stop:
            print("Early stopping")
            break