def encode(self, reflectance_map, background): rm_encodings = self.singlet(reflectance_map) bg_encodings = self.singlet(background) fully_encoded = tf.concat([rm_encodings[-1], bg_encodings[-1]], axis=-1) #fully_encoded = tf.nn.dropout(fully_encoded, 0.5) fully_encoded = layers.encode_layer(fully_encoded, 1024, (3, 3), (1, 1), 1, maxpool=False) decode_1 = layers.decode_layer(fully_encoded, 512, (3, 3), (2, 2), 3) decode_2 = layers.decode_layer( tf.concat([decode_1, rm_encodings[3]], axis=-1), 512, (3, 3), (2, 2), 3) decode_3 = layers.decode_layer( tf.concat([decode_2, rm_encodings[2]], axis=-1), 256, (3, 3), (2, 2), 3) decode_4 = layers.decode_layer( tf.concat([decode_3, rm_encodings[1]], axis=-1), 128, (3, 3), (2, 2), 1) return layers.encode_layer(decode_4, 3, (1, 1), (1, 1), 1, activation=None, norm=False, maxpool=False)
def decode(self, encoded): decode_1 = layers.decode_layer(encoded, 512, (3, 3), (2, 2), FLAGS.depth) decode_2 = layers.decode_layer(decode_1, 512, (3, 3), (2, 2), FLAGS.depth) decode_3 = layers.decode_layer(decode_2, 256, (3, 3), (2, 2), FLAGS.depth) decode_4 = layers.decode_layer(decode_3, 128, (3, 3), (2, 2), FLAGS.depth) decode_5 = layers.decode_layer(decode_4, 64, (3, 3), (2, 2), FLAGS.depth) return layers.encode_layer(decode_5, 3, (1, 1), (1, 1), 1, activation=None, norm=False, maxpool=False)
def decode(self, full, feature_maps, batch_size): full_2 = tf.concat([full, feature_maps[0]], axis=-1) if feature_maps: decode_1 = ed.decode_layer(full_2, 512, (3, 3), (8, 8)) # 2,8,8,512 fm_1 = tf.reshape(feature_maps[0], [batch_size, 8, 8, -1]) decode_2 = ed.decode_layer(tf.concat([decode_1, fm_1], axis=-1), 256, (3, 3), (2, 2)) decode_3 = ed.decode_layer(tf.concat([decode_2, feature_maps[1]], axis=-1), 128, (3, 3), (2, 2)) decode_4 = ed.decode_layer(tf.concat([decode_3, feature_maps[2]], axis=-1), 3, (3, 3), (1, 1)) return decode_4
def decode(self, encoded, multiscale): if FLAGS.multiscale: encoded = tf.concat([encoded, multiscale[-1]], axis=-1) decode_1 = layers.decode_layer(encoded, 512, (3, 3), (2, 2), FLAGS.depth) decode_1 = tf.concat([decode_1, multiscale[-2]], axis=-1) decode_2 = layers.decode_layer(decode_1, 512, (3, 3), (2, 2), FLAGS.depth) decode_2 = tf.concat([decode_2, multiscale[-3]], axis=-1) decode_3 = layers.decode_layer(decode_2, 256, (3, 3), (2, 2), FLAGS.depth) else: decode_1 = layers.decode_layer(encoded, 512, (3, 3), (2, 2), FLAGS.depth) decode_2 = layers.decode_layer(decode_1, 512, (3, 3), (2, 2), FLAGS.depth) decode_3 = layers.decode_layer(decode_2, 256, (3, 3), (2, 2), FLAGS.depth) decode_4 = layers.decode_layer(decode_3, 128, (3, 3), (2, 2), FLAGS.depth) return layers.encode_layer(decode_4, 3, (1, 1), (1, 1), 1, activation=None, norm=False, maxpool=False)
def decode_norms(self, encodings): norms = layers.decode_layer(encodings, 512, (3, 3), (1, 1), FLAGS.depth) norms = layers.decode_layer(norms, 256, (3, 3), (2, 2), FLAGS.depth) norms = layers.decode_layer(norms, 3, (3, 3), (1, 1), FLAGS.depth) return norms