Exemplo n.º 1
0
def conv1d(x, filters, kernel_size, strides=1, padding='causal', dilation_rate=1, act=None,
           init=None, scope="conv1d", use_bias=True):
    batch_size, seq_len, h = x.get_shape().as_list()
    # Taken from keras, there is a faster version from magenta
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        # assert seq_len % dilation_rate == 0

        w = tf.get_variable('kernel', shape=(kernel_size, h, filters), dtype=tf.float32, initializer=init)

        if padding == 'causal':
            # causal (dilated) convolution:
            left_pad = dilation_rate * (kernel_size - 1)
            pattern = [[0, 0], [left_pad, 0], [0, 0]]
            x = tf.pad(x, pattern)
            padding = 'VALID'

        out = tf.nn.convolution(
            input=x,
            filter=w,
            dilation_rate=(dilation_rate,),
            strides=(strides,),
            padding=padding)
        if use_bias:
            b = tf.get_variable('bias', shape=(filters,), dtype=tf.float32, initializer=tf.initializers.zeros)
            out = tf.add(out, b)
        if act is not None:
            return act(out)
    return out
Exemplo n.º 2
0
def eq_cifar_fn(x, output_dim=10, trainable=True):
    gconv_indices, gconv_shape_info, w_shape = gconv2d_util(h_input='Z2',
                                                            h_output='C4',
                                                            in_channels=3,
                                                            out_channels=8,
                                                            ksize=3)
    w = tf.get_variable('w1', shape=w_shape)

    conv1 = gconv2d(input=x,
                    filter=w,
                    strides=[1, 2, 2, 1],
                    padding='SAME',
                    gconv_indices=gconv_indices,
                    gconv_shape_info=gconv_shape_info)
    tf.add_to_collection('conv_output1', conv1)
    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)
    gconv_indices, gconv_shape_info, w_shape = gconv2d_util(h_input='C4',
                                                            h_output='C4',
                                                            in_channels=16,
                                                            out_channels=32,
                                                            ksize=5)
    w = tf.get_variable('w2', shape=w_shape)
    conv2 = gconv2d(input=conv1,
                    filter=w,
                    strides=[1, 2, 2, 1],
                    padding='SAME',
                    gconv_indices=gconv_indices,
                    gconv_shape_info=gconv_shape_info)
    pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)
    gconv_indices, gconv_shape_info, w_shape = gconv2d_util(h_input='C4',
                                                            h_output='C4',
                                                            in_channels=8,
                                                            out_channels=2,
                                                            ksize=5)
    w = tf.get_variable('w3', shape=w_shape)
    conv3 = gconv2d(input=conv2,
                    filter=w,
                    strides=[1, 1, 1, 1],
                    padding='SAME',
                    gconv_indices=gconv_indices,
                    gconv_shape_info=gconv_shape_info)
    conv3 = tf.reshape(conv3,
                       conv3.get_shape().as_list()[:3] + [4] + [out_channels])
    conv3 = tf.reduce_mean(conv3, axis=3)
    pool3 = tf.layers.max_pooling2d(inputs=conv3, pool_size=[2, 2], strides=2)
    pool3_flat = tf.layers.flatten(pool3)
    u = pool3_flat
    u = tf.layers.dense(inputs=pool3_flat,
                        units=output_dim,
                        activation=tf.nn.relu,
                        trainable=trainable)
    tf.add_to_collection('conv_output2', conv2)
    return u
Exemplo n.º 3
0
def eq_cnn_fn(x, output_dim=10, trainable=True, group='C4', num_filters=2):
    nchannels = x.shape[3]
    gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
        h_input='Z2',
        h_output='C4',
        in_channels=nchannels,
        out_channels=2,
        ksize=5)
    w = tf.get_variable('w1', shape=w_shape)

    conv1 = gconv2d(input=x,
                    filter=w,
                    strides=[1, 1, 1, 1],
                    padding='SAME',
                    gconv_indices=gconv_indices,
                    gconv_shape_info=gconv_shape_info)
    tf.add_to_collection('conv_output1', conv1)
    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

    # pool1 = layers.Dropout(0.25)(pool1)
    out_channels = 2
    gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
        h_input='C4',
        h_output='C4',
        in_channels=2,
        out_channels=out_channels,
        ksize=5)
    w = tf.get_variable('w2', shape=w_shape)
    conv2 = gconv2d(input=conv1,
                    filter=w,
                    strides=[1, 1, 1, 1],
                    padding='SAME',
                    gconv_indices=gconv_indices,
                    gconv_shape_info=gconv_shape_info)
    conv2 = tf.reshape(conv2,
                       conv2.get_shape().as_list()[:3] + [4] + [out_channels])
    conv2 = tf.reduce_mean(conv2, axis=3)
    conv2 = tf.reshape(conv2, conv2.get_shape().as_list()[:3] + [out_channels])
    pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)

    pool2_flat = tf.layers.flatten(pool2)
    u = pool2_flat
    print(u.shape)
    u = tf.layers.dense(inputs=pool2_flat,
                        units=output_dim,
                        activation=tf.nn.relu,
                        trainable=trainable)
    tf.add_to_collection('conv_output2', conv2)
    return u
    def _build_select_slate_op(self):
        p_no_click = self._prob_no_click_ph
        p = self._doc_affinity_scores_ph
        q = self._net_outputs.q_values[0]
        with tf.name_scope('select_slate'):
            self._output_slate = self._select_slate_fn(self._slate_size,
                                                       p_no_click, p, q)

        self._output_slate = tf.Print(
            self._output_slate,
            [tf.constant('cp 1'), self._output_slate, p, q],
            summarize=10000)
        self._output_slate = tf.reshape(self._output_slate,
                                        (self._slate_size, ))

        self._action_counts = tf.get_variable(
            'action_counts',
            shape=[self._num_candidates],
            initializer=tf.zeros_initializer())
        output_slate = tf.reshape(self._output_slate, [-1])
        output_one_hot = tf.one_hot(output_slate, self._num_candidates)
        update_ops = []
        for i in range(self._slate_size):
            update_ops.append(
                tf.assign_add(self._action_counts, output_one_hot[i]))
        self._select_action_update_op = tf.group(*update_ops)
def weight_variable(shape, weight_decay):
    """weight_variable generates a weight variable of a given shape."""
    initial = tf.initializers.truncated_normal(stddev=0.1)
    return tf.get_variable('weight',
                           shape=shape,
                           initializer=initial,
                           regularizer=tf.keras.regularizers.L2(weight_decay))
Exemplo n.º 6
0
 def __init__(self, input_shape, dict_size=(-1., 1., 20), gamma=None):
     self.d = tf.linspace(*dict_size)
     if gamma is None:
         self.gamma = .5 / tf.square(2 * (self.d[-1] - self.d[0]))  # (d_stop - d_start)*2
     else:
         self.gamma = gamma
     self.alpha = tf.get_variable('alpha', shape=(1, input_shape, self.d.get_shape()[0]),
                                  initializer=RidgeInit(gauss_kernel(self.d, self.d, self.gamma), self.d))
Exemplo n.º 7
0
def kaf(linear, name, kernel='rbf', D=None, gamma=None):
    if D is None:
        D = tf.linspace(start=-2., stop=2., num=20)

    with tf.variable_scope('kaf', reuse=tf.AUTO_REUSE):
        if kernel == "rbf":
            K = gauss_kernel(linear, D, gamma=gamma)
            alpha = tf.get_variable(name, shape=(1, linear.get_shape()[-1], D.get_shape()[0]),
                                    initializer=tf.random_normal_initializer(stddev=0.1))
        elif kernel == 'rbf2d':
            Dx, Dy = tf.meshgrid(D, D)
            K = gauss_kernel2D(linear, Dx, Dy, gamma=gamma)

            alpha = tf.get_variable(name,
                                    shape=(1, linear.get_shape()[-1] // 2, D.get_shape()[0] * D.get_shape()[0]),
                                    initializer=tf.random_normal_initializer(stddev=0.1))
        else:
            raise NotImplementedError()
        act = tf.reduce_sum(tf.multiply(K, alpha), axis=-1)
        # act = tf.squeeze(act, axis=0)
    return act
Exemplo n.º 8
0
def separate_head_linear_classifier(embeddings, num_classes, dataset_idx,
                                    start_idx, cosine_classifier,
                                    cosine_logits_multiplier, learnable_scale,
                                    weight_decay):
    """A linear classifier with num_sets heads, for different datasets.

  Args:
    embeddings: A Tensor of size [batch size, embedding dim].
    num_classes: A list of integers; the dimension of the classifier layers of
      the different heads.
    dataset_idx: An int Tensor. The index of the dataset head to use.
    start_idx: An int Tensor. The index of the first class of the given dataset.
    cosine_classifier: A bool. If true, a cosine classifier is used, which does
      not require a bias.
    cosine_logits_multiplier: A float. Only used if cosine_classifier is True,
      and multiplies the resulting logits.
    learnable_scale: A bool. Whether to make the cosine_logits_multiplier a
      learnable parameter. Only applies if cosine_classifier is True.
    weight_decay: A float; the scalar multiple on the L2 regularization of the
      weight matrix.

  Returns:
    logits: A Tensor of size [batch size, num outputs].
  """
    if not cosine_classifier:
        raise NotImplementedError(
            '`separate_head_linear_classifier` currently '
            'only supports `cosine_classifier` True.')

    if learnable_scale:
        cosine_logits_multiplier = tf.get_variable(
            'cosine_scale',
            initializer=cosine_logits_multiplier,
            dtype=tf.float32,
            trainable=True)

    embedding_dims = embeddings.get_shape().as_list()[-1]
    w_fc = functional_backbones.weight_variable(
        [embedding_dims, sum(num_classes)], weight_decay=weight_decay)

    # Select the output "head" to use in the forward pass.
    dataset_num_classes = tf.gather(num_classes, dataset_idx)
    w_fc = w_fc[:, start_idx:start_idx + dataset_num_classes]

    logits = linear_classifier_forward_pass(embeddings, w_fc, None,
                                            cosine_classifier,
                                            cosine_logits_multiplier, False)
    return logits
Exemplo n.º 9
0
    def build_graph(self):
        """Builds the neural network graph."""

        # define graph
        self.g = tf.Graph()
        with self.g.as_default():

            # create and store a new session for the graph
            self.sess = tf.Session()

            # define placeholders
            self.x = tf.placeholder(shape=[None, self.dim_input],
                                    dtype=tf.float32)
            self.y = tf.placeholder(shape=[None, self.num_classes],
                                    dtype=tf.float32)

            # linear layer(WX + b)
            with tf.variable_scope('last_layer/dense') as scope:
                weights = tf.get_variable('kernel',
                                          [self.dim_input, self.num_classes],
                                          dtype=tf.float32)
                biases = tf.get_variable('bias', [self.num_classes],
                                         dtype=tf.float32)
                wb = tf.concat([weights, tf.expand_dims(biases, axis=0)], 0)
                wb_renorm = tf.matmul(self.sigma_half_inv, wb)
                weights_renorm = wb_renorm[:self.dim_input, :]
                biases_renorm = wb_renorm[-1, :]
                self.z = tf.add(tf.matmul(self.x, weights_renorm),
                                biases_renorm,
                                name=scope.name)

            # Gaussian prior
            # prior = tf.nn.l2_loss(weights) + tf.nn.l2_loss(biases)

            # Non normalized loss, because of the preconditioning
            self.loss = self.n * tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits_v2(labels=self.y,
                                                           logits=self.z))

            # Bayesian loss
            self.bayesian_loss = self.loss  # + prior

            self.output_probs = tf.nn.softmax(self.z)

            # Variables of the last layer
            self.ll_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            self.ll_vars_concat = tf.concat(
                [self.ll_vars[0],
                 tf.expand_dims(self.ll_vars[1], axis=0)], 0)

            # Summary
            _variable_summaries(self.ll_vars_concat)

            # saving the weights of last layer when running SGLD/SGD/MCMC algorithm
            self.saver = tf.train.Saver(var_list=self.ll_vars,
                                        max_to_keep=self.num_samples)

            self.gd_opt = tf.train.GradientDescentOptimizer(self.step_size)
            # SGLD optimizer for the last layer
            if self.sampler in ['sgld', 'lmc']:
                grads_vars = self.gd_opt.compute_gradients(self.bayesian_loss)
                grads_vars_sgld = []

                for g, v in grads_vars:
                    if g is not None:
                        s = list(v.name)
                        s[v.name.rindex(':')] = '_'
                        # Adding Gaussian noise to the gradient
                        gaussian_noise = (np.sqrt(2. / self.step_size) *
                                          tf.random_normal(tf.shape(g)))
                        g_sgld = g + gaussian_noise
                        tf.summary.histogram(''.join(s) + '/grad_hist_mcmc', g)
                        tf.summary.histogram(
                            ''.join(s) + '/gaussian_noise_hist_mcmc',
                            gaussian_noise)
                        tf.summary.histogram(
                            ''.join(s) + '/grad_total_hist_mcmc', g_sgld)
                        grads_vars_sgld.append((g_sgld, v))

                self.train_op = self.gd_opt.apply_gradients(grads_vars_sgld)

            # SGD optimizer for the last layer
            if self.sampler == 'sgd':
                grads_vars_sgd = self.gd_opt.compute_gradients(self.loss)
                self.train_op = self.gd_opt.apply_gradients(grads_vars_sgd)

                for g, v in grads_vars_sgd:
                    if g is not None:
                        s = list(v.name)
                        s[v.name.rindex(':')] = '_'
                        tf.summary.histogram(''.join(s) + '/grad_hist_sgd', g)

            # Merge all the summaries and write them out
            self.all_summaries = tf.summary.merge_all()
            location = os.path.join(self.working_dir, 'logs')
            self.writer = tf.summary.FileWriter(location, graph=self.g)

            saver_network = tf.train.Saver(var_list=self.ll_vars)
            print('loading the network ...')
            # Restores from checkpoint
            saver_network.restore(self.sess, self.model_dir)
            print('Graph successfully loaded.')
Exemplo n.º 10
0
def linear_classifier(embeddings, num_classes, cosine_classifier,
                      cosine_logits_multiplier, use_weight_norm, weight_decay):
    """Forward pass through a linear classifier, or possibly a cosine classifier.

  Args:
    embeddings: A Tensor of size [batch size, embedding dim].
    num_classes: An integer; the dimension of the classification.
    cosine_classifier: A bool. If true, a cosine classifier is used, which does
      not require a bias.
    cosine_logits_multiplier: A float. Only used if cosine_classifier is True,
      and multiplies the resulting logits.
    use_weight_norm: A bool. Whether weight norm was used. If so, then if using
      cosine classifier, normalize only the embeddings but not the weights.
    weight_decay: A float; the scalar multiple on the L2 regularization of the
      weight matrix.

  Returns:
    logits: A Tensor of size [batch size, num outputs].
  """

    embedding_dims = embeddings.get_shape().as_list()[-1]

    if use_weight_norm:
        # A variable to keep track of whether the initialization has already
        # happened.
        data_dependent_init_done = tf.get_variable('data_dependent_init_done',
                                                   initializer=0,
                                                   dtype=tf.int32,
                                                   trainable=False)

        w_fc = tf.get_variable('w_fc', [embedding_dims, num_classes],
                               initializer=tf.random_normal_initializer(
                                   0, 0.05),
                               trainable=True)
        # This init is temporary as it needs to be done in a data-dependent way.
        # It will be overwritten during the first forward pass through this layer.
        g = tf.get_variable('g',
                            dtype=tf.float32,
                            initializer=tf.ones([num_classes]),
                            trainable=True)
        b_fc = None
        if not cosine_classifier:
            # Also initialize a bias.
            b_fc = tf.get_variable('b_fc',
                                   initializer=tf.zeros([num_classes]),
                                   trainable=True)

        def _do_data_dependent_init():
            """Returns ops for the data-dependent init of g and maybe b_fc."""
            w_fc_normalized = tf.nn.l2_normalize(w_fc.read_value(), [0])
            output_init = tf.matmul(embeddings, w_fc_normalized)
            mean_init, var_init = tf.nn.moments(output_init, [0])
            # Data-dependent init values.
            g_init_value = 1. / tf.sqrt(var_init + 1e-10)
            ops = [tf.assign(g, g_init_value)]
            if not cosine_classifier:
                # Also initialize a bias in a data-dependent way.
                b_fc_init_value = -mean_init * g_init_value
                ops.append(tf.assign(b_fc, b_fc_init_value))
            # Mark that the data-dependent initialization is done to prevent it from
            # happening again in the future.
            ops.append(tf.assign(data_dependent_init_done, 1))
            return tf.group(*ops)

        # Possibly perform data-dependent init (if it hasn't been done already).
        init_op = tf.cond(tf.equal(data_dependent_init_done, 0),
                          _do_data_dependent_init, tf.no_op)

        with tf.control_dependencies([init_op]):
            # Apply weight normalization.
            w_fc *= g / tf.sqrt(tf.reduce_sum(tf.square(w_fc), [0]))
            # Forward pass through the layer defined by w_fc and b_fc.
            logits = linear_classifier_forward_pass(embeddings, w_fc, b_fc,
                                                    cosine_classifier,
                                                    cosine_logits_multiplier,
                                                    True)

    else:
        # No weight norm.
        w_fc = functional_backbones.weight_variable(
            [embedding_dims, num_classes], weight_decay=weight_decay)
        b_fc = None
        if not cosine_classifier:
            # Also initialize a bias.
            b_fc = functional_backbones.bias_variable([num_classes])
        # Forward pass through the layer defined by w_fc and b_fc.
        logits = linear_classifier_forward_pass(embeddings, w_fc, b_fc,
                                                cosine_classifier,
                                                cosine_logits_multiplier,
                                                False)
    return logits
Exemplo n.º 11
0
def bias_variable(shape):
  """bias_variable generates a bias variable of a given shape."""
  initial = tf.initializers.constant(0.1)
  return tf.get_variable('bias', shape=shape, initializer=initial)
Exemplo n.º 12
0
def bn(x,
       params=None,
       moments=None,
       backprop_through_moments=True,
       use_ema=False,
       is_training=True,
       ema_epsilon=.9):
  """Batch normalization.

  The usage should be as follows: If x is the support images, moments should be
  None so that they are computed from the support set examples. On the other
  hand, if x is the query images, the moments argument should be used in order
  to pass in the mean and var that were computed from the support set.

  Args:
    x: inputs.
    params: None or a dict containing the values of the offset and scale params.
    moments: None or a dict containing the values of the mean and var to use for
      batch normalization.
    backprop_through_moments: Whether to allow gradients to flow through the
      given support set moments. Only applies to non-transductive batch norm.
    use_ema: apply moving averages of batch norm statistics, or update them,
      depending on whether we are training or testing.  Note that passing
      moments will override this setting, and result in neither updating or
      using ema statistics.  This is important to make sure that episodic
      learners don't update ema statistics a second time when processing
      queries.
    is_training: if use_ema=True, this determines whether to apply the moving
      averages, or update them.
    ema_epsilon: if updating moving averages, use this value for the
      exponential moving averages.

  Returns:
    output: The result of applying batch normalization to the input.
    params: The updated params.
    moments: The updated moments.
  """
  params_keys, params_vars, moments_keys, moments_vars = [], [], [], []

  with tf.variable_scope('batch_norm'):
    scope_name = tf.get_variable_scope().name

    if use_ema:
      ema_shape = [1, 1, 1, x.get_shape().as_list()[-1]]
      mean_ema = tf.get_variable(
          'mean_ema',
          shape=ema_shape,
          initializer=tf.initializers.zeros(),
          trainable=False)
      var_ema = tf.get_variable(
          'var_ema',
          shape=ema_shape,
          initializer=tf.initializers.ones(),
          trainable=False)

    if moments is not None:
      if backprop_through_moments:
        mean = moments[scope_name + '/mean']
        var = moments[scope_name + '/var']
      else:
        # This variant does not yield good resutls.
        mean = tf.stop_gradient(moments[scope_name + '/mean'])
        var = tf.stop_gradient(moments[scope_name + '/var'])
    elif use_ema and not is_training:
      mean = mean_ema
      var = var_ema
    else:
      # If not provided, compute the mean and var of the current batch.

      replica_ctx = tf.distribute.get_replica_context()
      if replica_ctx:
        # from third_party/tensorflow/python/keras/layers/normalization_v2.py
        axes = list(range(len(x.shape) - 1))
        local_sum = tf.reduce_sum(x, axis=axes, keepdims=True)
        local_squared_sum = tf.reduce_sum(
            tf.square(x), axis=axes, keepdims=True)
        batch_size = tf.cast(tf.shape(x)[0], tf.float32)
        x_sum, x_squared_sum, global_batch_size = (
            replica_ctx.all_reduce('sum',
                                   [local_sum, local_squared_sum, batch_size]))

        axes_vals = [(tf.shape(x))[i] for i in range(1, len(axes))]
        multiplier = tf.cast(tf.reduce_prod(axes_vals), tf.float32)
        multiplier = multiplier * global_batch_size

        mean = x_sum / multiplier
        x_squared_mean = x_squared_sum / multiplier
        # var = E(x^2) - E(x)^2
        var = x_squared_mean - tf.square(mean)
      else:
        mean, var = tf.nn.moments(
            x, axes=list(range(len(x.shape) - 1)), keep_dims=True)

    # Only update ema's if training and we computed the moments in the current
    # call.  Note: at test time for episodic learners, ema's may be passed
    # from the support set to the query set, even if it's not really needed.
    if use_ema and is_training and moments is None:
      replica_ctx = tf.distribute.get_replica_context()
      mean_upd = tf.assign(mean_ema,
                           mean_ema * ema_epsilon + mean * (1.0 - ema_epsilon))
      var_upd = tf.assign(var_ema,
                          var_ema * ema_epsilon + var * (1.0 - ema_epsilon))
      updates = tf.group([mean_upd, var_upd])
      if replica_ctx:
        tf.add_to_collection(
            tf.GraphKeys.UPDATE_OPS,
            tf.cond(
                tf.equal(replica_ctx.replica_id_in_sync_group, 0),
                lambda: updates, tf.no_op))
      else:
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, updates)

    moments_keys += [scope_name + '/mean']
    moments_vars += [mean]
    moments_keys += [scope_name + '/var']
    moments_vars += [var]

    if params is None:
      offset = tf.get_variable(
          'offset',
          shape=mean.get_shape().as_list(),
          initializer=tf.initializers.zeros())
      scale = tf.get_variable(
          'scale',
          shape=var.get_shape().as_list(),
          initializer=tf.initializers.ones())
    else:
      offset = params[scope_name + '/offset']
      scale = params[scope_name + '/scale']

    params_keys += [scope_name + '/offset']
    params_vars += [offset]
    params_keys += [scope_name + '/scale']
    params_vars += [scale]

    output = tf.nn.batch_normalization(x, mean, var, offset, scale, 0.00001)

    params = collections.OrderedDict(zip(params_keys, params_vars))
    moments = collections.OrderedDict(zip(moments_keys, moments_vars))

    return output, params, moments
def bn(x, params=None, moments=None, backprop_through_moments=True):
    """Batch normalization.

  The usage should be as follows: If x is the support images, moments should be
  None so that they are computed from the support set examples. On the other
  hand, if x is the query images, the moments argument should be used in order
  to pass in the mean and var that were computed from the support set.

  Args:
    x: inputs.
    params: None or a dict containing the values of the offset and scale params.
    moments: None or a dict containing the values of the mean and var to use for
      batch normalization.
    backprop_through_moments: Whether to allow gradients to flow through the
      given support set moments. Only applies to non-transductive batch norm.

  Returns:
    output: The result of applying batch normalization to the input.
    params: The updated params.
    moments: The updated moments.
  """
    params_keys, params_vars, moments_keys, moments_vars = [], [], [], []

    with tf.variable_scope('batch_norm'):
        scope_name = tf.get_variable_scope().name
        if moments is None:
            # If not provided, compute the mean and var of the current batch.
            mean, var = tf.nn.moments(x,
                                      axes=list(range(len(x.shape) - 1)),
                                      keep_dims=True)
        else:
            if backprop_through_moments:
                mean = moments[scope_name + '/mean']
                var = moments[scope_name + '/var']
            else:
                # This variant does not yield good resutls.
                mean = tf.stop_gradient(moments[scope_name + '/mean'])
                var = tf.stop_gradient(moments[scope_name + '/var'])

        moments_keys += [scope_name + '/mean']
        moments_vars += [mean]
        moments_keys += [scope_name + '/var']
        moments_vars += [var]

        if params is None:
            offset = tf.get_variable('offset',
                                     shape=mean.get_shape().as_list(),
                                     initializer=tf.initializers.zeros())
            scale = tf.get_variable('scale',
                                    shape=var.get_shape().as_list(),
                                    initializer=tf.initializers.ones())
        else:
            offset = params[scope_name + '/offset']
            scale = params[scope_name + '/scale']

        params_keys += [scope_name + '/offset']
        params_vars += [offset]
        params_keys += [scope_name + '/scale']
        params_vars += [scale]

        output = tf.nn.batch_normalization(x, mean, var, offset, scale,
                                           0.00001)
        params = collections.OrderedDict(zip(params_keys, params_vars))
        moments = collections.OrderedDict(zip(moments_keys, moments_vars))
        return output, params, moments
Exemplo n.º 14
0
def get_train_op(loss,
                 learning_rate=0.001,
                 lr_decay_steps=10000,
                 lr_decay_rate=0.98,
                 gradient_clip_norm=3.0,
                 use_tpu=True,
                 variables=None):
    """Get training operation with gradient clipping and learning rate decay.

  Distilled from tf.contrib.layers.optimize_loss().
  Args:
    loss: Scalar tensor of the loss function.
    learning_rate: Scalar initial learning rate.
    lr_decay_steps: Exponential decay timescale.
    lr_decay_rate: Exponential decay magnitude.
    gradient_clip_norm: Global norm by which to scale gradients.
    use_tpu: Use tpu for training.
    variables: List of variables to optimize. tf.trainable_variables() if None.

  Returns:
    train_op: Operation that runs one iteration of training.
  """
    global_step = tf.train.get_or_create_global_step()

    with tf.variable_scope('training', values=[loss, global_step]):
        # Make sure update ops run before computing loss.
        update_ops = list(set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)))
        with tf.control_dependencies(update_ops):
            loss = tf.identity(loss)

        # Learning rate variable, with decay.
        learning_rate_decay_fn = functools.partial(tf.train.exponential_decay,
                                                   decay_steps=lr_decay_steps,
                                                   decay_rate=lr_decay_rate,
                                                   staircase=True)
        lr = tf.get_variable(
            'learning_rate', [],
            trainable=False,
            initializer=tf.constant_initializer(learning_rate))
        lr = learning_rate_decay_fn(lr, global_step)

        # Optimizer.
        opt = tf.train.AdamOptimizer(lr)
        if use_tpu:
            opt = tf.tpu.CrossShardOptimizer(opt)

        # All trainable variables, if specific variables are not specified.
        if variables is None:
            variables = tf.trainable_variables()

        # Compute gradients.
        gradients = opt.compute_gradients(loss,
                                          variables,
                                          colocate_gradients_with_ops=False)

        # Optionally clip gradients by global norm.
        if isinstance(gradient_clip_norm, float):
            gradients = _clip_gradients_by_norm(gradients, gradient_clip_norm)

        # Create gradient updates.
        grad_updates = opt.apply_gradients(gradients,
                                           global_step=global_step,
                                           name='train')

        # Ensure the train_op computes grad_updates.
        with tf.control_dependencies([grad_updates]):
            train_op = tf.identity(loss)

        return train_op