예제 #1
0
    def _setup_loss(self):
        if self.loss[TYPE] == 'softmax_cross_entropy':
            self.train_loss_function = SoftmaxCrossEntropyLoss(
                num_classes=self.num_classes,
                feature_loss=self.loss,
                name='train_loss'
            )
        elif self.loss[TYPE] == 'sampled_softmax_cross_entropy':
            self.train_loss_function = SampledSoftmaxCrossEntropyLoss(
                decoder_obj=self.decoder_obj,
                num_classes=self.num_classes,
                feature_loss=self.loss,
                name='train_loss'
            )
        else:
            raise ValueError(
                "Loss type {} is not supported. Valid values are "
                "'softmax_cross_entropy' or "
                "'sampled_softmax_cross_entropy'".format(self.loss[TYPE])
            )

        self.eval_loss_function = SoftmaxCrossEntropyLoss(
            num_classes=self.num_classes,
            feature_loss=self.loss,
            name='eval_loss')
예제 #2
0
    def __init__(self,
                 decoder_obj=None,
                 num_classes=0,
                 feature_loss=None,
                 name='sampled_softmax_cross_entropy_metric'):
        super(SampledSoftmaxCrossEntropyMetric, self).__init__(name=name)

        self.metric_function = SampledSoftmaxCrossEntropyLoss(
            decoder_obj=decoder_obj,
            num_classes=num_classes,
            feature_loss=feature_loss)