Exemplo n.º 1
0
def main():
    # setup environments and seeds
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    # setup networks
    Network = getattr(models, args.net)
    model = Network(**args.net_params)
    model = model.cuda()

    optimizer = getattr(torch.optim, args.opt)(model.parameters(),
                                               **args.opt_params)
    criterion = getattr(criterions, args.criterion)

    msg = ''
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_iter = checkpoint['iter']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optim_dict'])
            msg = ("=> loaded checkpoint '{}' (iter {})".format(
                args.resume, checkpoint['iter']))
        else:
            msg = "=> no checkpoint found at '{}'".format(args.resume)
    else:
        msg = '-------------- New training session ----------------'

    msg += '\n' + str(args)
    logging.info(msg)

    # Data loading code
    Dataset = getattr(datasets, args.dataset)

    # The loader will get 1000 patches from 50 subjects for each sub epoch
    # each subject sample 20 patches
    train_list = os.path.join(args.data_dir, args.train_list)
    train_set = Dataset(train_list,
                        root=args.data_dir,
                        for_train=True,
                        num_patches=args.num_patches,
                        transforms=args.train_transforms,
                        sample_size=args.sample_size,
                        sub_sample_size=args.sub_sample_size,
                        target_size=args.target_size)

    num_iters = args.num_iters or (len(train_set) *
                                   args.num_epochs) // args.batch_size
    num_iters -= args.start_iter
    train_sampler = CycleSampler(len(train_set), num_iters * args.batch_size)
    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              collate_fn=train_set.collate,
                              sampler=train_sampler,
                              num_workers=args.workers,
                              pin_memory=True,
                              worker_init_fn=init_fn)

    if args.valid_list:
        valid_list = os.path.join(args.data_dir, args.valid_list)
        valid_set = Dataset(valid_list,
                            root=args.data_dir,
                            for_train=False,
                            crop=False,
                            transforms=args.test_transforms,
                            sample_size=args.sample_size,
                            sub_sample_size=args.sub_sample_size,
                            target_size=args.target_size)
        valid_loader = DataLoader(valid_set,
                                  batch_size=1,
                                  shuffle=False,
                                  collate_fn=valid_set.collate,
                                  num_workers=4,
                                  pin_memory=True)

        train_valid_set = Dataset(train_list,
                                  root=args.data_dir,
                                  for_train=False,
                                  crop=False,
                                  transforms=args.test_transforms,
                                  sample_size=args.sample_size,
                                  sub_sample_size=args.sub_sample_size,
                                  target_size=args.target_size)
        train_valid_loader = DataLoader(train_valid_set,
                                        batch_size=1,
                                        shuffle=False,
                                        collate_fn=train_valid_set.collate,
                                        num_workers=4,
                                        pin_memory=True)

    start = time.time()

    enum_batches = len(train_set) / float(args.batch_size)
    args.schedule = {
        int(k * enum_batches): v
        for k, v in args.schedule.items()
    }
    args.save_freq = int(enum_batches * args.save_freq)
    args.valid_freq = int(enum_batches * args.valid_freq)

    losses = AverageMeter()
    torch.set_grad_enabled(True)

    for i, (data, label) in enumerate(train_loader, args.start_iter):

        ## validation
        #if args.valid_list and  (i % args.valid_freq) == 0:
        #    logging.info('-'*50)
        #    msg  =  'Iter {}, Epoch {:.4f}, {}'.format(i, i/enum_batches, 'validation')
        #    logging.info(msg)
        #    with torch.no_grad():
        #        validate(valid_loader, model, batch_size=args.mini_batch_size, names=valid_set.names)

        # actual training
        adjust_learning_rate(optimizer, i)
        for data in zip(*[d.split(args.mini_batch_size) for d in data]):

            data = [t.cuda(non_blocking=True) for t in data]
            x1, x2, target = data[:3]

            if len(data) > 3:  # has mask
                m1, m2 = data[3:]
                x1 = add_mask(x1, m1, 1)
                x2 = add_mask(x2, m2, 1)

            # compute output
            output = model((x1, x2))  # output nx5x9x9x9, target nx9x9x9
            loss = criterion(output, target, args.alpha)

            # measure accuracy and record loss
            losses.update(loss.item(), target.numel())

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if (i + 1) % args.save_freq == 0:
            epoch = int((i + 1) // enum_batches)

            file_name = os.path.join(ckpts, 'model_epoch_{}.tar'.format(epoch))
            torch.save(
                {
                    'iter': i + 1,
                    'state_dict': model.state_dict(),
                    'optim_dict': optimizer.state_dict(),
                }, file_name)

        msg = 'Iter {0:}, Epoch {1:.4f}, Loss {2:.4f}'.format(
            i + 1, (i + 1) / enum_batches, losses.avg)
        logging.info(msg)

        losses.reset()

    i = num_iters + args.start_iter
    file_name = os.path.join(ckpts, 'model_last.tar')
    torch.save(
        {
            'iter': i,
            'state_dict': model.state_dict(),
            'optim_dict': optimizer.state_dict(),
        }, file_name)

    if args.valid_list:
        logging.info('-' * 50)
        msg = 'Iter {}, Epoch {:.4f}, {}'.format(i, i / enum_batches,
                                                 'validate validation data')
        logging.info(msg)

        with torch.no_grad():
            validate(valid_loader,
                     model,
                     batch_size=args.mini_batch_size,
                     names=valid_set.names,
                     out_dir=args.out)

        #logging.info('-'*50)
        #msg  =  'Iter {}, Epoch {:.4f}, {}'.format(i, i/enum_batches, 'validate training data')
        #logging.info(msg)

        #with torch.no_grad():
        #    validate(train_valid_loader, model, batch_size=args.mini_batch_size, names=train_valid_set.names, verbose=False)

    msg = 'total time: {:.4f} minutes'.format((time.time() - start) / 60)
    logging.info(msg)
def validate(valid_loader,
             model,
             batch_size,
             out_dir='',
             names=None,
             scoring=True,
             verbose=True):

    H, W, T = 240, 240, 155

    dset = valid_loader.dataset
    names = dset.names
    h, w, t = dset.shape
    h, w, t = int(h), int(w), int(t)
    sample_size = dset.sample_size
    sub_sample_size = dset.sub_sample_size
    target_size = dset.target_size
    dtype = torch.float32

    model.eval()
    criterion = F.cross_entropy

    vals = AverageMeter()
    for i, (data, labels) in enumerate(valid_loader):

        y = labels.cuda(non_blocking=True)
        data = [t.cuda(non_blocking=True) for t in data]
        x, coords = data[:2]

        if len(data) > 2:  # has mask
            x = add_mask(x, data.pop(), 0)

        outputs = torch.zeros(
            (5, h * w * t, target_size, target_size, target_size), dtype=dtype)
        #targets = torch.zeros((h*w*t, 9, 9, 9), dtype=torch.uint8)

        sample_loss = AverageMeter(
        ) if scoring and criterion is not None else None

        for b, coord in enumerate(coords.split(batch_size)):
            x1 = multicrop.crop3d_gpu(x, coord, sample_size, sample_size,
                                      sample_size, 1, True)
            x2 = multicrop.crop3d_gpu(x, coord, sub_sample_size,
                                      sub_sample_size, sub_sample_size, 3,
                                      True)

            if scoring:
                target = multicrop.crop3d_gpu(y, coord, target_size,
                                              target_size, target_size, 1,
                                              True)

            # compute output
            logit = model((x1, x2))  # nx5x9x9x9, target nx9x9x9
            output = F.softmax(logit, dim=1)

            # copy output
            start = b * batch_size
            end = start + output.shape[0]
            outputs[:, start:end] = output.permute(1, 0, 2, 3, 4).cpu()

            #targets[start:end] = target.type(dtype).cpu()

            # measure accuracy and record loss
            if scoring and criterion is not None:
                loss = criterion(logit, target)
                sample_loss.update(loss.item(), target.size(0))

        outputs = outputs.view(5, h, w, t, 9, 9,
                               9).permute(0, 1, 4, 2, 5, 3, 6)
        outputs = outputs.reshape(5, h * 9, w * 9, t * 9)
        outputs = outputs[:, :H, :W, :T].numpy()

        #targets = targets.view(h, w, t, 9, 9, 9).permute(0, 3, 1, 4, 2, 5).reshape(h*9, w*9, t*9)
        #targets = targets[:H, :W, :T].numpy()

        msg = 'Subject {}/{}, '.format(i + 1, len(valid_loader))
        name = str(i)
        if names:
            name = names[i]
            msg += '{:>20}, '.format(name)

        if out_dir:
            np.save(os.path.join(out_dir, name + '_preds'), outputs)

        if scoring:
            labels = labels.numpy()
            outputs = outputs.argmax(0)
            scores = dice(outputs, labels)

            #if criterion is not None:
            #    scores += sample_loss.avg,

            vals.update(np.array(scores))

            msg += ', '.join(
                ['{}: {:.4f}'.format(k, v) for k, v in zip(keys, scores)])

        if verbose:
            logging.info(msg)

    if scoring:
        msg = 'Average scores: '
        msg += ', '.join(
            ['{}: {:.4f}'.format(k, v) for k, v in zip(keys, vals.avg)])
        logging.info(msg)

    model.train()
    return vals.avg
Exemplo n.º 3
0
def validate(valid_loader,
             model,
             out_dir='',
             names=None,
             scoring=True,
             verbose=True):

    H, W, T = 240, 240, 155
    dtype = torch.float32

    dset = valid_loader.dataset

    model.eval()
    criterion = F.cross_entropy

    vals = AverageMeter()
    for i, data in enumerate(valid_loader):

        # target_cpu = data[1][0, :H, :W, :T].numpy() if scoring else None
        data = [t.cuda(non_blocking=True) for t in data]

        x, target = data[:2]

        if len(data) > 2:
            x = add_mask(x, data.pop(), 1)

        # compute output
        logit = model(x)  # nx5x9x9x9, target nx9x9x9
        output = F.softmax(logit, dim=1)  # nx5x9x9x9

        ## measure accuracy and record loss
        #loss = None
        #if scoring and criterion is not None:
        #    loss = criterion(logit, target).item()

        output = output[0, :, :H, :W, :T].cpu().numpy()

        msg = 'Subject {}/{}, '.format(i + 1, len(valid_loader))
        name = str(i)
        if names:
            name = names[i]
            msg += '{:>20}, '.format(name)

        if out_dir:
            oname = os.path.join(out_dir, name + '.nii.gz')
            H, W, T = 240, 240, 155
            seg_img = np.zeros(shape=(H, W, T), dtype=np.uint8)
            output = output.argmax(0)
            # print(seg_img.shape,output.shape)
            # ET_voxels = (output == 4).sum()
            # if ET_voxels < 500:
            #     output[np.where(output == 4)] = 1
            seg_img[np.where(output == 1)] = 1
            seg_img[np.where(output == 2)] = 2
            seg_img[np.where(output == 4)] = 4

            # if verbose:
            print('1:', np.sum(seg_img == 1), ' | 2:', np.sum(seg_img == 2),
                  ' | 4:', np.sum(seg_img == 4))
            print('WT:',
                  np.sum((seg_img == 1) | (seg_img == 2) | (seg_img == 4)),
                  ' | TC:', np.sum((seg_img == 1) | (seg_img == 4)), ' | ET:',
                  np.sum(seg_img == 4))
            nib.save(nib.Nifti1Image(seg_img, None), oname)

        if scoring:
            output = output.argmax(0)
            scores = dice(output, target_cpu)

            #if loss is not None:
            #    scores += loss,

            vals.update(np.array(scores))

            msg += ', '.join(
                ['{}: {:.4f}'.format(k, v) for k, v in zip(keys, scores)])

        if verbose:
            logging.info(msg)

    if scoring:
        msg = 'Average scores: '
        msg += ', '.join(
            ['{}: {:.4f}'.format(k, v) for k, v in zip(keys, vals.avg)])
        logging.info(msg)

    model.train()
    return vals.avg
def validate(valid_loader, model,
        out_dir='', names=None, scoring=True, verbose=True):

    H, W, T = 240, 240, 155
    dtype = torch.float32

    dset = valid_loader.dataset

    model.eval()
    criterion = F.cross_entropy

    vals = AverageMeter()
    for i, data in enumerate(valid_loader):

        target_cpu = data[1][0, :H, :W, :T].numpy() if scoring else None
        data = [t.cuda(non_blocking=True) for t in data]

        x, target = data[:2]

        if len(data) > 2:
            x = add_mask(x, data.pop(), 1)

        # compute output
        logit = model(x) # nx5x9x9x9, target nx9x9x9
        output = F.softmax(logit, dim=1) # nx5x9x9x9

        ## measure accuracy and record loss
        #loss = None
        #if scoring and criterion is not None:
        #    loss = criterion(logit, target).item()

        output = output[0, :, :H, :W, :T].cpu().numpy()

        msg = 'Subject {}/{}, '.format(i+1, len(valid_loader))
        name = str(i)
        if names:
            name = names[i]
            msg += '{:>20}, '.format(name)

        if out_dir:
            np.save(os.path.join(out_dir, name + '_preds'), output)

        if scoring:
            output = output.argmax(0)
            scores = dice(output, target_cpu)

            #if loss is not None:
            #    scores += loss,

            vals.update(np.array(scores))

            msg += ', '.join(['{}: {:.4f}'.format(k, v) for k, v in zip(keys, scores)])

        if verbose:
            logging.info(msg)

    if scoring:
        msg = 'Average scores: '
        msg += ', '.join(['{}: {:.4f}'.format(k, v) for k, v in zip(keys, vals.avg)])
        logging.info(msg)

    model.train()
    return vals.avg