예제 #1
0
    test_scores.clear()
    print('original image result')
    for sample in testloader:
        mri = sample['image']
        target = sample['target']
        mri = mri.cuda(device, non_blocking=True)
        target = target.to(device)

        name = 'ori-target'
        save_printer1 = Image3D(vis, name)
        save_printer1(name, torch.clamp(mri[0, :, :, :], min=0, max=1))

        output = model(mri)

        score = torch.nn.functional.softmax(output, 1)
        test_scores.update_true(target)
        test_scores.update_score(score)

    print('accuracy:', test_scores.accuracy)

    ############################ cont test ##################################
    # create latent codes
    discrete_code = FG.d_code
    continuous_code = FG.c_code
    z_dim = FG.z
    sample_num = 1000

    #temp_c = torch.linspace(-2, 2, 5)
    temp_c = torch.linspace(-2, 2, 10)

    sample_z = torch.zeros((sample_num, z_dim))
예제 #2
0
            optimizerinfo.step()

            Q.zero_grad()
            AP.zero_grad()

            # Q network
            z_AP = Q(x_q)
            # AP network
            ap_AP = AP(z_AP)
            # AP loss & back propagation
            #acc = matching(AP_AP.permute(0,2,1).data.cpu().numpy(), AP.permute(0,2,1).numpy())
            score = ap_AP
            loss_AP = BCE_loss(ap_AP, ap_q).mean()
            loss_AP.backward()
            optimizerAP.step()
            train_scores.update_true(y)
            train_scores.update_score(score)
        printers['D_loss']('train', epoch + i / len(trainloader), loss_D)
        printers['G_loss']('train', epoch + i / len(trainloader), loss_G)
        printers['info_loss']('train', epoch + i / len(trainloader), loss_info)
        printers['AP_loss']('train', epoch + i / len(trainloader), loss_AP)

        train_acc = train_scores.accuracy
        printers['acc']('train', epoch + i / len(trainloader), train_acc)
        print("Epoch: [%2d] D_loss: %.5f, G_loss: %.5f, info_loss: %.5f" %
              ((epoch + 1), loss_D.item(), loss_G.item(), loss_info.item()))

        if epoch % (10) == 0:
            #valid_data, valid_AP, _ = getData(data['valid'],FG.batch_size, iteration)
            valid_scores.clear()
            valid_printer = Image3D(vis, 'valid_output')
예제 #3
0
def main():
    # option flags
    FLG = train_args()

    # torch setting
    device = torch.device('cuda:{}'.format(FLG.devices[0]))
    torch.backends.cudnn.benchmark = True
    torch.cuda.set_device(FLG.devices[0])

    # create summary and report the option
    visenv = FLG.model
    summary = Summary(port=39199, env=visenv)
    summary.viz.text(argument_report(FLG, end='<br>'),
                     win='report' + str(FLG.running_fold))
    train_report = ScoreReport()
    valid_report = ScoreReport()
    timer = SimpleTimer()
    fold_str = 'fold' + str(FLG.running_fold)
    best_score = dict(epoch=0, loss=1e+100, accuracy=0)

    #### create dataset ###
    # kfold split
    target_dict = np.load(pjoin(FLG.data_root, 'target_dict.pkl'))
    trainblock, validblock, ratio = fold_split(
        FLG.fold, FLG.running_fold, FLG.labels,
        np.load(pjoin(FLG.data_root, 'subject_indices.npy')), target_dict)

    def _dataset(block, transform):
        return ADNIDataset(FLG.labels,
                           pjoin(FLG.data_root, FLG.modal),
                           block,
                           target_dict,
                           transform=transform)

    # create train set
    trainset = _dataset(trainblock, transform_presets(FLG.augmentation))

    # create normal valid set
    validset = _dataset(
        validblock,
        transform_presets('nine crop' if FLG.augmentation ==
                          'random crop' else 'no augmentation'))

    # each loader
    trainloader = DataLoader(trainset,
                             batch_size=FLG.batch_size,
                             shuffle=True,
                             num_workers=4,
                             pin_memory=True)
    validloader = DataLoader(validset, num_workers=4, pin_memory=True)

    # data check
    # for image, _ in trainloader:
    #     summary.image3d('asdf', image)

    # create model
    def kaiming_init(tensor):
        return kaiming_normal_(tensor, mode='fan_out', nonlinearity='relu')

    if 'plane' in FLG.model:
        model = Plane(len(FLG.labels),
                      name=FLG.model,
                      weights_initializer=kaiming_init)
    elif 'resnet11' in FLG.model:
        model = resnet11(len(FLG.labels),
                         FLG.model,
                         weights_initializer=kaiming_init)
    elif 'resnet19' in FLG.model:
        model = resnet19(len(FLG.labels),
                         FLG.model,
                         weights_initializer=kaiming_init)
    elif 'resnet35' in FLG.model:
        model = resnet35(len(FLG.labels),
                         FLG.model,
                         weights_initializer=kaiming_init)
    elif 'resnet51' in FLG.model:
        model = resnet51(len(FLG.labels),
                         FLG.model,
                         weights_initializer=kaiming_init)
    else:
        raise NotImplementedError(FLG.model)

    print_model_parameters(model)
    model = torch.nn.DataParallel(model, FLG.devices)
    model.to(device)

    # criterion
    train_criterion = torch.nn.CrossEntropyLoss(weight=torch.Tensor(
        list(map(lambda x: x * 2, reversed(ratio))))).to(device)
    valid_criterion = torch.nn.CrossEntropyLoss().to(device)

    # TODO resume
    # optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=FLG.lr,
                                 weight_decay=FLG.l2_decay)
    # scheduler
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, FLG.lr_gamma)

    start_epoch = 0
    global_step = start_epoch * len(trainloader)
    pbar = None
    for epoch in range(1, FLG.max_epoch + 1):
        timer.tic()
        scheduler.step()
        summary.scalar('lr',
                       fold_str,
                       epoch - 1,
                       optimizer.param_groups[0]['lr'],
                       ytickmin=0,
                       ytickmax=FLG.lr)

        # train()
        torch.set_grad_enabled(True)
        model.train(True)
        train_report.clear()
        if pbar is None:
            pbar = tqdm(total=len(trainloader) * FLG.validation_term,
                        desc='Epoch {:<3}-{:>3} train'.format(
                            epoch, epoch + FLG.validation_term - 1))
        for images, targets in trainloader:
            images = images.cuda(device, non_blocking=True)
            targets = targets.cuda(device, non_blocking=True)

            optimizer.zero_grad()

            outputs = model(images)
            loss = train_criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_report.update_true(targets)
            train_report.update_score(F.softmax(outputs, dim=1))

            summary.scalar('loss',
                           'train ' + fold_str,
                           global_step / len(trainloader),
                           loss.item(),
                           ytickmin=0,
                           ytickmax=1)

            pbar.update()
            global_step += 1

        if epoch % FLG.validation_term != 0:
            timer.toc()
            continue
        pbar.close()

        # valid()
        torch.set_grad_enabled(False)
        model.eval()
        valid_report.clear()
        pbar = tqdm(total=len(validloader),
                    desc='Epoch {:>3} valid'.format(epoch))
        for images, targets in validloader:
            true = targets
            npatchs = 1
            if len(images.shape) == 6:
                _, npatchs, c, x, y, z = images.shape
                images = images.view(-1, c, x, y, z)
                targets = torch.cat([targets
                                     for _ in range(npatchs)]).squeeze()
            images = images.cuda(device, non_blocking=True)
            targets = targets.cuda(device, non_blocking=True)

            output = model(images)
            loss = valid_criterion(output, targets)

            valid_report.loss += loss.item()

            if npatchs == 1:
                score = F.softmax(output, dim=1)
            else:
                score = torch.mean(F.softmax(output, dim=1),
                                   dim=0,
                                   keepdim=True)
            valid_report.update_true(true)
            valid_report.update_score(score)

            pbar.update()
        pbar.close()

        # report
        vloss = valid_report.loss / len(validloader)
        summary.scalar('accuracy',
                       'train ' + fold_str,
                       epoch,
                       train_report.accuracy,
                       ytickmin=-0.05,
                       ytickmax=1.05)

        summary.scalar('loss',
                       'valid ' + fold_str,
                       epoch,
                       vloss,
                       ytickmin=0,
                       ytickmax=0.8)
        summary.scalar('accuracy',
                       'valid ' + fold_str,
                       epoch,
                       valid_report.accuracy,
                       ytickmin=-0.05,
                       ytickmax=1.05)

        is_best = False
        if best_score['loss'] > vloss:
            best_score['loss'] = vloss
            best_score['epoch'] = epoch
            best_score['accuracy'] = valid_report.accuracy
            is_best = True

        print('Best Epoch {}: validation loss {} accuracy {}'.format(
            best_score['epoch'], best_score['loss'], best_score['accuracy']))

        # save
        if isinstance(model, torch.nn.DataParallel):
            state_dict = model.module.state_dict()
        else:
            state_dict = model.state_dict()

        save_checkpoint(
            dict(epoch=epoch,
                 best_score=best_score,
                 state_dict=state_dict,
                 optimizer_state_dict=optimizer.state_dict()),
            FLG.checkpoint_root, FLG.running_fold, FLG.model, is_best)
        pbar = None
        timer.toc()
        print('Time elapse {}h {}m {}s'.format(*timer.total()))
예제 #4
0
            loss_list.append(loss.item())

            # Backprop and perform Adam optimisation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Track the accuracy
            total = labels.size(0)
            _, predicted = torch.max(outputs.data, 1)
            correct = (predicted == labels).sum().item()
            acc_list.append(correct / total)
            printers['loss']('train', epoch + i / len(train_loader), loss)

            score = torch.nn.functional.softmax(outputs, 1)
            train_scores.update_true(labels)
            train_scores.update_score(score)

            if (i + 1) % 100 == 0:
                print(
                    'Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%'
                    .format(epoch + 1, num_epochs, i + 1, total_step,
                            loss.item(), (correct / total) * 100))
        printers['acc']('train', epoch, train_scores.accuracy)

        ######################### testing ###########################
        # Test the model
        model.eval()
        G.eval()
        with torch.no_grad():
            correct = 0
예제 #5
0
def test(FLG):
    device = torch.device('cuda:{}'.format(FLG.devices[0]))
    torch.set_grad_enabled(False)
    torch.backends.cudnn.benchmark = True
    torch.cuda.set_device(FLG.devices[0])
    report = [ScoreReport() for _ in range(FLG.fold)]
    overall_report = ScoreReport()
    target_dict = np.load(pjoin(FLG.data_root, 'target_dict.pkl'))

    if 'plane' in FLG.model:
        model = Plane(len(FLG.labels), name=FLG.model)
    elif 'resnet11' in FLG.model:
        model = resnet11(len(FLG.labels), FLG.model)
    elif 'resnet19' in FLG.model:
        model = resnet19(len(FLG.labels), FLG.model)
    elif 'resnet35' in FLG.model:
        model = resnet35(len(FLG.labels), FLG.model)
    elif 'resnet51' in FLG.model:
        model = resnet51(len(FLG.labels), FLG.model)
    else:
        raise NotImplementedError(FLG.model)
    model.to(device)

    for running_fold in range(FLG.fold):
        _, validblock, _ = fold_split(
            FLG.fold, running_fold, FLG.labels,
            np.load(pjoin(FLG.data_root, 'subject_indices.npy')), target_dict)
        validset = ADNIDataset(FLG.labels,
                               pjoin(FLG.data_root, FLG.modal),
                               validblock,
                               target_dict,
                               transform=transform_presets(FLG.augmentation))
        validloader = DataLoader(validset, pin_memory=True)

        epoch, _ = load_checkpoint(model, FLG.checkpoint_root, running_fold,
                                   FLG.model, None, True)
        model.eval()
        for image, target in validloader:
            true = target
            npatches = 1
            if len(image.shape) == 6:
                _, npatches, c, x, y, z = image.shape
                image = image.view(-1, c, x, y, z)
                target = torch.stack([target
                                      for _ in range(npatches)]).squeeze()
            image = image.cuda(device, non_blocking=True)
            target = target.cuda(device, non_blocking=True)

            output = model(image)

            if npatches == 1:
                score = F.softmax(output, dim=1)
            else:
                score = torch.mean(F.softmax(output, dim=1),
                                   dim=0,
                                   keepdim=True)

            report[running_fold].update_true(true)
            report[running_fold].update_score(score)

            overall_report.update_true(true)
            overall_report.update_score(score)

        print('At {}'.format(epoch))
        print(
            metrics.classification_report(report[running_fold].y_true,
                                          report[running_fold].y_pred,
                                          target_names=FLG.labels,
                                          digits=4))
        print('accuracy {}'.format(report[running_fold].accuracy))

    print('over all')
    print(
        metrics.classification_report(overall_report.y_true,
                                      overall_report.y_pred,
                                      target_names=FLG.labels,
                                      digits=4))
    print('accuracy {}'.format(overall_report.accuracy))

    with open(FLG.model + '_stat.pkl', 'wb') as f:
        pickle.dump(report, f, pickle.HIGHEST_PROTOCOL)
예제 #6
0
            target = data['target']
            images = image.cuda(device, non_blocking=True)
            target = target.type(torch.LongTensor).cuda(device,
                                                        non_blocking=True)

            printers['train_input']('train_input', images[1, :, :, :])

            # Backprop and perform Adam optimisation
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer.step()

            score = torch.nn.functional.softmax(outputs, 1)
            train_scores.update_true(target)
            train_scores.update_score(score)
            printers['loss']('train_f{}'.format(FG.running_fold),
                             epoch + i / len(trainloader), loss)
            train_pbar.update()
        train_pbar.close()

        printers['acc']('train_f{}'.format(FG.running_fold), epoch,
                        train_scores.accuracy)

        ############################ classification : test #############################
        model.eval()
        torch.set_grad_enabled(False)
        test_scores.clear()

        test_pbar = tqdm(total=len(testloader),
예제 #7
0
def test(FLG):
    device = torch.device('cuda:{}'.format(FLG.devices[0]))
    torch.set_grad_enabled(False)
    torch.backends.cudnn.benchmark = True
    torch.cuda.set_device(FLG.devices[0])
    report = [ScoreReport() for _ in range(FLG.fold)]
    overall_report = ScoreReport()
    target_dict = np.load(pjoin(FLG.data_root, 'target_dict.pkl'))

    with open(FLG.model + '_stat.pkl', 'rb') as f:
        stat = pickle.load(f)
    summary = Summary(port=10001, env=str(FLG.model) + 'CAM')

    class Feature(object):
        def __init__(self):
            self.blob = None

        def capture(self, blob):
            self.blob = blob

    if 'plane' in FLG.model:
        model = Plane(len(FLG.labels), name=FLG.model)
    elif 'resnet11' in FLG.model:
        model = resnet11(len(FLG.labels), FLG.model)
    elif 'resnet19' in FLG.model:
        model = resnet19(len(FLG.labels), FLG.model)
    elif 'resnet35' in FLG.model:
        model = resnet35(len(FLG.labels), FLG.model)
    elif 'resnet51' in FLG.model:
        model = resnet51(len(FLG.labels), FLG.model)
    else:
        raise NotImplementedError(FLG.model)
    model.to(device)

    ad_h = []
    nl_h = []
    adcams = np.zeros((4, 3, 112, 144, 112), dtype="f8")
    nlcams = np.zeros((4, 3, 112, 144, 112), dtype="f8")
    sb = [9.996e-01, 6.3e-01, 1.001e-01]
    for running_fold in range(FLG.fold):
        _, validblock, _ = fold_split(
            FLG.fold, running_fold, FLG.labels,
            np.load(pjoin(FLG.data_root, 'subject_indices.npy')), target_dict)
        validset = ADNIDataset(FLG.labels,
                               pjoin(FLG.data_root, FLG.modal),
                               validblock,
                               target_dict,
                               transform=transform_presets(FLG.augmentation))
        validloader = DataLoader(validset, pin_memory=True)

        epoch, _ = load_checkpoint(model, FLG.checkpoint_root, running_fold,
                                   FLG.model, None, True)
        model.eval()
        feature = Feature()

        def hook(mod, inp, oup):
            return feature.capture(oup.data.cpu().numpy())

        _ = model.layer4.register_forward_hook(hook)
        fc_weights = model.fc.weight.data.cpu().numpy()

        transformer = Compose([CenterCrop((112, 144, 112)), ToFloatTensor()])
        im, _ = original_load(validblock, target_dict, transformer, device)

        for image, target in validloader:
            true = target
            npatches = 1
            if len(image.shape) == 6:
                _, npatches, c, x, y, z = image.shape
                image = image.view(-1, c, x, y, z)
                target = torch.stack([target
                                      for _ in range(npatches)]).squeeze()
            image = image.cuda(device, non_blocking=True)
            target = target.cuda(device, non_blocking=True)

            output = model(image)

            if npatches == 1:
                score = F.softmax(output, dim=1)
            else:
                score = torch.mean(F.softmax(output, dim=1),
                                   dim=0,
                                   keepdim=True)

            report[running_fold].update_true(true)
            report[running_fold].update_score(score)

            overall_report.update_true(true)
            overall_report.update_score(score)

            print(target)
            if FLG.cam:
                s = 0
                cams = []
                if target[0] == 0:
                    s = score[0][0]
                    #s = s.cpu().numpy()[()]
                    cams = adcams
                else:
                    sn = score[0][1]
                    #s = s.cpu().numpy()[()]
                    cams = nlcams
                if s > sb[0]:
                    cams[0] = summary.cam3d(FLG.labels[target],
                                            im,
                                            feature.blob,
                                            fc_weights,
                                            target,
                                            cams[0],
                                            s,
                                            num_images=5)
                elif s > sb[1]:
                    cams[1] = summary.cam3d(FLG.labels[target],
                                            im,
                                            feature.blob,
                                            fc_weights,
                                            target,
                                            cams[1],
                                            s,
                                            num_images=5)
                elif s > sb[2]:
                    cams[2] = summary.cam3d(FLG.labels[target],
                                            im,
                                            feature.blob,
                                            fc_weights,
                                            target,
                                            cams[2],
                                            s,
                                            num_images=5)
                else:
                    cams[3] = summary.cam3d(FLG.labels[target],
                                            im,
                                            feature.blob,
                                            fc_weights,
                                            target,
                                            cams[3],
                                            s,
                                            num_images=5)
                #ad_h += [s]
                #nl_h += [sn]

        print('At {}'.format(epoch))
        print(
            metrics.classification_report(report[running_fold].y_true,
                                          report[running_fold].y_pred,
                                          target_names=FLG.labels,
                                          digits=4))
        print('accuracy {}'.format(report[running_fold].accuracy))

    #print(np.histogram(ad_h))
    #print(np.histogram(nl_h))

    print('over all')
    print(
        metrics.classification_report(overall_report.y_true,
                                      overall_report.y_pred,
                                      target_names=FLG.labels,
                                      digits=4))
    print('accuracy {}'.format(overall_report.accuracy))

    with open(FLG.model + '_stat.pkl', 'wb') as f:
        pickle.dump(report, f, pickle.HIGHEST_PROTOCOL)