def infer(self, features, *args, **kwargs): """Produce predictions from the model.""" del args, kwargs inputs_old = None if "inputs" in features and len(features["inputs"].shape) < 4: inputs_old = features["inputs"] features["inputs"] = tf.expand_dims(features["inputs"], 2) features["targets"] = tf.identity(features["inputs"]) # logits, _ = self(features) t2t_model.set_custom_getter_compose(self._custom_getter) tf.get_variable_scope().set_initializer( optimize.get_variable_initializer(self.hparams)) with self._eager_var_store.as_default(): self._fill_problem_hparams_features(features) # intentionally disable sharding during inference (in multi GPU) with tf.variable_scope(self.name): logits, _, _, targets_mask = self.model_fn(features) samples = tf.argmax(logits, axis=-1) samples = tf.where( tf.cast(targets_mask[..., tf.newaxis, tf.newaxis], tf.bool), samples, tf.ones_like(samples)) if inputs_old is not None: # Restore to not confuse Estimator. features["inputs"] = inputs_old return samples
def call(self, inputs, **kwargs): del kwargs features = inputs t2t_model.set_custom_getter_compose(self._custom_getter) tf.get_variable_scope().set_initializer( optimize.get_variable_initializer(self.hparams)) with self._eager_var_store.as_default(): self._fill_problem_hparams_features(features) """ Modified """ # Passing the encoder state reference in 'sharded_enc_ou' variable. sharded_features = self._shard_features(features) sharded_logits, losses, sharded_enc_ou = self.model_fn_sharded( sharded_features) if isinstance(sharded_logits, dict): concat_logits = {} for k, v in six.iteritems(sharded_logits): concat_logits[k] = tf.concat(v, 0) concat_enc_ou = {} for k, v in six.iteritems(sharded_enc_ou): concat_enc_ou[k] = tf.concat(v, 0) return concat_logits, losses, concat_enc_ou else: return tf.concat(sharded_logits, 0), losses, sharded_enc_ou[0].h
def call(self, features): tf.get_variable_scope().set_initializer( optimize.get_variable_initializer(self.hparams)) with self._eager_var_store.as_default(): self._fill_problem_hparams_features(features) sharded_features = self._shard_features(features) sharded_logits, losses = self.model_fn_sharded(sharded_features) return tf.concat(sharded_logits, 0), losses
def call(self, features): tf.get_variable_scope().set_initializer( optimize.get_variable_initializer(self.hparams)) with self._var_store.as_default(): self._fill_problem_hparams_features(features) sharded_features = self._shard_features(features) sharded_logits, losses = self.model_fn_sharded(sharded_features) return tf.concat(sharded_logits, 0), losses
def call(self, features): tf.get_variable_scope().set_initializer( optimize.get_variable_initializer(self.hparams)) with self._eager_var_store.as_default(): self._fill_problem_hparams_features(features) sharded_features = self._shard_features(features) sharded_logits, losses = self.model_fn_sharded(sharded_features) if isinstance(sharded_logits, dict): concat_logits = {} for k, v in six.iteritems(sharded_logits): concat_logits[k] = tf.concat(v, 0) return concat_logits, losses else: return tf.concat(sharded_logits, 0), losses
def call(self, features): tf.get_variable_scope().set_initializer( optimize.get_variable_initializer(self.hparams)) with self._eager_var_store.as_default(): self._fill_problem_hparams_features(features) sharded_features = self._shard_features(features) sharded_logits, losses = self.model_fn_sharded(sharded_features) if isinstance(sharded_logits, dict): concat_logits = {} for k, v in sharded_logits.iteritems(): concat_logits[k] = tf.concat(v, 0) return concat_logits, losses else: return tf.concat(sharded_logits, 0), losses
def _test_resnet(self, img_size, output_size): vocab_size = 1 batch_size = 1 x = np.random.random_integers(0, high=255, size=(batch_size, img_size, img_size, 3)) y = np.random.random_integers(1, high=vocab_size, size=(batch_size, 1, 1, 1)) #hparams = resnet_tiny_cpu() #hparams = resnet_50() hparams = resnet_32() p_hparams = problem_hparams.test_problem_hparams( vocab_size, vocab_size, hparams) p_hparams.input_modality["inputs"] = modalities.ImageModality(hparams) p_hparams.target_modality = modalities.ClassLabelModality( hparams, vocab_size) run_meta = tf.RunMetadata() with self.test_session() as session: features = { "inputs": tf.constant(x, dtype=tf.int32), "targets": tf.constant(y, dtype=tf.int32), } #model = resnet.Resnet(hparams, tf.estimator.ModeKeys.TRAIN, p_hparams) model = shake_shake.ShakeShake(hparams, tf.estimator.ModeKeys.TRAIN, p_hparams) logits, _ = model(features) print(logits.get_shape()) #opts = tf.profiler.ProfileOptionBuilder.float_operation() #flops = tf.profiler.profile(tf.get_default_graph(), run_meta=run_meta, options=opts) #print(flops.total_float_ops) session.run(tf.global_variables_initializer()) #res = session.run(logits) tf.get_variable_scope().set_initializer( optimize.get_variable_initializer(hparams)) loss = tf.losses.sparse_softmax_cross_entropy(labels=tf.constant( 0, dtype=tf.int32, shape=[1, 1, 1, 1, 1]), logits=logits) train_op = optimize.optimize(loss, 0.1, hparams) session.run(loss) opts = tf.profiler.ProfileOptionBuilder.float_operation() flops = tf.profiler.profile(tf.get_default_graph(), run_meta=run_meta, options=opts) print(flops.total_float_ops)