Exemple #1
0
def patch_mnih(data_dir, save_dir, patch_size, pad, overlap):
    """
    Preprocess the standard mnih dataset
    :param data_dir: path to the original mnih dataset
    :param save_dir: directory to save the extracted patches
    :param patch_size: size of the patches, should be a tuple of (h, w)
    :param pad: #pixels to be padded around each tile, should be either one element or four elements
    :param overlap: #overlapping pixels between two patches in both vertical and horizontal direction
    :return:
    """
    
    for dataset in tqdm(SPLITS, desc='Train-valid split'):
        FILENAMES = [
            fname.split('.')[0] for fname in os.listdir(os.path.join(DATA_DIR, dataset, MODES[0]))
        ]
        # create folders and files
        patch_dir = os.path.join(save_dir, 'patches')
        misc_utils.make_dir_if_not_exist(patch_dir)
        record_file = open(os.path.join(save_dir, 'file_list_{}.txt'.format(dataset)), 'w+')

        # get rgb and gt files
        for fname in tqdm(FILENAMES, desc='File-wise'):
            rgb_filename = os.path.join(DATA_DIR, dataset, 'sat', fname+'.tiff')
            gt_filename = os.path.join(DATA_DIR, dataset, 'map', fname+'.tif')
            for rgb_patch, gt_patch, y, x in patch_tile(rgb_filename, gt_filename, patch_size, pad, overlap):
                rgb_patchname = '{}_y{}x{}.jpg'.format(fname, int(y), int(x))
                gt_patchname = '{}_y{}x{}.png'.format(fname, int(y), int(x))
                misc_utils.save_file(os.path.join(patch_dir, rgb_patchname), rgb_patch.astype(np.uint8))
                misc_utils.save_file(os.path.join(patch_dir, gt_patchname), (gt_patch/255).astype(np.uint8))
                record_file.write('{} {}\n'.format(rgb_patchname, gt_patchname))
        record_file.close()
Exemple #2
0
 def extract_(file_list, file_exts, patch_size, pad, overlap, save_path):
     assert len(file_exts) == len(file_list[0])
     pbar = tqdm(file_list)
     record_file = open(os.path.join(save_path, 'file_list.txt'), 'w')
     for files in pbar:
         pbar.set_description('Extracting {}'.format(
             os.path.basename(files[0])))
         patch_list = []
         for f, ext in zip(files, file_exts):
             patch_list_ext = []
             img = misc_utils.load_file(f)
             grid_list = make_grid(
                 np.array(img.shape[:2]) + 2 * pad, patch_size, overlap)
             # extract images
             for patch, y, x in patch_block(img,
                                            pad,
                                            grid_list,
                                            patch_size,
                                            return_coord=True):
                 patch_name = '{}_y{}x{}.{}'.format(
                     os.path.basename(f).split('.')[0], int(y), int(x), ext)
                 patch_name = os.path.join(save_path, patch_name)
                 misc_utils.save_file(patch_name, patch.astype(np.uint8))
                 patch_list_ext.append(patch_name)
             patch_list.append(patch_list_ext)
         patch_list = misc_utils.rotate_list(patch_list)
         for items in patch_list:
             record_file.write('{}\n'.format(' '.join(items)))
     record_file.close()
Exemple #3
0
def main():
    # settings
    cfg = read_config()
    # set gpu to use
    device, parallel = misc_utils.set_gpu(cfg['gpu'])
    # set random seed
    misc_utils.set_random_seed(cfg['random_seed'])
    # make training directory
    misc_utils.make_dir_if_not_exist(cfg['save_dir'])
    misc_utils.save_file(os.path.join(cfg['save_dir'], 'config.json'), cfg)

    # train the model
    train_model(cfg, device, parallel)
Exemple #4
0
    def infer(self, model, pred_dir, patch_size, overlap, ext='_mask', file_ext='png', visualize=False,
              densecrf=False, crf_params=None):
        if isinstance(model, list) or isinstance(model, tuple):
            lbl_margin = model[0].lbl_margin
        else:
            lbl_margin = model.lbl_margin
        if crf_params is None and densecrf:
            crf_params = {'sxy': 3, 'srgb': 3, 'compat': 5}

        misc_utils.make_dir_if_not_exist(pred_dir)
        pbar = tqdm(self.rgb_files)
        for rgb_file in pbar:
            file_name = os.path.splitext(os.path.basename(rgb_file))[0].split('_')[0]
            pbar.set_description('Inferring {}'.format(file_name))
            # read data
            rgb = misc_utils.load_file(rgb_file)[:, :, :3]

            # evaluate on tiles
            tile_dim = rgb.shape[:2]
            tile_dim_pad = [tile_dim[0] + 2 * lbl_margin, tile_dim[1] + 2 * lbl_margin]
            grid_list = patch_extractor.make_grid(tile_dim_pad, patch_size, overlap)

            if isinstance(model, list) or isinstance(model, tuple):
                tile_preds = 0
                for m in model:
                    tile_preds = tile_preds + self.infer_tile(m, rgb, grid_list, patch_size, tile_dim, tile_dim_pad,
                                                              lbl_margin)
            else:
                tile_preds = self.infer_tile(model, rgb, grid_list, patch_size, tile_dim, tile_dim_pad, lbl_margin)

            if densecrf:
                d = dcrf.DenseCRF2D(*tile_preds.shape)
                U = unary_from_softmax(np.ascontiguousarray(
                    data_utils.change_channel_order(tile_preds, False)))
                d.setUnaryEnergy(U)
                d.addPairwiseBilateral(rgbim=rgb, **crf_params)
                Q = d.inference(5)
                tile_preds = np.argmax(Q, axis=0).reshape(*tile_preds.shape[:2])
            else:
                tile_preds = np.argmax(tile_preds, -1)

            if self.encode_func:
                pred_img = self.encode_func(tile_preds)
            else:
                pred_img = tile_preds

            if visualize:
                vis_utils.compare_figures([rgb, pred_img], (1, 2), fig_size=(12, 5))

            misc_utils.save_file(os.path.join(pred_dir, '{}{}.{}'.format(file_name, ext, file_ext)), pred_img)
Exemple #5
0
def patch_deepgloberoad(data_dir,
                        save_dir,
                        patch_size,
                        pad,
                        overlap,
                        valid_percent=0.14):
    dirs = ['road_trainv1/train', 'road_trainv2/train']

    # create folders and files
    patch_dir = os.path.join(save_dir, 'patches')
    misc_utils.make_dir_if_not_exist(patch_dir)
    record_file_train = open(os.path.join(save_dir, 'file_list_train.txt'),
                             'w+')
    record_file_valid = open(os.path.join(save_dir, 'file_list_valid.txt'),
                             'w+')

    # make folds
    files = []
    for dir_ in dirs:
        files.extend(
            data_utils.get_img_lbl(os.path.join(data_dir, dir_), 'sat.jpg',
                                   'mask.png'))
    valid_size = int(len(files) * valid_percent)

    for cnt, (img_file, lbl_file) in enumerate(tqdm(files)):
        city_name = os.path.splitext(
            os.path.basename(img_file))[0].split('_')[0]
        for rgb_patch, gt_patch, y, x in patch_tile(img_file, lbl_file,
                                                    patch_size, pad, overlap):
            img_patchname = '{}_y{}x{}.jpg'.format(city_name, int(y), int(x))
            lbl_patchname = '{}_y{}x{}.png'.format(city_name, int(y), int(x))
            misc_utils.save_file(os.path.join(patch_dir, img_patchname),
                                 rgb_patch.astype(np.uint8))
            misc_utils.save_file(os.path.join(patch_dir, lbl_patchname),
                                 (gt_patch / 255).astype(np.uint8))

            if cnt < valid_size:
                record_file_valid.write('{} {}\n'.format(
                    img_patchname, lbl_patchname))
            else:
                record_file_train.write('{} {}\n'.format(
                    img_patchname, lbl_patchname))
    record_file_train.close()
    record_file_valid.close()
Exemple #6
0
def patch_inria(data_dir, save_dir, patch_size, pad, overlap):
    """
    Preprocess the standard inria dataset
    :param data_dir: path to the original inria dataset
    :param save_dir: directory to save the extracted patches
    :param patch_size: size of the patches, should be a tuple of (h, w)
    :param pad: #pixels to be padded around each tile, should be either one element or four elements
    :param overlap: #overlapping pixels between two patches in both vertical and horizontal direction
    :return:
    """
    # create folders and files
    patch_dir = os.path.join(save_dir, 'patches')
    misc_utils.make_dir_if_not_exist(patch_dir)
    record_file_train = open(os.path.join(save_dir, 'file_list_train.txt'),
                             'w+')
    record_file_valid = open(os.path.join(save_dir, 'file_list_valid.txt'),
                             'w+')
    # get rgb and gt files
    for city_name in tqdm(SAVE_CITY, desc='City-wise'):
        for tile_id in tqdm(range(1, 37), desc='Tile-wise', leave=False):
            rgb_filename = os.path.join(data_dir, 'images',
                                        '{}{}.tif'.format(city_name, tile_id))
            gt_filename = os.path.join(data_dir, 'gt',
                                       '{}{}.tif'.format(city_name, tile_id))
            for rgb_patch, gt_patch, y, x in data_utils.patch_tile(
                    rgb_filename, gt_filename, patch_size, pad, overlap):
                rgb_patchname = '{}{}_y{}x{}.jpg'.format(
                    city_name, tile_id, int(y), int(x))
                gt_patchname = '{}{}_y{}x{}.png'.format(
                    city_name, tile_id, int(y), int(x))
                misc_utils.save_file(os.path.join(patch_dir, rgb_patchname),
                                     rgb_patch.astype(np.uint8))
                misc_utils.save_file(os.path.join(patch_dir, gt_patchname),
                                     (gt_patch / 255).astype(np.uint8))
                if city_name in VAL_CITY and tile_id in VAL_IDS:
                    record_file_valid.write('{} {}\n'.format(
                        rgb_patchname, gt_patchname))
                else:
                    record_file_train.write('{} {}\n'.format(
                        rgb_patchname, gt_patchname))
    record_file_train.close()
    record_file_valid.close()
Exemple #7
0
    def run(self, force_run=False, **kwargs):
        """
        Run the process
        :param force_run: if True, then the process will run no matter it has completed before
        :param kwargs:
        :return:
        """
        # check if state file exists
        state_exist = os.path.exists(self.state_file)
        # run the function if force run or haven't run before
        if force_run or state_exist == 0:
            print(('Start running {}'.format(self.name)))
            # write state log as incomplete
            with open(self.state_file, 'w') as f:
                f.write('Incomplete\n')

            # run the process
            self.val = self.func(**kwargs)

            # write state log as complete
            with open(self.state_file, 'w') as f:
                f.write('Finished\n')
            misc_utils.save_file(self.save_path, self.val)
        else:
            # if haven't run before, run the process
            if not self.check_finish():
                self.val = self.func(**kwargs)
                misc_utils.save_file(self.save_path, self.val)

            # if already exists, load the file
            self.val = misc_utils.load_file(self.save_path)

            # write state log as complete
            with open(self.state_file, 'w') as f:
                f.write('Finished\n')
        return self
Exemple #8
0
def make_dataset(ds_train, ds_valid, save_dir, th=0.5):
    import solaris as sol

    # create folders and files
    patch_dir = os.path.join(save_dir, 'patches')
    misc_utils.make_dir_if_not_exist(patch_dir)
    record_file_train = open(os.path.join(save_dir, 'file_list_train.txt'),
                             'w+')
    record_file_valid = open(os.path.join(save_dir, 'file_list_valid.txt'),
                             'w+')

    # remove counting
    remove_train_cnt = 0
    remove_valid_cnt = 0

    # make dataset
    ds_dict = {
        'train': {
            'ds': ds_train,
            'record': record_file_train,
            'remove_cnt': remove_train_cnt
        },
        'valid': {
            'ds': ds_valid,
            'record': record_file_valid,
            'remove_cnt': remove_valid_cnt
        }
    }

    # valid ds
    for phase in ['valid', 'train']:
        for rgb_file, gt_file in tqdm(ds_dict[phase]['ds']):
            img_save_name = os.path.join(
                patch_dir, '{}.jpg'.format(
                    os.path.splitext(os.path.basename(rgb_file))[0]))
            lbl_save_name = os.path.join(
                patch_dir, '{}.png'.format(
                    os.path.splitext(os.path.basename(rgb_file))[0]))
            convert_gtif_to_8bit(rgb_file, img_save_name)
            img = misc_utils.load_file(img_save_name)
            lbl = sol.vector.mask.footprint_mask(df=gt_file,
                                                 reference_im=rgb_file)

            # from mrs_utils import vis_utils
            # vis_utils.compare_figures([img, lbl], (1, 2), fig_size=(12, 5))

            blank_region = check_blank_region(img)
            if blank_region > th:
                ds_dict[phase]['remove_cnt'] += 1
                os.remove(img_save_name)
            else:
                if img.shape[0] != lbl.shape[0] or img.shape[1] != lbl.shape[1]:
                    assert np.unique(lbl) == np.array([0])
                    lbl = lbl[:img.shape[0], :img.shape[1]]
                misc_utils.save_file(os.path.join(patch_dir, lbl_save_name),
                                     (lbl / 255).astype(np.uint8))
                ds_dict[phase]['record'].write('{} {}\n'.format(
                    os.path.basename(img_save_name),
                    os.path.basename(lbl_save_name)))
        ds_dict[phase]['record'].close()
        print('{} set: {:.2f}% data removed with threshold of {}'.format(
            phase, ds_dict[phase]['remove_cnt'] / len(ds_dict[phase]['ds']),
            th))
        print('\t kept patches: {}'.format(
            len(ds_dict[phase]['ds']) - ds_dict[phase]['remove_cnt']))

        files_remove = glob(os.path.join(patch_dir, '*.aux.xml'))
        for f in files_remove:
            os.remove(f)
Exemple #9
0
    def evaluate(self,
                 model,
                 patch_size,
                 overlap,
                 pred_dir=None,
                 report_dir=None,
                 save_conf=False,
                 delta=1e-6,
                 eval_class=(1, ),
                 visualize=False,
                 densecrf=False,
                 crf_params=None,
                 verbose=True):
        if isinstance(model, list) or isinstance(model, tuple):
            lbl_margin = model[0].lbl_margin
        else:
            lbl_margin = model.lbl_margin
        if crf_params is None and densecrf:
            crf_params = {'sxy': 3, 'srgb': 3, 'compat': 5}

        iou_a, iou_b = np.zeros(len(eval_class)), np.zeros(len(eval_class))
        report = []
        if pred_dir:
            misc_utils.make_dir_if_not_exist(pred_dir)
        for rgb_file, lbl_file in zip(self.rgb_files, self.lbl_files):
            file_name = os.path.splitext(os.path.basename(lbl_file))[0]

            # read data
            rgb = misc_utils.load_file(rgb_file)[:, :, :3]
            lbl = misc_utils.load_file(lbl_file)
            if self.decode_func:
                lbl = self.decode_func(lbl)

            # evaluate on tiles
            tile_dim = rgb.shape[:2]
            tile_dim_pad = [
                tile_dim[0] + 2 * lbl_margin, tile_dim[1] + 2 * lbl_margin
            ]
            grid_list = patch_extractor.make_grid(tile_dim_pad, patch_size,
                                                  overlap)

            if isinstance(model, list) or isinstance(model, tuple):
                tile_preds = 0
                for m in model:
                    tile_preds = tile_preds + self.infer_tile(
                        m, rgb, grid_list, patch_size, tile_dim, tile_dim_pad,
                        lbl_margin)
            else:
                tile_preds = self.infer_tile(model, rgb, grid_list, patch_size,
                                             tile_dim, tile_dim_pad,
                                             lbl_margin)

            if save_conf:
                misc_utils.save_file(
                    os.path.join(pred_dir, '{}.npy'.format(file_name)),
                    scipy.special.softmax(tile_preds, axis=-1)[:, :, 1])

            if densecrf:
                d = dcrf.DenseCRF2D(*tile_preds.shape)
                U = unary_from_softmax(
                    np.ascontiguousarray(
                        data_utils.change_channel_order(
                            scipy.special.softmax(tile_preds, axis=-1),
                            False)))
                d.setUnaryEnergy(U)
                d.addPairwiseBilateral(rgbim=rgb, **crf_params)
                Q = d.inference(5)
                tile_preds = np.argmax(Q,
                                       axis=0).reshape(*tile_preds.shape[:2])
            else:
                tile_preds = np.argmax(tile_preds, -1)
            iou_score = metric_utils.iou_metric(lbl / self.truth_val,
                                                tile_preds,
                                                eval_class=eval_class)
            pstr, rstr = self.get_result_strings(file_name, iou_score, delta)
            tm.misc_utils.verb_print(pstr, verbose)
            report.append(rstr)
            iou_a += iou_score[0, :]
            iou_b += iou_score[1, :]
            if visualize:
                if self.encode_func:
                    vis_utils.compare_figures([
                        rgb,
                        self.encode_func(lbl),
                        self.encode_func(tile_preds)
                    ], (1, 3),
                                              fig_size=(15, 5))
                else:
                    vis_utils.compare_figures([rgb, lbl, tile_preds], (1, 3),
                                              fig_size=(15, 5))
            if pred_dir:
                if self.encode_func:
                    misc_utils.save_file(
                        os.path.join(pred_dir, '{}.png'.format(file_name)),
                        self.encode_func(tile_preds))
                else:
                    misc_utils.save_file(
                        os.path.join(pred_dir, '{}.png'.format(file_name)),
                        tile_preds)
        pstr, rstr = self.get_result_strings('Overall',
                                             np.stack([iou_a, iou_b], axis=0),
                                             delta)
        tm.misc_utils.verb_print(pstr, verbose)
        report.append(rstr)
        if report_dir:
            misc_utils.make_dir_if_not_exist(report_dir)
            misc_utils.save_file(os.path.join(report_dir, 'result.txt'),
                                 report)
        return np.mean(iou_a / (iou_b + delta)) * 100
Exemple #10
0
def train_model(args, device, parallel):
    """
    The function to train the model
    :param args: the class carries configuration parameters defined in config.py
    :param device: the device to run the model
    :return:
    """

    model = network_io.create_model(args)
    log_dir = os.path.join(args['save_dir'], 'log')
    writer = SummaryWriter(log_dir=log_dir)
    # TODO add write_graph back, probably need to swith to tensorboard in pytorch
    if parallel:
        model.encoder = network_utils.DataParallelPassThrough(model.encoder)
        model.decoder = network_utils.DataParallelPassThrough(model.decoder)
        if args['optimizer']['aux_loss']:
            model.cls = network_utils.DataParallelPassThrough(model.cls)
        print('Parallel training mode enabled!')
    train_params = model.set_train_params(
        (args['optimizer']['learn_rate_encoder'],
         args['optimizer']['learn_rate_decoder']))

    # make optimizer
    optm = network_io.create_optimizer(args['optimizer']['name'], train_params,
                                       args['optimizer']['learn_rate_encoder'])
    criterions = network_io.create_loss(args, device=device)
    cls_criterion = None
    with_aux = False
    if args['optimizer']['aux_loss']:
        with_aux = True
        cls_criterion = metric_utils.BCEWithLogitLoss(
            device, eval(args['trainer']['class_weight']))
    scheduler = optim.lr_scheduler.MultiStepLR(
        optm,
        milestones=eval(args['optimizer']['decay_step']),
        gamma=args['optimizer']['decay_rate'])

    # if not resume, train from scratch
    if args['trainer']['resume_epoch'] == 0 and args['trainer'][
            'finetune_dir'] == 'None':
        print('Training decoder {} with encoder {} from scratch ...'.format(
            args['decoder_name'], args['encoder_name']))
    elif args['trainer']['resume_epoch'] == 0 and args['trainer'][
            'finetune_dir']:
        print('Finetuning model from {}'.format(
            args['trainer']['finetune_dir']))
        if args['trainer']['further_train']:
            network_utils.load(model,
                               args['trainer']['finetune_dir'],
                               relax_load=True,
                               optm=optm,
                               device=device)
        else:
            network_utils.load(model,
                               args['trainer']['finetune_dir'],
                               relax_load=True)
    else:
        print('Resume training decoder {} with encoder {} from epoch {} ...'.
              format(args['decoder_name'], args['encoder_name'],
                     args['trainer']['resume_epoch']))
        network_utils.load_epoch(args['save_dir'],
                                 args['trainer']['resume_epoch'], model, optm,
                                 device)

    # prepare training
    print('Total params: {:.2f}M'.format(network_utils.get_model_size(model)))
    model.to(device)
    for c in criterions:
        c.to(device)

    # make data loader
    ds_cfgs = [a for a in sorted(args.keys()) if 'dataset' in a]
    assert ds_cfgs[0] == 'dataset'
    mean, std = args[ds_cfgs[0]]['mean'], args[ds_cfgs[0]][
        'std']  # read default mean and std first

    train_val_loaders = {'train': [], 'valid': []}
    for ds_cfg in ds_cfgs:
        if args[ds_cfg]['load_func'] == 'default':
            load_func = data_utils.default_get_stats
        else:
            load_func = None

        mean, std = network_io.get_dataset_stats(
            args[ds_cfg]['ds_name'],
            args[ds_cfg]['data_dir'],
            mean_val=(eval(args[ds_cfg]['mean']), eval(args[ds_cfg]['std'])),
            load_func=load_func,
            file_list=args[ds_cfg]['train_file'])
        args[ds_cfg]['mean'], args[ds_cfg]['std'] = str(tuple(mean)), str(
            tuple(
                std))  # update args mean and std with actual values being used

        tsfm_train, tsfm_valid = network_io.create_tsfm(args, mean, std)
        train_loader = DataLoader(
            data_loader.get_loader(args[ds_cfg]['data_dir'],
                                   args[ds_cfg]['train_file'],
                                   transforms=tsfm_train,
                                   n_class=args[ds_cfg]['class_num'],
                                   with_aux=with_aux),
            batch_size=int(args[ds_cfg]['batch_size']),
            shuffle=True,
            num_workers=int(args['dataset']['num_workers']),
            drop_last=True)
        train_val_loaders['train'].append(train_loader)

        if 'valid_file' in args[ds_cfg]:
            valid_loader = DataLoader(
                data_loader.get_loader(args[ds_cfg]['data_dir'],
                                       args[ds_cfg]['valid_file'],
                                       transforms=tsfm_valid,
                                       n_class=args[ds_cfg]['class_num'],
                                       with_aux=with_aux),
                batch_size=int(args[ds_cfg]['batch_size']),
                shuffle=False,
                num_workers=int(args[ds_cfg]['num_workers']))
            print('Training model on the {} dataset'.format(
                args[ds_cfg]['ds_name']))
            train_val_loaders['valid'].append(valid_loader)
    misc_utils.save_file(os.path.join(args['save_dir'], 'config.json'),
                         args)  # save config with actual mean and std used

    # train the model
    loss_dict = {}
    for epoch in range(int(args['trainer']['resume_epoch']),
                       int(args['trainer']['epochs'])):
        # each epoch has a training and validation step
        for phase in ['train', 'valid']:
            start_time = timeit.default_timer()
            if phase == 'train':
                model.train()
            else:
                model.eval()

            # TODO align aux loss and normal train
            loss_dict = model.step(
                train_val_loaders[phase],
                device,
                optm,
                phase,
                criterions,
                eval(args['trainer']['bp_loss_idx']),
                True,
                mean,
                std,
                loss_weights=eval(args['trainer']['loss_weights']),
                use_emau=args['use_emau'],
                use_ocr=args['use_ocr'],
                cls_criterion=cls_criterion,
                cls_weight=args['optimizer']['aux_loss_weight'])
            network_utils.write_and_print(writer, phase, epoch,
                                          int(args['trainer']['epochs']),
                                          loss_dict, start_time)

        scheduler.step()
        # save the model
        if epoch % int(args['trainer']['save_epoch']) == 0 and epoch != 0:
            save_name = os.path.join(args['save_dir'],
                                     'epoch-{}.pth.tar'.format(epoch))
            network_utils.save(model, epoch, optm, loss_dict, save_name)
    # save model one last time
    save_name = os.path.join(
        args['save_dir'],
        'epoch-{}.pth.tar'.format(int(args['trainer']['epochs'])))
    network_utils.save(model, int(args['trainer']['epochs']), optm, loss_dict,
                       save_name)
    writer.close()