Example #1
0
    def encoder_embeddings(self, x, is_training, embed_input=0):
        # Embedding
        if embed_input:
            enc = embedding(x,
                            vocab_size=config.n_input_vocab,
                            num_units=config.hidden_units,
                            scale=True,
                            scope="enc_embed")
        else:
            enc = x

        ## Positional Encoding
        feat_dim = shape_list(enc)[-1]
        # if input features are not same size as transformer units, make a linear projection
        if not feat_dim == config.hidden_units:
            enc = tf.layers.dense(enc, config.hidden_units)

        enc += self.positional_encoding(enc, scope='enc_pe')

        if embed_input:
            ## Dropout
            enc = tf.layers.dropout(enc,
                                    rate=config.dropout_rate,
                                    training=tf.convert_to_tensor(is_training))
        return enc
Example #2
0
        def symbols_to_logits_fn(ids, i, cache):
            """Go from ids to logits for next symbol."""
            ids = ids[:, -1:]

            # if i is a tf tensor then its in the while loop so only called once
            reuse_loop = i > 0 if isinstance(i, int) else None

            # targets = tf.expand_dims(tf.expand_dims(ids, axis=2), axis=3)
            targets = ids
            targets = preprocess_targets(targets, i, reuse_loop=reuse_loop)

            # -------- For test augmentation, we need to tile the previous input ----------
            if config.test_aug_times:
                targets = tf.tile(tf.expand_dims(targets, axis=0),
                                  [config.test_aug_times] + [1] *
                                  (len(targets.shape)))
                bs, beam, t, c = shape_list(targets)
                targets = tf.reshape(targets, [bs * beam, t, c])

            pre_logits, cache = self.decoder_step(targets,
                                                  cache,
                                                  is_training,
                                                  reuse_loop=reuse_loop)
            with tf.variable_scope(top_scope, reuse=tf.AUTO_REUSE):
                logits = tf.layers.dense(pre_logits,
                                         config.n_labels,
                                         reuse=reuse_loop)

            if config.test_aug_times:
                logits = tf.reshape(logits, [bs, beam, t, logits.shape[-1]])
                logits = tf.reduce_mean(logits, 0)

            return tf.squeeze(logits, axis=[1]), cache
Example #3
0
    def __init__(self, input):

        self.input = model = input

        aug_opts = {}
        if config.test_aug_times:
            # ------------------- With test augmentation, keep first sample  the same --------
            assert model.shape[0] == 1, 'Test augmentation only with bs=1'
            no_aug_input = model
            model = replicate_to_batch(model, config.test_aug_times - 1)

            aug_opts = {
                'horizontal_flip': config.horizontal_flip,
                'crop_pixels': config.crop_pixels,
            }

            no_aug_out = self.preprocess_and_augment(no_aug_input, aug_opts={})

        flip_prob = 0.5 if not config.test_aug_times == 2 else 1
        self.aug_out = model = self.preprocess_and_augment(model,
                                                           aug_opts=aug_opts,
                                                           flip_prob=flip_prob)

        if config.test_aug_times:
            self.aug_out = model = tf.concat([no_aug_out, self.aug_out], 0)

        # spatio-temporal frontend
        model = tf.contrib.keras.layers.ZeroPadding3D(padding=(2, 3, 3))(model)
        model = tf.layers.Conv3D(filters=64,
                                 kernel_size=(5, 7, 7),
                                 strides=[1, 2, 2],
                                 padding='valid',
                                 use_bias=False)(model)

        model = batch_normalization_wrapper(model)
        model = tf.nn.relu(model)
        model = tf.contrib.keras.layers.ZeroPadding3D(padding=(0, 1, 1))(model)
        model = tf.layers.MaxPooling3D(pool_size=(1, 3, 3),
                                       strides=(1, 2, 2))(model)

        # We want to apply the resnet on every timestep, so reshape into a batch of size b*t
        packed_model = temporal_batch_pack(model,
                                           input_shape=K.int_shape(model)[1:])
        resnet = resnet_18(packed_model)
        self.output = temporal_batch_unpack(
            resnet, shape_list(model)[1], input_shape=K.int_shape(resnet)[1:])
Example #4
0
    def decoder_body(self, enc, dec, is_training, top_scope=None):

        assert not is_training, 'Inference graph not to be used for training'

        inputs = enc
        batch_size = shape_list(dec)[0]
        decode_length = shape_list(dec)[1]

        # Create the positional encodings in advance,
        # so that we can add them on every time step within the loop
        timing_signal = self.positional_encoding(dec, scope='dec_pe')
        timing_signal = tf.expand_dims(timing_signal[0], 0)

        symbols_to_logits_fn = self.get_symbols_to_logits_fun(
            enc, timing_signal, is_training, top_scope)

        # Determine the batch_size of the logits. This will be diffent to the batch size if
        # we're doing test-time augmentation, as the different augmentations will have been
        # merged into 1 when we get to the logits
        logits_bs = 1 if config.test_aug_times else batch_size

        # Initialize cache and the decoding outputs to be filled in
        cache = self.initialize_cache(batch_size, enc)
        decoded_ids = tf.zeros([logits_bs, 0], dtype=tf.int64)
        decoded_logits = tf.zeros([logits_bs, 0, config.n_labels],
                                  dtype=tf.float32)
        next_id = self.go_token_idx * tf.ones([logits_bs, 1], dtype=tf.int64)

        # If we are using language model, get the symbols -> logprobs function for it
        lm_symbols_to_logprobs_fn = None
        if config.lm_path:
            lm_symbols_to_logprobs_fn = self.get_lm_symbols_to_logprobs_handle(
                cache, logits_bs, top_scope)

        scores = None

        if config.beam_size > 1:  # Beam Search
            vocab_size = config.n_labels
            initial_ids = next_id[:, 0]
            decoded_ids, scores = beam_search.beam_search(
                symbols_to_logits_fn,
                lm_symbols_to_logprobs_fn,
                initial_ids,
                config.beam_size,
                decode_length,
                vocab_size,
                batch_size=logits_bs,
                batch_size_states=batch_size,
                alpha=config.len_alpha,
                lm_alpha=config.lm_alpha,
                states=cache,
                # stop_early=(config.top_beams == 1),
                stop_early=False)

            if config.top_beams == 1:
                decoded_ids = decoded_ids[:, 0, 1:]
                if StrictVersion(tf.__version__) >= StrictVersion('1.5'):
                    decoded_ids.__dict__['_shape_val'] = tf.TensorShape(
                        [logits_bs, None])
                else:
                    decoded_ids._shape = tf.TensorShape([logits_bs, None])
            else:
                decoded_ids = decoded_ids[:, :config.top_beams, 1:]
                if StrictVersion(tf.__version__) >= StrictVersion('1.5'):
                    decoded_ids.decoded_ids.__dict__['_shape_val'] = \
                                        tf.TensorShape([logits_bs, config.top_beams, None])
                else:
                    decoded_ids._shape = tf.TensorShape(
                        [logits_bs, config.top_beams, None])

            decoded_logits = tf.zeros(
                [logits_bs, decode_length, config.n_labels])

        else:  # Greedy decoding

            def inner_loop(i, next_id, decoded_ids, decoded_logits, cache):
                logits, cache = symbols_to_logits_fn(next_id, i, cache)

                lip_logprobs = log_prob_from_logits(logits, axis=-1)
                if config.lm_path:
                    lm_logprobs, cache = lm_symbols_to_logprobs_fn(
                        next_id, i, cache)
                combined_scores = lip_logprobs
                if config.lm_path:
                    combined_scores += config.lm_alpha * lm_logprobs
                next_id = tf.to_int64(tf.argmax(combined_scores, axis=-1))
                next_id = tf.expand_dims(next_id, axis=1)

                decoded_ids = tf.concat([decoded_ids, next_id], axis=1)
                decoded_logits = tf.concat(
                    [decoded_logits,
                     tf.expand_dims(logits, axis=1)], axis=1)
                return i + 1, next_id, decoded_ids, decoded_logits, cache

            _, _, decoded_ids, decoded_logits, _ = tf.while_loop(
                lambda i, *_: tf.less(i, decode_length),
                inner_loop,
                [tf.constant(0), next_id, decoded_ids, decoded_logits, cache],
                shape_invariants=[
                    tf.TensorShape([]),
                    tf.TensorShape([None, None]),
                    tf.TensorShape([None, None]),
                    tf.TensorShape([None, None, config.n_labels]),
                    nest.map_structure(lambda t: tf.TensorShape(t.shape),
                                       cache),
                ])

        return decoded_ids, scores, decoded_logits
Example #5
0
def replicate_to_batch(frame, times):
    replicated = tf.tile(frame, [times * shape_list(frame)[0]] + [1] *
                         (len(frame.shape) - 1))
    return replicated