Exemple #1
0
    def __init__(
        self, g, k, s, c, h_g, h_l, std, hidden_size, num_classes,
    ):
        """Constructor.

        Args:
          g: size of the square patches in the glimpses extracted by the retina.
          k: number of patches to extract per glimpse.
          s: scaling factor that controls the size of successive patches.
          c: number of channels in each image.
          h_g: hidden layer size of the fc layer for `phi`.
          h_l: hidden layer size of the fc layer for `l`.
          std: standard deviation of the Gaussian policy.
          hidden_size: hidden size of the rnn.
          num_classes: number of classes in the dataset.
          num_glimpses: number of glimpses to take per image,
            i.e. number of BPTT steps.
        """
        super().__init__()

        self.std = std

        self.sensor = modules.GlimpseNetwork(h_g, h_l, g, k, s, c)
        self.rnn = modules.CoreNetwork(hidden_size, hidden_size)
        self.locator = modules.LocationNetwork(hidden_size, 2, std)
        self.classifier = modules.ActionNetwork(hidden_size, num_classes)
        self.baseliner = modules.BaselineNetwork(hidden_size, 1)
    # load images
    imgs = []
    paths = [data_dir + "./lenna.jpg", data_dir + "./cat.jpg"]
    for i in range(len(paths)):
        img = utils.img2array(paths[i], desired_size=[512, 512], expand=True)
        imgs.append(torch.from_numpy(img))
    imgs = torch.cat(imgs).permute((0, 3, 1, 2))
    B, C, H, W = imgs.shape

    loc = torch.Tensor([[-1.0, 1.0], [-1.0, 1.0]])
    sensor = modules.GlimpseNetwork(h_g=128, h_l=128, g=64, k=3, s=2, c=3)
    g_t = sensor(imgs, loc)
    assert g_t.shape == (B, 256)

    rnn = modules.CoreNetwork(input_size=256, hidden_size=256)
    h_t = torch.zeros(g_t.shape[0], 256)
    h_t = rnn(g_t, h_t)
    assert h_t.shape == (B, 256)

    classifier = modules.ActionNetwork(256, 10)
    a_t = classifier(h_t)
    assert a_t.shape == (B, 10)

    loc_net = modules.LocationNetwork(256, 2, 0.11)
    mu, l_t = loc_net(h_t)
    assert l_t.shape == (B, 2)

    base = modules.BaselineNetwork(256, 1)
    b_t = base(h_t)
    assert b_t.shape == (B, 1)