Esempio n. 1
0
def model_fn(features, labels, mode, params):
    """A model is called by TpuEstimator."""
    del labels
    global_step = tf.train.get_global_step()
    graph = mtf.Graph()
    mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
    if FLAGS.use_tpu:
        ctx = params['context']
        num_hosts = ctx.num_hosts
        host_placement_fn = ctx.tpu_host_placement_function
        device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
        tf.logging.info('device_list = %s' % device_list, )
        # TODO(ylc): Better estimation of replica cache size?
        replica_cache_size = 300 * 1000000  # 300M per replica
        # Worker 0 caches all the TPU binaries.
        worker0_mem = replica_cache_size * ctx.num_replicas
        devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
        var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                      devices_memeory_usage)
        mesh_devices = [''] * mesh_shape.size
        mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(mesh_shape, layout_rules,
                                                    mesh_devices,
                                                    ctx.device_assignment)
    else:
        var_placer = None
        mesh_devices = [''] * mesh_shape.size
        mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
            mesh_shape, layout_rules, mesh_devices)
    mesh = mtf.Mesh(graph, 'my_mesh', var_placer)

    with mtf.utils.outside_all_rewrites():
        logits, loss = toy_model(features, mesh)

    # TRAIN mode
    if mode == tf.estimator.ModeKeys.TRAIN:
        var_grads = mtf.gradients(
            [loss], [v.outputs[0] for v in graph.trainable_variables])
        if FLAGS.optimizer == 'Adafactor':
            optimizer = mtf.optimize.AdafactorOptimizer()
        else:
            assert FLAGS.optimizer == 'SGD'
            optimizer = mtf.optimize.SgdOptimizer(learning_rate=FLAGS.lr)
        update_ops = optimizer.apply_grads(var_grads,
                                           graph.trainable_variables)
    else:
        # for now, we can only export fully-replicated tensors.
        fully_replicated_logits = mtf.anonymize(logits)

    lowering = mtf.Lowering(graph, {mesh: mesh_impl})

    tf_loss = tf.to_float(lowering.export_to_tf_tensor(loss))

    if mode == tf.estimator.ModeKeys.TRAIN:
        tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
        tf_update_ops.append(tf.assign_add(global_step, 1))
        tf.logging.info('tf_update_ops: {}'.format(tf_update_ops))
        train_op = tf.group(tf_update_ops)
    else:
        tf_logits = lowering.export_to_tf_tensor(fully_replicated_logits)

    with mtf.utils.outside_all_rewrites():
        # Copy master variables to slices. Must be called first.
        restore_hook = mtf.MtfRestoreHook(lowering)
        if mode == tf.estimator.ModeKeys.TRAIN:
            saver = tf.train.Saver(tf.global_variables(),
                                   sharded=True,
                                   max_to_keep=10,
                                   keep_checkpoint_every_n_hours=2,
                                   defer_build=False,
                                   save_relative_paths=True)
            tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
            saver_listener = mtf.MtfCheckpointSaverListener(lowering)
            saver_hook = tf.train.CheckpointSaverHook(
                FLAGS.model_dir,
                save_steps=1000,
                saver=saver,
                listeners=[saver_listener])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.TRAIN,
                loss=tf_loss,
                train_op=train_op,
                training_hooks=[restore_hook, saver_hook])
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(tf_logits):
                mean_logits = tf.metrics.mean(tf_logits)
                return {'mean_logits': mean_logits}

            eval_metrics = (metric_fn, [tf_logits])

            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                evaluation_hooks=[restore_hook],
                loss=tf_loss,
                eval_metrics=eval_metrics)
    def _finish(self, update_ops, name):

        outer_wealth = self._get_non_slot(OUTER_WEALTH)
        betting_domain = self.betting_domain
        maximum_gradient = self._get_non_slot(MAXIMUM_GRADIENT)

        wealth_increment = sum(self.wealth_deltas.values())
        betting_fraction_dot_product = sum(
            self.betting_fraction_dot_product_deltas.values())
        grad_norm = sum(self.grad_norms.values())

        maximum_gradient_updated = self._assign(
            maximum_gradient, tf.maximum(maximum_gradient, grad_norm))
        update_ops.append(maximum_gradient_updated)

        gradient_scaling = 1.0 / maximum_gradient_updated
        # We will replace gradient with gradient/maximum_gradient_updated in order
        # to ensure ||gradient||_1 \le 1.
        # Since betting_fraction_dot_product and wealth_increment were calculated
        # using the original gradient, we also scale them by the same amount.
        betting_fraction_dot_product = betting_fraction_dot_product * gradient_scaling
        wealth_increment = wealth_increment * gradient_scaling

        outer_wealth_updated = self._assign_add(outer_wealth, wealth_increment)
        update_ops.append(outer_wealth_updated)

        inner_grad_scaling = (1.0 - betting_domain) / (
            1.0 - betting_fraction_dot_product)

        if self.output_summaries:
            tf.summary.scalar(self._name + "/total_wealth",
                              outer_wealth_updated)
            tf.summary.scalar(self._name + "/maximum_gradient_norm",
                              maximum_gradient_updated)
            tf.summary.scalar(self._name + "/gradient_L1_norm", grad_norm)

        if self.add_average:
            grad_norm_squared = tf.square(grad_norm)
            sum_grad_norm_squared = self._get_non_slot(SUM_GRAD_NORM_SQUARED)
            sum_grad_norm_squared_updated = self._assign_add(
                sum_grad_norm_squared, grad_norm_squared)

        for var in self.grads:

            grad = self.grads[var]

            if self.inner_optimizer == SCINOL:
                inner_grad = grad * inner_grad_scaling
            else:
                # Rescale gradient to have L1 norm at most 1.0
                scaled_grad = grad * gradient_scaling
                inner_grad = scaled_grad * inner_grad_scaling

            betting_fraction, inner_update_op = self._compute_inner_update(
                var, inner_grad)
            update_ops.append(inner_update_op)

            if self.output_summaries:
                betting_fraction_summary = tf.reduce_mean(
                    tf.abs(betting_fraction))
                tf.summary.scalar(
                    self._name + "/mean_abs_betting_fraction/" + var.name,
                    betting_fraction_summary)
                max_betting_fraction_summary = tf.reduce_max(
                    tf.abs(betting_fraction))
                tf.summary.scalar(
                    self._name + "/max_abs_betting_fraction/" + var.name,
                    max_betting_fraction_summary)

            next_offset = self.lr * betting_fraction * outer_wealth_updated
            initial_value = self.get_slot(var, INITIAL_VALUE)

            if self.add_average:
                average_offset = self.get_slot(var, AVERAGE_OFFSET)
                previous_sum_grad_norm_squared = sum_grad_norm_squared - grad_norm_squared
                average_offset_updated = self._assign_add(
                    average_offset, (previous_sum_grad_norm_squared *
                                     (next_offset - average_offset)) /
                    (sum_grad_norm_squared_updated))
                update_ops.append(average_offset_updated)

                var_updated = self._assign(
                    var, next_offset + average_offset_updated + initial_value)
            else:
                var_updated = self._assign(var, next_offset + initial_value)
            update_ops.append(var_updated)

        return tf.group(*update_ops, name=name)
Esempio n. 3
0
    def build(self, input_shape):
        """Builds the entropy model.

    Creates the variables for the network modeling the densities, creates the
    auxiliary loss estimating the median and tail quantiles of the densities,
    and then uses that to create the probability mass functions and the discrete
    cumulative density functions used by the range coder.

    Arguments:
      input_shape: Shape of the input tensor, used to get the number of
        channels.

    Raises:
      ValueError: if `input_shape` doesn't specify the length of the channel
        dimension.
    """
        input_shape = tf.TensorShape(input_shape)
        channel_axis = self._channel_axis(input_shape.ndims)
        channels = input_shape[channel_axis].value
        if channels is None:
            raise ValueError(
                "The channel dimension of the inputs must be defined.")
        self.input_spec = tf.keras.layers.InputSpec(
            ndim=input_shape.ndims, axes={channel_axis: channels})
        filters = (1, ) + self.filters + (1, )
        scale = self.init_scale**(1 / (len(self.filters) + 1))

        # Create variables.
        self._matrices = []
        self._biases = []
        self._factors = []
        for i in range(len(self.filters) + 1):
            init = np.log(np.expm1(1 / scale / filters[i + 1]))
            matrix = self.add_variable(
                "matrix_{}".format(i),
                dtype=self.dtype,
                shape=(channels, filters[i + 1], filters[i]),
                initializer=tf.initializers.constant(init))
            matrix = tf.nn.softplus(matrix)
            self._matrices.append(matrix)

            bias = self.add_variable(
                "bias_{}".format(i),
                dtype=self.dtype,
                shape=(channels, filters[i + 1], 1),
                initializer=tf.initializers.random_uniform(-.5, .5))
            self._biases.append(bias)

            if i < len(self.filters):
                factor = self.add_variable("factor_{}".format(i),
                                           dtype=self.dtype,
                                           shape=(channels, filters[i + 1], 1),
                                           initializer=tf.initializers.zeros())
                factor = tf.math.tanh(factor)
                self._factors.append(factor)

        # To figure out what range of the densities to sample, we need to compute
        # the quantiles given by `tail_mass / 2` and `1 - tail_mass / 2`. Since we
        # can't take inverses of the cumulative directly, we make it an optimization
        # problem:
        # `quantiles = argmin(|logit(cumulative) - target|)`
        # where `target` is `logit(tail_mass / 2)` or `logit(1 - tail_mass / 2)`.
        # Taking the logit (inverse of sigmoid) of the cumulative makes the
        # representation of the right target more numerically stable.

        # Numerically stable way of computing logits of `tail_mass / 2`
        # and `1 - tail_mass / 2`.
        target = np.log(2 / self.tail_mass - 1)
        # Compute lower and upper tail quantile as well as median.
        target = tf.constant([-target, 0, target], dtype=self.dtype)

        def quantiles_initializer(shape, dtype=None, partition_info=None):
            del partition_info  # unused
            assert tuple(shape[1:]) == (1, 3)
            init = tf.constant([[[-self.init_scale, 0, self.init_scale]]],
                               dtype=dtype)
            return tf.tile(init, (shape[0], 1, 1))

        quantiles = self.add_variable("quantiles",
                                      shape=(channels, 1, 3),
                                      dtype=self.dtype,
                                      initializer=quantiles_initializer)
        logits = self._logits_cumulative(quantiles, stop_gradient=True)
        loss = tf.math.reduce_sum(abs(logits - target))
        self.add_loss(loss, inputs=None)

        # Quantize such that the median coincides with the center of a bin.
        medians = quantiles[:, 0, 1]
        self._medians = tf.stop_gradient(medians)

        # Largest distance observed between lower tail quantile and median, and
        # between median and upper tail quantile.
        minima = medians - quantiles[:, 0, 0]
        minima = tf.cast(tf.math.ceil(minima), tf.int32)
        minima = tf.math.maximum(minima, 0)
        maxima = quantiles[:, 0, 2] - medians
        maxima = tf.cast(tf.math.ceil(maxima), tf.int32)
        maxima = tf.math.maximum(maxima, 0)

        # PMF starting positions and lengths.
        self._offset = -minima
        pmf_start = medians - tf.cast(minima, self.dtype)
        pmf_length = maxima + minima + 1

        # Sample the densities in the computed ranges, possibly computing more
        # samples than necessary at the upper end.
        max_length = tf.math.reduce_max(pmf_length)
        samples = tf.range(tf.cast(max_length, self.dtype), dtype=self.dtype)
        samples += pmf_start[:, None, None]

        half = tf.constant(.5, dtype=self.dtype)
        # We strip the sigmoid from the end here, so we can use the special rule
        # below to only compute differences in the left tail of the sigmoid.
        # This increases numerical stability (see explanation in `call`).
        lower = self._logits_cumulative(samples - half, stop_gradient=True)
        upper = self._logits_cumulative(samples + half, stop_gradient=True)
        # Flip signs if we can move more towards the left tail of the sigmoid.
        sign = -tf.math.sign(tf.math.add_n([lower, upper]))
        pmf = abs(
            tf.math.sigmoid(sign * upper) - tf.math.sigmoid(sign * lower))
        pmf = pmf[:, 0, :]

        # Compute out-of-range (tail) masses.
        tail_mass = tf.math.add_n([
            tf.math.sigmoid(lower[:, 0, :1]),
            tf.math.sigmoid(-upper[:, 0, -1:]),
        ])

        # Construct a valid CDF initializer, so that we can run the model without
        # error even on the zeroth training step.
        def cdf_initializer(shape, dtype=None, partition_info=None):
            del shape, partition_info  # unused
            assert dtype == tf.int32
            fill = tf.constant(.5, dtype=self.dtype)
            prob = tf.fill((channels, 2), fill)
            cdf = range_coding_ops.pmf_to_quantized_cdf(
                prob, precision=self.range_coder_precision)
            return tf.placeholder_with_default(cdf, shape=(channels, None))

        # We need to supply an initializer without fully defined static shape
        # here, or the variable will return the wrong dynamic shape later. A
        # placeholder with default gets the trick done (see initializer above).
        quantized_cdf = self.add_variable("quantized_cdf",
                                          shape=(channels, None),
                                          dtype=tf.int32,
                                          trainable=False,
                                          initializer=cdf_initializer)
        cdf_length = self.add_variable("cdf_length",
                                       shape=(channels, ),
                                       dtype=tf.int32,
                                       trainable=False,
                                       initializer=tf.initializers.constant(3))
        # Works around a weird TF issue with reading variables inside a loop.
        self._quantized_cdf = tf.identity(quantized_cdf)
        self._cdf_length = tf.identity(cdf_length)

        update_cdf = tf.assign(quantized_cdf,
                               self._pmf_to_cdf(pmf, tail_mass, pmf_length,
                                                max_length),
                               validate_shape=False)
        update_length = tf.assign(cdf_length, pmf_length + 2)
        update_op = tf.group(update_cdf, update_length)
        self.add_update(update_op, inputs=None)

        super(EntropyBottleneck, self).build(input_shape)
Esempio n. 4
0
                        metrics=transformer.get_metric_functions())

    print(tf.trainable_variables())
    print(adamm.variables())
    print(len(tf.trainable_variables()))
    print(len(adamm.variables()))

    print(adamm.get_slot_names())
    print(len(adamm.get_slot_names()))

    train_model.summary()

    print('Start unit testing : New BERTWrapper')

    sess = K.get_session()
    init = tf.group(tf.global_variables_initializer(),
                    tf.local_variables_initializer())
    sess.run(init)

    test_data = [
        ['Hello', 'World'],
        ['Hello', 'World'],
        ['Hello', 'World'],
        ['Hello', 'World'],
    ]
    input_vals = itokens.encode(test_data, max_length=16)
    output_vals = otokens.encode(test_data, max_length=16)
    print(input_vals)
    print(output_vals)

    train_model.fit(x=input_vals,
                    y=output_vals,
    def _compute_inner_update_scinol(self, var, grad):
        update_ops = []

        betting_domain = tf.cast(self.betting_domain, var.dtype.base_dtype)

        reward = self.get_slot(var, INNER_REWARD)
        betting_fraction = self.get_slot(var, OUTER_BETTING_FRACTION)
        sum_grad_squared = self.get_slot(var, INNER_SUM_GRAD_SQUARED)
        sum_grad = self.get_slot(var, INNER_SUM_GRAD)
        inner_maximum_gradient = self.get_slot(var, INNER_MAXIMUM_GRADIENT)

        # clip inner gradient to respect previous inner_maximum_gradient value
        # This introduces at most an additive constant overhead in the regret
        # since the inner betting fraction lies in a bounded domain.
        clipped_grad = tf.clip_by_value(grad, -inner_maximum_gradient,
                                        inner_maximum_gradient)

        with tf.control_dependencies([clipped_grad]):
            inner_maximum_gradient_updated = self._assign(
                inner_maximum_gradient,
                tf.maximum(inner_maximum_gradient, tf.abs(grad)))
            update_ops.append(inner_maximum_gradient_updated)

        clipped_old_betting_fraction = tf.clip_by_value(
            betting_fraction, -betting_domain, betting_domain)

        # Process grad to respect truncation to [-betting_domain, betting_domain]
        truncated_grad = tf.where(
            tf.greater_equal(
                clipped_grad *
                (betting_fraction - clipped_old_betting_fraction), 0.0),
            clipped_grad, tf.zeros(tf.shape(clipped_grad)))

        reward_delta = -betting_fraction * truncated_grad
        reward_updated = self._assign_add(reward, reward_delta)
        update_ops.append(reward_updated)

        sum_grad_squared_updated = self._assign_add(sum_grad_squared,
                                                    tf.square(truncated_grad))
        update_ops.append(sum_grad_squared_updated)

        sum_grad_updated = self._assign_add(sum_grad, truncated_grad)
        update_ops.append(sum_grad_updated)

        # The second term in this maximum, inner_maximum_gradient_updated / self.eta
        # is a hack to force the betting fraction to not be too big at first.
        scaling = tf.minimum(
            tf.rsqrt(sum_grad_squared_updated +
                     tf.square(inner_maximum_gradient_updated)),
            self.eta / inner_maximum_gradient_updated)
        theta = -sum_grad_updated * scaling

        # rescale inner flag is a hack that rescales the epsilon_v by the
        # maximum inner gradient.
        if self.rescale_inner:
            epsilon_scaling = inner_maximum_gradient_updated
        else:
            epsilon_scaling = 1.0

        inner_betting_fraction = tf.sign(theta) * tf.minimum(
            tf.abs(theta), 1.0) * scaling / 2.0
        new_betting_fraction = inner_betting_fraction * (
            reward_updated + epsilon_scaling * self.epsilon_v)

        betting_fraction_updated = self._assign(betting_fraction,
                                                new_betting_fraction)
        update_ops.append(betting_fraction_updated)

        clipped_betting_fraction = tf.clip_by_value(betting_fraction_updated,
                                                    -betting_domain,
                                                    betting_domain)

        if self.output_summaries:
            mean_unclipped_betting_fraction_summary = tf.reduce_mean(
                tf.abs(betting_fraction_updated))
            max_unclipped_betting_fraction_summary = tf.reduce_max(
                tf.abs(betting_fraction_updated))

            mean_clipped_betting_fraction_summary = tf.reduce_mean(
                tf.abs(clipped_betting_fraction))
            max_clipped_betting_fraction_summary = tf.reduce_max(
                tf.abs(clipped_betting_fraction))

            max_abs_gradient = tf.reduce_max(tf.abs(grad))
            max_truncated_grad = tf.reduce_max(tf.abs(truncated_grad))

            tf.summary.scalar(self._name + "/mean_unclipped_bet/" + var.name,
                              mean_unclipped_betting_fraction_summary)
            tf.summary.scalar(self._name + "/max_unclipped_bet/" + var.name,
                              max_unclipped_betting_fraction_summary)
            tf.summary.scalar(self._name + "/mean_clipped_bet/" + var.name,
                              mean_clipped_betting_fraction_summary)
            tf.summary.scalar(self._name + "/max_clipped_bet/" + var.name,
                              max_clipped_betting_fraction_summary)

            tf.summary.scalar(self._name + "/max_abs_inner_grad/" + var.name,
                              max_abs_gradient)
            tf.summary.scalar(
                self._name + "/max_abs_truncated_inner_grad/" + var.name,
                max_truncated_grad)
        return clipped_betting_fraction, tf.group(*update_ops)
Esempio n. 6
0
def create_optimizer(loss,
                     init_lr,
                     num_train_steps,
                     num_warmup_steps,
                     use_tpu,
                     Global_step,
                     optimizer="adamw",
                     poly_power=1.0,
                     start_warmup_step=0):
    """Creates an optimizer training op."""
    #global_step = tf.train.get_or_create_global_step()

    # by chenming
    if Global_step:
        global_step = Global_step
    else:
        global_step = tf.train.get_or_create_global_step()

    learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)

    # Implements linear decay of the learning rate.
    learning_rate = tf.train.polynomial_decay(learning_rate,
                                              global_step,
                                              num_train_steps,
                                              end_learning_rate=0.0,
                                              power=poly_power,
                                              cycle=False)

    # Implements linear warmup. I.e., if global_step - start_warmup_step <
    # num_warmup_steps, the learning rate will be
    # `(global_step - start_warmup_step)/num_warmup_steps * init_lr`.
    if num_warmup_steps:
        tf.logging.info("++++++ warmup starts at step " +
                        str(start_warmup_step) + ", for " +
                        str(num_warmup_steps) + " steps ++++++")
        global_steps_int = tf.cast(global_step, tf.int32)
        start_warm_int = tf.constant(start_warmup_step, dtype=tf.int32)
        global_steps_int = global_steps_int - start_warm_int
        warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)

        global_steps_float = tf.cast(global_steps_int, tf.float32)
        warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)

        warmup_percent_done = global_steps_float / warmup_steps_float
        warmup_learning_rate = init_lr * warmup_percent_done

        is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
        learning_rate = ((1.0 - is_warmup) * learning_rate +
                         is_warmup * warmup_learning_rate)

    # It is OK that you use this optimizer for finetuning, since this
    # is how the model was trained (note that the Adam m/v variables are NOT
    # loaded from init_checkpoint.)
    # It is OK to use AdamW in the finetuning even the model is trained by LAMB.
    # As report in the Bert pulic github, the learning rate for SQuAD 1.1 finetune
    # is 3e-5, 4e-5 or 5e-5. For LAMB, the users can use 3e-4, 4e-4,or 5e-4 for a
    # batch size of 64 in the finetune.
    if optimizer == "adamw":
        tf.logging.info("using adamw")
        optimizer = AdamWeightDecayOptimizer(
            learning_rate=learning_rate,
            weight_decay_rate=0.01,
            beta_1=0.9,
            beta_2=0.999,
            epsilon=1e-6,
            exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
    elif optimizer == "lamb":
        tf.logging.info("using lamb")
        optimizer = lamb_optimizer.LAMBOptimizer(
            learning_rate=learning_rate,
            weight_decay_rate=0.01,
            beta_1=0.9,
            beta_2=0.999,
            epsilon=1e-6,
            exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])
    else:
        raise ValueError("Not supported optimizer: ", optimizer)

    if use_tpu:
        optimizer = contrib_tpu.CrossShardOptimizer(optimizer)

    tvars = tf.trainable_variables()
    grads = tf.gradients(loss, tvars)

    # This is how the model was pre-trained.
    (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)

    train_op = optimizer.apply_gradients(list(zip(grads, tvars)),
                                         global_step=global_step)

    # Normally the global step update is done inside of `apply_gradients`.
    # However, neither `AdamWeightDecayOptimizer` nor `LAMBOptimizer` do this.
    # But if you use a different optimizer, you should probably take this line
    # out.
    new_global_step = global_step + 1
    train_op = tf.group(train_op, [global_step.assign(new_global_step)])
    return train_op
Esempio n. 7
0
def get_train_op(flags, total_loss, ema=None, tvars=None):
    """Generates the training operation."""
    global_step = tf.train.get_or_create_global_step()

    # increase the learning rate linearly
    if flags.warmup_steps > 0:
        warmup_lr = (tf.cast(global_step, tf.float32) /
                     tf.cast(flags.warmup_steps, tf.float32) *
                     flags.learning_rate)
    else:
        warmup_lr = 0.0

    # decay the learning rate
    if flags.decay_method == "poly":
        decay_lr = tf.train.polynomial_decay(
            flags.learning_rate,
            global_step=global_step - flags.warmup_steps,
            decay_steps=flags.train_steps - flags.warmup_steps,
            end_learning_rate=flags.learning_rate * flags.min_lr_ratio)
    elif flags.decay_method == "cos":
        decay_lr = tf.train.cosine_decay(
            flags.learning_rate,
            global_step=global_step - flags.warmup_steps,
            decay_steps=flags.train_steps - flags.warmup_steps,
            alpha=flags.min_lr_ratio)
    else:
        raise ValueError(flags.decay_method)

    learning_rate = tf.where(global_step < flags.warmup_steps, warmup_lr,
                             decay_lr)

    if flags.weight_decay == 0:
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                           epsilon=flags.adam_epsilon)
    elif flags.weight_decay > 0 and flags.num_core_per_host == 1:
        optimizer = AdamWeightDecayOptimizer(
            learning_rate=learning_rate,
            epsilon=flags.adam_epsilon,
            exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
            weight_decay_rate=flags.weight_decay)
    else:
        raise ValueError("Do not support `weight_decay > 0` with multi-gpu "
                         "training so far.")

    if flags.use_tpu:
        optimizer = tf.tpu.CrossShardOptimizer(optimizer)

    if tvars is None:
        grads_and_vars = optimizer.compute_gradients(total_loss)
    else:
        grads_and_vars = optimizer.compute_gradients(total_loss,
                                                     var_list=tvars)
    gradients, variables = zip(*grads_and_vars)
    clipped, gnorm = tf.clip_by_global_norm(gradients, flags.clip)

    train_op = optimizer.apply_gradients(zip(clipped, variables),
                                         global_step=global_step)

    # Manually increment `global_step` for AdamWeightDecayOptimizer
    if isinstance(optimizer, AdamWeightDecayOptimizer):
        new_global_step = global_step + 1
        train_op = tf.group(train_op, [global_step.assign(new_global_step)])

    if ema is not None:
        # Update the variables with the EMA after the train op.
        with tf.control_dependencies([train_op]):
            train_op = ema.apply(tf.trainable_variables())
    return train_op, learning_rate, gnorm
Esempio n. 8
0
def _model_fn(features, labels, mode, params, model):
    """Model defination for the SSD model based on ResNet-50.

  Args:
    features: the input image tensor with shape [batch_size, height, width, 3].
      The height and width are fixed and equal.
    labels: the input labels in a dictionary. The labels include class targets
      and box targets which are dense label maps. The labels are generated from
      get_input_fn function in data/dataloader.py
    mode: the mode of TPUEstimator including TRAIN, EVAL, and PREDICT.
    params: the dictionary defines hyperparameters of model. The default
      settings are in default_hparams function in this file.
    model: the SSD model outputs class logits and box regression outputs.

  Returns:
    spec: the EstimatorSpec or TPUEstimatorSpec to run training, evaluation,
      or prediction.
  """
    if mode == tf.estimator.ModeKeys.PREDICT:
        labels = features
        features = labels.pop('image')

    features -= tf.constant(constants.NORMALIZATION_MEAN,
                            shape=[1, 1, 3],
                            dtype=features.dtype)
    COEF_STD = 1.0 / tf.constant(
        constants.NORMALIZATION_STD, shape=[1, 1, 3], dtype=features.dtype)
    features *= COEF_STD

    def _model_outputs():
        return model(features,
                     params,
                     is_training_bn=(mode == tf.estimator.ModeKeys.TRAIN))

    if params['dtype'] == 'bf16':
        with tf.compat.v1.tpu.bfloat16_scope():
            cls_outputs, box_outputs = _model_outputs()
            levels = cls_outputs.keys()
            for level in levels:
                cls_outputs[level] = tf.cast(cls_outputs[level], tf.float32)
                box_outputs[level] = tf.cast(box_outputs[level], tf.float32)
    else:
        cls_outputs, box_outputs = _model_outputs()
        levels = cls_outputs.keys()

    # First check if it is in PREDICT mode.
    if mode == tf.estimator.ModeKeys.PREDICT:
        flattened_cls, flattened_box = concat_outputs(cls_outputs, box_outputs,
                                                      True)
        ssd_box_coder = faster_rcnn_box_coder.FasterRcnnBoxCoder(
            scale_factors=constants.BOX_CODER_SCALES)

        anchors = box_list.BoxList(
            tf.convert_to_tensor(dataloader.DefaultBoxes()('ltrb')))

        decoded_boxes = box_coder.batch_decode(encoded_boxes=flattened_box,
                                               box_coder=ssd_box_coder,
                                               anchors=anchors)

        pred_scores = tf.nn.softmax(flattened_cls, axis=2)

        pred_scores, indices = select_top_k_scores(
            pred_scores, constants.MAX_NUM_EVAL_BOXES)
        predictions = dict(
            labels,
            indices=indices,
            pred_scores=pred_scores,
            pred_box=decoded_boxes,
        )

        if params['visualize_dataloader']:
            # this is for inference visualization.
            predictions['image'] = features

        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    # Load pretrained model from checkpoint.
    if params['resnet_checkpoint'] and mode == tf.estimator.ModeKeys.TRAIN:

        def scaffold_fn():
            """Loads pretrained model through scaffold function."""
            tf.train.init_from_checkpoint(
                params['resnet_checkpoint'], {
                    '/': 'resnet%s/' % constants.RESNET_DEPTH,
                })
            return tf.train.Scaffold()
    else:
        scaffold_fn = None

    # Set up training loss and learning rate.
    update_learning_rate_schedule_parameters(params)
    global_step = tf.train.get_or_create_global_step()
    learning_rate = learning_rate_schedule(params, global_step)
    # cls_loss and box_loss are for logging. only total_loss is optimized.
    loss, cls_loss, box_loss = detection_loss(cls_outputs, box_outputs, labels)

    total_loss = loss + params['weight_decay'] * tf.add_n(
        [tf.nn.l2_loss(v) for v in tf.trainable_variables()])

    if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.MomentumOptimizer(learning_rate,
                                               momentum=constants.MOMENTUM)

        if params['distributed_optimizer']:
            optimizer = params['distributed_optimizer'](optimizer)

        # Batch norm requires update_ops to be added as a train_op dependency.
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        train_op = tf.group(optimizer.minimize(total_loss, global_step),
                            update_ops)
        return model_fn_lib.EstimatorSpec(mode=mode,
                                          loss=loss,
                                          train_op=train_op,
                                          scaffold=scaffold_fn())

    if mode == tf.estimator.ModeKeys.EVAL:
        raise NotImplementedError
Esempio n. 9
0
 def _set_up_cache(self):
     self._lower_offset, update_lower = self._cache_with_update_op(
         self._lower_offset)
     self._upper_offset, update_upper = self._cache_with_update_op(
         self._upper_offset)
     return tf.group([update_lower, update_upper])
Esempio n. 10
0
def create_optimizer(loss, init_lr, num_train_steps, num_warmup_steps, use_tpu):
  """Creates an optimizer training op."""
  global_step = tf.train.get_or_create_global_step()

  learning_rate = tf.constant(value=init_lr, shape=[], dtype=tf.float32)

  # Implements linear decay of the learning rate.
  learning_rate = tf.train.polynomial_decay(
      learning_rate,
      global_step,
      num_train_steps,
      end_learning_rate=0.0,
      power=1.0,
      cycle=False)

  # Implements linear warmup. I.e., if global_step < num_warmup_steps, the
  # learning rate will be `global_step/num_warmup_steps * init_lr`.
  if num_warmup_steps:
    global_steps_int = tf.cast(global_step, tf.int32)
    warmup_steps_int = tf.constant(num_warmup_steps, dtype=tf.int32)

    global_steps_float = tf.cast(global_steps_int, tf.float32)
    warmup_steps_float = tf.cast(warmup_steps_int, tf.float32)

    warmup_percent_done = global_steps_float / warmup_steps_float
    warmup_learning_rate = init_lr * warmup_percent_done

    is_warmup = tf.cast(global_steps_int < warmup_steps_int, tf.float32)
    learning_rate = (
        (1.0 - is_warmup) * learning_rate + is_warmup * warmup_learning_rate)

  # It is recommended that you use this optimizer for fine tuning, since this
  # is how the model was trained (note that the Adam m/v variables are NOT
  # loaded from init_checkpoint.)
  optimizer = optimization.AdamWeightDecayOptimizer(
      learning_rate=learning_rate,
      weight_decay_rate=0.01,
      beta_1=0.9,
      beta_2=0.999,
      epsilon=1e-6,
      exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"])

  if use_tpu:
    optimizer = contrib_tpu.CrossShardOptimizer(optimizer)

  tvars = tf.trainable_variables()
  print(tvars)
  tvars = [v for v in tvars if "bert" not in v.name]
  print("no bert")
  print(tvars)
  grads = tf.gradients(loss, tvars)

  # This is how the model was pre-trained.
  (grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)

  train_op = optimizer.apply_gradients(
      zip(grads, tvars), global_step=global_step)

  # Normally the global step update is done inside of `apply_gradients`.
  # However, `AdamWeightDecayOptimizer` doesn't do this. But if you use
  # a different optimizer, you should probably take this line out.
  new_global_step = global_step + 1
  train_op = tf.group(train_op, [global_step.assign(new_global_step)])
  return train_op
Esempio n. 11
0
def test_train(args):
    """Trains the model."""

    if args.verbose:
        tf.logging.set_verbosity(tf.logging.INFO)

    # Create input data pipeline.
    with tf.device("/cpu:0"):
        train_files = glob.glob(args.train_glob)
        if not train_files:
            raise RuntimeError(
                "No training images found with glob '{}'.".format(
                    args.train_glob))
        train_dataset = tf.data.Dataset.from_tensor_slices(train_files)
        train_dataset = train_dataset.shuffle(
            buffer_size=len(train_files)).repeat()
        train_dataset = train_dataset.map(
            read_png, num_parallel_calls=args.preprocess_threads)
        train_dataset = train_dataset.map(
            lambda x: tf.random_crop(x, (args.patchsize, args.patchsize, 3)))
        train_dataset = train_dataset.batch(args.batchsize)
        train_dataset = train_dataset.prefetch(32)

    num_pixels = args.batchsize * args.patchsize**2

    # Get training patch from dataset.
    x = train_dataset.make_one_shot_iterator().get_next()

    lmbda_log_dist = np.hstack((np.arange(0, 7, 0.01), np.arange(7, 0, -0.01)))
    lmbda_log_dist = tf.constant(lmbda_log_dist, dtype=tf.float32)
    s = tf.data.Dataset.from_tensor_slices(lmbda_log_dist).repeat()
    lmbda_log = s.make_one_shot_iterator().get_next()  # levels
    lmbda = 0.1 * tf.pow(2.0, lmbda_log - 6.0)  # true value

    # Instantiate model.
    analysis_transform = AnalysisTransform(args.num_filters, lmbda_log)
    synthesis_transform = SynthesisTransform(args.num_filters, lmbda_log)
    hyper_analysis_transform = HyperAnalysisTransform(args.num_filters,
                                                      lmbda_log)
    hyper_synthesis_transform = HyperSynthesisTransform(
        args.num_filters, lmbda_log)
    entropy_bottleneck = tfc.EntropyBottleneck()

    # Build autoencoder and hyperprior.
    y = analysis_transform(x)
    z = hyper_analysis_transform(abs(y))
    z_tilde, z_likelihoods = entropy_bottleneck(z, training=True)
    sigma = hyper_synthesis_transform(z_tilde)
    scale_table = np.exp(
        np.linspace(np.log(SCALES_MIN), np.log(SCALES_MAX), SCALES_LEVELS))
    conditional_bottleneck = tfc.GaussianConditional(sigma, scale_table)
    y_tilde, y_likelihoods = conditional_bottleneck(y, training=True)
    x_tilde = synthesis_transform(y_tilde)

    # Total number of bits divided by number of pixels.
    train_bpp = (tf.reduce_sum(tf.log(y_likelihoods)) + tf.reduce_sum(
        tf.log(z_likelihoods))) / (-np.log(2) * num_pixels)

    # Mean squared error across pixels.
    train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde))
    # Multiply by 255^2 to correct for rescaling.
    train_mse *= 255**2

    # The rate-distortion cost.
    train_loss = lmbda * train_mse + train_bpp

    # Minimize loss and auxiliary loss, and execute update op.
    step = tf.train.create_global_step()
    main_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
    main_step = main_optimizer.minimize(train_loss, global_step=step)

    aux_optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
    aux_step = aux_optimizer.minimize(entropy_bottleneck.losses[0])

    train_op = tf.group(main_step, aux_step, entropy_bottleneck.updates[0])

    tf.summary.scalar("loss", train_loss)
    tf.summary.scalar("bpp", train_bpp)
    tf.summary.scalar("mse", train_mse)

    tf.summary.image("original", quantize_image(x))
    tf.summary.image("reconstruction", quantize_image(x_tilde))

    hooks = [
        tf.train.StopAtStepHook(last_step=args.last_step),
        tf.train.NanTensorHook(train_loss),
    ]
    with tf.train.MonitoredTrainingSession(hooks=hooks,
                                           checkpoint_dir=args.checkpoint_dir,
                                           save_checkpoint_secs=300,
                                           save_summaries_secs=60) as sess:
        while not sess.should_stop():
            sess.run(train_op)
Esempio n. 12
0
def train(args):
  """Trains the model."""

  if args.verbose:
    tf.logging.set_verbosity(tf.logging.INFO)

  # Create input data pipeline.
  with tf.device("/cpu:0"):
    train_files = glob.glob(args.train_glob)
    if not train_files:
      raise RuntimeError(
          "No training images found with glob '{}'.".format(args.train_glob))
    train_dataset = tf.data.Dataset.from_tensor_slices(train_files)
    train_dataset = train_dataset.shuffle(buffer_size=len(train_files)).repeat()
    train_dataset = train_dataset.map(
        read_png, num_parallel_calls=args.preprocess_threads)
    train_dataset = train_dataset.map(
        lambda x: tf.random_crop(x, (args.patchsize, args.patchsize, 3)))
    train_dataset = train_dataset.batch(args.batchsize)
    train_dataset = train_dataset.prefetch(32)

  #num_pixels = args.batchsize * args.patchsize ** 2

  # Get training patch from dataset.
  x = train_dataset.make_one_shot_iterator().get_next()

  # Instantiate model.
  analysis_transform = AnalysisTransform(args.num_filters)
  #entropy_bottleneck = tfc.EntropyBottleneck()
  synthesis_transform = SynthesisTransform(args.num_filters)

  # Build autoencoder.
  y = analysis_transform(x)
  #y_tilde, likelihoods = entropy_bottleneck(y, training=True)
  x_tilde = synthesis_transform(y)

  # Total number of bits divided by number of pixels.
  #train_bpp = tf.reduce_sum(tf.log(likelihoods)) / (-np.log(2) * num_pixels)

  # Mean squared error across pixels.
  train_mse = tf.reduce_mean(tf.squared_difference(x, x_tilde))

  # Multiply by 255^2 to correct for rescaling.
  #train_mse *= 255 ** 2

  # Calculate psnr and ssim
  train_psnr = tf.reduce_mean(tf.image.psnr(x_tilde, x, 1))
  train_msssim_value = tf.reduce_mean(tf.image.ssim_multiscale(x_tilde, x, 1))

  # structural similarity loss
  train_ssim = tf.reduce_mean(1 - tf.image.ssim_multiscale(x_tilde, x, 1))

  #Choose distortion metric
  distortion = train_ssim if args.ssim_loss else train_mse
  
  # The rate-distortion cost.
  train_loss = distortion

  # Minimize loss and auxiliary loss, and execute update op.
  step = tf.train.create_global_step()
  main_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
  main_step = main_optimizer.minimize(train_loss, global_step=step)

  #aux_optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
  #aux_step = aux_optimizer.minimize(entropy_bottleneck.losses[0])

  train_op = tf.group(main_step)

  # Log scalar values
  s_loss = tf.summary.scalar("train/loss", train_loss)
  #s_bpp = tf.summary.scalar("train/bpp", train_bpp)
  s_mse = tf.summary.scalar("train/mse", train_mse)
  s_psnr = tf.summary.scalar("train/psnr", train_psnr)
  s_msssim_value = tf.summary.scalar("train/multiscale ssim value", train_msssim_value)
  s_ssim = tf.summary.scalar("train/multiscale ssim", -10 * tf.log(train_ssim)) 

  # Log training images
  s_original = tf.summary.image("images/original", quantize_image(x))
  s_reconstruction = tf.summary.image("images/reconstruction", quantize_image(x_tilde))

  # Merge scalars into a summary
  train_summary = tf.summary.merge([s_loss, s_mse, s_psnr, s_msssim_value, s_ssim])

  #Merge images into a summary
  image_summary = tf.summary.merge([s_original, s_reconstruction])

  hooks = [
      tf.train.StopAtStepHook(last_step=args.last_step),
      tf.train.NanTensorHook(train_loss),
      tf.train.SummarySaverHook(save_secs=30,output_dir=args.checkpoint_dir,summary_op=train_summary),
      tf.train.SummarySaverHook(save_secs=3600,output_dir=args.checkpoint_dir,summary_op=image_summary)
  ]
  with tf.train.MonitoredTrainingSession(
      hooks=hooks, checkpoint_dir=args.checkpoint_dir,
      save_checkpoint_secs=300, save_summaries_steps=None, save_summaries_secs=None) as sess:
    while not sess.should_stop():
      sess.run(train_op)
Esempio n. 13
0
def model_fn(features, labels, mode, params):
  """The model_fn argument for creating an Estimator."""
  tf.logging.info("features = %s labels = %s mode = %s params=%s" %
                  (features, labels, mode, params))
  global_step = tf.train.get_global_step()
  graph = mtf.Graph()
  mesh = mtf.Mesh(graph, "my_mesh")
  logits, loss = mnist_model(features, labels, mesh)
  mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
  layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)
  mesh_size = mesh_shape.size
  mesh_devices = [""] * mesh_size
  mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
      mesh_shape, layout_rules, mesh_devices)

  if mode == tf.estimator.ModeKeys.TRAIN:
    var_grads = mtf.gradients(
        [loss], [v.outputs[0] for v in graph.trainable_variables])
    optimizer = mtf.optimize.AdafactorOptimizer()
    update_ops = optimizer.apply_grads(var_grads, graph.trainable_variables)

  lowering = mtf.Lowering(graph, {mesh: mesh_impl})
  restore_hook = mtf.MtfRestoreHook(lowering)

  tf_logits = lowering.export_to_tf_tensor(logits)
  if mode != tf.estimator.ModeKeys.PREDICT:
    tf_loss = lowering.export_to_tf_tensor(loss)
    tf.summary.scalar("loss", tf_loss)

  if mode == tf.estimator.ModeKeys.TRAIN:
    tf_update_ops = [lowering.lowered_operation(op) for op in update_ops]
    tf_update_ops.append(tf.assign_add(global_step, 1))
    train_op = tf.group(tf_update_ops)
    saver = tf.train.Saver(
        tf.global_variables(),
        sharded=True,
        max_to_keep=10,
        keep_checkpoint_every_n_hours=2,
        defer_build=False, save_relative_paths=True)
    tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
    saver_listener = mtf.MtfCheckpointSaverListener(lowering)
    saver_hook = tf.train.CheckpointSaverHook(
        FLAGS.model_dir,
        save_steps=1000,
        saver=saver,
        listeners=[saver_listener])

    accuracy = tf.metrics.accuracy(
        labels=labels, predictions=tf.argmax(tf_logits, axis=1))

    # Name tensors to be logged with LoggingTensorHook.
    tf.identity(tf_loss, "cross_entropy")
    tf.identity(accuracy[1], name="train_accuracy")

    # Save accuracy scalar to Tensorboard output.
    tf.summary.scalar("train_accuracy", accuracy[1])

    # restore_hook must come before saver_hook
    return tf.estimator.EstimatorSpec(
        tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op,
        training_chief_hooks=[restore_hook, saver_hook])

  if mode == tf.estimator.ModeKeys.PREDICT:
    predictions = {
        "classes": tf.argmax(tf_logits, axis=1),
        "probabilities": tf.nn.softmax(tf_logits),
    }
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.PREDICT,
        predictions=predictions,
        prediction_hooks=[restore_hook],
        export_outputs={
            "classify": tf.estimator.export.PredictOutput(predictions)
        })
  if mode == tf.estimator.ModeKeys.EVAL:
    return tf.estimator.EstimatorSpec(
        mode=tf.estimator.ModeKeys.EVAL,
        loss=tf_loss,
        evaluation_hooks=[restore_hook],
        eval_metric_ops={
            "accuracy":
            tf.metrics.accuracy(
                labels=labels, predictions=tf.argmax(tf_logits, axis=1)),
        })
Esempio n. 14
0
    def test_multitower_examples_model(self):
        """Ensure graph search runs properly on a multitower setup.

    This test uses linear_model from examples/convnets.
    """
        with tf.Graph().as_default():

            def linear_model(images, labels, num_classes):
                """Creates a linear model.

        Args:
          images: The input image tensors, a tensor of size
              (batch_size x height_in x width_in x channels).
          labels: The sparse target labels, a tensor of size (batch_size x 1).
          num_classes: The number of classes, needed for one-hot encoding (int).

        Returns:
          loss: The total loss for this model (0-D tensor).
          logits: Predictions for this model (batch_size x num_classes).
        """
                images = tf.reshape(images, [images.shape[0], -1])
                logits = tf.layers.dense(images, num_classes, name='logits')
                loss = sparse_softmax_cross_entropy(labels, logits,
                                                    num_classes)
                return loss, logits

            model = linear_model
            layer_collection = lc.LayerCollection()
            num_towers = 2
            batch_size = num_towers
            num_classes = 2

            # Set up data.
            images = tf.random_uniform(shape=[batch_size, 32, 32, 1])
            labels = tf.random_uniform(dtype=tf.int64,
                                       shape=[batch_size, 1],
                                       maxval=num_classes)

            tower_images = tf.split(images, num_towers)
            tower_labels = tf.split(labels, num_towers)

            # Build model.
            losses = []
            logits = []
            for tower_id in range(num_towers):
                tower_name = 'tower%d' % tower_id
                with tf.name_scope(tower_name):
                    with tf.variable_scope(tf.get_variable_scope(),
                                           reuse=(tower_id > 0)):
                        current_loss, current_logits = model(
                            tower_images[tower_id], tower_labels[tower_id],
                            num_classes + 1)
                        layer_collection.register_categorical_predictive_distribution(
                            current_logits, name='logits')
                        losses.append(current_loss)
                        logits.append(current_logits)

            # Run the graph scanner.
            with tf.variable_scope(tf.get_variable_scope(),
                                   reuse=tf.AUTO_REUSE):
                gs.register_layers(layer_collection, tf.trainable_variables())
            self.assertEqual(len(layer_collection.fisher_blocks), 1)
            fisher_block = list(layer_collection.fisher_blocks.values())[0]
            self.assertIsInstance(fisher_block, fb.FullyConnectedKFACBasicFB)
            self.assertEqual(fisher_block.num_registered_towers, num_towers)

            global_step = tf.train.get_or_create_global_step()
            opt = optimizer.KfacOptimizer(learning_rate=0.1,
                                          cov_ema_decay=0.1,
                                          damping=0.1,
                                          layer_collection=layer_collection,
                                          momentum=0.1)
            cost = tf.reduce_mean(losses)
            (cov_update_thunks,
             inv_update_thunks) = opt.make_vars_and_create_op_thunks()
            cov_update_op = tf.group(*(thunk() for thunk in cov_update_thunks))
            inv_update_op = tf.group(*(thunk() for thunk in inv_update_thunks))
            train_op = opt.minimize(cost, global_step=global_step)
            init = tf.global_variables_initializer()

            # Run a single training step.
            with self.test_session() as sess:
                sess.run(init)
                sess.run([cov_update_op])
                sess.run([inv_update_op])
                sess.run([train_op])
Esempio n. 15
0
  def benchmark_model(self,
                      warmup_runs,
                      bm_runs,
                      num_threads,
                      trace_filename=None):
    """Benchmark model."""
    if self.tensorrt:
      print('Using tensorrt ', self.tensorrt)
      graphdef = self.freeze_model()

    if num_threads > 0:
      print('num_threads for benchmarking: {}'.format(num_threads))
      sess_config = tf.ConfigProto(
          intra_op_parallelism_threads=num_threads,
          inter_op_parallelism_threads=1)
    else:
      sess_config = tf.ConfigProto()

    # rewriter_config_pb2.RewriterConfig.OFF
    sess_config.graph_options.rewrite_options.dependency_optimization = 2
    if self.use_xla:
      sess_config.graph_options.optimizer_options.global_jit_level = (
          tf.OptimizerOptions.ON_2)

    with tf.Graph().as_default(), tf.Session(config=sess_config) as sess:
      inputs = tf.placeholder(tf.float32, name='input', shape=self.inputs_shape)
      output = self.build_model(inputs, is_training=False)

      img = np.random.uniform(size=self.inputs_shape)

      sess.run(tf.global_variables_initializer())
      if self.tensorrt:
        fetches = [inputs.name] + [i.name for i in output]
        goutput = self.convert_tr(graphdef, fetches)
        inputs, output = goutput[0], goutput[1:]

      if not self.use_xla:
        # Don't use tf.group because XLA removes the whole graph for tf.group.
        output = tf.group(*output)
      else:
        output = tf.add_n([tf.reduce_sum(x) for x in output])

      output_name = [output.name]
      input_name = inputs.name
      graphdef = tf.graph_util.convert_variables_to_constants(
          sess, sess.graph_def, output_name)

    with tf.Graph().as_default(), tf.Session(config=sess_config) as sess:
      tf.import_graph_def(graphdef, name='')

      for i in range(warmup_runs):
        start_time = time.time()
        sess.run(output_name, feed_dict={input_name: img})
        logging.info('Warm up: {} {:.4f}s'.format(i, time.time() - start_time))

      print('Start benchmark runs total={}'.format(bm_runs))
      start = time.perf_counter()
      for i in range(bm_runs):
        sess.run(output_name, feed_dict={input_name: img})
      end = time.perf_counter()
      inference_time = (end - start) / bm_runs
      print('Per batch inference time: ', inference_time)
      print('FPS: ', self.batch_size / inference_time)

      if trace_filename:
        run_options = tf.RunOptions()
        run_options.trace_level = tf.RunOptions.FULL_TRACE
        run_metadata = tf.RunMetadata()
        sess.run(
            output_name,
            feed_dict={input_name: img},
            options=run_options,
            run_metadata=run_metadata)
        logging.info('Dumping trace to %s', trace_filename)
        trace_dir = os.path.dirname(trace_filename)
        if not tf.io.gfile.exists(trace_dir):
          tf.io.gfile.makedirs(trace_dir)
        with tf.io.gfile.GFile(trace_filename, 'w') as trace_file:
          trace = timeline.Timeline(step_stats=run_metadata.step_stats)
          trace_file.write(trace.generate_chrome_trace_format(show_memory=True))
Esempio n. 16
0
 def make_update_op(update_thunks):
     update_ops = [thunk() for thunk in update_thunks]
     return tf.group(*update_ops)
Esempio n. 17
0
def build_all_reduce_iterations(all_device_tensors, tower_devices,
                                variable_mgr, num_iters):
    """Builds the all-reduce ops for multiple iterations to aggregate tensors.

  The tensors in `all_device_tensors` are aggregated `num_iters` times. Each
  iteration aggregates the results from the previous iteration. The iterations
  are run sequentially, so the aggregations for an iteration do not start
  running until the previous iteration has completed. Each iteration after the
  first is aggregating already-aggregated values, but it does not matter because
  we are only aggregating for benchmarking purposes.

  Args:
    all_device_tensors: List of lists of tensors. all_device_tensors[t][i] is
      a tensor, where t is the tower the tensor is on and i is the index of
      the tensor.
    tower_devices: A list of device strings. tower_devices[t] is the device
      of the tensors in all_device_tensors[t].
    variable_mgr: The VariableMgr to perform the all-reduce.
    num_iters: Number of iterations to aggregate tensors for.
  Returns:
    An op that when run, causes the all-reduce ops to run.
  """
    for i in range(num_iters):
        with tf.name_scope('iteration_%d' % i):
            # Step 1: Do the aggregation.
            with tf.name_scope('tensor_aggregation'):
                all_device_tensors = all_reduce(all_device_tensors,
                                                variable_mgr)

            # Step 2. Create identity ops, to bring the aggregated results back to
            # each device.
            new_all_device_tensors = []
            for device, device_tensors in zip(tower_devices,
                                              all_device_tensors):
                with tf.device(device):
                    new_all_device_tensors.append([
                        tf.identity(t, name='identity_after_allreduce')
                        for t in device_tensors
                    ])
            all_device_tensors = new_all_device_tensors

            # Step 3. Add control dependencies to delay the next iteration until this
            # iteration is complete. To avoid extra overhead, we do not have any
            # cross-device control dependencies, which means it's possible for two
            # iterations to slightly overlap.
            new_all_device_tensors = []
            for device_tensors in all_device_tensors:
                new_all_device_tensors.append([
                    control_flow_ops.with_dependencies(
                        device_tensors, t, name='identity_after_dependencies')
                    for t in device_tensors
                ])
            all_device_tensors = new_all_device_tensors

    # To prevent the dependency optimizer from removing every op we created,
    # we store the results in variables.
    ops_to_run = []
    for device, device_tensors in zip(tower_devices, all_device_tensors):
        with tf.device(device):
            for t in device_tensors:
                # The placeholder initial value is never run.
                var = tf.Variable(tf.placeholder(tf.float32, t.shape),
                                  collections=[])
                ops_to_run.append(var.assign(t))
    return tf.group(*ops_to_run)
Esempio n. 18
0
 def make_batch_executed_op(update_thunks, batch_size=1):
     return tf.group(*kfac.utils.batch_execute(
         global_step, update_thunks, batch_size=batch_size))
Esempio n. 19
0
    def __init__(self, config):

        if config.dataset.dir:
            # Gets the names of the classes
            classes_file = os.path.join(config.dataset.dir, 'classes.json')
            if tf.gfile.Exists(classes_file):
                self.class_labels = json.load(tf.gfile.GFile(classes_file))
            else:
                self.class_labels = None

        # Don't use data augmentation in predictions
        config.dataset.data_augmentation = None

        dataset_class = get_dataset(config.dataset.type)
        model_class = get_model(config.model.type)
        dataset = dataset_class(config)
        model = model_class(config)

        graph = tf.Graph()
        self.session = tf.Session(graph=graph)

        with graph.as_default():
            self.image_placeholder = tf.placeholder(tf.float32,
                                                    (None, None, 3))
            image_tf, _, process_meta = dataset.preprocess(
                self.image_placeholder)
            pred_dict = model(image_tf)

            # Restore checkpoint
            if config.train.job_dir:
                job_dir = config.train.job_dir
                if config.train.run_name:
                    job_dir = os.path.join(job_dir, config.train.run_name)
                ckpt = tf.train.get_checkpoint_state(job_dir)
                if not ckpt or not ckpt.all_model_checkpoint_paths:
                    raise ValueError(
                        'Could not find checkpoint in {}.'.format(job_dir))
                ckpt = ckpt.all_model_checkpoint_paths[-1]
                saver = tf.train.Saver(sharded=True, allow_empty=True)
                saver.restore(self.session, ckpt)
                tf.logging.info('Loaded checkpoint.')
            else:
                # A prediction without checkpoint is just used for testing
                tf.logging.warning(
                    'Could not load checkpoint. Using initialized model.')
                init_op = tf.group(tf.global_variables_initializer(),
                                   tf.local_variables_initializer())
                self.session.run(init_op)

            if config.model.type == 'ssd':
                cls_prediction = pred_dict['classification_prediction']
                objects_tf = cls_prediction['objects']
                objects_labels_tf = cls_prediction['labels']
                objects_labels_prob_tf = cls_prediction['probs']
            elif config.model.type == 'fasterrcnn':
                if config.model.network.get('with_rcnn', False):
                    cls_prediction = pred_dict['classification_prediction']
                    objects_tf = cls_prediction['objects']
                    objects_labels_tf = cls_prediction['labels']
                    objects_labels_prob_tf = cls_prediction['probs']
                else:
                    rpn_prediction = pred_dict['rpn_prediction']
                    objects_tf = rpn_prediction['proposals']
                    objects_labels_prob_tf = rpn_prediction['scores']
                    # All labels without RCNN are zero
                    objects_labels_tf = tf.zeros(
                        tf.shape(objects_labels_prob_tf), dtype=tf.int32)
            else:
                raise ValueError("Model type '{}' not supported".format(
                    config.model.type))

            self.fetches = {
                'objects': objects_tf,
                'labels': objects_labels_tf,
                'probs': objects_labels_prob_tf,
                'scale_factor': process_meta['scale_factor']
            }

            # If in debug mode, return the full prediction dictionary.
            if config.train.debug:
                self.fetches['_debug'] = pred_dict
Esempio n. 20
0
 def update_op_if_nan_or_inf():
     """Update loss_scale and discard gradients if nans/infs occurred."""
     return tf.group(tf.assign(loss_scale, loss_scale / 2.),
                     tf.assign(loss_scale_normal_steps, 0))
            gen_cost,
            var_list=lib.params_with_name('Generator'),
            colocate_gradients_with_ops=True)
        disc_train_op = tf.train.RMSPropOptimizer(learning_rate=5e-5).minimize(
            disc_cost,
            var_list=lib.params_with_name('Discriminator.'),
            colocate_gradients_with_ops=True)

        clip_ops = []
        for var in lib.params_with_name('Discriminator'):
            clip_bounds = [-.01, .01]
            clip_ops.append(
                tf.assign(
                    var, tf.clip_by_value(var, clip_bounds[0],
                                          clip_bounds[1])))
        clip_disc_weights = tf.group(*clip_ops)

    elif MODE == 'wgan-gp':
        gen_train_op = tf.train.AdamOptimizer(
            learning_rate=1e-4, beta1=0.,
            beta2=0.9).minimize(gen_cost,
                                var_list=lib.params_with_name('Generator'),
                                colocate_gradients_with_ops=True)
        disc_train_op = tf.train.AdamOptimizer(
            learning_rate=1e-4, beta1=0., beta2=0.9).minimize(
                disc_cost,
                var_list=lib.params_with_name('Discriminator.'),
                colocate_gradients_with_ops=True)

    elif MODE == 'dcgan':
        gen_train_op = tf.train.AdamOptimizer(
Esempio n. 22
0
 def update_op_if_no_nan_or_inf():
     """Apply gradients, and update loss scaling."""
     return tf.group(
         get_loss_scale_update_op(loss_scale, loss_scale_normal_steps,
                                  inc_loss_scale_every_n),
         *get_apply_gradients_ops_func())
Esempio n. 23
0
    def my_model_fn(features, labels, mode, params=None, config=None):
        """Estimator model function.
        Args:
          features: dictionary where keys are strings like "inputs" and "targets"
            and the values are the actual values of "inputs". See TPUEstimator's
            docs for more information
          labels: ignored argument
          mode: a tf.estimator.ModeKeys
          params: dictionary containing the key "context"
          config: ignored argument
        Returns:
          a TPUEstimatorSpec
        """
        del labels, config
        global_step = tf.train.get_global_step()
        if use_tpu and "context" in params:
            ctx = params["context"]
            num_hosts = ctx.num_hosts
            host_placement_fn = ctx.tpu_host_placement_function
            device_list = [
                host_placement_fn(host_id=t) for t in range(num_hosts)
            ]
            # TODO(ylc): Better estimation of replica cache size?
            replica_cache_size = 300 * 1000000  # 300M per replica
            # Worker 0 caches all the TPU binaries.
            worker0_mem = replica_cache_size * ctx.num_replicas
            devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
            var_placer = mtf.utils.BalancedVariablePlacer(
                device_list, devices_memeory_usage)
            # deprecated mesh_devices = [""] * mesh_shape.size
            physical_shape = list(
                params["context"].device_assignment.topology.mesh_shape)
            logical_to_physical = mtf.simd_mesh_impl.auto_logical_to_physical_tpu(
                mesh_shape.to_integer_list, physical_shape)
            mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
                mesh_shape,
                layout_rules,
                mesh_devices,
                ctx.device_assignment,
                logical_to_physical=logical_to_physical)
        else:
            var_placer = None
            # deprecated mesh_devices = [""] * mesh_shape.size
            mesh_impl = mtf.placement_mesh_impl.PlacementMeshImpl(
                mesh_shape, layout_rules, mesh_devices)

        graph = mtf.Graph()
        mesh = mtf.Mesh(graph, "my_mesh", var_placer)

        mtf_features = {}
        for key, x in features.items():
            outer_batch_dim = mtf.Dimension("outer_batch", outer_batch_size)
            batch_dim = mtf.Dimension("batch", batch_size // outer_batch_size)
            # Some auxiliary features may have been generated in packing.
            # The names of these new features are of the form
            #   "<original_feature_name>_<suffix>", e.g. "inputs_segmentation".
            #   We look up the lengths based on the original feature name, without
            #   the "_<suffix>".
            feature_length = sequence_length[key.split("_")[0]]
            length_dim = mtf.Dimension("length", feature_length)
            ensemble_dims = ([mtf.Dimension("ensemble", ensemble_inputs)]
                             if ensemble_inputs else [])
            feature_shape = mtf.Shape(ensemble_dims +
                                      [outer_batch_dim, batch_dim, length_dim])
            x = tf.cast(features[key], tf.int32)
            x = tf.reshape(x, feature_shape.to_integer_list)
            if not use_tpu:
                tf.logging.info("feature %s : %s" % (key, x))
                x = tf.Print(x, [x],
                             "import feature %s" % key,
                             summarize=1000,
                             first_n=10)
            mtf_features[key] = mtf.import_fully_replicated(mesh,
                                                            x,
                                                            feature_shape,
                                                            name=key)
            if key == "targets" or key == "codeprefixedtargets" or key == "controlcode":
                anon_targets = mtf.anonymize(mtf_features[key])

        if mode == tf.estimator.ModeKeys.PREDICT:

            def _feature_shape(key):
                feature_length = sequence_length[key.split("_")[0]]
                return mtf.Shape([
                    mtf.Dimension("batch", batch_size),
                    mtf.Dimension("length", feature_length)
                ])

            mtf_features = {
                k: mtf.reshape(v, _feature_shape(k))
                for k, v in six.iteritems(mtf_features)
            }
            inputs = mtf_features["inputs"]

            if attribute_embedding:
                attributes = mtf_features["attribute"]
            else:
                attributes = None

            if has_partial_sequences:
                controlcodes = mtf_features["controlcode"]
            else:
                controlcodes = None

            if predict_fn:
                mtf_samples = predict_fn(model=transformer_model,
                                         features=mtf_features,
                                         variable_dtype=get_variable_dtype())
            elif isinstance(transformer_model, transformer.Unitransformer):
                # pad so that there is enough room for the targets
                inputs = mtf.pad(inputs, [0, sequence_length["targets"]],
                                 length_dim.name)
                mtf_samples = transformer_model.sample_autoregressive(
                    inputs,
                    variable_dtype=get_variable_dtype(),
                    remove_partial_sequences=True)
            elif isinstance(transformer_model, Bitransformer_ll):
                mtf_samples = transformer_model.decode(
                    inputs,
                    attributes=attributes,
                    controlcodes=controlcodes,
                    has_partial_sequences=has_partial_sequences,
                    remove_partial_sequences=remove_partial_sequences,
                    variable_dtype=get_variable_dtype())  #
            elif isinstance(
                    transformer_model,
                (transformer.Bitransformer, transformer.StudentTeacher)):
                mtf_samples = transformer_model.decode(
                    inputs, variable_dtype=get_variable_dtype())
            else:
                raise ValueError("unrecognized class")
            mtf_samples = mtf.anonymize(mtf_samples)
            inputs = mtf.anonymize(inputs)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            inputs = lowering.export_to_tf_tensor(inputs)
            outputs = lowering.export_to_tf_tensor(mtf_samples)
            predictions = {"inputs": inputs, "outputs": outputs}

            # When exporting a model, we need to communicate to TF-Serving that
            # master variables need to be copied to their slave slice variables.
            # Estimator uses a Scaffold's "local_init_op" for this purpose, so we
            # augment the default "local_init_op" here.
            #
            # The "ready_op" is also constructed here to ensure the variables
            # initialized by "local_init_op" are the same ones checked by "ready_op".
            #
            # WARNING: Any variables created outside of this model_fn()
            # (e.g. tpu_estimator/iterations_per_loop) will NOT be initialized nor
            # checked by these ops.
            def scaffold_fn():
                return tf.train.Scaffold(
                    local_init_op=tf.group(
                        tf.train.Scaffold.default_local_init_op(),
                        lowering.copy_masters_to_slices(),
                        name="mtf_local_init_op"),
                    ready_op=tf.concat([
                        tf.report_uninitialized_variables(),
                        resources.report_uninitialized_resources()
                    ],
                                       axis=0,
                                       name="mtf_ready_op"))

            return tpu_estimator.TPUEstimatorSpec(
                mode=tf.estimator.ModeKeys.PREDICT,
                predictions=predictions,
                scaffold_fn=scaffold_fn,
                prediction_hooks=[mtf.MtfRestoreHook(lowering)])

        assert (mode == tf.estimator.ModeKeys.TRAIN
                or mode == tf.estimator.ModeKeys.EVAL)

        def logits_and_loss(mtf_features):
            """Compute logits and loss.
            Args:
              mtf_features: a dictionary
            Returns:
              logits: a mtf.Tensor
              loss: a mtf.Tensor
            """
            if model_type == "lm":  # TOTRY Adapt that to our case
                if "inputs" in mtf_features:
                    mtf_features = _dynamic_text2self(mtf_features)
                _, _, length_dim = mtf_features["targets"].shape
                inputs = mtf.shift(mtf_features["targets"],
                                   offset=1,
                                   dim=length_dim,
                                   wrap=False)
            else:
                inputs = mtf_features["inputs"]

            if attribute_embedding:
                attributes = mtf_features["attribute"]
            else:
                attributes = None

            if control_codes:
                codeprefixedtargets = mtf_features["codeprefixedtargets"]
            else:
                codeprefixedtargets = None

            if isinstance(transformer_model, transformer.Unitransformer):
                position_kwargs = dict(
                    sequence_id=mtf_features.get("targets_segmentation", None),
                    position=mtf_features.get("targets_position", None),
                )
            elif isinstance(transformer_model, transformer.Bitransformer
                            ) or model_type == "bi_student_teacher":
                if control_codes:
                    position_kwargs = dict(
                        encoder_sequence_id=mtf_features.get(
                            "inputs_segmentation", None),
                        decoder_sequence_id=mtf_features.get(
                            "codeprefixedtargets_segmentation", None),
                        decoder_subsequence_id=mtf_features.get(
                            "codeprefixedtargets_subsegmentation", None),
                        encoder_position=mtf_features.get(
                            "inputs_position", None),
                        decoder_position=mtf_features.get(
                            "codeprefixedtargets_position", None),
                    )
                else:
                    position_kwargs = dict(
                        encoder_sequence_id=mtf_features.get(
                            "inputs_segmentation", None),
                        decoder_sequence_id=mtf_features.get(
                            "targets_segmentation", None),
                        decoder_subsequence_id=mtf_features.get(
                            "targets_subsegmentation", None),
                        encoder_position=mtf_features.get(
                            "inputs_position", None),
                        decoder_position=mtf_features.get(
                            "targets_position", None),
                    )
            else:
                raise ValueError("unrecognized class")

            if isinstance(transformer_model, Bitransformer_ll):
                if cycle_consistency_loss:
                    logits_ae, l_ae = transformer_model.call_simple(
                        inputs=inputs,
                        targets=mtf_features["targets"],
                        compute_loss=True,
                        attributes=attributes,
                        codeprefixedtargets=codeprefixedtargets,
                        mode=mode,
                        variable_dtype=get_variable_dtype(),
                        **position_kwargs)

                    if has_partial_sequences:
                        controlcodes = mtf_features["controlcode"]
                    else:
                        controlcodes = None

                    with gin.config_scope('training'):
                        mtf_samples = transformer_model.decode(
                            inputs,
                            attributes=attributes,
                            controlcodes=controlcodes,
                            has_partial_sequences=has_partial_sequences,
                            remove_partial_sequences=remove_partial_sequences,
                            variable_dtype=get_variable_dtype())
                        # mtf_samples = mtf.anonymize(mtf_samples)
                    outputs = mtf_samples

                    logits_cycle, l_cycle = transformer_model.call_simple(
                        inputs=outputs,
                        targets=mtf_features["targets"],
                        compute_loss=True,
                        attributes=attributes,
                        codeprefixedtargets=codeprefixedtargets,
                        mode=mode,
                        variable_dtype=get_variable_dtype(),
                        **position_kwargs)

                    loss_ae_cycle = lambda_ae * l_ae + lambda_cycle * l_cycle
                    return logits_cycle, loss_ae_cycle
                else:
                    return transformer_model.call_simple(
                        inputs=inputs,
                        targets=mtf_features["targets"],
                        compute_loss=True,
                        attributes=attributes,
                        codeprefixedtargets=codeprefixedtargets,
                        mode=mode,
                        variable_dtype=get_variable_dtype(),
                        **position_kwargs)
            else:
                return transformer_model.call_simple(
                    inputs=inputs,
                    targets=mtf_features["targets"],
                    compute_loss=True,
                    mode=mode,
                    variable_dtype=get_variable_dtype(),
                    num_microbatches=num_microbatches,
                    **position_kwargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            num_microbatches = serialize_num_microbatches(
                batch_dim, sequence_length, mesh_shape, layout_rules)
            if num_microbatches > 1:

                def serialized_fn(mtf_features):
                    return {
                        "loss":
                        (logits_and_loss(mtf_features)[1] / num_microbatches)
                    }

                var_grads, loss_dict = mtf.serialize_training_step(
                    mtf_features, serialized_fn, batch_dim, num_microbatches)
                loss = loss_dict["loss"]
            else:
                loss = logits_and_loss(mtf_features)[1]
                var_grads = mtf.gradients(
                    [loss], [v.outputs[0] for v in graph.trainable_variables])

            if tpu_summaries:
                mtf.scalar_summary("loss", loss)

            if callable(learning_rate_schedule):
                # the following happens on CPU since TPU can't handle summaries.
                with mtf.utils.outside_all_rewrites():
                    learning_rate = learning_rate_schedule(
                        step=tf.train.get_global_step())
                    tf.summary.scalar("learning_rate", learning_rate)
            else:
                learning_rate = learning_rate_schedule

            if isinstance(variable_filter, str):
                pattern = re.compile(variable_filter)
                variable_filter_fn = lambda v: pattern.search(v.name)
            elif variable_filter is None:
                variable_filter_fn = lambda v: True
            elif callable(variable_filter):
                variable_filter_fn = variable_filter
            else:
                raise ValueError(
                    "variable_filter must be None, a string, or a callable function"
                )
            trainable_vars = [
                v for v in graph.trainable_variables if variable_filter_fn(v)
            ]
            trainable_var_grads = [
                g for g, v in zip(var_grads, graph.trainable_variables)
                if variable_filter_fn(v)
            ]
            if len(trainable_vars) != len(graph.trainable_variables):
                tf.logging.info("Variables being trained:")
                tf.logging.info([v.name for v in trainable_vars])
                tf.logging.info("Variables not being trained:")
                tf.logging.info([
                    v.name for v in graph.trainable_variables
                    if not variable_filter_fn(v)
                ])

            update_ops = optimizer(learning_rate=learning_rate).apply_grads(
                trainable_var_grads, trainable_vars)

            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)

            tf_loss = lowering.export_to_tf_tensor(loss)
            tf_loss = tf.cast(tf_loss, tf.float32)
            if not use_tpu:
                tf_loss = tf.Print(
                    tf_loss, [tf_loss, tf.train.get_global_step()],
                    "step, tf_loss")

            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            train_op = tf.group(tf_update_ops)

            if hasattr(transformer_model, "initialize"):
                with mtf.utils.outside_all_rewrites():
                    transformer_model.initialize()

            if tpu_summaries:
                # has to be outside of
                # with mtf.utils.outside_all_rewrites()
                host_call = mtf.utils.create_host_call(model_dir)
                mtf.utils.remove_summaries()
            else:
                host_call = None

            with mtf.utils.outside_all_rewrites():

                if init_checkpoint:
                    ckpt_vars = {
                        v
                        for v, _ in tf.train.list_variables(init_checkpoint)
                    }
                    global_vars = {v.op.name for v in tf.global_variables()}
                    restore_vars = ckpt_vars.intersection(global_vars)
                    tf.logging.info("Initializing variables from %s:",
                                    init_checkpoint)
                    tf.logging.debug("\n".join(sorted(restore_vars)))
                    tf.logging.info("Variables in %s but not in graph:",
                                    init_checkpoint)
                    tf.logging.info("\n".join(sorted(ckpt_vars - global_vars)))
                    tf.logging.info("Variables in graph but not in %s:",
                                    init_checkpoint)
                    tf.logging.info("\n".join(sorted(global_vars - ckpt_vars)))
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  {v: v
                                                   for v in restore_vars})

                # Copy master variables to slices. Must be called first.
                restore_hook = mtf.MtfRestoreHook(lowering)
                saver = tf.train.Saver(tf.global_variables(),
                                       sharded=True,
                                       max_to_keep=keep_checkpoint_max,
                                       keep_checkpoint_every_n_hours=2,
                                       defer_build=False,
                                       save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                saver_hook = tf.train.CheckpointSaverHook(
                    model_dir,
                    save_steps=save_checkpoints_steps,
                    saver=saver,
                    listeners=[saver_listener])
                gin_config_saver_hook = gin.tf.GinConfigSaverHook(
                    model_dir,
                    summarize_config=True,
                    include_step_in_filename=False)

                if use_tpu:
                    return tpu_estimator.TPUEstimatorSpec(
                        mode=tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        host_call=host_call,
                        training_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
                else:
                    return tf.estimator.EstimatorSpec(
                        tf.estimator.ModeKeys.TRAIN,
                        loss=tf_loss,
                        train_op=train_op,
                        training_chief_hooks=[
                            restore_hook,
                            saver_hook,
                            gin_config_saver_hook,
                        ])
        elif mode == tf.estimator.ModeKeys.EVAL:
            logits, loss = logits_and_loss(mtf_features)
            anon_logits = mtf.anonymize(logits)
            lowering = mtf.Lowering(graph, {mesh: mesh_impl},
                                    autostack=autostack)
            tf_loss = tf.cast(lowering.export_to_tf_tensor(loss), tf.float32)
            tf_loss = tf.cast(tf_loss, tf.float32)
            tf_logits = tf.cast(lowering.export_to_tf_tensor(anon_logits),
                                tf.float32)

            def simple_metrics(logits, labels):
                """Simple metrics for teacher-forced eval."""
                weights = tf.cast(tf.not_equal(labels, 0), tf.float32)
                xent = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=labels, logits=logits)
                predictions = tf.cast(tf.argmax(logits, axis=-1), labels.dtype)
                token_correct = tf.cast(tf.equal(predictions, labels),
                                        tf.float32) * weights
                sequence_correct = tf.to_float(
                    tf.equal(tf.reduce_sum(token_correct, -1),
                             tf.reduce_sum(weights, -1)))
                sequence_weights = tf.to_float(
                    tf.not_equal(tf.reduce_sum(weights, -1), 0))
                return {
                    "neg_log_perplexity":
                    tf.metrics.mean(-xent, weights),
                    "token_accuracy":
                    tf.metrics.mean(token_correct, weights),
                    "sequence_accuracy":
                    tf.metrics.mean(sequence_correct, sequence_weights)
                }

            labels = lowering.export_to_tf_tensor(anon_targets)
            eval_metrics = (simple_metrics, [tf_logits, labels])
            with mtf.utils.outside_all_rewrites():
                restore_hook = mtf.MtfRestoreHook(lowering)
            return tpu_estimator.TPUEstimatorSpec(
                tf.estimator.ModeKeys.EVAL,
                evaluation_hooks=[restore_hook],
                loss=tf_loss,
                eval_metrics=eval_metrics)
Esempio n. 24
0
 def increment_loss_scale_normal_steps_func():
     return tf.group(loss_scale_normal_steps.assign_add(1))
    def _compute_inner_update_onsbet(self, var, grad):
        update_ops = []

        eta = tf.cast(self.eta, var.dtype.base_dtype)
        betting_domain = tf.cast(self.betting_domain, var.dtype.base_dtype)

        wealth = self.get_slot(var, INNER_WEALTH)
        betting_fraction = self.get_slot(var, OUTER_BETTING_FRACTION)
        inner_betting_fraction = self.get_slot(var, INNER_BETTING_FRACTION)
        sum_grad_squared = self.get_slot(var, INNER_SUM_GRAD_SQUARED)
        inner_maximum_gradient = self.get_slot(var, INNER_MAXIMUM_GRADIENT)

        inner_maximum_gradient_updated = self._assign(
            inner_maximum_gradient,
            tf.maximum(inner_maximum_gradient, tf.abs(grad)))
        update_ops.append(inner_maximum_gradient_updated)

        clipped_old_betting_fraction = tf.clip_by_value(
            betting_fraction, -betting_domain, betting_domain)

        # Process grad to respect truncation to [-betting_domain, betting_domain]
        truncated_grad = tf.where(
            tf.greater_equal(
                grad * (betting_fraction - clipped_old_betting_fraction), 0),
            grad, tf.zeros(tf.shape(grad)))

        wealth_delta = -betting_fraction * truncated_grad
        wealth_updated = self._assign_add(wealth, wealth_delta)
        update_ops.append(wealth_updated)

        # This is the gradient with respect to the betting fraction v
        # use by the ONS algorithm - a kind of "inner inner grad".
        # Hueristic: We also scale v_grad down by the inner maximum gradient so as
        # to make it ``unitless''. This is helpful because the learning rate for
        # ONS is proportional to sum v_grad**2, and so the scale of the learning
        # rate and of v_grad are unlikely to be properly matched without this.
        if self.rescale_inner:
            v_grad = truncated_grad / (
                (1.0 - inner_betting_fraction * truncated_grad) *
                inner_maximum_gradient_updated)
        else:
            v_grad = truncated_grad / (
                (1.0 - inner_betting_fraction * truncated_grad))

        sum_grad_squared_updated = self._assign_add(sum_grad_squared,
                                                    tf.square(v_grad))
        update_ops.append(sum_grad_squared_updated)

        new_inner_betting_fraction = inner_betting_fraction - eta * v_grad / (
            sum_grad_squared_updated)
        new_inner_betting_fraction = tf.clip_by_value(
            new_inner_betting_fraction, -betting_domain, betting_domain)
        inner_betting_fraction_updated = self._assign(
            inner_betting_fraction, new_inner_betting_fraction)
        update_ops.append(inner_betting_fraction_updated)

        if self.output_summaries:
            mean_inner_betting_fraction_summary = tf.reduce_mean(
                tf.abs(inner_betting_fraction_updated))
            max_inner_betting_fraction_summary = tf.reduce_max(
                tf.abs(inner_betting_fraction_updated))
            inner_maximum_gradient_summary = tf.reduce_max(
                inner_maximum_gradient_updated)
            tf.summary.scalar(self._name + "/mean_inner_betting/" + var.name,
                              mean_inner_betting_fraction_summary)
            tf.summary.scalar(self._name + "/max_inner_betting/" + var.name,
                              max_inner_betting_fraction_summary)
            tf.summary.scalar(
                self._name + "/inner_maximum_gradient/" + var.name,
                inner_maximum_gradient_summary)

        betting_fraction_updated = self._assign(
            betting_fraction, inner_betting_fraction_updated * wealth_updated)
        update_ops.append(betting_fraction_updated)

        clipped_betting_fraction = tf.clip_by_value(betting_fraction_updated,
                                                    -betting_domain,
                                                    betting_domain)

        return clipped_betting_fraction, tf.group(*update_ops)
Esempio n. 26
0
 def increase_loss_scale_func():
     return tf.group(tf.assign(loss_scale_normal_steps, 0),
                     tf.assign(loss_scale, loss_scale * 2))
Esempio n. 27
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        tf.logging.info("*** Features ***")
        for name in sorted(features.keys()):
            tf.logging.info("  name = %s, shape = %s" %
                            (name, features[name].shape))

        # MTF setup.
        graph = mtf.Graph()
        mesh_shape = mtf.convert_to_shape(FLAGS.mesh_shape)
        layout_rules = mtf.convert_to_layout_rules(FLAGS.layout)

        ctx = params["context"]
        num_hosts = ctx.num_hosts
        host_placement_fn = ctx.tpu_host_placement_function
        device_list = [host_placement_fn(host_id=t) for t in range(num_hosts)]
        tf.logging.info("device_list = %s" % device_list, )
        replica_cache_size = 300 * 1000000  # 300M per replica
        # Worker 0 caches all the TPU binaries.
        worker0_mem = replica_cache_size * ctx.num_replicas
        devices_memeory_usage = [worker0_mem] + [0] * (num_hosts - 1)
        var_placer = mtf.utils.BalancedVariablePlacer(device_list,
                                                      devices_memeory_usage)
        mesh_devices = [""] * mesh_shape.size
        physical_shape = list(ctx.device_assignment.topology.mesh_shape)
        logical_to_physical = mtf.simd_mesh_impl.auto_logical_to_physical_tpu(
            mesh_shape.to_integer_list, physical_shape)
        mesh_impl = mtf.simd_mesh_impl.SimdMeshImpl(
            mesh_shape,
            layout_rules,
            mesh_devices,
            ctx.device_assignment,
            logical_to_physical=logical_to_physical)
        mesh = mtf.Mesh(graph, "bert_mesh", var_placer)

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]
        next_sentence_labels = tf.squeeze(features["next_sentence_labels"], 1)

        batch_size = input_ids.get_shape()[0].value
        batch_dim = mtf.Dimension("batch", batch_size)

        seq_length = input_ids.get_shape()[1].value
        seq_dim = mtf.Dimension("seq", seq_length)
        max_predictions_per_seq = masked_lm_positions.get_shape()[1].value
        max_predictions_per_seq_dim = mtf.Dimension("max_pred_seq",
                                                    max_predictions_per_seq)

        mtf_input_ids = mtf.import_tf_tensor(mesh, input_ids,
                                             [batch_dim, seq_dim])
        mtf_input_mask = mtf.import_tf_tensor(mesh, input_mask,
                                              [batch_dim, seq_dim])
        mtf_segment_ids = mtf.import_tf_tensor(mesh, segment_ids,
                                               [batch_dim, seq_dim])
        mtf_masked_lm_positions = mtf.import_tf_tensor(
            mesh, masked_lm_positions,
            [batch_dim, max_predictions_per_seq_dim])
        mtf_masked_lm_ids = mtf.import_tf_tensor(
            mesh, masked_lm_ids, [batch_dim, max_predictions_per_seq_dim])

        mtf_masked_lm_weights = mtf.import_tf_tensor(
            mesh, masked_lm_weights, [batch_dim, max_predictions_per_seq_dim])
        mtf_next_sentence_labels = mtf.import_tf_tensor(
            mesh, next_sentence_labels, [batch_dim])

        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        model = bert_lib.BertModel(config=bert_config,
                                   is_training=is_training,
                                   input_ids=mtf_input_ids,
                                   input_mask=mtf_input_mask,
                                   token_type_ids=mtf_segment_ids,
                                   layout=layout_rules,
                                   mesh_shape=mesh_shape)

        (masked_lm_loss, masked_lm_example_loss,
         masked_lm_logits) = model.get_masked_lm_output(
             mtf_masked_lm_positions, mtf_masked_lm_ids, mtf_masked_lm_weights)

        (next_sentence_loss, next_sentence_example_loss, next_sentence_logits
         ) = model.get_next_sentence_output(mtf_next_sentence_labels)

        extra_loss = model.get_extra_loss()

        total_loss = masked_lm_loss + next_sentence_loss
        total_loss = mtf.anonymize(total_loss)
        masked_lm_example_loss = mtf.anonymize(masked_lm_example_loss)
        masked_lm_logits = mtf.anonymize(masked_lm_logits)
        next_sentence_example_loss = mtf.anonymize(next_sentence_example_loss)
        next_sentence_logits = mtf.anonymize(next_sentence_logits)

        # TRAIN mode
        if mode == tf.estimator.ModeKeys.TRAIN:
            _, update_ops = optimization_lib.create_optimizer(
                total_loss + extra_loss,
                learning_rate,
                num_train_steps,
                num_warmup_steps,
                optimizer=FLAGS.optimizer,
                clip_gradients=FLAGS.clip_gradients)

        lowering = mtf.Lowering(graph, {mesh: mesh_impl})

        tf_loss = tf.to_float(lowering.export_to_tf_tensor(total_loss))

        if mode == tf.estimator.ModeKeys.TRAIN:
            global_step = tf.train.get_global_step()
            tf_update_ops = [
                lowering.lowered_operation(op) for op in update_ops
            ]
            tf_update_ops.append(tf.assign_add(global_step, 1))
            tf.logging.info("tf_update_ops: {}".format(tf_update_ops))
            train_op = tf.group(tf_update_ops)
        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(masked_lm_example_loss, masked_lm_logits,
                          masked_lm_ids, masked_lm_weights,
                          next_sentence_example_loss, next_sentence_logits,
                          next_sentence_labels):
                """Computes the loss and accuracy of the model."""
                masked_lm_logits = tf.reshape(masked_lm_logits,
                                              [-1, masked_lm_logits.shape[-1]])
                masked_lm_predictions = tf.argmax(masked_lm_logits,
                                                  axis=-1,
                                                  output_type=tf.int32)
                masked_lm_example_loss = tf.reshape(masked_lm_example_loss,
                                                    [-1])
                masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
                masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
                masked_lm_accuracy = tf.metrics.accuracy(
                    labels=masked_lm_ids,
                    predictions=masked_lm_predictions,
                    weights=masked_lm_weights)
                masked_lm_mean_loss = tf.metrics.mean(
                    values=masked_lm_example_loss, weights=masked_lm_weights)

                next_sentence_logits = tf.reshape(
                    next_sentence_logits, [-1, next_sentence_logits.shape[-1]])
                next_sentence_predictions = tf.argmax(next_sentence_logits,
                                                      axis=-1,
                                                      output_type=tf.int32)
                next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
                next_sentence_accuracy = tf.metrics.accuracy(
                    labels=next_sentence_labels,
                    predictions=next_sentence_predictions)
                next_sentence_mean_loss = tf.metrics.mean(
                    values=next_sentence_example_loss)

                return {
                    "masked_lm_accuracy": masked_lm_accuracy,
                    "masked_lm_loss": masked_lm_mean_loss,
                    "next_sentence_accuracy": next_sentence_accuracy,
                    "next_sentence_loss": next_sentence_mean_loss,
                }

            eval_metrics = (metric_fn, [
                lowering.export_to_tf_tensor(masked_lm_example_loss),
                lowering.export_to_tf_tensor(masked_lm_logits), masked_lm_ids,
                masked_lm_weights,
                lowering.export_to_tf_tensor(next_sentence_example_loss),
                lowering.export_to_tf_tensor(next_sentence_logits),
                next_sentence_labels
            ])

        with mtf.utils.outside_all_rewrites():
            # Copy master variables to slices. Must be called first.
            restore_hook = mtf.MtfRestoreHook(lowering)
            if mode == tf.estimator.ModeKeys.TRAIN:
                saver = tf.train.Saver(tf.global_variables(),
                                       sharded=True,
                                       max_to_keep=10,
                                       keep_checkpoint_every_n_hours=2,
                                       defer_build=False,
                                       save_relative_paths=True)
                tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
                saver_listener = mtf.MtfCheckpointSaverListener(lowering)
                saver_hook = tf.train.CheckpointSaverHook(
                    FLAGS.output_dir,
                    save_steps=1000,
                    saver=saver,
                    listeners=[saver_listener])

                return tf.estimator.tpu.TPUEstimatorSpec(
                    tf.estimator.ModeKeys.TRAIN,
                    loss=tf_loss,
                    train_op=train_op,
                    training_hooks=[restore_hook, saver_hook])
            elif mode == tf.estimator.ModeKeys.EVAL:
                return tf.estimator.tpu.TPUEstimatorSpec(
                    tf.estimator.ModeKeys.EVAL,
                    evaluation_hooks=[restore_hook],
                    loss=tf_loss,
                    eval_metrics=eval_metrics)
Esempio n. 28
0
    def evaluate(self, input_fn, checkpoint_path=None):
        if not tf.train.latest_checkpoint(checkpoint_path):
            raise ValueError("Could not find trained model at %s" %
                             checkpoint_path)

        with tf.Graph().as_default():
            features, labels = self._get_features_and_labels_from_input_fn(
                input_fn, ModeKeys.EVAL)
            spec, model = self._get_model_spec(features, labels, ModeKeys.EVAL)

            # Track the average loss in default
            eval_metric_ops = spec.eval_metric_ops or {}
            if model_fn_lib.LOSS_METRIC_KEY not in eval_metric_ops:
                loss_metric = tf.metrics.mean(spec.loss)
                eval_metric_ops[model_fn_lib.LOSS_METRIC_KEY] = loss_metric

            # Create the real eval op
            update_ops, eval_dict = _extract_metric_update_ops(eval_metric_ops)
            update_ops.extend(model._train_ops)
            eval_op = tf.group(*update_ops)

            # Also track the global step
            if tf.GraphKeys.GLOBAL_STEP in eval_dict:
                raise ValueError(
                    'Metric with name `global_step` is not allowed, because '
                    'Estimator already defines a default metric with the '
                    'same name.')
            eval_dict[tf.GraphKeys.GLOBAL_STEP] = \
                tf.train.get_or_create_global_step()

            # Prepare the session creator.
            scaffold = tf.train.Scaffold()
            session_creator = tf.train.ChiefSessionCreator(
                scaffold=scaffold, checkpoint_dir=checkpoint_path)

            # Prepare hooks
            all_hooks = list(spec.evaluation_hooks) or []
            final_ops_hook = tf.train.FinalOpsHook(eval_dict)
            all_hooks.append(final_ops_hook)

            # Evaluate over dataset
            self._bridge.connect()
            try:
                with tf.train.MonitoredSession(session_creator=session_creator,
                                               hooks=all_hooks) as sess:
                    if not self._restore_datablock(DATA_CHECKPOINT_INIT_VALUE):
                        raise ValueError("Restore data checkpoint error")
                    iter_id = 0
                    while not sess.should_stop():
                        self._bridge.start(iter_id)
                        logging.debug('after bridge start.')
                        start_time = time.time()
                        sess.run(eval_op)
                        end_time = time.time()
                        metrics.emit_timer(name="iter_timer",
                                           value=end_time - start_time,
                                           tags={})
                        logging.debug('after session run.')
                        self._bridge.commit()
                        logging.debug('after bridge commit.')
                        iter_id += 1
            finally:
                self._bridge.terminate()

            # Print result
            logging.info('Metrics for iteration %d: %s', iter_id,
                         _dict_to_str(final_ops_hook.final_ops_values))
            return final_ops_hook.final_ops_values
Esempio n. 29
0
    def benchmark_model(self,
                        warmup_runs,
                        bm_runs,
                        num_threads,
                        trace_filename=None):
        """Benchmark model."""
        if self.tensorrt:
            print('Using tensorrt ', self.tensorrt)
            self.build_and_save_model()
            graphdef = self.freeze_model()

        if num_threads > 0:
            print('num_threads for benchmarking: {}'.format(num_threads))
            sess_config = tf.ConfigProto(
                intra_op_parallelism_threads=num_threads,
                inter_op_parallelism_threads=1)
        else:
            sess_config = tf.ConfigProto()

        # rewriter_config_pb2.RewriterConfig.OFF
        sess_config.graph_options.rewrite_options.dependency_optimization = 2
        if self.use_xla:
            sess_config.graph_options.optimizer_options.global_jit_level = (
                tf.OptimizerOptions.ON_2)

        with tf.Graph().as_default(), tf.Session(config=sess_config) as sess:
            inputs = tf.placeholder(tf.float32,
                                    name='input',
                                    shape=self.inputs_shape)
            output = self.build_model(inputs, is_training=False)

            img = np.random.uniform(size=self.inputs_shape)

            sess.run(tf.global_variables_initializer())
            if self.tensorrt:
                fetches = [inputs.name] + [i.name for i in output]
                goutput = self.convert_tr(graphdef, fetches)
                inputs, output = goutput[0], goutput[1:]

            if not self.use_xla:
                # Don't use tf.group because XLA removes the whole graph for tf.group.
                output = tf.group(*output)
            for i in range(warmup_runs):
                start_time = time.time()
                sess.run(output, feed_dict={inputs: img})
                print('Warm up: {} {:.4f}s'.format(i,
                                                   time.time() - start_time))
            print('Start benchmark runs total={}'.format(bm_runs))
            timev = []
            for i in range(bm_runs):
                if trace_filename and i == (bm_runs // 2):
                    run_options = tf.RunOptions()
                    run_options.trace_level = tf.RunOptions.FULL_TRACE
                    run_metadata = tf.RunMetadata()
                    sess.run(output,
                             feed_dict={inputs: img},
                             options=run_options,
                             run_metadata=run_metadata)
                    logging.info('Dumping trace to %s', trace_filename)
                    trace_dir = os.path.dirname(trace_filename)
                    if not tf.io.gfile.exists(trace_dir):
                        tf.io.gfile.makedirs(trace_dir)
                    with tf.io.gfile.GFile(trace_filename, 'w') as trace_file:
                        from tensorflow.python.client import timeline  # pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
                        trace = timeline.Timeline(
                            step_stats=run_metadata.step_stats)
                        trace_file.write(
                            trace.generate_chrome_trace_format(
                                show_memory=True))

                start_time = time.time()
                sess.run(output, feed_dict={inputs: img})
                timev.append(time.time() - start_time)

            timev.sort()
            timev = timev[2:bm_runs - 2]
            print(
                '{} {}runs {}threads: mean {:.4f} std {:.4f} min {:.4f} max {:.4f}'
                .format(self.model_name, len(timev), num_threads,
                        np.mean(timev), np.std(timev), np.min(timev),
                        np.max(timev)))
Esempio n. 30
0
    def train(
            self,
            num_iterations=100,
            learning_rate=1.0,
            plot_results=True,  # Changed from True
            optimizer=tf.train.GradientDescentOptimizer):
        """Trains the model.
        Args:
          iterations: number of iterations to run.
          learning_rate: optimizer learning rate.
          plot_results: whether to plot the results at the end of training.
          optimizer: the optimizer to use. Default to GradientDescentOptimizer.
        Returns:
          The metrics dictionary evaluated at the last iteration.
        """
        with self._loss.graph.as_default():
            opt = optimizer(learning_rate)
            train_op = opt.minimize(self._loss)
            local_init_op = tf.group(tf.variables_initializer(opt.variables()),
                                     tf.local_variables_initializer())
            if self._session is None:
                self._session = tf.Session()
                with self._session.as_default():
                    self._session.run(tf.global_variables_initializer())
                    self._session.run(tf.tables_initializer())
                    tf.train.start_queue_runners()

        with self._session.as_default():
            local_init_op.run()
            iterations = []
            metrics = self._metrics or ({}, )
            metrics_vals = [
                collections.defaultdict(list) for _ in self._metrics
            ]

            # Train and append results.
            for i in range(num_iterations + 1):
                _, results = self._session.run((train_op, metrics))
                if (i % 10 == 0) or i == num_iterations:
                    print("\r iteration %d: " % i + ", ".join([
                        "%s=%f" % (k, v) for r in results
                        for k, v in r.items()
                    ]),
                          end='')
                    iterations.append(i)
                    for metric_val, result in zip(metrics_vals, results):
                        for k, v in result.items():
                            metric_val[k].append(v)

            for k, v in self._embedding_vars.items():
                self._embeddings[k] = v.eval()

            if plot_results:
                # Plot the metrics.
                num_subplots = len(metrics) + 1
                fig = plt.figure()
                fig.set_size_inches(num_subplots * 10, 8)
                for i, metric_vals in enumerate(metrics_vals):
                    ax = fig.add_subplot(1, num_subplots, i + 1)
                    for k, v in metric_vals.items():
                        ax.plot(iterations, v, label=k)
                    ax.set_xlim([1, num_iterations])
                    ax.legend()
                    plt.show()
            return results