示例#1
0
def train(fn_train, cell_max_min, cell_resolution):
    """
    @params: [fn_train, cell_max_min, cell_resolution]
    @returns: []
    Fits the 3D BHM on each frame of the dataset and plots occupancy or regression
    """
    print('\nTraining started---------------')
    alpha = 10**-2
    beta = 10**2
    for framei in range(args.num_frames):
        X, y_vx, y_vy, y_vz, partitions = utils_filereader.read_frame_velocity(
            args, framei, fn_train, cell_max_min)
        train_methods.train_velocity(args, alpha, beta, X, y_vx, y_vy, y_vz,
                                     partitions, cell_resolution, cell_max_min,
                                     framei)  ###///###
        del X, y_vx, y_vy, y_vz, partitions

    print('Training completed---------------\n')
示例#2
0
def query(fn_train, cell_max_min):
    """
    @params: [fn_train, cell_max_min]
    @returns: []
    Queries the 3D BHM for occupancy or regression on each frame of the dataset
    """
    print('Querying started---------------')
    for framei in range(args.num_frames):
        X, y_vx, y_vy, y_vz, partitions = utils_filereader.read_frame_velocity(
            args, framei, fn_train, cell_max_min)
        print(
            f"cell_resolution: {cell_resolution}, args.query_dist: {args.query_dist}"
        )
        query_methods.query_velocity(args, X, y_vx, y_vy, y_vz, partitions,
                                     cell_resolution, cell_max_min,
                                     framei)  ###///###
        del X, y_vx, y_vy, y_vz, partitions

    print('Querying completed---------------\n')
示例#3
0
def train(fn_train, cell_max_min, cell_resolution):
    """
    @params: [fn_train, cell_max_min, cell_resolution]
    @returns: []
    Fits the 3D BHM on each frame of the dataset and plots occupancy or regression
    """
    print('\nTraining started---------------')
    alpha = 10**-2
    beta = 10**2
    for framei in range(args.num_frames):
        if args.model_type == "occupancy" or args.model_type == "regression":
            g, X, y_occupancy, sigma, partitions = utils_filereader.read_frame(
                args, framei, fn_train, cell_max_min)
        elif args.model_type == "velocity":
            X, y_vx, y_vy, y_vz, partitions = utils_filereader.read_frame_velocity(
                args, framei, fn_train, cell_max_min)
        else:
            raise ValueError("Unknown model type: \"{}\"".format(
                args.model_type))

        if args.model_type == 'occupancy':
            train_methods.train_occupancy(args, partitions, cell_resolution, X,
                                          y_occupancy, sigma, framei)
        elif args.model_type == 'regression':
            train_methods.train_regression(args, alpha, beta, cell_resolution,
                                           cell_max_min, X, y_occupancy, g,
                                           sigma[:, :2], framei)
            # For regression, we use sigma dimension 2. This is hard coded in the pass to plot_regression for the sigma term above
        elif args.model_type == "velocity":  ###===###
            train_methods.train_velocity(args, alpha, beta, X, y_vx, y_vy,
                                         y_vz, partitions, cell_resolution,
                                         cell_max_min, framei)  ###///###

        if args.model_type == "occupancy" or args.model_type == "regression":
            del g, X, y_occupancy, sigma, partitions
        elif args.model_type == "velocity":
            del X, y_vx, y_vy, y_vz, partitions
        else:
            raise ValueError("Unknown model type: \"{}\"".format(
                args.model_type))
    print('Training completed---------------\n')
示例#4
0
def query(fn_train, cell_max_min):
    """
    @params: [fn_train, cell_max_min]
    @returns: []
    Queries the 3D BHM for occupancy or regression on each frame of the dataset
    """
    print('Querying started---------------')
    for framei in range(args.num_frames):
        if args.model_type == "occupancy" or args.model_type == "regression":
            g, X, y_occupancy, sigma, partitions = utils_filereader.read_frame(
                args, framei, fn_train, cell_max_min)
        elif args.model_type == "velocity":
            X, y_vx, y_vy, y_vz, partitions = utils_filereader.read_frame_velocity(
                args, framei, fn_train, cell_max_min)
        else:
            raise ValueError("Unknown model type: \"{}\"".format(
                args.model_type))

        if args.model_type == 'occupancy':
            query_methods.query_occupancy(args, cell_max_min, partitions, X,
                                          y_occupancy, framei)
        elif args.model_type == 'regression':
            query_methods.query_regression(args, cell_max_min, X, y_occupancy,
                                           g, framei)
        elif args.model_type == "velocity":  ###===###
            query_methods.query_velocity(args, X, y_vx, y_vy, y_vz, partitions,
                                         cell_resolution, cell_max_min,
                                         framei)  ###///###

        if args.model_type == "occupancy" or args.model_type == "regression":
            del g, X, y_occupancy, sigma, partitions
        elif args.model_type == "velocity":
            del X, y_vx, y_vy, y_vz, partitions
        else:
            raise ValueError("Unknown model type: \"{}\"".format(
                args.model_type))
    print('Querying completed---------------\n')
示例#5
0
def query_velocity(args, X, y_vx, y_vy, y_vz, partitions, cell_resolution, cell_max_min, framei):
    bhm_velocity_mdl, train_time = load_mdl(args, 'velocity/{}_f{}'.format(args.save_model_path, framei))

    option = ''
    if args.eval_path != '' and args.eval:
        #if eval is True, test the query
        print(" Query data from the test dataset")
        Xq_mv, y_vx_true, y_vy_true, y_vz_true, _ = read_frame_velocity(args, framei, args.eval_path, cell_max_min)
        option = args.eval_path
    elif args.query_dist[0] <= 0 and args.query_dist[1] <= 0 and args.query_dist[2] <= 0:
        #if all q_res are non-positive, then query input = X
        print(" Query data is the same as input data")
        Xq_mv = X
        option = 'Train data'
    elif args.query_dist[0] <= 0 or args.query_dist[1] <= 0 or args.query_dist[2] <= 0:
        #if at least one q_res is non-positive, then
        if args.query_dist[0] <= 0: #x-slice
            print(" Query data is x={} slice ".format(args.query_dist[3]))
            xx, yy, zz = torch.meshgrid(
                torch.arange(
                    args.query_dist[3],
                    args.query_dist[3] + 0.1,
                    1
                ),
                torch.arange(
                    cell_max_min[2],
                    cell_max_min[3] + args.query_dist[1],
                    args.query_dist[1]
                ),
                torch.arange(
                    cell_max_min[4],
                    cell_max_min[5] + args.query_dist[2],
                    args.query_dist[2]
                )
            )
            Xq_mv = torch.stack([xx.flatten(), yy.flatten(), zz.flatten()], dim=1)
            option = 'X slice at '.format(args.query_dist[3])
        elif args.query_dist[1] <= 0: #y-slice
            print(" Query data is y={} slice ".format(args.query_dist[3]))
            xx, yy, zz = torch.meshgrid(
                torch.arange(
                    cell_max_min[0],
                    cell_max_min[1] + args.query_dist[0],
                    args.query_dist[0]
                ),
                torch.arange(
                    args.query_dist[3],
                    args.query_dist[3] + 0.1,
                    1
                ),
                torch.arange(
                    cell_max_min[4],
                    cell_max_min[5] + args.query_dist[2],
                    args.query_dist[2]
                )
            )
            Xq_mv = torch.stack([xx.flatten(), yy.flatten(), zz.flatten()], dim=1)
            option = 'Y slice at '.format(args.query_dist[3])
        else: #z-slice
            print(" Query data is z={} slice ".format(args.query_dist[3]))
            xx, yy, zz = torch.meshgrid(
                torch.arange(
                    cell_max_min[0],
                    cell_max_min[1] + args.query_dist[0],
                    args.query_dist[0]
                ),
                torch.arange(
                    cell_max_min[2],
                    cell_max_min[3] + args.query_dist[1],
                    args.query_dist[1]
                ),
                torch.arange(
                    args.query_dist[3],
                    args.query_dist[3] + 0.1,
                    1
                )
            )
            Xq_mv = torch.stack([xx.flatten(), yy.flatten(), zz.flatten()], dim=1)
            option = 'Z slice at '.format(args.query_dist[3])
    else:
        #if not use the grid
        print(" Query data is a 3D grid.")
        xx, yy, zz = torch.meshgrid(
            torch.arange(
                cell_max_min[0],
                cell_max_min[1]+args.query_dist[0],
                args.query_dist[0]
            ),
            torch.arange(
                cell_max_min[2],
                cell_max_min[3]+args.query_dist[1],
                args.query_dist[1]
            ),
            torch.arange(
                cell_max_min[4],
                cell_max_min[5]+args.query_dist[2],
                args.query_dist[2]
            )
        )
        Xq_mv = torch.stack([xx.flatten(), yy.flatten(), zz.flatten()], dim=1)
        option = '3D grid'

    time1 = time.time()

    if args.likelihood_type == "gamma":
        mean_x, mean_y, mean_z = bhm_velocity_mdl.predict(Xq_mv)
    elif args.likelihood_type == "gaussian":
        mean_x, var_x, mean_y, var_y, mean_z, var_z = bhm_velocity_mdl.predict(Xq_mv, args.query_blocks, args.variance_only)
    else:
        raise ValueError("Unsupported likelihood type: \"{}\"".format(args.likelihood_type))

    query_time = time.time() - time1

    print(' Total querying time={} s'.format(round(query_time, 2)))
    save_query_data((X, y_vx, y_vy, y_vz, Xq_mv, mean_x, var_x, mean_y, var_y, mean_z, var_z, framei), \
                    'velocity/{}_f{}'.format(args.save_query_data_path, framei))

    if args.eval:
        if hasattr(args, 'report_notes'):
            notes = args.report_notes
        else:
            notes = ''
        axes = [('x', y_vx_true, mean_x, var_x), ('y', y_vy_true, mean_y, var_y), ('z', y_vz_true, mean_z, var_z)]
        for axis, Xqi, mean, var in axes:
            mdl_name = 'reports/' + args.plot_title + '_' + axis
            calc_scores_velocity(mdl_name, option, Xqi.numpy(), mean.numpy().ravel(), predicted_var=\
                np.diagonal(var.numpy()), train_time=train_time, query_time=query_time, save_report=True, notes=notes)