Ejemplo n.º 1
0
def main(opt):
    global best_score, logger, logger_results
    best_score = 0
    opt.save_options()

    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
        str(x) for x in opt.train['gpus'])

    # set up logger
    logger, logger_results = setup_logging(opt)
    opt.print_options(logger)

    if opt.train['random_seed'] >= 0:
        # logger.info("=> Using random seed {:d}".format(opt.train['random_seed']))
        torch.manual_seed(opt.train['random_seed'])
        torch.cuda.manual_seed(opt.train['random_seed'])
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        np.random.seed(opt.train['random_seed'])
        random.seed(opt.train['random_seed'])
    else:
        torch.backends.cudnn.benchmark = True

    # ----- create model ----- #
    model = ResUNet34(pretrained=opt.model['pretrained'],
                      with_uncertainty=opt.with_uncertainty)
    # model = nn.DataParallel(model)
    model = model.cuda()

    # ----- define optimizer ----- #
    optimizer = torch.optim.Adam(model.parameters(),
                                 opt.train['lr'],
                                 betas=(0.9, 0.99),
                                 weight_decay=opt.train['weight_decay'])

    # ----- define criterion ----- #
    criterion = torch.nn.NLLLoss(ignore_index=2).cuda()

    # ----- load data ----- #
    data_transforms = {
        'train': get_transforms(opt.transform['train']),
        'val': get_transforms(opt.transform['val'])
    }

    img_dir = '{:s}/train'.format(opt.train['img_dir'])
    target_vor_dir = '{:s}/train'.format(opt.train['label_vor_dir'])
    target_cluster_dir = '{:s}/train'.format(opt.train['label_cluster_dir'])
    dir_list = [img_dir, target_vor_dir, target_cluster_dir]
    post_fix = ['label_vor.png', 'label_cluster.png']
    num_channels = [3, 3, 3]
    train_set = DataFolder(dir_list, post_fix, num_channels,
                           data_transforms['train'])
    train_loader = DataLoader(train_set,
                              batch_size=opt.train['batch_size'],
                              shuffle=True,
                              num_workers=opt.train['workers'])

    # ----- optionally load from a checkpoint for validation or resuming training ----- #
    if opt.train['checkpoint']:
        if os.path.isfile(opt.train['checkpoint']):
            logger.info("=> loading checkpoint '{}'".format(
                opt.train['checkpoint']))
            checkpoint = torch.load(opt.train['checkpoint'])
            opt.train['start_epoch'] = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                opt.train['checkpoint'], checkpoint['epoch']))
        else:
            logger.info("=> no checkpoint found at '{}'".format(
                opt.train['checkpoint']))

    # ----- training and validation ----- #
    num_epochs = opt.train['num_epochs']

    for epoch in range(opt.train['start_epoch'], num_epochs):
        # train for one epoch or len(train_loader) iterations
        logger.info('Epoch: [{:d}/{:d}]'.format(epoch + 1, num_epochs))
        train_loss, train_loss_vor, train_loss_cluster = train(
            opt, train_loader, model, optimizer, criterion)

        # evaluate on val set
        with torch.no_grad():
            val_acc, val_aji = validate(opt, model, data_transforms['val'])

        # check if it is the best accuracy
        is_best = val_aji > best_score
        best_score = max(val_aji, best_score)

        cp_flag = (epoch + 1) % opt.train['checkpoint_freq'] == 0
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, epoch, opt.train['save_dir'], is_best, cp_flag)

        # save the training results to txt files
        logger_results.info(
            '{:d}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}'.format(
                epoch + 1, train_loss, train_loss_vor, train_loss_cluster,
                val_acc, val_aji))

    for i in list(logger.handlers):
        logger.removeHandler(i)
        i.flush()
        i.close()
    for i in list(logger_results.handlers):
        logger_results.removeHandler(i)
        i.flush()
        i.close()
Ejemplo n.º 2
0
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from network import UNet
from dataset import DataFolder
import torch.utils.data as data
from util import EarlyStopping, save_nets, save_predictions, load_best_weights
from train_options import parser


args = parser.parse_args()
print(args)

all_loader = data.DataLoader(
    dataset=DataFolder('dataset/all_images_256/', 'dataset/all_masks_256/', 'all'),
    batch_size=args.eval_batch_size,
    shuffle=False,
    num_workers=2
)

train_loader = data.DataLoader(
    dataset=DataFolder('dataset/train_images_256/', 'dataset/train_masks_256/', 'train'),
    batch_size=args.train_batch_size,
    shuffle=True,
    num_workers=2
)

valid_loader = data.DataLoader(
    dataset=DataFolder('dataset/valid_images_256/', 'dataset/valid_masks_256/', 'validation'),
    batch_size=args.eval_batch_size,
Ejemplo n.º 3
0
def main():
    global opt, num_iter, tb_writer, logger, logger_results
    opt = Options(isTrain=True)
    opt.parse()
    opt.save_options()

    tb_writer = SummaryWriter('{:s}/tb_logs'.format(opt.train['save_dir']))

    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
        str(x) for x in opt.train['gpus'])

    # set up logger
    logger, logger_results = setup_logging(opt)

    # ----- create model ----- #
    model = ResUNet34(pretrained=opt.model['pretrained'])
    # if not opt.train['checkpoint']:
    #     logger.info(model)
    model = nn.DataParallel(model)
    model = model.cuda()
    cudnn.benchmark = True

    # ----- define optimizer ----- #
    optimizer = torch.optim.Adam(model.parameters(),
                                 opt.train['lr'],
                                 betas=(0.9, 0.99),
                                 weight_decay=opt.train['weight_decay'])

    # ----- define criterion ----- #
    criterion = torch.nn.NLLLoss(ignore_index=2).cuda()
    if opt.train['crf_weight'] > 0:
        logger.info('=> Using CRF loss...')
        global criterion_crf
        criterion_crf = CRFLoss(opt.train['sigmas'][0], opt.train['sigmas'][1])

    # ----- load data ----- #
    data_transforms = {
        'train': get_transforms(opt.transform['train']),
        'test': get_transforms(opt.transform['test'])
    }

    img_dir = '{:s}/train'.format(opt.train['img_dir'])
    target_vor_dir = '{:s}/train'.format(opt.train['label_vor_dir'])
    target_cluster_dir = '{:s}/train'.format(opt.train['label_cluster_dir'])
    dir_list = [img_dir, target_vor_dir, target_cluster_dir]
    post_fix = ['label_vor.png', 'label_cluster.png']
    num_channels = [3, 3, 3]
    train_set = DataFolder(dir_list, post_fix, num_channels,
                           data_transforms['train'])
    train_loader = DataLoader(train_set,
                              batch_size=opt.train['batch_size'],
                              shuffle=True,
                              num_workers=opt.train['workers'])

    # ----- optionally load from a checkpoint for validation or resuming training ----- #
    if opt.train['checkpoint']:
        if os.path.isfile(opt.train['checkpoint']):
            logger.info("=> loading checkpoint '{}'".format(
                opt.train['checkpoint']))
            checkpoint = torch.load(opt.train['checkpoint'])
            opt.train['start_epoch'] = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                opt.train['checkpoint'], checkpoint['epoch']))
        else:
            logger.info("=> no checkpoint found at '{}'".format(
                opt.train['checkpoint']))

    # ----- training and validation ----- #
    num_epoch = opt.train['train_epochs'] + opt.train['finetune_epochs']
    num_iter = num_epoch * len(train_loader)
    # print training parameters
    logger.info("=> Initial learning rate: {:g}".format(opt.train['lr']))
    logger.info("=> Batch size: {:d}".format(opt.train['batch_size']))
    logger.info("=> Number of training iterations: {:d}".format(num_iter))
    logger.info("=> Training epochs: {:d}".format(opt.train['train_epochs']))
    logger.info("=> Fine-tune epochs using dense CRF loss: {:d}".format(
        opt.train['finetune_epochs']))
    logger.info("=> CRF loss weight: {:.2g}".format(opt.train['crf_weight']))

    for epoch in range(opt.train['start_epoch'], num_epoch):
        # train for one epoch or len(train_loader) iterations
        logger.info('Epoch: [{:d}/{:d}]'.format(epoch + 1, num_epoch))
        finetune_flag = False if epoch < opt.train['train_epochs'] else True
        if epoch == opt.train['train_epochs']:
            logger.info("Fine-tune begins, lr = {:.2g}".format(
                opt.train['lr'] * 0.1))
            for param_group in optimizer.param_groups:
                param_group['lr'] = opt.train['lr'] * 0.1

        train_results = train(train_loader, model, optimizer, criterion,
                              finetune_flag)
        train_loss, train_loss_vor, train_loss_cluster, train_loss_crf = train_results

        cp_flag = (epoch + 1) % opt.train['checkpoint_freq'] == 0
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, epoch, opt.train['save_dir'], cp_flag)

        # save the training results to txt files
        logger_results.info('{:d}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}'.format(
            epoch + 1, train_loss, train_loss_vor, train_loss_cluster,
            train_loss_crf))
        # tensorboard logs
        tb_writer.add_scalars(
            'epoch_losses', {
                'train_loss': train_loss,
                'train_loss_vor': train_loss_vor,
                'train_loss_cluster': train_loss_cluster,
                'train_loss_crf': train_loss_crf
            }, epoch)
    tb_writer.close()
    for i in list(logger.handlers):
        logger.removeHandler(i)
        i.flush()
        i.close()
    for i in list(logger_results.handlers):
        logger_results.removeHandler(i)
        i.flush()
        i.close()
Ejemplo n.º 4
0
def main(opt):
    global best_score, num_iter, tb_writer, logger, logger_results
    best_score = 0
    opt.isTrain = True

    if not os.path.exists(opt.train['save_dir']):
        os.makedirs(opt.train['save_dir'])
    tb_writer = SummaryWriter('{:s}/tb_logs'.format(opt.train['save_dir']))

    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
        str(x) for x in opt.train['gpus'])

    opt.define_transforms()
    opt.save_options()

    # set up logger
    logger, logger_results = setup_logging(opt)

    # ----- create model ----- #
    model_name = opt.model['name']
    model = create_model(model_name, opt.model['out_c'],
                         opt.model['pretrained'])
    # if not opt.train['checkpoint']:
    #     logger.info(model)
    model = nn.DataParallel(model)
    model = model.cuda()

    # ----- define optimizer ----- #
    optimizer = torch.optim.Adam(model.parameters(),
                                 opt.train['lr'],
                                 betas=(0.9, 0.99),
                                 weight_decay=opt.train['weight_decay'])

    # ----- define criterion ----- #
    criterion = torch.nn.MSELoss(reduction='none').cuda()

    # ----- load data ----- #
    img_dir = '{:s}/train'.format(opt.train['img_dir'])
    target_dir = '{:s}/train'.format(opt.train['label_dir'])
    if opt.round == 0:
        dir_list = [img_dir, target_dir]
        post_fix = ['label_detect.png']
        num_channels = [3, 1]
        train_transform = get_transforms(opt.transform['train_stage1'])
    else:
        bg_dir = '{:s}/train'.format(opt.train['bg_dir'])
        dir_list = [img_dir, target_dir, bg_dir]
        post_fix = ['label_detect.png', 'label_bg.png']
        num_channels = [3, 1, 1]
        train_transform = get_transforms(opt.transform['train_stage2'])
    train_set = DataFolder(dir_list, post_fix, num_channels, train_transform)
    train_loader = DataLoader(train_set,
                              batch_size=opt.train['batch_size'],
                              shuffle=True,
                              num_workers=opt.train['workers'])
    val_transform = get_transforms(opt.transform['val'])

    # ----- training and validation ----- #
    num_epoch = opt.train['train_epochs']
    num_iter = num_epoch * len(train_loader)
    # print training parameters
    logger.info("=> Initial learning rate: {:g}".format(opt.train['lr']))
    logger.info("=> Batch size: {:d}".format(opt.train['batch_size']))
    logger.info("=> Number of training iterations: {:d}".format(num_iter))
    logger.info("=> Training epochs: {:d}".format(opt.train['train_epochs']))

    for epoch in range(num_epoch):
        # train for one epoch or len(train_loader) iterations
        logger.info('Epoch: [{:d}/{:d}]'.format(epoch + 1, num_epoch))
        train_loss = train(opt, train_loader, model, optimizer, criterion)

        # evaluate on val set
        with torch.no_grad():
            val_recall, val_prec, val_F1 = validate(opt, model, val_transform)

        # check if it is the best accuracy
        is_best = val_F1 > best_score
        best_score = max(val_F1, best_score)

        cp_flag = True if (epoch +
                           1) % opt.train['checkpoint_freq'] == 0 else False
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, epoch, is_best, opt.train['save_dir'], cp_flag)

        # save the training results to txt files
        logger_results.info('{:d}\t{:.4f} || {:.4f}\t{:.4f}\t{:.4f}'.format(
            epoch + 1, train_loss, val_recall, val_prec, val_F1))
        # tensorboard logs
        tb_writer.add_scalars('epoch_loss', {'train_loss': train_loss}, epoch)
        tb_writer.add_scalars('epoch_acc', {
            'val_recall': val_recall,
            'val_prec': val_prec,
            'val_F1': val_F1
        }, epoch)

    tb_writer.close()
    for i in list(logger.handlers):
        logger.removeHandler(i)
        i.flush()
        i.close()
    for i in list(logger_results.handlers):
        logger_results.removeHandler(i)
        i.flush()
        i.close()
Ejemplo n.º 5
0
                        type=float,
                        default=10,
                        help='Early stopping patience.',
                        dest='patience')
    parser.add_argument('-d',
                        '--min_delta',
                        type=float,
                        default=0.001,
                        help='Minimum loss improvement for each epoch.',
                        dest='min_delta')

    args = parser.parse_args()
    print(args)

    train_loader = data.DataLoader(dataset=DataFolder(
        'new_dataset/train/train_images_256/',
        'new_dataset/train/train_masks_256/', 'train'),
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=4)

    valid_loader = data.DataLoader(dataset=DataFolder(
        'new_dataset/val/train_images_256/',
        'new_dataset/val/train_masks_256/', 'validation'),
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   num_workers=4)

    test_loader = data.DataLoader(dataset=DataFolder(
        'new_dataset/test/train_images_256/',
        'new_dataset/test/train_masks_256/', 'test'),