Ejemplo n.º 1
0
def main(flags):
    nn_utils.set_gpu(GPU)

    # define network
    model = unet.UNet(flags.num_classes, flags.patch_size, suffix=flags.model_suffix, learn_rate=flags.learning_rate,
                          decay_step=flags.decay_step, decay_rate=flags.decay_rate, epochs=flags.epochs,
                          batch_size=flags.batch_size)
    overlap = model.get_overlap()

    cm = collectionMaker.read_collection(raw_data_path=flags.data_dir,
                                         field_name='austin,chicago,kitsap,tyrol-w,vienna',
                                         field_id=','.join(str(i) for i in range(37)),
                                         rgb_ext='RGB',
                                         gt_ext='GT',
                                         file_ext='tif',
                                         force_run=False,
                                         clc_name=flags.ds_name)
    gt_d255 = collectionEditor.SingleChanMult(cm.clc_dir, 1 / 255, ['GT', 'gt_d255']). \
        run(force_run=False, file_ext='png', d_type=np.uint8, )
    cm.replace_channel(gt_d255.files, True, ['GT', 'gt_d255'])
    cm.print_meta_data()
    file_list_train = cm.load_files(field_id=','.join(str(i) for i in range(6, 37)), field_ext='RGB,gt_d255')
    file_list_valid = cm.load_files(field_id=','.join(str(i) for i in range(6)), field_ext='RGB,gt_d255')
    chan_mean = cm.meta_data['chan_mean'][:3]

    patch_list_train = patchExtractor.PatchExtractor(flags.patch_size, flags.tile_size, flags.ds_name + '_train',
                                                     overlap, overlap // 2). \
        run(file_list=file_list_train, file_exts=['jpg', 'png'], force_run=False).get_filelist()
    patch_list_valid = patchExtractor.PatchExtractor(flags.patch_size, flags.tile_size, flags.ds_name + '_valid',
                                                     overlap, overlap // 2). \
        run(file_list=file_list_valid, file_exts=['jpg', 'png'], force_run=False).get_filelist()

    train_init_op, valid_init_op, reader_op = \
        dataReaderSegmentation.DataReaderSegmentationTrainValid(
            flags.patch_size, patch_list_train, patch_list_valid, batch_size=flags.batch_size, chan_mean=chan_mean,
            aug_func=[reader_utils.image_flipping, reader_utils.image_rotating],
            random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=flags.val_mult).read_op()
    feature, label = reader_op

    model.create_graph(feature)
    model.compile(feature, label, flags.n_train, flags.n_valid, flags.patch_size, ersaPath.PATH['model'],
                  par_dir=flags.model_par_dir, val_mult=flags.val_mult, loss_type='xent')
    train_hook = hook.ValueSummaryHook(flags.verb_step, [model.loss, model.lr_op],
                                       value_names=['train_loss', 'learning_rate'], print_val=[0])
    model_save_hook = hook.ModelSaveHook(model.get_epoch_step()*flags.save_epoch, model.ckdir)
    valid_loss_hook = hook.ValueSummaryHook(model.get_epoch_step(), [model.loss, model.loss_iou],
                                            value_names=['valid_loss', 'valid_mIoU'], log_time=True,
                                            run_time=model.n_valid, iou_pos=1)
    image_hook = hook.ImageValidSummaryHook(model.input_size, model.get_epoch_step(), feature, label, model.output,
                                            nn_utils.image_summary, img_mean=cm.meta_data['chan_mean'])
    start_time = time.time()
    model.train(train_hooks=[train_hook, model_save_hook], valid_hooks=[valid_loss_hook, image_hook],
                train_init=train_init_op, valid_init=valid_init_op)
    print('Duration: {:.3f}'.format((time.time() - start_time)/3600))
Ejemplo n.º 2
0
def main(flags):
    nn_utils.set_gpu(GPU)

    # define network
    model = pspnet.PSPNet(flags.num_classes, flags.tile_size, suffix=flags.model_suffix, learn_rate=flags.learning_rate,
                          decay_step=flags.decay_step, decay_rate=flags.decay_rate, epochs=flags.epochs,
                          batch_size=flags.batch_size, weight_decay=flags.weight_decay, momentum=flags.momentum)

    cm_train = cityscapes_reader.CollectionMakerCityscapes(flags.data_dir, flags.rgb_type, flags.gt_type, 'train',
                                                           flags.rgb_ext, flags.gt_ext, ['png', 'png'],
                                                           clc_name='{}_train'.format(flags.ds_name),
                                                           force_run=flags.force_run)
    cm_valid = cityscapes_reader.CollectionMakerCityscapes(flags.data_dir, flags.rgb_type, flags.gt_type, 'val',
                                                           flags.rgb_ext, flags.gt_ext, ['png', 'png'],
                                                           clc_name='{}_valid'.format(flags.ds_name),
                                                           force_run=flags.force_run)
    cm_train.print_meta_data()
    cm_valid.print_meta_data()

    resize_func = lambda img: resize_image(img, flags.tile_size)
    train_init_op, valid_init_op, reader_op = dataReaderSegmentation.DataReaderSegmentationTrainValid(
            flags.tile_size, cm_train.meta_data['file_list'], cm_valid.meta_data['file_list'],
            flags.batch_size, cm_train.meta_data['chan_mean'], aug_func=[reader_utils.image_flipping_hori,
                                                                         reader_utils.image_scaling_with_label],
            random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=flags.val_mult, global_func=resize_func)\
        .read_op()
    feature, label = reader_op

    model.create_graph(feature)
    model.load_resnet(flags.res_dir)
    model.compile(feature, label, flags.n_train, flags.n_valid, flags.tile_size, ersaPath.PATH['model'],
                  par_dir=flags.model_par_dir, val_mult=flags.val_mult, loss_type='xent')
    train_hook = hook.ValueSummaryHook(flags.verb_step, [model.loss, model.lr_op],
                                       value_names=['train_loss', 'learning_rate'], print_val=[0])
    model_save_hook = hook.ModelSaveHook(model.get_epoch_step()*flags.save_epoch, model.ckdir)
    valid_loss_hook = hook.ValueSummaryHook(model.get_epoch_step(), [model.loss, model.loss_iou],
                                            value_names=['valid_loss', 'valid_mIoU'], log_time=True,
                                            run_time=model.n_valid, iou_pos=1)
    image_hook = hook.ImageValidSummaryHook(model.input_size, model.get_epoch_step(), feature, label, model.output,
                                            cityscapes_labels.image_summary, img_mean=cm_train.meta_data['chan_mean'])
    start_time = time.time()
    model.train(train_hooks=[train_hook, model_save_hook], valid_hooks=[valid_loss_hook, image_hook],
                train_init=train_init_op, valid_init=valid_init_op)
    print('Duration: {:.3f}'.format((time.time() - start_time)/3600))
Ejemplo n.º 3
0
suffix = 'aemo_hist_rgb'
ds_name = 'aemo'
lr = 1e-3
ds = 60
dr = 0.1
epochs = 130
bs = 5
valid_mult = 5
gpu = 0
n_train = 785
n_valid = 500
verb_step = 50
save_epoch = 5
model_dir = r'/hdd6/Models/spca/psp101/pspnet_spca_PS(384, 384)_BS5_EP100_LR0.001_DS40_DR0.1'

nn_utils.set_gpu(gpu)

# define network
unet = pspnet.PSPNet(class_num, patch_size, suffix=suffix, learn_rate=lr, decay_step=ds, decay_rate=dr,
                     epochs=epochs, batch_size=bs, weight_decay=1e-3)
overlap = unet.get_overlap()

cm = collectionMaker.read_collection(raw_data_path=r'/home/lab/Documents/bohao/data/aemo',
                                     field_name='aus10,aus30,aus50',
                                     field_id='',
                                     rgb_ext='.*rgb',
                                     gt_ext='.*gt',
                                     file_ext='tif',
                                     force_run=False,
                                     clc_name='aemo')
gt_d255 = collectionEditor.SingleChanMult(cm.clc_dir, 1/255, ['.*gt', 'gt_d255']).\
Ejemplo n.º 4
0
def main(flags):
    nn_utils.set_gpu(GPU)

    # define network
    model = pspnet.PSPNet(flags.num_classes,
                          flags.patch_size,
                          suffix=flags.model_suffix,
                          learn_rate=flags.learning_rate,
                          decay_step=flags.decay_step,
                          decay_rate=flags.decay_rate,
                          epochs=flags.epochs,
                          batch_size=flags.batch_size,
                          weight_decay=flags.weight_decay,
                          momentum=flags.momentum)
    overlap = model.get_overlap()

    cm = collectionMaker.read_collection(raw_data_path=flags.data_dir,
                                         field_name='Fresno,Modesto,Stockton',
                                         field_id=','.join(
                                             str(i) for i in range(663)),
                                         rgb_ext='RGB',
                                         gt_ext='GT',
                                         file_ext='jpg,png',
                                         force_run=False,
                                         clc_name=flags.ds_name)
    cm.print_meta_data()
    file_list_train = cm.load_files(field_id=','.join(
        str(i) for i in range(0, 250)),
                                    field_ext='RGB,GT')
    file_list_valid = cm.load_files(field_id=','.join(
        str(i) for i in range(250, 500)),
                                    field_ext='RGB,GT')
    chan_mean = cm.meta_data['chan_mean'][:3]

    patch_list_train = patchExtractor.PatchExtractor(flags.patch_size, flags.tile_size, flags.ds_name + '_train',
                                                     overlap, overlap // 2). \
        run(file_list=file_list_train, file_exts=['jpg', 'png'], force_run=False).get_filelist()
    patch_list_valid = patchExtractor.PatchExtractor(flags.patch_size, flags.tile_size, flags.ds_name + '_valid',
                                                     overlap, overlap // 2). \
        run(file_list=file_list_valid, file_exts=['jpg', 'png'], force_run=False).get_filelist()

    train_init_op, valid_init_op, reader_op = \
        dataReaderSegmentation.DataReaderSegmentationTrainValid(
            flags.patch_size, patch_list_train, patch_list_valid, batch_size=flags.batch_size, chan_mean=chan_mean,
            aug_func=[reader_utils.image_flipping, reader_utils.image_rotating],
            random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=flags.val_mult).read_op()
    feature, label = reader_op

    model.create_graph(feature)
    model.load_resnet(flags.res_dir, keep_last=False)
    model.compile(feature,
                  label,
                  flags.n_train,
                  flags.n_valid,
                  flags.patch_size,
                  ersaPath.PATH['model'],
                  par_dir=flags.model_par_dir,
                  val_mult=flags.val_mult,
                  loss_type='xent')
    train_hook = hook.ValueSummaryHook(
        flags.verb_step, [model.loss, model.lr_op],
        value_names=['train_loss', 'learning_rate'],
        print_val=[0])
    model_save_hook = hook.ModelSaveHook(
        model.get_epoch_step() * flags.save_epoch, model.ckdir)
    valid_loss_hook = hook.ValueSummaryHookIters(
        model.get_epoch_step(), [model.loss_xent, model.loss_iou],
        value_names=['valid_loss', 'valid_mIoU'],
        log_time=True,
        run_time=model.n_valid)
    image_hook = hook.ImageValidSummaryHook(model.input_size,
                                            model.get_epoch_step(),
                                            feature,
                                            label,
                                            model.output,
                                            nn_utils.image_summary,
                                            img_mean=cm.meta_data['chan_mean'])
    start_time = time.time()
    model.train(train_hooks=[train_hook, model_save_hook],
                valid_hooks=[valid_loss_hook, image_hook],
                train_init=train_init_op,
                valid_init=valid_init_op)
    print('Duration: {:.3f}'.format((time.time() - start_time) / 3600))
Ejemplo n.º 5
0
    def process(self):
        """
        Evaluate the network
        :return:
        """
        nn_utils.set_gpu(self.gpu)

        if self.score_results:
            with open(os.path.join(self.score_save_dir, 'result.txt'), 'w'):
                pass
        iou_record = []

        # prepare the reader
        if self.score_results:
            init_op, reader_op = dataReaderSegmentation.DataReaderSegmentationTesting(
                self.input_size,
                self.tile_size,
                self.file_list,
                overlap=self.model.get_overlap(),
                pad=self.model.get_overlap() // 2,
                batch_size=self.batch_size,
                chan_mean=self.img_mean,
                is_train=False,
                has_gt=True,
                random=False,
                gt_dim=1,
                include_gt=True).read_op()
            feature, label = reader_op
            self.model.create_graph(feature, **self.kwargs)
        else:
            init_op, reader_op = dataReaderSegmentation.DataReaderSegmentationTesting(
                self.input_size,
                self.tile_size,
                self.file_list,
                overlap=self.model.get_overlap(),
                pad=self.model.get_overlap() // 2,
                batch_size=self.batch_size,
                chan_mean=self.img_mean,
                is_train=False,
                has_gt=False,
                random=False,
                gt_dim=0,
                include_gt=False).read_op()
            feature = reader_op
            self.model.create_graph(feature[0], **self.kwargs)
        pad = self.model.get_overlap()

        for file_cnt, (file_name_list) in enumerate(self.file_list):
            file_name_truth = None
            if self.score_results:
                file_name, file_name_truth = file_name_list
                tile_name = os.path.basename(file_name_truth).split(
                    self.split_char)[0]
            else:
                file_name = file_name_list[0]
                tile_name = os.path.basename(file_name).split(
                    self.split_char)[0]
            if self.verb:
                print('Evaluating {} ... '.format(tile_name))

            # read tile size if no tile size is given
            if self.tile_size is None or self.compute_shape_flag:
                self.compute_shape_flag = True
                tile = ersa_utils.load_file(file_name)
                self.tile_size = tile.shape[:2]

            start_time = time.time()

            # run the model
            if self.model.config is None:
                self.model.config = tf.ConfigProto(allow_soft_placement=True)
            with tf.Session(config=self.model.config) as sess:
                init = tf.global_variables_initializer()
                sess.run(init)
                self.model.load(self.model_dir,
                                sess,
                                epoch=self.load_epoch_num,
                                best_model=self.best_model)
                result = self.model.test_sample(sess, init_op[file_cnt])
            image_pred = patchExtractor.unpatch_block(
                result,
                tile_dim=[self.tile_size[0] + pad, self.tile_size[1] + pad],
                patch_size=self.input_size,
                tile_dim_output=self.tile_size,
                patch_size_output=[
                    self.input_size[0] - pad, self.input_size[1] - pad
                ],
                overlap=pad)
            if self.compute_shape_flag:
                self.tile_size = None

            pred = nn_utils.get_pred_labels(image_pred) * self.truth_val

            if self.score_results:
                truth_label_img = ersa_utils.load_file(file_name_truth)
                iou = nn_utils.iou_metric(truth_label_img,
                                          pred,
                                          divide_flag=True)
                iou_record.append(iou)

                duration = time.time() - start_time
                if self.verb:
                    print('{} mean IoU={:.3f}, duration: {:.3f}'.format(
                        tile_name, iou[0] / iou[1], duration))

                # save results
                pred_save_dir = os.path.join(self.score_save_dir, 'pred')
                ersa_utils.make_dir_if_not_exist(pred_save_dir)
                ersa_utils.save_file(
                    os.path.join(pred_save_dir, '{}.png'.format(tile_name)),
                    pred.astype(np.uint8))
                if self.score_results:
                    with open(os.path.join(self.score_save_dir, 'result.txt'),
                              'a+') as file:
                        file.write('{} {}\n'.format(tile_name, iou))

        if self.score_results:
            iou_record = np.array(iou_record)
            mean_iou = np.sum(iou_record[:, 0]) / np.sum(iou_record[:, 1])
            print('Overall mean IoU={:.3f}'.format(mean_iou))
            with open(os.path.join(self.score_save_dir, 'result.txt'),
                      'a+') as file:
                file.write('{}'.format(mean_iou))
Ejemplo n.º 6
0
    def process(self):
        """
        Evaluate the network
        :return:
        """
        nn_utils.set_gpu(self.gpu)

        if self.score_results:
            with open(os.path.join(self.score_save_dir, 'result.txt'), 'w'):
                pass
        iou_record = []

        # prepare the reader
        if self.score_results:
            feature, label = self.reader_op
            self.model.create_graph(feature, **self.kwargs)
        else:
            feature = self.reader_op
            self.model.create_graph(feature[0], **self.kwargs)

        # run the model
        with tf.Session() as sess:
            init = tf.global_variables_initializer()
            sess.run([init, self.init_op])
            self.model.load(self.model_dir,
                            sess,
                            epoch=self.load_epoch_num,
                            best_model=self.best_model)
            for file_cnt, (file_name_list) in enumerate(self.file_list):
                file_name_truth = None
                if self.score_results:
                    file_name, file_name_truth = file_name_list
                    tile_name = os.path.basename(file_name_truth).split(
                        self.split_char)[0]
                else:
                    file_name = file_name_list[0]
                    tile_name = os.path.basename(file_name).split(
                        self.split_char)[0]
                if self.verb:
                    print('Evaluating {} ... '.format(tile_name))
                start_time = time.time()

                result = sess.run(self.model.output)
                pred = np.argmax(np.squeeze(result, axis=0), axis=-1)
                if self.post_func is not None:
                    pred = self.post_func(pred)
                if self.save_func is not None:
                    save_img = self.save_func(pred)
                else:
                    save_img = pred

                if self.score_results:
                    truth_label_img = ersa_utils.load_file(file_name_truth)
                    iou = nn_evals.mean_IU(pred, truth_label_img,
                                           self.ignore_label)
                    iou_record.append(iou)

                    duration = time.time() - start_time
                    if self.verb:
                        print('{} mean IoU={:.3f}, duration: {:.3f}'.format(
                            tile_name, iou, duration))

                    # save results
                    pred_save_dir = os.path.join(self.score_save_dir, 'pred')
                    ersa_utils.make_dir_if_not_exist(pred_save_dir)
                    ersa_utils.save_file(
                        os.path.join(pred_save_dir,
                                     '{}.png'.format(tile_name)),
                        save_img.astype(np.uint8))
                    if self.score_results:
                        with open(
                                os.path.join(self.score_save_dir,
                                             'result.txt'), 'a+') as file:
                            file.write('{} {}\n'.format(tile_name, iou))

        if self.score_results:
            iou_record = np.array(iou_record)
            mean_iou = np.mean(iou_record)
            print('Overall mean IoU={:.3f}'.format(mean_iou))
            with open(os.path.join(self.score_save_dir, 'result.txt'),
                      'a+') as file:
                file.write('{}'.format(mean_iou))
Ejemplo n.º 7
0
def main(flags):
    nn_utils.set_gpu(flags.GPU)
    np.random.seed(flags.run_id)
    tf.set_random_seed(flags.run_id)

    # define network
    model = unet.UNet(flags.num_classes,
                      flags.patch_size,
                      suffix=flags.suffix,
                      learn_rate=flags.learn_rate,
                      decay_step=flags.decay_step,
                      decay_rate=flags.decay_rate,
                      epochs=flags.epochs,
                      batch_size=flags.batch_size)
    overlap = model.get_overlap()

    cm = collectionMaker.read_collection(
        raw_data_path=flags.data_dir,
        field_name='Tucson,Colwich,Clyde,Wilmington',
        field_id=','.join(str(i) for i in range(1, 16)),
        rgb_ext='RGB',
        gt_ext='GT',
        file_ext='tif,png',
        force_run=False,
        clc_name=flags.ds_name)
    gt_d255 = collectionEditor.SingleChanSwitch(cm.clc_dir, {
        2: 0,
        3: 1,
        4: 0,
        5: 0,
        6: 0,
        7: 0
    }, ['GT', 'GT_switch'], 'tower_only').run(
        force_run=False,
        file_ext='png',
        d_type=np.uint8,
    )
    cm.replace_channel(gt_d255.files, True, ['GT', 'GT_switch'])
    cm.print_meta_data()

    file_list_train = cm.load_files(
        field_name='Tucson,Colwich,Clyde,Wilmington',
        field_id=','.join(str(i) for i in range(4, 16)),
        field_ext='RGB,GT_switch')
    file_list_valid = cm.load_files(
        field_name='Tucson,Colwich,Clyde,Wilmington',
        field_id='1,2,3',
        field_ext='RGB,GT_switch')

    patch_list_train = patchExtractor.PatchExtractor(flags.patch_size,
                                                     ds_name=flags.ds_name + '_tower_only',
                                                     overlap=overlap, pad=overlap // 2). \
        run(file_list=file_list_train, file_exts=['jpg', 'png'], force_run=False).get_filelist()
    patch_list_valid = patchExtractor.PatchExtractor(flags.patch_size,
                                                     ds_name=flags.ds_name + '_tower_only',
                                                     overlap=overlap, pad=overlap // 2). \
        run(file_list=file_list_valid, file_exts=['jpg', 'png'], force_run=False).get_filelist()
    chan_mean = cm.meta_data['chan_mean']

    train_init_op, valid_init_op, reader_op = \
        dataReaderSegmentation.DataReaderSegmentationTrainValid(
            flags.patch_size, patch_list_train, patch_list_valid, batch_size=flags.batch_size,
            chan_mean=chan_mean,
            aug_func=[reader_utils.image_flipping, reader_utils.image_rotating],
            random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=flags.val_mult).read_op()
    feature, label = reader_op

    model.create_graph(feature)
    model.compile(feature,
                  label,
                  flags.n_train,
                  flags.n_valid,
                  flags.patch_size,
                  ersaPath.PATH['model'],
                  par_dir=flags.par_dir,
                  loss_type='xent',
                  pos_weight=flags.pos_weight)
    train_hook = hook.ValueSummaryHook(
        flags.verb_step, [model.loss, model.lr_op],
        value_names=['train_loss', 'learning_rate'],
        print_val=[0])
    model_save_hook = hook.ModelSaveHook(
        model.get_epoch_step() * flags.save_epoch, model.ckdir)
    valid_loss_hook = hook.ValueSummaryHookIters(
        model.get_epoch_step(), [model.loss_iou, model.loss_xent],
        value_names=['IoU', 'valid_loss'],
        log_time=True,
        run_time=model.n_valid)
    image_hook = hook.ImageValidSummaryHook(model.input_size,
                                            model.get_epoch_step(),
                                            feature,
                                            label,
                                            model.pred,
                                            partial(
                                                nn_utils.image_summary,
                                                label_num=flags.num_classes),
                                            img_mean=chan_mean)
    start_time = time.time()
    model.train(train_hooks=[train_hook, model_save_hook],
                valid_hooks=[valid_loss_hook, image_hook],
                train_init=train_init_op,
                valid_init=valid_init_op)
    print('Duration: {:.3f}'.format((time.time() - start_time) / 3600))
Ejemplo n.º 8
0
def main(flags):
    nn_utils.set_gpu(GPU)

    # define network
    model = stn.STN(flags.num_classes,
                    flags.tile_size,
                    suffix=flags.model_suffix,
                    learn_rate=flags.learning_rate,
                    decay_step=flags.decay_step,
                    decay_rate=flags.decay_rate,
                    epochs=flags.epochs,
                    batch_size=flags.batch_size)

    cm_train = cityscapes_reader.CollectionMakerCityscapes(
        flags.data_dir,
        flags.rgb_type,
        flags.gt_type,
        'train',
        flags.rgb_ext,
        flags.gt_ext, ['png', 'png'],
        clc_name='{}_train'.format(flags.ds_name),
        force_run=flags.force_run)

    cm_valid = cityscapes_reader.CollectionMakerCityscapes(
        flags.data_dir,
        flags.rgb_type,
        flags.gt_type,
        'val',
        flags.rgb_ext,
        flags.gt_ext, ['png', 'png'],
        clc_name='{}_valid'.format(flags.ds_name),
        force_run=flags.force_run)
    cm_train.print_meta_data()
    cm_valid.print_meta_data()

    train_pred_dir = r'/home/lab/Documents/bohao/data/deeplab_model/vis_train/raw_segmentation_results'
    valid_pred_dir = r'/home/lab/Documents/bohao/data/deeplab_model/vis/raw_segmentation_results'
    file_list_train = get_image_list(flags.data_dir, train_pred_dir, 'train')
    file_list_valid = get_image_list(flags.data_dir, valid_pred_dir, 'val')

    resize_func = lambda img: resize_image(img, flags.tile_size)
    train_init_op, valid_init_op, reader_op = DataReaderSegmentationTrainValid(
            flags.tile_size, file_list_train, file_list_valid,
            flags.batch_size, cm_train.meta_data['chan_mean'], aug_func=[reader_utils.image_flipping_hori],
            random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=flags.val_mult, global_func=resize_func)\
        .read_op()
    feature, label, pred = reader_op
    train_init_op_valid, _, reader_op = DataReaderSegmentationTrainValid(
        flags.tile_size, file_list_valid, file_list_train,
        flags.batch_size, cm_valid.meta_data['chan_mean'], aug_func=[reader_utils.image_flipping_hori],
        random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=flags.val_mult, global_func=resize_func) \
        .read_op()
    _, _, pred_valid = reader_op

    model.create_graph(pred, feature_valid=pred_valid, rgb=feature)

    model.compile(pred,
                  label,
                  flags.n_train,
                  flags.n_valid,
                  flags.tile_size,
                  ersaPath.PATH['model'],
                  par_dir=flags.model_par_dir,
                  val_mult=flags.val_mult,
                  loss_type='xent')
    train_hook = hook.ValueSummaryHook(
        flags.verb_step, [
            model.loss, model.g_loss, model.d_loss, model.lr_op[0],
            model.lr_op[1], model.lr_op[2]
        ],
        value_names=['seg_loss', 'g_loss', 'd_loss', 'lr_seg', 'lr_g', 'lr_d'],
        print_val=[0, 1, 2])
    model_save_hook = hook.ModelSaveHook(
        model.get_epoch_step() * flags.save_epoch, model.ckdir)
    valid_loss_hook = hook.ValueSummaryHookIters(
        model.get_epoch_step(), [model.loss_xent, model.loss_iou],
        value_names=['valid_loss', 'valid_mIoU'],
        log_time=True,
        run_time=model.n_valid)
    image_hook = hook.ImageValidSummaryHook(
        model.input_size,
        model.get_epoch_step(),
        feature,
        label,
        model.refine,
        cityscapes_labels.image_summary,
        img_mean=cm_train.meta_data['chan_mean'])
    start_time = time.time()
    model.train(train_hooks=[train_hook, model_save_hook],
                valid_hooks=[valid_loss_hook, image_hook],
                train_init=[train_init_op, train_init_op_valid],
                valid_init=valid_init_op)
    print('Duration: {:.3f}'.format((time.time() - start_time) / 3600))
Ejemplo n.º 9
0
def main(flags):
    nn_utils.set_gpu(GPU)

    # define network
    model = deeplab.DeepLab(flags.num_classes,
                            flags.tile_size,
                            batch_size=flags.batch_size)

    cm_train = cityscapes_reader.CollectionMakerCityscapes(
        flags.data_dir,
        flags.rgb_type,
        flags.gt_type,
        'train',
        flags.rgb_ext,
        flags.gt_ext, ['png', 'png'],
        clc_name='{}_train'.format(flags.ds_name),
        force_run=False)
    cm_test = cityscapes_reader.CollectionMakerCityscapes(
        flags.data_dir,
        flags.rgb_type,
        flags.gt_type,
        'val',
        flags.rgb_ext,
        flags.gt_ext, ['png', 'png'],
        clc_name='{}_valid'.format(flags.ds_name),
        force_run=False)
    cm_test.print_meta_data()
    resize_func_train = lambda img: skimage.transform.resize(
        img, flags.tile_size, mode='reflect')
    resize_func_test = lambda img: skimage.transform.resize(
        img,
        cm_test.meta_data['tile_dim'],
        order=0,
        preserve_range=True,
        mode='reflect')

    init_op, reader_op = dataReaderSegmentation.DataReaderSegmentation(
        flags.tile_size,
        cm_test.meta_data['file_list'],
        batch_size=flags.batch_size,
        random=False,
        chan_mean=cm_train.meta_data['chan_mean'],
        is_train=False,
        has_gt=True,
        gt_dim=1,
        include_gt=True,
        global_func=resize_func_train).read_op()
    estimator = nn_processor.NNEstimatorSegmentScene(
        model,
        cm_test.meta_data['file_list'],
        flags.res_dir,
        init_op,
        reader_op,
        ds_name='city_scapes',
        save_result_parent_dir='Cityscapes',
        gpu=flags.GPU,
        score_result=True,
        split_char='.',
        post_func=resize_func_test,
        save_func=make_general_id_map,
        ignore_label=(-1, 255))
    estimator.run(force_run=flags.force_run)
Ejemplo n.º 10
0
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from glob import glob
import ersa_utils
from nn import nn_utils, deeplab
from preprocess import patchExtractor
from collection import collectionMaker
from visualize import visualize_utils

patch_size = (321, 321)
overlap = 184
class_num = 2
bs = 1
suffix = 'aemo_pad'
nn_utils.set_gpu(1)

cm = collectionMaker.read_collection('aemo_pad')
chan_mean = cm.meta_data['chan_mean'][-3:]
file_list_valid = cm.load_files(field_name='aus50', field_id='', field_ext='.*rgb_hist,.*gt_d255')
pred_dir = r'/hdd/Results/aemo/deeplab_aemo_spca_hist_PS(321, 321)_BS5_EP2_LR0.001_DS50_DR0.1/default/pred'
pred_files = sorted(glob(os.path.join(pred_dir, '*.png')))

unet = deeplab.DeepLab(class_num, patch_size, suffix=suffix, batch_size=bs)
model_dir = r'/hdd6/Models/aemo/new2/deeplab_aemo_spca_hist_PS(321, 321)_BS5_EP2_LR0.001_DS50_DR0.1'
feature = tf.placeholder(tf.float32, shape=(None, patch_size[0], patch_size[1], 3))
unet.create_graph(feature)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    unet.load(model_dir, sess)
Ejemplo n.º 11
0
def main(flags):
    nn_utils.set_gpu(flags.GPU)
    for start_layer in flags.start_layer:
        if start_layer >= 10:
            suffix_base = 'aemo_newloss'
        else:
            suffix_base = 'aemo_newloss_up{}'.format(start_layer)
        if flags.from_scratch:
            suffix_base += '_scratch'
        for lr in flags.learn_rate:
            for run_id in range(4):
                suffix = '{}_{}'.format(suffix_base, run_id)
                tf.reset_default_graph()

                np.random.seed(run_id)
                tf.set_random_seed(run_id)

                # define network
                model = unet.UNet(flags.num_classes, flags.patch_size, suffix=suffix, learn_rate=lr,
                                  decay_step=flags.decay_step, decay_rate=flags.decay_rate,
                                  epochs=flags.epochs, batch_size=flags.batch_size)
                overlap = model.get_overlap()

                cm = collectionMaker.read_collection(raw_data_path=flags.data_dir,
                                                     field_name='aus10,aus30,aus50',
                                                     field_id='',
                                                     rgb_ext='.*rgb',
                                                     gt_ext='.*gt',
                                                     file_ext='tif',
                                                     force_run=False,
                                                     clc_name=flags.ds_name)
                cm.print_meta_data()

                file_list_train = cm.load_files(field_name='aus10,aus30', field_id='', field_ext='.*rgb,.*gt')
                file_list_valid = cm.load_files(field_name='aus50', field_id='', field_ext='.*rgb,.*gt')

                patch_list_train = patchExtractor.PatchExtractor(flags.patch_size, flags.tile_size,
                                                                 flags.ds_name + '_train_hist',
                                                                 overlap, overlap // 2). \
                    run(file_list=file_list_train, file_exts=['jpg', 'png'], force_run=False).get_filelist()
                patch_list_valid = patchExtractor.PatchExtractor(flags.patch_size, flags.tile_size,
                                                                 flags.ds_name + '_valid_hist',
                                                                 overlap, overlap // 2). \
                    run(file_list=file_list_valid, file_exts=['jpg', 'png'], force_run=False).get_filelist()
                chan_mean = cm.meta_data['chan_mean']

                train_init_op, valid_init_op, reader_op = \
                    dataReaderSegmentation.DataReaderSegmentationTrainValid(
                        flags.patch_size, patch_list_train, patch_list_valid, batch_size=flags.batch_size,
                        chan_mean=chan_mean,
                        aug_func=[reader_utils.image_flipping, reader_utils.image_rotating],
                        random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=flags.val_mult).read_op()
                feature, label = reader_op

                model.create_graph(feature)
                if start_layer >= 10:
                    model.compile(feature, label, flags.n_train, flags.n_valid, flags.patch_size, ersaPath.PATH['model'],
                                  par_dir=flags.par_dir, loss_type='xent')
                else:
                    model.compile(feature, label, flags.n_train, flags.n_valid, flags.patch_size, ersaPath.PATH['model'],
                                  par_dir=flags.par_dir, loss_type='xent',
                                  train_var_filter=['layerup{}'.format(i) for i in range(start_layer, 10)])
                train_hook = hook.ValueSummaryHook(flags.verb_step, [model.loss, model.lr_op],
                                                   value_names=['train_loss', 'learning_rate'],
                                                   print_val=[0])
                model_save_hook = hook.ModelSaveHook(model.get_epoch_step() * flags.save_epoch, model.ckdir)
                valid_loss_hook = hook.ValueSummaryHookIters(model.get_epoch_step(), [model.loss_xent, model.loss_iou],
                                                             value_names=['valid_loss', 'IoU'], log_time=True,
                                                             run_time=model.n_valid)
                image_hook = hook.ImageValidSummaryHook(model.input_size, model.get_epoch_step(), feature, label,
                                                        model.pred,
                                                        nn_utils.image_summary, img_mean=chan_mean)
                start_time = time.time()
                if not flags.from_scratch:
                    model.load(flags.model_dir)
                model.train(train_hooks=[train_hook, model_save_hook], valid_hooks=[valid_loss_hook, image_hook],
                            train_init=train_init_op, valid_init=valid_init_op)
                print('Duration: {:.3f}'.format((time.time() - start_time) / 3600))
Ejemplo n.º 12
0
            patch_img = img[y_slide - patch_size:y_slide,
                            x_slide - patch_size:x_slide, :]
            patch_gt = gt[y_slide - patch_size:y_slide,
                          x_slide - patch_size:x_slide]
            assert np.all([
                patch_img[patch_size - 88 - block_size - cnt_y:patch_size -
                          88 - cnt_y, patch_size - 88 - block_size -
                          cnt_x:patch_size - 88 - cnt_x, :] == ref_val
            ])

            yield patch_img, patch_gt, cnt_y, cnt_x


if __name__ == '__main__':
    force_run = False
    nn_utils.set_gpu(0)
    pretrained_model_dir = r'/hdd6/Models/Inria_decay/DeeplabV3_inria_decay_0_PS(321, 321)_BS5_' \
                           r'EP100_LR1e-05_DS40.0_DR0.1_SFN32'
    for field_name in ['austin', 'chicago', 'kitsap', 'tyrol-w', 'vienna']:
        for field_id in ['1', '2', '3', '4', '5']:
            y = 800
            x = 800
            patch_size = 321
            stride = 1
            block_size = 33
            output_size = 121
            tf.reset_default_graph()
            record_matrix = []

            img_dir, task_dir = sis_utils.get_task_img_folder()
            save_file_name = os.path.join(
Ejemplo n.º 13
0
import tensorflow as tf
import matplotlib.pyplot as plt
import ersa_utils, sis_utils
import uab_collectionFunctions
from nn import nn_utils
from bohaoCustom import uabMakeNetwork_UNet

model_dir = r'/hdd6/Models/Inria_decay/UnetCrop_inria_decay_0_PS(572, 572)_BS5_EP100_LR0.0001_DS60.0_DR0.1_SFN32'
rgb_dir = r'/media/ei-edl01/data/uab_datasets/inria/data/Original_Tiles/austin1_RGB.tif'

rgb = ersa_utils.load_file(rgb_dir)
rgb = rgb[1200:2200, 900:2300, :]
height = 2200 - 1200
width = 2300 - 900
rgb_temp = np.copy(rgb)
nn_utils.set_gpu(-1)
img_dir, task_dir = sis_utils.get_task_img_folder()

blCol = uab_collectionFunctions.uabCollection('inria')
img_mean = blCol.getChannelMeans([0, 1, 2])

X = tf.placeholder(tf.float32, shape=[None, 572, 572, 3], name='X')
y = tf.placeholder(tf.int32, shape=[None, 572, 572, 1], name='y')
mode = tf.placeholder(tf.bool, name='mode')
model = uabMakeNetwork_UNet.UnetModelCrop({
    'X': X,
    'Y': y
},
                                          trainable=mode,
                                          input_size=(572, 572),
                                          batch_size=1,
Ejemplo n.º 14
0
def main(flags):
    nn_utils.set_gpu(flags.GPU)
    for start_layer in flags.start_layer:
        if start_layer >= 10:
            suffix_base = 'aemo'
        else:
            suffix_base = 'aemo_up{}'.format(start_layer)
        if flags.from_scratch:
            suffix_base += '_scratch'
        for lr in flags.learn_rate:
            for run_id in range(1):
                suffix = '{}_{}'.format(suffix_base, run_id)
                tf.reset_default_graph()

                np.random.seed(run_id)
                tf.set_random_seed(run_id)

                # define network
                model = unet.UNet(flags.num_classes,
                                  flags.patch_size,
                                  suffix=suffix,
                                  learn_rate=lr,
                                  decay_step=flags.decay_step,
                                  decay_rate=flags.decay_rate,
                                  epochs=flags.epochs,
                                  batch_size=flags.batch_size)

                file_list = os.path.join(flags.data_dir, 'file_list.txt')
                lines = ersa_utils.load_file(file_list)

                patch_list_train = []
                patch_list_valid = []
                train_tile_names = ['aus10', 'aus30']
                valid_tile_names = ['aus50']

                for line in lines:
                    tile_name = os.path.basename(
                        line.split(' ')[0]).split('_')[0].strip()
                    if tile_name in train_tile_names:
                        patch_list_train.append(line.strip().split(' '))
                    elif tile_name in valid_tile_names:
                        patch_list_valid.append(line.strip().split(' '))
                    else:
                        raise ValueError

                cm = collectionMaker.read_collection('aemo_align')
                chan_mean = cm.meta_data['chan_mean']

                train_init_op, valid_init_op, reader_op = \
                    dataReaderSegmentation.DataReaderSegmentationTrainValid(
                        flags.patch_size, patch_list_train, patch_list_valid, batch_size=flags.batch_size,
                        chan_mean=chan_mean,
                        aug_func=[reader_utils.image_flipping, reader_utils.image_rotating],
                        random=True, has_gt=True, gt_dim=1, include_gt=True, valid_mult=flags.val_mult).read_op()
                feature, label = reader_op

                model.create_graph(feature)
                if start_layer >= 10:
                    model.compile(feature,
                                  label,
                                  flags.n_train,
                                  flags.n_valid,
                                  flags.patch_size,
                                  ersaPath.PATH['model'],
                                  par_dir=flags.par_dir,
                                  loss_type='xent')
                else:
                    model.compile(feature,
                                  label,
                                  flags.n_train,
                                  flags.n_valid,
                                  flags.patch_size,
                                  ersaPath.PATH['model'],
                                  par_dir=flags.par_dir,
                                  loss_type='xent',
                                  train_var_filter=[
                                      'layerup{}'.format(i)
                                      for i in range(start_layer, 10)
                                  ])
                train_hook = hook.ValueSummaryHook(
                    flags.verb_step, [model.loss, model.lr_op],
                    value_names=['train_loss', 'learning_rate'],
                    print_val=[0])
                model_save_hook = hook.ModelSaveHook(
                    model.get_epoch_step() * flags.save_epoch, model.ckdir)
                valid_loss_hook = hook.ValueSummaryHookIters(
                    model.get_epoch_step(), [model.loss_xent, model.loss_iou],
                    value_names=['valid_loss', 'IoU'],
                    log_time=True,
                    run_time=model.n_valid)
                image_hook = hook.ImageValidSummaryHook(model.input_size,
                                                        model.get_epoch_step(),
                                                        feature,
                                                        label,
                                                        model.pred,
                                                        nn_utils.image_summary,
                                                        img_mean=chan_mean)
                start_time = time.time()
                if not flags.from_scratch:
                    model.load(flags.model_dir)
                model.train(train_hooks=[train_hook, model_save_hook],
                            valid_hooks=[valid_loss_hook, image_hook],
                            train_init=train_init_op,
                            valid_init=valid_init_op)
                print('Duration: {:.3f}'.format(
                    (time.time() - start_time) / 3600))