Ejemplo n.º 1
0
 def _build_generator(self):
     layer_list = [
         tf.keras.layers.Reshape([4, 4, self.latent_dim // 16]),
         layers.Conv(256, ks=4, strides=2, transpose=True),
         layers.Act('relu'),
         layers.Conv(128, ks=4, strides=2, transpose=True),
         layers.Act('relu'),
         layers.Conv(FLAGS.out_dim, ks=4, strides=2, transpose=True)
     ]
     return self.sequential(layer_list)
Ejemplo n.º 2
0
 def _build_inference(self):
     layer_list = [
         layers.Conv(32, ks=3, strides=2),
         layers.Act('relu'),
         layers.Conv(64, ks=3, strides=2),
         layers.Act('relu'),
         tf.keras.layers.Flatten(),
         layers.Dense(self.latent_dim + self.latent_dim)
     ]
     return self.sequential(layer_list)
Ejemplo n.º 3
0
 def _build_encoder(self):
     layer_list = [
         layers.Conv(16, ks=4, strides=2),
         layers.Norm(),
         layers.Conv(64, ks=4, strides=2),
         layers.Norm(),
         layers.Conv(256, ks=4, strides=2),
         layers.Norm(),
         layers.Conv(512, ks=4, strides=2)
     ]
     return self.sequential(layer_list)
Ejemplo n.º 4
0
 def _build_generator(self):
     layer_list = [
         layers.Dense(7 * 7 * 32),
         layers.Act('relu'),
         tf.keras.layers.Reshape([7, 7, 32]),
         layers.Conv(64, ks=3, strides=2, transpose=True),
         layers.Act('relu'),
         layers.Conv(32, ks=3, strides=2, transpose=True),
         layers.Act('relu'),
         layers.Conv(FLAGS.out_dim, ks=3, strides=1, transpose=True)
     ]
     return self.sequential(layer_list)
Ejemplo n.º 5
0
 def __init__(self,
              filters,
              ks,
              strides=1,
              preact=False,
              use_norm=True,
              use_act=True,
              use_bias=False,
              last_norm=False,
              transpose=False):
     super(BasicBlock, self).__init__(filters, strides=strides)
     self.preact = preact
     self.use_norm = use_norm
     self.use_act = use_act
     self.last_norm = last_norm
     if self.use_norm:
         norm_scale = False if 'relu' in FLAGS.conv_act.lower() else True
         self.bn = layers.Norm(scale=norm_scale)
     if self.use_act:
         self.act = layers.Act()
     self.conv = layers.Conv(filters,
                             ks,
                             strides=strides,
                             use_bias=use_bias,
                             transpose=transpose)
     if self.last_norm:
         self.last_bn = layers.Norm()