def encoder_embeddings(self, x, is_training, embed_input=0): # Embedding if embed_input: enc = embedding(x, vocab_size=config.n_input_vocab, num_units=config.hidden_units, scale=True, scope="enc_embed") else: enc = x ## Positional Encoding feat_dim = shape_list(enc)[-1] # if input features are not same size as transformer units, make a linear projection if not feat_dim == config.hidden_units: enc = tf.layers.dense(enc, config.hidden_units) enc += self.positional_encoding(enc, scope='enc_pe') if embed_input: ## Dropout enc = tf.layers.dropout(enc, rate=config.dropout_rate, training=tf.convert_to_tensor(is_training)) return enc
def symbols_to_logits_fn(ids, i, cache): """Go from ids to logits for next symbol.""" ids = ids[:, -1:] # if i is a tf tensor then its in the while loop so only called once reuse_loop = i > 0 if isinstance(i, int) else None # targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3) targets = ids targets = preprocess_targets(targets, i, reuse_loop=reuse_loop) # -------- For test augmentation, we need to tile the previous input ---------- if config.test_aug_times: targets = tf.tile(tf.expand_dims(targets, axis=0), [config.test_aug_times] + [1] * (len(targets.shape))) bs, beam, t, c = shape_list(targets) targets = tf.reshape(targets, [bs * beam, t, c]) pre_logits, cache = self.decoder_step(targets, cache, is_training, reuse_loop=reuse_loop) with tf.variable_scope(top_scope, reuse=tf.AUTO_REUSE): logits = tf.layers.dense(pre_logits, config.n_labels, reuse=reuse_loop) if config.test_aug_times: logits = tf.reshape(logits, [bs, beam, t, logits.shape[-1]]) logits = tf.reduce_mean(logits, 0) return tf.squeeze(logits, axis=[1]), cache
def __init__(self, input): self.input = model = input aug_opts = {} if config.test_aug_times: # ------------------- With test augmentation, keep first sample the same -------- assert model.shape[0] == 1, 'Test augmentation only with bs=1' no_aug_input = model model = replicate_to_batch(model, config.test_aug_times - 1) aug_opts = { 'horizontal_flip': config.horizontal_flip, 'crop_pixels': config.crop_pixels, } no_aug_out = self.preprocess_and_augment(no_aug_input, aug_opts={}) flip_prob = 0.5 if not config.test_aug_times == 2 else 1 self.aug_out = model = self.preprocess_and_augment(model, aug_opts=aug_opts, flip_prob=flip_prob) if config.test_aug_times: self.aug_out = model = tf.concat([no_aug_out, self.aug_out], 0) # spatio-temporal frontend model = tf.contrib.keras.layers.ZeroPadding3D(padding=(2, 3, 3))(model) model = tf.layers.Conv3D(filters=64, kernel_size=(5, 7, 7), strides=[1, 2, 2], padding='valid', use_bias=False)(model) model = batch_normalization_wrapper(model) model = tf.nn.relu(model) model = tf.contrib.keras.layers.ZeroPadding3D(padding=(0, 1, 1))(model) model = tf.layers.MaxPooling3D(pool_size=(1, 3, 3), strides=(1, 2, 2))(model) # We want to apply the resnet on every timestep, so reshape into a batch of size b*t packed_model = temporal_batch_pack(model, input_shape=K.int_shape(model)[1:]) resnet = resnet_18(packed_model) self.output = temporal_batch_unpack( resnet, shape_list(model)[1], input_shape=K.int_shape(resnet)[1:])
def decoder_body(self, enc, dec, is_training, top_scope=None): assert not is_training, 'Inference graph not to be used for training' inputs = enc batch_size = shape_list(dec)[0] decode_length = shape_list(dec)[1] # Create the positional encodings in advance, # so that we can add them on every time step within the loop timing_signal = self.positional_encoding(dec, scope='dec_pe') timing_signal = tf.expand_dims(timing_signal[0], 0) symbols_to_logits_fn = self.get_symbols_to_logits_fun( enc, timing_signal, is_training, top_scope) # Determine the batch_size of the logits. This will be diffent to the batch size if # we're doing test-time augmentation, as the different augmentations will have been # merged into 1 when we get to the logits logits_bs = 1 if config.test_aug_times else batch_size # Initialize cache and the decoding outputs to be filled in cache = self.initialize_cache(batch_size, enc) decoded_ids = tf.zeros([logits_bs, 0], dtype=tf.int64) decoded_logits = tf.zeros([logits_bs, 0, config.n_labels], dtype=tf.float32) next_id = self.go_token_idx * tf.ones([logits_bs, 1], dtype=tf.int64) # If we are using language model, get the symbols -> logprobs function for it lm_symbols_to_logprobs_fn = None if config.lm_path: lm_symbols_to_logprobs_fn = self.get_lm_symbols_to_logprobs_handle( cache, logits_bs, top_scope) scores = None if config.beam_size > 1: # Beam Search vocab_size = config.n_labels initial_ids = next_id[:, 0] decoded_ids, scores = beam_search.beam_search( symbols_to_logits_fn, lm_symbols_to_logprobs_fn, initial_ids, config.beam_size, decode_length, vocab_size, batch_size=logits_bs, batch_size_states=batch_size, alpha=config.len_alpha, lm_alpha=config.lm_alpha, states=cache, # stop_early=(config.top_beams == 1), stop_early=False) if config.top_beams == 1: decoded_ids = decoded_ids[:, 0, 1:] if StrictVersion(tf.__version__) >= StrictVersion('1.5'): decoded_ids.__dict__['_shape_val'] = tf.TensorShape( [logits_bs, None]) else: decoded_ids._shape = tf.TensorShape([logits_bs, None]) else: decoded_ids = decoded_ids[:, :config.top_beams, 1:] if StrictVersion(tf.__version__) >= StrictVersion('1.5'): decoded_ids.decoded_ids.__dict__['_shape_val'] = \ tf.TensorShape([logits_bs, config.top_beams, None]) else: decoded_ids._shape = tf.TensorShape( [logits_bs, config.top_beams, None]) decoded_logits = tf.zeros( [logits_bs, decode_length, config.n_labels]) else: # Greedy decoding def inner_loop(i, next_id, decoded_ids, decoded_logits, cache): logits, cache = symbols_to_logits_fn(next_id, i, cache) lip_logprobs = log_prob_from_logits(logits, axis=-1) if config.lm_path: lm_logprobs, cache = lm_symbols_to_logprobs_fn( next_id, i, cache) combined_scores = lip_logprobs if config.lm_path: combined_scores += config.lm_alpha * lm_logprobs next_id = tf.to_int64(tf.argmax(combined_scores, axis=-1)) next_id = tf.expand_dims(next_id, axis=1) decoded_ids = tf.concat([decoded_ids, next_id], axis=1) decoded_logits = tf.concat( [decoded_logits, tf.expand_dims(logits, axis=1)], axis=1) return i + 1, next_id, decoded_ids, decoded_logits, cache _, _, decoded_ids, decoded_logits, _ = tf.while_loop( lambda i, *_: tf.less(i, decode_length), inner_loop, [tf.constant(0), next_id, decoded_ids, decoded_logits, cache], shape_invariants=[ tf.TensorShape([]), tf.TensorShape([None, None]), tf.TensorShape([None, None]), tf.TensorShape([None, None, config.n_labels]), nest.map_structure(lambda t: tf.TensorShape(t.shape), cache), ]) return decoded_ids, scores, decoded_logits
def replicate_to_batch(frame, times): replicated = tf.tile(frame, [times * shape_list(frame)[0]] + [1] * (len(frame.shape) - 1)) return replicated