Exemplo n.º 1
0
    def __call__(self, trainer=None):
        """Executes the evaluator extension.

        Unlike usual extensions, this extension can be executed without passing
        a trainer object. This extension reports the performance on validation
        dataset using the :func:`~pytorch_trainer.report` function. Thus, users can use
        this extension independently from any trainer by manually configuring
        a :class:`~pytorch_trainer.Reporter` object.

        Args:
            trainer (~pytorch_trainer.training.Trainer): Trainer object that invokes
                this extension. It can be omitted in case of calling this
                extension manually.

        Returns:
            dict: Result dictionary that contains mean statistics of values
            reported by the evaluation function.

        """
        # set up a reporter
        reporter = reporter_module.Reporter()
        if self.name is not None:
            prefix = self.name + '/'
        else:
            prefix = ''
        for name, target in six.iteritems(self._targets):
            reporter.add_observer(prefix + name, target)
            reporter.add_observers(prefix + name, target.named_children())

        with reporter:
            with torch.no_grad():
                result = self.evaluate()

        reporter_module.report(result)
        return result
Exemplo n.º 2
0
 def generative_lossfun(self, p_fake):
     t_1 = torch.ones(p_fake.shape,
                      dtype=p_fake.dtype,
                      device=p_fake.device)
     loss = F.mse_loss(p_fake, t_1)
     reporter.report({'loss_gen': loss})
     return loss
Exemplo n.º 3
0
 def forward(self, x, t):
     y = self.predictor(x)
     loss = F.nll_loss(y, t)
     reporter.report({'loss': loss}, self)
     acc = accuracy(y, t)
     reporter.report({'accuracy': acc}, self)
     return loss
Exemplo n.º 4
0
 def discriminative_lossfun(self, p_real, p_fake):
     t_1 = torch.ones(p_real.shape,
                      dtype=p_real.dtype,
                      device=p_real.device)
     t_0 = torch.zeros(p_fake.shape,
                       dtype=p_fake.dtype,
                       device=p_fake.device)
     loss = (F.mse_loss(p_real, t_1) \
              + F.mse_loss(p_fake, t_0)) * 0.5
     reporter.report({'loss_dis': loss})
     return loss
Exemplo n.º 5
0
    def forward(self, x):
        h = F.relu(self.l1(x))
        h = self.dropout(h)
        h = F.relu(self.l2(h))
        h = self.dropout(h)

        out = self.l3(h)
        sigma = self.l3_sigma(h)

        reporter.report({'sigma': sigma.mean()}, self)

        return out, sigma
Exemplo n.º 6
0
    def conditional_lossfun(self, y_fake, y_true):

        model = self.generator

        if hasattr(model, 'lossfun'):
            lossfun = model.lossfun
        else:
            lossfun = self.loss_func

        loss = lossfun(y_fake, y_true)
        reporter.report({'loss_cond': loss})
        return loss
Exemplo n.º 7
0
    def __call__(self, trainer):
        observation = trainer.observation
        if not (self._numerator_key in observation and
                self._denominator_key in observation):
            return

        self._numerator += observation[self._numerator_key]
        self._denominator += observation[self._denominator_key]

        if self._trigger(trainer):
            result = float(self._numerator) / self._denominator
            self._numerator = 0
            self._denominator = 0
            reporter.report({self._result_key: result})
Exemplo n.º 8
0
    def forward(self, x):

        h = super().forward(x)

        out = self['conv_out'](h)
        out = crop(out, x.shape)

        if not self._sigma:
            return out

        sigma = self['conv_sigma'](h)
        sigma = crop(sigma, x.shape)

        reporter.report({'sigma': torch.mean(sigma)}, self)

        return out, sigma
Exemplo n.º 9
0
    def report(self, trainer):

        # set up a reporter
        reporter = reporter_module.Reporter()
        if self.name is not None:
            prefix = self.name + '/'
        else:
            prefix = ''
        for name, target in six.iteritems(self._targets):
            reporter.add_observer(prefix + name, target)
            reporter.add_observers(prefix + name + '/',
                                   target.named_children())

        with reporter:
            result = self.evaluate(trainer)

        reporter_module.report(result)

        return result
Exemplo n.º 10
0
    def forward(self, *args, **kwargs):
        """Computes the loss value for input and label pair.
        It also computes accuracy and stores it to the attribute.
        Args:
            args (list of ~chainer.Variable): Input minibatch.
            kwargs (dict of ~chainer.Variable): Input minibatch.
        When ``label_key`` is ``int``, the corresponding element in ``args``
        is treated as ground truth labels. And when it is ``str``, the
        element in ``kwargs`` is used.
        The all elements of ``args`` and ``kwargs`` except the ground truth
        labels are features.
        It feeds features to the predictor and compare the result
        with ground truth labels.
        .. note::
            We set ``None`` to the attributes ``y``, ``loss`` and ``accuracy``
            each time before running the predictor, to avoid unnecessary memory
            consumption. Note that the variables set on those attributes hold
            the whole computation graph when they are computed. The graph
            stores interim values on memory required for back-propagation.
            We need to clear the attributes to free those values.
        Returns:
            ~chainer.Variable: Loss value.
        """

        self._reset()

        n_args = len(args) + len(kwargs)
        x = get_values(args, kwargs, self.x_keys)
        t = get_values(args, kwargs, self.t_keys) if n_args > 1 else None

        # predict, and then apply final activation
        y = self.predictor(x)

        if self.activation is not None:
            y = self.activation(y)

        # preserve
        self.x = x
        self.y = y
        self.t = t


        # if only input `x` is exist, return the predictions
        if t is None:
            return y

        # if ground-truth label `t` is exist, evaluate the loss and accuracy.
        # return the loss during training, otherwise return the predictions.
        if self.lossfun is not None:
            self.loss = self.lossfun(y, t)
            reporter.report({'loss': self.loss}, self)

        if self.accfun is not None:
            self.accuracy = self.accfun(y, t)
            reporter.report({'accuracy': self.accuracy}, self)

        if self.training:

            if self.loss is None:
                raise ValueError('loss is None..')

            return self.loss

        else:
            return self.y
Exemplo n.º 11
0
 def generative_lossfun(self, p_fake):
     size = p_fake.numel() / p_fake.shape[1]
     loss = torch.sum(F.softplus(-p_fake)) / size
     reporter.report({'loss_gen': loss})
     return loss
Exemplo n.º 12
0
 def discriminative_lossfun(self, p_real, p_fake):
     size = p_real.numel() / p_real.shape[1]
     loss = (torch.sum(F.softplus(-p_real)) / size \
             + torch.sum(F.softplus(p_fake)) / size) * 0.5 # NOTE: equivalent to binary cross entropy
     reporter.report({'loss_dis': loss})
     return loss