예제 #1
0
 def _condition_conv(self, inputs, inputs_extra, layer_sizes):
     channels = inputs.get_shape().as_list()[-1]
     inputs_extra_flat = tf.keras.layers.Flatten()(inputs_extra)
     scales = utils.mlp(inputs_extra_flat, layer_sizes + [channels])
     shifts = utils.mlp(inputs_extra_flat, layer_sizes + [channels])
     scales = tf.expand_dims(tf.expand_dims(scales, -2), -2)
     shifts = tf.expand_dims(tf.expand_dims(shifts, -2), -2)
     if self._conditioning_postprocessing is not None:
         scales, shifts = self._conditioning_postprocessing(scales, shifts)
     return scales * inputs + shifts
예제 #2
0
 def _call_conv(self, inputs, inputs_extra):
     endpoints = {}
     x = inputs
     for nblock in range(self._num_blocks):
         if (inputs_extra is not None and
             (self._conditioning_type == "concat" or
              (self._conditioning_type == "input" and nblock == 0))):
             x = broadcast_and_concat(x, inputs_extra)
         num_channels = self._base_num_channels * (2**nblock)
         with tf.variable_scope("conv_block_{}".format(nblock)):
             x = self._conv_block(x, self._layers_per_block, num_channels)
             if (inputs_extra is not None
                     and self._conditioning_type == "mult_and_add"):
                 endpoints["conv_block_before_cond_{}".format(nblock)] = x
                 with tf.variable_scope("conditioning"):
                     x = self._condition_conv(
                         x, inputs_extra, self._conditioning_layer_sizes)
             endpoints["conv_block_before_nonlin_{}".format(nblock)] = x
             if nblock < self._num_blocks - 1:
                 x = self._nonlinearity(x)
         endpoints["conv_block_{}".format(nblock)] = x
     if self._fc_layer_sizes is not None:
         x = tf.keras.layers.Flatten()(x)
         x = utils.mlp(x, self._fc_layer_sizes)
     return x, endpoints
예제 #3
0
 def _call_upconv(self, inputs, inputs_extra):
     endpoints = {}
     x = inputs
     if self._fc_layer_sizes is not None:
         x = tf.keras.layers.Flatten()(x)
         x = utils.mlp(x, self._fc_layer_sizes)
         x = tf.reshape(x, [-1] + self._upconv_reshape_size)
     for nblock in range(self._num_blocks - 1, -1, -1):
         num_channels = self._base_num_channels * (2**nblock)
         if (inputs_extra is not None and
             (self._conditioning_type == "concat" or
              (self._conditioning_type == "input" and nblock == 0))):
             x = broadcast_and_concat(x, inputs_extra)
         with tf.variable_scope("upconv_block_{}".format(nblock)):
             x = self._upconv_block(x, self._layers_per_block, num_channels)
             if (inputs_extra is not None
                     and self._conditioning_type == "mult_and_add"):
                 endpoints["upconv_block_before_cond_{}".format(nblock)] = x
                 with tf.variable_scope("conditioning"):
                     x = self._condition_conv(
                         x, inputs_extra, self._conditioning_layer_sizes)
             endpoints["upconv_block_before_nonlin_{}".format(nblock)] = x
             x = self._nonlinearity(x)
         endpoints["upconv_block_{}".format(nblock)] = x
     # final layer that outputs the required number of channels
     x = tf.layers.conv2d(inputs=x,
                          filters=self._channels_out,
                          kernel_size=3,
                          strides=(1, 1),
                          padding="same",
                          kernel_initializer=self._kernel_initializer)
     return x, endpoints