def __init__( self, g, k, s, c, h_g, h_l, std, hidden_size, num_classes, ): """Constructor. Args: g: size of the square patches in the glimpses extracted by the retina. k: number of patches to extract per glimpse. s: scaling factor that controls the size of successive patches. c: number of channels in each image. h_g: hidden layer size of the fc layer for `phi`. h_l: hidden layer size of the fc layer for `l`. std: standard deviation of the Gaussian policy. hidden_size: hidden size of the rnn. num_classes: number of classes in the dataset. num_glimpses: number of glimpses to take per image, i.e. number of BPTT steps. """ super().__init__() self.std = std self.sensor = modules.GlimpseNetwork(h_g, h_l, g, k, s, c) self.rnn = modules.CoreNetwork(hidden_size, hidden_size) self.locator = modules.LocationNetwork(hidden_size, 2, std) self.classifier = modules.ActionNetwork(hidden_size, num_classes) self.baseliner = modules.BaselineNetwork(hidden_size, 1)
# load images imgs = [] paths = [data_dir + "./lenna.jpg", data_dir + "./cat.jpg"] for i in range(len(paths)): img = utils.img2array(paths[i], desired_size=[512, 512], expand=True) imgs.append(torch.from_numpy(img)) imgs = torch.cat(imgs).permute((0, 3, 1, 2)) B, C, H, W = imgs.shape loc = torch.Tensor([[-1.0, 1.0], [-1.0, 1.0]]) sensor = modules.GlimpseNetwork(h_g=128, h_l=128, g=64, k=3, s=2, c=3) g_t = sensor(imgs, loc) assert g_t.shape == (B, 256) rnn = modules.CoreNetwork(input_size=256, hidden_size=256) h_t = torch.zeros(g_t.shape[0], 256) h_t = rnn(g_t, h_t) assert h_t.shape == (B, 256) classifier = modules.ActionNetwork(256, 10) a_t = classifier(h_t) assert a_t.shape == (B, 10) loc_net = modules.LocationNetwork(256, 2, 0.11) mu, l_t = loc_net(h_t) assert l_t.shape == (B, 2) base = modules.BaselineNetwork(256, 1) b_t = base(h_t) assert b_t.shape == (B, 1)