Exemplo n.º 1
0
def read_flag():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch-size', default=BATCH_SIZE, type=int, help='batch size (10)')
    parser.add_argument('--ds-name', default=DS_NAME, type=str, help='dataset name')
    parser.add_argument('--tile-size', default=TILE_SIZE, type=tuple, help='tile size 5000')
    parser.add_argument('--patch-size', default=PATCH_SIZE, type=tuple, help='patch size 572')
    parser.add_argument('--epochs', default=EPOCHS, type=int, help='# epochs (1)')
    parser.add_argument('--num-classes', type=int, default=NUM_CLASS, help='# classes (including background)')
    parser.add_argument('--par-dir', type=str, default=PAR_DIR, help='parent directory name to save the model')
    parser.add_argument('--n-train', type=int, default=N_TRAIN, help='# samples per epoch')
    parser.add_argument('--n-valid', type=int, default=N_VALID, help='# patches to valid')
    parser.add_argument('--val-mult', type=int, default=VAL_MULT, help='validation_bs=val_mult*train_bs')
    parser.add_argument('--GPU', type=str, default=GPU, help="GPU used for computation.")
    parser.add_argument('--decay-step', type=float, default=DECAY_STEP, help='Learning rate decay step in number of epochs.')
    parser.add_argument('--decay-rate', type=float, default=DECAY_RATE, help='Learning rate decay rate')
    parser.add_argument('--verb-step', type=int, default=VERB_STEP, help='#steps between two verbose prints')
    parser.add_argument('--save-epoch', type=int, default=SAVE_EPOCH, help='#epochs between two model save')
    parser.add_argument('--data-dir', type=str, default=DATA_DIR, help='root directory of cityscapes')
    parser.add_argument('--from-scratch', type=bool, default=FROM_SCRATCH, help='from scratch or not')
    parser.add_argument('--start-layer', type=str, default=START_LAYER, help='start layer to unfreeze')
    parser.add_argument('--model-dir', type=str, default=MODEL_DIR, help='pretrained model directory')
    parser.add_argument('--learn-rate', type=str, default=LEARN_RATE, help='learning rate')

    flags = parser.parse_args()
    flags.par_dir = 'aemo/' + flags.par_dir
    flags.learn_rate = ersa_utils.str2list(flags.learn_rate, d_type=float)
    flags.start_layer = ersa_utils.str2list(flags.start_layer, d_type=int)
    return flags
Exemplo n.º 2
0
    def load_files(self, field_name=None, field_id=None, field_ext=None):
        """
        Load all files meet the given filters, each one above can be left blank
        :param field_name: name of the field
        :param field_id: name of the id
        :param field_ext: name of the field extension
        :return:
        """
        if field_name is None:
            field_name = self.field_name
        if field_id is None:
            field_id = self.field_id
        field_ext = ersa_utils.str2list(field_ext, d_type=str)
        files = []
        for fe in field_ext:
            if fe in self.rgb_ext:
                select_file_ext = self.file_ext[self.rgb_ext.index(fe)]
            else:
                select_file_ext = self.file_ext[-1]
            if type(field_id) is not str:
                field_id = str(field_id)
            file = ersa_utils.rotate_list(
                self.get_file_selection(field_name, field_id, fe,
                                        select_file_ext))[0]
            files.append(file)

        files = ersa_utils.rotate_list(files)
        if len(files) == 1:
            if len(files[0]) == 1:
                # only one file been requested
                return files[0][0]
        return files
Exemplo n.º 3
0
 def get_file_selection(self, field_name, field_id, field_ext, file_ext):
     """
     Get list of lists of files selected by given field names, field ids, field extensions and file extensions
     :param field_name: name of the fields (e.g., city names)
     :param field_id: id of the fields (e.g., tile numbers)
     :param field_ext: extension of the fields (e.g., RGB)
     :param file_ext: file extension (e.g., tif)
     :return: list of lists, where each row is file names of same place with different files
     """
     field_name = ersa_utils.str2list(field_name, d_type=str)
     field_id = ersa_utils.str2list(field_id, d_type=str)
     field_ext = ersa_utils.str2list(field_ext, d_type=str)
     file_ext = ersa_utils.str2list(file_ext, d_type=str)
     if len(file_ext) == 1:
         file_ext = [file_ext[0] for _ in range(len(field_ext))]
     file_selection = []
     for field, file in zip(field_ext, file_ext):
         regexp = self.make_regexp(field_name, field_id, field, file)
         file_selection.append(self.get_files(regexp, full_path=True))
     file_selection = ersa_utils.rotate_list(file_selection)
     return file_selection
Exemplo n.º 4
0
    def __init__(self,
                 raw_data_path,
                 field_name,
                 field_id,
                 rgb_ext,
                 gt_ext,
                 file_ext,
                 files=None,
                 clc_name=None,
                 force_run=False):
        """
        Create a collection
        :param raw_data_path: path to where the data are stored
        :param field_name: could be name of the cities, or other prefix of the images
        :param field_id: could be id of the tiles, or other suffix of the images
        :param rgb_ext: name extensions that indicates the images are not ground truth, use ',' to separate if you have
                        multiple extensions
        :param gt_ext: name extensions that indicates the images are ground truth, you can only have at most one ground
                       truth extension
        :param file_ext: extension of the files, use ',' to separate if you have multiple extensions, if all the files
                         have the same extension, you only need to specify one
        :param files: files in the raw_data_path, can be specified by user to exclude some of the raw files, if it is
                      None, all files will be found automatically
        :param clc_name: name of the collection, if set to None, it will be the name of the raw_data_path folder
        :param force_run: force run the collection maker even if it already exists
        """
        self.raw_data_path = raw_data_path
        self.field_name = ersa_utils.str2list(
            field_name,
            d_type=str)  # name of the 'cities' to include in the collection
        self.field_id = ersa_utils.str2list(
            field_id,
            d_type=str)  # id of the 'cities' to include in the collection
        self.rgb_ext = ersa_utils.str2list(rgb_ext, d_type=str)
        self.gt_ext = gt_ext
        if len(gt_ext) == 0:
            has_gt_ext = 0
        else:
            has_gt_ext = 1
        self.file_ext = ersa_utils.str2list(file_ext, d_type=str)
        assert len(self.file_ext) == 1 or len(
            self.file_ext) == len(self.rgb_ext) + has_gt_ext
        if len(self.file_ext) == 1:
            self.file_ext = [
                self.file_ext[0]
                for _ in range(len(self.rgb_ext) + has_gt_ext)
            ]
        if clc_name is None:
            clc_name = os.path.basename(raw_data_path)
        self.clc_name = clc_name  # name of the collection
        self.clc_dir = self.get_dir()  # directory to store the collection
        self.force_run = force_run

        # make collection
        if files is None:
            self.files = sorted(glob(os.path.join(self.raw_data_path, '*.*')))
        else:
            self.files = files
        self.clc_pb = processBlock.BasicProcess('collection_maker',
                                                self.clc_dir,
                                                self.make_collection)
        self.clc_pb.run(self.force_run)
        self.meta_data = self.read_meta_data()
Exemplo n.º 5
0
def main(flags):
    run_id_list = ersa_utils.str2list(flags.run_id, d_type=int)
    for flags.run_id in run_id_list:
        tf.reset_default_graph()
        np.random.seed(int(flags.run_id))
        tf.set_random_seed(int(flags.run_id))

        if flags.start_layer >= 10:
            flags.model_name = flags.model_name[:-1] + str(flags.run_id)
        else:
            if int(flags.run_id) == 0:
                flags.model_name = flags.model_name[:-1] + str(flags.run_id)
            else:
                flags.model_name = flags.model_name[:-5] + str(flags.run_id)
            flags.model_name += '_up{}'.format(flags.start_layer)

        # make network
        # define place holder
        X = tf.placeholder(tf.float32, shape=[None, flags.input_size[0], flags.input_size[1], 3], name='X')
        y = tf.placeholder(tf.int32, shape=[None, flags.input_size[0], flags.input_size[1], 1], name='y')
        mode = tf.placeholder(tf.bool, name='mode')
        model = uabMakeNetwork_UNet.UnetModelCrop({'X': X, 'Y': y},
                                                  trainable=mode,
                                                  model_name=flags.model_name,
                                                  input_size=flags.input_size,
                                                  batch_size=flags.batch_size,
                                                  learn_rate=flags.learning_rate,
                                                  decay_step=flags.decay_step,
                                                  decay_rate=flags.decay_rate,
                                                  epochs=flags.epochs,
                                                  start_filter_num=flags.sfn)
        model.create_graph('X', class_num=flags.num_classes)

        # create collection
        # the original file is in /ei-edl01/data/uab_datasets/inria
        blCol = uab_collectionFunctions.uabCollection(flags.ds_name)
        blCol.readMetadata()
        img_mean = blCol.getChannelMeans([1, 2, 3])  # get mean of rgb info

        # extract patches
        extrObj = uab_DataHandlerFunctions.uabPatchExtr([0, 1, 2, 3],
                                                        cSize=flags.input_size,
                                                        numPixOverlap=int(model.get_overlap()),
                                                        extSave=['png', 'jpg', 'jpg', 'jpg'],
                                                        isTrain=True,
                                                        gtInd=0,
                                                        pad=int(model.get_overlap()//2))
        patchDir = extrObj.run(blCol)

        # make data reader
        # use first 5 tiles for validation
        idx, file_list = uabCrossValMaker.uabUtilGetFolds(patchDir, 'fileList.txt', 'tile')
        file_list_train = uabCrossValMaker.make_file_list_by_key(idx, file_list, [0, 1, 2, 3])
        file_list_valid = uabCrossValMaker.make_file_list_by_key(idx, file_list, [4, 5])

        with tf.name_scope('image_loader'):
            # GT has no mean to subtract, append a 0 for block mean
            dataReader_train = uabDataReader.ImageLabelReader([0], [1, 2, 3], patchDir, file_list_train, flags.input_size,
                                                              flags.tile_size,
                                                              flags.batch_size, dataAug='flip,rotate',
                                                              block_mean=np.append([0], img_mean))
            # no augmentation needed for validation
            dataReader_valid = uabDataReader.ImageLabelReader([0], [1, 2, 3], patchDir, file_list_valid, flags.input_size,
                                                              flags.tile_size,
                                                              flags.batch_size, dataAug=' ', block_mean=np.append([0], img_mean))

        # train
        start_time = time.time()

        if flags.start_layer >= 10:
            model.train_config('X', 'Y', flags.n_train, flags.n_valid, flags.input_size, uabRepoPaths.modelPath,
                               loss_type='xent', par_dir='aemo/{}'.format(flags.ds_name))
        else:
            model.train_config('X', 'Y', flags.n_train, flags.n_valid, flags.input_size, uabRepoPaths.modelPath,
                               loss_type='xent', par_dir='aemo/{}'.format(flags.ds_name),
                               train_var_filter=['layerup{}'.format(i) for i in range(flags.start_layer, 10)])
        model.run(train_reader=dataReader_train,
                  valid_reader=dataReader_valid,
                  pretrained_model_dir=flags.model_dir,   # train from scratch, no need to load pre-trained model
                  isTrain=True,
                  img_mean=img_mean,
                  verb_step=100,                        # print a message every 100 step(sample)
                  save_epoch=20,                         # save the model every 5 epochs
                  gpu=GPU,
                  tile_size=flags.tile_size,
                  patch_size=flags.input_size)

        duration = time.time() - start_time
        print('duration {:.2f} hours'.format(duration/60/60))
Exemplo n.º 6
0
    parser = argparse.ArgumentParser()
    parser.add_argument('--test-list', default=TEST_LIST, type=str, help='test list')
    parser.add_argument('--min-size', default=MIN_SIZE, type=int, help='minimum size to remove')

    flags = parser.parse_args()
    return flags


if __name__ == '__main__':
    flags = read_flag()
    model_dirs = [
        r'/hdd/Results/domain_selection/UnetCrop_inria_aug_grid_0_PS(572, 572)_BS5_EP100_LR0.0001_DS60_DR0.1_SFN32',
        r'/hdd/Results/domain_selection/DeeplabV3_inria_aug_grid_0_PS(321, 321)_BS5_EP100_LR1e-05_DS40_DR0.1_SFN32',
        r'/hdd/Results/ugan/UnetGAN_V3_inria_gan_xregion_0_PS(572, 572)_BS20_EP30_LR0.0001_1e-06_1e-06_DS30.0_30.0_30.0_DR0.1_0.1_0.1',
    ]
    model_list = ersa_utils.str2list(flags.test_list)
    model_dirs = [model_dirs[x] for x in model_list]
    img_dir, task_dir = sis_utils.get_task_img_folder()

    for cnt, model_dir in enumerate(model_dirs):
        pred_files = sorted(glob(os.path.join(model_dir, 'inria', 'pred', '*.png')))
        save_file = os.path.join(task_dir, '{}_xregion_wolverine.txt'.format(model_dir))
        with open(save_file, 'w'):
            pass

        for pred_file in pred_files:
            city_name = pred_file.split('/')[-1].split('.')[0]
            pred = ersa_utils.load_file(pred_file)
            pred = remove_small_objects(pred, flags.min_size)

            lbl, idx = get_object_id(pred)