Esempio n. 1
0
    def get_features(self,
                     tokenizer,
                     max_length=None,
                     pad_on_left=False,
                     pad_token=0,
                     mask_padding_with_zero=True,
                     return_tensors=None):
        """
        Convert examples in a list of ``InputFeatures``

        Args:
            tokenizer: Instance of a tokenizer that will tokenize the examples
            max_length: Maximum example length
            task: GLUE task
            label_list: List of labels. Can be obtained from the processor using the ``processor.get_labels()`` method
            output_mode: String indicating the output mode. Either ``regression`` or ``classification``
            pad_on_left: If set to ``True``, the examples will be padded on the left rather than on the right (default)
            pad_token: Padding token
            mask_padding_with_zero: If set to ``True``, the attention mask will be filled by ``1`` for actual values
                and by ``0`` for padded values. If set to ``False``, inverts it (``1`` for padded values, ``0`` for
                actual values)

        Returns:
            If the ``examples`` input is a ``tf.data.Dataset``, will return a ``tf.data.Dataset``
            containing the task-specific features. If the input is a list of ``InputExamples``, will return
            a list of task-specific ``InputFeatures`` which can be fed to the model.

        """
        if max_length is None:
            max_length = tokenizer.max_len

        label_map = {label: i for i, label in enumerate(self.labels)}

        all_input_ids = []
        for (ex_index, example) in enumerate(self.examples):
            if ex_index % 10000 == 0:
                logger.info("Tokenizing example %d", ex_index)

            input_ids = tokenizer.encode(
                example.text_a,
                add_special_tokens=True,
                max_length=min(max_length, tokenizer.max_len),
            )
            all_input_ids.append(input_ids)

        batch_length = max(len(input_ids) for input_ids in all_input_ids)

        features = []
        for (ex_index,
             (input_ids,
              example)) in enumerate(zip(all_input_ids, self.examples)):
            if ex_index % 10000 == 0:
                logger.info("Writing example %d", ex_index)
            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            attention_mask = [1 if mask_padding_with_zero else 0
                              ] * len(input_ids)

            # Zero-pad up to the sequence length.
            padding_length = batch_length - len(input_ids)
            if pad_on_left:
                input_ids = ([pad_token] * padding_length) + input_ids
                attention_mask = ([0 if mask_padding_with_zero else 1] *
                                  padding_length) + attention_mask
            else:
                input_ids = input_ids + ([pad_token] * padding_length)
                attention_mask = attention_mask + (
                    [0 if mask_padding_with_zero else 1] * padding_length)

            assert len(
                input_ids
            ) == batch_length, "Error with input length {} vs {}".format(
                len(input_ids), batch_length)
            assert len(
                attention_mask
            ) == batch_length, "Error with input length {} vs {}".format(
                len(attention_mask), batch_length)

            if self.mode == "classification":
                label = label_map[example.label]
            elif self.mode == "regression":
                label = float(example.label)
            else:
                raise ValueError(self.mode)

            if ex_index < 5 and self.verbose:
                logger.info("*** Example ***")
                logger.info("guid: %s" % (example.guid))
                logger.info("input_ids: %s" %
                            " ".join([str(x) for x in input_ids]))
                logger.info("attention_mask: %s" %
                            " ".join([str(x) for x in attention_mask]))
                logger.info("label: %s (id = %d)" % (example.label, label))

            features.append(
                InputFeatures(input_ids=input_ids,
                              attention_mask=attention_mask,
                              label=label))

        if return_tensors is None:
            return features
        elif return_tensors == 'tf':
            if not is_tf_available():
                raise ImportError(
                    "return_tensors set to 'tf' but TensorFlow 2.0 can't be imported"
                )
            import tensorflow as tf

            def gen():
                for ex in features:
                    yield ({
                        'input_ids': ex.input_ids,
                        'attention_mask': ex.attention_mask
                    }, ex.label)

            dataset = tf.data.Dataset.from_generator(
                gen, ({
                    'input_ids': tf.int32,
                    'attention_mask': tf.int32
                }, tf.int64), ({
                    'input_ids': tf.TensorShape([None]),
                    'attention_mask': tf.TensorShape([None])
                }, tf.TensorShape([])))
            return dataset
        elif return_tensors == 'pt':
            if not is_torch_available():
                raise ImportError(
                    "return_tensors set to 'pt' but PyTorch can't be imported")
            import torch
            from torch.utils.data import TensorDataset
            all_input_ids = torch.tensor([f.input_ids for f in features],
                                         dtype=torch.long)
            all_attention_mask = torch.tensor(
                [f.attention_mask for f in features], dtype=torch.long)
            if self.mode == "classification":
                all_labels = torch.tensor([f.label for f in features],
                                          dtype=torch.long)
            elif self.mode == "regression":
                all_labels = torch.tensor([f.label for f in features],
                                          dtype=torch.float)

            dataset = TensorDataset(all_input_ids, all_attention_mask,
                                    all_labels)
            return dataset
        else:
            raise ValueError("return_tensors should be one of 'tf' or 'pt'")
 def test_with_np_int32(self):
   t = computation_types.TensorType(np.int32, [10])
   self.assertEqual(t.dtype, tf.int32)
   self.assertEqual(t.shape, tf.TensorShape([10]))
 def test_with_np_int32_in_dict(self):
   t = computation_types.to_type(collections.OrderedDict([('foo', np.int32)]))
   self.assertIsInstance(t, computation_types.StructType)
   self.assertIsInstance(t.foo, computation_types.TensorType)
   self.assertEqual(t.foo.dtype, tf.int32)
   self.assertEqual(t.foo.shape, tf.TensorShape([]))
Esempio n. 4
0
    def build(self, input_shape):
        input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
        input_tensor_shape = tf.TensorShape(input_tensor)
        if len(input_tensor_shape.as_list()) != 3:
            raise ValueError(
                "TransformerEncoderBlock expects a three-dimensional "
                "input of shape [batch, sequence, width].")
        batch_size, sequence_length, hidden_size = input_tensor_shape

        if len(input_shape) == 2:
            mask_tensor_shape = tf.TensorShape(input_shape[1])
            expected_mask_tensor_shape = tf.TensorShape(
                [batch_size, sequence_length, sequence_length])
            if not expected_mask_tensor_shape.is_compatible_with(
                    mask_tensor_shape):
                raise ValueError(
                    "When passing a mask tensor to "
                    "TransformerEncoderBlock, the mask tensor must be of "
                    "shape [batch, sequence_length, sequence_length] "
                    "(here %s). Got a mask tensor of shape %s." %
                    (expected_mask_tensor_shape, mask_tensor_shape))
        if hidden_size % self._num_heads != 0:
            raise ValueError(
                "The input size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (hidden_size, self._num_heads))
        self._attention_head_size = int(hidden_size // self._num_heads)
        common_kwargs = dict(bias_initializer=self._bias_initializer,
                             kernel_regularizer=self._kernel_regularizer,
                             bias_regularizer=self._bias_regularizer,
                             activity_regularizer=self._activity_regularizer,
                             kernel_constraint=self._kernel_constraint,
                             bias_constraint=self._bias_constraint)
        self._attention_layer = tf.keras.layers.MultiHeadAttention(
            num_heads=self._num_heads,
            key_dim=self._attention_head_size,
            dropout=self._attention_dropout,
            use_bias=self._use_bias,
            kernel_initializer=self._attention_initializer,
            name="self_attention",
            **common_kwargs)
        self._attention_dropout = tf.keras.layers.Dropout(
            rate=self._output_dropout)
        # Use float32 in layernorm for numeric stability.
        # It is probably safe in mixed_float16, but we haven't validated this yet.
        self._attention_layer_norm = (tf.keras.layers.LayerNormalization(
            name="self_attention_layer_norm",
            axis=-1,
            epsilon=self._norm_epsilon,
            dtype=tf.float32))
        self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
            "abc,cd->abd",
            output_shape=(None, self._inner_dim),
            bias_axes="d",
            kernel_initializer=self._kernel_initializer,
            name="intermediate",
            **common_kwargs)
        policy = tf.keras.mixed_precision.experimental.global_policy()
        if policy.name == "mixed_bfloat16":
            # bfloat16 causes BERT with the LAMB optimizer to not converge
            # as well, so we use float32.
            # TODO(b/154538392): Investigate this.
            policy = tf.float32
        self._intermediate_activation_layer = tf.keras.layers.Activation(
            self._inner_activation, dtype=policy)
        self._inner_dropout_layer = tf.keras.layers.Dropout(
            rate=self._inner_dropout)
        self._output_dense = tf.keras.layers.experimental.EinsumDense(
            "abc,cd->abd",
            output_shape=(None, hidden_size),
            bias_axes="d",
            name="output",
            kernel_initializer=self._kernel_initializer,
            **common_kwargs)
        self._output_dropout = tf.keras.layers.Dropout(
            rate=self._output_dropout)
        # Use float32 in layernorm for numeric stability.
        self._output_layer_norm = tf.keras.layers.LayerNormalization(
            name="output_layer_norm",
            axis=-1,
            epsilon=self._norm_epsilon,
            dtype=tf.float32)

        super(TransformerEncoderBlock, self).build(input_shape)
Esempio n. 5
0
def glue_convert_examples_to_features(
    examples,
    tokenizer,
    max_length=512,
    task=None,
    label_list=None,
    output_mode=None,
    pad_on_left=False,
    pad_token=0,
    pad_token_segment_id=0,
    mask_padding_with_zero=True,
):
    """
    Loads a data file into a list of ``InputFeatures``

    Args:
        examples: List of ``InputExamples`` or ``tf.data.Dataset`` containing the examples.
        tokenizer: Instance of a tokenizer that will tokenize the examples
        max_length: Maximum example length
        task: GLUE task
        label_list: List of labels. Can be obtained from the processor using the ``processor.get_labels()`` method
        output_mode: String indicating the output mode. Either ``regression`` or ``classification``
        pad_on_left: If set to ``True``, the examples will be padded on the left rather than on the right (default)
        pad_token: Padding token
        pad_token_segment_id: The segment ID for the padding token (It is usually 0, but can vary such as for XLNet where it is 4)
        mask_padding_with_zero: If set to ``True``, the attention mask will be filled by ``1`` for actual values
            and by ``0`` for padded values. If set to ``False``, inverts it (``1`` for padded values, ``0`` for
            actual values)

    Returns:
        If the ``examples`` input is a ``tf.data.Dataset``, will return a ``tf.data.Dataset``
        containing the task-specific features. If the input is a list of ``InputExamples``, will return
        a list of task-specific ``InputFeatures`` which can be fed to the model.

    """
    is_tf_dataset = False
    if is_tf_available() and isinstance(examples, tf.data.Dataset):
        is_tf_dataset = True

    if task is not None:
        processor = glue_processors[task]()
        if label_list is None:
            label_list = processor.get_labels()
            logger.info("Using label list %s for task %s" % (label_list, task))
        if output_mode is None:
            output_mode = glue_output_modes[task]
            logger.info("Using output mode %s for task %s" % (output_mode, task))

    label_map = {label: i for i, label in enumerate(label_list)}

    features = []
    for (ex_index, example) in enumerate(examples):
        len_examples = 0
        if is_tf_dataset:
            example = processor.get_example_from_tensor_dict(example)
            example = processor.tfds_map(example)
            len_examples = tf.data.experimental.cardinality(examples)
        else:
            len_examples = len(examples)
        if ex_index % 10000 == 0:
            logger.info("Writing example %d/%d" % (ex_index, len_examples))

        inputs = tokenizer.encode_plus(example.text_a, example.text_b, add_special_tokens=True, max_length=max_length,)
        input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

        # Zero-pad up to the sequence length.
        padding_length = max_length - len(input_ids)
        if pad_on_left:
            input_ids = ([pad_token] * padding_length) + input_ids
            attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
            token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids
        else:
            input_ids = input_ids + ([pad_token] * padding_length)
            attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
            token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)

        assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length)
        assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(
            len(attention_mask), max_length
        )
        assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(
            len(token_type_ids), max_length
        )

        if output_mode == "classification":
            label = label_map[example.label]
        elif output_mode == "regression":
            label = float(example.label)
        else:
            raise KeyError(output_mode)

        if ex_index < 5:
            logger.info("*** Example ***")
            logger.info("guid: %s" % (example.guid))
            logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
            logger.info("attention_mask: %s" % " ".join([str(x) for x in attention_mask]))
            logger.info("token_type_ids: %s" % " ".join([str(x) for x in token_type_ids]))
            logger.info("label: %s (id = %d)" % (example.label, label))

        features.append(
            InputFeatures(
                input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, label=label
            )
        )

    if is_tf_available() and is_tf_dataset:

        def gen():
            for ex in features:
                yield (
                    {
                        "input_ids": ex.input_ids,
                        "attention_mask": ex.attention_mask,
                        "token_type_ids": ex.token_type_ids,
                    },
                    ex.label,
                )

        return tf.data.Dataset.from_generator(
            gen,
            ({"input_ids": tf.int32, "attention_mask": tf.int32, "token_type_ids": tf.int32}, tf.int64),
            (
                {
                    "input_ids": tf.TensorShape([None]),
                    "attention_mask": tf.TensorShape([None]),
                    "token_type_ids": tf.TensorShape([None]),
                },
                tf.TensorShape([]),
            ),
        )

    return features
Esempio n. 6
0
 def test_return_correct_batchSize(self):
     tensor_shape = tf.TensorShape(dims=[32, 299, 384, 3])
     self.assertEqual(32, static_shape.get_batch_size(tensor_shape))
 def _event_shape(self):
     return tf.TensorShape([])
 def compute_output_shape(self, input_shape):
     output_shape = list(input_shape)
     output_shape[self.axis] = self.depth_v.value
     print(output_shape)
     return tf.TensorShape(output_shape)
Esempio n. 9
0
 def feature_spec(self):
     class_num = self.taskconf['classes']['num']
     if self.input_type == 'samples':
         # wavforms
         output_shapes = (
             tf.TensorShape([self.example_len]),
             tf.TensorShape([self.max_text_len]),
             tf.TensorShape([]),
             tf.TensorShape([]),
             tf.TensorShape([]),
             tf.TensorShape([class_num]),  # soft_label
         )
     else:
         # features
         output_shapes = (
             tf.TensorShape(self.feature_shape),  # audio_feat (3000, 40, 3)
             tf.TensorShape([self.max_text_len]),  # text
             tf.TensorShape([]),  # label
             tf.TensorShape([]),  # filename
             tf.TensorShape([]),  # clip_id
             tf.TensorShape([class_num]),  # soft_label
         )
     output_types = (
         tf.float32,
         tf.int32,
         tf.int32,
         tf.string,
         tf.int32,
         tf.float32,
     )
     return output_shapes, output_types
Esempio n. 10
0
    def _tf_custom_classifier_convert_examples_to_features(
        examples: tf.data.Dataset,
        tokenizer: PreTrainedTokenizer,
        separator: str,
        custom_info: dict,
        task=str,
        max_length: Optional[int] = None,
    ) -> tf.data.Dataset:
        """
        Returns:
            A ``tf.data.Dataset`` containing the task-specific features.

        """
        def get_example_from_tensor_dict(tensor_dict):
            return InputExample(
                tensor_dict["idx"].numpy(),
                tensor_dict["sentence1"].numpy().decode("utf-8"),
                tensor_dict["sentence2"].numpy().decode("utf-8"),
                str(tensor_dict["label"].numpy()),
            )

        processor = DataProcessor()
        examples = [
            processor.tfds_map(get_example_from_tensor_dict(example))
            for example in examples
        ]
        features = custom_classifier_convert_examples_to_features(
            examples,
            tokenizer,
            separator,
            custom_info,
            max_length=max_length,
            task=task)

        def gen():
            for ex in features:
                yield (
                    {
                        "input_ids": ex.input_ids,
                        "attention_mask": ex.attention_mask,
                        "token_type_ids": ex.token_type_ids,
                    },
                    ex.label,
                )

        return tf.data.Dataset.from_generator(
            gen,
            ({
                "input_ids": tf.int32,
                "attention_mask": tf.int32,
                "token_type_ids": tf.int32
            }, tf.int64),
            (
                {
                    "input_ids": tf.TensorShape([None]),
                    "attention_mask": tf.TensorShape([None]),
                    "token_type_ids": tf.TensorShape([None]),
                },
                tf.TensorShape([]),
            ),
        )
Esempio n. 11
0
    def _length_regulator(self, encoder_hidden_states, durations_gt):
        """Length regulator logic."""
        sum_durations = tf.reduce_sum(durations_gt, axis=-1)  # [batch_size]
        max_durations = tf.reduce_max(sum_durations)

        input_shape = tf.shape(encoder_hidden_states)
        batch_size = input_shape[0]
        hidden_size = input_shape[-1]

        # initialize output hidden states and encoder masking.
        if self.enable_tflite_convertible:
            # There is only 1 batch in inference, so we don't have to use
            # `tf.While` op with 3-D output tensor.
            repeats = durations_gt[0]
            real_length = tf.reduce_sum(repeats)
            pad_size = max_durations - real_length
            # masks : [max_durations]
            masks = tf.sequence_mask([real_length], max_durations, dtype=tf.int32)
            repeat_encoder_hidden_states = tf.repeat(
                encoder_hidden_states[0], repeats=repeats, axis=0
            )
            repeat_encoder_hidden_states = tf.expand_dims(
                tf.pad(repeat_encoder_hidden_states, [[0, pad_size], [0, 0]]), 0
            )  # [1, max_durations, hidden_size]

            outputs = repeat_encoder_hidden_states
            encoder_masks = masks
        else:
            outputs = tf.zeros(shape=[0, max_durations, hidden_size], dtype=tf.float32)
            encoder_masks = tf.zeros(shape=[0, max_durations], dtype=tf.int32)

            def condition(
                i,
                batch_size,
                outputs,
                encoder_masks,
                encoder_hidden_states,
                durations_gt,
                max_durations,
            ):
                return tf.less(i, batch_size)

            def body(
                i,
                batch_size,
                outputs,
                encoder_masks,
                encoder_hidden_states,
                durations_gt,
                max_durations,
            ):
                repeats = durations_gt[i]
                real_length = tf.reduce_sum(repeats)
                pad_size = max_durations - real_length
                masks = tf.sequence_mask([real_length], max_durations, dtype=tf.int32)
                repeat_encoder_hidden_states = tf.repeat(
                    encoder_hidden_states[i], repeats=repeats, axis=0
                )
                repeat_encoder_hidden_states = tf.expand_dims(
                    tf.pad(repeat_encoder_hidden_states, [[0, pad_size], [0, 0]]), 0
                )  # [1, max_durations, hidden_size]
                outputs = tf.concat([outputs, repeat_encoder_hidden_states], axis=0)
                encoder_masks = tf.concat([encoder_masks, masks], axis=0)
                return [
                    i + 1,
                    batch_size,
                    outputs,
                    encoder_masks,
                    encoder_hidden_states,
                    durations_gt,
                    max_durations,
                ]

            # initialize iteration i.
            i = tf.constant(0, dtype=tf.int32)
            _, _, outputs, encoder_masks, _, _, _, = tf.while_loop(
                condition,
                body,
                [
                    i,
                    batch_size,
                    outputs,
                    encoder_masks,
                    encoder_hidden_states,
                    durations_gt,
                    max_durations,
                ],
                shape_invariants=[
                    i.get_shape(),
                    batch_size.get_shape(),
                    tf.TensorShape(
                        [
                            None,
                            None,
                            self.config.encoder_self_attention_params.hidden_size,
                        ]
                    ),
                    tf.TensorShape([None, None]),
                    encoder_hidden_states.get_shape(),
                    durations_gt.get_shape(),
                    max_durations.get_shape(),
                ],
            )

        return outputs, encoder_masks
Esempio n. 12
0
    def _get_data(self, filenames, split_name, **config):
        def _gen_shape():
            primitives = parse_primitives(config['primitives'],
                                          self.drawing_primitives)
            while True:
                primitive = np.random.choice(primitives)
                image = synthetic_dataset.generate_background(
                    config['generation']['image_size'],
                    **config['generation']['params']['generate_background'])
                points = np.array(
                    getattr(synthetic_dataset,
                            primitive)(image,
                                       **config['generation']['params'].get(
                                           primitive, {})))
                yield (np.expand_dims(image, axis=-1).astype(np.float32),
                       np.flip(points.astype(np.float32), 1))

        def _read_image(filename):
            image = tf.read_file(filename)
            image = tf.image.decode_png(image, channels=1)
            return tf.cast(image, tf.float32)

        # Python function
        def _read_points(filename):
            return np.load(filename.decode('utf-8')).astype(np.float32)

        if config['on-the-fly']:
            data = tf.data.Dataset.from_generator(
                _gen_shape, (tf.float32, tf.float32),
                (tf.TensorShape(config['generation']['image_size'] + [1]),
                 tf.TensorShape([None, 2])))
            data = data.map(lambda i, c: pipeline.downsample(
                i, c, **config['preprocessing']))
        else:
            # Initialize dataset with file names
            data = tf.data.Dataset.from_tensor_slices(
                (filenames[split_name]['images'],
                 filenames[split_name]['points']))
            # Read image and point coordinates
            data = data.map(lambda image, points: (_read_image(
                image), tf.py_func(_read_points, [points], tf.float32)))
            data = data.map(lambda image, points:
                            (image, tf.reshape(points, [-1, 2])))

        if split_name == 'validation':
            data = data.take(config['validation_size'])
        elif split_name == 'test':
            data = data.take(config['test_size'])

        data = data.map(lambda image, kp: {'image': image, 'keypoints': kp})
        data = data.map(pipeline.add_dummy_valid_mask)

        if config['cache_in_memory'] and not config['on-the-fly']:
            tf.compat.v1.logging.info(
                'Caching data, fist access will take some time.')
            data = data.cache()

        # Apply augmentation
        if split_name == 'training' or config['add_augmentation_to_test_set']:
            if config['augmentation']['photometric']['enable']:
                data = data.map_parallel(
                    lambda d: pipeline.photometric_augmentation(
                        d, **config['augmentation']['photometric']))
            if config['augmentation']['homographic']['enable']:
                data = data.map_parallel(
                    lambda d: pipeline.homographic_augmentation(
                        d, **config['augmentation']['homographic']))

        # Convert the point coordinates to a dense keypoint map
        data = data.map_parallel(pipeline.add_keypoint_map)

        return data
Esempio n. 13
0
 def output_shapes(self):
     return (tensorflow.TensorShape([]), tensorflow.TensorShape([]))
@author: landon
"""

import tensorflow as tf
import numpy as np
import json
import argparse
import os
import time

HAS_TITLE = [1,1,0,1,1,1,1,1,1]
N_TRACKS = [0,5,5,10,25,25,100,100,1]
IS_RANDOM = [0,0,0,0,0,1,0,1,0]

TENSOR_SPEC = tf.RaggedTensorSpec(tf.TensorShape([4, None]), tf.int32, 1, tf.int64)



def delete_tensor_by_indices(tensor,indices,n_tracks):
    idxs = tf.reshape(indices,(-1,1))
    mask = ~tf.scatter_nd(indices=idxs,updates=tf.ones_like(indices,dtype=tf.bool),shape=[n_tracks])
    return tf.boolean_mask(tensor,mask)

@tf.autograph.experimental.do_not_convert
def map_func(x):
    return {
        'track_ids':x[0],
        'title_ids':x[1],
        'n_tracks':x[2][0],
        }
Esempio n. 15
0
    def _common(cls, node, **kwargs):
        tensor_dict = kwargs["tensor_dict"]
        boxes = tensor_dict[node.inputs[0]]
        scores = tensor_dict[node.inputs[1]]
        # in ONNX spec max_output_boxes_per_class need to be in int64 but
        # max_output_boxes for tf.image.non_max_suppression must be in tf.int32
        # therefore need to cast this input to tf.int32
        max_output_boxes_per_class = tf.cast(
            tensor_dict['max_output_boxes_per_class'], tf.int32
        ) if 'max_output_boxes_per_class' in tensor_dict else tf.cast(
            boxes.shape[1], tf.int32)
        iou_threshold = tensor_dict[
            'iou_threshold'] if 'iou_threshold' in tensor_dict else tf.constant(
                [0.5], tf.float32)
        score_threshold = tensor_dict[
            'score_threshold'] if 'score_threshold' in tensor_dict else tf.constant(
                [float('-inf')], tf.float32)
        center_point_box = node.attrs.get("center_point_box", 0)

        if center_point_box == 1:
            boxes_t = tf.transpose(boxes, perm=[0, 2, 1])
            x_centers = tf.slice(boxes_t, [0, 0, 0], [-1, 1, -1])
            y_centers = tf.slice(boxes_t, [0, 1, 0], [-1, 1, -1])
            widths = tf.slice(boxes_t, [0, 2, 0], [-1, 1, -1])
            heights = tf.slice(boxes_t, [0, 3, 0], [-1, 1, -1])
            y1 = tf.subtract(y_centers, tf.divide(heights, 2))
            x1 = tf.subtract(x_centers, tf.divide(widths, 2))
            y2 = tf.add(y_centers, tf.divide(heights, 2))
            x2 = tf.add(x_centers, tf.divide(widths, 2))
            boxes_t = tf.concat([y1, x1, y2, x2], 1)
            boxes = tf.transpose(boxes_t, perm=[0, 2, 1])

        @tf.function
        def create_nodes(boxes, scores, max_output_boxes_per_class,
                         iou_threshold, score_threshold, result):
            # get number of batches in boxes
            num_batches = tf.shape(boxes)[0]
            for batch_i in tf.range(num_batches):
                # get boxes in batch_i only
                tf_boxes = tf.squeeze(tf.gather(boxes, [batch_i]), axis=0)
                # get scores of all classes in batch_i only
                batch_i_scores = tf.squeeze(tf.gather(scores, [batch_i]),
                                            axis=0)
                # get number of classess in batch_i only
                num_classes = tf.shape(batch_i_scores)[0]
                for class_j in tf.range(num_classes):
                    # get scores in class_j for batch_i only
                    tf_scores = tf.squeeze(tf.gather(batch_i_scores,
                                                     [class_j]),
                                           axis=0)
                    # get the selected boxes indices
                    selected_indices = tf.image.non_max_suppression(
                        tf_boxes, tf_scores, max_output_boxes_per_class[0],
                        iou_threshold[0], score_threshold[0])
                    # add batch and class information into the indices
                    output = tf.transpose(
                        [tf.cast(selected_indices, dtype=tf.int64)])
                    paddings = tf.constant([[0, 0], [1, 0]])
                    output = tf.pad(output,
                                    paddings,
                                    constant_values=tf.cast(class_j,
                                                            dtype=tf.int64))
                    output = tf.pad(output,
                                    paddings,
                                    constant_values=tf.cast(batch_i,
                                                            dtype=tf.int64))
                    # tf.function will auto convert "result" from variable to placeholder
                    # therefore don't need to use assign here
                    result = output if tf.equal(batch_i, 0) and tf.equal(
                        class_j, 0) else tf.concat([result, output], 0)

            return result

        # Since tf.function doesn't support locals() and it require all the variables
        # are defined before use in the "for loop" before it will perform any auto
        # convertion of the python code. Therefore need to define "result" as a
        # Variable here and send it in as a parameter to "create_nodes"
        result = tf.Variable([[0, 0, 0]],
                             dtype=tf.int64,
                             shape=tf.TensorShape([None, 3]))
        return [
            create_nodes(boxes, scores, max_output_boxes_per_class,
                         iou_threshold, score_threshold, result)
        ]
Esempio n. 16
0
    def train_input_fn(self, path, batch_size):
        """ Make a Tensorflow dataset that is shuffled, batched and parsed
        Args:
            path: path of the record file to unpack and read
            batch_size: Size of the batch for training
        Returns:
            A dataset that is shuffled and padded
        """

        if not os.path.isfile(path):
            raise Exception('ERROR: Provided path is not a file')

        def _parse(ex):
            """ Explain to TF how to go back from a serialized example to tensors
            Args:
                ex: An example
            Returns:
                A dictionary of tensors
            """
            # Define how to parse the example
            context_features = {
                "source_len": tf.FixedLenFeature([], dtype=tf.int64),
                "target_len": tf.FixedLenFeature([], dtype=tf.int64)
            }
            sequence_features = {
                "source_seq": tf.FixedLenSequenceFeature([], dtype=tf.int64),
                "target_seq": tf.FixedLenSequenceFeature([], dtype=tf.int64),
                "decoder_tgt": tf.FixedLenSequenceFeature([], dtype=tf.int64)
            }
            #Parse the example and return dict of tensors
            context_parsed, sequence_parsed = tf.parse_single_sequence_example(
                serialized=ex,
                context_features=context_features,
                sequence_features=sequence_features)

            return {
                "source_seq": sequence_parsed["source_seq"],
                "source_len": context_parsed["source_len"]
            }, {
                "target_seq": sequence_parsed["target_seq"],
                "target_len": context_parsed["target_len"],
                "decoder_tgt": sequence_parsed["decoder_tgt"]
            }

        dataset = tf.data.TFRecordDataset([path], num_parallel_reads=4).map(
            _parse, num_parallel_calls=10).shuffle(buffer_size=2 * batch_size +
                                                   1).repeat(None)

        padded_shapes = (
            {
                "source_seq":
                tf.TensorShape([None]),  # pads to largest sentence in batch
                "source_len": tf.TensorShape([])
            },  # No padding
            {
                "target_seq": tf.TensorShape([None]),
                "target_len": tf.TensorShape([]),
                "decoder_tgt": tf.TensorShape([None])
            })

        dataset = dataset.padded_batch(batch_size, padded_shapes=padded_shapes)

        # enables pipelines
        dataset = dataset.prefetch(2)

        return dataset
Esempio n. 17
0
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix, save_weights_only=True)

# Executing training
# EPOCHS=10
# history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])

tf.train.latest_checkpoint(checkpoint_dir)

model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=1)

model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

model.build(tf.TensorShape([1, None]))

model.summary()


def generate_text(model, start_string):
    # Evaluation step (generating text using the learned model)

    # Number of characters to generate
    num_generate = 500

    # Converting our start string to numbers (vectorizing)
    input_eval = [char2idx[s] for s in start_string]
    input_eval = tf.expand_dims(input_eval, 0)

    # Empty string to store our results
Esempio n. 18
0
def squad_convert_examples_to_features(
    examples,
    tokenizer,
    max_seq_length,
    doc_stride,
    max_query_length,
    is_training,
    padding_strategy="max_length",
    return_dataset=False,
    threads=1,
    tqdm_enabled=True,
):
    """
    Converts a list of examples into a list of features that can be directly given as input to a model. It is
    model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
    Args:
        examples: list of :class:`~transformers.data.processors.squad.SquadExample`
        tokenizer: an instance of a child of :class:`~transformers.PreTrainedTokenizer`
        max_seq_length: The maximum sequence length of the inputs.
        doc_stride: The stride used when the context is too large and is split across several features.
        max_query_length: The maximum length of the query.
        is_training: whether to create features for model evaluation or model training.
        padding_strategy: Default to "max_length". Which padding strategy to use
        return_dataset: Default False. Either 'pt' or 'tf'.
            if 'pt': returns a torch.data.TensorDataset, if 'tf': returns a tf.data.Dataset
        threads: multiple processing threads.
    Returns:
        list of :class:`~transformers.data.processors.squad.SquadFeatures`
    Example::
        processor = SquadV2Processor()
        examples = processor.get_dev_examples(data_dir)
        features = squad_convert_examples_to_features(
            examples=examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            doc_stride=args.doc_stride,
            max_query_length=args.max_query_length,
            is_training=not evaluate,
        )
    """
    # Defining helper methods
    features = []

    threads = min(threads, cpu_count())
    with Pool(threads,
              initializer=squad_convert_example_to_features_init,
              initargs=(tokenizer, )) as p:
        annotate_ = partial(
            squad_convert_example_to_features,
            max_seq_length=max_seq_length,
            doc_stride=doc_stride,
            max_query_length=max_query_length,
            padding_strategy=padding_strategy,
            is_training=is_training,
        )
        features = list(
            tqdm(
                p.imap(annotate_, examples, chunksize=32),
                total=len(examples),
                desc="convert squad examples to features",
                disable=not tqdm_enabled,
            ))

    new_features = []
    unique_id = 1000000000
    example_index = 0
    for example_features in tqdm(features,
                                 total=len(features),
                                 desc="add example index and unique id",
                                 disable=not tqdm_enabled):
        if not example_features:
            continue
        for example_feature in example_features:
            example_feature.example_index = example_index
            example_feature.unique_id = unique_id
            new_features.append(example_feature)
            unique_id += 1
        example_index += 1
    features = new_features
    del new_features
    if return_dataset == "pt":
        if not is_torch_available():
            raise RuntimeError(
                "PyTorch must be installed to return a PyTorch dataset.")

        # Convert to Tensors and build dataset
        all_input_ids = torch.tensor([f.input_ids for f in features],
                                     dtype=torch.long)
        all_attention_masks = torch.tensor(
            [f.attention_mask for f in features], dtype=torch.long)
        all_token_type_ids = torch.tensor([f.token_type_ids for f in features],
                                          dtype=torch.long)
        all_cls_index = torch.tensor([f.cls_index for f in features],
                                     dtype=torch.long)
        all_p_mask = torch.tensor([f.p_mask for f in features],
                                  dtype=torch.float)
        all_is_impossible = torch.tensor([f.is_impossible for f in features],
                                         dtype=torch.float)

        if not is_training:
            all_feature_index = torch.arange(all_input_ids.size(0),
                                             dtype=torch.long)
            dataset = TensorDataset(all_input_ids, all_attention_masks,
                                    all_token_type_ids, all_feature_index,
                                    all_cls_index, all_p_mask)
        else:
            all_start_positions = torch.tensor(
                [f.start_position for f in features], dtype=torch.long)
            all_end_positions = torch.tensor(
                [f.end_position for f in features], dtype=torch.long)
            dataset = TensorDataset(
                all_input_ids,
                all_attention_masks,
                all_token_type_ids,
                all_start_positions,
                all_end_positions,
                all_cls_index,
                all_p_mask,
                all_is_impossible,
            )

        return features, dataset
    elif return_dataset == "tf":
        if not is_tf_available():
            raise RuntimeError(
                "TensorFlow must be installed to return a TensorFlow dataset.")

        def gen():
            for i, ex in enumerate(features):
                if ex.token_type_ids is None:
                    yield (
                        {
                            "input_ids": ex.input_ids,
                            "attention_mask": ex.attention_mask,
                            "feature_index": i,
                            "qas_id": ex.qas_id,
                        },
                        {
                            "start_positions": ex.start_position,
                            "end_positions": ex.end_position,
                            "cls_index": ex.cls_index,
                            "p_mask": ex.p_mask,
                            "is_impossible": ex.is_impossible,
                        },
                    )
                else:
                    yield (
                        {
                            "input_ids": ex.input_ids,
                            "attention_mask": ex.attention_mask,
                            "token_type_ids": ex.token_type_ids,
                            "feature_index": i,
                            "qas_id": ex.qas_id,
                        },
                        {
                            "start_positions": ex.start_position,
                            "end_positions": ex.end_position,
                            "cls_index": ex.cls_index,
                            "p_mask": ex.p_mask,
                            "is_impossible": ex.is_impossible,
                        },
                    )

        # Why have we split the batch into a tuple? PyTorch just has a list of tensors.
        if "token_type_ids" in tokenizer.model_input_names:
            train_types = (
                {
                    "input_ids": tf.int32,
                    "attention_mask": tf.int32,
                    "token_type_ids": tf.int32,
                    "feature_index": tf.int64,
                    "qas_id": tf.string,
                },
                {
                    "start_positions": tf.int64,
                    "end_positions": tf.int64,
                    "cls_index": tf.int64,
                    "p_mask": tf.int32,
                    "is_impossible": tf.int32,
                },
            )

            train_shapes = (
                {
                    "input_ids": tf.TensorShape([None]),
                    "attention_mask": tf.TensorShape([None]),
                    "token_type_ids": tf.TensorShape([None]),
                    "feature_index": tf.TensorShape([]),
                    "qas_id": tf.TensorShape([]),
                },
                {
                    "start_positions": tf.TensorShape([]),
                    "end_positions": tf.TensorShape([]),
                    "cls_index": tf.TensorShape([]),
                    "p_mask": tf.TensorShape([None]),
                    "is_impossible": tf.TensorShape([]),
                },
            )
        else:
            train_types = (
                {
                    "input_ids": tf.int32,
                    "attention_mask": tf.int32,
                    "feature_index": tf.int64,
                    "qas_id": tf.string
                },
                {
                    "start_positions": tf.int64,
                    "end_positions": tf.int64,
                    "cls_index": tf.int64,
                    "p_mask": tf.int32,
                    "is_impossible": tf.int32,
                },
            )

            train_shapes = (
                {
                    "input_ids": tf.TensorShape([None]),
                    "attention_mask": tf.TensorShape([None]),
                    "feature_index": tf.TensorShape([]),
                    "qas_id": tf.TensorShape([]),
                },
                {
                    "start_positions": tf.TensorShape([]),
                    "end_positions": tf.TensorShape([]),
                    "cls_index": tf.TensorShape([]),
                    "p_mask": tf.TensorShape([None]),
                    "is_impossible": tf.TensorShape([]),
                },
            )

        return tf.data.Dataset.from_generator(gen, train_types, train_shapes)
    else:
        return features
Esempio n. 19
0
 def test_return_correct_depth(self):
     tensor_shape = tf.TensorShape(dims=[32, 299, 384, 3])
     self.assertEqual(3, static_shape.get_depth(tensor_shape))
Esempio n. 20
0
def MovingAvgQuantize(inputs,
                      per_channel=False,
                      init_min=-6.0,
                      init_max=6.0,
                      ema_decay=0.999,
                      updates_collection=ops.GraphKeys.UPDATE_OPS,
                      vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
                      name_prefix='MovingAvgQuantize',
                      reuse=None,
                      is_training=True,
                      num_bits=8,
                      narrow_range=False):
  """Adds a layer that collects quantization ranges as EMAs of input ranges.

  MovingAvgQuantize creates variables called 'min' and 'max', representing the
  interval used for quantization and clamping.

  Args:
    inputs: a tensor containing values to be quantized.
    per_channel: (default False) a boolean specifying whether to use different
      quantization ranges per output channel.
    init_min: a float scalar, the initial value for variable min.
    init_max: a float scalar, the initial value for variable max.
    ema_decay: EMA decay parameter.
    updates_collection: (Optional) collections to collect the update ops for
      computation.
    vars_collection: (Optional) collection where to store variables for
      quantization interval ends.
    name_prefix: name_prefix for created nodes.
    reuse: whether or not the layer and its variables should be reused. To be
      able to reuse the layer scope must be given.
    is_training: Whether the op is applied to a training or eval graph.
    num_bits: Number of bits to use for quantization, must be between 2 and 8.
    narrow_range: Whether to use the narrow quantization range
      [1; 2^num_bits - 1] or wide range [0; 2^num_bits - 1].
  Returns:
    a tensor containing quantized values.
  """
  with variable_scope.variable_scope(
      None, default_name=name_prefix, values=[inputs], reuse=reuse):
    input_shape = inputs.get_shape()
    input_dim = 4 if inputs.shape == tf.TensorShape(None) else len(input_shape) # FIXME (Chia-Lin)
    if per_channel:
      # Only support quantizing 1-, 2- and 4-dimensional tensors.
      assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in '
                                      ' scope: %s' % (input_shape, name_prefix))
      min_max_shape = [input_shape[-1]]
    else:
      min_max_shape = []

    min_var = model_variable(
        'min',
        shape=min_max_shape,
        initializer=init_ops.constant_initializer(init_min),
        collections=[vars_collection],
        trainable=False)
    max_var = model_variable(
        'max',
        shape=min_max_shape,
        initializer=init_ops.constant_initializer(init_max),
        collections=[vars_collection],
        trainable=False)
    if not is_training:
      return _FakeQuantWithMinMaxVars(
          inputs,
          min_var,
          max_var,
          per_channel=per_channel,
          num_bits=num_bits,
          narrow_range=narrow_range)
    if per_channel:
      if input_dim == 2:
        reduce_dims = [0]
      elif input_dim == 4:
        reduce_dims = [0, 1, 2]

    if per_channel:
      if input_dim >= 2:
        batch_min = math_ops.reduce_min(
            inputs, reduction_indices=reduce_dims, name='BatchMin')
      else:
        batch_min = inputs
    else:
      batch_min = math_ops.reduce_min(inputs, name='BatchMin')
    # B-eng requires that 0.0 if always in the [min; max] range.
    batch_min = math_ops.minimum(batch_min, 0.0)
    assign_min = moving_averages.assign_moving_average(
        min_var, batch_min, ema_decay, name='AssignMinEma')
    ops.add_to_collection(updates_collection, assign_min.op)

    if per_channel:
      if input_dim >= 2:
        batch_max = math_ops.reduce_max(
            inputs, reduction_indices=reduce_dims, name='BatchMax')
      else:
        batch_max = inputs
    else:
      batch_max = math_ops.reduce_max(inputs, name='BatchMax')
    # B-eng requires that 0.0 if always in the [min; max] range.
    batch_max = math_ops.maximum(batch_max, 0.0)
    assign_max = moving_averages.assign_moving_average(
        max_var, batch_max, ema_decay, name='AssignMaxEma')
    ops.add_to_collection(updates_collection, assign_max.op)

    return _FakeQuantWithMinMaxVars(
        inputs,
        assign_min,
        assign_max,
        per_channel=per_channel,
        num_bits=num_bits,
        narrow_range=narrow_range)
 def compute_output_shape(self, input_shape):
     """ Outputs produced by the layer """
     return [
         tf.TensorShape((input_shape[1][0], input_shape[1][1], input_shape[1][2])),
         tf.TensorShape((input_shape[1][0], input_shape[1][1], input_shape[0][1]))
     ]
Esempio n. 22
0
 def output_shapes(self):
   return tuple([
       tensorflow.TensorShape([])]) if self._batch == 0 else tuple([
           tensorflow.TensorShape([None])])
Esempio n. 23
0
    next_Y_ph = tf.placeholder(tf.float32, [config.batch_size, action_num],
                               name="next_Y_ph")
    reward_ph = tf.placeholder(tf.float32, [config.batch_size],
                               name="reward_ph")

    ph_lst = [input_state_ph, action_ph, Y_ph, next_Y_ph, reward_ph]

    q = tf.FIFOQueue(2, [ph.dtype for ph in ph_lst],
                     [ph.get_shape() for ph in ph_lst])
    enqueue_op = q.enqueue(ph_lst)
    input_state, action, Y, next_Y, reward = q.dequeue()

    # so that i can feed inputs with different batch sizes.
    input_state = tf.placeholder_with_default(
        input_state,
        shape=tf.TensorShape([None]).concatenate(input_state.get_shape()[1:]))
    action = tf.placeholder_with_default(action, shape=[None])
    next_input_state_ph = tf.placeholder(tf.float32,
                                         [config.batch_size, 84, 84, 4],
                                         name="next_input_state_placeholder")

    with tf.variable_scope("DQN"):
        Q, R, predicted_next_Q = createQNetwork(input_state, action, config,
                                                "DQN")
        DQN_params = tf.get_collection("DQN_weights")
        max_action_DQN = tf.argmax(Q, 1)
    with tf.variable_scope("DQNTarget"):
        # pasing an action is useless because the target never runs the next_Y_prediction but it is needed for the code to work
        QT, RT, predicted_next_QT = createQNetwork(next_input_state_ph, action,
                                                   config, "DQNT")
        DQNT_params = tf.get_collection("DQNT_weights")
 def compute_output_shape(self, input_shape):
     shape = tf.TensorShape(input_shape).as_list()
     shape[-1] = self.output_dim
     return tf.TensorShape(shape)
Esempio n. 25
0
class SharpMask(resnet_model.Model):
    mask_size = 224
    types = {'score': tf.float32, 'mask': tf.int8, 'image': tf.float32}
    shapes = {
        'score': tf.TensorShape([None]),
        'mask': tf.TensorShape([None, mask_size, mask_size]),
        'image': tf.TensorShape([None, mask_size, mask_size, 3])
    }

    def __init__(self,
                 train_path,
                 validation_path,
                 session=None,
                 resnet_ckpt=None,
                 summary_path=None,
                 checkpoint_path=None,
                 batch_size=32):
        super(SharpMask, self).__init__(resnet_size=50,
                                        bottleneck=True,
                                        num_classes=1001,
                                        num_filters=64,
                                        kernel_size=7,
                                        conv_stride=2,
                                        first_pool_size=3,
                                        first_pool_stride=2,
                                        second_pool_size=7,
                                        second_pool_stride=1,
                                        block_sizes=[3, 4, 6, 3],
                                        block_strides=[1, 2, 2, 2],
                                        final_size=2048,
                                        version=resnet_model.DEFAULT_VERSION,
                                        data_format=None,
                                        dtype=resnet_model.DEFAULT_DTYPE)
        if session is None:
            self.sess = tf.Session()
        else:
            self.sess = session

        it_structure = tf.data.Iterator.from_structure(self.types, self.shapes)
        self.iterator = it_structure.get_next()

        self.image_placeholder = tf.placeholder_with_default("", shape=())

        self.image_input = self.iterator['image']
        self.score_target = self.iterator['score']
        self.seg_target = self.iterator['mask']

        self.score_placeholder = tf.placeholder_with_default([1.0], (1, ))
        self.mask_placeholder = tf.placeholder_with_default(
            tf.ones((1, self.mask_size, self.mask_size), dtype=tf.int8),
            (1, self.mask_size, self.mask_size))
        dummy_ds = tf.data.Dataset.from_tensor_slices({
            'image':
            tf.expand_dims(transform_image(self.image_placeholder), 0),
            'score':
            self.score_placeholder,
            'mask':
            self.mask_placeholder
        }).map(
            lambda x: {
                'score': tf.expand_dims(x['score'], 0),
                'mask': tf.expand_dims(x['mask'], 0),
                'image': tf.expand_dims(x['image'], 0)
            })
        self.placeholder_init_op = it_structure.make_initializer(dummy_ds)

        if train_path is not None:
            self.train_ds = self._create_dataset(train_path, batch_size)
            self.training_init_op = it_structure.make_initializer(
                self.train_ds)

        if validation_path is not None:
            self.validation_ds = self._create_dataset(validation_path,
                                                      batch_size)
            self.validation_init_op = it_structure.make_initializer(
                self.validation_ds)

        if summary_path is not None:
            self.summary_writer = tf.summary.FileWriter(
                summary_path, self.sess.graph)

        if not os.path.exists(checkpoint_path):
            os.makedirs(checkpoint_path)

        self.checkpoint_file = os.path.join(checkpoint_path, 'sharpmask.ckpt')

        self.resnet_output = self(self.image_input, False)

        if resnet_ckpt is not None:
            saver = tf.train.Saver()
            saver.restore(self.sess, resnet_ckpt)

        self.block_layers = [
            self.sess.graph.get_tensor_by_name(
                "resnet_model/block_layer{}:0".format(i + 1)) for i in range(4)
        ]

        self.training_mode = tf.placeholder_with_default(True, shape=())

        with tf.variable_scope("deepmask_trunk"):
            trunk = tf.layers.conv2d(self.block_layers[-1],
                                     512, (1, 1),
                                     activation=tf.nn.relu,
                                     data_format=self.data_format)
            trunk = tf.layers.flatten(trunk)
            trunk = tf.layers.dense(trunk, 512)
        self.sess.run(
            tf.variables_initializer(
                tf.get_collection(key=tf.GraphKeys.GLOBAL_VARIABLES,
                                  scope='deepmask_trunk')))

        with tf.variable_scope("segmentation_branch"):
            seg_predictions = tf.layers.dense(trunk, 56 * 56)
            seg_predictions = tf.reshape(seg_predictions, [-1, 56, 56, 1])
            self.dm_seg_prediction = tf.squeeze(
                tf.image.resize_bilinear(seg_predictions, [224, 224]), 3)

        self.sess.run(
            tf.variables_initializer(
                tf.get_collection(key=tf.GraphKeys.GLOBAL_VARIABLES,
                                  scope='segmentation_branch')))

        with tf.variable_scope("score_branch"):
            score_predictions = tf.layers.dropout(trunk,
                                                  rate=0.5,
                                                  training=self.training_mode)
            score_predictions = tf.layers.dense(score_predictions,
                                                1024,
                                                activation=tf.nn.relu)
            score_predictions = tf.layers.dropout(score_predictions,
                                                  rate=0.5,
                                                  training=self.training_mode)
            self.score_predictions = tf.layers.dense(score_predictions,
                                                     1,
                                                     name='score_out')

        self.sess.run(
            tf.variables_initializer(
                tf.get_collection(key=tf.GraphKeys.GLOBAL_VARIABLES,
                                  scope='score_branch')))

        #self.saver = tf.train.Saver()

        k = 32
        with tf.variable_scope("refinement"):
            M = tf.layers.dense(trunk, k * 7 * 7, name='vertical_0')
            M = tf.reshape(
                M, [-1, k, 7, 7
                    ]) if self.data_format == "channels_first" else tf.reshape(
                        M, [-1, 7, 7, k])

            for i in range(1, 5):
                ki = int(k / 2**(i - 1))
                knext = int(ki / 2)

                F = self.block_layers[4 - i]
                S = tf.layers.conv2d(F,
                                     64 if i < 4 else 32, (3, 3),
                                     padding='SAME',
                                     activation=tf.nn.relu,
                                     data_format=self.data_format,
                                     name='horizontal_{}_64'.format(i))

                S = tf.layers.conv2d(S,
                                     ki, (3, 3),
                                     padding='SAME',
                                     activation=tf.nn.relu,
                                     data_format=self.data_format,
                                     name='horizontal_{}_{}'.format(i, ki))
                S = tf.layers.conv2d(S,
                                     knext, (3, 3),
                                     padding='SAME',
                                     data_format=self.data_format,
                                     name='horizontal_{}_{}'.format(i, knext))

                M = tf.layers.conv2d(M,
                                     k / 2**(i - 1), (3, 3),
                                     padding='SAME',
                                     activation=tf.nn.relu,
                                     data_format=self.data_format,
                                     name='vertical_{}_{}'.format(i, ki))
                M = tf.layers.conv2d(M,
                                     knext, (3, 3),
                                     padding='SAME',
                                     data_format=self.data_format,
                                     name='vertical_{}_{}'.format(i, knext))

                M = tf.nn.relu(S + M)
                if self.data_format == "channels_first":
                    M = tf.transpose(M, perm=[0, 2, 3, 1])
                    M = tf.image.resize_bilinear(
                        M, [M.shape[1] * 2, M.shape[2] * 2])
                    M = tf.transpose(M, perm=[0, 3, 1, 2])
                else:
                    M = tf.image.resize_bilinear(
                        M, [M.shape[1] * 2, M.shape[2] * 2])

            refinement_out = tf.layers.conv2d(M,
                                              1, (3, 3),
                                              padding='SAME',
                                              data_format=self.data_format,
                                              name='refinement_out')
            if self.data_format == "channels_first":
                refinement_out = tf.transpose(refinement_out,
                                              perm=[0, 2, 3, 1])

            refinement_out = tf.image.resize_bilinear(
                refinement_out,
                [refinement_out.shape[1] * 2, refinement_out.shape[2] * 2])
            refinement_out = tf.squeeze(refinement_out, axis=3)
            self.refinement_prediction = refinement_out

        self.sess.run(
            tf.variables_initializer(
                tf.get_collection(key=tf.GraphKeys.GLOBAL_VARIABLES,
                                  scope='refinement')))

        with tf.variable_scope("metrics"):
            score_metric_prediction = tf.where(
                self.score_predictions > 0.0,
                tf.ones_like(self.score_predictions),
                -tf.ones_like(self.score_predictions))
            self.score_accuracy_metric, self.score_accuracy_update = tf.metrics.accuracy(
                self.score_target, score_metric_prediction)

            self.dm_seg_iou_metric, self.dm_seg_iou_update = self._create_seg_metrics(
                self.dm_seg_prediction)
            self.sm_seg_iou_metric, self.sm_seg_iou_update = self._create_seg_metrics(
                self.refinement_prediction)

        self.saver = tf.train.Saver()

    def restore(self):
        self.saver.restore(self.sess, self.checkpoint_file)

    def fit_deepmask(self,
                     epochs=25,
                     lr=0.001,
                     score_factor=1.0 / 32,
                     weight_decay=0.00005):
        with tf.variable_scope("deepmask_training"):
            score_loss, segmentation_loss = self._binary_regression_loss(
                self.dm_seg_prediction, score_factor=score_factor)

            lr_var = tf.constant(
                lr
            )  # tf.train.inverse_time_decay(lr, global_step, 1,weight_decay)
            weight_loss, weight_vars = self._weight_decay()
            weight_decay_opt = tf.train.GradientDescentOptimizer(
                learning_rate=weight_decay)
            weight_decay_opt_op = weight_decay_opt.minimize(
                weight_loss, var_list=weight_vars)
            opt = tf.train.MomentumOptimizer(learning_rate=lr_var,
                                             momentum=0.9,
                                             use_nesterov=True)
            opt_op = opt.minimize(segmentation_loss + score_loss)

        self.sess.run(
            tf.variables_initializer(
                tf.get_collection(key=tf.GraphKeys.GLOBAL_VARIABLES,
                                  scope='deepmask_training')))

        self._fit_cycle(epochs,
                        lr_var,
                        progress_ops_dict={
                            'segmentation_loss': segmentation_loss,
                            'score_loss': score_loss,
                            'segmentation_iou': self.dm_seg_iou_metric,
                            'score_accuracy': self.score_accuracy_metric
                        },
                        opt_ops=[opt_op, weight_decay_opt_op],
                        metric_update_ops=[
                            self.dm_seg_iou_update, self.score_accuracy_update
                        ])

        print('Deep mask fit cycle completed')

    def fit_sharpmask(self, epochs=25, lr=0.001, weight_decay=0.00005):
        with tf.variable_scope("sharpmask_training"):
            _, segmentation_loss = self._binary_regression_loss(
                self.refinement_prediction)

            global_step = tf.Variable(initial_value=0)
            lr_var = tf.constant(lr)

            segmentation_opt = tf.train.MomentumOptimizer(learning_rate=lr_var,
                                                          momentum=0.9,
                                                          use_nesterov=True)
            segmentation_opt_op = segmentation_opt.minimize(
                segmentation_loss,
                global_step=global_step,
                var_list=tf.get_collection(key=tf.GraphKeys.GLOBAL_VARIABLES,
                                           scope='refinement'))

        self.sess.run(
            tf.variables_initializer(
                tf.get_collection(key=tf.GraphKeys.GLOBAL_VARIABLES,
                                  scope='sharpmask_training')))

        self._fit_cycle(epochs,
                        lr_var,
                        progress_ops_dict={
                            'segmentation_loss': segmentation_loss,
                            'segmentation_iou': self.sm_seg_iou_metric
                        },
                        opt_ops=[segmentation_opt_op],
                        metric_update_ops=[self.sm_seg_iou_update])

        print('Sharp mask fit cycle completed')

    def deepmask_validation(self):
        self._run_validation(
            {
                'segmentation_iou': self.dm_seg_iou_metric,
                'score_accuracy': self.score_accuracy_metric
            },
            metric_update_ops=[
                self.dm_seg_iou_update, self.score_accuracy_update
            ])

    def sharpmask_validation(self):
        self._run_validation({'segmentation_iou': self.sm_seg_iou_metric},
                             metric_update_ops=[self.sm_seg_iou_update])

    def eval_sharpmask(self, eval_source, eval_target):
        self._eval_prediction(eval_source, eval_target,
                              self.refinement_prediction)

    def eval_deepmask(self, eval_source, eval_target):
        self._eval_prediction(eval_source, eval_target, self.dm_seg_prediction)

    def _create_seg_metrics(self, seg_predictions):
        mask_indices = tf.where(self.score_target > 0)
        seg_metric_prediction = tf.gather(seg_predictions, mask_indices)
        seg_metric_prediction = tf.where(seg_metric_prediction > 0.0,
                                         tf.ones_like(seg_metric_prediction),
                                         tf.zeros_like(seg_metric_prediction))
        seg_mask = tf.gather(self.seg_target, mask_indices)
        seg_mask = tf.where(seg_mask > 0, tf.ones_like(seg_mask),
                            tf.zeros_like(seg_mask))
        return tf.metrics.mean_iou(seg_mask, seg_metric_prediction, 2)

    def _eval_prediction(self,
                         eval_source,
                         eval_target,
                         seg_predictions,
                         threshold=-1.0):
        self.sess.run([self.placeholder_init_op],
                      feed_dict={
                          self.image_placeholder: eval_source,
                          self.training_mode: False
                      })
        score_predictions, seg_predictions = self.sess.run(
            [self.score_predictions, seg_predictions])

        print('Predicted score is {}'.format(score_predictions[0]))

        eval_image = io.imread(eval_source)
        mask = np.where(seg_predictions[0] > threshold, 255, 0)
        mask = np.expand_dims(mask, axis=2).astype(np.uint8)
        mask = cv2.resize(mask, (eval_image.shape[1], eval_image.shape[0]))
        mask = Image.fromarray(mask)
        mask = mask.convert('RGB')

        eval_image = Image.fromarray(eval_image)
        eval_image = eval_image.convert('RGB')

        target_img = Image.blend(eval_image, mask, 0.5)
        target_img.save(eval_target)

        print('Image with the mask applied stored at {}'.format(eval_target))

    def _eval_resnet(self, eval_source):
        self.sess.run([self.placeholder_init_op],
                      feed_dict={self.image_placeholder: eval_source})
        prediction = self.sess.run([self.resnet_output])
        return IM_CLASSES[np.argmax(prediction[0])]

    def _create_dataset(self, data_path, batch_size):
        tfrecord_files = glob.glob(os.path.join(data_path, '*.tfrecord'))
        dataset = tf.data.TFRecordDataset(tfrecord_files,
                                          buffer_size=1572864000)
        dataset = dataset.shuffle(20000)
        dataset = dataset.map(transform_ds, num_parallel_calls=20)
        dataset = dataset.batch(32)

        return dataset

    def _binary_regression_loss(self, seg_predictions, score_factor=1.0 / 32):
        mask_target = tf.cast(self.seg_target, tf.float32)
        segmentation_loss = tf.reduce_mean(
            (1.0 + self.score_target) / 2.0 * tf.reduce_mean(
                tf.log(1.0 + tf.exp(-seg_predictions * mask_target)),
                axis=[1, 2]))
        score_loss = tf.reduce_mean(
            tf.log(1.0 + tf.exp(-self.score_target *
                                self.score_predictions))) * score_factor
        return score_loss, segmentation_loss

    def _weight_decay(
            self,
            scopes=['deepmask_trunk', 'segmentation_branch', 'score_branch']):
        weights = list(
            itertools.chain(*[
                tf.get_collection(key=tf.GraphKeys.GLOBAL_VARIABLES,
                                  scope=scope) for scope in scopes
            ]))
        weights = list(filter(lambda x: 'kernel' in x.name, weights))
        weights_norm = tf.reduce_sum(input_tensor=tf.stack(
            [tf.nn.l2_loss(i) for i in weights]),
                                     name='weights_norm')

        return weights_norm, weights

    def _run_validation(self,
                        progress_ops_dict,
                        metric_update_ops,
                        validation_steps_count=None):
        progress_ops_names, progress_ops = zip(*progress_ops_dict.items())
        progress_ops = list(progress_ops)

        validation_ops = metric_update_ops + progress_ops

        pbar = tqdm(total=validation_steps_count,
                    desc='Validation',
                    file=sys.stdout)
        counter = 0

        self.sess.run(tf.local_variables_initializer())
        self.sess.run(self.validation_init_op)

        while True:
            try:
                progress = self.sess.run(validation_ops,
                                         feed_dict={self.training_mode: False
                                                    })[-len(progress_ops):]
                counter += 1
                pbar.update()
                pbar.set_description('Validation ({})'.format(', '.join([
                    '{}={}'.format(name, val)
                    for name, val in zip(progress_ops_names, progress)
                ])))
            except tf.errors.OutOfRangeError as oe:
                break

        result = {
            name: value
            for name, value in zip(progress_ops_names, progress)
        }
        result['total_steps'] = counter

        return result

    def _fit_cycle(self, epochs, lr_var, progress_ops_dict, opt_ops,
                   metric_update_ops):
        progress_ops_names, progress_ops = zip(*progress_ops_dict.items())
        training_ops = opt_ops + metric_update_ops + list(progress_ops)

        train_steps_per_epoch = None
        validation_steps_per_epoch = None

        for e in range(epochs):
            tic = datetime.datetime.now()
            lr = self.sess.run([
                lr_var, self.training_init_op,
                tf.local_variables_initializer()
            ])[0]

            print()
            tqdm.write("----- Epoch {}/{} ; learning rate {} -----".format(
                e + 1, epochs, lr))
            pbar = tqdm(total=train_steps_per_epoch,
                        desc='Training',
                        file=sys.stdout)
            train_steps_per_epoch = 0

            while True:
                try:
                    progress = self.sess.run(training_ops)[-len(progress_ops):]
                    pbar.update()
                    pbar.set_description('Training ({})'.format(', '.join([
                        '{}={}'.format(name, val)
                        for name, val in zip(progress_ops_names, progress)
                    ])))
                    train_steps_per_epoch += 1
                except tf.errors.OutOfRangeError as oe:
                    break

            del pbar
            validation_results = self._run_validation(
                progress_ops_dict, metric_update_ops,
                validation_steps_per_epoch)
            training_report = ', '.join([
                'Training {}={}'.format(name, val)
                for name, val in zip(progress_ops_names, progress)
            ])
            validation_report = ', '.join([
                'Validation {}={}'.format(name, val)
                for name, val in validation_results.items()
            ])
            validation_steps_per_epoch = validation_results['total_steps']
            self.saver.save(self.sess, self.checkpoint_file)
            gc.collect()
            toc = datetime.datetime.now()
            tqdm.write("----- Epoch {} finished in {} -- {}. {}".format(
                e + 1, toc - tic, training_report, validation_report))
Esempio n. 26
0
def sample_sequence(*, hparams, length, start_token=None,
                    batch_size=None, context=None, temperature=1,
                    top_k=0, top_p=0.0):
    if start_token is None:
        assert context is not None, 'Specify exactly one of start_token and context!'
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
        context = tf.fill([batch_size, 1], start_token)

    def step(hparams, tokens, past=None):
        lm_output = model.model(hparams=hparams, X=tokens,
                                past=past, reuse=tf.compat.v1.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(model.past_shape(
            hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }

    with tf.compat.v1.name_scope('sample_sequence'):
        # Don't feed the last context token -- leave that to the loop below
        # TODO: Would be slightly faster if we called step on the entire context,
        # rather than leaving the last token transformer calculation to the while loop.
        context_output = step(hparams, context[:, :-1])

        def body(past, prev, output):
            next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
            logits = next_outputs['logits'][:, -1, :] / tf.cast(temperature, tf.float32)
            if top_p > 0.0:
                logits = top_p_logits(logits, p=top_p)
            else:
                logits = top_k_logits(logits, k=top_k)
            samples = tf.random.categorical(
                logits, num_samples=1, dtype=tf.int32)
            return [
                tf.concat([past, next_outputs['presents']], axis=-2),
                tf.squeeze(samples, axis=[1]),
                tf.concat([output, samples], axis=1),
            ]

        def cond(*args):
            return True
        _, _, tokens =  tf.nest.map_structure(
            tf.stop_gradient,
            tf.while_loop(
            cond=cond, body=body,
            maximum_iterations=length,
            loop_vars=[
                context_output['presents'],
                context[:, -1],
                context,
            ],
            shape_invariants=[
                tf.TensorShape(model.past_shape(
                    hparams=hparams, batch_size=batch_size)),
                tf.TensorShape([batch_size]),
                tf.TensorShape([batch_size, None]),
            ],
            back_prop=False,
        ))

        return tokens
 def test_with_np_int32_in_tensor_spec(self):
   t = computation_types.to_type((np.int32, [5]))
   self.assertIsInstance(t, computation_types.TensorType)
   self.assertEqual(t.dtype, tf.int32)
   self.assertEqual(t.shape, tf.TensorShape([5]))
    def create_sprite_image(self, examples):
        """Returns an encoded sprite image for use in Facets Dive.

    Args:
      examples: A list of serialized example protos to get images for.

    Returns:
      An encoded PNG.
    """
        def generate_image_from_thubnails(thumbnails, thumbnail_dims):
            """Generates a sprite atlas image from a set of thumbnails."""
            num_thumbnails = tf.shape(thumbnails)[0].eval()
            images_per_row = int(math.ceil(math.sqrt(num_thumbnails)))
            thumb_height = thumbnail_dims[0]
            thumb_width = thumbnail_dims[1]
            master_height = images_per_row * thumb_height
            master_width = images_per_row * thumb_width
            num_channels = 3
            master = np.zeros([master_height, master_width, num_channels])
            for idx, image in enumerate(thumbnails.eval()):
                left_idx = idx % images_per_row
                top_idx = int(math.floor(idx / images_per_row))
                left_start = left_idx * thumb_width
                left_end = left_start + thumb_width
                top_start = top_idx * thumb_height
                top_end = top_start + thumb_height
                master[top_start:top_end, left_start:left_end, :] = image
            return tf.image.encode_png(master)

        with tf.Session():
            keys_to_features = {
                self.image_feature_name:
                tf.FixedLenFeature((), tf.string, default_value=''),
            }
            parsed = tf.parse_example(examples, keys_to_features)
            images = tf.zeros([1, 1, 1, 1], tf.float32)
            i = tf.constant(0)
            thumbnail_dims = (self.sprite_thumbnail_dim_px,
                              self.sprite_thumbnail_dim_px)
            num_examples = tf.constant(len(examples))
            encoded_images = parsed[self.image_feature_name]

            # Loop over all examples, decoding the image feature value, resizing
            # and appending to a list of all images.
            def loop_body(i, encoded_images, images):
                encoded_image = encoded_images[i]
                image = tf.image.decode_jpeg(encoded_image, channels=3)
                resized_image = tf.image.resize_images(image, thumbnail_dims)
                expanded_image = tf.expand_dims(resized_image, 0)
                images = tf.cond(
                    tf.equal(i, 0), lambda: expanded_image,
                    lambda: tf.concat([images, expanded_image], 0))
                return i + 1, encoded_images, images

            loop_out = tf.while_loop(
                lambda i, encoded_images, images: tf.less(i, num_examples),
                loop_body, [i, encoded_images, images],
                shape_invariants=[
                    i.get_shape(),
                    encoded_images.get_shape(),
                    tf.TensorShape(None)
                ])

            # Create the single sprite atlas image from these thumbnails.
            sprite = generate_image_from_thubnails(loop_out[2], thumbnail_dims)
            return sprite.eval()
 def test_unknown_tensorshape(self):
   t = computation_types.TensorType(tf.int32, tf.TensorShape(None))
   self.assertEqual(t.dtype, tf.int32)
   self.assertEqual(t.shape, tf.TensorShape(None))
Esempio n. 30
0
def _main(_):
    # Data
    train_data = tx.data.MonoTextData(config.train_data_hparams)
    val_data = tx.data.MonoTextData(config.val_data_hparams)
    test_data = tx.data.MonoTextData(config.test_data_hparams)
    iterator = tx.data.TrainTestDataIterator(train=train_data,
                                             val=val_data,
                                             test=test_data)
    data_batch = iterator.get_next()

    opt_vars = {
        'learning_rate': config.lr_decay_hparams["init_lr"],
        'best_valid_nll': 1e100,
        'steps_not_improved': 0,
        'kl_weight': config.kl_anneal_hparams["start"]
    }

    decay_cnt = 0
    max_decay = config.lr_decay_hparams["max_decay"]
    decay_factor = config.lr_decay_hparams["decay_factor"]
    decay_ts = config.lr_decay_hparams["threshold"]

    save_dir = "./models/%s" % config.dataset

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    suffix = "%s_%sDecoder.ckpt" % \
            (config.dataset, config.decoder_hparams["type"])

    save_path = os.path.join(save_dir, suffix)

    # KL term annealing rate
    anneal_r = 1.0 / (config.kl_anneal_hparams["warm_up"] * \
        (train_data.dataset_size() / config.batch_size))

    # Model architecture
    encoder_embedder = tx.modules.WordEmbedder(
        vocab_size=train_data.vocab.size, hparams=config.emb_hparams)
    decoder_embedder = tx.modules.WordEmbedder(
        vocab_size=train_data.vocab.size, hparams=config.emb_hparams)

    input_embed = encoder_embedder(data_batch["text_ids"])
    output_embed = decoder_embedder(data_batch["text_ids"][:, :-1])

    if config.enc_keep_prob_in < 1:
        input_embed = tf.nn.dropout(
            input_embed, tx.utils.switch_dropout(config.enc_keep_prob_in))

    if config.dec_keep_prob_in < 1:
        output_embed = tf.nn.dropout(
            output_embed, tx.utils.switch_dropout(config.dec_keep_prob_in))

    encoder = tx.modules.UnidirectionalRNNEncoder(
        hparams={"rnn_cell": config.enc_cell_hparams})

    if config.decoder_hparams["type"] == "lstm":
        decoder = tx.modules.BasicRNNDecoder(
            vocab_size=train_data.vocab.size,
            hparams={"rnn_cell": config.dec_cell_hparams})
        decoder_initial_state_size = decoder.cell.state_size
    elif config.decoder_hparams["type"] == 'transformer':
        decoder = tx.modules.TransformerDecoder(
            embedding=decoder_embedder.embedding, hparams=config.trans_hparams)
        decoder_initial_state_size = tf.TensorShape(
            [1, config.emb_hparams["dim"]])
    else:
        raise NotImplementedError

    connector_mlp = tx.modules.MLPTransformConnector(config.latent_dims * 2)

    connector_stoch = tx.modules.ReparameterizedStochasticConnector(
        decoder_initial_state_size)

    _, ecdr_states = encoder(input_embed, sequence_length=data_batch["length"])

    mean_logvar = connector_mlp(ecdr_states)
    mean, logvar = tf.split(mean_logvar, 2, 1)
    kl_loss = kl_dvg(mean, logvar)

    dst = tfd.MultivariateNormalDiag(loc=mean, scale_diag=tf.exp(0.5 * logvar))

    dcdr_states, latent_z = connector_stoch(dst)

    # decoder
    if config.decoder_hparams["type"] == "lstm":
        # concat latent variable to input at every time step
        latent_z = tf.expand_dims(latent_z, axis=1)
        latent_z = tf.tile(latent_z, [1, tf.shape(output_embed)[1], 1])
        output_embed = tf.concat([output_embed, latent_z], axis=2)

        outputs, _, _ = decoder(initial_state=dcdr_states,
                                decoding_strategy="train_greedy",
                                inputs=output_embed,
                                sequence_length=data_batch["length"] - 1)
    else:
        outputs = decoder(inputs=output_embed,
                          memory=dcdr_states,
                          memory_sequence_length=tf.ones(
                              tf.shape(dcdr_states)[0]))

    logits = outputs.logits

    seq_lengths = data_batch["length"] - 1
    # Losses & train ops
    rc_loss = tx.losses.sequence_sparse_softmax_cross_entropy(
        labels=data_batch["text_ids"][:, 1:],
        logits=logits,
        sequence_length=data_batch["length"] - 1)

    # KL annealing
    kl_weight = tf.placeholder(tf.float32, shape=())

    nll = rc_loss + kl_weight * kl_loss

    learning_rate = tf.placeholder(dtype=tf.float32,
                                   shape=(),
                                   name='learning_rate')
    train_op = tx.core.get_train_op(nll,
                                    learning_rate=learning_rate,
                                    hparams=config.opt_hparams)

    def _run_epoch(sess, epoch, mode_string, display=10):
        if mode_string == 'train':
            iterator.switch_to_train_data(sess)
        elif mode_string == 'valid':
            iterator.switch_to_val_data(sess)
        elif mode_string == 'test':
            iterator.switch_to_test_data(sess)

        step = 0
        start_time = time.time()
        num_words = num_sents = 0
        nll_ = 0.
        kl_loss_ = rc_loss_ = 0.

        while True:
            try:
                fetches = {
                    "nll": nll,
                    "kl_loss": kl_loss,
                    "rc_loss": rc_loss,
                    "lengths": seq_lengths
                }

                if mode_string == 'train':
                    fetches["train_op"] = train_op
                    opt_vars["kl_weight"] = min(
                        1.0, opt_vars["kl_weight"] + anneal_r)

                    kl_weight_ = opt_vars["kl_weight"]
                else:
                    kl_weight_ = 1.0

                mode = (tf.estimator.ModeKeys.TRAIN if mode_string == 'train'
                        else tf.estimator.ModeKeys.EVAL)

                feed = {
                    tx.global_mode(): mode,
                    kl_weight: kl_weight_,
                    learning_rate: opt_vars["learning_rate"]
                }

                fetches_ = sess.run(fetches, feed_dict=feed)

                batch_size = len(fetches_["lengths"])
                num_sents += batch_size

                num_words += sum(fetches_["lengths"])
                nll_ += fetches_["nll"] * batch_size
                kl_loss_ += fetches_["kl_loss"] * batch_size
                rc_loss_ += fetches_["rc_loss"] * batch_size

                if step % display == 0 and mode_string == 'train':
                    print('%s: epoch %d, step %d, nll %.4f, klw: %.4f, ' \
                           'KL %.4f,  rc %.4f, log_ppl %.4f, ppl %.4f, ' \
                           'time elapsed: %.1fs' % \
                          (mode_string, epoch, step, nll_ / num_sents,
                           opt_vars["kl_weight"], kl_loss_ / num_sents,
                           rc_loss_ / num_sents, nll_ / num_words,
                           np.exp(nll_ / num_words), time.time() - start_time))

                    sys.stdout.flush()

                step += 1

            except tf.errors.OutOfRangeError:
                print('\n%s: epoch %d, nll %.4f, KL %.4f, rc %.4f, ' \
                      'log_ppl %.4f, ppl %.4f\n' %
                      (mode_string, epoch, nll_ / num_sents,
                       kl_loss_ / num_sents, rc_loss_ / num_sents,
                       nll_ / num_words, np.exp(nll_ / num_words)))
                break

        return nll_ / num_sents, np.exp(nll_ / num_words)

    def generate(sess, saver, fname=None):
        if tf.train.checkpoint_exists(FLAGS.model):
            saver.restore(sess, FLAGS.model)
        else:
            raise ValueError("cannot find checkpoint model")

        batch_size = train_data.batch_size

        dst = tfd.MultivariateNormalDiag(
            loc=tf.zeros([batch_size, config.latent_dims]),
            scale_diag=tf.ones([batch_size, config.latent_dims]))

        dcdr_states, latent_z = connector_stoch(dst)

        # to concatenate latent variable to input word embeddings
        def _cat_embedder(ids):
            embedding = decoder_embedder(ids)
            return tf.concat([embedding, latent_z], axis=1)

        vocab = train_data.vocab
        start_tokens = tf.ones(batch_size, tf.int32) * vocab.bos_token_id
        end_token = vocab.eos_token_id

        if config.decoder_hparams["type"] == "lstm":
            outputs, _, _ = decoder(initial_state=dcdr_states,
                                    decoding_strategy="infer_sample",
                                    embedding=_cat_embedder,
                                    max_decoding_length=100,
                                    start_tokens=start_tokens,
                                    end_token=end_token)
        else:
            outputs, _ = decoder(memory=dcdr_states,
                                 decoding_strategy="infer_sample",
                                 memory_sequence_length=tf.ones(
                                     tf.shape(dcdr_states)[0]),
                                 max_decoding_length=100,
                                 start_tokens=start_tokens,
                                 end_token=end_token)

        sample_tokens = vocab.map_ids_to_tokens(outputs.sample_id)
        sess.run(tf.tables_initializer())

        mode_key = tf.estimator.ModeKeys.EVAL
        feed = {tx.global_mode(): mode_key}
        sample_tokens_ = sess.run(sample_tokens, feed_dict=feed)
        if fname is None:
            fh = sys.stdout
        else:
            fh = open(fname, 'w', encoding='utf-8')

        for sent in sample_tokens_:
            sent = list(sent)
            end_id = sent.index(vocab.eos_token)
            fh.write(' '.join(sent[:end_id + 1]) + '\n')

        fh.close()

    saver = tf.train.Saver()
    with tf.Session() as sess:
        # generate samples from prior
        if FLAGS.mode == "predict":
            generate(sess, saver, FLAGS.out)
            return

        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        sess.run(tf.tables_initializer())

        # Counts trainable parameters
        total_parameters = 0
        for variable in tf.trainable_variables():
            shape = variable.get_shape()  # shape is an array of tf.Dimension
            variable_parameters = 1
            for dim in shape:
                variable_parameters *= dim.value
            total_parameters += variable_parameters
        print("%d total parameters" % total_parameters)

        best_nll = best_ppl = 0.

        for epoch in range(config.num_epochs):
            _, _ = _run_epoch(sess, epoch, 'train', display=200)
            val_nll, _ = _run_epoch(sess, epoch, 'valid')
            test_nll, test_ppl = _run_epoch(sess, epoch, 'test')

            if val_nll < opt_vars['best_valid_nll']:
                opt_vars['best_valid_nll'] = val_nll
                opt_vars['steps_not_improved'] = 0
                best_nll = test_nll
                best_ppl = test_ppl
                saver.save(sess, save_path)
            else:
                opt_vars['steps_not_improved'] += 1
                if opt_vars['steps_not_improved'] == decay_ts:
                    old_lr = opt_vars['learning_rate']
                    opt_vars['learning_rate'] *= decay_factor
                    opt_vars['steps_not_improved'] = 0
                    new_lr = opt_vars['learning_rate']

                    print('-----\nchange lr, old lr: %f, new lr: %f\n-----' %
                          (old_lr, new_lr))

                    saver.restore(sess, save_path)

                    decay_cnt += 1
                    if decay_cnt == max_decay:
                        break

        print('\nbest testing nll: %.4f, best testing ppl %.4f\n' %
              (best_nll, best_ppl))