Exemple #1
0
def select_optimal_rho(item, rhos, input_file, output_file, input_file_n,
                       cubes_d, points_numbers_d, cube_positions_d, scale,
                       cube_size, res):
    for i, rho in enumerate(rhos):
        print('===== select rho =====')
        postprocess(output_file, cubes_d, points_numbers_d, cube_positions_d,
                    scale, cube_size, rho)
        results = pc_error(input_file,
                           output_file,
                           input_file_n,
                           res,
                           show=False)
        PSNR = float(results[item])
        print('===== results: ', i, rho, item, PSNR)
        if i == 0:
            MAX_PSNR = 0
            optimal_rho = rho
        else:
            MAX_PSNR = max(PSNR, MAX_PSNR)

        if PSNR < MAX_PSNR:
            break
        else:
            optimal_rho = rho

    return optimal_rho
Exemple #2
0
    def inference(self, img):
        img_info = {"id": 0}
        if isinstance(img, str):
            img_info["file_name"] = os.path.basename(img)
            img = cv2.imread(img)
            if img is None:
                raise ValueError("test image path is invalid!")
        else:
            img_info["file_name"] = None

        height, width = img.shape[:2]
        img_info["height"] = height
        img_info["width"] = width
        img_info["raw_img"] = img

        img, ratio = preprocess(img, self.test_size, self.rgb_means, self.std)
        img_info["ratio"] = ratio
        img = F.expand_dims(mge.tensor(img), 0)

        t0 = time.time()
        outputs = self.model(img)
        outputs = postprocess(outputs, self.num_classes, self.confthre,
                              self.nmsthre)
        logger.info("Infer time: {:.4f}s".format(time.time() - t0))
        return outputs, img_info
Exemple #3
0
def e2e(s):
    # make the input space-delimited in prefix notation
    ir = postprocess(infix_to_prefix(preprocess(s)))
    # split on space and turn into nested tuples
    tup = tuple_for_polish_expression(ir.split(' '))
    # convert to MRS and return
    return prettyUMRSForTuple(tup)
Exemple #4
0
def _main_():
    args = argparser.parse_args()
    config_path = args.config
    with open(config_path) as config_buffer:
        config = json.loads(config_buffer.read())
    weights_path = args.weights_path
    sm_model = SMModel(config['model'])
    sm_model.model.summary()
    sm_model.model.load_weights(weights_path)
    test_generator = DataGenerator(config=config['test'],
                                   preprocessing=sm_model.preprocessing,
                                   n_class=sm_model.n_class,
                                   split='test')
    encoded_pixels = []
    image_id_class_id = []
    for X, filenames in tqdm(list(test_generator)):
        preds = sm_model.model.predict_on_batch(X)
        if config['test']['tta']:
            for flip_type in ['ud', 'lr', 'udlr']:
                X_temp = flip(X.copy(), flip_type)
                pred_temp = sm_model.model.predict_on_batch(X_temp)
                preds += flip(pred_temp, flip_type)
            preds /= 4
        preds = postprocess(preds, config['postprocess'], True)
        for i in range(len(preds)):
            for j in range(4):
                encoded_pixels.append(run_length_encode(preds[i, :, :, j]))
                image_id_class_id.append(filenames[i] + '_{}'.format(j + 1))
    df = pd.DataFrame(data=encoded_pixels,
                      index=image_id_class_id,
                      columns=['EncodedPixels'])
    df.index.name = 'ImageId_ClassId'
    df.to_csv('submission.csv')
Exemple #5
0
def _main_():
    args = argparser.parse_args()
    config_path = args.config
    with open(config_path) as config_buffer:
        config = json.loads(config_buffer.read())
    sm_model = SMModel(config['model'])
    oof_preds = []
    oof_true_masks = []
    for i in range(5):
        config['train']['fold'] = i
        generator = DataGenerator(config=config['train'], 
                                  preprocessing=sm_model.preprocessing,
                                  n_class=sm_model.n_class, 
                                  split='val', 
                                  full_size_mask=True)
        weithts_path = os.path.join(config['train']['save_model_folder'], 
                                    'val_best_fold_{}_weights.h5'.format(i))
        sm_model.model.load_weights(weithts_path)
        print('Fold {} eval begin.'.format(i))
        for X, y in tqdm(list(generator)):
            y_preds = sm_model.model.predict(X)
            y_preds = postprocess(y_preds, config['postprocess'], True)
            oof_preds.append(y_preds)
            y = y[:, :, :, :4]
            oof_true_masks.append(y)
    oof_preds = np.concatenate(oof_preds)
    oof_true_masks = np.concatenate(oof_true_masks)
    
    cv_dice_coef = dice_coef_score(oof_true_masks, oof_preds)
    print('CV Dice Coef Score: {}'.format(cv_dice_coef))
Exemple #6
0
def save_prediction(pred, input_file, tile, path, scale=False):
    fname = input_file.split('/')[-1]
    sample = fname.split('_')[0]
    path = os.path.join(path, sample)
    os.makedirs(path, exist_ok=True)
    shape = util.shape(input_file)
    header = util.header(input_file)
    vol = process.postprocess(pred, shape, resize=True, tile=tile)
    util.save_vol(vol, os.path.join(path, fname), header, scale)
    print(fname, flush=True)
Exemple #7
0
 def on_epoch_end(self, epoch, logs={}):
     dice_coef = 0
     for X, y_true in list(self.generator):
         y_pred = self.model.predict(X)
         y_pred = postprocess(y_pred, self.postprocess_config)
         y_true = y_true[:, :, :, :4]
         inter = (y_true * y_pred).sum(1).sum(1)
         union = y_true.sum(1).sum(1) + y_pred.sum(1).sum(1)
         dice_coef_batch = (2*inter + self.eps) / (union + self.eps)
         dice_coef += dice_coef_batch.sum()
     dice_coef /= (self.generator.n_samples * 4)
     logs.update({'dice_coef_score': dice_coef})
     print('Epoch {} Validation Dice Coefficent Score: {}.\n'.format(epoch+1, dice_coef))
Exemple #8
0
def save_predictions(preds, generator, path, scale=False):
    os.makedirs(path, exist_ok=True)

    if generator.tile_inputs:
        preds = np.reshape(preds, (-1, 8) + preds.shape[1:])

    for i in range(preds.shape[0]):
        input_file = generator.input_files[i]
        fname = input_file.split('/')[-1]
        header = util.header(input_file)
        shape = util.shape(input_file)
        volume = process.postprocess(preds[i],
                                     shape,
                                     resize=True,
                                     tile=generator.tile_inputs)
        util.save_vol(volume, os.path.join(path, fname), header, scale)
Exemple #9
0
def test():
    initial_time = datetime.now()
    category = 'restaurants_bars'

    model_to_test = ''
    if request.args.get('model'):
        model_to_test = 'camembert'

    print('\n', '#'*50)
    print(f' Start Analyse on {category} '.center(50, '#'))
    print('#'*50, '\n')

    init_time = datetime.now()
    print(' Start scraping '.center(30, '#'))
    time_elapsed = datetime.now() - init_time
    print(f'Scraping time : {time_elapsed}')

    init_time = datetime.now()
    print(' Start preprocess '.center(30, '#'))
    time_elapsed = datetime.now() - init_time
    print(f'Preprocess time : {time_elapsed}')

    # 4. Predict sentiment and add it to dataframe
    init_time = datetime.now()
    print(' Start prediction '.center(30, '#'))
    if model_to_test == 'camembert':
        df = pd.read_csv(f'predict_{category}_cam.csv')
    else:
        df = pd.read_csv(f'predict_{category}.csv')
    refs = [df.site.unique(), df.site.unique()]
    time_elapsed = datetime.now() - init_time
    print(f'Prediction time : {time_elapsed}')

    # 5. Apply postprocess to transform data into json
    init_time = datetime.now()
    print(' Start postprocess '.center(30, '#'))
    json_review = process.postprocess(df, refs)
    time_elapsed = datetime.now() - init_time
    print(f'Postprocess time : {time_elapsed}')
    time_elapsed = datetime.now() - initial_time
    print(f'Total time elapsed : {time_elapsed}')

    return json_review
Exemple #10
0
def eval(input_file, rootdir, resolution, mode, cube_size, modelname,
         fixed_thres, postfix):
    # model = 'model_voxception'
    model = importlib.import_module(modelname)

    filename = os.path.split(input_file)[-1][:-4]
    output_file = filename + '_rec_' + postfix + '.ply'
    input_file_n = input_file
    csv_rootdir = os.path.join(rootdir, "csv")
    if not os.path.exists(csv_rootdir):
        os.makedirs(csv_rootdir)
    cfg_rootdir = os.path.join(rootdir, "cfg")
    if not os.path.exists(cfg_rootdir):
        os.makedirs(cfg_rootdir)
    # default config
    config, config_file = set_default_config(input_file, cfg_rootdir,
                                             resolution, mode, cube_size,
                                             modelname)
    cube_size = config.getint('DEFAULT', 'cube_size')
    min_num = config.getint('DEFAULT', 'min_num')
    res = config.getint('DEFAULT', 'resolution')
    print('cube size:', cube_size, 'min num:', min_num, 'res:', res)

    for index, rate in enumerate(config.sections()):
        scale = float(config.get(rate, 'scale'))
        ckpt_dir = str(config.get(rate, 'ckpt_dir'))
        print('====================', 'config:', rate, 'scale:', scale,
              'ckpt_dir:', ckpt_dir)

        if mode == "factorized":
            cubes_d, cube_positions, points_numbers, N, bpps = test_factorized(
                input_file, model, ckpt_dir, scale, cube_size, min_num,
                postfix)
        elif mode == "hyper":
            cubes_d, cube_positions, points_numbers, N, bpps = test_hyper(
                input_file, model, ckpt_dir, scale, cube_size, min_num,
                postfix)
        cubes_d = cubes_d.numpy()
        print("bpp:", bpps[0])
        # select rho for optimal d1/d2 metrics.
        if fixed_thres == None:
            rho_d1, rho_d2 = cfg_post_process(config, config_file, rate,
                                              input_file, output_file,
                                              input_file_n, cubes_d,
                                              points_numbers, cube_positions,
                                              scale, cube_size, res)
        else:
            rho_d1, rho_d2 = 1.0, 1.0

        # metrics.
        rho = 1.0
        postprocess(output_file, cubes_d, points_numbers, cube_positions,
                    scale, cube_size, rho, fixed_thres)
        results = pc_error(input_file,
                           output_file,
                           input_file_n,
                           res,
                           show=False)

        rho = rho_d1
        postprocess(output_file, cubes_d, points_numbers, cube_positions,
                    scale, cube_size, rho, fixed_thres)
        results_d1 = pc_error(input_file,
                              output_file,
                              input_file_n,
                              res,
                              show=False)

        rho = rho_d2
        postprocess(output_file, cubes_d, points_numbers, cube_positions,
                    scale, cube_size, rho, fixed_thres)
        results_d2 = pc_error(input_file,
                              output_file,
                              input_file_n,
                              res,
                              show=False)

        results = collect_results(results, results_d1, results_d2, bpps, N,
                                  scale, rho_d1, rho_d2)

        if index == 0:
            all_results = results.copy(deep=True)
        else:
            all_results = all_results.append(results, ignore_index=True)

    # write to csv
    print(all_results)
    if not os.path.exists(csv_rootdir):
        os.makedirs(csv_rootdir)
    csv_name = os.path.join(csv_rootdir, filename + '.csv')
    all_results.to_csv(csv_name, index=False)
    # plot
    plot_results(all_results, filename, csv_rootdir)
    return all_results
Exemple #11
0
def oof_eval(oof_preds, oof_true_masks, config):
    oof_preds = postprocess(oof_preds, config['postprocess'], True)
    cv_dice_coef = dice_coef_score(oof_true_masks, oof_preds)
    print('CV Dice Coef Score: {}'.format(cv_dice_coef))
    return cv_dice_coef
Exemple #12
0
def eval(input_file, rootdir, cfgdir, res, mode, cube_size, modelname, fixed_thres, postfix):
    # model = 'model_voxception'
    model = importlib.import_module(modelname)

    filename = os.path.split(input_file)[-1][:-4]
    input_file_n = input_file    
    csv_rootdir = rootdir
    if not os.path.exists(csv_rootdir):
        os.makedirs(csv_rootdir)
    csv_name = os.path.join(csv_rootdir, filename + '.csv')

    config = configparser.ConfigParser()
    config.read(cfgdir)

    cube_size = config.getint('DEFAULT', 'cube_size')
    min_num = config.getint('DEFAULT', 'min_num')
    print('cube size:', cube_size, 'min num:', min_num, 'res:', res)

    for index, rate in enumerate(config.sections()):
        scale = float(config.get(rate, 'scale'))
        ckpt_dir = str(config.get(rate, 'ckpt_dir'))
        rho_d1 = float(config.get(rate, 'rho_d1'))
        rho_d2 = float(config.get(rate, 'rho_d2'))
        print('='*80, '\n', 'config:', rate, 'scale:', scale, 'ckpt_dir:', ckpt_dir, 'rho (d1):', rho_d1, 'rho_d2:', rho_d2)

        if mode=="factorized":
            cubes_d, cube_positions, points_numbers, N, bpps = test_factorized(input_file, model, ckpt_dir, scale, cube_size, min_num, postfix)
        elif mode == "hyper":
            cubes_d, cube_positions, points_numbers, N, bpps = test_hyper(input_file, model, ckpt_dir, scale, cube_size, min_num, postfix)
        cubes_d = cubes_d.numpy()
        print("bpp:",bpps[0])

        # metrics.
        rho = 1.0
        output_file = filename + '_rec_' + str(rate) + '_' + 'rho' + str(round(rho*100)) + postfix + '.ply'
        postprocess(output_file, cubes_d, points_numbers, cube_positions, scale, cube_size, rho, fixed_thres)
        results = pc_error(input_file, output_file, input_file_n, res, show=False)

        rho = rho_d1
        output_file = filename + '_rec_' + str(rate) + '_' + 'rho' + str(round(rho*100)) + postfix + '.ply'
        postprocess(output_file, cubes_d, points_numbers, cube_positions, scale, cube_size, rho, fixed_thres)
        results_d1 = pc_error(input_file, output_file, input_file_n, res, show=False)

        rho = rho_d2
        output_file = filename + '_rec_' + str(rate) + '_' + 'rho' + str(round(rho*100)) + postfix + '.ply'
        postprocess(output_file, cubes_d, points_numbers, cube_positions, scale, cube_size, rho, fixed_thres)
        results_d2 = pc_error(input_file, output_file, input_file_n, res, show=False)
         
        results = collect_results(results, results_d1, results_d2, bpps, N, scale, rho_d1, rho_d2)

        if index == 0:
            all_results = results.copy(deep=True)
        else:
            all_results = all_results.append(results, ignore_index=True)

        all_results.to_csv(csv_name, index=False)

    print(all_results)
    plot_results(all_results, filename, csv_rootdir)

    return all_results
Exemple #13
0
def graphs():
    initial_time = datetime.now()
    # 1. Get infos for scraping
    # Mandatory argument : Category
    category = request.args.get('category')

    # Optional argument
    # number of site to scrape per category, default 5 (0 for max)
    if request.args.get('num_of_site'):
        num_of_site = int(request.args.get('num_of_site'))
    else:
        num_of_site = 5
    # number of page to scrape per site, default 2 (0 for max)
    if request.args.get('num_page'):
        num_page = int(request.args.get('num_page'))
    else:
        num_page = 2
    # city where the scraping is desired (better with department number)
    if request.args.get('location'):
        location = request.args.get('location')
    else:
        location = 'no city'
    # model to use for scraping (one option 'camembert', else default model)
    model_to_test = ''
    if request.args.get('model'):
        model_to_test = 'camembert'

    print('\n', '#'*50)
    print(f' Start Analyse on {category} '.center(50, '#'))
    print('#'*50, '\n')

    # 2. Scrape trustpilot to get dataframe
    init_time = datetime.now()
    print(' Start scraping '.center(30, '#'))
    refs, df = scraping.scrape(category, location, num_of_site, num_page)
    time_elapsed = datetime.now() - init_time
    print(f'Scraping time : {time_elapsed}')
    if len(df)>0:
        # 3. Preprocess dataframe before prediction
        init_time = datetime.now()
        print(' Start preprocess '.center(30, '#'))
        df = process.preprocess_df(df)
        time_elapsed = datetime.now() - init_time
        print(f'Preprocess time : {time_elapsed}')

        # 4. Predict sentiment and add it to dataframe
        init_time = datetime.now()
        print(' Start prediction '.center(30, '#'))
        if model_to_test == 'camembert':
            df = model.predict_camembert(df)
        else:
            df = model.predict(df)
        time_elapsed = datetime.now() - init_time
        print(f'Prediction time : {time_elapsed}')

        # 5. Apply postprocess to transform data into json
        init_time = datetime.now()
        print(' Start postprocess '.center(30, '#'))
        json_review = process.postprocess(df, refs)
        time_elapsed = datetime.now() - init_time
        print(f'Postprocess time : {time_elapsed}')
    else:
        print("No data found")
        json_review = "<h1>Pas de données</h1>"
    time_elapsed = datetime.now() - initial_time
    print(f'Total time elapsed : {time_elapsed}')
    return json_review
Exemple #14
0
def predict():
    input_data = request.get_json(force=True)
    transformed_input_data = preprocess(input_data)
    prediction = model.predict(transformed_input_data)
    transformed_prediction = postprocess(prediction)
    return jsonify({"prediction": transformed_prediction})
Exemple #15
0
def plot_interface(pltid):
    user = root.authorized()
    app = request.query.app
    cid = request.query.cid
    jid = request.query.jid
    params = dict()

    if not cid:
        params[
            'err'] = "No case id specified. First select a case id from the list of jobs."
        return template('error', params)

    if re.search("/", cid):
        (owner, c) = cid.split("/")
    else:
        owner = user
        c = cid

    shared = jobs(cid=c).shared
    # only allow admin to see other user's cases that have not been shared
    if owner != user and shared != "True" and user != "admin":
        return template('error', err="access forbidden")

    inputs, _, _ = root.myapps[app].read_params(owner, c)
    sim_dir = os.path.join(user_dir, owner, app, c)

    # use pltid of 0 to trigger finding the first pltid for the current app
    if int(pltid) == 0:
        query = (apps.id == plots.appid) & (apps.name == app)
        result = db(query).select().first()
        if result: pltid = result['plots']['id']

    p = Plot()

    # get the data for the pltid given
    try:
        result = db(plots.id == pltid).select().first()
        plottype = result['ptype']

        plot_title = result['title']
    except:
        exc_type, exc_value, exc_traceback = sys.exc_info()
        print traceback.print_exception(exc_type, exc_value, exc_traceback)
        redirect('/plots/edit?app=' + app + '&cid=' + cid)

    # if plot not in DB return error
    if plottype is None:
        params = {'cid': cid, 'app': app, 'user': user}
        params['err'] = "Sorry! This app does not support plotting capability"
        return template('error', params)

    # determine which view template to use
    if plottype == 'flot-cat':
        tfn = 'plots/flot-cat'
    elif plottype == 'flot-scatter':
        tfn = 'plots/flot-scatter'
    elif plottype == 'flot-scatter-animated':
        tfn = 'plots/flot-scatter-animated'  # for backwards compatability
    elif plottype == 'flot-line':
        tfn = 'plots/flot-scatter'
    elif plottype == 'plotly-hist':
        tfn = 'plots/plotly-hist'
    elif plottype == 'mpl-line' or plottype == 'mpl-bar':
        redirect('/mpl/' + pltid + '?app=' + app + '&cid=' + cid)
    elif plottype == 'handson':
        tfn = 'plots/handson'
    elif plottype == 'flot-3d':
        return plot_flot_3d(result, cid, app, sim_dir, owner, user, plot_title,
                            pltid)
    else:
        return template("error", err="plot type not supported: " + plottype)

    if result['options']:
        options = replace_tags(result['options'], inputs)
    else:
        options = ''

    # get list of all plots for this app
    query = (apps.id == plots.appid) & (apps.name == app)
    list_of_plots = db(query).select()

    # extract data from files
    data = []
    ticks = []
    plotpath = ''
    result = db(datasource.pltid == pltid).select()

    datadef = ""
    for r in result:
        plotfn = r['filename']

        # in addition to supporting input params, also support case id
        if "cid" not in inputs: inputs["cid"] = c

        # replace <cid>.dat with xyz123.dat
        plotfn = replace_tags(plotfn, inputs)
        plotpath = os.path.join(sim_dir, plotfn)

        # handle CSV data
        _, file_extension = os.path.splitext(plotfn)
        if file_extension == '.csv':
            data = p.get_csv_data(plotpath)
            stats = ''

        # handle X, Y columnar data
        else:
            cols = r['cols']
            line_range = r['line_range']
            try:
                datadef += r['data_def'] + ", "
            except:
                exc_type, exc_value, exc_traceback = sys.exc_info()
                print traceback.print_exception(exc_type, exc_value,
                                                exc_traceback)
                datadef = ""

            if cols.find(":") > 0:  # two columns
                num_fields = 2
                (col1str, col2str) = cols.split(":")
                col1 = int(col1str)
                col2 = int(col2str)
            else:  # single column
                num_fields = 1
                col1 = int(cols)

            # do some postprocessing
            if line_range is not None:
                # to prevent breaking current spc apps, still support
                # expressions like 1:1000, but in the future this should
                # be changed to a range 1-1000.  Therefore, using : is deprecated
                # and will be removed in the future.
                (line1str, line2str) = re.split("[-:]", line_range)
                line1 = int(line1str)
                ## there is a problem with the following statement
                ## shows up in mendel app
                # if root.myapps[app].postprocess > 0:
                #    dat = process.postprocess(plotpath, line1, line2)
                # else:
                try:  # if line2 is specified
                    line2 = int(line2str)
                    dat = p.get_data(plotpath, col1, col2, line1, line2)
                except:  # if line2 not specified
                    exc_type, exc_value, exc_traceback = sys.exc_info()
                    print traceback.print_exception(exc_type, exc_value,
                                                    exc_traceback)
                    if num_fields == 2:
                        dat = p.get_data(plotpath, col1, col2, line1)
                    else:  # single column of data
                        dat = p.get_data(plotpath, col1)
                # remove this app-specific code in future
                if app == "fpg":
                    import process
                    dat = process.postprocess(plotpath, line1, line2)
            else:
                dat = p.get_data(plotpath, col1, col2)

            if dat == -1:
                stats = "ERROR: Could not read data file"
            elif dat == -2:
                stats = "ERROR: file exists, but problem parsing data. Are column and line ranges setup properly? Is all the data there?"
            else:
                stats = compute_stats(plotpath)
            # [[1,2,3]] >>> [1,2,3]

            # clean data
            #dat = [d.replace('?', '0') for d in dat]
            data.append(dat)

            if num_fields == 1: data = data[0]

            if plottype == 'flot-cat':
                ticks = p.get_ticks(plotpath, col1, col2)

    desc = jobs(cid=c).description

    params = {
        'cid': cid,
        'pltid': pltid,
        'data': data,
        'app': app,
        'user': user,
        'owner': owner,
        'ticks': ticks,
        'plot_title': plot_title,
        'plotpath': plotpath,
        'rows': list_of_plots,
        'options': options,
        'datadef': datadef,
        'stats': stats,
        'description': desc
    }

    if jid: params['jid'] = jid

    return template(tfn, params)
Exemple #16
0
            strings, min_v, max_v, shape = compress_factorized(cubes, model, args.ckpt_dir)
            if not args.output:
                args.output = os.path.split(args.input)[-1][:-4]
                rootdir = './compressed'
            else:
                rootdir, args.output = os.path.split(args.output)
            bytes_strings, bytes_pointnums, bytes_cubepos = write_binary_files_factorized(
                args.output, strings.numpy(), points_numbers, cube_positions, min_v.numpy(), max_v.numpy(), shape.numpy(), rootdir=rootdir)

        elif args.command == "decompress":
            rootdir, filename = os.path.split(args.input)
            if not args.output:
                args.output = filename + "_rec.ply"
            strings_d, points_numbers_d, cube_positions_d, min_v_d, max_v_d, shape_d = read_binary_files_factorized(filename, rootdir)
            cubes_d = decompress_factorized(strings_d, min_v_d, max_v_d, shape_d, model, args.ckpt_dir)
            postprocess(args.output, cubes_d.numpy(), points_numbers_d, cube_positions_d, args.scale, args.cube_size, args.rho)
    
    if args.mode == "hyper":
        if args.command == "compress":
            if not args.output:
                args.output = os.path.split(args.input)[-1][:-4]
                rootdir = './compressed'
            else:
                rootdir, args.output = os.path.split(args.output)

            cubes, cube_positions, points_numbers = preprocess(args.input, args.scale, args.cube_size, args.min_num)
 
            y_strings, y_min_vs, y_max_vs, y_shape, z_strings, z_min_v, z_max_v, z_shape = compress_hyper(cubes, model, args.ckpt_dir)

            bytes_strings, bytes_strings_head, bytes_strings_hyper, bytes_pointnums, bytes_cubepos = write_binary_files_hyper(
                args.output, y_strings.numpy(), z_strings.numpy(), points_numbers, cube_positions,