def _condition_conv(self, inputs, inputs_extra, layer_sizes): channels = inputs.get_shape().as_list()[-1] inputs_extra_flat = tf.keras.layers.Flatten()(inputs_extra) scales = utils.mlp(inputs_extra_flat, layer_sizes + [channels]) shifts = utils.mlp(inputs_extra_flat, layer_sizes + [channels]) scales = tf.expand_dims(tf.expand_dims(scales, -2), -2) shifts = tf.expand_dims(tf.expand_dims(shifts, -2), -2) if self._conditioning_postprocessing is not None: scales, shifts = self._conditioning_postprocessing(scales, shifts) return scales * inputs + shifts
def _call_conv(self, inputs, inputs_extra): endpoints = {} x = inputs for nblock in range(self._num_blocks): if (inputs_extra is not None and (self._conditioning_type == "concat" or (self._conditioning_type == "input" and nblock == 0))): x = broadcast_and_concat(x, inputs_extra) num_channels = self._base_num_channels * (2**nblock) with tf.variable_scope("conv_block_{}".format(nblock)): x = self._conv_block(x, self._layers_per_block, num_channels) if (inputs_extra is not None and self._conditioning_type == "mult_and_add"): endpoints["conv_block_before_cond_{}".format(nblock)] = x with tf.variable_scope("conditioning"): x = self._condition_conv( x, inputs_extra, self._conditioning_layer_sizes) endpoints["conv_block_before_nonlin_{}".format(nblock)] = x if nblock < self._num_blocks - 1: x = self._nonlinearity(x) endpoints["conv_block_{}".format(nblock)] = x if self._fc_layer_sizes is not None: x = tf.keras.layers.Flatten()(x) x = utils.mlp(x, self._fc_layer_sizes) return x, endpoints
def _call_upconv(self, inputs, inputs_extra): endpoints = {} x = inputs if self._fc_layer_sizes is not None: x = tf.keras.layers.Flatten()(x) x = utils.mlp(x, self._fc_layer_sizes) x = tf.reshape(x, [-1] + self._upconv_reshape_size) for nblock in range(self._num_blocks - 1, -1, -1): num_channels = self._base_num_channels * (2**nblock) if (inputs_extra is not None and (self._conditioning_type == "concat" or (self._conditioning_type == "input" and nblock == 0))): x = broadcast_and_concat(x, inputs_extra) with tf.variable_scope("upconv_block_{}".format(nblock)): x = self._upconv_block(x, self._layers_per_block, num_channels) if (inputs_extra is not None and self._conditioning_type == "mult_and_add"): endpoints["upconv_block_before_cond_{}".format(nblock)] = x with tf.variable_scope("conditioning"): x = self._condition_conv( x, inputs_extra, self._conditioning_layer_sizes) endpoints["upconv_block_before_nonlin_{}".format(nblock)] = x x = self._nonlinearity(x) endpoints["upconv_block_{}".format(nblock)] = x # final layer that outputs the required number of channels x = tf.layers.conv2d(inputs=x, filters=self._channels_out, kernel_size=3, strides=(1, 1), padding="same", kernel_initializer=self._kernel_initializer) return x, endpoints