예제 #1
0
 def model_fn(self, features, labels, mode, params):
     """TPUEstimator compatible model function."""
     del labels
     is_training = (mode == tf.estimator.ModeKeys.TRAIN)
     data_shape = features.get_shape().as_list()[1:]
     z_mean, z_logvar = self.gaussian_encoder(features,
                                              is_training=is_training)
     z_sampled = self.sample_from_latent_distribution(z_mean, z_logvar)
     z_shuffle = shuffle_codes(z_sampled)
     with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
         logits_z, probs_z = architectures.make_discriminator(
             z_sampled, is_training=is_training)
         _, probs_z_shuffle = architectures.make_discriminator(
             z_shuffle, is_training=is_training)
     reconstructions = self.decode(z_sampled, data_shape, is_training)
     per_sample_loss = losses.make_reconstruction_loss(
         features, reconstructions)
     reconstruction_loss = tf.reduce_mean(per_sample_loss)
     kl_loss = compute_gaussian_kl(z_mean, z_logvar)
     standard_vae_loss = tf.add(reconstruction_loss,
                                kl_loss,
                                name="VAE_loss")
     # tc = E[log(p_real)-log(p_fake)] = E[logit_real - logit_fake]
     tc_loss_per_sample = logits_z[:, 0] - logits_z[:, 1]
     tc_loss = tf.reduce_mean(tc_loss_per_sample, axis=0)
     regularizer = kl_loss + self.gamma * tc_loss
     factor_vae_loss = tf.add(standard_vae_loss,
                              self.gamma * tc_loss,
                              name="factor_VAE_loss")
     discr_loss = tf.add(0.5 * tf.reduce_mean(tf.log(probs_z[:, 0])),
                         0.5 *
                         tf.reduce_mean(tf.log(probs_z_shuffle[:, 1])),
                         name="discriminator_loss")
     if mode == tf.estimator.ModeKeys.TRAIN:
         optimizer_vae = optimizers.make_vae_optimizer()
         optimizer_discriminator = optimizers.make_discriminator_optimizer()
         all_variables = tf.trainable_variables()
         encoder_vars = [
             var for var in all_variables if "encoder" in var.name
         ]
         decoder_vars = [
             var for var in all_variables if "decoder" in var.name
         ]
         discriminator_vars = [var for var in all_variables \
                               if "discriminator" in var.name]
         update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
         train_op_vae = optimizer_vae.minimize(
             loss=factor_vae_loss,
             global_step=tf.train.get_global_step(),
             var_list=encoder_vars + decoder_vars)
         train_op_discr = optimizer_discriminator.minimize(
             loss=-discr_loss,
             global_step=tf.train.get_global_step(),
             var_list=discriminator_vars)
         train_op = tf.group(train_op_vae, train_op_discr, update_ops)
         tf.summary.scalar("reconstruction_loss", reconstruction_loss)
         logging_hook = tf.train.LoggingTensorHook(
             {
                 "loss": factor_vae_loss,
                 "reconstruction_loss": reconstruction_loss
             },
             every_n_iter=50)
         return contrib_tpu.TPUEstimatorSpec(mode=mode,
                                             loss=factor_vae_loss,
                                             train_op=train_op,
                                             training_hooks=[logging_hook])
     elif mode == tf.estimator.ModeKeys.EVAL:
         return contrib_tpu.TPUEstimatorSpec(
             mode=mode,
             loss=factor_vae_loss,
             eval_metrics=(make_metric_fn("reconstruction_loss",
                                          "regularizer", "kl_loss"),
                           [reconstruction_loss, regularizer, kl_loss]))
     else:
         raise NotImplementedError("Eval mode not supported.")
    def __init__(self, preds, labels, pos_weight, norm, d_real, d_fake,
                 pred_attrs, attr_labels_list, sample_list):
        attr_preds_list = pred_attrs
        preds_sub = preds
        labels_sub = labels

        self.real = d_real

        # Discrimminator Loss
        self.dc_loss_real = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(
                self.real),
                                                    logits=self.real,
                                                    name='dclreal'))

        self.dc_loss_fake = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.zeros_like(d_fake), logits=d_fake, name='dcfake'))
        self.dc_loss = self.dc_loss_fake + self.dc_loss_real

        # Generator loss
        generator_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=tf.ones_like(d_fake), logits=d_fake, name='gl'))

        self.link_cost = norm * tf.reduce_mean(
            tf.nn.weighted_cross_entropy_with_logits(
                logits=preds_sub, targets=labels_sub, pos_weight=pos_weight))

        self.attr_loss = tf.losses.softmax_cross_entropy(
            logits=tf.cast(attr_preds_list[0], tf.float32),
            onehot_labels=attr_labels_list[0],
            reduction=tf.losses.Reduction.NONE)
        mask_attr = np.sum(attr_labels_list[0], axis=1)
        self.attr_loss = tf.multiply(self.attr_loss, sample_list)
        self.attr_loss = tf.reduce_mean(tf.multiply(self.attr_loss, mask_attr))

        self.pri_loss = tf.losses.softmax_cross_entropy(
            logits=tf.cast(attr_preds_list[1], tf.float32),
            onehot_labels=attr_labels_list[1],
            reduction=tf.losses.Reduction.NONE)
        mask_attr = np.sum(attr_labels_list[1], axis=1)
        self.pri_loss = tf.multiply(self.pri_loss, sample_list)
        self.pri_loss = tf.reduce_mean(tf.multiply(self.pri_loss, mask_attr))

        self.attr_cost = FLAGS.uti_attr_weight * self.attr_loss - (
            FLAGS.pri_weight * self.pri_loss)

        self.cost = self.attr_cost + FLAGS.link_weight * self.link_cost

        self.generator_loss = generator_loss + self.cost

        all_variables = tf.trainable_variables()
        dc_var = [var for var in all_variables if 'dc_' in var.name]
        en_var = [var for var in all_variables if 'e_' in var.name]
        pri_var = [var for var in all_variables if 'pri_' in var.name]
        all_rm_pri = [x for x in all_variables if x not in pri_var]

        self.O_optimizer = tf.train.AdamOptimizer(
            learning_rate=FLAGS.learning_rate)  # Adam Optimizer
        self.O_opt_op = self.O_optimizer.minimize(self.cost,
                                                  var_list=all_rm_pri)
        #self.O_grads_vars = self.O_optimizer.compute_gradients(self.generator_loss)

        self.A_optimizer = tf.train.AdamOptimizer(
            learning_rate=FLAGS.learning_rate)  # Adam Optimizer
        self.A_opt_op = self.A_optimizer.minimize(self.pri_loss,
                                                  var_list=pri_var)
        #self.A_grads_vars = self.A_optimizer.compute_gradients(self.attr2_loss)

        with tf.variable_scope(tf.get_variable_scope()):
            self.discriminator_optimizer = tf.train.AdamOptimizer(
                learning_rate=FLAGS.discriminator_learning_rate,
                beta1=0.9,
                name='adam1').minimize(
                    self.dc_loss,
                    var_list=dc_var)  #minimize(dc_loss_real, var_list=dc_var)

            self.generator_optimizer = tf.train.AdamOptimizer(
                learning_rate=FLAGS.discriminator_learning_rate,
                beta1=0.9,
                name='adam2').minimize(self.generator_loss, var_list=en_var)
예제 #3
0
                'Attempting to find it in GLOBAL_VARIABLES collection.',
                variable_name)
        global_vars = tensor.graph.get_collection(
            tf.GraphKeys.GLOBAL_VARIABLES)
        matched_vars = [
            v for v in global_vars if v.name == variable_name + ':0'
        ]
        if not matched_vars:
            raise ValueError(
                'Variable %s is in GraphDef but not in the live graph.')
        assert len(matched_vars) == 1
        return matched_vars[0]


var_store = {}
top_level_scope = tf.get_variable_scope()


def write_to_variable(tensor, fail_if_exists=True):
    """Saves a tensor for later retrieval on CPU."""
    # Only relevant for debugging.
    debug_name = 'tpu_util__' + tensor.name.split(':')[0]

    reuse = False if fail_if_exists else tf.compat.v1.AUTO_REUSE
    with tf.variable_scope(top_level_scope, reuse=reuse):
        variable = tf.get_variable(name=debug_name,
                                   shape=tensor.shape,
                                   dtype=tensor.dtype,
                                   trainable=False,
                                   use_resource=True)
def style_prediction(style_input_,
                     activation_names,
                     activation_depths,
                     is_training=True,
                     trainable=True,
                     inception_end_point='Mixed_6e',
                     style_prediction_bottleneck=100,
                     reuse=None):
    """Maps style images to the style embeddings (beta and gamma parameters).

  Args:
    style_input_: Tensor. Batch of style input images.
    activation_names: string. Scope names of the activations of the transformer
        network which are used to apply style normalization.
    activation_depths: Shapes of the activations of the transformer network
        which are used to apply style normalization.
    is_training: bool. Is it training phase or not?
    trainable: bool. Should the parameters be marked as trainable?
    inception_end_point: string. Specifies the endpoint to construct the
        inception_v3 network up to. This network is part of the style prediction
        network.
    style_prediction_bottleneck: int. Specifies the bottleneck size in the
        number of parameters of the style embedding.
    reuse: bool. Whether to reuse model parameters. Defaults to False.

  Returns:
    Tensor for the output of the style prediction network, Tensor for the
        bottleneck of style parameters of the style prediction network.
  """
    with tf.name_scope('style_prediction') and tf.variable_scope(
            tf.get_variable_scope(), reuse=reuse):
        with slim.arg_scope(_inception_v3_arg_scope(is_training=is_training)):
            with slim.arg_scope(
                [slim.conv2d, slim.fully_connected, slim.batch_norm],
                    trainable=trainable):
                with slim.arg_scope([slim.batch_norm, slim.dropout],
                                    is_training=is_training):
                    _, end_points = inception_v3.inception_v3_base(
                        style_input_,
                        scope='InceptionV3',
                        final_endpoint=inception_end_point)

        # Shape of feat_convlayer is (batch_size, ?, ?, depth).
        # For Mixed_6e end point, depth is 768, for input image size of 256x265
        # width and height are 14x14.
        feat_convlayer = end_points[inception_end_point]
        with tf.name_scope('bottleneck'):
            # (batch_size, 1, 1, depth).
            bottleneck_feat = tf.reduce_mean(feat_convlayer,
                                             axis=[1, 2],
                                             keep_dims=True)

        if style_prediction_bottleneck > 0:
            with slim.arg_scope([slim.conv2d],
                                activation_fn=None,
                                normalizer_fn=None,
                                trainable=trainable):
                # (batch_size, 1, 1, style_prediction_bottleneck).
                bottleneck_feat = slim.conv2d(bottleneck_feat,
                                              style_prediction_bottleneck,
                                              [1, 1])

        style_params = {}
        with tf.variable_scope('style_params'):
            for i in range(len(activation_depths)):
                with tf.variable_scope(activation_names[i], reuse=reuse):
                    with slim.arg_scope([slim.conv2d],
                                        activation_fn=None,
                                        normalizer_fn=None,
                                        trainable=trainable):

                        # Computing beta parameter of the style normalization for the
                        # activation_names[i] layer of the style transformer network.
                        # (batch_size, 1, 1, activation_depths[i])
                        beta = slim.conv2d(bottleneck_feat,
                                           activation_depths[i], [1, 1])
                        # (batch_size, activation_depths[i])
                        beta = tf.squeeze(beta, [1, 2], name='SpatialSqueeze')
                        style_params['{}/beta'.format(
                            activation_names[i])] = beta

                        # Computing gamma parameter of the style normalization for the
                        # activation_names[i] layer of the style transformer network.
                        # (batch_size, 1, 1, activation_depths[i])
                        gamma = slim.conv2d(bottleneck_feat,
                                            activation_depths[i], [1, 1])
                        # (batch_size, activation_depths[i])
                        gamma = tf.squeeze(gamma, [1, 2],
                                           name='SpatialSqueeze')
                        style_params['{}/gamma'.format(
                            activation_names[i])] = gamma

    return style_params, bottleneck_feat
예제 #5
0
def train_mnist_multitower(num_epochs,
                           num_towers,
                           devices,
                           use_fake_data=False,
                           session_config=None):
    """Train a ConvNet on MNIST.

  Training data is split equally among the towers. Each tower computes loss on
  its own batch of data and the loss is aggregated on the CPU. The model
  variables are placed on first tower. The covariance and inverse update ops
  and variables are placed on specified devices in a round robin manner.

  Args:
    num_epochs: int. Number of passes to make over the training set.
    num_towers: int. Number of towers.
    devices: list of strings. List of devices to place the towers.
    use_fake_data: bool. If True, generate a synthetic dataset.
    session_config: None or tf.ConfigProto. Configuration for tf.Session().

  Returns:
    accuracy of model on the final minibatch of training data.
  """
    num_towers = 1 if not devices else len(devices)
    # Load a dataset.
    tf.logging.info("Loading MNIST into memory.")
    tower_batch_size = 128
    batch_size = tower_batch_size * num_towers
    tf.logging.info(
        ("Loading MNIST into memory. Using batch_size = %d = %d towers * %d "
         "tower batch size.") % (batch_size, num_towers, tower_batch_size))
    (examples,
     labels) = mnist.load_mnist_as_iterator(num_epochs,
                                            batch_size,
                                            use_fake_data=use_fake_data,
                                            flatten_images=False)

    # Split minibatch across towers.
    examples = tf.split(examples, num_towers)
    labels = tf.split(labels, num_towers)

    # Build an MLP. Each tower's layers will be added to the LayerCollection.
    layer_collection = kfac.LayerCollection()
    tower_results = []
    for tower_id in range(num_towers):
        with tf.device(devices[tower_id]):
            with tf.name_scope("tower%d" % tower_id):
                with tf.variable_scope(tf.get_variable_scope(),
                                       reuse=(tower_id > 0)):
                    tf.logging.info("Building tower %d." % tower_id)
                    tower_results.append(
                        build_model(examples[tower_id],
                                    labels[tower_id],
                                    10,
                                    layer_collection,
                                    register_layers_manually=_USE_MANUAL_REG))
    losses, accuracies = zip(*tower_results)
    # When using multiple towers we only want to perform automatic
    # registation once, after the final tower is made
    if not _USE_MANUAL_REG:
        layer_collection.auto_register_layers()

    # Average across towers.
    loss = tf.reduce_mean(losses)
    accuracy = tf.reduce_mean(accuracies)

    # Fit model.
    g_step = tf.train.get_or_create_global_step()
    optimizer = kfac.PeriodicInvCovUpdateKfacOpt(
        invert_every=_INVERT_EVERY,
        cov_update_every=_COV_UPDATE_EVERY,
        learning_rate=0.0001,
        cov_ema_decay=0.95,
        damping=0.001,
        layer_collection=layer_collection,
        placement_strategy="round_robin",
        cov_devices=devices,
        inv_devices=devices,
        trans_devices=devices,
        momentum=0.9)

    with tf.device(devices[0]):
        train_op = optimizer.minimize(loss, global_step=g_step)

    # Without setting allow_soft_placement=True there will be problems when
    # the optimizer tries to place certain ops like "mod" on the GPU (which isn't
    # supported).
    if not session_config:
        session_config = tf.ConfigProto(allow_soft_placement=True)

    tf.logging.info("Starting training.")
    with tf.train.MonitoredTrainingSession(config=session_config) as sess:
        while not sess.should_stop():
            global_step_, loss_, accuracy_, _ = sess.run(
                [g_step, loss, accuracy, train_op])

            if global_step_ % _REPORT_EVERY == 0:
                tf.logging.info("global_step: %d | loss: %f | accuracy: %s",
                                global_step_, loss_, accuracy_)
예제 #6
0
    def solve(
        self,
        ode_fn,
        initial_time,
        initial_state,
        solution_times,
        jacobian_fn=None,
        jacobian_sparsity=None,
        batch_ndims=None,
        previous_solver_internal_state=None,
        constants=None,
    ):
        """Solves an initial value problem.

    An initial value problem consists of a system of ODEs and an initial
    condition:

    ```none
    dy/dt(t) = ode_fn(t, y(t), **constants)
    y(initial_time) = initial_state
    ```

    Here, `t` (also called time) is a scalar float `Tensor` and `y(t)` (also
    called the state at time `t`) is an N-D float or complex `Tensor`.
    `constants` is are values that are constant with respect to time. Passing
    the constants here rather than just closing over them in `ode_fn` is only
    necessary if you want gradients with respect to these values.

    ### Example

    The ODE `dy/dt(t) = dot(A, y(t))` is solved below.

    ```python
    t_init, t0, t1 = 0., 0.5, 1.
    y_init = tf.constant([1., 1.], dtype=tf.float64)
    A = tf.constant([[-1., -2.], [-3., -4.]], dtype=tf.float64)

    def ode_fn(t, y):
      return tf.linalg.matvec(A, y)

    results = tfp.math.ode.BDF().solve(ode_fn, t_init, y_init,
                                       solution_times=[t0, t1])
    y0 = results.states[0]  # == dot(matrix_exp(A * t0), y_init)
    y1 = results.states[1]  # == dot(matrix_exp(A * t1), y_init)
    ```

    If the exact solution times are not important, it can be much
    more efficient to let the solver choose them using
    `solution_times=tfp.math.ode.ChosenBySolver(final_time=1.)`.
    This yields the state at various times between `t_init` and `final_time`,
    in which case `results.states[i]` is the state at time `results.times[i]`.

    #### Gradients

    The gradients are computed using the adjoint sensitivity method described in
    [Chen et al. (2018)][1].

    ```python
    grad = tf.gradients(y1, y0) # == dot(e, J)
    # J is the Jacobian of y1 with respect to y0. In this case, J = exp(A * t1).
    # e = [1, ..., 1] is the row vector of ones.
    ```

    This is not capable of computing gradients with respect to values closed
    over by `ode_fn`, e.g., in the example above:

    ```python
    def ode_fn(t, y):
      return tf.linalg.matvec(A, y)

    with tf.GradientTape() as tape:
      tape.watch(A)
      results = tfp.math.ode.BDF().solve(ode_fn, t_init, y_init,
                                         solution_times=[t0, t1])
    tape.gradient(results.states, A)  # Undefined!
    ```

    There are two options to get the gradients flowing through these values:

    1. Use `tf.Variable` for these values.
    2. Pass the values in explicitly using the `constants` argument:

    ```python
    def ode_fn(t, y, A):
      return tf.linalg.matvec(A, y)

    with tf.GradientTape() as tape:
      tape.watch(A)
      results = tfp.math.ode.BDF().solve(ode_fn, t_init, y_init,
                                         solution_times=[t0, t1],
                                         constants={'A': A})
    tape.gradient(results.states, A)  # Fine.
    ```

    By default, this uses the same solver for the augmented ODE. This can be
    controlled via `make_adjoint_solver_fn`.

    #### References

    [1]: Chen, Tian Qi, et al. "Neural ordinary differential equations."
         Advances in Neural Information Processing Systems. 2018.

    Args:
      ode_fn: Function of the form `ode_fn(t, y, **constants)`. The input `t` is
        a scalar float `Tensor`. The input `y` and output are both `Tensor`s
        with the same shape and `dtype` as `initial_state`. `constants` is are
        values that are constant with respect to time. Passing the constants
        here rather than just closing over them in `ode_fn` is only necessary if
        you want gradients with respect to these values.
      initial_time: Scalar float `Tensor` specifying the initial time.
      initial_state: N-D float or complex `Tensor` specifying the initial state.
        The `dtype` of `initial_state` must be complex for problems with
        complex-valued states (even if the initial state is real).
      solution_times: 1-D float `Tensor` specifying a list of times. The solver
        stores the computed state at each of these times in the returned
        `Results` object. Must satisfy `initial_time <= solution_times[0]` and
        `solution_times[i] < solution_times[i+1]`. Alternatively, the user can
        pass `tfp.math.ode.ChosenBySolver(final_time)` where `final_time` is a
        scalar float `Tensor` satisfying `initial_time < final_time`. Doing so
        requests that the solver automatically choose suitable times up to and
        including `final_time` at which to store the computed state.
      jacobian_fn: Optional function of the form `jacobian_fn(t, y)`. The input
        `t` is a scalar float `Tensor`. The input `y` has the same shape and
        `dtype` as `initial_state`. The output is a 2N-D `Tensor` whose shape is
        `initial_state.shape + initial_state.shape` and whose `dtype` is the
        same as `initial_state`. In particular, the `(i1, ..., iN, j1, ...,
        jN)`-th entry of `jacobian_fn(t, y)` is the derivative of the `(i1, ...,
        iN)`-th entry of `ode_fn(t, y)` with respect to the `(j1, ..., jN)`-th
        entry of `y`. If this argument is left unspecified, the solver
        automatically computes the Jacobian if and when it is needed.
        Default value: `None`.
      jacobian_sparsity: Optional 2N-D boolean `Tensor` whose shape is
        `initial_state.shape + initial_state.shape` specifying the sparsity
        pattern of the Jacobian. This argument is ignored if `jacobian_fn` is
        specified.
        Default value: `None`.
      batch_ndims: Optional nonnegative integer. When specified, the first
        `batch_ndims` dimensions of `initial_state` are batch dimensions.
        Default value: `None`.
      previous_solver_internal_state: Optional solver-specific argument used to
        warm-start this invocation of `solve`.
        Default value: `None`.
      constants: Optional dictionary with string keys and values being (possibly
        nested) float `Tensor`s. These represent values that are constant with
        respect to time. Specifying these here allows the adjoint sentitivity
        method to compute gradients of the results with respect to these values.

    Returns:
      Object of type `Results`.
    """
        if constants is None:
            constants = {}
        input_state_structure = initial_state
        constant_state_structure = constants
        flat_initial_state = tf.nest.flatten(initial_state)
        flat_constants = tf.nest.flatten(constants)
        num_state_components = len(flat_initial_state)

        @tf.custom_gradient
        def gradient_helper(*flat_initial_state_and_constants):
            """Restricts gradient to initial state components and constants."""
            flat_initial_state_and_constants = [
                tf.convert_to_tensor(c)
                for c in flat_initial_state_and_constants
            ]
            flat_initial_state = (
                flat_initial_state_and_constants[:num_state_components])
            flat_constants = flat_initial_state_and_constants[
                num_state_components:]
            initial_state = tf.nest.pack_sequence_as(input_state_structure,
                                                     flat_initial_state)
            constants = tf.nest.pack_sequence_as(constant_state_structure,
                                                 flat_constants)

            results = self._solve(
                ode_fn=functools.partial(ode_fn, **constants),
                initial_time=initial_time,
                initial_state=initial_state,
                solution_times=solution_times,
                jacobian_fn=jacobian_fn,
                jacobian_sparsity=jacobian_sparsity,
                batch_ndims=batch_ndims,
                previous_solver_internal_state=previous_solver_internal_state,
            )
            results = Results(
                times=tf.stop_gradient(results.times),
                states=results.states,
                diagnostics=util.stop_gradient_of_real_or_complex_entries(
                    results.diagnostics),
                solver_internal_state=util.
                stop_gradient_of_real_or_complex_entries(
                    results.solver_internal_state))

            def grad_fn(*dresults, **kwargs):
                """Adjoint sensitivity method to compute gradients."""
                adjoint_solver = self._make_adjoint_solver_fn()
                dresults = tf.nest.pack_sequence_as(results, dresults)
                dstates = dresults.states
                # The signature grad_fn(*dresults, variables=None) is not valid Python 2
                # so use kwargs instead.
                variables = kwargs.pop('variables', [])
                assert not kwargs  # This assert should never fail.
                # TODO(b/138304303): Support complex types.
                with tf.name_scope('{}Gradients'.format(self._name)):
                    get_dtype = lambda x: x.dtype

                    def error_if_complex(dtype):
                        if dtype.is_complex:
                            raise NotImplementedError(
                                'The adjoint sensitivity method does '
                                'not support complex dtypes.')

                    state_dtypes = tf.nest.map_structure(
                        get_dtype, initial_state)
                    tf.nest.map_structure(error_if_complex, state_dtypes)
                    common_state_dtype = dtype_util.common_dtype(initial_state)
                    real_dtype = dtype_util.real_dtype(common_state_dtype)

                    # We add initial_time to ensure that we know where to stop.
                    result_times = tf.concat(
                        [[tf.cast(initial_time, real_dtype)], results.times],
                        0)
                    num_result_times = tf.size(result_times)

                    # First two components correspond to reverse and adjoint states.
                    # the last two component is adjoint state for variables and constants.
                    terminal_augmented_state = tuple([
                        rk_util.nest_constant(initial_state, 0.0),
                        rk_util.nest_constant(initial_state, 0.0),
                        tuple(
                            rk_util.nest_constant(variable, 0.0)
                            for variable in variables),
                        rk_util.nest_constant(constants, 0.0),
                    ])

                    # The XLA compiler does not compile code which slices/indexes using
                    # integer `Tensor`s. `TensorArray`s are used to get around this.
                    result_time_array = tf.TensorArray(
                        results.times.dtype,
                        clear_after_read=False,
                        size=num_result_times,
                        element_shape=[]).unstack(result_times)

                    # TensorArray shape should not include time dimension, hence shape[1:]
                    result_state_arrays = [
                        tf.TensorArray(  # pylint: disable=g-complex-comprehension
                            dtype=component.dtype,
                            size=num_result_times - 1,
                            clear_after_read=False,
                            element_shape=component.shape[1:]).unstack(
                                component)
                        for component in tf.nest.flatten(results.states)
                    ]
                    result_state_arrays = tf.nest.pack_sequence_as(
                        results.states, result_state_arrays)
                    dresult_state_arrays = [
                        tf.TensorArray(  # pylint: disable=g-complex-comprehension
                            dtype=component.dtype,
                            size=num_result_times - 1,
                            clear_after_read=False,
                            element_shape=component.shape[1:]).unstack(
                                component)
                        for component in tf.nest.flatten(dstates)
                    ]
                    dresult_state_arrays = tf.nest.pack_sequence_as(
                        results.states, dresult_state_arrays)

                    def augmented_ode_fn(backward_time, augmented_state):
                        """Dynamics function for the augmented system.

            Describes a differential equation that evolves the augmented state
            backwards in time to compute gradients using the adjoint method.
            Augmented state consists of 4 components `(state, adjoint_state,
            vars, constants)` all evaluated at time `backward_time`:

            state: represents the solution of user provided `ode_fn`. The
              structure coincides with the `initial_state`.
            adjoint_state: represents the solution of the adjoint sensitivity
              differential equation as discussed below. Has the same structure
              and shape as `state`.
            variables: represent the solution of the adjoint equation for
              variable gradients. Represented as a `Tuple(Tensor, ...)` with as
              many tensors as there are `variables` variable outside this
              function.
            constants: represent the solution of the adjoint equation for
              constant gradients. Has the same structure and shape as
              `constants` variable outside this function.

            The adjoint sensitivity equation describes the gradient of the
            solution with respect to the value of the solution at a previous
            time t. Its dynamics are given by
            d/dt[adj(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), z)
            Which is computed as:
            d/dt[adj(t)]_i = -1 * sum_j(adj(t)_j * d/dz_i[ode_fn(t, z)_j)]
            d/dt[adj(t)]_i = -1 * d/dz_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)]
            where in the last line we moved adj(t)_j under derivative by
            removing gradient from it.

            Adjoint equation for the gradient with respect to every
            `tf.Variable` and constant theta follows:
            d/dt[grad_theta(t)] = -1 * adj(t) @ jacobian(ode_fn(t, z), theta)
            = -1 * d/d theta_i[sum_j(no_grad_adj_j * ode_fn(t, z)_j)]

            Args:
              backward_time: Floating `Tensor` representing current time.
              augmented_state: `Tuple(state, adjoint_state, variable_grads)`

            Returns:
              negative_derivatives: Structure of `Tensor`s equal to backwards
                time derivative of the `state` componnent.
              adjoint_ode: Structure of `Tensor`s equal to backwards time
                derivative of the `adjoint_state` component.
              adjoint_variables_ode: Structure of `Tensor`s equal to backwards
                time derivative of the `vars` component.
              adjoint_constants_ode: Structure of `Tensor`s equal to backwards
                time derivative of the `constants` component.
            """
                        # The negative signs disappears after the change of variables.
                        # The ODE solver cannot handle the case initial_time > final_time
                        # and hence a change of variables backward_time = -time is used.
                        time = -backward_time
                        state, adjoint_state, _, _ = augmented_state

                        # TODO(b/152464477): Doesn't work reliably in TF1.
                        with tf.GradientTape() as tape:
                            tape.watch([variables, state, constants])
                            derivatives = ode_fn(time, state, **constants)
                            adjoint_no_grad = tf.nest.map_structure(
                                tf.stop_gradient, adjoint_state)
                            negative_derivatives = rk_util.weighted_sum(
                                [-1.0], [derivatives])

                            def dot_prod(tensor_a, tensor_b):
                                return tf.reduce_sum(tensor_a * tensor_b)

                            # See docstring for details.
                            adjoint_dot_derivatives = tf.nest.map_structure(
                                dot_prod, adjoint_no_grad, derivatives)
                            adjoint_dot_derivatives = tf.squeeze(
                                tf.add_n(
                                    tf.nest.flatten(adjoint_dot_derivatives)))

                        (adjoint_ode, adjoint_variables_ode,
                         adjoint_constants_ode) = tape.gradient(
                             adjoint_dot_derivatives,
                             (state, tuple(variables), constants),
                             unconnected_gradients=tf.UnconnectedGradients.ZERO
                         )
                        return (negative_derivatives, adjoint_ode,
                                adjoint_variables_ode, adjoint_constants_ode)

                    def make_augmented_state(n, prev_augmented_state):
                        """Constructs the augmented state for step `n`."""
                        (_, adjoint_state, adjoint_variable_state,
                         adjoint_constant_state) = prev_augmented_state
                        initial_state = _read_solution_components(
                            result_state_arrays,
                            input_state_structure,
                            n - 1,
                        )
                        initial_adjoint = _read_solution_components(
                            dresult_state_arrays,
                            input_state_structure,
                            n - 1,
                        )
                        initial_adjoint_state = rk_util.weighted_sum(
                            [1.0, 1.0], [adjoint_state, initial_adjoint])
                        augmented_state = (
                            initial_state,
                            initial_adjoint_state,
                            adjoint_variable_state,
                            adjoint_constant_state,
                        )
                        return augmented_state

                    def reverse_to_result_time(n, augmented_state,
                                               solver_internal_state, _):
                        """Integrates the augmented system backwards in time."""
                        lower_bound_of_integration = result_time_array.read(n)
                        upper_bound_of_integration = result_time_array.read(n -
                                                                            1)
                        initial_augmented_state = make_augmented_state(
                            n, augmented_state)
                        # TODO(b/138304303): Allow the user to specify the Hessian of
                        # `ode_fn` so that we can get the Jacobian of the adjoint system.
                        # TODO(b/143624114): Support higher order derivatives.
                        solver_internal_state = (
                            adjoint_solver.
                            _adjust_solver_internal_state_for_state_jump(  # pylint: disable=protected-access
                                ode_fn=augmented_ode_fn,
                                initial_time=-lower_bound_of_integration,
                                initial_state=initial_augmented_state,
                                previous_solver_internal_state=
                                solver_internal_state,
                                previous_state=augmented_state,
                            ))
                        augmented_results = adjoint_solver.solve(
                            ode_fn=augmented_ode_fn,
                            initial_time=-lower_bound_of_integration,
                            initial_state=initial_augmented_state,
                            solution_times=[-upper_bound_of_integration],
                            batch_ndims=batch_ndims,
                            previous_solver_internal_state=
                            solver_internal_state,
                        )
                        # Results added an extra time dim of size 1, squeeze it.
                        select_result = lambda x: tf.squeeze(x, [0])
                        result_state = augmented_results.states
                        result_state = tf.nest.map_structure(
                            select_result, result_state)
                        status = augmented_results.diagnostics.status
                        return (n - 1, result_state,
                                augmented_results.solver_internal_state,
                                status)

                    initial_n = num_result_times - 1
                    solver_internal_state = adjoint_solver._initialize_solver_internal_state(  # pylint: disable=protected-access
                        ode_fn=augmented_ode_fn,
                        initial_time=result_time_array.read(initial_n),
                        initial_state=make_augmented_state(
                            initial_n, terminal_augmented_state),
                    )

                    _, augmented_state, _, _ = tf.while_loop(
                        lambda n, _as, _sis, status:
                        (n >= 1) & tf.equal(status, 0),
                        reverse_to_result_time,
                        (initial_n, terminal_augmented_state,
                         solver_internal_state, 0),
                        back_prop=False,
                    )
                    (_, adjoint_state, adjoint_variables,
                     adjoint_constants) = augmented_state
                    return (tf.nest.flatten(adjoint_state) +
                            tf.nest.flatten(adjoint_constants),
                            list(adjoint_variables))

            return results, grad_fn

        # TODO(b/140760650): We must use a resource-using variable scope, otherwise
        # custom_gradient will complain even if there are no variables in `ode_fn`.
        with tf1.variable_scope(tf1.get_variable_scope(), use_resource=True):
            return gradient_helper(*(flat_initial_state + flat_constants))
예제 #7
0
    def _build_loss(self):
        """Builds the loss tensor, to be minimized by the optimizer."""
        self.reader = reader.DataReader(
            self.data_dir,
            self.batch_size,
            self.img_height,
            self.img_width,
            SEQ_LENGTH,
            1,  # num_scales
            self.file_extension,
            self.random_scale_crop,
            reader.FLIP_RANDOM,
            self.random_color,
            self.imagenet_norm,
            self.shuffle,
            self.input_file,
            queue_size=self.queue_size)

        (self.image_stack, self.image_stack_norm, self.seg_stack,
         self.intrinsic_mat, _) = self.reader.read_data()
        if self.learn_intrinsics:
            self.intrinsic_mat = None
        if self.intrinsic_mat is None and not self.learn_intrinsics:
            raise RuntimeError(
                'Could not read intrinsic matrix. Turn '
                'learn_intrinsics on to learn it instead of loading '
                'it.')
        self.export('self.image_stack', self.image_stack)

        object_masks = []
        for i in range(self.batch_size):
            object_ids = tf.unique(tf.reshape(self.seg_stack[i], [-1]))[0]
            object_masks_i = []
            for j in range(SEQ_LENGTH):
                current_seg = self.seg_stack[i, :, :, j * 3]  # (H, W)

                def process_obj_mask(obj_id):
                    """Create a mask for obj_id, skipping the background mask."""
                    mask = tf.logical_and(
                        tf.equal(current_seg, obj_id),  # pylint: disable=cell-var-from-loop
                        tf.not_equal(tf.cast(0, tf.uint8), obj_id))
                    # Leave out vert small masks, that are most often errors.
                    size = tf.reduce_sum(tf.to_int32(mask))
                    mask = tf.logical_and(mask,
                                          tf.greater(size, MIN_OBJECT_AREA))
                    if not self.boxify:
                        return mask
                    # Complete the mask to its bounding box.
                    binary_obj_masks_y = tf.reduce_any(mask,
                                                       axis=1,
                                                       keepdims=True)
                    binary_obj_masks_x = tf.reduce_any(mask,
                                                       axis=0,
                                                       keepdims=True)
                    return tf.logical_and(binary_obj_masks_y,
                                          binary_obj_masks_x)

                object_mask = tf.map_fn(  # (N, H, W)
                    process_obj_mask, object_ids, dtype=tf.bool)
                object_mask = tf.reduce_any(object_mask, axis=0)
                object_masks_i.append(object_mask)
            object_masks.append(tf.stack(object_masks_i, axis=-1))

        self.seg_stack = tf.cast(tf.stack(object_masks, axis=0), tf.float)
        tf.summary.image('Masks', self.seg_stack)

        with tf.variable_scope(DEPTH_SCOPE):
            # Organized by ...[i][scale].  Note that the order is flipped in
            # variables in build_loss() below.
            self.disp = {}
            self.depth = {}

            # Parabolic rampup of he noise over LAYER_NORM_NOISE_RAMPUP_STEPS steps.
            # We stop at 0.5 because this is the value above which the multiplicative
            # noise we use can become negative. Further experimentation is needed to
            # find if non-negativity is indeed needed.
            noise_stddev = 0.5 * tf.square(
                tf.minimum(
                    tf.cast(self.global_step, tf.float) /
                    float(LAYER_NORM_NOISE_RAMPUP_STEPS), 1.0))

            def _normalizer_fn(x, is_train, name='bn'):
                return randomized_layer_normalization.normalize(
                    x, is_train=is_train, name=name, stddev=noise_stddev)

            with tf.variable_scope(tf.get_variable_scope(),
                                   reuse=tf.AUTO_REUSE):
                for i in range(SEQ_LENGTH):
                    image = self.image_stack_norm[:, :, :, 3 * i:3 * (i + 1)]
                    self.depth[
                        i] = depth_prediction_net.depth_prediction_resnet18unet(
                            image, True, self.weight_reg, _normalizer_fn)
                    self.disp[i] = 1.0 / self.depth[i]

        with tf.name_scope('compute_loss'):
            self.reconstr_loss = 0
            self.smooth_loss = 0
            self.ssim_loss = 0
            self.depth_consistency_loss = 0

            # Smoothness.
            if self.smooth_weight > 0:
                for i in range(SEQ_LENGTH):
                    disp_smoothing = self.disp[i]
                    # Perform depth normalization, dividing by the mean.
                    mean_disp = tf.reduce_mean(disp_smoothing,
                                               axis=[1, 2, 3],
                                               keep_dims=True)
                    disp_input = disp_smoothing / mean_disp
                    self.smooth_loss += _depth_smoothness(
                        disp_input, self.image_stack[:, :, :,
                                                     3 * i:3 * (i + 1)])

            self.rot_loss = 0.0
            self.trans_loss = 0.0

            def add_result_to_loss_and_summaries(endpoints, i, j):
                tf.summary.image(
                    'valid_mask%d%d' % (i, j),
                    tf.expand_dims(endpoints['depth_proximity_weight'], -1))

                self.depth_consistency_loss += endpoints['depth_error']
                self.reconstr_loss += endpoints['rgb_error']
                self.ssim_loss += 0.5 * endpoints['ssim_error']
                self.rot_loss += endpoints['rotation_error']
                self.trans_loss += endpoints['translation_error']

            self.motion_smoothing = 0.0
            with tf.variable_scope(tf.get_variable_scope(),
                                   reuse=tf.AUTO_REUSE):
                for i in range(SEQ_LENGTH - 1):
                    j = i + 1
                    depth_i = self.depth[i][:, :, :, 0]
                    depth_j = self.depth[j][:, :, :, 0]
                    image_j = self.image_stack[:, :, :, 3 * j:3 * (j + 1)]
                    image_i = self.image_stack[:, :, :, i * 3:(i + 1) * 3]
                    # We select a pair of consecutive images (and their respective
                    # predicted depth maps). Now we have the network predict a motion
                    # field that connects the two. We feed the pair of images into the
                    # network, once in forward order and then in reverse order. The
                    # results are fed into the loss calculation. The following losses are
                    # calculated:
                    # - RGB and SSIM photometric consistency.
                    # - Cycle consistency of rotations and translations for every pixel.
                    # - L1 smoothness of the disparity and the motion field.
                    # - Depth consistency
                    rot, trans, trans_res, mat = motion_prediction_net.motion_field_net(
                        images=tf.concat([image_i, image_j], axis=-1),
                        weight_reg=self.weight_reg)
                    inv_rot, inv_trans, inv_trans_res, inv_mat = (
                        motion_prediction_net.motion_field_net(
                            images=tf.concat([image_j, image_i], axis=-1),
                            weight_reg=self.weight_reg))

                    if self.learn_intrinsics:
                        intrinsic_mat = 0.5 * (mat + inv_mat)
                    else:
                        intrinsic_mat = self.intrinsic_mat[:, 0, :, :]

                    def dilate(x):
                        # Dilation by n pixels is roughtly max pooling by 2 * n + 1.
                        p = self.foreground_dilation * 2 + 1
                        return tf.nn.max_pool(x, [1, p, p, 1], [1] * 4, 'SAME')

                    trans += trans_res * dilate(self.seg_stack[:, :, :,
                                                               j:j + 1])
                    inv_trans += inv_trans_res * dilate(
                        self.seg_stack[:, :, :, i:i + 1])

                    tf.summary.image('trans%d%d' % (i, i + 1), trans)
                    tf.summary.image('trans%d%d' % (i + 1, i), inv_trans)

                    tf.summary.image('trans_res%d%d' % (i + 1, i),
                                     inv_trans_res)
                    tf.summary.image('trans_res%d%d' % (i, i + 1), trans_res)

                    self.motion_smoothing += _smoothness(trans)
                    self.motion_smoothing += _smoothness(inv_trans)
                    tf.summary.scalar(
                        'trans_stdev',
                        tf.sqrt(0.5 * tf.reduce_mean(
                            tf.square(trans) + tf.square(inv_trans))))

                    transformed_depth_j = transform_depth_map.using_motion_vector(
                        depth_j, trans, rot, intrinsic_mat)

                    add_result_to_loss_and_summaries(
                        consistency_losses.rgbd_and_motion_consistency_loss(
                            transformed_depth_j, image_j, depth_i, image_i,
                            rot, trans, inv_rot, inv_trans), i, j)

                    transformed_depth_i = transform_depth_map.using_motion_vector(
                        depth_i, inv_trans, inv_rot, intrinsic_mat)

                    add_result_to_loss_and_summaries(
                        consistency_losses.rgbd_and_motion_consistency_loss(
                            transformed_depth_i, image_i, depth_j, image_j,
                            inv_rot, inv_trans, rot, trans), j, i)

            # Build the total loss as composed of L1 reconstruction, SSIM, smoothing
            # and object size constraint loss as appropriate.
            self.reconstr_loss *= self.reconstr_weight
            self.export('self.reconstr_loss', self.reconstr_loss)
            self.total_loss = self.reconstr_loss
            if self.smooth_weight > 0:
                self.smooth_loss *= self.smooth_weight
                self.total_loss += self.smooth_loss
                self.export('self.smooth_loss', self.smooth_loss)
            if self.ssim_weight > 0:
                self.ssim_loss *= self.ssim_weight
                self.total_loss += self.ssim_loss
                self.export('self.ssim_loss', self.ssim_loss)

            if self.motion_smoothing_weight > 0:
                self.motion_smoothing *= self.motion_smoothing_weight
                self.total_loss += self.motion_smoothing
                self.export('self.motion_sm_loss', self.motion_smoothing)

            if self.depth_consistency_loss_weight:
                self.depth_consistency_loss *= self.depth_consistency_loss_weight
                self.total_loss += self.depth_consistency_loss
                self.export('self.depth_consistency_loss',
                            self.depth_consistency_loss)

            self.rot_loss *= self.rotation_consistency_weight
            self.trans_loss *= self.translation_consistency_weight
            self.export('rot_loss', self.rot_loss)
            self.export('trans_loss', self.trans_loss)

            self.total_loss += self.rot_loss
            self.total_loss += self.trans_loss

            self.export('self.total_loss', self.total_loss)
예제 #8
0
def run_pathnet_training_and_evaluation(
    task_names,
    task_data,
    input_data_shape,
    training_hparams,
    components_layers,
    evaluate_on,
    summary_dir,
    resume_checkpoint_dir=None,
    save_checkpoint_every_n_steps=250,
    intermediate_eval_steps=[]):
  """Trains and evaluates a PathNet multitask image classification model.

  Args:
    task_names: (list of strings) names of tasks.
    task_data: (list of dicts) list of dictionaries, one per task.
      Each dictionary should map strings into `tf.data.Dataset`s.
      The `i`-th dictionary should contain all dataset splits (such as 'train',
      'test', 'eval', etc) for the `i`-th task. The splits can be arbitrary,
      but to run the training, every dataset should contain a 'train' split.
    input_data_shape: (sequence of ints) expected shape of input images
      (excluding batch dimension). For example, for the MNIST dataset
      `input_data_shape=[28, 28, 1]`.
    training_hparams: (tf.contrib.training.HParams) training hyperparameters.
    components_layers: (list of `pn.ComponentsLayer`s) layers that make up
      the PathNet model.
    evaluate_on: (list of strings) dataset splits on which the trained PathNet
      should be evaluated. These keys should be present in every dictionary
      in `task_data`.
    summary_dir: (string) directory for the summary writer.
    resume_checkpoint_dir: (string or None) directory for the checkpoint
      to reload, or None if should start from scratch.
    save_checkpoint_every_n_steps: (int) frequency for saving model checkpoints.
    intermediate_eval_steps: (list of ints) training step numbers at which
      accuracy should be evaluated. An evaluation after the last step is
      always performed.
  """
  session = tf.Session(graph=tf.get_default_graph())

  summary_writer = tf.contrib.summary.create_file_writer(summary_dir)
  summary_writer.set_as_default()

  num_tasks = len(task_names)

  # Every `num_tasks` subsequent steps contain exactly one step for each task,
  # and always in the order as they appear in `task_data`. Setting the logging
  # frequency to `num_tasks + 1` (or any other number coprime with `num_tasks`)
  # guarantees that each task will get to record summaries with the same
  # frequency.
  with tf.contrib.summary.record_summaries_every_n_global_steps(num_tasks + 1):
    pathnet = pn.PathNet(components_layers, training_hparams)
    num_steps = training_hparams.num_steps

    eval_steps = intermediate_eval_steps + [num_steps]

    # Loop each training dataset forever.
    train_data = [
        dataset['train'].repeat().make_one_shot_iterator().get_next()
        for dataset in task_data
    ]

    # Attach the task id to each dataset.
    train_data = list(enumerate(train_data))

    p_inputs = tf.placeholder(tf.float32, shape=[None] + input_data_shape)
    p_labels = tf.placeholder(tf.int64, shape=[None])
    p_task_id = tf.placeholder(tf.int32, shape=[], name='task_id')

    train_step_op, _ = build_pathnet_graph(
        p_inputs, p_labels, p_task_id, pathnet, training=True)

  with tf.variable_scope(tf.get_variable_scope(), reuse=True):
    _, out_logits_eval = build_pathnet_graph(
        p_inputs, p_labels, p_task_id, pathnet, training=False)

  session.run(tf.global_variables_initializer())
  tf.contrib.summary.initialize(session=session)

  saver = tf.train.Saver(tf.global_variables())

  start_step = 0

  p_task_accuracies = {}
  accuracy_summary_op = {}

  for data_split in evaluate_on:
    p_task_accuracies[data_split], accuracy_summary_op[data_split] = \
      create_accuracy_summary_ops(
          task_names, summary_name_prefix='final_eval_%s' % data_split)

  if resume_checkpoint_dir is not None:
    print('Resuming from checkpoint: %s' % resume_checkpoint_dir)

    last_global_step = int(resume_checkpoint_dir.split('-')[-1])

    assert last_global_step % num_tasks == 0
    start_step = last_global_step // num_tasks

    saver.restore(session, resume_checkpoint_dir)

  for dataset in task_data:
    for data_split in evaluate_on:
      num_batches = count_batches(
          session, dataset[data_split].make_one_shot_iterator().get_next())

      dataset[data_split] = dataset[data_split].repeat()
      dataset[data_split] = (
          dataset[data_split].make_one_shot_iterator().get_next())

      dataset[data_split] = (dataset[data_split], num_batches)

  for step in tqdm(range(start_step, num_steps)):
    random.shuffle(train_data)

    run_pathnet_training_step(
        session, p_inputs, p_labels, p_task_id, train_step_op, train_data)

    if step + 1 in eval_steps:
      for data_split in evaluate_on:
        eval_data = [dataset[data_split] for dataset in task_data]

        print('Running evaluation on: %s' % data_split)

        task_accuracies = run_pathnet_evaluation(
            session=session,
            p_inputs=p_inputs,
            p_task_id=p_task_id,
            out_logits=out_logits_eval,
            task_names=task_names,
            eval_data=eval_data)

        run_accuracy_summary_ops(
            session,
            p_task_accuracies[data_split],
            task_accuracies,
            accuracy_summary_op[data_split])

    if (step + 1) % save_checkpoint_every_n_steps == 0:
      path = summary_dir + '/chkpt'
      saver.save(
          session, path, global_step=tf.train.get_or_create_global_step())
예제 #9
0
def main(argv):
  tf.logging.set_verbosity(tf.logging.INFO)
  trainer_lib.set_random_seed(FLAGS.random_seed)
  usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
  t2t_trainer.maybe_log_registry_and_exit()


  if FLAGS.cloud_mlengine:
    cloud_mlengine.launch()
    return

  if FLAGS.generate_data:
    t2t_trainer.generate_data()

  if cloud_mlengine.job_dir():
    FLAGS.output_dir = cloud_mlengine.job_dir()

  if argv:
    t2t_trainer.set_hparams_from_args(argv[1:])

  if FLAGS.surrogate_attack:
    tf.logging.warn("Performing surrogate model attack.")
    sur_hparams = create_surrogate_hparams()
    trainer_lib.add_problem_hparams(sur_hparams, FLAGS.problem)

  hparams = t2t_trainer.create_hparams()
  trainer_lib.add_problem_hparams(hparams, FLAGS.problem)

  attack_params = create_attack_params()
  attack_params.add_hparam(attack_params.epsilon_name, 0.0)

  if FLAGS.surrogate_attack:
    sur_config = create_surrogate_run_config(sur_hparams)
  config = t2t_trainer.create_run_config(hparams)
  params = {
      "batch_size": hparams.batch_size,
      "use_tpu": FLAGS.use_tpu,
  }

  # add "_rev" as a hack to avoid image standardization
  problem = registry.problem(FLAGS.problem + "_rev")

  inputs, labels, features = prepare_data(problem, hparams, params, config)

  sess = tf.Session()

  if FLAGS.surrogate_attack:
    sur_model_fn = t2t_model.T2TModel.make_estimator_model_fn(
        FLAGS.surrogate_model, sur_hparams, use_tpu=FLAGS.use_tpu)
    sur_ch_model = adv_attack_utils.T2TAttackModel(
        sur_model_fn, features, params, sur_config, scope="surrogate")
    # Dummy call to construct graph
    sur_ch_model.get_probs(inputs)

    checkpoint_path = os.path.expanduser(FLAGS.surrogate_output_dir)
    tf.train.init_from_checkpoint(
        tf.train.latest_checkpoint(checkpoint_path), {"/": "surrogate/"})
    sess.run(tf.global_variables_initializer())

  other_vars = set(tf.global_variables())

  model_fn = t2t_model.T2TModel.make_estimator_model_fn(
      FLAGS.model, hparams)
  ch_model = adv_attack_utils.T2TAttackModel(model_fn, features, params, config)

  acc_mask = None
  probs = ch_model.get_probs(inputs)
  if FLAGS.ignore_incorrect:
    preds = tf.argmax(probs, -1, output_type=labels.dtype)
    preds = tf.reshape(preds, labels.shape)
    acc_mask = tf.to_float(tf.equal(labels, preds))
  one_hot_labels = tf.one_hot(labels, probs.shape[-1])

  if FLAGS.surrogate_attack:
    attack = create_attack(attack_params.attack)(sur_ch_model, sess=sess)
  else:
    attack = create_attack(attack_params.attack)(ch_model, sess=sess)

  new_vars = set(tf.global_variables()) - other_vars

  # Restore weights
  saver = tf.train.Saver(new_vars)
  checkpoint_path = os.path.expanduser(FLAGS.output_dir)
  saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))

  # reuse variables
  tf.get_variable_scope().reuse_variables()

  def compute_accuracy(x, l, mask):
    """Compute model accuracy."""
    preds = ch_model.get_probs(x)
    preds = tf.squeeze(preds)
    preds = tf.argmax(preds, -1, output_type=l.dtype)

    _, acc_update_op = tf.metrics.accuracy(l, preds, weights=mask)

    if FLAGS.surrogate_attack:
      preds = sur_ch_model.get_probs(x)
      preds = tf.squeeze(preds)
      preds = tf.argmax(preds, -1, output_type=l.dtype)
      acc_update_op = tf.tuple((acc_update_op,
                                tf.metrics.accuracy(l, preds, weights=mask)[1]))

    sess.run(tf.initialize_local_variables())
    for i in range(FLAGS.eval_steps):
      tf.logging.info(
          "\tEvaluating batch [%d / %d]" % (i + 1, FLAGS.eval_steps))
      acc = sess.run(acc_update_op)
    if FLAGS.surrogate_attack:
      tf.logging.info("\tFinal acc: (%.4f, %.4f)" % (acc[0], acc[1]))
    else:
      tf.logging.info("\tFinal acc: %.4f" % acc)
    return acc

  epsilon_acc_pairs = []
  for epsilon in attack_params.attack_epsilons:
    tf.logging.info("Attacking @ eps=%.4f" % epsilon)
    attack_params.set_hparam(attack_params.epsilon_name, epsilon)
    adv_x = attack.generate(inputs, y=one_hot_labels, **attack_params.values())
    acc = compute_accuracy(adv_x, labels, acc_mask)
    epsilon_acc_pairs.append((epsilon, acc))

  for epsilon, acc in epsilon_acc_pairs:
    if FLAGS.surrogate_attack:
      tf.logging.info(
          "Accuracy @ eps=%.4f: (%.4f, %.4f)" % (epsilon, acc[0], acc[1]))
    else:
      tf.logging.info("Accuracy @ eps=%.4f: %.4f" % (epsilon, acc))
예제 #10
0
    def __process(self, all_frames, all_actions, all_rewards, all_raw_frames):
        """Main video processing function."""
        hparams = self.hparams
        all_frames_copy = [tf.identity(frame) for frame in all_frames]
        orig_frame_shape = common_layers.shape_list(all_frames[0])
        batch_size = orig_frame_shape[0]
        ss_func = self.get_scheduled_sample_func(batch_size)
        target_frames = []
        extra_loss = 0.0

        # Any extra info required by the model goes into here.
        video_features = self.video_features(all_frames, all_actions,
                                             all_rewards, all_raw_frames)

        num_frames = len(all_frames)
        if self.is_recurrent_model:
            input_index_range = range(num_frames - 1)
        else:
            input_index_range = range(hparams.video_num_target_frames)

        # Setup the internal states as well as an auxiliary tf op
        # to enforce syncronization between prediction steps.
        if self.internal_states is None:
            internal_states = None
            sync_op = tf.no_op()
        else:
            internal_states = self.load_internal_states_ops()
            with tf.control_dependencies(flat_lists(internal_states)):
                sync_op = tf.no_op()

        res_frames, sampled_frames, res_rewards, res_policies, res_values = \
            [], [], [], [], []
        for i in input_index_range:
            with tf.control_dependencies([sync_op]):
                frames, actions, rewards, target_index = self.__get_next_inputs(
                    i, all_frames, all_actions, all_rewards)
                target_frame = all_frames[target_index]
                target_frames.append(tf.identity(target_frame))

                with tf.variable_scope(tf.get_variable_scope(),
                                       reuse=tf.AUTO_REUSE):
                    float_frames = [tf.to_float(frame) for frame in frames]
                    func_out = self.next_frame(float_frames, actions, rewards,
                                               tf.to_float(target_frame),
                                               internal_states, video_features)
                    res_frame, res_reward, res_policy, res_value, res_extra_loss, \
                        internal_states = func_out
                    res_frames.append(res_frame)
                    res_rewards.append(res_reward)
                    res_policies.append(res_policy)
                    res_values.append(res_value)
                    extra_loss += res_extra_loss / float(
                        len(input_index_range))

                    # Syncronizing the internals states
                    # Some Tensflow Magic to make sure everything happens as it should.
                    with tf.control_dependencies([res_frame]):
                        sync_op = tf.no_op()
                        if self.is_predicting and self.is_recurrent_model and i == 0:
                            # The internal state save happens at the end of the 1st iteration
                            # which essentially allows recurrent models to continue
                            # running after one prediction.
                            # Necessary for planning/rl applications.
                            save_ops = self.save_internal_states_ops(
                                internal_states)
                            with tf.control_dependencies(flat_lists(save_ops)):
                                sync_op = tf.no_op()

                # Only for Softmax loss: sample frame so we can keep iterating.
                sampled_frame = self.get_sampled_frame(res_frame)
                sampled_frames.append(sampled_frame)

                # Check whether we are done with context frames or not
                if self.is_recurrent_model:
                    done_warm_start = (i >= hparams.video_num_input_frames - 1)
                else:
                    done_warm_start = True  # Always true for non-reccurent networks.

                if self.is_predicting and done_warm_start:
                    all_frames[target_index] = sampled_frame

                # Scheduled sampling during training.
                if self.is_training:
                    groundtruth_items = [tf.to_float(target_frame)]
                    generated_items = [sampled_frame]
                    ss_frame, = self.get_scheduled_sample_inputs(
                        done_warm_start, groundtruth_items, generated_items,
                        ss_func)
                    all_frames[target_index] = ss_frame

        video_extra_loss = self.video_extra_loss(sampled_frames, target_frames,
                                                 internal_states,
                                                 video_features)
        tf.summary.scalar("video_extra_loss", video_extra_loss)
        extra_loss += video_extra_loss

        if self.is_recurrent_model:
            has_input_predictions = hparams.video_num_input_frames > 1
            if self.is_training and hparams.internal_loss and has_input_predictions:
                # add the loss for input frames as well.
                extra_gts = all_frames_copy[1:hparams.video_num_input_frames]
                extra_raw_gts = all_raw_frames[1:hparams.
                                               video_num_input_frames]
                extra_pds = res_frames[:hparams.video_num_input_frames - 1]
                recon_loss = self.get_extra_internal_loss(
                    extra_raw_gts, extra_gts, extra_pds)
                extra_loss += recon_loss
            # Cut the predicted input frames.
            res_frames = res_frames[hparams.video_num_input_frames - 1:]
            res_rewards = res_rewards[hparams.video_num_input_frames - 1:]
            res_policies = res_policies[hparams.video_num_input_frames - 1:]
            res_values = res_values[hparams.video_num_input_frames - 1:]
            sampled_frames = sampled_frames[hparams.video_num_input_frames -
                                            1:]
            target_frames = target_frames[hparams.video_num_input_frames - 1:]

        self.visualize_predictions(sampled_frames,
                                   [tf.to_float(f) for f in target_frames])

        output_frames = tf.stack(res_frames, axis=1)
        targets = output_frames

        if any((self.has_rewards, self.has_policies, self.has_values)):
            targets = {"targets": output_frames}
            if self.has_rewards:
                targets["target_reward"] = tf.stack(res_rewards, axis=1)
            if self.has_policies:
                targets["target_policy"] = tf.stack(res_policies, axis=1)
            if self.has_values:
                targets["target_value"] = tf.stack(res_values, axis=1)

        return targets, extra_loss
                                                     min_after_dequeue=3200)
    test_data, test_label = read_tfrecords(test_set)
    test_data, test_label = tf.train.shuffle_batch([test_data, test_label],
                                                   batch_size=BATCH_SIZE,
                                                   capacity=3000,
                                                   min_after_dequeue=1500)

train_z, var_dict1 = encode(train_data)
train_logits, var_dict2 = decode(train_z)
train_out = tf.nn.sigmoid(train_logits)
train_loss = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(labels=train_label,
                                            logits=train_logits))
optimizer = tf.train.AdamOptimizer(0.001).minimize(train_loss)

tf.get_variable_scope().reuse_variables()
test_z, _ = encode(test_data)
test_logits, _ = decode(test_z)
test_loss = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(labels=test_label,
                                            logits=test_logits))

var_dict = dict(var_dict1.items() + var_dict2.items())

saver1 = tf.train.Saver(var_list=var_dict)
saver2 = tf.train.Saver(max_to_keep=2)
sess = tf.Session(config=tf.ConfigProto(log_device_placement=True))

sess.run(tf.global_variables_initializer())

model_path = tf.train.latest_checkpoint(folder_of_well_trained_model)
예제 #12
0
    def build_sample_graph(self,
                           input_pianorolls=None,
                           outer_masks=None,
                           total_gibbs_steps=None):
        """Builds the tf.while_loop based sampling graph.

    Args:
      input_pianorolls: Optional input pianorolls override. If None, uses the
          pianorolls placeholder.
      outer_masks: Optional input outer_masks override. If None, uses the
          outer_masks placeholder.
      total_gibbs_steps: Optional input total_gibbs_steps override. If None,
          uses the total_gibbs_steps placeholder.
    Returns:
      The output op of the graph.
    """
        if input_pianorolls is None:
            input_pianorolls = self.inputs["pianorolls"]
        if outer_masks is None:
            outer_masks = self.inputs["outer_masks"]

        tt = tf.shape(input_pianorolls)[1]
        sample_steps = tf.to_float(self.inputs["sample_steps"])
        if total_gibbs_steps is None:
            total_gibbs_steps = self.inputs["total_gibbs_steps"]
        temperature = self.inputs["temperature"]

        input_pianorolls = tf.to_float(input_pianorolls)
        outer_masks = self.make_outer_masks(outer_masks, input_pianorolls)

        # Calculate total_gibbs_steps as steps * num_instruments if not given.
        total_gibbs_steps = tf.cond(
            tf.equal(total_gibbs_steps, 0),
            lambda: tf.to_float(tt * self.hparams.num_instruments),
            lambda: tf.to_float(total_gibbs_steps))

        # sample_steps is set to total_gibbs_steps if not given.
        sample_steps = tf.cond(tf.equal(sample_steps,
                                        0), lambda: total_gibbs_steps,
                               lambda: tf.to_float(sample_steps))

        def infer_step(pianorolls, step_count):
            """Called by tf.while_loop, takes a Gibbs step."""
            mask_prob = compute_mask_prob_from_yao_schedule(
                step_count, total_gibbs_steps)
            # 1 indicates mask out, 0 is not mask.
            masks = make_bernoulli_masks(tf.shape(pianorolls), mask_prob,
                                         outer_masks)

            logits = self.predict(pianorolls, masks)
            samples = sample_with_temperature(logits, temperature=temperature)

            outputs = pianorolls * (1 - masks) + samples * masks

            check_completion_op = tf.assert_equal(
                tf.where(tf.equal(tf.reduce_max(masks, axis=2), 1.),
                         tf.reduce_max(outputs, axis=2),
                         tf.reduce_max(pianorolls, axis=2)), 1.)
            with tf.control_dependencies([check_completion_op]):
                outputs = tf.identity(outputs)

            step_count += 1
            return outputs, step_count

        current_step = tf.to_float(self.inputs["current_step"])

        # Initializes pianorolls by evaluating the model once to fill in all gaps.
        logits = self.predict(tf.to_float(input_pianorolls), outer_masks)
        samples = sample_with_temperature(logits, temperature=temperature)
        tf.get_variable_scope().reuse_variables()

        self.samples, current_step = tf.while_loop(
            lambda samples, current_step: current_step < sample_steps,
            infer_step, [samples, current_step],
            shape_invariants=[
                tf.TensorShape([None, None, None, None]),
                tf.TensorShape(None),
            ],
            back_prop=False,
            parallel_iterations=1,
            name="coco_while")
        self.samples.set_shape(input_pianorolls.shape)
        return self.samples
예제 #13
0
def clean_bert_model(model_file,
                     save_file,
                     waste_name_: List[str] = None,
                     num_new_tokens=None,
                     word_embedding_name='word_embeddings',
                     output_bias_name='cls/predictions/output_bias'):
    '''
    将已保存的bert系列模型的优化器参数去掉
    :param model_file:  原始ckpt文件
    :param save_file: 处理后模型保存文件
    :param waste_name_: 自定义去除参数名
    :param num_new_tokens:  如果不为 None,则对权重的word embedding部分进行resize,这样可以自定义词典大小
    :param word_embedding_name:
    :param output_bias_name:
    :return:
    '''
    tf.reset_default_graph()
    var_list = tf.train.list_variables(model_file)
    var_values, var_dtypes = {}, {}

    waste_name = [
        'global_step',
        'adam',
        'Adam',  # for bert
        'lamb',
        'bad_steps',
        'good_steps',
        'loss_scale',  # for nezha
    ]
    if isinstance(waste_name_, list):
        waste_name.extend(waste_name_)

    for (name, shape) in var_list:
        if not any(n in name for n in waste_name):
            var_values[name] = None

    reader = pywrap_tensorflow.NewCheckpointReader(model_file)
    for name in var_values:
        tensor = reader.get_tensor(name)
        var_dtypes[name] = tensor.dtype
        if num_new_tokens is not None and num_new_tokens != tensor.shape[0]:
            if word_embedding_name in name:
                temp_tensor = get_truncated_normal_values(
                    shape=[num_new_tokens, tensor.shape[1]])
                min_size = min(num_new_tokens, tensor.shape[0])
                temp_tensor[:min_size, :] = tensor[:min_size, :]
                tensor = temp_tensor
            elif output_bias_name in name:
                temp_tensor = np.zeros([
                    num_new_tokens,
                ], dtype=tensor.dtype)
                min_size = min(num_new_tokens, tensor.shape[0])
                temp_tensor[:min_size] = tensor[:min_size]
                tensor = temp_tensor
        var_values[name] = tensor

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        tf_vars = [
            tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[v])
            for v in var_values
        ]
    placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
    assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]

    saver = tf.train.Saver(tf.all_variables())

    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        for p, assign_op, (name, value) in zip(placeholders, assign_ops,
                                               six.iteritems(var_values)):
            sess.run(assign_op, {p: value})

        # Use the built saver to save the averaged checkpoint.
        saver.save(sess, save_file)
예제 #14
0
 def __init__(self, object_factory):
   self._object_factory = object_factory
   self._wrapped_object = self._object_factory()
   self._variable_scope = tf.get_variable_scope()
   self._captured_calls = {}
   self._captured_attrs = {}
예제 #15
0
 def build_validator(self):
     with self.graph.as_default():
         iterator, val_image, val_segmentation = self.build_inputs(self.val_data_path)
         tf.get_variable_scope().reuse_variables()
         probs = model.forward_network_softmax(val_image, self.bright_dark)
     return iterator, probs, val_segmentation, val_image
예제 #16
0
  def model_fn(self, features, labels, mode, params):
    """TPUEstimator compatible model function."""
    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    data_shape = features.get_shape().as_list()[1:]
    data_shape[0] = int(data_shape[0] / 2)
    features_1 = features[:, :data_shape[0], :, :]
    features_2 = features[:, data_shape[0]:, :, :]
    with tf.variable_scope(
        tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
      z_mean, z_logvar = self.gaussian_encoder(features_1,
                                               is_training=is_training)
      z_mean_2, z_logvar_2 = self.gaussian_encoder(features_2,
                                                   is_training=is_training)
    labels = tf.squeeze(tf.one_hot(labels, z_mean.get_shape().as_list()[1]))
    kl_per_point = compute_kl(z_mean, z_mean_2, z_logvar, z_logvar_2)

    new_mean = 0.5 * z_mean + 0.5 * z_mean_2
    var_1 = tf.exp(z_logvar)
    var_2 = tf.exp(z_logvar_2)
    new_log_var = tf.math.log(0.5*var_1 + 0.5*var_2)

    mean_sample_1, log_var_sample_1 = self.aggregate(
        z_mean, z_logvar, new_mean, new_log_var, labels, kl_per_point)
    mean_sample_2, log_var_sample_2 = self.aggregate(
        z_mean_2, z_logvar_2, new_mean, new_log_var, labels, kl_per_point)
    z_sampled_1 = self.sample_from_latent_distribution(
        mean_sample_1, log_var_sample_1)
    z_sampled_2 = self.sample_from_latent_distribution(
        mean_sample_2, log_var_sample_2)
    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
      reconstructions_1 = self.decode(z_sampled_1, data_shape, is_training)
      reconstructions_2 = self.decode(z_sampled_2, data_shape, is_training)
    per_sample_loss_1 = losses.make_reconstruction_loss(
        features_1, reconstructions_1)
    per_sample_loss_2 = losses.make_reconstruction_loss(
        features_2, reconstructions_2)
    reconstruction_loss_1 = tf.reduce_mean(per_sample_loss_1)
    reconstruction_loss_2 = tf.reduce_mean(per_sample_loss_2)
    reconstruction_loss = (0.5 * reconstruction_loss_1 +
                           0.5 * reconstruction_loss_2)
    kl_loss_1 = vae.compute_gaussian_kl(mean_sample_1, log_var_sample_1)
    kl_loss_2 = vae.compute_gaussian_kl(mean_sample_2, log_var_sample_2)
    kl_loss = 0.5 * kl_loss_1 + 0.5 * kl_loss_2
    regularizer = self.regularizer(
        kl_loss, None, None, None)

    loss = tf.add(reconstruction_loss,
                  regularizer,
                  name="loss")
    elbo = tf.add(reconstruction_loss, kl_loss, name="elbo")
    if mode == tf.estimator.ModeKeys.TRAIN:
      optimizer = optimizers.make_vae_optimizer()
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      train_op = optimizer.minimize(
          loss=loss, global_step=tf.train.get_global_step())
      train_op = tf.group([train_op, update_ops])
      tf.summary.scalar("reconstruction_loss", reconstruction_loss)
      tf.summary.scalar("elbo", -elbo)
      logging_hook = tf.train.LoggingTensorHook({
          "loss": loss,
          "reconstruction_loss": reconstruction_loss,
          "elbo": -elbo,
      },
                                                every_n_iter=100)
      return TPUEstimatorSpec(
          mode=mode,
          loss=loss,
          train_op=train_op,
          training_hooks=[logging_hook])
    elif mode == tf.estimator.ModeKeys.EVAL:
      return TPUEstimatorSpec(
          mode=mode,
          loss=loss,
          eval_metrics=(make_metric_fn("reconstruction_loss", "elbo",
                                       "regularizer", "kl_loss"),
                        [reconstruction_loss, -elbo, regularizer, kl_loss]))
    else:
      raise NotImplementedError("Eval mode not supported.")
예제 #17
0
 def __init__(self, name, *args, **kwargs):
     with tf.variable_scope(name):
         self._init(*args, **kwargs)
         self.scope = tf.get_variable_scope().name
예제 #18
0
def clean_bert_model(model_file,
                     save_file,
                     remove_ori_model=False,
                     waste_name_: List[str] = None):
    '''
    将已保存的bert系列模型的优化器参数去掉
    :param model_file:  原始ckpt文件
    :param save_file: 处理后模型保存文件
    :param remove_ori_model: 是否删除原来的模型文件
    :param waste_name_: 自定义去除参数名
    :return:
    '''
    tf.reset_default_graph()
    var_list = tf.train.list_variables(model_file)
    var_values, var_dtypes = {}, {}

    waste_name = [
        'global_step',
        'adam',
        'Adam',  # for bert
        'lamb',
        'bad_steps',
        'good_steps',
        'loss_scale',  # for nezha
    ]
    if isinstance(waste_name, list):
        waste_name.extend(waste_name_)

    for (name, shape) in var_list:
        if not any(n in name for n in waste_name):
            var_values[name] = None

    reader = contrib.framework.load_checkpoint(model_file)
    for name in var_values:
        tensor = reader.get_tensor(name)
        var_dtypes[name] = tensor.dtype
        var_values[name] = tensor

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        tf_vars = [
            tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[v])
            for v in var_values
        ]
    placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
    assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]

    saver = tf.train.Saver(tf.all_variables())

    # 去除原本的模型文件
    if remove_ori_model:
        dir, filename = os.path.split(model_file)
        for file in os.listdir(dir):
            file = os.path.join(dir, file)
            if model_file in file:
                os.remove(file)

    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        for p, assign_op, (name, value) in zip(placeholders, assign_ops,
                                               six.iteritems(var_values)):
            sess.run(assign_op, {p: value})

        # Use the built saver to save the averaged checkpoint.
        saver.save(sess, save_file)
예제 #19
0
    def test_multitower_examples_model(self):
        """Ensure graph search runs properly on a multitower setup.

    This test uses linear_model from examples/convnets.
    """
        with tf.Graph().as_default():

            def linear_model(images, labels, num_classes):
                """Creates a linear model.

        Args:
          images: The input image tensors, a tensor of size
              (batch_size x height_in x width_in x channels).
          labels: The sparse target labels, a tensor of size (batch_size x 1).
          num_classes: The number of classes, needed for one-hot encoding (int).

        Returns:
          loss: The total loss for this model (0-D tensor).
          logits: Predictions for this model (batch_size x num_classes).
        """
                images = tf.reshape(images, [images.shape[0], -1])
                logits = tf.layers.dense(images, num_classes, name='logits')
                loss = sparse_softmax_cross_entropy(labels, logits,
                                                    num_classes)
                return loss, logits

            model = linear_model
            layer_collection = lc.LayerCollection()
            num_towers = 2
            batch_size = num_towers
            num_classes = 2

            # Set up data.
            images = tf.random_uniform(shape=[batch_size, 32, 32, 1])
            labels = tf.random_uniform(dtype=tf.int64,
                                       shape=[batch_size, 1],
                                       maxval=num_classes)

            tower_images = tf.split(images, num_towers)
            tower_labels = tf.split(labels, num_towers)

            # Build model.
            losses = []
            logits = []
            for tower_id in range(num_towers):
                tower_name = 'tower%d' % tower_id
                with tf.name_scope(tower_name):
                    with tf.variable_scope(tf.get_variable_scope(),
                                           reuse=(tower_id > 0)):
                        current_loss, current_logits = model(
                            tower_images[tower_id], tower_labels[tower_id],
                            num_classes + 1)
                        layer_collection.register_categorical_predictive_distribution(
                            current_logits, name='logits')
                        losses.append(current_loss)
                        logits.append(current_logits)

            # Run the graph scanner.
            with tf.variable_scope(tf.get_variable_scope(),
                                   reuse=tf.AUTO_REUSE):
                gs.register_layers(layer_collection, tf.trainable_variables())
            self.assertEqual(len(layer_collection.fisher_blocks), 1)
            fisher_block = list(layer_collection.fisher_blocks.values())[0]
            self.assertIsInstance(fisher_block, fb.FullyConnectedKFACBasicFB)
            self.assertEqual(fisher_block.num_registered_towers, num_towers)

            global_step = tf.train.get_or_create_global_step()
            opt = optimizer.KfacOptimizer(learning_rate=0.1,
                                          cov_ema_decay=0.1,
                                          damping=0.1,
                                          layer_collection=layer_collection,
                                          momentum=0.1)
            cost = tf.reduce_mean(losses)
            (cov_update_thunks,
             inv_update_thunks) = opt.make_vars_and_create_op_thunks()
            cov_update_op = tf.group(*(thunk() for thunk in cov_update_thunks))
            inv_update_op = tf.group(*(thunk() for thunk in inv_update_thunks))
            train_op = opt.minimize(cost, global_step=global_step)
            init = tf.global_variables_initializer()

            # Run a single training step.
            with self.test_session() as sess:
                sess.run(init)
                sess.run([cov_update_op])
                sess.run([inv_update_op])
                sess.run([train_op])
예제 #20
0
def main(unused_argv):
  """Builds the graph and then runs training and validation."""
  print('TensorFlow version:', tf.__version__)

  tf.logging.set_verbosity(tf.logging.INFO)

  if FLAGS.data_dir is None:
    tf.logging.fatal('No input directory was provided.')

  print(FLAGS.maskout_method, 'separate', FLAGS.separate_instruments)

  hparams = _hparams_from_flags()

  # Get data.
  print('dataset:', FLAGS.dataset, FLAGS.data_dir)
  print('current dir:', os.path.curdir)
  train_data = lib_data.get_dataset(FLAGS.data_dir, hparams, 'train')
  valid_data = lib_data.get_dataset(FLAGS.data_dir, hparams, 'valid')
  print('# of train_data:', train_data.num_examples)
  print('# of valid_data:', valid_data.num_examples)
  if train_data.num_examples < hparams.batch_size:
    print('reducing batch_size to %i' % train_data.num_examples)
    hparams.batch_size = train_data.num_examples

  train_data.update_hparams(hparams)

  # Save hparam configs.
  logdir = os.path.join(FLAGS.logdir, hparams.log_subdir_str)
  tf.gfile.MakeDirs(logdir)
  config_fpath = os.path.join(logdir, 'config')
  tf.logging.info('Writing to %s', config_fpath)
  with tf.gfile.Open(config_fpath, 'w') as p:
    hparams.dump(p)

  # Build the graph and subsequently running it for train and validation.
  with tf.Graph().as_default():
    no_op = tf.no_op()

    # Build placeholders and training graph, and validation graph with reuse.
    m = lib_graph.build_graph(is_training=True, hparams=hparams)
    tf.get_variable_scope().reuse_variables()
    mvalid = lib_graph.build_graph(is_training=False, hparams=hparams)

    tracker = Tracker(
        label='validation loss',
        patience=FLAGS.patience,
        decay_op=m.decay_op,
        save_path=os.path.join(FLAGS.logdir, hparams.log_subdir_str,
                               'best_model.ckpt'))

    # Graph will be finalized after instantiating supervisor.
    sv = tf.train.Supervisor(
        logdir=logdir,
        saver=tf.train.Supervisor.USE_DEFAULT if FLAGS.log_progress else None,
        summary_op=None,
        save_model_secs=FLAGS.save_model_secs)
    with sv.PrepareSession() as sess:
      epoch_count = 0
      while epoch_count < FLAGS.num_epochs or not FLAGS.num_epochs:
        if sv.should_stop():
          break

        # Run training.
        run_epoch(sv, sess, m, train_data, hparams, m.train_op, 'train',
                  epoch_count)

        # Run validation.
        if epoch_count % hparams.eval_freq == 0:
          estimate_popstats(sv, sess, m, train_data, hparams)
          loss = run_epoch(sv, sess, mvalid, valid_data, hparams, no_op,
                           'valid', epoch_count)
          tracker(loss, sess)
          if tracker.should_stop():
            break

        epoch_count += 1

    print('best', tracker.label, tracker.best)
    print('Done.')
    return tracker.best
예제 #21
0
파일: network.py 프로젝트: mecha2k/tf-1.15
    def _init_graph(self) -> None:
        # Collect inputs.
        self.input_names = []

        for param in inspect.signature(self._build_func).parameters.values():
            if param.kind == param.POSITIONAL_OR_KEYWORD and param.default is param.empty:
                self.input_names.append(param.name)

        self.num_inputs = len(self.input_names)
        assert self.num_inputs >= 1

        # Choose name and scope.
        if self.name is None:
            self.name = self._build_func_name
        assert re.match("^[A-Za-z0-9_.\\-]*$", self.name)
        with tf.name_scope(None):
            self.scope = tf.get_default_graph().unique_name(self.name, mark_as_used=True)

        # Finalize build func kwargs.
        build_kwargs = dict(self.static_kwargs)
        build_kwargs["is_template_graph"] = True
        build_kwargs["components"] = self.components

        # Build template graph.
        with tfutil.absolute_variable_scope(self.scope, reuse=tf.AUTO_REUSE), tfutil.absolute_name_scope(self.scope):  # ignore surrounding scopes
            assert tf.get_variable_scope().name == self.scope
            assert tf.get_default_graph().get_name_scope() == self.scope
            with tf.control_dependencies(None):  # ignore surrounding control dependencies
                self.input_templates = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
                out_expr = self._build_func(*self.input_templates, **build_kwargs)

        # Collect outputs.
        assert tfutil.is_tf_expression(out_expr) or isinstance(out_expr, tuple)
        self.output_templates = [out_expr] if tfutil.is_tf_expression(out_expr) else list(out_expr)
        self.num_outputs = len(self.output_templates)
        assert self.num_outputs >= 1
        assert all(tfutil.is_tf_expression(t) for t in self.output_templates)

        # Perform sanity checks.
        if any(t.shape.ndims is None for t in self.input_templates):
            raise ValueError("Network input shapes not defined. Please call x.set_shape() for each input.")
        if any(t.shape.ndims is None for t in self.output_templates):
            raise ValueError("Network output shapes not defined. Please call x.set_shape() where applicable.")
        if any(not isinstance(comp, Network) for comp in self.components.values()):
            raise ValueError("Components of a Network must be Networks themselves.")
        if len(self.components) != len(set(comp.name for comp in self.components.values())):
            raise ValueError("Components of a Network must have unique names.")

        # List inputs and outputs.
        self.input_shapes = [tfutil.shape_to_list(t.shape) for t in self.input_templates]
        self.output_shapes = [tfutil.shape_to_list(t.shape) for t in self.output_templates]
        self.input_shape = self.input_shapes[0]
        self.output_shape = self.output_shapes[0]
        self.output_names = [t.name.split("/")[-1].split(":")[0] for t in self.output_templates]

        # List variables.
        self.own_vars = OrderedDict((var.name[len(self.scope) + 1:].split(":")[0], var) for var in tf.global_variables(self.scope + "/"))
        self.vars = OrderedDict(self.own_vars)
        self.vars.update((comp.name + "/" + name, var) for comp in self.components.values() for name, var in comp.vars.items())
        self.trainables = OrderedDict((name, var) for name, var in self.vars.items() if var.trainable)
        self.var_global_to_local = OrderedDict((var.name.split(":")[0], name) for name, var in self.vars.items())
def train(args):
    """Train different architectures for a number of epochs."""

    with tf.Graph().as_default(), tf.device('/cpu:0'):

        # Read data from disk
        images, labels = data_loader.read_inputs(True, args)

        epoch_number = tf.get_variable('epoch_number', [],
                                       dtype=tf.int32,
                                       initializer=tf.constant_initializer(0),
                                       trainable=False)

        # Decay the learning rate
        lr = tf.train.piecewise_constant(epoch_number,
                                         args.LR_steps,
                                         args.LR_values,
                                         name='LearningRate')
        # Weight Decay policy
        wd = tf.train.piecewise_constant(epoch_number,
                                         args.WD_steps,
                                         args.WD_values,
                                         name='WeightDecay')

        is_training = not args.transfer_mode[0] == 1

        # Create an optimizer that performs gradient descent.
        opt = tf.train.MomentumOptimizer(lr, 0.9)

        # Calculate the gradients for each model tower.
        tower_grads = []
        tower_auxgrads = []
        with tf.variable_scope(tf.get_variable_scope()):
            for i in xrange(args.num_gpus):
                with tf.device('/gpu:%d' % i):
                    with tf.name_scope('Tower_%d' % i) as scope:
                        # Calculate the loss for one tower. This function
                        # constructs the entire model but shares the variables across
                        # all towers.
                        logits = arch.get_model(images, wd, is_training, args)

                        # Top-1 accuracy
                        top1acc = tf.reduce_mean(
                            tf.cast(tf.nn.in_top_k(logits, labels, 1),
                                    tf.float32))
                        # Top-5 accuracy
                        topnacc = tf.reduce_mean(
                            tf.cast(tf.nn.in_top_k(logits, labels, args.top_n),
                                    tf.float32))

                        # Build the portion of the Graph calculating the losses. Note that we will
                        # assemble the total_loss using a custom function below.
                        cross_entropy_mean = loss(logits, labels)

                        # Get all the regularization lesses and add them
                        regularization_losses = tf.get_collection(
                            tf.GraphKeys.REGULARIZATION_LOSSES)

                        reg_loss = tf.add_n(regularization_losses)

                        #Add a tensorboard summary
                        tf.summary.scalar('Regularization Loss', reg_loss)

                        # Compute the total loss (cross entropy loss + regularization loss)
                        total_loss = tf.add(cross_entropy_mean, reg_loss)

                        # Attach a scalar summary for the total loss and top-1 and top-5 accuracies
                        tf.summary.scalar('Total Loss', total_loss)
                        tf.summary.scalar('Top-1 Accuracy', top1acc)
                        tf.summary.scalar(
                            'Top-' + str(args.top_n) + ' Accuracy', topnacc)

                        # Reuse variables for the next tower.
                        tf.get_variable_scope().reuse_variables()

                        # Retain the summaries from the final tower.
                        summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
                                                      scope)

                        # Gather batch normaliziation update operations
                        batchnorm_updates = tf.get_collection(
                            tf.GraphKeys.UPDATE_OPS, scope)

                        # Calculate the gradients for the batch of data on this CIFAR tower.
                        if args.transfer_mode[0] == 3:
                            grads = opt.compute_gradients(
                                total_loss,
                                var_list=tf.get_collection(
                                    tf.GraphKeys.VARIABLES, scope='output'))
                            auxgrads = opt.compute_gradients(total_loss)
                            tower_auxgrads.append(auxgrads)
                        elif args.transfer_mode[0] == 1:
                            grads = opt.compute_gradients(
                                total_loss,
                                var_list=tf.get_collection(
                                    tf.GraphKeys.VARIABLES, scope='output'))
                        else:
                            grads = opt.compute_gradients(total_loss)

                        # Keep track of the gradients across all towers.
                        tower_grads.append(grads)
        # We must calculate the mean of each gradient. Note that this is the
        # synchronization point across all towers.
        grads = average_gradients(tower_grads)
        auxgrads = average_gradients(tower_auxgrads)

        # Add a summary to track the learning rate and weight decay
        summaries.append(tf.summary.scalar('learning_rate', lr))
        summaries.append(tf.summary.scalar('weight_decay', wd))

        # Group all updates to into a single train op.
        #with tf.control_dependencies(bn_update_ops):
        # Setup the train operation
        if args.transfer_mode[0] == 3:
            train_op = tf.cond(
                tf.less(epoch_number, args.transfer_mode[1]), lambda: tf.group(
                    opt.apply_gradients(grads), *batchnorm_updates), lambda: tf
                .group(opt.apply_gradients(auxgrads), *batchnorm_updates))
        elif args.transfer_mode[0] == 1:
            train_op = opt.apply_gradients(grads)
        else:
            batchnorm_updates_op = tf.group(*batchnorm_updates)
            train_op = tf.group(opt.apply_gradients(grads),
                                batchnorm_updates_op)

        # a loader for loading the pretrained model (it does not load the last layer)
        if args.retrain_from is not None:
            if args.transfer_mode[0] == 0:
                pretrained_loader = tf.train.Saver()
            else:
                pretrained_loader = tf.train.Saver(var_list=exclude())

        # Create a saver.
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=5)

        # Build the summary operation from the last tower summaries.
        summary_op = tf.summary.merge_all()

        # Build an initialization operation to run below.
        init = tf.global_variables_initializer()

        # Logging the runtime information if requested
        if args.log_debug_info:
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()
        else:
            run_options = None
            run_metadata = None

        # Creating a session to run the built graph
        sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=args.log_device_placement))

        sess.run(init)

        # Continue training from a saved snapshot, load a pre-trained model
        if args.retrain_from is not None:
            ckpt = tf.train.get_checkpoint_state(args.retrain_from)
            if ckpt and ckpt.model_checkpoint_path:
                # Restores from checkpoint
                pretrained_loader.restore(sess, ckpt.model_checkpoint_path)
            else:
                print("checkpoint not found: " + args.retrain_from)
                return

        # Start the queue runners.
        tf.train.start_queue_runners(sess=sess)

        # Setup a summary writer
        summary_writer = tf.summary.FileWriter(args.log_dir, sess.graph)

        # Set the start epoch number
        start_epoch = sess.run(epoch_number + 1)

        # The main training loop
        for epoch in xrange(start_epoch, start_epoch + args.num_epochs):
            # update epoch_number
            sess.run(epoch_number.assign(epoch))

            # Trainig batches
            for step in xrange(args.num_batches):

                start_time = time.time()
                _, loss_value, top1_accuracy, topn_accuracy = sess.run(
                    [train_op, cross_entropy_mean, top1acc, topnacc],
                    options=run_options,
                    run_metadata=run_metadata)
                duration = time.time() - start_time

                # Check for errors
                assert not np.isnan(
                    loss_value), 'Model diverged with loss = NaN'

                # Logging and writing tensorboard summaries
                if step % 10 == 0:
                    num_examples_per_step = args.chunked_batch_size * args.num_gpus
                    examples_per_sec = num_examples_per_step / duration
                    sec_per_batch = duration / args.num_gpus

                    format_str = (
                        '%s: epoch %d, step %d, loss = %.2f, Top-1 = %.2f Top-'
                        + str(args.top_n) + ' = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    print(format_str %
                          (datetime.now(), epoch, step, loss_value,
                           top1_accuracy, topn_accuracy, examples_per_sec,
                           sec_per_batch))
                    sys.stdout.flush()
                if step % 100 == 0:
                    summary_str = sess.run(summary_op)
                    summary_writer.add_summary(summary_str,
                                               args.num_batches * epoch + step)
                    if args.log_debug_info:
                        summary_writer.add_run_metadata(
                            run_metadata, 'epoch%d step%d' % (epoch, step))

            # Save the model checkpoint periodically after each training epoch
            checkpoint_path = os.path.join(args.log_dir, args.snapshot_prefix)
            saver.save(sess, checkpoint_path, global_step=epoch)
예제 #23
0
    def _init_model(self, **kwargs):
        dtypes = (tf.float32, tf.float32)
        output_shapes = ([None, None, None, self.input_size[-1]], None)
        self._init_unprocessing(**kwargs)
        self.denoising_losses = []
        self._make_filters()
        self._set_next_elements(dtypes, output_shapes)
        self._init_vgg_net(**kwargs)
        vgg_input_gt = dict()
        vgg_input_pred = dict()
        with tf.variable_scope(tf.get_variable_scope()):
            for i in range(self.device_offset,
                           self.num_devices + self.device_offset):
                self._curr_device = i
                self._curr_block = None
                self._curr_dependent_op = 0  # For ops with dependencies between GPUs such as BN
                device = '/{}:'.format(self.compute_device) + str(i)
                with tf.device(device):
                    with tf.name_scope(self.compute_device + '_' + str(i) +
                                       '/'):
                        self.X, _ = self.next_elements[device]
                        self.X_in.append(self.X)

                        self.X = self.zero_pad(self.X,
                                               pad_value=self.pad_value)
                        self.X = tf.cond(
                            self.augmentation,
                            lambda: self.augment_images(self.X, **kwargs),
                            lambda: self.center_crop(self.X),
                            name='augmentation')

                        image, bayer_img, noisy_img, variance, metadata = tf.map_fn(
                            self.unprocess_images,
                            self.X,
                            dtype=(tf.float32, tf.float32, tf.float32,
                                   tf.float32, [
                                       tf.float32, tf.float32, tf.float32,
                                       tf.float32, tf.float32, tf.float32
                                   ]),
                            parallel_iterations=32,
                            back_prop=False)

                        self.Y = self.process(image,
                                              metadata[2],
                                              metadata[3],
                                              metadata[0],
                                              simple=self.simple_unprocessing)
                        self.Y.set_shape([None] + list(self.input_size))
                        vgg_input_gt[device] = self.Y

                        self.Y_mosaic = process.process(
                            bayer_img,
                            metadata[2],
                            metadata[3],
                            metadata[0],
                            simple=self.simple_unprocessing)
                        self.Y_mosaic.set_shape([None] + list(self.input_size))
                        noisy = process.process(
                            noisy_img,
                            metadata[2],
                            metadata[3],
                            metadata[0],
                            simple=self.simple_unprocessing)
                        noisy.set_shape([None] + list(self.input_size))
                        self.Xs.append(noisy)
                        self.Ys.append(self.Y)

                        self.X = tf.concat([noisy_img, variance], axis=-1)
                        self.X = tf.math.subtract(self.X,
                                                  self.image_mean,
                                                  name='zero_center')
                        if self.channel_first:
                            self.X = tf.transpose(self.X, perm=[0, 3, 1, 2])
                        if self.dtype is not tf.float32:
                            with tf.name_scope(
                                    '{}/cast/'.format(self.compute_device +
                                                      '_' + str(i))):
                                self.X = tf.cast(self.X, dtype=self.dtype)

                        self._shot_noise_tensor = metadata[4]
                        self._read_noise_tensor = metadata[5]
                        with tf.name_scope(
                                'nn'
                        ) if self.model_scope is None else tf.name_scope(
                                self.model_scope):
                            self.d = self._build_model()
                        if self.dtype is not tf.float32:
                            with tf.name_scope(
                                    '{}/cast/'.format(self.compute_device +
                                                      '_' + str(i))):
                                self.d['pred'] = tf.cast(self.d['pred'],
                                                         dtype=tf.float32)
                        if self.channel_first:
                            self.d['pred'] = tf.transpose(self.d['pred'],
                                                          perm=[0, 2, 3, 1])
                        tf.get_variable_scope().reuse_variables()

                        self.dicts.append(self.d)
                        self.pred = self.process(
                            self.d['pred'],
                            metadata[2],
                            metadata[3],
                            metadata[0],
                            simple=self.simple_unprocessing)
                        self.pred.set_shape([None] + list(self.input_size))
                        if 'denoised' in self.d:
                            if self.dtype is not tf.float32:
                                with tf.name_scope(
                                        '{}/cast/'.format(self.compute_device +
                                                          '_' + str(i))):
                                    self.d['denoised'] = tf.cast(
                                        self.d['denoised'], dtype=tf.float32)
                            if self.channel_first:
                                self.d['denoised'] = tf.transpose(
                                    self.d['denoised'], perm=[0, 2, 3, 1])
                            self.denoised = process.process(
                                self.d['denoised'],
                                metadata[2],
                                metadata[3],
                                metadata[0],
                                simple=self.simple_unprocessing)
                            self.denoised.set_shape([None] +
                                                    list(self.input_size))
                        else:
                            self.denoised = None
                        self.preds.append(self.pred)
                        vgg_input_pred[device] = self.pred

                        self.losses.append(self._build_loss(**kwargs))

        self._build_perceptual_loss(vgg_input_gt, vgg_input_pred, **kwargs)
        self._make_debug_images()
        with tf.device(self.param_device):
            with tf.variable_scope('calc/'):
                self.debug_values.append(
                    tf.reduce_mean(self.denoising_losses,
                                   name='denoising_loss'))
예제 #24
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    processors = {
        "sst-2": rc.SST2Processor,
        "mnli": rc.MnliProcessor,
    }

    tokenization.validate_case_matches_checkpoint(FLAGS.do_lower_case,
                                                  FLAGS.init_checkpoint)
    bert_config = modeling.BertConfig.from_json_file(FLAGS.bert_config_file)

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.do_lower_case)

    task_name = FLAGS.task_name.lower()
    processor = processors[task_name]()
    num_labels = len(processor.get_labels())

    # This flexible input variable will be optimized to carry out model inversion
    flex_input = tf.get_variable(name="flex_input",
                                 shape=[
                                     FLAGS.train_batch_size,
                                     FLAGS.max_seq_length,
                                     bert_config.hidden_size
                                 ])

    # Use the input template to mix original embeddings with flex input embeddings
    # Different segments in template are separated by " <piece> "
    # Each segment is associated with a word piece (or [EMPTY] to get flex inputs)
    # and a frequency. (which is separated by "<freq>"). * can be used to choose a
    # frequency till the end of the string
    #
    # Here is an example 2-sequence template for tasks like MNLI to optimize
    # 20 vectors, (10 for each sequence)
    # [CLS]<freq>1 <piece> [EMPTY]<freq>10 <piece> [SEP]<freq>1 <piece> \
    # [EMPTY]<freq>10 <piece> [SEP]<freq>1 <piece> [PAD]<freq>*
    (input_tensor, embed_var, flex_input_mask, bert_input_mask,
     token_type_ids) = em_util.template_to_input_tensor(
         template=FLAGS.input_template,
         flex_input=flex_input,
         config=bert_config,
         tokenizer=tokenizer,
         max_seq_length=FLAGS.max_seq_length)

    # Get the nearest neighbours of the input tensor
    # Useful for converting input tensor back to a string representation
    nearest_neighbours, cos_sim = em_util.get_nearest_neighbour(
        source=input_tensor, reference=embed_var)

    # Convert the nearest neighbours back into embeddings. This is done since text
    # is discrete, and we want to create actual textual outputs.
    nn_embeddings, _ = em_util.run_bert_embeddings(
        input_ids=nearest_neighbours, config=bert_config)

    mean_masked_cos_sim = tf.reduce_mean(
        tf.boolean_mask(cos_sim, flex_input_mask))

    # With this probability vector, a custom cross-entropy goal can be specified.
    # When this is used, the inputs are optimized to encourage the classifier to
    # produce a softmax output similar to prob_vector.
    prob_vector = tf.constant(
        [[float(x) for x in FLAGS.prob_vector.split(",")]])

    model_fn_partial = functools.partial(em_util.model_fn,
                                         bert_input_mask=bert_input_mask,
                                         token_type_ids=token_type_ids,
                                         bert_config=bert_config,
                                         num_labels=num_labels,
                                         obj_type=FLAGS.obj_type,
                                         prob_vector=prob_vector)

    parent_scope = tf.get_variable_scope()
    with tf.variable_scope(parent_scope):
        flex_input_obj, _, _ = model_fn_partial(input_tensor=input_tensor)

    if FLAGS.obj_type[:3] == "max":
        flex_input_loss = -1 * flex_input_obj
    elif FLAGS.obj_type[:3] == "min":
        flex_input_loss = flex_input_obj

    with tf.variable_scope(parent_scope, reuse=True):
        nn_input_obj, _, _ = model_fn_partial(input_tensor=nn_embeddings)

    opt = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
    invert_op = opt.minimize(flex_input_loss, var_list=[flex_input])

    tvars = tf.trainable_variables()

    assignment_map, _ = modeling.get_assignment_map_from_checkpoint(
        tvars, FLAGS.init_checkpoint)

    tf.logging.info("Variables mapped = %d / %d", len(assignment_map),
                    len(tvars))

    tf.train.init_from_checkpoint(FLAGS.init_checkpoint, assignment_map)

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    # optimize a part of the flex_input (depending on the template)
    for i in range(FLAGS.total_steps):
        fio, nio, _, mcs = sess.run(
            [flex_input_obj, nn_input_obj, invert_op, mean_masked_cos_sim])
        tf.logging.info(
            "Step %d / %d. flex-input obj = %.4f, nn obj = %.4f, cos sim = %.4f",
            i, FLAGS.total_steps, fio, nio, mcs)

    # Find nearest neighbours for the final optimized vectors
    batched_nn, batched_nn_sim = sess.run([nearest_neighbours, cos_sim])

    for nn, _ in zip(batched_nn, batched_nn_sim):
        tf.logging.info("Sentence = %s", em_util.detokenize(nn, tokenizer))

    return
예제 #25
0
def trace_scan(loop_fn,
               initial_state,
               elems,
               trace_fn,
               parallel_iterations=10,
               name=None):
    """A simplified version of `tf.scan` that has configurable tracing.

  This function repeatedly calls `loop_fn(state, elem)`, where `state` is the
  `initial_state` during the first iteration, and the return value of `loop_fn`
  for every iteration thereafter. `elem` is a slice of `elements` along the
  first dimension, accessed in order. Additionally, it calls `trace_fn` on the
  return value of `loop_fn`. The `Tensor`s in return values of `trace_fn` are
  stacked and returned from this function, such that the first dimension of
  those `Tensor`s matches the size of `elems`.

  Args:
    loop_fn: A callable that takes in a `Tensor` or a nested collection of
      `Tensor`s with the same structure as `initial_state`, a slice of `elems`
      and returns the same structure as `initial_state`.
    initial_state: A `Tensor` or a nested collection of `Tensor`s passed to
      `loop_fn` in the first iteration.
    elems: A `Tensor` that is split along the first dimension and each element
      of which is passed to `loop_fn`.
    trace_fn: A callable that takes in the return value of `loop_fn` and returns
      a `Tensor` or a nested collection of `Tensor`s.
    parallel_iterations: Passed to the internal `tf.while_loop`.
    name: Name scope used in this function. Default: 'trace_scan'.

  Returns:
    final_state: The final return value of `loop_fn`.
    trace: The same structure as the return value of `trace_fn`, but with each
      `Tensor` being a stack of the corresponding `Tensors` in the return value
      of `trace_fn` for each slice of `elems`.
  """
    with tf1.name_scope(name, 'trace_scan',
                        [initial_state, elems]), tf1.variable_scope(
                            tf1.get_variable_scope()) as vs:
        if vs.caching_device is None and not tf.executing_eagerly():
            vs.set_caching_device(lambda op: op.device)

        initial_state = tf.nest.map_structure(
            lambda x: tf.convert_to_tensor(value=x, name='initial_state'),
            initial_state)
        elems = tf.convert_to_tensor(value=elems, name='elems')

        static_length = elems.shape[0]
        if tf.compat.dimension_value(static_length) is None:
            length = tf.shape(input=elems)[0]
        else:
            length = tf.convert_to_tensor(value=static_length,
                                          dtype=tf.int32,
                                          name='length')

        # This is an TensorArray in part because of XLA, which had trouble with
        # non-statically known indices. I.e. elems[i] errored, but
        # elems_array.read(i) worked.
        elems_array = tf.TensorArray(elems.dtype,
                                     size=length,
                                     element_shape=elems.shape[1:])
        elems_array = elems_array.unstack(elems)

        trace_arrays = tf.nest.map_structure(
            lambda x: tf.TensorArray(
                x.dtype, size=length, element_shape=x.shape),
            trace_fn(initial_state))

        def _body(i, state, trace_arrays):
            state = loop_fn(state, elems_array.read(i))
            trace_arrays = tf.nest.pack_sequence_as(trace_arrays, [
                a.write(i, v) for a, v in zip(tf.nest.flatten(trace_arrays),
                                              tf.nest.flatten(trace_fn(state)))
            ])
            return i + 1, state, trace_arrays

        _, final_state, trace_arrays = tf.while_loop(
            cond=lambda i, *args: i < length,
            body=_body,
            loop_vars=(0, initial_state, trace_arrays),
            parallel_iterations=parallel_iterations)

        stacked_trace = tf.nest.map_structure(lambda x: x.stack(),
                                              trace_arrays)

        # Restore the static length if we know it.
        def _merge_static_length(x):
            x.set_shape(tf.TensorShape(static_length).concatenate(x.shape[1:]))
            return x

        stacked_trace = tf.nest.map_structure(_merge_static_length,
                                              stacked_trace)
        return final_state, stacked_trace
예제 #26
0
def main(unused_argv: Any) -> None:
    tf.logging.info("Saving model saves and results to " + FLAGS.model_dir)

    global_seed(42)

    if not FLAGS.do_train and not FLAGS.do_eval:
        raise ValueError("At least one of `do_train`, `do_eval` must be True.")

    config = model_config.load_config(FLAGS.config)

    if FLAGS.do_train:
        tf.logging.info("Training with train filenames: " +
                        str(FLAGS.training_filename))

    # Training allows noisy examples so do not use clean output vocab
    model_fn = model_builder.build_model_fn(config,
                                            FLAGS.output_vocab_filepath,
                                            clean_output_vocab_path="")

    # region training
    if FLAGS.do_train:
        # for keepsake CLI (helps track experiment results)
        experiment = keepsake.init(params={
            "learning_rate": config.training_options.optimizer_learning_rate,
            "batch_size": config.training_options.batch_size,
            "training_steps": config.training_options.training_steps,
            "eval_batch_size": FLAGS.eval_batch_size,
            "training_data": FLAGS.training_filename,
            "eval_data": FLAGS.eval_filename,
        }, )

        train_input_fn = input_pipeline.create_training_input_fn(
            config,
            FLAGS.tf_examples_dir,
            [name for name in FLAGS.training_filename if name],
        )

        train_features, train_labels = train_input_fn()
        train_model = model_fn(train_features, train_labels,
                               tf.estimator.ModeKeys.TRAIN)

        tf.get_variable_scope().reuse_variables()

        inference_config = inference.Config(
            FLAGS.eval_dataset_name,
            FLAGS.eval_splits.split(","),
            FLAGS.output_vocab_filepath,
            FLAGS.clean_output_vocab_filepath,
            FLAGS.eval_beam_size,
            FLAGS.using_abstract_sql,
            FLAGS.database_directory,
            FLAGS.empty_database_directory,
            FLAGS.original_data_directory,
            model_config.load_config(FLAGS.config),
        )

        saver = tf.train.Saver(max_to_keep=None)

        global_step = 0
        checkpoint = checkpoint_path(FLAGS.model_dir, global_step)

        validation_query_cache: Dict[str, Any] = {}

        with tf.Session() as init_sess:
            init_sess.run(tf.global_variables_initializer())
            saver.save(init_sess, checkpoint)

        while global_step < config.training_options.training_steps:
            # region training loop
            with tf.Session() as train_sess:
                tf.logging.info(
                    "Training from step %s to step %s",
                    global_step,
                    global_step + FLAGS.steps_between_saves,
                )
                saver.restore(train_sess, checkpoint)

                train_losses = []

                for step in range(FLAGS.steps_between_saves):
                    _, train_loss = train_sess.run(
                        [train_model.train_op, train_model.loss])

                    train_losses.append(train_loss)

                    if step % 100 == 0:
                        tf.logging.info(
                            "Step %s's training loss: %s",
                            global_step + step,
                            train_loss,
                        )

                train_loss = statistics.mean(train_losses)

                global_step += FLAGS.steps_between_saves
                checkpoint = checkpoint_path(FLAGS.model_dir, global_step)
                saver.save(train_sess, checkpoint)
            # endregion

            # region eval loop
            tf.logging.info("Evaluating checkpoint %s", checkpoint)

            examples = inference.load_tf_examples(
                os.path.join(FLAGS.tf_examples_dir, FLAGS.eval_filename))
            random.shuffle(examples)

            tf.logging.info("Running inference on %s", FLAGS.eval_filename)
            predictions = inference.inference(
                examples,
                checkpoint,
                inference_config,
            )

            examples_to_execute = get_examples_to_execute(
                predictions, inference_config)

            # Only update cache when it's empty
            should_update_cache = len(validation_query_cache) == 0

            # only scholar is case sensitive
            case_sensitive = "scholar" not in FLAGS.eval_dataset_name.lower()

            results, validation_query_cache = official_evaluation.execute_predictions(
                instructions=examples_to_execute,
                cache_dict=validation_query_cache,
                case_sensitive=case_sensitive,
                verbose=False,
                update_cache=should_update_cache,
            )

            metrics = official_evaluation.aggregate_metrics(
                results, FLAGS.use_empty_tables)
            tf.logging.info("Validation Results:\n\tExecution F1: %s",
                            metrics.execution_f1)
            # endregion

            experiment.checkpoint(
                step=global_step,
                metrics={
                    "train_loss": train_loss,
                    "eval_execution_f1": metrics.execution_f1,
                    "eval_string_match": metrics.string_same,
                },
                primary_metric=("eval_execution_f1", "maximize"),
            )

            # region disk management

            for step in checkpoints_to_delete(experiment):
                assert (
                    step != global_step
                ), f"Can't delete step {step}; need it for next training epoch starting at step {global_step}"
                print(f"Deleting checkpoint {step}")
                delete_checkpoint(FLAGS.model_dir, step)
예제 #27
0
def dense(
    x,
    units,
    activation=None,
    use_bias=True,
    kernel_initializer="glorot_uniform",
    bias_initializer="zeros",
    sparsity_technique="variational_dropout",
    auxiliary_initializer=None,
    threshold=3.0,
    clip_alpha=None,
    training=True,
    dtype=tf.float32,
    name=None,
    initial_sparsity=None):
  """Matmul & bias add that supports broadcasting for batched gemm.

  Supports a contrained set of functionality provided by tf.layers.dense.

  Args:
    x: input tensor.
    units: number of units in the dense layer.
    activation: activation function to use in the layer.
    use_bias: whether or not to add a bias to the output.
    kernel_initializer: weight initializer for the layer.
    bias_initializer: weight initializer for the bias.
    sparsity_technique: sparsification technique to apply to the weights.
    auxiliary_initializer: initializer for auxiliary variables use in
      variational dropout and l0 regularization.
    threshold: log-alpha threshold for variational dropout.
    clip_alpha: whether to clip the alpha values for variational dropout.
    training: whether this run is training or evaluation the model.
    dtype: data type for the weights and computation.
    name: name for the layer.
    initial_sparsity: initial weight sparsity at the start of training.

  Returns:
    Tensor representing the output of the layer.
  """
  activation = activations.get(activation)
  kernel_initializer = initializers.get(kernel_initializer)
  bias_initializer = initializers.get(bias_initializer)

  if (sparsity_technique == "magnitude_pruning" or
      sparsity_technique == "random_pruning"):
    if initial_sparsity is not None:
      # If the initial sparsity value is passed in, use the sparse glorot
      # uniform initializer to account for the zero valued weights.
      kernel_initializer = common_init.SparseGlorotUniform(
          initial_sparsity, dtype=dtype)
      tf.logging.info(
          "Using sparse initialization with sparsity {} for variable {}"
          .format(initial_sparsity, tf.get_variable_scope().name))

    # If the sparsity technique is magnitude_pruning, or random_pruning
    # use the model_pruning masked_fully_connected layer
    #
    # masked_fully_connected doesn't take use_bias arg, pass None for the
    # bias initializer if we don't want a bias variable
    bias_initializer = bias_initializer if use_bias else None
    with tf.variable_scope(name, default_name="dense"):
      return pruning_layers.masked_fully_connected(
          inputs=x,
          num_outputs=units,
          activation_fn=activation,
          weights_initializer=kernel_initializer,
          biases_initializer=bias_initializer)
  if initial_sparsity is not None:
    raise ValueError("initial_sparsity only supported for mp & rp")

  # layer_name = "%s_{}" % name if name else "{}"

  input_shape = x.get_shape().as_list()
  if input_shape[-1] is None:
    raise ValueError("The last dimension of the inputs to `Dense` "
                     "should be defined. Found `None`.")

  with tf.variable_scope(name, default_name="dense") as vs:
    kernel = tf.get_variable(
        "kernel",
        shape=[input_shape[-1], units],
        initializer=kernel_initializer,
        dtype=dtype,
        trainable=True)

    bias = None
    if use_bias:
      bias = tf.get_variable(
          "bias",
          shape=[units,],
          initializer=bias_initializer,
          dtype=dtype,
          trainable=True)

  # Compute the dense layer
  if sparsity_technique == "variational_dropout":
    log_sigma2_initializer = initializers.get(auxiliary_initializer)

    if not log_sigma2_initializer:
      log_sigma2_initializer = tf.constant_initializer(value=-10, dtype=dtype)

    with tf.variable_scope(vs, auxiliary_name_scope=False) as vs1:
      with tf.name_scope(vs1.original_name_scope):
        log_sigma2 = tf.get_variable(
            "log_sigma2",
            shape=[input_shape[-1], units],
            initializer=log_sigma2_initializer,
            dtype=dtype,
            trainable=True)

    variational_parameters = (kernel, log_sigma2)
    tf.add_to_collection(
        VARIATIONAL_DROPOUT_PARAMETERS,
        variational_parameters)

    input_rank = x.get_shape().ndims
    if input_rank > 2:
      if training:
        outputs = vd.nn.broadcast_matmul_train(
            x,
            variational_parameters,
            clip_alpha=clip_alpha)
      else:
        outputs = vd.nn.broadcast_matmul_eval(
            x,
            variational_parameters,
            threshold)
    else:
      if training:
        outputs = vd.nn.matmul_train(
            x,
            variational_parameters,
            clip_alpha=clip_alpha)
      else:
        outputs = vd.nn.matmul_eval(
            x,
            variational_parameters,
            threshold)
  else:
    if sparsity_technique != "l0_regularization":
      raise ValueError("Unsupported sparsity technique {}"
                       .format(sparsity_technique))
    log_alpha_initializer = initializers.get(auxiliary_initializer)

    if not log_alpha_initializer:
      # Default to \alpha / (\alpha + 1) equal to 0.5
      # Default to \alpha / (\alpha + 1) = .1
      log_alpha_initializer = tf.random_normal_initializer(
          mean=2.197, stddev=0.01, dtype=dtype)

    with tf.variable_scope(vs, auxiliary_name_scope=False) as vs1:
      with tf.name_scope(vs1.original_name_scope):
        log_alpha = tf.get_variable(
            "log_alpha",
            shape=[input_shape[-1], units],
            initializer=log_alpha_initializer,
            dtype=dtype,
            trainable=True)

    weight_parameters = (kernel, log_alpha)
    tf.add_to_collection(
        L0_REGULARIZATION_PARAMETERS,
        weight_parameters)

    input_rank = x.get_shape().ndims
    if input_rank > 2:
      if training:
        outputs = l0.nn.broadcast_matmul_train(x, weight_parameters)
      else:
        outputs = l0.nn.broadcast_matmul_eval(x, weight_parameters)
    else:
      if training:
        outputs = l0.nn.matmul_train(x, weight_parameters)
      else:
        outputs = l0.nn.matmul_eval(x, weight_parameters)

  # Handle the bias and activation
  if use_bias:
    outputs = tf.nn.bias_add(outputs, bias)
  if activation is not None:
    return activation(outputs)
  return outputs
예제 #28
0
def trace_scan(loop_fn,
               initial_state,
               elems,
               trace_fn,
               trace_criterion_fn=None,
               static_trace_allocation_size=None,
               parallel_iterations=10,
               name=None):
  """A simplified version of `tf.scan` that has configurable tracing.

  This function repeatedly calls `loop_fn(state, elem)`, where `state` is the
  `initial_state` during the first iteration, and the return value of `loop_fn`
  for every iteration thereafter. `elem` is a slice of `elements` along the
  first dimension, accessed in order. Additionally, it calls `trace_fn` on the
  return value of `loop_fn`. The `Tensor`s in return values of `trace_fn` are
  stacked and returned from this function, such that the first dimension of
  those `Tensor`s matches the size of `elems`.

  Args:
    loop_fn: A callable that takes in a `Tensor` or a nested collection of
      `Tensor`s with the same structure as `initial_state`, a slice of `elems`
      and returns the same structure as `initial_state`.
    initial_state: A `Tensor` or a nested collection of `Tensor`s passed to
      `loop_fn` in the first iteration.
    elems: A `Tensor` that is split along the first dimension and each element
      of which is passed to `loop_fn`.
    trace_fn: A callable that takes in the return value of `loop_fn` and returns
      a `Tensor` or a nested collection of `Tensor`s.
    trace_criterion_fn: Optional callable that takes in the return value of
      `loop_fn` and returns a boolean `Tensor` indicating whether to trace it.
      If `None`, all steps are traced.
      Default value: `None`.
    static_trace_allocation_size: Optional Python `int` size of trace to
      allocate statically. This should be an upper bound on the number of steps
      traced and is used only when the length cannot be
      statically inferred (for example, if a `trace_criterion_fn` is specified).
      It is primarily intended for contexts where static shapes are required,
      such as in XLA-compiled code.
      Default value: `None`.
    parallel_iterations: Passed to the internal `tf.while_loop`.
    name: Name scope used in this function. Default: 'trace_scan'.

  Returns:
    final_state: The final return value of `loop_fn`.
    trace: The same structure as the return value of `trace_fn`, but with each
      `Tensor` being a stack of the corresponding `Tensors` in the return value
      of `trace_fn` for each slice of `elems`.
  """
  with tf.name_scope(name or 'trace_scan'), tf1.variable_scope(
      tf1.get_variable_scope()) as vs:
    if vs.caching_device is None and not tf.executing_eagerly():
      vs.set_caching_device(lambda op: op.device)

    initial_state = tf.nest.map_structure(
        lambda x: tf.convert_to_tensor(x, name='initial_state'),
        initial_state, expand_composites=True)
    elems = tf.convert_to_tensor(elems, name='elems')

    length = ps.size0(elems)

    # This is an TensorArray in part because of XLA, which had trouble with
    # non-statically known indices. I.e. elems[i] errored, but
    # elems_array.read(i) worked.
    elems_array = tf.TensorArray(
        elems.dtype, size=length, element_shape=elems.shape[1:])
    elems_array = elems_array.unstack(elems)

    # Initialize trace arrays.
    if trace_criterion_fn is None:
      dynamic_size, initial_size = tf.is_tensor(length), length
    elif static_trace_allocation_size is not None:
      dynamic_size, initial_size = False, static_trace_allocation_size
    elif JAX_MODE or (not tf.executing_eagerly() and
                      control_flow_util.GraphOrParentsInXlaContext(
                          tf1.get_default_graph())):
      dynamic_size, initial_size = False, length
    else:
      dynamic_size, initial_size = True, 0
    initial_trace = trace_fn(initial_state)
    flat_initial_trace = tf.nest.flatten(initial_trace, expand_composites=True)
    trace_arrays = []
    for trace_elt in flat_initial_trace:
      trace_arrays.append(
          tf.TensorArray(
              trace_elt.dtype,
              size=initial_size,
              dynamic_size=dynamic_size,
              element_shape=trace_elt.shape))

    # Helper for writing a (structured) state to (structured) arrays.
    def trace_one_step(num_steps_traced, trace_arrays, state):
      return [ta.write(num_steps_traced, x) for ta, x in
              zip(trace_arrays,
                  tf.nest.flatten(trace_fn(state), expand_composites=True))]

    def _body(i, state, num_steps_traced, trace_arrays):
      elem = elems_array.read(i)
      state = loop_fn(state, elem)

      trace_arrays, num_steps_traced = ps.cond(
          trace_criterion_fn(state) if trace_criterion_fn else True,
          lambda: (trace_one_step(num_steps_traced, trace_arrays, state),  # pylint: disable=g-long-lambda
                   num_steps_traced + 1),
          lambda: (trace_arrays, num_steps_traced))

      return i + 1, state, num_steps_traced, trace_arrays

    _, final_state, _, trace_arrays = tf.while_loop(
        cond=lambda i, *_: i < length,
        body=_body,
        loop_vars=(0, initial_state, 0, trace_arrays),
        parallel_iterations=parallel_iterations)

    # unflatten
    stacked_trace = tf.nest.pack_sequence_as(
        initial_trace, [ta.stack() for ta in trace_arrays],
        expand_composites=True)

    # Restore the static length if we know it.
    static_length = tf.TensorShape(None if dynamic_size else initial_size)

    def _merge_static_length(x):
      tensorshape_util.set_shape(x, static_length.concatenate(x.shape[1:]))
      return x

    stacked_trace = tf.nest.map_structure(
        _merge_static_length, stacked_trace, expand_composites=True)
    return final_state, stacked_trace
예제 #29
0
    def _build_single_q_network(self, observations, head, state_t, state_tp1,
                                done_mask, reward_t, error_weight):
        """Builds the computational graph for a single Q network.

    Briefly, this part is calculating the following two quantities:
    1. q_value = q_fn(observations)
    2. td_error = q_fn(state_t) - reward_t - gamma * q_fn(state_tp1)
    The optimization target is to minimize the td_error.

    Args:
      observations: shape = [batch_size, hparams.fingerprint_length].
        The input of the Q function.
      head: shape = [1].
        The index of the head chosen for decision in bootstrap DQN.
      state_t: shape = [batch_size, hparams.fingerprint_length].
        The state at time step t.
      state_tp1: a list of tensors, with total number of batch_size,
        each has shape = [num_actions, hparams.fingerprint_length].
        Note that the num_actions can be different for each tensor.
        The state at time step t+1, tp1 is short for t plus 1.
      done_mask: shape = [batch_size, 1]
        Whether state_tp1 is the terminal state.
      reward_t: shape = [batch_size, 1]
        the reward at time step t.
      error_weight: shape = [batch_size, 1]
        weight for the loss.

    Returns:
      q_values: Tensor of [batch_size, 1]. The q values for the observations.
      td_error: Tensor of [batch_size, 1]. The TD error.
      weighted_error: Tensor of [batch_size, 1]. The TD error weighted by
        error_weight.
      q_fn_vars: List of tf.Variables. The variables of q_fn when computing
        the q_values of state_t
      q_fn_vars: List of tf.Variables. The variables of q_fn when computing
        the q_values of state_tp1

    """
        with tf.variable_scope('q_fn'):
            # q_value have shape [batch_size, 1].
            q_values = tf.gather(self.q_fn(observations), head, axis=-1)

        # calculating q_fn(state_t)
        # The Q network shares parameters with the action graph.
        with tf.variable_scope('q_fn', reuse=True):
            q_t = self.q_fn(state_t, reuse=True)
        q_fn_vars = tf.trainable_variables(scope=tf.get_variable_scope().name +
                                           '/q_fn')

        # calculating q_fn(state_tp1)
        with tf.variable_scope('q_tp1', reuse=tf.AUTO_REUSE):
            q_tp1 = [
                self.q_fn(s_tp1, reuse=tf.AUTO_REUSE) for s_tp1 in state_tp1
            ]
        q_tp1_vars = tf.trainable_variables(
            scope=tf.get_variable_scope().name + '/q_tp1')

        if self.double_q:
            with tf.variable_scope('q_fn', reuse=True):
                q_tp1_online = [
                    self.q_fn(s_tp1, reuse=True) for s_tp1 in state_tp1
                ]
            if self.num_bootstrap_heads:
                num_heads = self.num_bootstrap_heads
            else:
                num_heads = 1
            # determine the action to choose based on online Q estimator.
            q_tp1_online_idx = [
                tf.stack([
                    tf.argmax(q, axis=0),
                    tf.range(num_heads, dtype=tf.int64)
                ],
                         axis=1) for q in q_tp1_online
            ]
            # use the index from max online q_values to compute the value
            # function
            v_tp1 = tf.stack([
                tf.gather_nd(q, idx) for q, idx in zip(q_tp1, q_tp1_online_idx)
            ],
                             axis=0)
        else:
            v_tp1 = tf.stack([tf.reduce_max(q) for q in q_tp1], axis=0)

        # if s_{t+1} is the terminal state, we do not evaluate the Q value of
        # the state.
        q_tp1_masked = (1.0 - done_mask) * v_tp1

        q_t_target = reward_t + self.gamma * q_tp1_masked

        # stop gradient from flowing to the computating graph which computes
        # the Q value of s_{t+1}.
        # td_error has shape [batch_size, 1]
        td_error = q_t - tf.stop_gradient(q_t_target)

        # If use bootstrap, each head is trained with a different subset of the
        # training sample. Like the idea of dropout.
        if self.num_bootstrap_heads:
            head_mask = tf.keras.backend.random_binomial(
                shape=(1, self.num_bootstrap_heads), p=0.6)
            td_error = tf.reduce_mean(td_error * head_mask, axis=1)
        # The loss comes from a traditional trick in convex optimization:
        # http://web.stanford.edu/~boyd/cvxbook/.
        # See Chapter 6 pp. 298
        # It will makes the optimization robust.
        # Specifically, the loss will use l1 instead of l2 loss when the td error
        # gets larger than 1.0. The l2 loss has the disadvantage that it has
        # the tendency to be dominated by outliers. In terms of estimation theory,
        # the asymptotic relative efficiency of the l1 loss estimator is better
        # for heavy-tailed distributions.
        errors = tf.where(
            tf.abs(td_error) < 1.0,
            tf.square(td_error) * 0.5, 1.0 * (tf.abs(td_error) - 0.5))
        weighted_error = tf.reduce_mean(error_weight * errors)
        return q_values, td_error, weighted_error, q_fn_vars, q_tp1_vars
예제 #30
0
def bilinear(inputs1,
             inputs2,
             output_size,
             n_splits=1,
             add_bias1=True,
             add_bias2=True,
             initializer=None,
             moving_params=None):
    """ """

    # Prepare the input
    if not isinstance(inputs1, (list, tuple)):
        inputs1 = [inputs1]
    n_dims1 = len(inputs1[0].get_shape().as_list())
    all_inputs1 = tf.concat(inputs1, n_dims1 - 1)
    inputs1_size = all_inputs1.get_shape().as_list()[-1]
    inputs1_bucket_size = tf.shape(all_inputs1)[-2]

    if not isinstance(inputs2, (list, tuple)):
        inputs2 = [inputs2]
    n_dims2 = len(inputs2[0].get_shape().as_list())
    all_inputs2 = tf.concat(inputs2, n_dims2 - 1)
    inputs2_size = all_inputs2.get_shape().as_list()[-1]
    inputs2_bucket_size = tf.shape(all_inputs2)[-2]

    # Prepare the output
    output_size *= n_splits
    output_shape = []
    shape1 = tf.shape(all_inputs1)
    for i in range(n_dims1 - 1):
        output_shape.append(shape1[i])
    output_shape.append(output_size)
    output_shape.append(inputs2_bucket_size)
    output_shape = tf.stack(output_shape)

    all_inputs1 = tf.reshape(all_inputs1,
                             tf.stack([-1, inputs1_bucket_size, inputs1_size]))
    if add_bias1:
        bias1 = tf.ones(
            tf.stack([tf.shape(all_inputs1)[0], inputs1_bucket_size, 1]))
        all_inputs1 = tf.concat([all_inputs1, bias1], 2)
        inputs1_size += 1
    all_inputs2 = tf.reshape(all_inputs2,
                             tf.stack([-1, inputs2_bucket_size, inputs2_size]))
    if add_bias2:
        bias2 = tf.ones(
            tf.stack([tf.shape(all_inputs2)[0], inputs2_bucket_size, 1]))
        all_inputs2 = tf.concat([all_inputs2, bias2], 2)
        inputs2_size += 1
    with tf.variable_scope('Bilinear'):
        # Get the matrix
        if initializer is None and tf.get_variable_scope().reuse is None:
            mat = orthonormal_initializer(inputs1_size, inputs2_size)[:,
                                                                      None, :]
            mat = np.concatenate([mat] * output_size, axis=1)
        weights = tf.get_variable('Weights',
                                  [inputs1_size, output_size, inputs2_size],
                                  initializer=initializer)
        if moving_params is not None:
            weights = moving_params.average(weights)
        else:
            tf.add_to_collection('Weights', weights)

        # Do the multiplication
        # (bn x d) (d x rd) -> (bn x rd)
        lin = tf.matmul(tf.reshape(all_inputs1, [-1, inputs1_size]),
                        tf.reshape(weights, [inputs1_size, -1]))
        # (b x nr x d) (b x n x d)T -> (b x nr x n)
        bilin = tf.matmul(tf.reshape(
            lin,
            tf.stack([-1, inputs1_bucket_size * output_size, inputs2_size])),
                          all_inputs2,
                          transpose_b=True)
        # (bn x r x n)
        bilin = tf.reshape(bilin, output_shape)

        if n_splits > 1:
            return tf.split(bilin, n_splits, n_dims - 2)
        else:
            return bilin