Пример #1
0
    def _build(self, incoming, loss_config, encoder_fn, decoder_fn, *args,
               **kwargs):
        self._build_dependencies()

        losses = None
        loss = None
        if Modes.GENERATE == self.mode:
            results = self.decode(incoming=incoming, decoder_fn=decoder_fn)
        elif Modes.ENCODE == self.mode:
            encoded = self.encode(incoming=incoming, encoder_fn=encoder_fn)
            results = self.z_mean(encoded)
        else:
            encoded = self.encode(incoming=incoming, encoder_fn=encoder_fn)
            z_mean = self.z_mean(encoded)
            z_log_sigma = self.z_log_sigma(encoded)
            shape = self._get_decoder_shape(incoming)
            eps = tf.random_normal(shape=shape,
                                   mean=self.mean,
                                   stddev=self.stddev,
                                   dtype=tf.float32,
                                   name='eps')
            z = tf.add(z_mean, tf.multiply(tf.sqrt(tf.exp(z_log_sigma)), eps))
            results = self.decode(incoming=z, decoder_fn=decoder_fn)
            losses, loss = self._build_loss(incoming,
                                            results,
                                            loss_config,
                                            z_mean=z_mean,
                                            z_log_sigma=z_log_sigma)

        return BridgeSpec(results=results, losses=losses, loss=loss)
Пример #2
0
    def _build(self, incoming, loss_config, encoder_fn, decoder_fn, *args, **kwargs):
        losses, loss = None, None
        if Modes.GENERATE == self.mode:
            results = self.decode(incoming=incoming, decoder_fn=decoder_fn)
        elif Modes.ENCODE == self.mode:
            results = self.encode(incoming=incoming, encoder_fn=encoder_fn)
        else:
            x = self.encode(incoming=incoming, encoder_fn=encoder_fn)
            results = self.decode(incoming=x, decoder_fn=decoder_fn)
            if not Modes.is_infer(self.mode):
                losses, loss = self._build_loss(incoming, results, loss_config)

        return BridgeSpec(results=results, losses=losses, loss=loss)
Пример #3
0
    def _build(self, features, labels, loss, encoder_fn, decoder_fn, *args,
               **kwargs):
        losses, loss = None, None
        if Modes.GENERATE == self.mode:
            results = self.decode(incoming=features,
                                  features=features,
                                  labels=labels,
                                  decoder_fn=decoder_fn)
        elif Modes.ENCODE == self.mode:
            results = self.encode(features=features,
                                  labels=labels,
                                  encoder_fn=encoder_fn)
        else:
            x = self.encode(features=features,
                            labels=labels,
                            encoder_fn=encoder_fn)
            results = self.decode(features=x,
                                  labels=labels,
                                  decoder_fn=decoder_fn)
            if not Modes.is_infer(self.mode):
                losses, loss = self._build_loss(results, features, labels,
                                                loss)

        return BridgeSpec(results=results, losses=losses, loss=loss)