def test_epoch(self, x_batches, y_batches, itr_max=None): accs = [] n_batches = len(x_batches) avg_itr = 0 for x_batch, y_batch in zip(x_batches, y_batches): x_batch = set_tensor(x_batch, self.device) y_batch = set_tensor(y_batch, self.device) batch_size = x_batch.size(1) x = [[] for _ in range(self.n_layers)] q = [[] for _ in range(self.n_layers)] if self.amortised is True: q[0] = x_batch for l in range(1, self.n_layers): b_q = self.b_q[l - 1].repeat(1, batch_size) q[l] = self.W_q[l - 1] @ F.f(q[l - 1], self.act_fn) + b_q x = q[::-1] x[self.n_layers - 1] = x_batch else: x[0] = torch.empty_like(y_batch).normal_(mean=0.0, std=0.1) for l in range(1, self.n_layers): b = self.b[l - 1].repeat(1, batch_size) x[l] = self.W[l - 1] @ F.f(x[l - 1], self.act_fn) + b x[self.n_layers - 1] = x_batch x, errors, its = self.infer_v2(x, batch_size, x_batch, itr_max=itr_max) pred_y = x[0] acc = mnist_utils.mnist_accuracy(pred_y, y_batch) accs.append(acc) avg_itr += its return accs, avg_itr / n_batches
def test_amortised_epoch(self, x_batches, y_batches): accs = [] for x_batch, y_batch in zip(x_batches, y_batches): x_batch = set_tensor(x_batch, self.device) y_batch = set_tensor(y_batch, self.device) batch_size = x_batch.size(1) q = [[] for _ in range(self.n_layers)] q[0] = x_batch for l in range(1, self.n_layers): b_q = self.b_q[l - 1].repeat(1, batch_size) q[l] = self.W_q[l - 1] @ F.f(q[l - 1], self.act_fn) + b_q pred_y = q[-1] acc = mnist_utils.mnist_accuracy(pred_y, y_batch) accs.append(acc) return accs
def test_epoch(self, x_batches, y_batches): accs = [] for x_batch, y_batch in zip(x_batches, y_batches): x_batch = set_tensor(x_batch, self.device) y_batch = set_tensor(y_batch, self.device) batch_size = x_batch.size(1) x = [[] for _ in range(self.n_layers)] x[0] = x_batch for l in range(1, self.n_layers): b = self.b[l - 1].repeat(1, batch_size) x[l] = self.W[l - 1] @ F.f(x[l - 1], self.act_fn) + b pred_y = x[-1] acc = mnist_utils.mnist_accuracy(pred_y, y_batch) accs.append(acc) return accs
def test(self, imgs, labels): img_batches, label_batches, batch_sizes = self._get_batches( imgs, labels, self.batch_size) n_batches = len(img_batches) print(f"testing on {n_batches} batches of size {self.batch_size}") accs = [] for batch in range(n_batches): batch_size = batch_sizes[batch] x = [[] for _ in range(self.n_layers)] x[0] = img_batches[batch] for l in range(1, self.n_layers): x[l] = self.W[l - 1] @ F.f(x[l - 1], self.act_fn) + np.tile( self.b[l - 1], (1, batch_size)) acc = mnist_utils.mnist_accuracy(x[-1], label_batches[batch]) accs.append(acc) print(f"average accuracy {np.mean(np.array(accs))}")
def test_epoch(self, img_batches, label_batches, itr_max=None): accs = [] n_batches = len(img_batches) avg_itr = 0 for img_batch, label_batch in zip(img_batches, label_batches): img_batch = set_tensor(img_batch, self.device) label_batch = set_tensor(label_batch, self.device) batch_size = img_batch.size(1) x = [[] for _ in range(self.n_layers)] x[0] = img_batch for l in range(1, self.n_layers): b_q = self.b_q[l - 1].repeat(1, batch_size) x[l] = self.W_q[l - 1] @ F.f(x[l - 1], self.act_fn) + b_q x = x[::-1] x[self.n_layers - 1] = img_batch x, errors, q_errors, its = self.hybrid_infer(x, batch_size, itr_max=itr_max, test=True) pred_labels = x[0] acc = mnist_utils.mnist_accuracy(pred_labels, label_batch) accs.append(acc) avg_itr += its return accs, avg_itr / n_batches