def bottom(self, x): with tf.variable_scope(self.name): if not tf.contrib.eager.in_eager_mode(): tf.summary.image("inputs", common_layers.tpu_safe_image_summary(x), max_outputs=2) return tf.to_float(x)
def targets_bottom(self, x): inputs = x with tf.variable_scope(self.name): if not tf.contrib.eager.in_eager_mode(): tf.summary.image( "targets_bottom", common_layers.tpu_safe_image_summary(inputs), max_outputs=1) inputs_shape = common_layers.shape_list(inputs) if len(inputs_shape) != 4: raise ValueError("Assuming images given as int tensors in the format " "[batch, height, width, channels] (256 values).") # We embed each of 256=self.top_dimensionality possible pixel values. embedding_var = tf.get_variable( "pixel_embedding", [self.top_dimensionality, self.PIXEL_EMBEDDING_SIZE]) hot_inputs = tf.one_hot(tf.to_int32(inputs), self.top_dimensionality) hot_inputs = tf.reshape(hot_inputs, [-1, self.top_dimensionality]) embedded = tf.matmul(hot_inputs, embedding_var) # Let's now merge all channels that were embedded into a single vector. merged_size = self.PIXEL_EMBEDDING_SIZE * inputs_shape[3] embedded = tf.reshape(embedded, inputs_shape[:3] + [merged_size]) merged = tf.layers.dense( embedded, self._body_input_depth, name="merge_pixel_embedded_channels") return merged
def targets_bottom(self, x): inputs = x with tf.variable_scope(self.name): if not tf.contrib.eager.in_eager_mode(): tf.summary.image("targets_bottom", common_layers.tpu_safe_image_summary(inputs), max_outputs=1) inputs_shape = common_layers.shape_list(inputs) if len(inputs_shape) != 4: raise ValueError( "Assuming images given as int tensors in the format " "[batch, height, width, channels] (256 values).") # We embed each of 256=self.top_dimensionality possible pixel values. embedding_var = tf.get_variable( "pixel_embedding", [self.top_dimensionality, self.PIXEL_EMBEDDING_SIZE]) hot_inputs = tf.one_hot(tf.to_int32(inputs), self.top_dimensionality) hot_inputs = tf.reshape(hot_inputs, [-1, self.top_dimensionality]) embedded = tf.matmul(hot_inputs, embedding_var) # Let's now merge all channels that were embedded into a single vector. merged_size = self.PIXEL_EMBEDDING_SIZE * inputs_shape[3] embedded = tf.reshape(embedded, inputs_shape[:3] + [merged_size]) merged = tf.layers.dense(embedded, self._body_input_depth, name="merge_pixel_embedded_channels") return merged
def bottom(self, x): with tf.variable_scope(self.name): if not tf.executing_eagerly(): tf.summary.image("inputs", common_layers.tpu_safe_image_summary(x), max_outputs=2) return tf.to_float(x)
def image_summary(self, name, image_logits, max_outputs=1): """Helper for image summaries that are safe on TPU.""" if len(image_logits.get_shape()) != 5: tf.logging.info("Not generating image summary, maybe not an image.") return return tf.summary.image( name, common_layers.tpu_safe_image_summary(tf.argmax(image_logits, -1)), max_outputs=max_outputs)
def image_summary(self, name, image_logits, max_outputs=1): """Helper for image summaries that are safe on TPU.""" if len(image_logits.get_shape()) != 5: tf.logging.info("Not generating image summary, maybe not an image.") return return tf.summary.image( name, common_layers.tpu_safe_image_summary(tf.argmax(image_logits, -1)), max_outputs=max_outputs)
def top(self, body_output, _): num_channels = self._model_hparams.problem.num_channels num_frames = self._model_hparams.video_num_target_frames with tf.variable_scope("rgb"): body_output_shape = common_layers.shape_list(body_output) res = tf.layers.dense(body_output, num_channels * num_frames, name="cast") res = tf.reshape(res, body_output_shape[:3] + [num_channels, num_frames]) res = tf.transpose(res, [0, 4, 1, 2, 3]) # Move frames next to batch. if not tf.get_variable_scope().reuse: res_argmax = res[:, -1, :, :, :] tf.summary.image( "result", common_layers.tpu_safe_image_summary(res_argmax), max_outputs=1) return tf.expand_dims(res, axis=-1) # Add an axis like in perplexity.
def top(self, body_output, _): num_channels = self._model_hparams.problem.num_channels num_frames = self._model_hparams.video_num_target_frames with tf.variable_scope("rgb"): body_output_shape = common_layers.shape_list(body_output) res = tf.layers.dense(body_output, num_channels * num_frames, name="cast") res = tf.reshape(res, body_output_shape[:3] + [num_channels, num_frames]) res = tf.transpose(res, [0, 4, 1, 2, 3]) # Move frames next to batch. if not tf.get_variable_scope().reuse: res_argmax = res[:, -1, :, :, :] tf.summary.image( "result", common_layers.tpu_safe_image_summary(res_argmax), max_outputs=1) return tf.expand_dims(res, axis=-1) # Add an axis like in perplexity.
def top(self, body_output, _): # TODO(lukaszkaiser): is this a universal enough way to get channels? num_channels = self._model_hparams.problem.num_channels with tf.variable_scope("rgb_softmax"): body_output_shape = common_layers.shape_list(body_output) reshape_shape = body_output_shape[:3] reshape_shape.extend([num_channels, self.top_dimensionality]) res = tf.layers.dense(body_output, self.top_dimensionality * num_channels) res = tf.reshape(res, reshape_shape) if not tf.get_variable_scope().reuse: res_argmax = tf.argmax(res, axis=-1) tf.summary.image( "result", common_layers.tpu_safe_image_summary(res_argmax), max_outputs=1) return res
def top(self, body_output, _): # TODO(lukaszkaiser): is this a universal enough way to get channels? num_channels = self._model_hparams.problem.num_channels with tf.variable_scope("rgb_softmax"): body_output_shape = common_layers.shape_list(body_output) reshape_shape = body_output_shape[:3] reshape_shape.extend([num_channels, self.top_dimensionality]) res = tf.layers.dense(body_output, self.top_dimensionality * num_channels) res = tf.reshape(res, reshape_shape) if not tf.get_variable_scope().reuse: res_argmax = tf.argmax(res, axis=-1) tf.summary.image( "result", common_layers.tpu_safe_image_summary(res_argmax), max_outputs=1) return res
def top(self, body_output, _): num_channels = self._model_hparams.problem.num_channels num_frames = self._model_hparams.video_num_target_frames with tf.variable_scope("rgb_softmax"): body_output_shape = common_layers.shape_list(body_output) reshape_shape = body_output_shape[:3] reshape_shape.extend([num_channels, num_frames, self.top_dimensionality]) res = tf.layers.dense(body_output, self.top_dimensionality * num_channels * num_frames) res = tf.reshape(res, reshape_shape) res = tf.transpose(res, [0, 4, 1, 2, 3, 5]) if not tf.get_variable_scope().reuse: res_argmax = tf.argmax(res[:, -1, :, :, :, :], axis=-1) tf.summary.image( "result", common_layers.tpu_safe_image_summary(res_argmax), max_outputs=1) return res
def top(self, body_output, _): num_channels = self._model_hparams.problem.num_channels num_frames = self._model_hparams.video_num_target_frames with tf.variable_scope("rgb_softmax"): body_output_shape = common_layers.shape_list(body_output) reshape_shape = body_output_shape[:3] reshape_shape.extend([num_channels, num_frames, self.top_dimensionality]) res = tf.layers.dense(body_output, self.top_dimensionality * num_channels * num_frames) res = tf.reshape(res, reshape_shape) res = tf.transpose(res, [0, 4, 1, 2, 3, 5]) if not tf.get_variable_scope().reuse: res_argmax = tf.argmax(res[:, -1, :, :, :, :], axis=-1) tf.summary.image( "result", common_layers.tpu_safe_image_summary(res_argmax), max_outputs=1) return res
def bottom_compress(self, inputs, name="bottom"): """Compresses channel-wise input pixels into whole pixel representions. Perform conversion of RGB pixel values to a real number in the range -1 to 1. This combines pixel channels to form a representation of shape [img_len, img_len]. Args: inputs: Tensor representing RGB pixel intensities as integers, of shape [batch, img_len, img_len, channels]. name: string, scope. Returns: body_input: Tensor of shape [batch, img_len, img_len, body_input_depth]. """ with tf.variable_scope(name): inputs = tf.to_float(inputs) hp = self._model_hparams if hp.mode != tf.estimator.ModeKeys.PREDICT: tf.summary.image("inputs", common_layers.tpu_safe_image_summary(inputs), max_outputs=2) inputs = common_layers.convert_rgb_to_symmetric_real(inputs) # Reshape inputs to apply convolutions across [img_len, img_len*channels]. inputs_shape = common_layers.shape_list(inputs) inputs = tf.reshape( inputs, [-1, inputs_shape[1], inputs_shape[2] * inputs_shape[3], 1]) # tf.logging.info("input shape" , inputs_shape) # Compress RGB intensities for each pixel using a convolution. # ValueError: Negative dimension size caused by subtracting 3 from 1 for 'imagetransformer/parallel_0_4/ # imagetransformer/imagetransformer/image_channel_bottom_identity_modality/output_bottom/conv_input/Conv2D' # (op: 'Conv2D') with input shapes: [?,1,1,1], [1,3,1,256]. outputs = tf.layers.conv2d(inputs, self._body_input_depth, kernel_size=(1, self.num_channels), padding="VALID", strides=(1, self.num_channels), activation=tf.nn.relu, name="conv_input") return outputs
def bottom_compress(self, inputs, name="bottom"): """Compresses channel-wise input pixels into whole pixel representions. Perform conversion of RGB pixel values to a real number in the range -1 to 1. This combines pixel channels to form a representation of shape [img_len, img_len]. Args: inputs: Tensor representing RGB pixel intensities as integers, of shape [batch, img_len, img_len, channels]. name: string, scope. Returns: body_input: Tensor of shape [batch, img_len, img_len, self._model_hparams.hidden_size]. """ with tf.variable_scope(name): inputs = tf.to_float(inputs) hp = self._model_hparams if hp.mode != tf.estimator.ModeKeys.PREDICT: tf.summary.image( "inputs", common_layers.tpu_safe_image_summary(inputs), max_outputs=2) inputs = common_layers.convert_rgb_to_symmetric_real(inputs) # Reshape inputs to apply convolutions across [img_len, img_len*channels]. inputs_shape = common_layers.shape_list(inputs) inputs = tf.reshape( inputs, [-1, inputs_shape[1], inputs_shape[2] * inputs_shape[3], 1]) # Compress RGB intensities for each pixel using a convolution. outputs = tf.layers.conv2d( inputs, self._model_hparams.hidden_size, kernel_size=(1, self.num_channels), padding="VALID", strides=(1, self.num_channels), activation=tf.nn.relu, name="conv_input") return outputs
def bottom_compress(self, inputs, name="bottom"): """Transform input from data space to model space. Perform conversion of RGB pixel values to a real number in the range -1 to 1 and combine channel values for each pixel to form a representation of size image_length x image_length dims. Args: inputs: A Tensor representing RGB pixel intensities as integers. [batch, ...] name: string, scope. Returns: body_input: A Tensor with shape [batch, ?, ?, body_input_depth]. """ with tf.variable_scope(name): inputs = tf.to_float(inputs) hp = self._model_hparams if hp.mode != tf.estimator.ModeKeys.PREDICT: tf.summary.image("inputs", common_layers.tpu_safe_image_summary(inputs), max_outputs=2) inputs = common_layers.convert_rgb_to_symmetric_real(inputs) ishape = common_layers.shape_list(inputs) inputs = tf.reshape(inputs, [-1, ishape[1], ishape[2] * ishape[3], 1]) inputs.set_shape([None, None, None, 1]) # We compress RGB intensities for each pixel using a conv. x = tf.layers.conv2d(inputs, self._body_input_depth, (1, self.num_channels), padding="VALID", strides=(1, self.num_channels), activation=tf.nn.relu, name="conv_input") x.set_shape([None, None, None, self._body_input_depth]) return x
def bottom_compress(self, inputs, name="bottom"): """Transform input from data space to model space. Perform conversion of RGB pixel values to a real number in the range -1 to 1 and combine channel values for each pixel to form a representation of size image_length x image_length dims. Args: inputs: A Tensor representing RGB pixel intensities as integers. [batch, ...] name: string, scope. Returns: body_input: A Tensor with shape [batch, ?, ?, body_input_depth]. """ with tf.variable_scope(name): inputs = tf.to_float(inputs) hp = self._model_hparams if hp.mode != tf.estimator.ModeKeys.PREDICT: tf.summary.image( "inputs", common_layers.tpu_safe_image_summary(inputs), max_outputs=2) inputs = common_layers.convert_rgb_to_symmetric_real(inputs) ishape = common_layers.shape_list(inputs) inputs = tf.reshape(inputs, [-1, ishape[1], ishape[2] * ishape[3], 1]) inputs.set_shape([None, None, None, 1]) # We compress RGB intensities for each pixel using a conv. x = tf.layers.conv2d( inputs, self._body_input_depth, (1, self.num_channels), padding="VALID", strides=(1, self.num_channels), activation=tf.nn.relu, name="conv_input") x.set_shape([None, None, None, self._body_input_depth]) return x
def body(self, features): hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN if hparams.mode != tf.estimator.ModeKeys.PREDICT: x = features["targets"] shape = common_layers.shape_list(x) is1d = shape[2] == 1 self.is1d = is1d # Run encoder. x = self.encoder(x) # Bottleneck (mix during early training, not too important but stable). b, b_loss = self.bottleneck(x) self._cur_bottleneck_tensor = b b = self.unbottleneck(b, common_layers.shape_list(x)[-1]) b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training) if hparams.gan_loss_factor != 0.0: # Add a purely sampled batch on which we'll compute the GAN loss. g = self.unbottleneck( self.sample(), common_layers.shape_list(x)[-1], reuse=True) b = tf.concat([g, b], axis=0) # With probability bottleneck_max_prob use the bottleneck, otherwise x. if hparams.bottleneck_max_prob < -1.0: x = tf.where( tf.less(tf.random_uniform([]), hparams.bottleneck_max_prob), b, x) else: x = b else: if self._cur_bottleneck_tensor is None: b = self.sample() else: b = self._cur_bottleneck_tensor res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. x = self.decoder(x) if hparams.mode == tf.estimator.ModeKeys.PREDICT: return x, {"bottleneck_loss": 0.0} # Cut to the right size and mix before returning. res = x[:, :shape[1], :shape[2], :] # Add GAN loss if requested. gan_loss = 0.0 if hparams.gan_loss_factor != 0.0: # Split back if we added a purely sampled batch. res_gan, res = tf.split(res, 2, axis=0) num_channels = self.hparams.problem.num_channels res_rgb = common_layers.convert_real_to_rgb( tf.nn.sigmoid(tf.layers.dense(res_gan, num_channels, name="gan_rgb"))) tf.summary.image( "gan", common_layers.tpu_safe_image_summary(res_rgb), max_outputs=1) orig_rgb = tf.to_float(features["targets_raw"]) def discriminate(x): return self.discriminator(x, is_training=is_training) gan_loss = common_layers.sliced_gan_loss(orig_rgb, reverse_gradient(res_rgb), discriminate, self.hparams.num_sliced_vecs) gan_loss *= hparams.gan_loss_factor # Mix the final result and return. res = common_layers.mix(res, features["targets"], hparams.bottleneck_warmup_steps // 2, is_training) return res, {"bottleneck_loss": b_loss, "gan_loss": -gan_loss}
def bottom(self, x): with tf.variable_scope(self.name): if not tf.contrib.eager.in_eager_mode(): tf.summary.image( "inputs", common_layers.tpu_safe_image_summary(x), max_outputs=2) return tf.to_float(x)
def body(self, features): hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN if hparams.mode != tf.estimator.ModeKeys.PREDICT: x = features["targets"] labels = features["targets_raw"] shape = common_layers.shape_list(x) is1d = shape[2] == 1 self.is1d = is1d # Run encoder. x = self.encoder(x) # Bottleneck (mix during early training, not too important but stable). b, b_loss = self.bottleneck(x) self._cur_bottleneck_tensor = b b = self.unbottleneck(b, common_layers.shape_list(x)[-1]) b = common_layers.mix(b, x, hparams.bottleneck_warmup_steps, is_training) if hparams.gan_loss_factor != 0.0: # Add a purely sampled batch on which we'll compute the GAN loss. g = self.unbottleneck( self.sample(), common_layers.shape_list(x)[-1], reuse=True) b = tf.concat([g, b], axis=0) # With probability bottleneck_max_prob use the bottleneck, otherwise x. if hparams.bottleneck_max_prob < -1.0: x = tf.where( tf.less(tf.random_uniform([]), hparams.bottleneck_max_prob), b, x) else: x = b else: if self._cur_bottleneck_tensor is None: b = self.sample() else: b = self._cur_bottleneck_tensor res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. x = self.decoder(x) if hparams.mode == tf.estimator.ModeKeys.PREDICT: return x, {"bottleneck_loss": 0.0} # Cut to the right size and mix before returning. res = x[:, :shape[1], :shape[2], :] is_image = isinstance(self.hparams.problem, image_utils.ImageProblem) if is_image: vocab_size = self.hparams.problem.vocab_size res = tf.layers.dense( res, self.hparams.problem.num_channels * self.hparams.hidden_size) output_shape = common_layers.shape_list(res)[:-1] + [ self.hparams.problem.num_channels, self.hparams.hidden_size ] res = tf.reshape(res, output_shape) elif isinstance(self.hparams.problem, text_problems.Text2TextProblem): vocab_size = self._problem_hparams.target_modality.top_dimensionality res = tf.layers.dense(res, self.hparams.hidden_size) else: raise Exception("Unsupported problem type: %s" % self.hparams.problem) one_hot_labels = tf.one_hot(labels, vocab_size) code_loss_gan = 0.0 if hparams.gan_loss_factor != 0.0: res_gan, res = tf.split(res, 2, axis=0) with tf.variable_scope("vq"): reconstr_gan, _, code_loss_gan, _ = discretization.vq_loss( res, one_hot_labels, vocab_size) with tf.variable_scope("vq", reuse=tf.AUTO_REUSE): reconstr, target_codes, code_loss, targets_loss = discretization.vq_loss( res, one_hot_labels, vocab_size) # Add GAN loss if requested. gan_loss = 0.0 if hparams.gan_loss_factor != 0.0: if is_image: tf.summary.image( "gan", common_layers.tpu_safe_image_summary(tf.argmax(reconstr_gan, -1)), max_outputs=1) def discriminate(x): return self.discriminator(x, is_training=is_training) gan_loss = common_layers.sliced_gan_loss(target_codes, reverse_gradient(res_gan), discriminate, self.hparams.num_sliced_vecs) gan_loss *= hparams.gan_loss_factor if is_image: tf.summary.image( "ae", common_layers.tpu_safe_image_summary(tf.argmax(reconstr, -1)), max_outputs=1) return reconstr, { "training": targets_loss, "code_loss": code_loss, "code_loss_gan": code_loss_gan, "b_loss": b_loss, "gan_loss": -gan_loss }