示例#1
0
def store_value(filepath, key, value):
    if not os.path.exists(filepath):
        data = {}
    else:
        data = io.read_json(filepath)
    data[key] = str(value)
    with open(filepath, "w") as f:
        json.dump(data, f)
    return True
示例#2
0
def write_from_json(dataset_path, label_json_fp, label_name, **kwargs):
    data = io.read_json(label_json_fp)
    fps = []
    for c in (0, 1):
        fps.extend(glob.glob(os.path.join(dataset_path, "*", str(c), "*ID*.TXT")))
    for fp in fps:
        prefix = get_prefix(fp)
        header = HTPA32x32d.tools.read_txt_header(fp)
        label = None
        for chunk in header.split(","):
            if label_name in chunk:
                label = int(chunk.split(label_name)[-1])
        if label:
            print("Label for {} exists, ignoring!".format(fp))
        if not label:
            new_label = data[get_prefix(fp)]
            assert type(new_label) == int
            new_header = header+","+label_name+str(new_label)
            HTPA32x32d.tools.modify_txt_header(fp, new_header)
示例#3
0
K = tf.keras.backend


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('config_json_name', type=str)
    args = parser.parse_args()
    config_json_name = args.config_json_name.split(".json")[0]
    if not os.path.exists(TRAIN_LOG_FP):
        with open(TRAIN_LOG_FP, "w") as f:
            dt = datetime.datetime.now()
            f.write("Log started by {} on {} at {}. \n".format(
                config_json_name, dt.date(), dt.time()))
    # 1.
    # Read model
    model_config = io.read_json(os.path.join(
        "settings", config_json_name+".json"))
    for key in MODEL_CONFIG_KEYS:
        try:
            model_config[key]
        except KeyError:
            raise Exception(
                "Key {} not found in model configuration.".format(key))
    try:
        Setup = SETUP_DIC[model_config["setup"]]
    except KeyError:
        raise ValueError("Specified setup {} doesn't exist. Implemented setups: {}".format(
            model_config["setup"], SETUP_DIC))
    dataset_path, destination_parent_path = model_config[
        'dataset_intermediate_path'], model_config['dataset_processed_parent_path']
    processed_develop_path = os.path.join(
        destination_parent_path, os.path.split(dataset_path)[1])
示例#4
0
def label_gui(dataset_path, label_json_fp, label_name, label_batch_size, **kwargs):
    global global_curr_pos
    global global_max_pos
    global global_timesteps
    global global_img_fp_dict
    global global_label_pos
    data = io.read_json(label_json_fp)
    fps = []
    for c in (0, 1):
        fps.extend(glob.glob(os.path.join(dataset_path, "*", str(c), "*ID*.TXT")))
    loop_idx = 0
    for fp in fps:
        prefix = get_prefix(fp)
        if data[prefix]:
            continue
        absolute_prefix = get_absolute_prefix(fp)
        rgb_dir = absolute_prefix+"IDRGB"
        img_fp_list = glob.glob(os.path.join(rgb_dir, "*.jpg"))
        global_img_fp_dict = {}
        for img_fp in img_fp_list:
            timestep = float(os.path.split(img_fp)[-1].split(".jpg")[0].replace("-","."))
            global_img_fp_dict[timestep] = img_fp
        global_timesteps = list(global_img_fp_dict.keys())
        global_timesteps.sort()


        global_max_pos = len(global_timesteps)-1
        global_curr_pos = global_max_pos
        global_label_pos = None
        def key_event(e):
            global global_curr_pos
            global global_max_pos
            global global_timesteps
            global global_img_fp_dict
            global global_label_pos
            if e.key == " ":
                global_label_pos = global_curr_pos
                plt.close()
                return
            if e.key == "right":
                global_curr_pos = global_curr_pos + 1
                if global_curr_pos > global_max_pos:
                    global_curr_pos = 0
            elif e.key == "left":
                global_curr_pos = global_curr_pos - 1
                if global_curr_pos < 0:
                    global_curr_pos = global_max_pos
            ax.cla()
            fig.canvas.set_window_title(str(global_curr_pos))
            implot = ax.imshow(cv2.imread(global_img_fp_dict[global_timesteps[global_curr_pos]])[:,:,::-1])
            fig.canvas.draw()


        ax = plt.gca()
        fig = plt.gcf()
        plt.xticks([]), plt.yticks([]) 
        fig.canvas.set_window_title(str(global_curr_pos))
        implot = ax.imshow(cv2.imread(global_img_fp_dict[global_timesteps[global_curr_pos]])[:,:,::-1])
        cid = implot.figure.canvas.mpl_connect('key_press_event', key_event)
        plt.show()

        with open(os.path.join(rgb_dir, "timesteps.pkl"), "rb") as handle:
            timesteps_pkl = pickle.load(handle)
        real_label = timesteps_pkl.index(("{0:.2f}".format(global_timesteps[global_label_pos])).replace(".", "-")+".jpg")
        print("Labeling {} as {} at {}!".format(absolute_prefix, real_label, timesteps_pkl[real_label]))
        data[prefix] = real_label
        loop_idx += 1
        if loop_idx >= label_batch_size:
            break
    with open(label_json_fp, "w") as f:
        json.dump(data, f)