Beispiel #1
0
    def encode(self, reflectance_map, background):

        rm_encodings = self.singlet(reflectance_map)
        bg_encodings = self.singlet(background)

        fully_encoded = tf.concat([rm_encodings[-1], bg_encodings[-1]],
                                  axis=-1)
        #fully_encoded = tf.nn.dropout(fully_encoded, 0.5)
        fully_encoded = layers.encode_layer(fully_encoded,
                                            1024, (3, 3), (1, 1),
                                            1,
                                            maxpool=False)
        decode_1 = layers.decode_layer(fully_encoded, 512, (3, 3), (2, 2), 3)
        decode_2 = layers.decode_layer(
            tf.concat([decode_1, rm_encodings[3]], axis=-1), 512, (3, 3),
            (2, 2), 3)
        decode_3 = layers.decode_layer(
            tf.concat([decode_2, rm_encodings[2]], axis=-1), 256, (3, 3),
            (2, 2), 3)
        decode_4 = layers.decode_layer(
            tf.concat([decode_3, rm_encodings[1]], axis=-1), 128, (3, 3),
            (2, 2), 1)
        return layers.encode_layer(decode_4,
                                   3, (1, 1), (1, 1),
                                   1,
                                   activation=None,
                                   norm=False,
                                   maxpool=False)
Beispiel #2
0
 def decode(self, encoded):
     decode_1 = layers.decode_layer(encoded, 512, (3, 3), (2, 2), FLAGS.depth)
     decode_2 = layers.decode_layer(decode_1, 512, (3, 3), (2, 2), FLAGS.depth)
     decode_3 = layers.decode_layer(decode_2, 256, (3, 3), (2, 2), FLAGS.depth)
     decode_4 = layers.decode_layer(decode_3, 128, (3, 3), (2, 2), FLAGS.depth)
     decode_5 = layers.decode_layer(decode_4, 64, (3, 3), (2, 2), FLAGS.depth)
     return layers.encode_layer(decode_5, 3, (1, 1), (1, 1), 1, activation=None, norm=False, maxpool=False)
Beispiel #3
0
 def decode(self, full, feature_maps, batch_size):
     full_2 = tf.concat([full, feature_maps[0]], axis=-1)
     if feature_maps:
         decode_1 = ed.decode_layer(full_2, 512, (3, 3), (8, 8))  # 2,8,8,512
         fm_1 = tf.reshape(feature_maps[0], [batch_size, 8, 8, -1])
         decode_2 = ed.decode_layer(tf.concat([decode_1, fm_1], axis=-1), 256, (3, 3), (2, 2))
         decode_3 = ed.decode_layer(tf.concat([decode_2, feature_maps[1]], axis=-1), 128, (3, 3), (2, 2))
         decode_4 = ed.decode_layer(tf.concat([decode_3, feature_maps[2]], axis=-1), 3, (3, 3), (1, 1))
     return decode_4
Beispiel #4
0
 def decode(self, encoded, multiscale):
     if FLAGS.multiscale:
         encoded = tf.concat([encoded, multiscale[-1]], axis=-1)
         decode_1 = layers.decode_layer(encoded, 512, (3, 3), (2, 2),
                                        FLAGS.depth)
         decode_1 = tf.concat([decode_1, multiscale[-2]], axis=-1)
         decode_2 = layers.decode_layer(decode_1, 512, (3, 3), (2, 2),
                                        FLAGS.depth)
         decode_2 = tf.concat([decode_2, multiscale[-3]], axis=-1)
         decode_3 = layers.decode_layer(decode_2, 256, (3, 3), (2, 2),
                                        FLAGS.depth)
     else:
         decode_1 = layers.decode_layer(encoded, 512, (3, 3), (2, 2),
                                        FLAGS.depth)
         decode_2 = layers.decode_layer(decode_1, 512, (3, 3), (2, 2),
                                        FLAGS.depth)
         decode_3 = layers.decode_layer(decode_2, 256, (3, 3), (2, 2),
                                        FLAGS.depth)
         decode_4 = layers.decode_layer(decode_3, 128, (3, 3), (2, 2),
                                        FLAGS.depth)
     return layers.encode_layer(decode_4,
                                3, (1, 1), (1, 1),
                                1,
                                activation=None,
                                norm=False,
                                maxpool=False)
Beispiel #5
0
 def decode_norms(self, encodings):
     norms = layers.decode_layer(encodings, 512, (3, 3), (1, 1),
                                 FLAGS.depth)
     norms = layers.decode_layer(norms, 256, (3, 3), (2, 2), FLAGS.depth)
     norms = layers.decode_layer(norms, 3, (3, 3), (1, 1), FLAGS.depth)
     return norms