Exemple #1
0
    def get_loss(self, x, z):
        feed = {}
        feed[tbn('x:0')] = x
        feed[tbn('z:0')] = z
        feed[tbn('is_training:0')] = False

        return self.sess.run([self.loss_G, self.loss_Dreal, self.loss_Dfake], feed_dict=feed)
Exemple #2
0
    def sample(self, z):

        feed = {}
        feed[tbn('is_training:0')] = False
        feed[tbn('z:0')] = z
        out = self.sess.run(tbn('samples:0'), feed_dict=feed)

        return out
Exemple #3
0
    def train(self):
        """Take a training step with batches from each domain."""
        self.iteration += 1

        feed = {
            tbn('lr:0'): self.args.learning_rate,
            tbn('is_training:0'): True
        }

        self.sess.run([obn('train_op_D')], feed_dict=feed)
        self.sess.run([obn('train_op_G')], feed_dict=feed)
Exemple #4
0
    def train(self, x=None, z=None, learning_rate=.001):
        feed = {}
        if x:
            feed[tbn('x:0')] = x
        if z:
            feed[tbn('z:0')] = z
        feed[tbn('learning_rate:0')] = learning_rate
        feed[tbn('is_training:0')] = True

        self.sess.run(self.update_op_D, feed_dict=feed)
        self.sess.run(self.update_op_G, feed_dict=feed)
Exemple #5
0
    def get_loss(self, xb1, xb2):
        """Return all of the loss values for the given input."""
        feed = {
            tbn('xb1:0'): xb1,
            tbn('xb2:0'): xb2,
            tbn('is_training:0'): False
        }

        losses = self.sess.run(tf.get_collection('losses'), feed_dict=feed)

        lstring = ' '.join(['{:.3f}'.format(loss) for loss in losses])

        return lstring
Exemple #6
0
    def get_layer(self, xb1, xb2, name):
        """Get a layer of the network by name for the entire datasets given in xb1 and xb2."""
        tensor_name = "{}:0".format(name)
        tensor = tbn(tensor_name)

        feed = {
            tbn('xb1:0'): xb1,
            tbn('xb2:0'): xb2,
            tbn('is_training:0'): False
        }

        layer = self.sess.run(tensor, feed_dict=feed)

        return layer