def slicenet_internal(inputs, targets, target_space, hparams, run_decoder=True): """The slicenet model, main step used for training.""" with tf.variable_scope("slicenet"): # Project to hidden size if necessary if inputs.get_shape().as_list()[-1] != hparams.model_d: inputs = common_layers.conv_block( inputs, hparams.model_d, [((1, 1), (3, 3))], first_relu=False, padding="SAME", force2d=True) # Flatten inputs and encode. inputs = tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2) inputs_mask = 1.0 - embedding_to_padding(inputs) inputs = common_layers.add_timing_signal(inputs) # Add position info. target_space_emb = embed_target_space(target_space, hparams.model_d) extra_layers = int(hparams.num_hidden_layers * 1.5) inputs_encoded = multi_conv_res( inputs, "SAME", "encoder", extra_layers, hparams, mask=inputs_mask) if not run_decoder: return inputs_encoded # Do the middle part. decoder_start, similarity_loss = slicenet_middle( inputs_encoded, targets, target_space_emb, inputs_mask, hparams) # Decode. decoder_final = multi_conv_res( decoder_start, "LEFT", "decoder", hparams.num_hidden_layers, hparams, mask=inputs_mask, source=inputs_encoded) return decoder_final, tf.reduce_mean(similarity_loss)
def slicenet_internal(inputs, targets, target_space, hparams, run_decoder=True): """The slicenet model, main step used for training.""" with tf.variable_scope("slicenet"): # Project to hidden size if necessary if inputs.get_shape().as_list()[-1] != hparams.hidden_size: inputs = common_layers.conv_block( inputs, hparams.hidden_size, [((1, 1), (3, 3))], first_relu=False, padding="SAME", force2d=True) # Flatten inputs and encode. inputs = tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2) inputs_mask = 1.0 - embedding_to_padding(inputs) inputs = common_layers.add_timing_signal(inputs) # Add position info. target_space_emb = embed_target_space(target_space, hparams.hidden_size) extra_layers = int(hparams.num_hidden_layers * 1.5) inputs_encoded = multi_conv_res( inputs, "SAME", "encoder", extra_layers, hparams, mask=inputs_mask) if not run_decoder: return inputs_encoded # Do the middle part. decoder_start, similarity_loss = slicenet_middle( inputs_encoded, targets, target_space_emb, inputs_mask, hparams) # Decode. decoder_final = multi_conv_res( decoder_start, "LEFT", "decoder", hparams.num_hidden_layers, hparams, mask=inputs_mask, source=inputs_encoded) return decoder_final, tf.reduce_mean(similarity_loss)
def slicenet_internal(inputs, targets, target_space, problem_idx, hparams): """The slicenet model, main step used for training.""" with tf.variable_scope("slicenet"): # Flatten inputs and encode. inputs = tf.expand_dims(common_layers.flatten4d3d(inputs), axis=2) inputs_mask = 1.0 - embedding_to_padding(inputs) inputs = common_layers.add_timing_signal(inputs) # Add position info. target_space_emb = embed_target_space(target_space, hparams.hidden_size) extra_layers = int(hparams.num_hidden_layers * 1.5) inputs_encoded = multi_conv_res(inputs, "SAME", "encoder", extra_layers, hparams, mask=inputs_mask) target_modality_name = hparams.problems[ problem_idx].target_modality.name if "class_label_modality" in target_modality_name: # If we're just predicing a class, there is no use for a decoder. return inputs_encoded # Do the middle part. decoder_start, similarity_loss = slicenet_middle( inputs_encoded, targets, target_space_emb, inputs_mask, hparams) # Decode. decoder_final = multi_conv_res(decoder_start, "LEFT", "decoder", hparams.num_hidden_layers, hparams, mask=inputs_mask, source=inputs_encoded) return decoder_final, tf.reduce_mean(similarity_loss)
def testAddTimingSignal(self): batch = 5 length = 7 height = 3 depth = 35 x = np.random.rand(batch, length, height, depth) a = common_layers.add_timing_signal(tf.constant(x, dtype=tf.float32)) res = self.evaluate(a) self.assertEqual(res.shape, (batch, length, height, depth))
def testAddTimingSignal(self): batch = 5 length = 7 height = 3 depth = 35 x = np.random.rand(batch, length, height, depth) with self.test_session() as session: a = common_layers.add_timing_signal(tf.constant(x, dtype=tf.float32)) session.run(tf.global_variables_initializer()) res = session.run(a) self.assertEqual(res.shape, (batch, length, height, depth))
def slicenet_middle(inputs_encoded, targets, target_space_emb, mask, hparams): """Middle part of slicenet, connecting encoder and decoder.""" def norm_fn(x, name): with tf.variable_scope(name, default_name="norm"): return common_layers.apply_norm(x, hparams.norm_type, hparams.hidden_size, hparams.norm_epsilon) # Flatten targets and embed target_space_id. targets_flat = tf.expand_dims(common_layers.flatten4d3d(targets), axis=2) target_space_emb = tf.tile(target_space_emb, [tf.shape(targets_flat)[0], 1, 1, 1]) # Calculate similarity loss (but don't run if not needed). if len(hparams.problems) > 1 and hparams.sim_loss_mult > 0.00001: targets_timed = common_layers.add_timing_signal(targets_flat) extra_layers = int(hparams.num_hidden_layers * 1.5) with tf.variable_scope(tf.get_variable_scope(), reuse=True): targets_encoded = multi_conv_res(targets_timed, "SAME", "encoder", extra_layers, hparams) with tf.variable_scope("similarity_loss"): similarity_loss = similarity_cost(inputs_encoded, targets_encoded) similarity_loss *= hparams.sim_loss_mult else: similarity_loss = 0.0 # Use attention from each target to look at input and retrieve. targets_shifted = common_layers.shift_right(targets_flat, pad_value=target_space_emb) if hparams.attention_type == "none": targets_with_attention = tf.zeros_like(targets_shifted) else: inputs_padding_bias = (1.0 - mask) * -1e9 # Bias to not attend to padding. targets_with_attention = attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=inputs_padding_bias) # Positional targets: merge attention and raw. kernel = (hparams.kernel_height, hparams.kernel_width) targets_merged = common_layers.subseparable_conv_block( tf.concat([targets_with_attention, targets_shifted], axis=3), hparams.hidden_size, [((1, 1), kernel)], normalizer_fn=norm_fn, padding="LEFT", separability=4, name="targets_merge") return targets_merged, similarity_loss
def slicenet_middle(inputs_encoded, targets, target_space_emb, mask, hparams): """Middle part of slicenet, connecting encoder and decoder.""" def norm_fn(x, name): with tf.variable_scope(name, default_name="norm"): return common_layers.apply_norm(x, hparams.norm_type, hparams.hidden_size, hparams.norm_epsilon) # Flatten targets and embed target_space_id. targets_flat = tf.expand_dims(common_layers.flatten4d3d(targets), axis=2) target_space_emb = tf.tile(target_space_emb, [tf.shape(targets_flat)[0], 1, 1, 1]) # Calculate similarity loss (but don't run if not needed). if len(hparams.problems) > 1 and hparams.sim_loss_mult > 0.00001: targets_timed = common_layers.add_timing_signal(targets_flat) extra_layers = int(hparams.num_hidden_layers * 1.5) with tf.variable_scope(tf.get_variable_scope(), reuse=True): targets_encoded = multi_conv_res(targets_timed, "SAME", "encoder", extra_layers, hparams) with tf.variable_scope("similarity_loss"): similarity_loss = similarity_cost(inputs_encoded, targets_encoded) similarity_loss *= hparams.sim_loss_mult else: similarity_loss = 0.0 # Use attention from each target to look at input and retrieve. targets_shifted = common_layers.shift_right( targets_flat, pad_value=target_space_emb) if hparams.attention_type == "none": targets_with_attention = tf.zeros_like(targets_shifted) else: inputs_padding_bias = (1.0 - mask) * -1e9 # Bias to not attend to padding. targets_with_attention = attention( targets_shifted, inputs_encoded, norm_fn, hparams, bias=inputs_padding_bias) # Positional targets: merge attention and raw. kernel = (hparams.kernel_height, hparams.kernel_width) targets_merged = common_layers.subseparable_conv_block( tf.concat([targets_with_attention, targets_shifted], axis=3), hparams.hidden_size, [((1, 1), kernel)], normalizer_fn=norm_fn, padding="LEFT", separability=4, name="targets_merge") return targets_merged, similarity_loss
def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None): """Complete attention layer with preprocessing.""" separabilities = [hparams.separability, hparams.separability] if hparams.separability < 0: separabilities = [hparams.separability - 1, hparams.separability] targets_timed = common_layers.subseparable_conv_block( common_layers.add_timing_signal(targets_shifted), hparams.model_d, [((1, 1), (5, 1)), ((4, 1), (5, 1))], normalizer_fn=norm_fn, padding="LEFT", separabilities=separabilities, name="targets_time") if hparams.attention_type == "transformer": targets_timed = tf.squeeze(targets_timed, 2) target_shape = tf.shape(targets_timed) targets_segment = tf.zeros([target_shape[0], target_shape[1]]) target_attention_bias = common_attention.attention_bias_lower_triangle( target_shape[1]) inputs_encoded = common_layers.flatten4d3d(inputs_encoded) # TODO(jbaccash): use input bias parameter. This code seems to assume fixed # size inputs. inputs_attention_bias = tf.zeros([ tf.shape(inputs_encoded)[0], hparams.num_heads, tf.shape(targets_segment)[1], tf.shape(inputs_encoded)[1] ]) qv = common_attention.multihead_attention( targets_timed, None, target_attention_bias, hparams.model_d, hparams.model_d, hparams.model_d, hparams.num_heads, hparams.attention_dropout, name="self_attention") qv = common_attention.multihead_attention( qv, inputs_encoded, inputs_attention_bias, hparams.model_d, hparams.model_d, hparams.model_d, hparams.num_heads, hparams.attention_dropout, name="encdec_attention") return tf.expand_dims(qv, 2) else: raise ValueError("Unsupported attention_type: %s" % hparams.attention_type)
def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None): """Complete attention layer with preprocessing.""" separabilities = [hparams.separability, hparams.separability] if hparams.separability < 0: separabilities = [hparams.separability - 1, hparams.separability] targets_timed = common_layers.subseparable_conv_block( common_layers.add_timing_signal(targets_shifted), hparams.hidden_size, [((1, 1), (5, 1)), ((4, 1), (5, 1))], normalizer_fn=norm_fn, padding="LEFT", separabilities=separabilities, name="targets_time") if hparams.attention_type == "transformer": targets_timed = tf.squeeze(targets_timed, 2) target_shape = tf.shape(targets_timed) targets_segment = tf.zeros([target_shape[0], target_shape[1]]) target_attention_bias = common_attention.attention_bias( targets_segment, targets_segment, lower_triangular=True) inputs_attention_bias = tf.zeros([ tf.shape(inputs_encoded)[0], hparams.num_heads, tf.shape(targets_segment)[1], tf.shape(inputs_encoded)[1] ]) qv = common_attention.multihead_attention(targets_timed, None, target_attention_bias, hparams.hidden_size, hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, name="self_attention") qv = common_attention.multihead_attention(qv, inputs_encoded, inputs_attention_bias, hparams.hidden_size, hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, name="encdec_attention") return tf.expand_dims(qv, 2) elif hparams.attention_type == "simple": targets_with_attention = common_layers.simple_attention(targets_timed, inputs_encoded, bias=bias) return norm_fn(targets_shifted + targets_with_attention, name="attn_norm")
def attention(targets_shifted, inputs_encoded, norm_fn, hparams, bias=None): """Complete attention layer with preprocessing.""" separabilities = [hparams.separability, hparams.separability] if hparams.separability < 0: separabilities = [hparams.separability - 1, hparams.separability] targets_timed = common_layers.subseparable_conv_block( common_layers.add_timing_signal(targets_shifted), hparams.hidden_size, [((1, 1), (5, 1)), ((4, 1), (5, 1))], normalizer_fn=norm_fn, padding="LEFT", separabilities=separabilities, name="targets_time") if hparams.attention_type == "transformer": targets_timed = tf.squeeze(targets_timed, 2) target_shape = tf.shape(targets_timed) targets_segment = tf.zeros([target_shape[0], target_shape[1]]) target_attention_bias = common_attention.attention_bias( targets_segment, targets_segment, lower_triangular=True) inputs_attention_bias = tf.zeros([ tf.shape(inputs_encoded)[0], hparams.num_heads, tf.shape(targets_segment)[1], tf.shape(inputs_encoded)[1] ]) qv = common_attention.multihead_attention( targets_timed, None, target_attention_bias, hparams.hidden_size, hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, name="self_attention") qv = common_attention.multihead_attention( qv, inputs_encoded, inputs_attention_bias, hparams.hidden_size, hparams.hidden_size, hparams.hidden_size, hparams.num_heads, hparams.attention_dropout, name="encdec_attention") return tf.expand_dims(qv, 2) elif hparams.attention_type == "simple": targets_with_attention = common_layers.simple_attention( targets_timed, inputs_encoded, bias=bias) return norm_fn(targets_shifted + targets_with_attention, name="attn_norm")