def infer(self, features, *args, **kwargs):
        """Produce predictions from the model."""
        del args, kwargs
        inputs_old = None
        if "inputs" in features and len(features["inputs"].shape) < 4:
            inputs_old = features["inputs"]
            features["inputs"] = tf.expand_dims(features["inputs"], 2)
        features["targets"] = tf.identity(features["inputs"])

        # logits, _ = self(features)
        t2t_model.set_custom_getter_compose(self._custom_getter)
        tf.get_variable_scope().set_initializer(
            optimize.get_variable_initializer(self.hparams))
        with self._eager_var_store.as_default():
            self._fill_problem_hparams_features(features)
            # intentionally disable sharding during inference (in multi GPU)
            with tf.variable_scope(self.name):
                logits, _, _, targets_mask = self.model_fn(features)

        samples = tf.argmax(logits, axis=-1)
        samples = tf.where(
            tf.cast(targets_mask[..., tf.newaxis, tf.newaxis], tf.bool),
            samples, tf.ones_like(samples))
        if inputs_old is not None:  # Restore to not confuse Estimator.
            features["inputs"] = inputs_old
        return samples
    def call(self, inputs, **kwargs):
        del kwargs
        features = inputs
        t2t_model.set_custom_getter_compose(self._custom_getter)
        tf.get_variable_scope().set_initializer(
            optimize.get_variable_initializer(self.hparams))
        with self._eager_var_store.as_default():
            self._fill_problem_hparams_features(features)
            """ Modified """
            # Passing the encoder state reference in 'sharded_enc_ou' variable.
            sharded_features = self._shard_features(features)
            sharded_logits, losses, sharded_enc_ou = self.model_fn_sharded(
                sharded_features)

            if isinstance(sharded_logits, dict):
                concat_logits = {}
                for k, v in six.iteritems(sharded_logits):
                    concat_logits[k] = tf.concat(v, 0)
                concat_enc_ou = {}
                for k, v in six.iteritems(sharded_enc_ou):
                    concat_enc_ou[k] = tf.concat(v, 0)
                return concat_logits, losses, concat_enc_ou

            else:
                return tf.concat(sharded_logits,
                                 0), losses, sharded_enc_ou[0].h
Exemple #3
0
 def call(self, features):
     tf.get_variable_scope().set_initializer(
         optimize.get_variable_initializer(self.hparams))
     with self._eager_var_store.as_default():
         self._fill_problem_hparams_features(features)
         sharded_features = self._shard_features(features)
         sharded_logits, losses = self.model_fn_sharded(sharded_features)
         return tf.concat(sharded_logits, 0), losses
Exemple #4
0
 def call(self, features):
   tf.get_variable_scope().set_initializer(
       optimize.get_variable_initializer(self.hparams))
   with self._var_store.as_default():
     self._fill_problem_hparams_features(features)
     sharded_features = self._shard_features(features)
     sharded_logits, losses = self.model_fn_sharded(sharded_features)
     return tf.concat(sharded_logits, 0), losses
Exemple #5
0
 def call(self, features):
   tf.get_variable_scope().set_initializer(
       optimize.get_variable_initializer(self.hparams))
   with self._eager_var_store.as_default():
     self._fill_problem_hparams_features(features)
     sharded_features = self._shard_features(features)
     sharded_logits, losses = self.model_fn_sharded(sharded_features)
     if isinstance(sharded_logits, dict):
       concat_logits = {}
       for k, v in six.iteritems(sharded_logits):
         concat_logits[k] = tf.concat(v, 0)
       return concat_logits, losses
     else:
       return tf.concat(sharded_logits, 0), losses
Exemple #6
0
 def call(self, features):
   tf.get_variable_scope().set_initializer(
       optimize.get_variable_initializer(self.hparams))
   with self._eager_var_store.as_default():
     self._fill_problem_hparams_features(features)
     sharded_features = self._shard_features(features)
     sharded_logits, losses = self.model_fn_sharded(sharded_features)
     if isinstance(sharded_logits, dict):
       concat_logits = {}
       for k, v in sharded_logits.iteritems():
         concat_logits[k] = tf.concat(v, 0)
       return concat_logits, losses
     else:
       return tf.concat(sharded_logits, 0), losses
Exemple #7
0
 def _test_resnet(self, img_size, output_size):
     vocab_size = 1
     batch_size = 1
     x = np.random.random_integers(0,
                                   high=255,
                                   size=(batch_size, img_size, img_size, 3))
     y = np.random.random_integers(1,
                                   high=vocab_size,
                                   size=(batch_size, 1, 1, 1))
     #hparams = resnet_tiny_cpu()
     #hparams = resnet_50()
     hparams = resnet_32()
     p_hparams = problem_hparams.test_problem_hparams(
         vocab_size, vocab_size, hparams)
     p_hparams.input_modality["inputs"] = modalities.ImageModality(hparams)
     p_hparams.target_modality = modalities.ClassLabelModality(
         hparams, vocab_size)
     run_meta = tf.RunMetadata()
     with self.test_session() as session:
         features = {
             "inputs": tf.constant(x, dtype=tf.int32),
             "targets": tf.constant(y, dtype=tf.int32),
         }
         #model = resnet.Resnet(hparams, tf.estimator.ModeKeys.TRAIN, p_hparams)
         model = shake_shake.ShakeShake(hparams,
                                        tf.estimator.ModeKeys.TRAIN,
                                        p_hparams)
         logits, _ = model(features)
         print(logits.get_shape())
         #opts = tf.profiler.ProfileOptionBuilder.float_operation()
         #flops = tf.profiler.profile(tf.get_default_graph(), run_meta=run_meta, options=opts)
         #print(flops.total_float_ops)
         session.run(tf.global_variables_initializer())
         #res = session.run(logits)
         tf.get_variable_scope().set_initializer(
             optimize.get_variable_initializer(hparams))
         loss = tf.losses.sparse_softmax_cross_entropy(labels=tf.constant(
             0, dtype=tf.int32, shape=[1, 1, 1, 1, 1]),
                                                       logits=logits)
         train_op = optimize.optimize(loss, 0.1, hparams)
         session.run(loss)
         opts = tf.profiler.ProfileOptionBuilder.float_operation()
         flops = tf.profiler.profile(tf.get_default_graph(),
                                     run_meta=run_meta,
                                     options=opts)
         print(flops.total_float_ops)