示例#1
0
def main():
    data_dir = get_ram_config()["data_dir"]

    # load images
    imgs = []
    paths = [data_dir / "lenna.jpg", data_dir / "cat.jpg"]
    for i in range(len(paths)):
        img = pil_img_to_np_array(paths[i],
                                  desired_size=[512, 512],
                                  expand=True)
        imgs.append(torch.from_numpy(img))
    imgs = torch.cat(imgs).permute(0, 3, 1, 2)

    # loc = torch.Tensor(2, 2).uniform_(-1, 1)
    loc = torch.from_numpy(numpy.array([[0.0, 0.0], [0.0, 0.0]]))

    num_patches = 5
    scale = 2
    patch_size = 10

    ret = GlimpseSensor.Retina(
        size_first_patch=patch_size,
        num_patches_per_glimpse=num_patches,
        scale_factor_suc=scale,
    )
    glimpse = ret.foveate(imgs, loc).data.numpy()

    glimpse = numpy.reshape(glimpse,
                            [2, num_patches, 3, patch_size, patch_size])
    glimpse = numpy.transpose(glimpse, [0, 1, 3, 4, 2])

    merged = []
    for i in range(len(glimpse)):
        g = glimpse[i]
        g = list(g)
        g = [np_array_to_pil_img(l) for l in g]
        res = reduce(pil_merge_images, list(g))
        merged.append(res)

    merged = [numpy.asarray(l, dtype="float32") / 255.0 for l in merged]

    fig, axs = pyplot.subplots(nrows=2, ncols=1)
    for i, ax in enumerate(axs.flat):
        axs[i].imshow(merged[i])
        axs[i].get_xaxis().set_visible(False)
        axs[i].get_yaxis().set_visible(False)
    pyplot.show()
示例#2
0
        Args:
          i:
        """
        color = "r"
        co = coords[i]
        for j, ax in enumerate(axs.flat):
            for p in ax.patches:
                p.remove()
            c = co[j]
            rect = matplotlib_bounding_box(c[0], c[1], size, color)
            ax.add_patch(rect)

    anim = animation.FuncAnimation(fig,
                                   update_data,
                                   frames=num_anims,
                                   interval=500,
                                   repeat=True)

    anim.save(str(plot_dir / f"epoch_{epoch}.gif"),
              writer=animation.PillowWriter(fps=30))
    # anim.save(        str(plot_dir / f"epoch_{epoch}.mp4"),        extra_args=["-vcodec", "h264", "-pix_fmt", "yuv420p"]    )  # save as mp4
    pyplot.show()


if __name__ == "__main__":
    config = get_ram_config()
    plot_dir = list(config.plot_dir.iterdir())[-1]
    print(plot_dir)
    main(plot_dir)