def loss_iw(self, logits, features): if isinstance(logits, dict): losses = {} for k, v in six.iteritems(logits): losses[k] = self._loss_single_iw(v, k, features[k], weights=features.get(k + "_mask")) n, d = losses[k] if common_layers.should_generate_summaries(): tf.summary.scalar(k + "_loss", n / d) tf.summary.scalar(k + "_loss_num", n) tf.summary.scalar(k + "_loss_den", d) if getattr(self.hparams, "visualize_logits_histogram", False): hist = tf.summary.histogram hist(k + "_predict", tf.argmax(tf.squeeze(v), axis=-1)) hist(k + "_targets", features[k]) return tf.add_n([n / d for n, d in losses.values()]) else: return self._loss_single_iw(logits, "targets", features["targets"], weights=features.get("targets_mask"))
def optimize(loss, learning_rate, hparams, use_tpu=False): """Minimize loss.""" loss = weight_decay_and_noise(loss, hparams, learning_rate) loss = tf.identity(loss, name="total_loss") log_variable_sizes(verbose=hparams.summarize_vars) if hparams.summarize_vars: summarize_variables() diet_vars = [ v for v in tf.global_variables() if v.dtype == dtypes.float16_ref ] log_variable_sizes(diet_vars, "Diet Variables", verbose=hparams.summarize_vars) opt = ConditionalOptimizer(hparams.optimizer, learning_rate, hparams, use_tpu) if use_tpu: opt = tf.contrib.tpu.CrossShardOptimizer(opt) opt_summaries = [] if common_layers.should_generate_summaries(): tf.summary.scalar("learning_rate", learning_rate) opt_summaries = ["loss"] if hparams.summarize_grads and common_layers.should_generate_summaries(): tf.logging.info("Summarizing gradients") opt_summaries.extend( ["gradients", "gradient_norm", "global_gradient_norm"]) if hparams.clip_grad_norm: tf.logging.info("Clipping gradients, norm: %0.5f", hparams.clip_grad_norm) if hparams.grad_noise_scale: tf.logging.info("Adding noise to gradients, noise scale: %0.5f", hparams.grad_noise_scale) train_op = tf.contrib.layers.optimize_loss( name="training", loss=loss, global_step=tf.train.get_or_create_global_step(), learning_rate=learning_rate, clip_gradients=hparams.clip_grad_norm or None, gradient_noise_scale=hparams.grad_noise_scale or None, optimizer=opt, summaries=opt_summaries, colocate_gradients_with_ops=True) return train_op
def decoder(self, x, encoder_layers=None): with tf.variable_scope("decoder"): hparams = self.hparams is_training = self.hparams.mode == tf.estimator.ModeKeys.TRAIN kernel, strides = self._get_kernel_and_strides() residual_kernel = (hparams.residual_kernel_height, hparams.residual_kernel_width) residual_kernel1d = (hparams.residual_kernel_height, 1) residual_kernel = residual_kernel1d if self.is1d else residual_kernel residual_conv = tf.layers.conv2d if hparams.residual_use_separable_conv: residual_conv = tf.layers.separable_conv2d # Up-convolutions. for i in range(hparams.num_hidden_layers): j = hparams.num_hidden_layers - i - 1 if is_training: nomix_p = common_layers.inverse_lin_decay( int(hparams.bottleneck_warmup_steps * 0.25 * 2**j)) + 0.01 if common_layers.should_generate_summaries(): tf.summary.scalar("nomix_p_%d" % j, nomix_p) filters = hparams.hidden_size * 2**j filters = min(filters, hparams.max_hidden_size) with tf.variable_scope("layer_%d" % i): j = hparams.num_hidden_layers - i - 1 x = tf.layers.conv2d_transpose( x, filters, kernel, strides=strides, padding="SAME", activation=common_layers.belu, name="strided") y = x for r in range(hparams.num_residual_layers): residual_filters = filters if r < hparams.num_residual_layers - 1: residual_filters = int( filters * hparams.residual_filter_multiplier) y = residual_conv( y, residual_filters, residual_kernel, padding="SAME", activation=common_layers.belu, name="residual_%d" % r) x += tf.nn.dropout(y, 1.0 - hparams.residual_dropout) x = common_layers.layer_norm(x, name="ln") x = common_attention.add_timing_signal_nd(x) if encoder_layers is not None: enc_x = encoder_layers[j] enc_shape = common_layers.shape_list(enc_x) x_mix = x[:enc_shape[0], :enc_shape[1], :enc_shape[2], :] if is_training: # Mix at the beginning of training. rand = tf.random_uniform(common_layers.shape_list(x_mix)) x_mix = tf.where(tf.less(rand, nomix_p), x_mix, enc_x) x = x_mix return x
def optimize(loss, learning_rate, hparams, use_tpu=False): """Minimize loss.""" loss = weight_decay_and_noise(loss, hparams, learning_rate) loss = tf.identity(loss, name="total_loss") # Print trainable variables. log_variable_sizes(verbose=hparams.summarize_vars) # Print non-trainable variables. non_trainable_variables = list( set(tf.global_variables()) - set(tf.trainable_variables())) log_variable_sizes(non_trainable_variables, tag="Non-trainable variables", verbose=hparams.summarize_vars) if hparams.summarize_vars: summarize_variables() # Summarize non-trainable variables as well summarize_variables(non_trainable_variables, tag="Non-trainable variables") diet_vars = [ v for v in tf.global_variables() if v.dtype == dtypes.float16_ref ] log_variable_sizes( diet_vars, "Diet Variables", verbose=hparams.summarize_vars) opt = ConditionalOptimizer(hparams.optimizer, learning_rate, hparams, use_tpu) if use_tpu: opt = tf.contrib.tpu.CrossShardOptimizer(opt) opt_summaries = [] if common_layers.should_generate_summaries(): tf.summary.scalar("learning_rate", learning_rate) opt_summaries.append("loss") if hparams.summarize_grads: tf.logging.info("Summarizing gradients") opt_summaries.extend( ["gradients", "gradient_norm", "global_gradient_norm"]) if hparams.clip_grad_norm: tf.logging.info("Clipping gradients, norm: %0.5f", hparams.clip_grad_norm) if hparams.grad_noise_scale: tf.logging.info("Adding noise to gradients, noise scale: %0.5f", hparams.grad_noise_scale) train_op = tf.contrib.layers.optimize_loss( name="training", loss=loss, global_step=tf.train.get_or_create_global_step(), learning_rate=learning_rate, clip_gradients=hparams.clip_grad_norm or None, gradient_noise_scale=hparams.grad_noise_scale or None, optimizer=opt, summaries=opt_summaries, colocate_gradients_with_ops=True) return train_op
def weight_decay_and_noise(loss, hparams, learning_rate, var_list=None): """Apply weight decay and weight noise.""" if var_list is None: var_list = tf.trainable_variables() decay_vars = [v for v in var_list] noise_vars = [v for v in var_list if "/body/" in v.name] weight_decay_loss = weight_decay(hparams.weight_decay, decay_vars) if hparams.weight_decay and common_layers.should_generate_summaries(): tf.summary.scalar("losses/weight_decay", weight_decay_loss) weight_noise_ops = weight_noise(hparams.weight_noise, learning_rate, noise_vars) with tf.control_dependencies(weight_noise_ops): loss = tf.identity(loss) loss += weight_decay_loss return loss
def weight_noise(noise_rate, learning_rate, var_list): """Apply weight noise to vars in var_list.""" if not noise_rate: return [tf.no_op()] tf.logging.info("Applying weight noise scaled by learning rate, " "noise_rate: %0.5f", noise_rate) noise_ops = [] for v in var_list: with tf.device(v.device): # pylint: disable=protected-access scale = noise_rate * learning_rate * 0.001 if common_layers.should_generate_summaries(): tf.summary.scalar("weight_noise_scale", scale) noise = tf.truncated_normal(v.shape) * scale noise_op = v.assign_add(noise) noise_ops.append(noise_op) return noise_ops
def vq_gating(x, num_experts, k, bneck, hparams=None, name="vq_gating"): """VQ gating. Args: x: input Tensor with shape [batch_size, input_size] num_experts: an integer k: an integer - number of experts per example bneck: a bottleneck object hparams: optional hparams name: an optional string Returns: gates: a Tensor with shape [batch_size, num_experts] load: a Tensor with shape [num_experts] """ with tf.variable_scope(name, reuse=tf.AUTO_REUSE): if hparams.use_scales: scales = tf.get_variable("scales", [num_experts], tf.float32, initializer=tf.ones_initializer()) scales = tf.nn.softmax(scales) hparams.scales = scales input_size = x.get_shape().as_list()[-1] batch_size = common_layers.shape_list(x)[0] if k > 1: # first project into two dense layers, chop and discretize, and gate # TODO(avaswani): Maybe scale the embeddings flowing out of the experts. # We might want to do this to match the computation being done by topk x = tf.layers.dense(x, input_size * k) # x goes from [batch_size, input_size*k] to [batch_size*k, input_size] x = tf.reshape(x, [batch_size * k, input_size]) inputs = tf.expand_dims(x, axis=1) inputs = tf.expand_dims(inputs, axis=1) # VQ hparams hparams.z_size = int(math.log(num_experts, 2)) hparams.hidden_size = input_size hparams.top_k = k d = bneck.discrete_bottleneck(inputs) centroids = None exp_discrete = d["discrete"] embed_lookup = d["embed"] extra_loss = d["loss"] if hparams.residual_centroids: centroids = embed_lookup(exp_discrete) # gives the centroids top_k_indices = tf.squeeze(exp_discrete, axis=1) tf.summary.histogram("discrete_counts", top_k_indices) # if k > 1, then we need to reshape top_k_indices from [batch_size*k, 1] # to [batch_size, k] if k > 1: top_k_indices = tf.reshape(top_k_indices, [batch_size, k]) # get the top k gates top_k_gates = tf.ones([batch_size, k]) # This will be a `Tensor` of shape `[batch_size, n]`, with zeros in the # positions corresponding to all but the top k experts per example. gates = _rowwise_unsorted_segment_sum(top_k_gates, top_k_indices, num_experts) # Compute count per expert from the gates. # gates has shape [batch_size, num_experts] # count per expert has shape [num_experts, 1] count_per_expert = tf.reduce_sum(gates, axis=0) if hparams.use_scales: scale_loss = tf.reduce_mean(tf.to_float(count_per_expert) * scales) extra_loss += scale_loss if common_layers.should_generate_summaries(): tf.summary.histogram("vq_loss", extra_loss) tf.summary.historgram("scale_loss", scale_loss) return gates, extra_loss, centroids
def optimize(loss, learning_rate, hparams, use_tpu=False, variables=None): """Minimize loss.""" loss = weight_decay_and_noise(loss, hparams, learning_rate) if hparams.get('shs_regularization', default=None) is not None: if hparams.shs_regularization: loss = weight_group_hoyer_square(loss, hparams) if hparams.get('ssl_regularization', default=None) is not None: if hparams.ssl_regularization: loss = weight_group_lasso(loss, hparams) loss = tf.identity(loss, name="total_loss") if variables is None: variables = tf.trainable_variables() # Print trainable variables. log_variable_sizes(variables, verbose=hparams.summarize_vars) # Print non-trainable variables. non_trainable_variables = list(set(tf.global_variables()) - set(variables)) log_variable_sizes(non_trainable_variables, tag="Non-trainable variables", verbose=hparams.summarize_vars) if hparams.summarize_vars: summarize_variables(variables) # Summarize non-trainable variables as well summarize_variables(non_trainable_variables, tag="Non-trainable variables") diet_vars = [ v for v in tf.global_variables() if v.dtype == dtypes.float16_ref ] log_variable_sizes(diet_vars, "Diet Variables", verbose=hparams.summarize_vars) opt = ConditionalOptimizer(hparams.optimizer, learning_rate, hparams, use_tpu) if use_tpu: opt = tf.contrib.tpu.CrossShardOptimizer(opt) opt_summaries = [] if common_layers.should_generate_summaries(): tf.summary.scalar("learning_rate", learning_rate) opt_summaries.append("loss") if hparams.summarize_grads: tf.logging.info("Summarizing gradients") opt_summaries.extend( ["gradients", "gradient_norm", "global_gradient_norm"]) if hparams.clip_grad_norm: tf.logging.info("Clipping gradients, norm: %0.5f", hparams.clip_grad_norm) if hparams.grad_noise_scale: tf.logging.info("Adding noise to gradients, noise scale: %0.5f", hparams.grad_noise_scale) train_op = tf.contrib.layers.optimize_loss( name="training", loss=loss, global_step=tf.train.get_or_create_global_step(), learning_rate=learning_rate, clip_gradients=hparams.clip_grad_norm or None, gradient_noise_scale=hparams.grad_noise_scale or None, optimizer=opt, summaries=opt_summaries, colocate_gradients_with_ops=True, variables=variables) return train_op
def optimize(loss, learning_rate, hparams, use_tpu=False, variables=None): """Minimize loss.""" loss = weight_decay_and_noise(loss, hparams, learning_rate) loss = tf.identity(loss, name="total_loss") if variables is None: variables = tf.trainable_variables() # Print trainable variables. log_variable_sizes(variables, verbose=hparams.summarize_vars) # Print non-trainable variables. non_trainable_variables = list(set(tf.global_variables()) - set(variables)) log_variable_sizes(non_trainable_variables, tag="Non-trainable variables", verbose=hparams.summarize_vars) if hparams.summarize_vars: summarize_variables(variables) # Summarize non-trainable variables as well summarize_variables(non_trainable_variables, tag="Non-trainable variables") diet_vars = [ v for v in tf.global_variables() if v.dtype == dtypes.float16_ref ] log_variable_sizes(diet_vars, "Diet Variables", verbose=hparams.summarize_vars) opt = ConditionalOptimizer(hparams.optimizer, learning_rate, hparams, use_tpu) if use_tpu: opt = tf.contrib.tpu.CrossShardOptimizer(opt) if getattr(hparams, "gpu_automatic_mixed_precision", False): if use_tpu: raise RuntimeError( "GPU auto mixed precision cannot be used with TPU") elif _mixed_precision_is_enabled(hparams): raise RuntimeError( "GPU auto mixed precision cannot be used with manual mixed precision" ) else: setattr(opt, "_use_locking", "True") setattr(opt, "_name", "ConditionalOptimizer") opt = tf.train.experimental.enable_mixed_precision_graph_rewrite( opt) opt_summaries = [] if common_layers.should_generate_summaries(): tf.summary.scalar("learning_rate", learning_rate) opt_summaries.append("loss") if hparams.summarize_grads: tf.logging.info("Summarizing gradients") opt_summaries.extend( ["gradients", "gradient_norm", "global_gradient_norm"]) if hparams.clip_grad_norm: tf.logging.info("Clipping gradients, norm: %0.5f", hparams.clip_grad_norm) if hparams.grad_noise_scale: tf.logging.info("Adding noise to gradients, noise scale: %0.5f", hparams.grad_noise_scale) train_op = tf.contrib.layers.optimize_loss( name="training", loss=loss, global_step=tf.train.get_or_create_global_step(), learning_rate=learning_rate, clip_gradients=hparams.clip_grad_norm or None, gradient_noise_scale=hparams.grad_noise_scale or None, optimizer=opt, summaries=opt_summaries, colocate_gradients_with_ops=True, variables=variables) return train_op
def body(self, features): hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN vocab_size = self._problem_hparams.vocab_size["targets"] if hasattr(self._hparams, "vocab_divisor"): vocab_size += (-vocab_size) % self._hparams.vocab_divisor encoder_layers = None self.is1d = hparams.sample_width == 1 if (hparams.mode != tf.estimator.ModeKeys.PREDICT or self._encode_on_predict): labels = features["targets_raw"] labels_shape = common_layers.shape_list(labels) # handle videos if len(labels.shape) == 5: labels = time_to_channels(labels) shape = common_layers.shape_list(labels) x = tf.one_hot(labels, vocab_size) x = self.embed(x) target_codes = x if shape[2] == 1: self.is1d = True # Run encoder. x, encoder_layers = self.encoder(x) # Bottleneck. b, b_loss = self.bottleneck(x) xb_loss = 0.0 b_shape = common_layers.shape_list(b) self._cur_bottleneck_tensor = b res_size = common_layers.shape_list(x)[-1] b = self.unbottleneck(b, res_size) if not is_training: x = b else: l = 2**hparams.num_hidden_layers warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l) nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01 if common_layers.should_generate_summaries(): tf.summary.scalar("nomix_p_bottleneck", nomix_p) rand = tf.random_uniform(common_layers.shape_list(x)) # This is the distance between b and x. Having this as loss helps learn # the bottleneck function, but if we back-propagated to x it would be # minimized by just setting x=0 and b=0 -- so we don't want too much # of the influence of this, and we stop-gradient to not zero-out x. x_stop = tf.stop_gradient(x) xb_loss = tf.reduce_mean(tf.reduce_sum( tf.squared_difference(x_stop, b), axis=-1)) # To prevent this loss from exploding we clip at 1, but anneal clipping. clip_max = 1.0 / common_layers.inverse_exp_decay( warm_step, min_value=0.001) xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max) xb_loss *= clip_max / xb_clip x = tf.where(tf.less(rand, nomix_p), b, x) if hparams.gan_loss_factor != 0.0: # Add a purely sampled batch on which we'll compute the GAN loss. g = self.unbottleneck( self.sample(shape=b_shape), common_layers.shape_list(x)[-1], reuse=True) x = tf.concat([x, g], axis=0) else: if self._cur_bottleneck_tensor is None: b = self.sample() else: b = self._cur_bottleneck_tensor self._cur_bottleneck_tensor = b res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. x = self.decoder(x, encoder_layers) # Cut to the right size and mix before returning. res = x if hparams.mode != tf.estimator.ModeKeys.PREDICT: res = x[:, :shape[1], :shape[2], :] # Final dense layer. res = tf.layers.dense( res, self.num_channels * hparams.hidden_size, name="res_dense") output_shape = common_layers.shape_list(res)[:-1] + [ self.num_channels, self.hparams.hidden_size ] res = tf.reshape(res, output_shape) if hparams.mode == tf.estimator.ModeKeys.PREDICT: if hparams.use_vq_loss: (reconstr, _, _, _, _) = discretization.vq_loss(res, labels, vocab_size) else: reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final") return reconstr, {"bottleneck_loss": 0.0} if hparams.gan_loss_factor != 0.0: res, res_gan = tf.split(res, 2, axis=0) # Losses. losses = { "bottleneck_extra": b_loss, "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss } if hparams.use_vq_loss: vq_temperature = hparams.vq_temperature / common_layers.inverse_exp_decay( hparams.gan_codes_warmup_steps * 1.2, min_value=hparams.vq_temperature * 2) if hparams.mode != tf.estimator.ModeKeys.TRAIN: vq_temperature = None with tf.variable_scope("vq_loss"): (reconstr, _, target_codes, code_loss, targets_loss) = discretization.vq_loss( res, labels, vocab_size, temperature=vq_temperature) losses["code_loss"] = code_loss * hparams.code_loss_factor losses["training"] = targets_loss else: reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final") targets_loss = tf.losses.sparse_softmax_cross_entropy( logits=tf.reshape(reconstr, labels_shape + [vocab_size]), labels=tf.reshape(labels, labels_shape)) losses["training"] = targets_loss # GAN losses. if hparams.gan_loss_factor != 0.0: update_means_factor = common_layers.inverse_exp_decay( hparams.gan_codes_warmup_steps, min_value=0.0001) if hparams.use_vq_loss: with tf.variable_scope("vq_loss", reuse=True): update_means = tf.less(tf.random_uniform([]), update_means_factor) reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss( res_gan, labels, vocab_size, do_update=update_means, temperature=vq_temperature) reconstr_gan_nonoise = reconstr_gan code_loss_gan *= hparams.code_loss_factor * update_means_factor losses["code_loss_gan"] = code_loss_gan else: reconstr_gan = tf.layers.dense( res_gan, vocab_size, name="autoencoder_final", reuse=True) reconstr_gan_nonoise = reconstr_gan reconstr_gan = self.gumbel_sample(reconstr_gan) # Embed to codes. gan_codes = self.embed(reconstr_gan) # Add GAN loss if requested. gan_loss = 0.0 if hparams.gan_loss_factor != 0.0: self.image_summary("gan", reconstr_gan_nonoise) def discriminate(x): """Run a dioscriminator depending on the hparams.""" if hparams.discriminator == "default": return common_layers.deep_discriminator( x, hparams.discriminator_batchnorm, is_training) elif hparams.discriminator == "patched": return common_layers.patch_discriminator(x) elif hparams.discriminator == "single": return common_layers.single_discriminator( x, hparams.discriminator_size, hparams.discriminator_kernel_size, hparams.discriminator_strides, pure_mean=hparams.discriminator_pure_mean) elif hparams.discriminator == "double": return common_layers.double_discriminator( x, hparams.discriminator_size, hparams.discriminator_kernel_size, hparams.discriminator_strides, pure_mean=hparams.discriminator_pure_mean) else: raise Exception("Unknown discriminator %s" % hparams.discriminator) tc_shape = common_layers.shape_list(target_codes) if len(tc_shape) > 4: target_codes = tf.reshape(target_codes, tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]]) gan_codes = tf.reshape(gan_codes, tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]]) gan_lr = common_layers.inverse_exp_decay( hparams.gan_codes_warmup_steps * 1.5) rev_grad_gan_codes = reverse_gradient(gan_codes, lr=gan_lr) gan_loss = common_layers.sliced_gan_loss( target_codes, rev_grad_gan_codes, discriminate, self.hparams.num_sliced_vecs, do_tanh=hparams.sliced_do_tanh) gan_loss *= hparams.gan_loss_factor * update_means_factor losses["gan_loss"] = -gan_loss self.image_summary("ae", reconstr) logits = tf.reshape(reconstr, labels_shape + [vocab_size]) return logits, losses
def body(self, features): hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN vocab_size = self._problem_hparams.modality["targets"].top_dimensionality encoder_layers = None self.is1d = hparams.sample_width == 1 if (hparams.mode != tf.estimator.ModeKeys.PREDICT or self._encode_on_predict): labels = features["targets_raw"] labels_shape = common_layers.shape_list(labels) # handle videos if len(labels.shape) == 5: labels = time_to_channels(labels) shape = common_layers.shape_list(labels) x = tf.one_hot(labels, vocab_size) x = self.embed(x) target_codes = x if shape[2] == 1: self.is1d = True # Run encoder. x, encoder_layers = self.encoder(x) # Bottleneck. b, b_loss = self.bottleneck(x) xb_loss = 0.0 b_shape = common_layers.shape_list(b) self._cur_bottleneck_tensor = b res_size = common_layers.shape_list(x)[-1] b = self.unbottleneck(b, res_size) if not is_training: x = b else: l = 2**hparams.num_hidden_layers warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l) nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01 if common_layers.should_generate_summaries(): tf.summary.scalar("nomix_p_bottleneck", nomix_p) rand = tf.random_uniform(common_layers.shape_list(x)) # This is the distance between b and x. Having this as loss helps learn # the bottleneck function, but if we back-propagated to x it would be # minimized by just setting x=0 and b=0 -- so we don't want too much # of the influence of this, and we stop-gradient to not zero-out x. x_stop = tf.stop_gradient(x) xb_loss = tf.reduce_mean(tf.reduce_sum(tf.square(x_stop - b), axis=-1)) # To prevent this loss from exploding we clip at 1, but anneal clipping. clip_max = 1.0 / common_layers.inverse_exp_decay( warm_step, min_value=0.001) xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max) xb_loss *= clip_max / xb_clip x = tf.where(tf.less(rand, nomix_p), b, x) if hparams.gan_loss_factor != 0.0: # Add a purely sampled batch on which we'll compute the GAN loss. g = self.unbottleneck( self.sample(shape=b_shape), common_layers.shape_list(x)[-1], reuse=True) x = tf.concat([x, g], axis=0) else: if self._cur_bottleneck_tensor is None: b = self.sample() else: b = self._cur_bottleneck_tensor self._cur_bottleneck_tensor = b res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. x = self.decoder(x, encoder_layers) # Cut to the right size and mix before returning. res = x if hparams.mode != tf.estimator.ModeKeys.PREDICT: res = x[:, :shape[1], :shape[2], :] # Final dense layer. res = tf.layers.dense( res, self.num_channels * hparams.hidden_size, name="res_dense") output_shape = common_layers.shape_list(res)[:-1] + [ self.num_channels, self.hparams.hidden_size ] res = tf.reshape(res, output_shape) if hparams.mode == tf.estimator.ModeKeys.PREDICT: if hparams.use_vq_loss: (reconstr, _, _, _, _) = discretization.vq_loss(res, labels, vocab_size) else: reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final") return reconstr, {"bottleneck_loss": 0.0} if hparams.gan_loss_factor != 0.0: res, res_gan = tf.split(res, 2, axis=0) # Losses. losses = { "bottleneck_extra": b_loss, "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss } if hparams.use_vq_loss: vq_temperature = hparams.vq_temperature / common_layers.inverse_exp_decay( hparams.gan_codes_warmup_steps * 1.2, min_value=hparams.vq_temperature * 2) if hparams.mode != tf.estimator.ModeKeys.TRAIN: vq_temperature = None with tf.variable_scope("vq_loss"): (reconstr, _, target_codes, code_loss, targets_loss) = discretization.vq_loss( res, labels, vocab_size, temperature=vq_temperature) losses["code_loss"] = code_loss * hparams.code_loss_factor losses["training"] = targets_loss else: reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final") targets_loss = tf.losses.sparse_softmax_cross_entropy( logits=tf.reshape(reconstr, labels_shape + [vocab_size]), labels=tf.reshape(labels, labels_shape)) losses["training"] = targets_loss # GAN losses. if hparams.gan_loss_factor != 0.0: update_means_factor = common_layers.inverse_exp_decay( hparams.gan_codes_warmup_steps, min_value=0.0001) if hparams.use_vq_loss: with tf.variable_scope("vq_loss", reuse=True): update_means = tf.less(tf.random_uniform([]), update_means_factor) reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss( res_gan, labels, vocab_size, do_update=update_means, temperature=vq_temperature) reconstr_gan_nonoise = reconstr_gan code_loss_gan *= hparams.code_loss_factor * update_means_factor losses["code_loss_gan"] = code_loss_gan else: reconstr_gan = tf.layers.dense( res_gan, vocab_size, name="autoencoder_final", reuse=True) reconstr_gan_nonoise = reconstr_gan reconstr_gan = self.gumbel_sample(reconstr_gan) # Embed to codes. gan_codes = self.embed(reconstr_gan) # Add GAN loss if requested. gan_loss = 0.0 if hparams.gan_loss_factor != 0.0: self.image_summary("gan", reconstr_gan_nonoise) def discriminate(x): """Run a dioscriminator depending on the hparams.""" if hparams.discriminator == "default": return common_layers.deep_discriminator( x, hparams.discriminator_batchnorm, is_training) elif hparams.discriminator == "patched": return common_layers.patch_discriminator(x) elif hparams.discriminator == "single": return common_layers.single_discriminator( x, hparams.discriminator_size, hparams.discriminator_kernel_size, hparams.discriminator_strides, pure_mean=hparams.discriminator_pure_mean) elif hparams.discriminator == "double": return common_layers.double_discriminator( x, hparams.discriminator_size, hparams.discriminator_kernel_size, hparams.discriminator_strides, pure_mean=hparams.discriminator_pure_mean) else: raise Exception("Unknown discriminator %s" % hparams.discriminator) tc_shape = common_layers.shape_list(target_codes) if len(tc_shape) > 4: target_codes = tf.reshape(target_codes, tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]]) gan_codes = tf.reshape(gan_codes, tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]]) gan_lr = common_layers.inverse_exp_decay( hparams.gan_codes_warmup_steps * 1.5) rev_grad_gan_codes = reverse_gradient(gan_codes, lr=gan_lr) gan_loss = common_layers.sliced_gan_loss( target_codes, rev_grad_gan_codes, discriminate, self.hparams.num_sliced_vecs, do_tanh=hparams.sliced_do_tanh) gan_loss *= hparams.gan_loss_factor * update_means_factor losses["gan_loss"] = -gan_loss self.image_summary("ae", reconstr) logits = tf.reshape(reconstr, labels_shape + [vocab_size]) return logits, losses
def decoder(self, x, encoder_layers=None): with tf.variable_scope("decoder"): hparams = self.hparams is_training = self.hparams.mode == tf.estimator.ModeKeys.TRAIN kernel, strides = self._get_kernel_and_strides() residual_kernel = (hparams.residual_kernel_height, hparams.residual_kernel_width) residual_kernel1d = (hparams.residual_kernel_height, 1) residual_kernel = residual_kernel1d if self.is1d else residual_kernel residual_conv = tf.layers.conv2d if hparams.residual_use_separable_conv: residual_conv = tf.layers.separable_conv2d # Up-convolutions. for i in range(hparams.num_hidden_layers): j = hparams.num_hidden_layers - i - 1 if is_training: nomix_p = common_layers.inverse_lin_decay( int(hparams.bottleneck_warmup_steps * 0.25 * 2**j)) + 0.01 if common_layers.should_generate_summaries(): tf.summary.scalar("nomix_p_%d" % j, nomix_p) filters = hparams.hidden_size * 2**j filters = min(filters, hparams.max_hidden_size) with tf.variable_scope("layer_%d" % i): j = hparams.num_hidden_layers - i - 1 x = tf.layers.conv2d_transpose( x, filters, kernel, strides=strides, padding="SAME", activation=common_layers.belu, name="strided") y = x for r in range(hparams.num_residual_layers): residual_filters = filters if r < hparams.num_residual_layers - 1: residual_filters = int( filters * hparams.residual_filter_multiplier) y = residual_conv( y, residual_filters, residual_kernel, padding="SAME", activation=common_layers.belu, name="residual_%d" % r) x += tf.nn.dropout(y, 1.0 - hparams.residual_dropout) x = common_layers.layer_norm(x, name="ln") x = common_attention.add_timing_signal_nd(x) if encoder_layers is not None: enc_x = encoder_layers[j] enc_shape = common_layers.shape_list(enc_x) x_mix = x[:enc_shape[0], :enc_shape[1], :enc_shape[2], :] if is_training: # Mix at the beginning of training. rand = tf.random_uniform(common_layers.shape_list(x_mix)) x_mix = tf.where(tf.less(rand, nomix_p), x_mix, enc_x) if hparams.gan_loss_factor != 0: x_gan = x[enc_shape[0]:, :enc_shape[1], :enc_shape[2], :] x = tf.concat([x_mix, x_gan], axis=0) else: x = x_mix return x
def autoencoder_body(self, features): """ Customized body function for autoencoders acting on continuous images. This is based on tensor2tensor.models.research.AutoencoderBasic.body and should be compatible with most derived classes. The original autoencoder class relies on embedding the channels to a discrete vocabulary and defines the loss on that vocab. It's cool and all, but here we prefer expressing the reconstruction loss as an actual continuous likelihood function. """ hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN output_activation = tf.nn.softplus if hparams.output_activation == 'softplus' else None input_shape = [None, ] + common_layers.shape_list(features["inputs"])[1:] if hparams.mode == tf.estimator.ModeKeys.PREDICT: # In predict mode, we also define TensorFlow Hub modules for all pieces of # the autoencoder if hparams.encode_psf and 'psf' in features: psf_shape = [None, ] + common_layers.shape_list(features["psf"])[1:] # First build encoder spec def make_model_spec(): input_layer = tf.placeholder(tf.float32, shape=input_shape) x = self.embed(tf.expand_dims(input_layer, -1)) x, encoder_layers = self.encoder(x) b, b_loss = self.bottleneck(x) hub.add_signature(inputs=input_layer, outputs=b) def make_model_spec_psf(): input_layer = tf.placeholder(tf.float32, shape=input_shape) psf_layer = tf.placeholder(tf.float32, shape=psf_shape) x = self.embed(tf.expand_dims(input_layer, -1)) # If we have access to the PSF, we add this information to the encoder if hparams.encode_psf and 'psf' in features: psf_image = tf.expand_dims(tf.signal.irfft2d(tf.cast(psf_layer[...,0], tf.complex64)), axis=-1) # Roll the image to undo the fftshift, assuming x1 zero padding and x2 subsampling psf_image = tf.roll(psf_image, shift=[input_shape[1], input_shape[2]], axis=[1,2]) psf_image = tf.image.resize_with_crop_or_pad(psf_image, input_shape[1], input_shape[2]) net_psf = tf.layers.conv2d(psf_image, hparams.hidden_size // 4, 5, padding='same', name="psf_embed_1") net_psf = common_layers.layer_norm(net_psf, name="psf_norm") x, encoder_layers = self.encoder(tf.concat([x, net_psf], axis=-1)) else: x, encoder_layers = self.encoder(x) b, b_loss = self.bottleneck(x) hub.add_signature(inputs={'input':input_layer, 'psf':psf_layer}, outputs=b) spec = hub.create_module_spec(make_model_spec_psf if hparams.encode_psf else make_model_spec, drop_collections=['checkpoints']) encoder = hub.Module(spec, name="encoder_module") hub.register_module_for_export(encoder, "encoder") if hparams.encode_psf: code = encoder({'input':features["inputs"], 'psf': features['psf']}) else: code = encoder(features["inputs"]) b_shape = [None, ] + common_layers.shape_list(code)[1:] res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) # Second build decoder spec def make_model_spec(): input_layer = tf.placeholder(tf.float32, shape=b_shape) x = self.unbottleneck(input_layer, res_size) x = self.decoder(x, None) reconstr = tf.layers.dense(x, input_shape[-1], name="autoencoder_final", activation=output_activation) hub.add_signature(inputs=input_layer, outputs=reconstr) hub.attach_message("stamp_size", tf.train.Int64List(value=[hparams.problem_hparams.img_len])) try: hub.attach_message("pixel_size", tf.train.FloatList(value=[hparams.problem_hparams.pixel_scale[res] for res in hparams.problem_hparams.resolutions])) except AttributeError: hub.attach_message("pixel_size", tf.train.FloatList(value=[hparams.problem_hparams.pixel_scale])) spec = hub.create_module_spec(make_model_spec, drop_collections=['checkpoints']) decoder = hub.Module(spec, name="decoder_module") hub.register_module_for_export(decoder, "decoder") reconstr = decoder(code) return reconstr , {"bottleneck_loss": 0.0} encoder_layers = None self.is1d = hparams.sample_width == 1 if (hparams.mode != tf.estimator.ModeKeys.PREDICT or self._encode_on_predict): labels = features["targets_raw"] labels_shape = common_layers.shape_list(labels) shape = common_layers.shape_list(labels) with tf.variable_scope('encoder_module'): x = self.embed(tf.expand_dims(labels, -1)) if shape[2] == 1: self.is1d = True # Run encoder. with tf.variable_scope('encoder_module'): # If we have access to the PSF, we add this information to the encoder # Note that we only support single band images so far... if hparams.encode_psf and 'psf' in features: psf_image = tf.expand_dims(tf.signal.irfft2d(tf.cast(features['psf'][...,0], tf.complex64)), axis=-1) # Roll the image to undo the fftshift, assuming x1 zero padding and x2 subsampling psf_image = tf.roll(psf_image, shift=[input_shape[1], input_shape[2]], axis=[1,2]) psf_image = tf.image.resize_with_crop_or_pad(psf_image, input_shape[1], input_shape[2]) net_psf = tf.layers.conv2d(psf_image, hparams.hidden_size // 4, 5, padding='same', name="psf_embed_1") net_psf = common_layers.layer_norm(net_psf, name="psf_norm") x, encoder_layers = self.encoder(tf.concat([x, net_psf], axis=-1)) else: x, encoder_layers = self.encoder(x) # Bottleneck. with tf.variable_scope('encoder_module'): b, b_loss = self.bottleneck(x) xb_loss = 0.0 b_shape = common_layers.shape_list(b) self._cur_bottleneck_tensor = b res_size = common_layers.shape_list(x)[-1] with tf.variable_scope('decoder_module'): b = self.unbottleneck(b, res_size) if not is_training: x = b else: l = 2**hparams.num_hidden_layers warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l) nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01 if common_layers.should_generate_summaries(): tf.summary.scalar("nomix_p_bottleneck", nomix_p) rand = tf.random_uniform(common_layers.shape_list(x)) # This is the distance between b and x. Having this as loss helps learn # the bottleneck function, but if we back-propagated to x it would be # minimized by just setting x=0 and b=0 -- so we don't want too much # of the influence of this, and we stop-gradient to not zero-out x. x_stop = tf.stop_gradient(x) xb_loss = tf.reduce_mean(tf.reduce_sum( tf.squared_difference(x_stop, b), axis=-1)) # To prevent this loss from exploding we clip at 1, but anneal clipping. clip_max = 1.0 / common_layers.inverse_exp_decay( warm_step, min_value=0.001) xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max) xb_loss *= clip_max / xb_clip x = tf.where(tf.less(rand, nomix_p), b, x) else: if self._cur_bottleneck_tensor is None: b = self.sample() else: b = self._cur_bottleneck_tensor self._cur_bottleneck_tensor = b res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) with tf.variable_scope('decoder_module'): x = self.unbottleneck(b, res_size) # Run decoder. with tf.variable_scope('decoder_module'): x = self.decoder(x, encoder_layers) # Cut to the right size and mix before returning. res = x if hparams.mode != tf.estimator.ModeKeys.PREDICT: res = x[:, :shape[1], :shape[2], :] with tf.variable_scope('decoder_module'): reconstr = tf.layers.dense(res, shape[-1], name="autoencoder_final", activation=output_activation) # We apply an optional apodization of the output before taking the if hparams.output_apodization > 0: nx = reconstr.get_shape().as_list()[1] alpha = 2 * hparams.output_apodization / nx from scipy.signal.windows import tukey # Create a tukey window w = tukey(nx, alpha) w = np.outer(w,w).reshape((1, nx, nx,1)).astype('float32') # And penalize non zero things at the border apo_loss = tf.reduce_mean(tf.reduce_sum(((1.- w)*reconstr)**2, axis=[1,2,3])) else: w = 1.0 apo_loss = 0. # We apply the window reconstr = reconstr * w # Optionally regularizes further the output # Anisotropic TV: tv = tf.reduce_mean(tf.image.total_variation(reconstr)) # Smoothed Isotropic TV: #im_dx, im_dy = tf.image.image_gradients(reconstr) #tv = tf.reduce_sum(tf.sqrt(im_dx**2 + im_dy**2 + 1e-6), axis=[1,2,3]) #tv = tf.reduce_mean(tv) image_summary("without_psf",tf.reshape(reconstr, labels_shape)) # Apply channel-wise convolution with the PSF if requested if hparams.apply_psf and 'psf' in features: output_list = [] for i in range(shape[3]): output_list.append(tf.squeeze(convolve(tf.expand_dims(reconstr[...,i],-1), tf.cast(features['psf'][...,i], tf.complex64), zero_padding_factor=1))) reconstr = tf.stack(output_list,axis=-1) reconstr = tf.reshape(reconstr,shape) # Losses. losses = { "bottleneck_extra": b_loss, "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss, "total_variation": hparams.total_variation_loss * tv, "apodization_loss": hparams.apodization_loss * apo_loss, } loglik = loglikelihood_fn(labels, reconstr, features, hparams) targets_loss = tf.reduce_mean(- loglik) tf.summary.scalar("negloglik", targets_loss) tf.summary.scalar("bottleneck_loss", b_loss) # Compute final loss losses["training"] = targets_loss + b_loss + hparams.bottleneck_l2_factor * xb_loss + hparams.total_variation_loss * tv + hparams.apodization_loss * apo_loss logits = tf.reshape(reconstr, labels_shape) image_summary("ae", reconstr) image_summary("input", labels) return logits, losses
def graph_attention(q, k, v, bias, dropout_rate=0.0, image_shapes=None, name=None, make_image_summary=True, save_weights_to=None, dropout_broadcast_dims=None, adjacency_matrix=None, num_edge_types=5): """graph attention. Args: q: a Tensor with shape [batch, heads, length_q, depth_k] k: a Tensor with shape [batch, heads, length_kv, depth_k] v: a Tensor with shape [batch, heads, length_kv, depth_v] bias: bias Tensor (see attention_bias()) dropout_rate: a floating point number image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() name: an optional string make_image_summary: True if you want an image summary. save_weights_to: an optional dictionary to capture attention weights for vizualization; the weights tensor will be appended there under a string key created from the variable scope (including name). dropout_broadcast_dims: an optional list of integers less than 4 specifying in which dimensions to broadcast the dropout decisions. saves memory. adjacency_matrix: optional matrix of [batch, length, length] ids indicating edge type num_edge_types: an int indicating number of edge types Returns: A Tensor of shape [batch, length, depth(q)] """ with tf.variable_scope(name, default_name="dot_product_attention", values=[q, k, v]) as scope: # [batch, num_heads, query_length, memory_length] logits = tf.matmul(q, k, transpose_b=True) if adjacency_matrix is not None: key_head_depth = common_layers.shape_list(q)[-1] adjacency_vectors = make_edge_vectors(adjacency_matrix, num_edge_types, key_head_depth, name=name) # transposing q to be [batch, length_q, heads, depth_k] # to allow for matmul with [batch, length_q, length_q, depth_k] q_t = tf.transpose(q, [0, 2, 1, 3]) adj_logits = tf.matmul(q_t, adjacency_vectors, transpose_b=True) logits += tf.transpose(adj_logits, [0, 2, 1, 3]) # [batch, depth, num_nodes, num_nodes] if bias is not None: logits += bias weights = tf.nn.softmax(logits, name="attention_weights") if save_weights_to is not None: save_weights_to[scope.name] = weights # dropping out the attention links for each of the heads weights = common_layers.dropout_with_broadcast_dims( weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims) if common_layers.should_generate_summaries() and make_image_summary: common_attention.attention_image_summary(weights, image_shapes) return tf.matmul(weights, v)
def dot_product_attention_mtsa( q, k, v, bias, dropout_rate=0.0, image_shapes=None, name=None, make_image_summary=True, save_weights_to=None, dropout_broadcast_dims=None, use_k_mtsa=True, afn_extra='none', afn_dot='exp', afn_multi='exp', bias_start=0., bi_direction=False, ): """Dot-product attention. Args: q: Tensor with shape [..., length_q, depth_k]. k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must match with q. v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must match with q. bias: bias Tensor (see attention_bias()) dropout_rate: a float. image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() name: an optional string make_image_summary: True if you want an image summary. save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). dropout_broadcast_dims: an optional list of integers less than rank of q. Specifies in which dimensions to broadcast the dropout decisions. Returns: Tensor with shape [..., length_q, depth_v]. """ print("!!!!!dot_product_attention_mtsa!!!!!") with tf.variable_scope(name, default_name="dot_product_attention", values=[q, k, v]) as scope: # get dim dim_q = q.get_shape().as_list()[-1] dim_k = k.get_shape().as_list()[-1] dim_v = v.get_shape().as_list()[-1] # prepare multi_logits_scale_factor = 1. / math.sqrt( dim_v) if afn_multi.startswith('scaled') else 1. afn_extra, afn_dot, afn_multi = afn_name2fn(afn_extra), afn_name2fn( afn_dot), afn_name2fn(afn_multi) # if bias is not None: # inp_mask_1d = tf.to_float(tf.equal(bias, 0.)) # bs,1,1,vl # inp_mask_1d = tf.transpose(inp_mask_1d, [0, 1, 3, 2]) # bs,1,vl,1 # else: # inp_mask_1d = None # token2token self attention dot_logits = tf.matmul(q, k, transpose_b=True) # bs,hd,ql,vl if bias is not None: bias = common_layers.cast_like(bias, dot_logits) # 1/bs,1,ql/1,vl dot_logits += bias e_dot_logits = afn_dot(dot_logits) # bs,hd,ql,vl if bi_direction: head_num = v.get_shape().as_list()[1] ql, vl = tf.shape(q)[-2], tf.shape(v)[-2] assert head_num is not None assert head_num % 2 == 0 ones_mat = tf.ones([ql, vl], tf.float32) mul_mask_fw = tf.matrix_band_part(ones_mat, -1, 0) # Lower triangular part. mul_mask_bw = tf.matrix_band_part(ones_mat, 0, -1) # Upper triangular part. mul_mask_fw_tile = tf.tile(tf.expand_dims(mul_mask_fw, 0), [head_num // 2, 1, 1]) mul_mask_bw_tile = tf.tile(tf.expand_dims(mul_mask_bw, 0), [head_num // 2, 1, 1]) mul_mask = tf.expand_dims(tf.concat( [mul_mask_fw_tile, mul_mask_bw_tile], axis=0), axis=0) e_dot_logits *= mul_mask # source2token self-attention multi_logits = multi_head_dense_layer( k if use_k_mtsa else v, dim_v, True, bias_start if afn_extra is None else 0., 'multi_logits1') if afn_extra is not None: # use one extra layer for multi-dim multi_logits = multi_head_dense_layer(afn_extra(multi_logits), dim_v, True, bias_start, 'multi_logits2') e_multi_logits = afn_multi(multi_logits * multi_logits_scale_factor) # bs,hd,vl,vd # if inp_mask_1d is not None: # use mask for exp_logits # e_multi_logits *= inp_mask_1d # mtsa accum_z_deno = tf.matmul(e_dot_logits, e_multi_logits) # bs,hd,ql,vd accum_z_deno = tf.where( # in case of NaN and Inf tf.greater(accum_z_deno, tf.zeros_like(accum_z_deno)), accum_z_deno, tf.ones_like(accum_z_deno)) # attention dropout e_dot_logits = common_layers.dropout_with_broadcast_dims( e_dot_logits, math.sqrt(1. - dropout_rate), broadcast_dims=dropout_broadcast_dims) e_multi_logits = common_layers.dropout_with_broadcast_dims( e_multi_logits, math.sqrt(1. - dropout_rate), broadcast_dims=dropout_broadcast_dims) rep_mul_score = v * e_multi_logits # bs,hd,vl,vd accum_rep_mul_score = tf.matmul(e_dot_logits, rep_mul_score) # bs,hd,ql,vd # calculate the final attention results attn_res = accum_rep_mul_score / accum_z_deno # if inp_mask_1d is not None: # use mask for output # attn_res *= inp_mask_1d # ============ for vis ======= weights = e_dot_logits / (tf.reduce_sum( e_dot_logits, axis=-1, keepdims=True, name="attention_weights") + 0.00001) if save_weights_to is not None: save_weights_to[scope.name] = weights save_weights_to[scope.name + "/logits"] = dot_logits if common_layers.should_generate_summaries() and make_image_summary: common_attention.attention_image_summary(weights, image_shapes) return attn_res
def dot_product_area_attention(q, k, v, bias, dropout_rate=0.0, image_shapes=None, name=None, attention_image_summary=None, save_weights_to=None, dropout_broadcast_dims=None, max_area_width=1, max_area_height=1, memory_height=1, area_key_mode="mean", area_value_mode="sum", top_k_areas=0, area_temperature=1.0, training=True): """Dot-product area attention. Args: q: Tensor with shape [..., length_q, depth_k]. k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must match with q. v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must match with q. bias: bias Tensor (see attention_bias()) dropout_rate: a float. image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() name: an optional string attention_image_summary: the callback for making image summary of attention. save_weights_to: an optional dictionary to capture attention weights for visualization; the weights tensor will be appended there under a string key created from the variable scope (including name). dropout_broadcast_dims: an optional list of integers less than rank of q. Specifies in which dimensions to broadcast the dropout decisions. max_area_width: the max width allowed for an area. max_area_height: the max height allowed for an area. memory_height: the height of the memory. area_key_mode: the mode for computing area keys, which can be "mean", "concat", "sum", "sample_concat", and "sample_sum". area_value_mode: the mode for computing area values, which can be either "mean", or "sum". top_k_areas: Use the top key areas for attention. area_temperature: the temperature for attention softmax. training: indicating if it is in the training mode. Returns: Tensor with shape [..., length_q, depth_v]. """ tf.logging.info( "dot_product_area_attention: " "area_h=%d, area_w=%d, mem_h=%d, " "area_key_mode=%s, area_value_mode=%s, " "area_temperature=%f", max_area_height, max_area_width, memory_height, area_key_mode, area_value_mode, area_temperature) with tf.variable_scope(name, default_name="dot_product_area_attention", values=[q, k, v]) as scope: mem_shape = common_layers.shape_list(k) batch_size = mem_shape[0] head_size = mem_shape[1] length = mem_shape[2] depth = mem_shape[3] k_area = compute_area_key(tf.reshape(k, [-1, length, depth]), max_area_width=max_area_width, max_area_height=max_area_height, height=memory_height, mode=area_key_mode, training=training) if area_value_mode == "mean": v_area, _, _, _, _ = compute_area_features( tf.reshape(v, [-1, length, depth]), max_area_width=max_area_width, max_area_height=max_area_height, height=memory_height) elif area_value_mode == "max": v_area, _, _ = basic_pool(tf.reshape(v, [-1, length, depth]), max_area_width=max_area_width, max_area_height=max_area_height, height=memory_height, fn=tf.reduce_max) elif area_value_mode == "sum": _, _, v_area, _, _ = compute_area_features( tf.reshape(v, [-1, length, depth]), max_area_width=max_area_width, max_area_height=max_area_height, height=memory_height) else: raise ValueError("Unsupported area value mode=%s" % area_value_mode) k = tf.reshape(k_area, [batch_size, head_size, -1, depth]) v = tf.reshape(v_area, [batch_size, head_size, -1, depth]) logits = tf.matmul(q, k, transpose_b=True) # [..., length_q, length_kv] if bias is not None: bias = common_layers.cast_like(bias, logits) with tf.name_scope("compute_area_att_bias", values=[bias]): bias_shape = common_layers.shape_list(bias) mem_length = bias_shape[-1] bias_values = tf.reshape(tf.to_float(tf.less(bias, -1)), [-1, mem_length, 1]) _, _, padding_sum, _, _ = compute_area_features( bias_values, max_area_width=max_area_width, max_area_height=max_area_height, height=memory_height) bias = tf.where(tf.cast(tf.to_int32(padding_sum), tf.bool), tf.fill(tf.shape(padding_sum), -np.inf), tf.zeros_like(padding_sum, dtype=tf.float32)) bias = tf.reshape( bias, [bias_shape[0], bias_shape[1], bias_shape[2], -1]) logits += bias logits = logits / area_temperature weights = tf.nn.softmax(logits, name="attention_weights") if top_k_areas > 0: tf.logging.info("area_attention top_k_areas=%d", top_k_areas) top_k = tf.minimum( common_layers.shape_list(weights)[-1], top_k_areas) top_weights, _ = tf.nn.top_k(weights, k=top_k) min_values = tf.reduce_min(top_weights, -1, keepdims=True) weights = tf.where(tf.greater_equal(weights, min_values), weights, tf.zeros_like(weights)) weights = tf.div(weights, tf.reduce_sum(weights, -1, keepdims=True)) if save_weights_to is not None: save_weights_to[scope.name] = weights save_weights_to[scope.name + "/logits"] = logits # Drop out attention links for each head. weights = common_layers.dropout_with_broadcast_dims( weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims) if common_layers.should_generate_summaries( ) and attention_image_summary: attention_image_summary(weights, image_shapes) return tf.matmul(weights, v)
def noisy_top_k_gating(x, num_experts, train, k=2, initializer=tf.zeros_initializer(), noisy_gating=True, noise_epsilon=1e-2, name=None): """Noisy top-k gating. See paper: https://arxiv.org/abs/1701.06538. Args: x: input Tensor with shape [batch_size, input_size] num_experts: an integer train: a boolean - we only add noise at training time. k: an integer - number of experts per example initializer: an initializer noisy_gating: a boolean noise_epsilon: a float name: an optional string Returns: gates: a Tensor with shape [batch_size, num_experts] load: a Tensor with shape [num_experts] """ with tf.variable_scope(name, default_name="noisy_top_k_gating"): input_size = x.get_shape().as_list()[-1] w_gate = tf.get_variable("w_gate", [input_size, num_experts], tf.float32, initializer) if noisy_gating: w_noise = tf.get_variable("w_noise", [input_size, num_experts], tf.float32, initializer) clean_logits = tf.matmul(x, w_gate) if noisy_gating: raw_noise_stddev = tf.matmul(x, w_noise) noise_stddev = ( (tf.nn.softplus(raw_noise_stddev) + noise_epsilon) * (tf.to_float(train))) noisy_logits = clean_logits + ( tf.random_normal(tf.shape(clean_logits)) * noise_stddev) logits = noisy_logits if common_layers.should_generate_summaries(): tf.summary.histogram("noisy_logits", noisy_logits) tf.summary.histogram("noise_stddev", noise_stddev) else: logits = clean_logits top_logits, top_indices = _my_top_k(logits, min(k + 1, num_experts)) # top k logits has shape [batch, k] top_k_logits = tf.slice(top_logits, [0, 0], [-1, k]) top_k_indices = tf.slice(top_indices, [0, 0], [-1, k]) top_k_gates = tf.nn.softmax(top_k_logits) # This will be a `Tensor` of shape `[batch_size, n]`, with zeros in the # positions corresponding to all but the top k experts per example. gates = _rowwise_unsorted_segment_sum(top_k_gates, top_k_indices, num_experts) if noisy_gating and k < num_experts: load = tf.reduce_sum( _prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits, k), 0) else: load = _gates_to_load(gates) if common_layers.should_generate_summaries(): tf.summary.histogram("importance", tf.reduce_sum(gates, 0)) tf.summary.histogram("load", load) return gates, load
def body(self, features): hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN encoder_layers = None self.is1d = hparams.sample_width == 1 if (hparams.mode != tf.estimator.ModeKeys.PREDICT or self._encode_on_predict): labels = features["targets_raw"] labels_shape = common_layers.shape_list(labels) # handle videos if len(labels.shape) == 5: labels = time_to_channels(labels) shape = common_layers.shape_list(labels) x = tf.expand_dims(labels, axis=-1) x = self.embed(x) target_codes = x print(x) if shape[2] == 1: self.is1d = True # Run encoder. x, encoder_layers = self.encoder(x) # Bottleneck. b, b_loss = self.bottleneck(x) xb_loss = 0.0 b_shape = common_layers.shape_list(b) self._cur_bottleneck_tensor = b res_size = common_layers.shape_list(x)[-1] b = self.unbottleneck(b, res_size) if not is_training: x = b else: l = 2**hparams.num_hidden_layers warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l) nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01 if common_layers.should_generate_summaries(): tf.summary.scalar("nomix_p_bottleneck", nomix_p) rand = tf.random_uniform(common_layers.shape_list(x)) # This is the distance between b and x. Having this as loss helps learn # the bottleneck function, but if we back-propagated to x it would be # minimized by just setting x=0 and b=0 -- so we don't want too much # of the influence of this, and we stop-gradient to not zero-out x. x_stop = tf.stop_gradient(x) xb_loss = tf.reduce_mean(tf.reduce_sum( tf.squared_difference(x_stop, b), axis=-1)) # To prevent this loss from exploding we clip at 1, but anneal clipping. clip_max = 1.0 / common_layers.inverse_exp_decay( warm_step, min_value=0.001) xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max) xb_loss *= clip_max / xb_clip x = tf.where(tf.less(rand, nomix_p), b, x) else: if self._cur_bottleneck_tensor is None: b = self.sample() else: b = self._cur_bottleneck_tensor self._cur_bottleneck_tensor = b res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. x = self.decoder(x, encoder_layers) # Cut to the right size and mix before returning. res = x if hparams.mode != tf.estimator.ModeKeys.PREDICT: res = x[:, :shape[1], :shape[2], :] # Final dense layer. res = tf.layers.dense( res, self.num_channels * hparams.hidden_size, name="res_dense") output_shape = common_layers.shape_list(res)[:-1] + [ self.num_channels, self.hparams.hidden_size ] res = tf.reshape(res, output_shape) if hparams.mode == tf.estimator.ModeKeys.PREDICT: reconstr = tf.layers.dense(res, self.num_channels, name="autoencoder_final") return reconstr, {"bottleneck_loss": 0.0} # Losses. losses = { "bottleneck_extra": b_loss, "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss } reconstr = tf.layers.dense(res, self.num_channels, name="autoencoder_final") reconstr = tf.reshape(reconstr, labels_shape) targets_loss = self.reconstruction_loss(reconstr, labels) losses["training"] = targets_loss self.image_summary("inputs", labels) self.image_summary("ae", reconstr) return reconstr, losses
def body(self, features): hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN if hparams.mode != tf.estimator.ModeKeys.PREDICT: labels = features["targets_raw"] vocab_size = self._problem_hparams.target_modality.top_dimensionality shape = common_layers.shape_list(labels) x = tf.one_hot(labels, vocab_size) x = self.embed(x) target_codes = x is1d = shape[2] == 1 self.is1d = is1d # Run encoder. x, encoder_layers = self.encoder(x) # Bottleneck. b, b_loss = self.bottleneck(x) xb_loss = 0.0 b_shape = common_layers.shape_list(b) self._cur_bottleneck_tensor = b b = self.unbottleneck(b, common_layers.shape_list(x)[-1]) if not is_training: x = b else: l = 2**hparams.num_hidden_layers warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l) nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01 if common_layers.should_generate_summaries(): tf.summary.scalar("nomix_p_bottleneck", nomix_p) rand = tf.random_uniform(common_layers.shape_list(x)) # This is the distance between b and x. Having this as loss helps learn # the bottleneck function, but if we back-propagated to x it would be # minimized by just setting x=0 and b=0 -- so we don't want too much # of the influence of this, and we stop-gradient to not zero-out x. x_stop = tf.stop_gradient(x) xb_loss = tf.reduce_mean( tf.reduce_sum(tf.square(x_stop - b), axis=-1)) # To prevent this loss from exploding we clip at 1, but anneal clipping. clip_max = 1.0 / common_layers.inverse_exp_decay( warm_step, min_value=0.001) xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max) xb_loss *= clip_max / xb_clip x = tf.where(tf.less(rand, nomix_p), b, x) if hparams.gan_loss_factor != 0.0: # Add a purely sampled batch on which we'll compute the GAN loss. g = self.unbottleneck(self.sample(shape=b_shape), common_layers.shape_list(x)[-1], reuse=True) x = tf.concat([g, x], axis=0) encoder_layers = [ tf.concat([l, l], axis=0) for l in encoder_layers ] else: if self._cur_bottleneck_tensor is None: b = self.sample() else: b = self._cur_bottleneck_tensor res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) x = self.unbottleneck(b, res_size) # Run decoder. x = self.decoder(x, encoder_layers) if hparams.mode == tf.estimator.ModeKeys.PREDICT: return x, {"bottleneck_loss": 0.0} # Cut to the right size and mix before returning. res = x[:, :shape[1], :shape[2], :] # Final dense layer. res = tf.layers.dense(res, self.num_channels * hparams.hidden_size, name="res_dense") output_shape = common_layers.shape_list(res)[:-1] + [ self.num_channels, self.hparams.hidden_size ] res = tf.reshape(res, output_shape) if hparams.gan_loss_factor != 0.0: res_gan, res = tf.split(res, 2, axis=0) # Losses. losses = { "bottleneck_extra": b_loss, "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss } if hparams.use_vq_loss: vq_temperature = hparams.vq_temperature / common_layers.inverse_exp_decay( hparams.gan_codes_warmup_steps * 1.2, min_value=hparams.vq_temperature * 2) if hparams.mode != tf.estimator.ModeKeys.TRAIN: vq_temperature = None with tf.variable_scope("vq_loss"): (reconstr, _, target_codes, code_loss, targets_loss) = discretization.vq_loss( res, labels, vocab_size, temperature=vq_temperature) losses["code_loss"] = code_loss * hparams.code_loss_factor losses["training"] = targets_loss else: reconstr = tf.layers.dense(res, vocab_size, name="autoencoder_final") targets_loss = tf.losses.sparse_softmax_cross_entropy( logits=reconstr, labels=labels) losses["training"] = targets_loss # GAN losses. if hparams.gan_loss_factor != 0.0: update_means_factor = common_layers.inverse_exp_decay( hparams.gan_codes_warmup_steps, min_value=0.0001) if hparams.use_vq_loss: with tf.variable_scope("vq_loss", reuse=True): update_means = tf.less(tf.random_uniform([]), update_means_factor) reconstr_gan, gan_codes, _, code_loss_gan, _ = discretization.vq_loss( res_gan, labels, vocab_size, do_update=update_means, temperature=vq_temperature) code_loss_gan *= hparams.code_loss_factor * update_means_factor losses["code_loss_gan"] = code_loss_gan else: reconstr_gan = tf.layers.dense(res_gan, vocab_size, name="autoencoder_final", reuse=True) reconstr_gan = tf.nn.log_softmax(reconstr_gan) if is_training and hparams.gumbel_temperature > 0.0: gumbel_samples = discretization.gumbel_sample( common_layers.shape_list(reconstr_gan)) gumbel_samples *= hparams.gumbel_noise_factor reconstr_gan += gumbel_samples reconstr_sample = latent_layers.multinomial_sample( reconstr_gan, temperature=hparams.gumbel_temperature) reconstr_gan = tf.nn.softmax(reconstr_gan / hparams.gumbel_temperature) else: reconstr_sample = tf.argmax(reconstr_gan, axis=-1) reconstr_gan = tf.nn.softmax(reconstr_gan / 0.1) # Sharpen a bit. # Use 1-hot forward, softmax backward. reconstr_hot = tf.one_hot(reconstr_sample, vocab_size) reconstr_gan += reconstr_hot - tf.stop_gradient(reconstr_gan) # Embed to codes. gan_codes = self.embed(reconstr_gan) # Add GAN loss if requested. gan_loss = 0.0 if hparams.gan_loss_factor != 0.0: self.image_summary("gan", reconstr_gan) def discriminate(x): return self.discriminator(x, is_training=is_training) tc_shape = common_layers.shape_list(target_codes) if len(tc_shape) > 4: target_codes = tf.reshape( target_codes, tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]]) gan_codes = tf.reshape( gan_codes, tc_shape[:-2] + [tc_shape[-1] * tc_shape[-2]]) gan_lr = common_layers.inverse_exp_decay( hparams.gan_codes_warmup_steps * 1.5) rev_grad_gan_codes = reverse_gradient(gan_codes, lr=gan_lr) gan_loss = common_layers.sliced_gan_loss( target_codes, rev_grad_gan_codes, discriminate, self.hparams.num_sliced_vecs) gan_loss *= hparams.gan_loss_factor * update_means_factor losses["gan_loss"] = -gan_loss self.image_summary("ae", reconstr) logits = reconstr return logits, losses
def graph_attention(q, k, v, bias, dropout_rate=0.0, image_shapes=None, name=None, make_image_summary=True, save_weights_to=None, dropout_broadcast_dims=None, adjacency_matrix=None, num_edge_types=5): """graph attention. Args: q: a Tensor with shape [batch, heads, length_q, depth_k] k: a Tensor with shape [batch, heads, length_kv, depth_k] v: a Tensor with shape [batch, heads, length_kv, depth_v] bias: bias Tensor (see attention_bias()) dropout_rate: a floating point number image_shapes: optional tuple of integer scalars. see comments for attention_image_summary() name: an optional string make_image_summary: True if you want an image summary. save_weights_to: an optional dictionary to capture attention weights for vizualization; the weights tensor will be appended there under a string key created from the variable scope (including name). dropout_broadcast_dims: an optional list of integers less than 4 specifying in which dimensions to broadcast the dropout decisions. saves memory. adjacency_matrix: optional matrix of [batch, length, length] ids indicating edge type num_edge_types: an int indicating number of edge types Returns: A Tensor of shape [batch, length, depth(q)] """ with tf.variable_scope( name, default_name="dot_product_attention", values=[q, k, v]) as scope: # [batch, num_heads, query_length, memory_length] logits = tf.matmul(q, k, transpose_b=True) if adjacency_matrix is not None: key_head_depth = common_layers.shape_list(q)[-1] adjacency_vectors = make_edge_vectors( adjacency_matrix, num_edge_types, key_head_depth, name=name) # transposing q to be [batch, length_q, heads, depth_k] # to allow for matmul with [batch, length_q, length_q, depth_k] q_t = tf.transpose(q, [0, 2, 1, 3]) adj_logits = tf.matmul(q_t, adjacency_vectors, transpose_b=True) logits += tf.transpose(adj_logits, [0, 2, 1, 3]) # [batch, depth, num_nodes, num_nodes] if bias is not None: logits += bias weights = tf.nn.softmax(logits, name="attention_weights") if save_weights_to is not None: save_weights_to[scope.name] = weights # dropping out the attention links for each of the heads weights = common_layers.dropout_with_broadcast_dims( weights, 1.0 - dropout_rate, broadcast_dims=dropout_broadcast_dims) if common_layers.should_generate_summaries() and make_image_summary: common_attention.attention_image_summary(weights, image_shapes) return tf.matmul(weights, v)
def autoencoder_body(self, features): """ Customized body function for autoencoders acting on continuous images. This is based on tensor2tensor.models.research.AutoencoderBasic.body and should be compatible with most derived classes. The original autoencoder class relies on embedding the channels to a discrete vocabulary and defines the loss on that vocab. It's cool and all, but here we prefer expressing the reconstruction loss as an actual continuous likelihood function. """ hparams = self.hparams is_training = hparams.mode == tf.estimator.ModeKeys.TRAIN output_activation = tf.nn.softplus if hparams.output_activation == 'softplus' else None input_shape = [ None, ] + common_layers.shape_list(features["inputs"])[1:] if hparams.mode == tf.estimator.ModeKeys.PREDICT: # In predict mode, we also define TensorFlow Hub modules for all pieces of # the autoencoder # First build encoder spec def make_model_spec(): input_layer = tf.placeholder(tf.float32, shape=input_shape) x = self.embed(tf.expand_dims(input_layer, -1)) x, encoder_layers = self.encoder(x) b, b_loss = self.bottleneck(x) hub.add_signature(inputs=input_layer, outputs=b) def make_model_spec_psf(): input_layer = tf.placeholder(tf.float32, shape=input_shape) psf_layer = tf.placeholder(tf.float32, shape=input_shape) x = self.embed(tf.expand_dims(input_layer, -1)) # If we have access to the PSF, we add this information to the encoder if hparams.encode_psf and 'psf' in features: net_psf = tf.layers.conv2d(psf_layer, hparams.hidden_size // 4, 5, padding='same', name="psf_embed_1") net_psf = common_layers.layer_norm(net_psf, name="psf_norm") x, encoder_layers = self.encoder( tf.concat([x, net_psf], axis=-1)) else: x, encoder_layers = self.encoder(x) b, b_loss = self.bottleneck(x) hub.add_signature(inputs={ 'input': input_layer, 'psf': psf_layer }, outputs=b) spec = hub.create_module_spec( make_model_spec_psf if hparams.encode_psf else make_model_spec, drop_collections=['checkpoints']) encoder = hub.Module(spec, name="encoder_module") hub.register_module_for_export(encoder, "encoder") if hparams.encode_psf: code = encoder({ 'input': features["inputs"], 'psf': features['psf'] }) else: code = encoder(features["inputs"]) b_shape = [ None, ] + common_layers.shape_list(code)[1:] res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) # Second build decoder spec def make_model_spec(): input_layer = tf.placeholder(tf.float32, shape=b_shape) x = self.unbottleneck(input_layer, res_size) x = self.decoder(x, None) reconstr = tf.layers.dense(x, self.num_channels, name="autoencoder_final", activation=output_activation) hub.add_signature(inputs=input_layer, outputs=reconstr) hub.attach_message( "stamp_size", tf.train.Int64List(value=[hparams.problem_hparams.img_len])) hub.attach_message( "pixel_size", tf.train.FloatList( value=[hparams.problem_hparams.pixel_scale])) spec = hub.create_module_spec(make_model_spec, drop_collections=['checkpoints']) decoder = hub.Module(spec, name="decoder_module") hub.register_module_for_export(decoder, "decoder") reconstr = decoder(code) return reconstr, {"bottleneck_loss": 0.0} encoder_layers = None self.is1d = hparams.sample_width == 1 if (hparams.mode != tf.estimator.ModeKeys.PREDICT or self._encode_on_predict): labels = features["targets_raw"] labels_shape = common_layers.shape_list(labels) shape = common_layers.shape_list(labels) with tf.variable_scope('encoder_module'): x = self.embed(tf.expand_dims(labels, -1)) if shape[2] == 1: self.is1d = True # Run encoder. with tf.variable_scope('encoder_module'): # If we have access to the PSF, we add this information to the encoder if hparams.encode_psf and 'psf' in features: net_psf = tf.layers.conv2d(features['psf'], hparams.hidden_size // 4, 5, padding='same', name="psf_embed_1") net_psf = common_layers.layer_norm(net_psf, name="psf_norm") x, encoder_layers = self.encoder( tf.concat([x, net_psf], axis=-1)) else: x, encoder_layers = self.encoder(x) # Bottleneck. with tf.variable_scope('encoder_module'): b, b_loss = self.bottleneck(x) xb_loss = 0.0 b_shape = common_layers.shape_list(b) self._cur_bottleneck_tensor = b res_size = common_layers.shape_list(x)[-1] with tf.variable_scope('decoder_module'): b = self.unbottleneck(b, res_size) if not is_training: x = b else: l = 2**hparams.num_hidden_layers warm_step = int(hparams.bottleneck_warmup_steps * 0.25 * l) nomix_p = common_layers.inverse_lin_decay(warm_step) + 0.01 if common_layers.should_generate_summaries(): tf.summary.scalar("nomix_p_bottleneck", nomix_p) rand = tf.random_uniform(common_layers.shape_list(x)) # This is the distance between b and x. Having this as loss helps learn # the bottleneck function, but if we back-propagated to x it would be # minimized by just setting x=0 and b=0 -- so we don't want too much # of the influence of this, and we stop-gradient to not zero-out x. x_stop = tf.stop_gradient(x) xb_loss = tf.reduce_mean( tf.reduce_sum(tf.squared_difference(x_stop, b), axis=-1)) # To prevent this loss from exploding we clip at 1, but anneal clipping. clip_max = 1.0 / common_layers.inverse_exp_decay(warm_step, min_value=0.001) xb_clip = tf.maximum(tf.stop_gradient(xb_loss), clip_max) xb_loss *= clip_max / xb_clip x = tf.where(tf.less(rand, nomix_p), b, x) else: if self._cur_bottleneck_tensor is None: b = self.sample() else: b = self._cur_bottleneck_tensor self._cur_bottleneck_tensor = b res_size = self.hparams.hidden_size * 2**self.hparams.num_hidden_layers res_size = min(res_size, hparams.max_hidden_size) with tf.variable_scope('decoder_module'): x = self.unbottleneck(b, res_size) # Run decoder. with tf.variable_scope('decoder_module'): x = self.decoder(x, encoder_layers) # Cut to the right size and mix before returning. res = x if hparams.mode != tf.estimator.ModeKeys.PREDICT: res = x[:, :shape[1], :shape[2], :] with tf.variable_scope('decoder_module'): reconstr = tf.layers.dense(res, self.num_channels, name="autoencoder_final", activation=output_activation) # Apply channel-wise convolution with the PSF if requested # TODO: Handle multiple bands if hparams.apply_psf and 'psf' in features: if self.num_channels > 1: raise NotImplementedError rec_padded = tf.pad( reconstr[:, :, :, 0], [[0, 0], [0, int(hparams.psf_convolution_pad_factor * shape[1])], [0, int(hparams.psf_convolution_pad_factor * shape[2])]]) psf_padded = tf.pad( features['psf'][..., 0], [[0, 0], [0, int(hparams.psf_convolution_pad_factor * shape[1])], [0, int(hparams.psf_convolution_pad_factor * shape[2])]]) reconstr = tf.expand_dims(tf.spectral.irfft2d( tf.spectral.rfft2d(rec_padded) * tf.cast(tf.abs(tf.spectral.rfft2d(psf_padded)), tf.complex64)), axis=-1) reconstr = reconstr[:, :shape[1], :shape[2], :] # Losses. losses = { "bottleneck_extra": b_loss, "bottleneck_l2": hparams.bottleneck_l2_factor * xb_loss } loglik = loglikelihood_fn(labels, reconstr, features, hparams) targets_loss = tf.reduce_mean(-loglik) tf.summary.scalar("negloglik", targets_loss) tf.summary.scalar("bottleneck_loss", b_loss) losses["training"] = targets_loss logits = tf.reshape(reconstr, labels_shape) image_summary("ae", reconstr) image_summary("input", labels) return logits, losses