예제 #1
0
    accs = []
    start_time = time.time()
    EPS = 1e-12

    train_scores = ScoreReport()
    valid_scores = ScoreReport()
    max_acc = 0
    min_loss = 0

    for epoch in range(FG.num_epoch):
        timer.tic()
        D.zero_grad()
        G.zero_grad()
        Q.zero_grad()  #Encoder
        AP.zero_grad()
        train_scores.clear()

        if (epoch + 1) % 100 == 0:
            schedularD.step()
            schedularG.step()
            schedularinfo.step()
        else:
            pass
        schedularAP.step()
        printers['lr']('D', epoch, optimizerD.param_groups[0]['lr'])
        printers['lr']('G', epoch, optimizerG.param_groups[0]['lr'])
        printers['lr']('info', epoch, optimizerinfo.param_groups[0]['lr'])
        printers['lr']('AP', epoch, optimizerAP.param_groups[0]['lr'])
        for i, data in enumerate(trainloader):
            x = data['image']
            xi = data['image']
예제 #2
0
def main():
    # option flags
    FLG = train_args()

    # torch setting
    device = torch.device('cuda:{}'.format(FLG.devices[0]))
    torch.backends.cudnn.benchmark = True
    torch.cuda.set_device(FLG.devices[0])

    # create summary and report the option
    visenv = FLG.model
    summary = Summary(port=39199, env=visenv)
    summary.viz.text(argument_report(FLG, end='<br>'),
                     win='report' + str(FLG.running_fold))
    train_report = ScoreReport()
    valid_report = ScoreReport()
    timer = SimpleTimer()
    fold_str = 'fold' + str(FLG.running_fold)
    best_score = dict(epoch=0, loss=1e+100, accuracy=0)

    #### create dataset ###
    # kfold split
    target_dict = np.load(pjoin(FLG.data_root, 'target_dict.pkl'))
    trainblock, validblock, ratio = fold_split(
        FLG.fold, FLG.running_fold, FLG.labels,
        np.load(pjoin(FLG.data_root, 'subject_indices.npy')), target_dict)

    def _dataset(block, transform):
        return ADNIDataset(FLG.labels,
                           pjoin(FLG.data_root, FLG.modal),
                           block,
                           target_dict,
                           transform=transform)

    # create train set
    trainset = _dataset(trainblock, transform_presets(FLG.augmentation))

    # create normal valid set
    validset = _dataset(
        validblock,
        transform_presets('nine crop' if FLG.augmentation ==
                          'random crop' else 'no augmentation'))

    # each loader
    trainloader = DataLoader(trainset,
                             batch_size=FLG.batch_size,
                             shuffle=True,
                             num_workers=4,
                             pin_memory=True)
    validloader = DataLoader(validset, num_workers=4, pin_memory=True)

    # data check
    # for image, _ in trainloader:
    #     summary.image3d('asdf', image)

    # create model
    def kaiming_init(tensor):
        return kaiming_normal_(tensor, mode='fan_out', nonlinearity='relu')

    if 'plane' in FLG.model:
        model = Plane(len(FLG.labels),
                      name=FLG.model,
                      weights_initializer=kaiming_init)
    elif 'resnet11' in FLG.model:
        model = resnet11(len(FLG.labels),
                         FLG.model,
                         weights_initializer=kaiming_init)
    elif 'resnet19' in FLG.model:
        model = resnet19(len(FLG.labels),
                         FLG.model,
                         weights_initializer=kaiming_init)
    elif 'resnet35' in FLG.model:
        model = resnet35(len(FLG.labels),
                         FLG.model,
                         weights_initializer=kaiming_init)
    elif 'resnet51' in FLG.model:
        model = resnet51(len(FLG.labels),
                         FLG.model,
                         weights_initializer=kaiming_init)
    else:
        raise NotImplementedError(FLG.model)

    print_model_parameters(model)
    model = torch.nn.DataParallel(model, FLG.devices)
    model.to(device)

    # criterion
    train_criterion = torch.nn.CrossEntropyLoss(weight=torch.Tensor(
        list(map(lambda x: x * 2, reversed(ratio))))).to(device)
    valid_criterion = torch.nn.CrossEntropyLoss().to(device)

    # TODO resume
    # optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=FLG.lr,
                                 weight_decay=FLG.l2_decay)
    # scheduler
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, FLG.lr_gamma)

    start_epoch = 0
    global_step = start_epoch * len(trainloader)
    pbar = None
    for epoch in range(1, FLG.max_epoch + 1):
        timer.tic()
        scheduler.step()
        summary.scalar('lr',
                       fold_str,
                       epoch - 1,
                       optimizer.param_groups[0]['lr'],
                       ytickmin=0,
                       ytickmax=FLG.lr)

        # train()
        torch.set_grad_enabled(True)
        model.train(True)
        train_report.clear()
        if pbar is None:
            pbar = tqdm(total=len(trainloader) * FLG.validation_term,
                        desc='Epoch {:<3}-{:>3} train'.format(
                            epoch, epoch + FLG.validation_term - 1))
        for images, targets in trainloader:
            images = images.cuda(device, non_blocking=True)
            targets = targets.cuda(device, non_blocking=True)

            optimizer.zero_grad()

            outputs = model(images)
            loss = train_criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_report.update_true(targets)
            train_report.update_score(F.softmax(outputs, dim=1))

            summary.scalar('loss',
                           'train ' + fold_str,
                           global_step / len(trainloader),
                           loss.item(),
                           ytickmin=0,
                           ytickmax=1)

            pbar.update()
            global_step += 1

        if epoch % FLG.validation_term != 0:
            timer.toc()
            continue
        pbar.close()

        # valid()
        torch.set_grad_enabled(False)
        model.eval()
        valid_report.clear()
        pbar = tqdm(total=len(validloader),
                    desc='Epoch {:>3} valid'.format(epoch))
        for images, targets in validloader:
            true = targets
            npatchs = 1
            if len(images.shape) == 6:
                _, npatchs, c, x, y, z = images.shape
                images = images.view(-1, c, x, y, z)
                targets = torch.cat([targets
                                     for _ in range(npatchs)]).squeeze()
            images = images.cuda(device, non_blocking=True)
            targets = targets.cuda(device, non_blocking=True)

            output = model(images)
            loss = valid_criterion(output, targets)

            valid_report.loss += loss.item()

            if npatchs == 1:
                score = F.softmax(output, dim=1)
            else:
                score = torch.mean(F.softmax(output, dim=1),
                                   dim=0,
                                   keepdim=True)
            valid_report.update_true(true)
            valid_report.update_score(score)

            pbar.update()
        pbar.close()

        # report
        vloss = valid_report.loss / len(validloader)
        summary.scalar('accuracy',
                       'train ' + fold_str,
                       epoch,
                       train_report.accuracy,
                       ytickmin=-0.05,
                       ytickmax=1.05)

        summary.scalar('loss',
                       'valid ' + fold_str,
                       epoch,
                       vloss,
                       ytickmin=0,
                       ytickmax=0.8)
        summary.scalar('accuracy',
                       'valid ' + fold_str,
                       epoch,
                       valid_report.accuracy,
                       ytickmin=-0.05,
                       ytickmax=1.05)

        is_best = False
        if best_score['loss'] > vloss:
            best_score['loss'] = vloss
            best_score['epoch'] = epoch
            best_score['accuracy'] = valid_report.accuracy
            is_best = True

        print('Best Epoch {}: validation loss {} accuracy {}'.format(
            best_score['epoch'], best_score['loss'], best_score['accuracy']))

        # save
        if isinstance(model, torch.nn.DataParallel):
            state_dict = model.module.state_dict()
        else:
            state_dict = model.state_dict()

        save_checkpoint(
            dict(epoch=epoch,
                 best_score=best_score,
                 state_dict=state_dict,
                 optimizer_state_dict=optimizer.state_dict()),
            FLG.checkpoint_root, FLG.running_fold, FLG.model, is_best)
        pbar = None
        timer.toc()
        print('Time elapse {}h {}m {}s'.format(*timer.total()))
예제 #3
0
        y[test_idx],
        transform=Compose([
            #ToWoldCoordinateSystem(), ToTensor()]))
            ToWoldCoordinateSystem(),
            RandomCrop(80, 64),
            ToTensor()
        ]))
    # dataloader
    testloader = DataLoader(testset, num_workers=4, pin_memory=True)

    model.eval()
    G.eval()
    torch.set_grad_enabled(False)
    ############################ original : test #############################
    test_scores = ScoreReport()
    test_scores.clear()
    print('original image result')
    for sample in testloader:
        mri = sample['image']
        target = sample['target']
        mri = mri.cuda(device, non_blocking=True)
        target = target.to(device)

        name = 'ori-target'
        save_printer1 = Image3D(vis, name)
        save_printer1(name, torch.clamp(mri[0, :, :, :], min=0, max=1))

        output = model(mri)

        score = torch.nn.functional.softmax(output, 1)
        test_scores.update_true(target)
예제 #4
0
        stats = logging.Statistics(['loss_D', 'loss_G'])
        printers['lr']('D', epoch, optimizerD.param_groups[0]['lr'])
        printers['lr']('G', epoch, optimizerG.param_groups[0]['lr'])
        printers['lr']('info', epoch, optimizerinfo.param_groups[0]['lr'])
        printers['lr']('AC', epoch, optimizerAC.param_groups[0]['lr'])
        timer.tic()

        if (epoch + 1) % 100 == 0:
            schedularD.step()
            schedularG.step()
            schedularinfo.step()
            schedularAC.step()
        D.train()
        G.train()
        E.train()
        train_scores.clear()

        real_rate = torch.zeros(10)
        fake_rate = torch.zeros(10)
        for step, data in enumerate(trainloader):
            D.zero_grad()
            G.zero_grad()
            E.zero_grad()
            x = data['image']
            y = data['target']
            batch_size = x.size(0)  # batch_size <= FG.batch_size
            label = y.float().cuda(device, non_blocking=True)
            y_real = torch.ones(batch_size).cuda(device, non_blocking=True)
            y_fake = torch.zeros(batch_size).cuda(device, non_blocking=True)

            # x = x.unsqueeze(dim=1).float().cuda(device, non_blocking=True)