def __call__(self, x, t): """Computes the loss value for given input and ground truth labels. Args: x (~chainer.Variable): Input of the weight matrix multiplication. t (~chainer.Variable): Batch of ground truth labels. Returns: ~chainer.Variable: Loss value. """ batch_size = x.shape[0] if hasattr(self, 'sample_data'): # for test sample_data = self.sample_data else: shape = (batch_size, self.sample_size) sample_data = self.sampler.sample(shape) samples = chainer.Variable(sample_data) return black_out.black_out(x, t, self.W, samples)
def forward(self, x, t): """Computes the loss value for given input and ground truth labels. Args: x (~chainer.Variable): Input of the weight matrix multiplication. t (~chainer.Variable): Batch of ground truth labels. Returns: ~chainer.Variable: Loss value. """ batch_size = x.shape[0] if self.sample_data is not None: # for test sample_data = self.sample_data else: shape = (batch_size, self.sample_size) sample_data = self.sampler.sample(shape) samples = variable.Variable(sample_data, requires_grad=False) return black_out.black_out(x, t, self.W, samples)