Esempio n. 1
0
 def _inverse(self, y, **kwargs):
     if 'init' in kwargs and kwargs['init']:
         self._init_from_data(y)
     return act_norm(y,
                     var(self.log_s),
                     var(self.b),
                     inverse=tf.constant(True))
Esempio n. 2
0
 def train(self, train_data: tf.data.Dataset, steps_per_epoch, num_epochs=1,
           lam=1.0, lam_decay=0.0, alpha=0.0, **flow_kwargs):
     train_gen_data = train_data.take(steps_per_epoch).repeat(num_epochs)
     with tqdm(total=steps_per_epoch*num_epochs, desc='train') as prog:
         hist = dict()
         lam = tf.Variable(lam, dtype=tf.float32)
         for epoch in range(num_epochs):
             for x,y in train_gen_data.take(steps_per_epoch):
                 # train discriminators
                 dx_loss, dy_loss = self.train_discriminators_on_batch(x, y)
                 # train generators
                 g_obj, nll_x, nll_y, gx_loss, gy_loss, gx_aux, gy_aux = self.train_generators_on_batch(x, y, alpha=alpha, lam=utils.var(lam))
                 utils.update_metrics(hist, g_obj=g_obj.numpy(), gx_loss=gx_loss.numpy(), gy_loss=dy_loss.numpy(),
                                      nll_x=nll_x.numpy(), nll_y=nll_y.numpy())
                 prog.update(1)
                 prog.set_postfix(utils.get_metrics(hist))
             lam.assign_sub(lam_decay)
     return hist
Esempio n. 3
0
 def _inverse(self, y, **kwargs):
     return invertible_1x1_conv(y, var(self.L), var(self.U), var(self.P), var(self.log_d), var(self.sgn_d), inverse=tf.constant(True))
Esempio n. 4
0
 def _forward(self, x, **kwargs):
     if 'init' in kwargs and kwargs['init']:
         self._init_from_data(x)
     return act_norm(x, var(self.log_s), var(self.b))
Esempio n. 5
0
 def _forward(self, x, **kwargs):
     return invertible_1x1_conv(x, var(self.L), var(self.U), var(self.P), var(self.log_d), var(self.sgn_d))