# for i, ((tr_block, block_idx), conditional_blocks) in enumerate(datasets):
        #    cspn = cspns[i]
        if i == 0:
            # first time, we only care about the structure to put nans
            mpe_query_blocks = np.zeros_like(tr_block[0:num_mpes, :].reshape(num_mpes, -1))
            sample_query_blocks = np.zeros_like(tr_block[0:num_samples, :].reshape(num_samples, -1))
        else:
            # i+1 time: we set the previous mpe values as evidence
            mpe_query_blocks = np.zeros_like(np.array(tr_block[0:num_mpes, :].reshape(num_mpes, -1)))
            mpe_query_blocks[:, -(mpe_result.shape[1]) :] = mpe_result

            sample_query_blocks = np.zeros_like(np.array(tr_block[0:num_samples, :].reshape(num_samples, -1)))
            sample_query_blocks[:, -(sample_result.shape[1]) :] = sample_result

        cspn_mpe_query = set_sub_block_nans(mpe_query_blocks, inp=block_idx, nans=block_idx[0:conditional_blocks])
        mpe_result = mpe(cspn, cspn_mpe_query)

        mpe_img_blocks = stitch_imgs(
            mpe_result.shape[0], img_size=img_size, num_blocks=num_blocks, blocks={tuple(block_idx): mpe_result}
        )

        cspn_sample_query = set_sub_block_nans(sample_query_blocks, inp=block_idx, nans=block_idx[0:conditional_blocks])
        sample_result = sample_instances(cspn, cspn_sample_query, RandomState(123))

        sample_img_blocks = stitch_imgs(
            sample_result.shape[0], img_size=img_size, num_blocks=num_blocks, blocks={tuple(block_idx): sample_result}
        )

        for j in range(num_mpes):
            mpe_fname = output_path + "mpe_%s_%s.png" % ("-".join(map(str, block_idx)), j)
Exemplo n.º 2
0
        if mpe_query_blocks is None:
            # first time, we only care about the structure to put nans
            mpe_query_blocks = np.zeros_like(tr_block[0:10, :].reshape(10, -1))
            sample_query_blocks = mpe_query_blocks
        else:
            # i+1 time: we set the previous mpe values as evidence
            mpe_query_blocks = np.zeros_like(np.array(tr_block[0:10, :].reshape(10, -1)))
            mpe_query_blocks[:, -(mpe_result.shape[1] - 10):] = mpe_result[:, 0:-10]

            sample_query_blocks = np.zeros_like(np.array(tr_block[0:10, :].reshape(10, -1)))
            sample_query_blocks[:, -(sample_result.shape[1] - 10):] = sample_result[:, 0:-10]




        cspn_mpe_query = np.concatenate((set_sub_block_nans(mpe_query_blocks, inp=block_idx, nans=[block_idx[0]]),
                                         np.eye(10, 10)), axis=1)
        mpe_result = mpe(cspn, cspn_mpe_query)

        mpe_img_blocks = stitch_imgs(mpe_result.shape[0], img_size=(20, 20), num_blocks=(2, 2),
                                     blocks={tuple(block_idx): mpe_result[:, 0:-10]})

        cspn_sample_query = np.concatenate((set_sub_block_nans(sample_query_blocks, inp=block_idx, nans=[block_idx[0]]),
                                            np.eye(10, 10)),
                                           axis=1)
        sample_result = sample_instances(cspn, cspn_sample_query, RandomState(123))

        sample_img_blocks = stitch_imgs(mpe_result.shape[0], img_size=(20, 20), num_blocks=(2, 2),
                                        blocks={tuple(block_idx): sample_result[:, 0:-10]})

        for c in range(10):
Exemplo n.º 3
0
            mpe_query_blocks = np.zeros_like(tr_block[0:10, :].reshape(10, -1))
            sample_query_blocks = np.zeros_like(tr_block[0:10, :].reshape(
                10, -1))
        else:
            # i+1 time: we set the previous mpe values as evidence
            mpe_query_blocks = np.zeros_like(
                np.array(tr_block[0:10, :].reshape(10, -1)))
            mpe_query_blocks[:, -(mpe_result.shape[1]):] = mpe_result

            sample_query_blocks = np.zeros_like(
                np.array(tr_block[0:10, :].reshape(10, -1)))
            sample_query_blocks[:, -(
                sample_query_blocks.shape[1]):] = sample_result

        cspn_mpe_query = set_sub_block_nans(mpe_query_blocks,
                                            inp=block_idx,
                                            nans=[block_idx[0]])
        mpe_result = mpe(cspn, cspn_mpe_query)

        mpe_img_blocks = stitch_imgs(mpe_result.shape[0],
                                     img_size=(64, 64),
                                     num_blocks=(2, 2),
                                     blocks={tuple(block_idx): mpe_result})

        cspn_sample_query = set_sub_block_nans(sample_query_blocks,
                                               inp=block_idx,
                                               nans=[block_idx[0]])
        sample_result = sample_instances(cspn, cspn_sample_query,
                                         RandomState(123))

        sample_img_blocks = stitch_imgs(
Exemplo n.º 4
0
        else:
            # i+1 time: we set the previous mpe values as evidence
            mpe_query_blocks = np.zeros_like(
                np.array(tr_block[0:10, :].reshape(10, -1)))
            mpe_query_blocks[:,
                             -(mpe_result.shape[1] - 10):] = mpe_result[:,
                                                                        0:-10]

            sample_query_blocks = np.zeros_like(
                np.array(tr_block[0:10, :].reshape(10, -1)))
            sample_query_blocks[:, -(sample_result.shape[1] -
                                     10):] = sample_result[:, 0:-10]

        cspn_mpe_query = np.concatenate(
            (set_sub_block_nans(mpe_query_blocks,
                                inp=block_idx,
                                nans=[block_idx[0]]), np.eye(10, 10)),
            axis=1)
        mpe_result = mpe(cspn, cspn_mpe_query)

        mpe_img_blocks = stitch_imgs(
            mpe_result.shape[0],
            img_size=(20, 20),
            num_blocks=(2, 2),
            blocks={tuple(block_idx): mpe_result[:, 0:-10]})

        cspn_sample_query = np.concatenate(
            (set_sub_block_nans(sample_query_blocks,
                                inp=block_idx,
                                nans=[block_idx[0]]), np.eye(10, 10)),
            axis=1)