Ejemplo n.º 1
0
car_size = 3  # meters
GSD = 0.15  # meters
yolt_box_size = np.rint(car_size / GSD)  # size in pixels
print("yolt_box_size (pixels):", yolt_box_size)
##############################

##############################
# slicing variables
slice_overlap = 0.1
zero_frac_thresh = 0.2
sliceHeight, sliceWidth = 544, 544  # for for 82m windows
##############################

##############################
# set yolt category params from pbtxt
label_map_dict = preprocess_tfrecords.load_pbtxt(label_map_path, verbose=False)
print("label_map_dict:", label_map_dict)
# get ordered keys
key_list = sorted(label_map_dict.keys())
category_num = len(key_list)
print key_list
# category list for yolt
cat_list = [label_map_dict[k] for k in key_list]
cat_list = 'car'
print("cat list:", cat_list)
yolt_cat_str = ','.join(cat_list)
print("yolt cat str:", yolt_cat_str)
# create yolt_category dictionary (should start at 0, not 1!)
yolt_cat_dict = {x: i for i, x in enumerate(cat_list)}
print("yolt_cat_dict:", yolt_cat_dict)
# conversion between yolt and pbtxt numbers (just increase number by 1)
Ejemplo n.º 2
0
def main():

    parser = argparse.ArgumentParser()

    parser.add_argument(
        '--tfrecords_filename',
        type=str,
        default='/cosmiq/simrdwn/tmp/val_detections_ssd.tfrecord',
        help="tfrecords file")
    parser.add_argument('--outdir',
                        type=str,
                        default='/cosmiq/simrdwn/tmp/images_ssd',
                        help="Output file location")
    parser.add_argument(
        '--pbtxt_filename',
        type=str,
        default='/cosmiq/simrdwn/data/class_labels_airplane_boat_car.pbtxt',
        help="Class dictionary")
    parser.add_argument('--tf_type',
                        type=str,
                        default='test',
                        help="weather the tfrecord is for test or train")
    parser.add_argument('--slice_val_images',
                        type=int,
                        default=0,
                        help="Switch for if validaion images are sliced")
    parser.add_argument('--verbose',
                        type=int,
                        default=0,
                        help="Print a lot o stuff?")

    #### Plotting settings
    parser.add_argument(
        '--plot_thresh',
        type=float,
        default=0.33,
        help="Threshold for plotting boxes, set < 0 to skip plotting")
    parser.add_argument(
        '--nms_overlap_thresh',
        type=float,
        default=0.5,
        help="IOU threshold for non-max-suppresion, skip if < 0")
    parser.add_argument('--make_box_labels',
                        type=int,
                        default=1,
                        help="If 1, make print label above each box")
    parser.add_argument('--scale_alpha',
                        type=int,
                        default=1,
                        help="If 1, scale box opacity with confidence")
    parser.add_argument('--plot_line_thickness',
                        type=int,
                        default=1,
                        help="If 1, scale box opacity with confidence")

    args = parser.parse_args()
    print("args:", args)
    t0 = time.time()

    if not os.path.exists(args.outdir):
        os.mkdir(args.outdir)

    # make label_map_dic (key=int, value=str), and reverse
    label_map_dict = preprocess_tfrecords.load_pbtxt(args.pbtxt_filename,
                                                     verbose=False)
    #label_map_dict_rev = {v: k for k,v in label_map_dict.iteritems()}

    # convert tfrecord to dataframe
    df_init0 = tf_to_df(tfrecords_filename=args.tfrecords_filename,
                        label_map_dict=label_map_dict,
                        tf_type=args.tf_type)
    #df_init = tf_to_df(tfrecords_filename=args.tfrecords_filename,
    #            outdir=args.outdir, plot_thresh=args.plot_thresh,
    #            label_map_dict=label_map_dict,
    #            show_labels = bool(args.make_box_labels),
    #            alpha_scaling = bool(args.scale_alpha),
    #            plot_line_thickness=args.plot_line_thickness)
    t1 = time.time()
    print("Time to run tf_to_df():", t1 - t0, "seconds")
    print("df_init.columns:", df_init0.columns)

    # filter out low confidence detections
    df_init = df_init0.copy()[df_init0['Prob'] >= args.plot_thresh]

    # augment dataframe columns
    df = post_process.augment_df(df_init,
                                 valid_testims_dir_tot='',
                                 slice_sizes=[0],
                                 valid_slice_sep='__',
                                 edge_buffer_valid=0,
                                 max_edge_aspect_ratio=4,
                                 valid_box_rescale_frac=1.0,
                                 rotate_boxes=False,
                                 verbose=bool(args.verbose))
    print("len df:", len(df))
    print("df.columns:", df_init.columns)
    print("df.iloc[0[:", df.iloc[0])
    outfile_df = os.path.join(args.outdir, '00_dataframe.csv')
    df.to_csv(outfile_df)

    # plot
    if args.plot_thresh > 0:
        post_process.refine_and_plot_df(
            df,
            label_map_dict=label_map_dict,
            outdir=args.outdir,
            #slice_sizes=[0],
            sliced=bool(args.slice_val_images),
            plot_thresh=args.plot_thresh,
            nms_overlap_thresh=args.nms_overlap_thresh,
            show_labels=args.make_box_labels,
            alpha_scaling=args.scale_alpha,
            plot_line_thickness=args.plot_line_thickness,
            verbose=bool(args.verbose))

    print("Plots output to:", args.outdir)
    print("Time to get and plot records:", time.time() - t0, "seconds")
Ejemplo n.º 3
0
    def gen_data(self):
        sys.path.append(os.path.join(self.path_simrdwn_utils, '..', 'core'))
        import preprocess_tfrecords
        ##############################
        # set yolt category params from pbtxt
        label_map_dict = preprocess_tfrecords.load_pbtxt(self.label_map_path,
                                                         verbose=False)
        print("label_map_dict:", label_map_dict)
        # get ordered keys
        key_list = sorted(label_map_dict.keys())
        category_num = len(key_list)
        # category list for yolt
        cat_list = [label_map_dict[k] for k in key_list]
        print("cat list:", cat_list)
        yolt_cat_str = ','.join(cat_list)
        print("yolt cat str:", yolt_cat_str)
        # create yolt_category dictionary (should start at 0, not 1!)
        yolt_cat_dict = {x: i for i, x in enumerate(cat_list)}
        print("yolt_cat_dict:", yolt_cat_dict)
        # conversion between yolt and pbtxt numbers (just increase number by 1)
        convert_dict = {x: x + 1 for x in range(100)}
        print("convert_dict:", convert_dict)
        ##############################

        ##############################
        # Slice large images into smaller chunks
        ##############################
        print("self.im_list_name:", self.im_list_name)
        if os.path.exists(self.im_list_name):
            run_slice = False
        else:
            run_slice = True

        for i, d in enumerate(self.train_dirs):
            dtot = os.path.join(self.ground_truth_dir, d)
            print("dtot:", dtot)

            # get label files
            files = os.listdir(dtot)
            annotate_files = [
                f for f in files if f.endswith(self.annotation_suffix)
            ]
            # print ("annotate_files:", annotate_files
            img_files = [
                f for f in files if f.endswith('.jpg') or f.endswith('.png')
            ]

            for imfile in img_files:
                ext = imfile.split('.')[-1]
                name_root = imfile.split(ext)[0]
                annotate_file = name_root + self.annotation_suffix
                annotate_file_tot = os.path.join(dtot, annotate_file)
                imfile_tot = os.path.join(dtot, imfile)
                outroot = d + '_' + imfile.split('.' + ext)[0]
                print("\nName_root", name_root)
                print("   annotate_file:", annotate_file)
                print("  imfile:", imfile)
                print("  imfile_tot:", imfile_tot)
                print("  outroot:", outroot)

                if run_slice:
                    if os.path.exists(annotate_file_tot):
                        parse_cowc.slice_im_cowc(
                            imfile_tot,
                            annotate_file_tot,
                            outroot,
                            self.images_dir,
                            self.labels_dir,
                            yolt_cat_dict,
                            cat_list[0],
                            self.yolt_box_size,
                            sliceHeight=self.sliceHeight,
                            sliceWidth=self.sliceWidth,
                            zero_frac_thresh=self.zero_frac_thresh,
                            overlap=self.slice_overlap,
                            pad=0,
                            verbose=self.verbose)
                    else:
                        parse_cowc.slice_im_no_mask(
                            imfile_tot,
                            outroot,
                            self.images_dir,
                            self.labels_dir,
                            yolt_cat_dict,
                            cat_list[0],
                            self.yolt_box_size,
                            sliceHeight=self.sliceHeight,
                            sliceWidth=self.sliceWidth,
                            zero_frac_thresh=self.zero_frac_thresh,
                            overlap=self.slice_overlap,
                            pad=0,
                            verbose=self.verbose)
        ##############################

        ##############################
        # Get list for simrdwn/data/, copy to data dir
        ##############################
        train_ims = [
            os.path.join(self.images_dir, f)
            for f in os.listdir(self.images_dir)
        ]
        f = open(self.im_list_name, 'w')
        for item in train_ims:
            f.write("%s\n" % item)
        f.close()
        # copy to data dir
        print("Copying", self.im_list_name, "to:", self.simrdwn_data_dir)
        shutil.copy(self.im_list_name, self.simrdwn_data_dir)
        ##############################

        ##############################
        # Ensure labels were created correctly by plotting a few
        ##############################
        max_plots = 50
        thickness = 2
        yolt_data_prep_funcs.plot_training_bboxes(
            self.labels_dir,
            self.images_dir,
            ignore_augment=False,
            sample_label_vis_dir=self.sample_label_vis_dir,
            max_plots=max_plots,
            thickness=thickness,
            ext='.png',
            verbose=True)

        ##############################
        # Make a .tfrecords file
        ##############################
        importlib.reload(preprocess_tfrecords)
        preprocess_tfrecords.yolt_imlist_to_tf(
            self.im_list_name,
            label_map_dict,
            TF_RecordPath=self.tfrecord_train,
            TF_PathVal='',
            val_frac=0.0,
            convert_dict=convert_dict,
            verbose=True)
        # copy train file to data dir
        print("Copying", self.tfrecord_train, "to:", self.simrdwn_data_dir)
        shutil.copy(self.tfrecord_train, self.simrdwn_data_dir)

        ##############################
        # Copy test images to test dir
        print("Copying test images to:", self.test_out_dir)
        for td in self.test_dirs:
            td_tot_in = os.path.join(self.ground_truth_dir, td)
            td_tot_out = os.path.join(self.test_out_dir, td)
            if not os.path.exists(td_tot_out):
                os.makedirs(td_tot_out)
            # copy non-label files
            for f in os.listdir(td_tot_in):
                if (f.endswith('.png')
                        or f.endswith('.jpg')) and not f.endswith(
                            ('_Cars.png', '_Negatives.png', '.xcf')):
                    shutil.copy2(os.path.join(td_tot_in, f), td_tot_out)
Ejemplo n.º 4
0
def main():

    parser = argparse.ArgumentParser()

    parser.add_argument('--outdir',
                        type=str,
                        default='/cosmiq/simrdwn/tmp/images_ssd',
                        help="Output file location")
    parser.add_argument(
        '--pbtxt_filename',
        type=str,
        default='/cosmiq/simrdwn/data/class_labels_airplane_boat_car.pbtxt',
        help="Class dictionary")
    parser.add_argument('--df_csv', type=str, default='', help="dataframe csv")
    parser.add_argument('--df_csv_out',
                        type=str,
                        default='',
                        help="output dataframe csv")
    parser.add_argument('--verbose',
                        type=int,
                        default=0,
                        help="Print a lot o stuff?")

    #### Plotting settings
    parser.add_argument('--slice_val_images',
                        type=int,
                        default=0,
                        help="Switch for if validaion images are sliced")
    parser.add_argument(
        '--plot_thresh',
        type=float,
        default=0.33,
        help="Threshold for plotting boxes, set < 0 to skip plotting")
    parser.add_argument(
        '--nms_overlap_thresh',
        type=float,
        default=0.5,
        help="IOU threshold for non-max-suppresion, skip if < 0")
    parser.add_argument('--make_box_labels',
                        type=int,
                        default=1,
                        help="If 1, make print label above each box")
    parser.add_argument('--scale_alpha',
                        type=int,
                        default=1,
                        help="If 1, scale box opacity with confidence")
    parser.add_argument('--plot_line_thickness',
                        type=int,
                        default=1,
                        help="If 1, scale box opacity with confidence")

    args = parser.parse_args()
    print("args:", args)
    t0 = time.time()

    header = [
        'Loc_Tmp', u'Prob', u'Xmin', u'Ymin', u'Xmax', u'Ymax', u'Category'
    ]

    # make label_map_dic (key=int, value=str), and reverse
    label_map_dict = preprocess_tfrecords.load_pbtxt(args.pbtxt_filename,
                                                     verbose=False)
    #label_map_dict_rev = {v: k for k,v in label_map_dict.iteritems()}

    if not os.path.exists(args.outdir):
        os.mkdir(args.outdir)

    # read dataframe
    df_init = pd.read_csv(args.df_csv, names=header)
    # tf_infer_cmd outputs integer categories, update to strings
    df_init['Category'] = [
        label_map_dict[ktmp] for ktmp in df_init['Category'].values
    ]

    # augment dataframe columns
    df = post_process.augment_df(df_init,
                                 valid_testims_dir_tot='',
                                 slice_sizes=[0],
                                 valid_slice_sep='__',
                                 edge_buffer_valid=0,
                                 max_edge_aspect_ratio=4,
                                 valid_box_rescale_frac=1.0,
                                 rotate_boxes=False,
                                 verbose=bool(args.verbose))
    print("df.columns:", df_init.columns)
    print("df.iloc[0[:", df.iloc[0])

    outfile_df = args.df_csv_out
    #outfile_df = args.df_csv.split('.')[0] + '_aug.csv'
    #outfile_df = os.path.join(args.outdir, '00_dataframe.csv')
    df.to_csv(outfile_df)

    # plot
    if args.plot_thresh > 0:
        post_process.refine_and_plot_df(
            df,
            label_map_dict=label_map_dict,
            outdir=args.outdir,
            #slice_sizes=[0],
            sliced=bool(args.slice_val_images),
            plot_thresh=args.plot_thresh,
            nms_overlap_thresh=args.nms_overlap_thresh,
            show_labels=args.make_box_labels,
            alpha_scaling=args.scale_alpha,
            plot_line_thickness=args.plot_line_thickness,
            verbose=bool(args.verbose))

    print("Plots output to:", args.outdir)
    print("Time to get and plot records:", time.time() - t0, "seconds")