Esempio n. 1
0
def equirectangular_padding(images, num_paddings):
    """Pad equirectangular panorama images.

  Args:
    images: a 4-D tensor of shape `[BATCH, HEIGHT, WIDTH, CHANNELS]`.
    num_paddings: a 2x2 integer list [[n_top, n_bottom], [n_left, n_right]]
      representing the number of rows or columns to pad around the images.

  Returns:
    a 4-D tensor representing padded images with a shape of
    `[BATCH, n_top+HEIGHT+n_bottom, n_left+WIDTH+n_right, CHANNELS]`.

  Raises:
    ValueError: 'images' has the wrong dimensions.
                The number of paddings exceeds the height or width dimension.
  """
    with tf.name_scope(None, 'equirectangular_padding',
                       [images, num_paddings]):
        if len(images.shape) != 4:
            raise ValueError("'images' has the wrong dimensions.")

        shape = images.shape.as_list()
        height, width = shape[1], shape[2]
        top, down = num_paddings[0][0], num_paddings[0][1]
        left, right = num_paddings[1][0], num_paddings[1][1]
        if top > height or down > height:
            raise ValueError(
                'The number of paddings exceeds the height dimension.')
        if left > width or right > width:
            raise ValueError(
                'The number of paddings exceeds the width dimension.')

        semicircle = tf.cast(width / 2, tf.int32)
        # The padded rows should be symmetric to 'images', but they should be
        # shifted by 180 degrees. Copy the rightmost column (2*pi-w) as the padded
        # colomn on the left and copy the leftmost column (0+w) as the padded colomn
        # on the right.
        top_padding = tf.reverse(tf.roll(images[:, :top, :, :],
                                         axis=2,
                                         shift=semicircle),
                                 axis=[1])
        bottom_padding = tf.roll(tf.reverse(images, axis=[1])[:, :down, :, :],
                                 axis=2,
                                 shift=semicircle)
        padded_images = tf.concat([top_padding, images, bottom_padding], 1)
        left_padding = tf.reverse(tf.reverse(padded_images,
                                             axis=[2])[:, :, :left, :],
                                  axis=[2])
        right_padding = padded_images[:, :, :right, :]
        padded_images = tf.concat([left_padding, padded_images, right_padding],
                                  2)
        return padded_images
Esempio n. 2
0
def create_seq(character, action_metadata, direction, length=8, start=0):
    """Creates a sequence.

  Args:
    character: A character sprite tensor.
    action_metadata: An action metadata tuple.
    direction: An integer representing the direction, i.e., the row
      offset within each action group corresponding to a particular
      direction.
    length: Desired length of the sequence. If this is longer than
      the number of available frames, it will roll over to the
      beginning.
    start: Index of possible frames at which to start the sequence.

  Returns:
    A sequence tensor.
  """
    sprite_start = (action_metadata[0] + direction) * FRAME_SIZE
    sprite_end = (action_metadata[0] + direction + 1) * FRAME_SIZE
    sprite_line = character[sprite_start:sprite_end, ...]

    # Extract 64x64 patches that are side-by-side in the sprite, and limit
    # to the actual number of frames for the given action.
    frames = tf.stack(tf.split(sprite_line, 13, axis=1))  # 13 is a hack
    frames = frames[0:action_metadata[1]]

    # Extract a slice of the desired length.
    # NOTE: Length could be longer than the number of frames, so tile as needed.
    frames = tf.roll(frames, shift=-start, axis=0)
    frames = tf.tile(frames, [2, 1, 1, 1])  # 2 is a hack
    frames = frames[:length]
    frames = tf.cast(frames, dtype=tf.float32)
    frames.set_shape([length, FRAME_SIZE, FRAME_SIZE, CHANNELS])
    return frames
Esempio n. 3
0
 def _aug_roll(self, images):
     """Smooths the image using a 5x5 uniform kernel."""
     width, height = images.shape[1], images.shape[2]
     width, height = tf.cast(width, tf.float32), tf.cast(height, tf.float32)
     x = tf.cast(self.roll_x.value * width, tf.int32)
     y = tf.cast(self.roll_y.value * height, tf.int32)
     return tf.roll(images, [x, y], axis=[1, 2])
Esempio n. 4
0
 def build_graph(parameters):
     input_value = tf.compat.v1.placeholder(dtype=parameters["input_dtype"],
                                            name="input",
                                            shape=parameters["input_shape"])
     outs = tf.roll(input_value,
                    shift=parameters["shift"],
                    axis=parameters["axis"])
     return [input_value], [outs]
Esempio n. 5
0
def _smoothness_helper(motion_map):
    """Calculates L1 (total variation) smoothness loss of a tensor.
    Args:
      motion_map: A tensor to be smoothed, of shape [B, H, W, C].
    Returns:
      A scalar tf.Tensor, The total variation loss.
    """
    # We roll in order to impose continuity across the boundary. The motivation is
    # that there is some ambiguity between rotation and spatial gradients of
    # translation maps. We would like to discourage spatial gradients of the
    # translation field, and to absorb sich gradients into the rotation as much as
    # possible. This is why we impose continuity across the spatial boundary.
    motion_map_dx = motion_map - tf.roll(motion_map, 1, 1)
    motion_map_dy = motion_map - tf.roll(motion_map, 1, 2)
    sm_loss = tf.sqrt(1e-24 + tf.square(motion_map_dx) +
                      tf.square(motion_map_dy))
    tf.summary.image("motion_sm", sm_loss)
    return tf.reduce_mean(sm_loss)
    def test_pano_shift_transform(self):
        random_state = np.random.RandomState(seed=0)
        with tf.name_scope("test_shift_pano_forward_pass"), self.session():
            pano_stack = tf.constant(
                random_state.uniform(size=[1, 128, 256, 3]), dtype=tf.float32)
            # This is 90 degree rotation about the z-axis with the convention of
            # z-axis facing down in the world coordinate system.
            yaw_rotation_radians = 0.5 * np.pi * tf.ones(shape=[1],
                                                         dtype=tf.float32)
            rotated_panorama = pano_transformer.rotate_pano_horizontally(
                pano_stack, yaw_rotation_radians)
            # A 90 degree rotation is a 256 / 64=64 pixel shift of the
            # panorama. Shift works in the opposite direction to rotation
            # to tf.roll's pixel shifting behavior.
            rolled_pano = tf.roll(pano_stack, -64, axis=2)
            self.assertAllClose(rotated_panorama.eval(), rolled_pano.eval())

            # Assert the opposite rotation as well.
            yaw_rotation_radians = -0.5 * np.pi * tf.ones(shape=[1],
                                                          dtype=tf.float32)
            rotated_panorama = pano_transformer.rotate_pano_horizontally(
                pano_stack, yaw_rotation_radians)
            rolled_pano = tf.roll(pano_stack, 64, axis=2)
            self.assertAllClose(rotated_panorama.eval(), rolled_pano.eval())
def _pad_mics_tf(signal, new_mics):
    """Pads new mic channels to an input tensor and returns the updated tensor.

  Args:
    signal: A tf.tensor of shape (input_mics, samples)
    new_mics: The number of new mic channels to be added (integer scalar tensor)
  Returns:
    padded_signal: A tf.tensor of shape (input_mics + new_mics, samples)
  """
    # Take first new_mics channels and shift them by 1 sample.
    new_inputs = tf.roll(signal[:new_mics, :], shift=1, axis=-1)
    # Add noise 1e-3 times the RMS value in the signal.
    noise_scale = 1e-3 * tf.sqrt(tf.reduce_mean(tf.square(new_inputs)))
    new_inputs += noise_scale * tf.random.normal(tf.shape(new_inputs))
    return tf.concat((signal, new_inputs), axis=0)
Esempio n. 8
0
 def build_graph(parameters):
     input_tensor = tf.compat.v1.placeholder(
         dtype=parameters["input_dtype"],
         name="input",
         shape=parameters["input_shape"])
     shift_tensor = tf.compat.v1.placeholder(dtype=tf.int64,
                                             name="shift",
                                             shape=get_shape(
                                                 parameters["shift"]))
     axis_tensor = tf.compat.v1.placeholder(dtype=tf.int64,
                                            name="axis",
                                            shape=get_shape(
                                                parameters["axis"]))
     outs = tf.roll(input_tensor, shift_tensor, axis_tensor)
     return [input_tensor, shift_tensor, axis_tensor], [outs]
Esempio n. 9
0
def l1smoothness(tensor, wrap_around=True):
    """Calculates L1 (total variation) smoothness loss of a tensor.

  Args:
    tensor: A tensor to be smoothed, of shape [B, H, W, C].
    wrap_around: True to wrap around the last pixels to the first.

  Returns:
    A scalar tf.Tensor, The total variation loss.
  """
    with tf.name_scope('l1smoothness'):
        tensor_dx = tensor - tf.roll(tensor, 1, 1)
        tensor_dy = tensor - tf.roll(tensor, 1, 2)
        # We optionally wrap around in order to impose continuity across the
        # boundary. The motivation is that there is some ambiguity between rotation
        # and spatial gradients of translation maps. We would like to discourage
        # spatial gradients of the translation field, and to absorb sich gradients
        # into the rotation as much as possible. This is why we impose continuity
        # across the spatial boundary.
        if not wrap_around:
            tensor_dx = tensor_dx[:, 1:, 1:, :]
            tensor_dy = tensor_dy[:, 1:, 1:, :]
        return tf.reduce_mean(
            tf.sqrt(1e-24 + tf.square(tensor_dx) + tf.square(tensor_dy)))
Esempio n. 10
0
def ir2tf(imp_resp, shape, sess, dim=None, is_real=True):
    """Compute the transfer function of an impulse response (IR).
    This function makes the necessary correct zero-padding, zero
    convention, correct fft2, etc... to compute the transfer function
    of IR. To use with unitary Fourier transform for the signal (ufftn
    or equivalent).

    Parameters
    ----------
    imp_resp : ndarray
        The impulse responses.
    shape : tuple of int
        A tuple of integer corresponding to the target shape of the
        transfer function.
    dim : int, optional
        The last axis along which to compute the transform. All
        axes by default.
    is_real : boolean, optional
       If True (default), imp_resp is supposed real and the Hermitian property
       is used with rfftn Fourier transform.

    Returns
    -------
    y : complex ndarray
       The transfer function of shape ``shape``.
    """
    if not dim:
        dim = len(imp_resp.shape)
    # Zero padding and fill
    irpadded = tf.Variable(tf.zeros(shape))
    sess.run(tf.variables_initializer([irpadded]))
    sess.run(
        tf.assign(irpadded[tuple([slice(0, s) for s in imp_resp.shape])],
                  imp_resp))

    # Roll for zero convention of the fft to avoid the phase
    # problem. Work with odd and even size.
    for axis, axis_size in enumerate(imp_resp.shape):
        if axis >= len(imp_resp.shape) - dim:
            irpadded = tf.roll(
                irpadded,
                shift=-tf.cast(tf.floor(tf.cast(axis_size, tf.int32) / 2),
                               tf.int32),
                axis=axis)
    if is_real:
        return tf.signal.rfft2d(irpadded)
    else:
        return tf.fft2d(tf.cast(irpadded, tf.complex64))
Esempio n. 11
0
def diffsort(offsets):
    """Calculate the argsort of the difference between the input offsets.

  Useful for sorting row indices in sparse matrices.

  Args:
    offsets: Tensor, array of offsets for the sparse of each row, where
      `offset[i+1] - offsets[i]` is the length of row i. Length `m+1`,
      where 'm' is the number of rows.

  Returns:
    Tensor, array of row indices sorted by row length, from largest to
      smallest.
  """
    diffs = (offsets - tf.roll(offsets, shift=-1, axis=0))[:-1]
    return tf.cast(tf.argsort(diffs, direction="DESCENDING"), tf.uint32)
Esempio n. 12
0
def benoit_noise(image):

    dnoise = 0

    dnoise = dnoise + tf.math.reduce_mean(
        tf.math.square(image - tf.roll(image, shift=[1, 0], axis=[0, 1])))
    dnoise = dnoise + tf.math.reduce_mean(
        tf.math.square(image - tf.roll(image, shift=[0, 1], axis=[0, 1])))
    dnoise = dnoise + tf.math.reduce_mean(
        tf.math.square(image - tf.roll(image, shift=[0, -1], axis=[0, 1])))
    dnoise = dnoise + tf.math.reduce_mean(
        tf.math.square(image - tf.roll(image, shift=[-1, 0], axis=[0, 1])))

    dnoise = dnoise + 0.5 * tf.math.reduce_mean(
        tf.math.square(image - tf.roll(image, shift=[1, 1], axis=[0, 1])))
    dnoise = dnoise + 0.5 * tf.math.reduce_mean(
        tf.math.square(image - tf.roll(image, shift=[-1, 1], axis=[0, 1])))
    dnoise = dnoise + 0.5 * tf.math.reduce_mean(
        tf.math.square(image - tf.roll(image, shift=[1, -1], axis=[0, 1])))
    dnoise = dnoise + 0.5 * tf.math.reduce_mean(
        tf.math.square(image - tf.roll(image, shift=[-1, -1], axis=[0, 1])))

    return tf.math.sqrt(dnoise)
Esempio n. 13
0
    def build(self, input_layer, ntm_layer, idx=0):
        w = super().build(input_layer, ntm_layer, idx)
        self.shifts = tf.keras.layers.Dense(self.shift_size,
                                            activation=self.shift_activation,
                                            name=self.name +
                                            "_gate")(input_layer)
        self.gamma = tf.keras.layers.Dense(1,
                                           activation="relu",
                                           name=self.name +
                                           "_gamma")(input_layer)

        self.layers.append(self.shifts)
        self.layers.append(self.gamma)

        paddings = tf.Variable(tf.zeros((tf.rank(self.shifts), 2),
                                        dtype=tf.int32),
                               dtype=tf.int32)
        npaddings = paddings[-1, 1].assign(ntm_layer.lines - self.shift_size)

        self.shifts = tf.pad(self.shifts, npaddings)
        rows = []
        for i in range(ntm_layer.lines):
            p = (i + 1) + ntm_layer.lines - self.m
            rows.append(tf.roll(self.shifts, shift=p, axis=-1))
        mshifts = tf.stack(rows, axis=1)
        #[None, M] -->  [None, M, 1]
        bkp_shape = tf.shape(w)
        w = tf.expand_dims(w, axis=-1)
        #[None, M, M] x [None, M, 1]
        w = tf.matmul(mshifts, w)
        w = tf.reshape(w, bkp_shape)
        self.w = tf.keras.layers.Softmax(name='%s_output' % (self.name))(
            w * (self.gamma + 1))
        self.layers.append(self.w)
        if ntm_layer.network is not None:
            self.pw_slot.update.append(self.w)
        return self.w
Esempio n. 14
0
            def dp_body(k, logZ, lastZ):
                '''
        case j < N-k + 1:
          logZ[j,k] = log_sum( log(1-pi(j)) + logZ[j+1,k], logp(j) + logZ[j+1,k-1])
        case j = N-k + 1
          logZ[j,k] = accum_logp[j]
        case j > N-k + 1
          logZ[j,k] = 0
        '''

                # shift lastZ one step
                shifted_lastZ = tf.roll(lastZ[:, :-1], shift=1,
                                        axis=1)  #logZ[j+1,k-1]
                log_yes = logp + shifted_lastZ  # b x N
                logZ_j = tf.TensorArray(tf.float32, size=seq_len + 1)
                init_value = accum_logp[:, seq_len - k]
                logZ_j = logZ_j.write(seq_len - k, init_value)
                _, logZ_j, logb, loga = tf.while_loop(
                    accum_cond, accum_body,
                    [seq_len - k - 1, logZ_j, log_yes, init_value])
                logZ_j = logZ_j.stack()  # N x b
                logZ_j = tf.transpose(logZ_j, [1, 0])  # b x N
                logZ = logZ.write(k, logZ_j)
                return [tf.add(k, 1), logZ, logZ_j]
Esempio n. 15
0
def sequentially_resolve_permutation(
        blocks: tf.Tensor,
        window: tf.Tensor,
        loss_fn: LossType = _mag_stft_mse_loss) -> tf.Tensor:
    """Resolves permutation between overlapping blocks.

  Args:
    blocks: Waveform in blocks (blocks, sources, samples)
    window: Window function used to obtain the blocks.
    loss_fn: A loss function operating on two tensors and returning
      a loss tensor, reducing loss value over the last dimension.
  Returns:
    perm_blocks: Permute sources to be consistent across blocks with shape
      (blocks, sources, samples)
  """
    num_blocks = tf.shape(blocks)[0]
    num_sources = blocks.shape[1]
    block_samples = blocks.shape[-1]
    hop_samples = block_samples // 2
    # If the blocks were obtained after windowing, we need to compensate
    # for the effect of the window. We can do this by dividing with the window
    # however that can be unstable, so we compensate by rolling the window and
    # multiplying the blocks with the rolled window which ends up having the
    # same windowing effect on begin and end parts of the block. This is
    # numerically more stable and does not give large weight to samples at the
    # edges which can cause problems in practice.
    rolled_window = tf.roll(window, shift=hop_samples, axis=0)
    reweighted_blocks = rolled_window * blocks
    # Split samples dimension into two: begin-part and end-part.
    begin_parts, end_parts = tf.split(reweighted_blocks, 2, axis=-1)

    # We assume axis=1 is the reference, axis=2 is the estimate.
    # End parts of previous frame is the reference, begin parts of current
    # frame is the estimate.
    begin_parts = tf.expand_dims(begin_parts, axis=1)
    end_parts = tf.expand_dims(end_parts, axis=2)
    loss_matrix = loss_fn(end_parts[:-1], begin_parts[1:])
    # loss_matrix has shape (blocks - 1, sources, sources).
    # We find the permutations that aligns each current block with its previous
    # block using batch mode permutation resolution where the blocks
    # dimension is considered as the batch dimension.
    permutation = _resolve_permutation(loss_matrix)
    # permutation has shape (blocks - 1, sources, 1), so we reshape.
    permutation = tf.reshape(permutation, (num_blocks - 1, num_sources))
    # Sequentially update permutations to consider the first block as the
    # reference instead of the previous block to align them all.
    # We do this by sequentially realigning each block with the previous block.
    # Since the previous block is already updated to align with the first block,
    # due to the iterative induction here, we get all aligned signals.
    updated_perm_init = tf.TensorArray(tf.int32,
                                       size=num_blocks,
                                       element_shape=(num_sources, ))
    previous_perm = tf.range(num_sources)
    # First block of updated perm is identity permutation.
    updated_perm_init = updated_perm_init.write(0, previous_perm)

    def update_perms(i, previous_perm, updated_perm):
        # Note: permutation is missing first block, so we access (i-1)th entry.
        perm_now = tf.gather(permutation[i - 1], previous_perm, axis=0)
        updated_perm = updated_perm.write(i, perm_now)
        previous_perm = perm_now
        i = i + 1
        return i, previous_perm, updated_perm

    while_cond = lambda i, prev_perm, updated_perm: tf.less(i, num_blocks)
    _, _, updated_perm = tf.while_loop(while_cond,
                                       update_perms,
                                       [1, previous_perm, updated_perm_init],
                                       back_prop=False,
                                       parallel_iterations=1)
    updated_perm = updated_perm.stack()
    # u_perm has shape (n_blocks, n_sources).
    # Obtain the indices for gather_nd.
    batch_index = tf.tile(tf.expand_dims(tf.range(num_blocks), axis=1),
                          [1, num_sources])
    batch_index = tf.reshape(batch_index, [num_blocks, num_sources])
    perm_index = tf.stack([batch_index, updated_perm], axis=2)
    # Now, permute blocks tensor to align all blocks.
    perm_blocks = tf.gather_nd(blocks, perm_index)
    return perm_blocks
Esempio n. 16
0
    def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
        """The `model_fn` for TPUEstimator."""

        logging.info("*** Model: Params ***")
        for name in sorted(params.keys()):
            logging.info("  %s = %s", name, params[name])
        logging.info("*** Model: Features ***")
        for name in sorted(features.keys()):
            logging.info("  name = %s, shape = %s", name, features[name].shape)

        model = modeling.ReadItTwiceBertModel(
            config=model_config, use_one_hot_embeddings=use_one_hot_embeddings)

        span_prediction_layer = modeling.SpanPredictionHead(
            intermediate_size=model_config.intermediate_size,
            dropout_rate=model_config.hidden_dropout_prob)

        # [batch_size, main_seq_length]
        token_ids = features["token_ids"]
        main_seq_length = tf.shape(token_ids)[1]
        block_ids = features["block_ids"]
        block_pos = features["block_pos"]

        annotation_begins = features.get("entity_annotation_begins")
        annotation_ends = features.get("entity_annotation_ends")
        annotation_labels = features.get("entity_annotation_labels")

        # Do not attend padding tokens
        # [batch_size, main_seq_length, main_seq_length]
        att_mask = tf.tile(
            tf.expand_dims(tf.not_equal(token_ids, padding_token_id), 1),
            [1, main_seq_length, 1])
        att_mask = tf.cast(att_mask, dtype=tf.int32)

        main_output = model(
            token_ids=token_ids,
            training=(mode == tf.estimator.ModeKeys.TRAIN),
            block_ids=block_ids,
            block_pos=block_pos,
            att_mask=att_mask,
            annotation_begins=annotation_begins,
            annotation_ends=annotation_ends,
            annotation_labels=annotation_labels,
            enable_side_inputs=enable_side_inputs,
            num_replicas_concat=num_replicas_concat,
            cross_block_attention_mode=cross_block_attention_mode)

        span_logits = span_prediction_layer(
            hidden_states=main_output.final_hidden_states,
            token_ids=token_ids,
            padding_token_id=padding_token_id,
            ignore_prefix_length=features["prefix_length"],
            training=(mode == tf.estimator.ModeKeys.TRAIN))

        is_summary_loss_enabled = (mode == tf.estimator.ModeKeys.TRAIN
                                   and summary_loss_weight is not None
                                   and summary_loss_weight > 0)
        if is_summary_loss_enabled:
            logging.info("Using summary prediction loss with weight %.3f",
                         summary_loss_weight)
            summary_token_ids = features["summary_token_ids"]
            summary_labels = tf.roll(summary_token_ids, shift=-1, axis=1)
            decoder = modeling.ReadItTwiceDecoderModel(
                config=model_config,
                num_layers_override=summary_num_layers,
                num_cross_attention_heads=summary_num_cross_attention_heads,
                enable_default_side_input=summary_enable_default_side_input,
                use_one_hot_embeddings=use_one_hot_embeddings)
            summary_token_logits = decoder(
                token_ids=summary_token_ids,
                side_input=main_output.global_summary.states,
                token2side_input_att_mask=modeling.get_cross_block_att(
                    block_ids,
                    block_pos,
                    main_output.global_summary.block_ids,
                    main_output.global_summary.block_pos,
                    cross_block_attention_mode="doc"),
                training=True)
            language_model_loss_fn = losses.LanguageModelLoss(
                decoder.get_token_embedding_table(),
                hidden_size=model_config.hidden_size)
            language_model_loss = language_model_loss_fn(
                summary_token_logits,
                summary_labels,
                padding_token_id=padding_token_id).loss
        else:
            language_model_loss = None

        tvars = tf.trainable_variables()

        initialized_variable_names = {}
        scaffold_fn = None
        if init_checkpoint:
            (assignment_map, initialized_variable_names
             ) = checkpoint_utils.get_assignment_map_from_checkpoint(
                 tvars, init_checkpoint)
            if use_tpu:

                def tpu_scaffold():
                    tf.train.init_from_checkpoint(init_checkpoint,
                                                  assignment_map)
                    return tf.train.Scaffold()

                scaffold_fn = tpu_scaffold
            else:
                tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                         init_string)

        output_spec = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            host_inputs = dict()

            span_prediction_loss = losses.BatchSpanCrossEntropyLoss()

            qa_loss = span_prediction_loss(
                logits=span_logits,
                annotation_begins=features["answer_annotation_begins"],
                annotation_ends=features["answer_annotation_ends"],
                annotation_labels=features["answer_annotation_labels"],
                block_ids=block_ids,
                num_replicas=num_replicas_concat,
                eps=1e-5)
            host_inputs["train_metrics/qa_loss"] = tf.expand_dims(qa_loss, 0)

            if language_model_loss is not None:
                total_loss = (
                    1.0 / (1.0 + summary_loss_weight) * qa_loss +
                    summary_loss_weight /
                    (1.0 + summary_loss_weight) * language_model_loss)
                host_inputs["train_metrics/summary_lm_loss"] = tf.expand_dims(
                    language_model_loss, 0)
            else:
                total_loss = qa_loss

            # Add regularization losses.
            if model.losses:
                total_loss += tf.math.add_n(model.losses)

            train_op = optimization.create_optimizer(total_loss,
                                                     learning_rate,
                                                     num_train_steps,
                                                     num_warmup_steps,
                                                     use_tpu,
                                                     optimizer,
                                                     poly_power,
                                                     start_warmup_step,
                                                     learning_rate_schedule,
                                                     reduce_loss_sum=True)

            host_inputs.update({
                "global_step":
                tf.expand_dims(tf.train.get_or_create_global_step(), 0),
                "train_metrics/loss":
                tf.expand_dims(total_loss, 0),
            })

            host_call = (functools.partial(record_summary_host_fn,
                                           metrics_dir=os.path.join(
                                               FLAGS.output_dir,
                                               "train_metrics")), host_inputs)

            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                scaffold_fn=scaffold_fn,
                host_call=host_call)
        elif mode == tf.estimator.ModeKeys.PREDICT:
            begin_logits_values, begin_logits_indices = tf.math.top_k(
                span_logits[:, :, 0],
                k=nbest_logits_for_eval,
            )
            end_logits_values, end_logits_indices = tf.math.top_k(
                span_logits[:, :, 1],
                k=nbest_logits_for_eval,
            )

            predictions = {
                "block_ids": tf.identity(block_ids),
                "begin_logits_values": begin_logits_values,
                "begin_logits_indices": begin_logits_indices,
                "end_logits_values": end_logits_values,
                "end_logits_indices": end_logits_indices,
                "token_ids": tf.identity(token_ids),
            }
            output_spec = tf.estimator.tpu.TPUEstimatorSpec(
                mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
        else:
            raise ValueError("Only TRAIN and PREDICT modes is supported: %s" %
                             (mode))

        return output_spec
Esempio n. 17
0
    def call(self, inputs, prev_state):
        """Evaluates one timestep of the current neural stack cell.

    See section 3.4 of Grefenstette et al., 2015.

    Args:
      inputs: The inputs to the neural stack cell should be a tf.float32 tensor
        with shape [batch_size, embedding_size]
      prev_state: The NeuralStackState from the previous timestep.

    Returns:
      A tuple of the output of the stack as well as the new NeuralStackState.
    """
        batch_size = tf.shape(inputs)[0]

        # Call the controller and get controller interface values.
        with tf.control_dependencies([prev_state.read_strengths]):
            controller_output = self.call_controller(
                inputs, prev_state.read_values, prev_state.controller_state,
                batch_size)

        # Always write input values to memory regardless of push strength.
        # See Equation-1 in Grefenstette et al., 2015.
        new_memory_values = prev_state.memory_values + tf.reduce_sum(
            tf.expand_dims(controller_output.write_values, axis=2) *
            prev_state.write_strengths,
            axis=1)

        # Attenuate the read strengths of existing memory values depending on the
        # current pop strength.
        # See Equation-2 in Grefenstette et al., 2015.
        new_read_strengths = prev_state.read_strengths
        for h in range(self._num_read_heads - 1, -1, -1):
            new_read_strengths = tf.nn.relu(new_read_strengths - tf.nn.relu(
                tf.slice(controller_output.pop_strengths, [0, h, 0, 0],
                         [-1, 1, -1, -1]) -
                tf.expand_dims(tf.reduce_sum(
                    new_read_strengths * self.get_read_mask(h), axis=2),
                               axis=3)))

        # Combine all write heads and their associated push values into a single set
        # of read weights.
        new_read_strengths += tf.reduce_sum(controller_output.push_strengths *
                                            prev_state.write_strengths,
                                            axis=1,
                                            keep_dims=True)

        # Calculate the "top" value of the stack by looking at read strengths.
        # See Equation-3 in Grefenstette et al., 2015.
        new_read_values = tf.reduce_sum(
            tf.minimum(
                new_read_strengths,
                tf.nn.relu(1 - tf.expand_dims(tf.reduce_sum(
                    new_read_strengths * tf.concat([
                        self.get_read_mask(h)
                        for h in range(self._num_read_heads)
                    ],
                                                   axis=1),
                    axis=2),
                                              axis=3))) *
            tf.expand_dims(new_memory_values, axis=1),
            axis=2)

        # Temporarily split write strengths apart so they can be shifted in
        # different directions.
        write_strengths_by_head = tf.split(prev_state.write_strengths,
                                           self._num_write_heads,
                                           axis=1)
        # Shift the write strengths for each write head in the direction indicated
        # by get_write_head_offset().
        new_write_strengths = tf.concat([
            tf.roll(
                write_strength, shift=self.get_write_head_offset(h), axis=2)
            for h, write_strength in enumerate(write_strengths_by_head)
        ],
                                        axis=1)

        return (controller_output.outputs,
                NeuralStackState(controller_state=controller_output.state,
                                 read_values=new_read_values,
                                 memory_values=new_memory_values,
                                 read_strengths=new_read_strengths,
                                 write_strengths=new_write_strengths))
Esempio n. 18
0
    def set_network(self, Gs, dtype='float16'):
        if Gs is None:
            self._Gs = None
            return
        self._Gs = Gs.clone(randomize_noise=False, dtype=dtype, num_fp16_res=0, fused_modconv=True)

        # Compute dlatent stats.
        self._info(f'Computing W midpoint and stddev using {self.dlatent_avg_samples} samples...')
        latent_samples = np.random.RandomState(123).randn(self.dlatent_avg_samples, *self._Gs.input_shapes[0][1:])
        dlatent_samples = self._Gs.components.mapping.run(latent_samples, None)  # [N, L, C]
        dlatent_samples = dlatent_samples[:, :1, :].astype(np.float32)           # [N, 1, C]
        self._dlatent_avg = np.mean(dlatent_samples, axis=0, keepdims=True)      # [1, 1, C]
        self._dlatent_std = (np.sum((dlatent_samples - self._dlatent_avg) ** 2) / self.dlatent_avg_samples) ** 0.5
        self._info(f'std = {self._dlatent_std:g}')

        # Setup noise inputs.
        self._info('Setting up noise inputs...')
        self._noise_vars = []
        noise_init_ops = []
        noise_normalize_ops = []
        while True:
            n = f'G_synthesis/noise{len(self._noise_vars)}'
            if not n in self._Gs.vars:
                break
            v = self._Gs.vars[n]
            self._noise_vars.append(v)
            noise_init_ops.append(tf.compat.v1.assign(v, tf.random.normal(tf.shape(input=v), dtype=tf.float32)))
            noise_mean = tf.reduce_mean(input_tensor=v)
            noise_std = tf.reduce_mean(input_tensor=(v - noise_mean)**2)**0.5
            noise_normalize_ops.append(tf.compat.v1.assign(v, (v - noise_mean) / noise_std))
        self._noise_init_op = tf.group(*noise_init_ops)
        self._noise_normalize_op = tf.group(*noise_normalize_ops)

        # Build image output graph.
        self._info('Building image output graph...')
        self._minibatch_size = 1
        self._dlatents_var = tf.Variable(tf.zeros([self._minibatch_size] + list(self._dlatent_avg.shape[1:])), name='dlatents_var')
        self._dlatent_noise_in = tf.compat.v1.placeholder(tf.float32, [], name='noise_in')
        dlatents_noise = tf.random.normal(shape=self._dlatents_var.shape) * self._dlatent_noise_in
        self._dlatents_expr = tf.tile(self._dlatents_var + dlatents_noise, [1, self._Gs.components.synthesis.input_shape[1], 1])
        self._images_float_expr = tf.cast(self._Gs.components.synthesis.get_output_for(self._dlatents_expr), tf.float32)
        self._images_uint8_expr = tflib.convert_images_to_uint8(self._images_float_expr, nchw_to_nhwc=True)

        # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
        proc_images_expr = (self._images_float_expr + 1) * (255 / 2)
        sh = proc_images_expr.shape.as_list()
        if sh[2] > 256:
            factor = sh[2] // 256
            proc_images_expr = tf.reduce_mean(input_tensor=tf.reshape(proc_images_expr, [-1, sh[1], sh[2] // factor, factor, sh[2] // factor, factor]), axis=[3,5])

        # Build loss graph.
        self._info('Building loss graph...')
        self._target_images_var = tf.Variable(tf.zeros(proc_images_expr.shape), name='target_images_var')
        if self._lpips is None:
            with dnnlib.util.open_url('https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/metrics/vgg16_zhang_perceptual.pkl') as f:
                self._lpips = pickle.load(f)
        self._dist = self._lpips.get_output_for(proc_images_expr, self._target_images_var)
        self._loss = tf.reduce_sum(input_tensor=self._dist)

        # Build noise regularization graph.
        self._info('Building noise regularization graph...')
        reg_loss = 0.0
        for v in self._noise_vars:
            sz = v.shape[2]
            while True:
                reg_loss += tf.reduce_mean(input_tensor=v * tf.roll(v, shift=1, axis=3))**2 + tf.reduce_mean(input_tensor=v * tf.roll(v, shift=1, axis=2))**2
                if sz <= 8:
                    break # Small enough already
                v = tf.reshape(v, [1, 1, sz//2, 2, sz//2, 2]) # Downscale
                v = tf.reduce_mean(input_tensor=v, axis=[3, 5])
                sz = sz // 2
        self._loss += reg_loss * self.regularize_noise_weight

        # Setup optimizer.
        self._info('Setting up optimizer...')
        self._lrate_in = tf.compat.v1.placeholder(tf.float32, [], name='lrate_in')
        self._opt = tflib.Optimizer(learning_rate=self._lrate_in)
        self._opt.register_gradients(self._loss, [self._dlatents_var] + self._noise_vars)
        self._opt_step = self._opt.apply_updates()
Esempio n. 19
0
def tgn_memory(
    n_nodes: int,
    memory_size: int,
    time_embedding_size: int,
    node_ids: tf.Tensor,
    write_idx: tf.Tensor,
    write_mask: tf.Tensor,
    write_features: tf.Tensor,
    write_times: tf.Tensor,
) -> TgnMemory:
    """Create TGN memory read & update operations.

    A trainable memory for nodes in an temporal interaction graph. The memory
    state is computed using the latest interaction event that touched a node.
    The update is a GRU cell, taking as input the previous memory of both source
    and desination nodes for that edge, the edge feature vector and time difference
    from interaction to current time.

    Note that the GRU cell is computed lazily when the memory is read, rather than
    when it is stored, to support a single step of truncated backpropagation through
    time and obtain a gradient for GRU variables.

    Please see "Temporal Graph Network" (https://arxiv.org/abs/2006.10637) for full
    details.

    Arguments:

      n_nodes -- total number of slots in the memory

      memory_size -- size of stored state in the memory / GRU cell output size

      time_embedding_size -- size of the time encoding activation provided to the
                             GRU cell

      node_ids -- shape (n_read), (-1 <= ID < n_nodes), the memory locations to be read

      write_idx -- shape (2, n_write), (0 <= idx < n_read), the (src, dst) indices of
                   edges, selecting nodes that should be written with their updated
                   memory state

      write_mask -- shape (2, n_write), boolean tensor for elements in write_idx that
                    should be written (true) or skipped (false), such that each memory
                    location is written at most once

      write_features -- shape (n_write, feature_size), input features to be stored and
                        used to compute the memory when it is next accessed

      write_times -- shape (n_write), edge event times to be stored and used to compute
                     the memory when it next accessed

    Returns:

      TgnMemory(
        output      -- tensor of shape (n_read, memory_size), current memory for node_ids
        last_update -- tensor of shape (n_read), last update of output
        updates     -- tuple of operations to run to update the memory
      )
    """
    assert_shape(node_ids, (None, ))
    _, n_write = assert_shape(write_idx, (2, None))
    assert_shape(write_mask, (2, n_write))
    _, feature_size = assert_shape(write_features, (n_write, None))
    assert_shape(write_times, (n_write, ))
    dtype = write_features.dtype

    # Declare memory
    # As an optimisation, we concatenate the 6 fields required by the memory
    # into 2 tensors, one consisting of ints, the other of floats.
    # This requires some extra code to slice and concat, but means we can use
    # 2 (dynamic) gather operations instead of 6.

    # Each row: [last_update, dt, neighbour]
    v_ints = tf.get_variable(
        "ints",
        shape=(1 + n_nodes, 3),
        dtype=tf.int32,
        trainable=False,
        initializer=tf.zeros_initializer(),
        collections=[tf.GraphKeys.GLOBAL_VARIABLES, TGN_MEMORY_VARIABLES_KEY],
    )
    # Each row: [memory, features, direction]
    v_floats = tf.get_variable(
        "floats",
        shape=(1 + n_nodes, memory_size + feature_size + 2),
        dtype=dtype,
        trainable=False,
        initializer=tf.zeros_initializer(),
        collections=[tf.GraphKeys.GLOBAL_VARIABLES, TGN_MEMORY_VARIABLES_KEY],
    )

    # Memory[0] is used for padding (node_ids == -1)
    safe_node_ids = 1 + node_ids

    # Read memory for node_ids
    node_ints = tf.gather(v_ints, safe_node_ids)
    node_last_update, node_dt, node_neighbour_idx = tf.unstack(node_ints,
                                                               axis=1)
    node_neighbour = tf.gather(v_floats[:, :memory_size], node_neighbour_idx)
    node_time_encoding = time_encoder(tf.cast(node_dt, tf.float32),
                                      time_embedding_size, dtype)

    node_floats = tf.gather(v_floats, safe_node_ids)
    node_self = node_floats[:, :memory_size]
    node_features = node_floats[:, memory_size:memory_size + feature_size]
    node_direction = node_floats[:, memory_size + feature_size:]

    node_memory = gru_cell(
        node_self,
        tf.concat(
            [
                node_direction[:, 0, tf.newaxis] * node_self +
                node_direction[:, 1, tf.newaxis] * node_neighbour,
                node_direction[:, 1, tf.newaxis] * node_self +
                node_direction[:, 0, tf.newaxis] * node_neighbour,
                node_features,
                node_time_encoding,
            ],
            axis=1,
        ),
    )

    # Write memory according to (write_idx, write_mask)
    flat_write_idx = tf.reshape(write_idx, (-1, ))
    indices = tf.gather(safe_node_ids, flat_write_idx)
    masked_indices = indices * tf.cast(tf.reshape(write_mask,
                                                  (-1, )), indices.dtype)
    p_last_update = tf.reshape(tf.tile(write_times[tf.newaxis], (2, 1)),
                               (-1, ))
    p_dt = p_last_update - tf.gather(node_last_update, flat_write_idx)
    # Swap src and dst indices to get the neighbour index for each node
    p_neighbour = tf.roll(indices, n_write, 0)
    p_memory = tf.gather(node_memory, flat_write_idx)
    p_features = tf.tile(write_features, (2, 1))
    p_direction = tf.repeat(tf.eye(2, dtype=dtype), n_write,
                            0)  # src=[1, 0], dst=[0, 1]

    # There is already a data dependency, but just to be sure...
    with tf.control_dependencies([node_last_update, node_memory]):
        update_ints = v_ints.scatter_update(
            tf.IndexedSlices(
                tf.stack([p_last_update, p_dt, p_neighbour], axis=1),
                masked_indices))
        update_floats = v_floats.scatter_update(
            tf.IndexedSlices(
                tf.concat([p_memory, p_features, p_direction], axis=1),
                masked_indices))

    return TgnMemory(
        output=node_memory,
        last_update=node_last_update,
        updates=(update_ints, update_floats),
    )
Esempio n. 20
0
#%-----------------Loss Function-------------------------------------------
loss = -tf.reduce_sum(tf.multiply(Y, tf.log(Y_pred)))

#%-----------------Backward Propagation for Output Layer-------------------

#delY formation
del_Y = -tf.divide(Y, Y_pred)

#del_W3 Calculation
#delz3_delW3 formation
delz3_delW3_elem = tf.concat([H2, tf.zeros_like(H2)], 1)
for k in range(8):
    delz3_delW3_elem = tf.concat([delz3_delW3_elem, tf.zeros_like(H2)], 1)
delz3_delW3_list = []
for k in range(10):
    delz3_delW3_list.append(tf.roll(delz3_delW3_elem, k, 1))
delz3_delW3 = tf.stack(delz3_delW3_list, 2)

#dely_delz3 formation
temp = -tf.matmul(Y_pred, Y_pred, transpose_b=True)
temp_diag = tf.reshape(tf.diag(Y_pred), [Y_pred.shape[0], Y_pred.shape[0]])
dely_delz3 = tf.add(temp, temp_diag)

tempz3 = tf.matmul(dely_delz3, del_Y)
del_W3 = tf.reshape(tf.matmul(delz3_delW3, tempz3), W3.shape)

#del_H2 Calculation
del_H2 = tf.reshape(tf.matmul(W3, tempz3), H2.shape)

#del_W3_0 Calculation
del_W3_0 = tempz3
  def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
    """The `model_fn` for TPUEstimator."""

    logging.info("*** Model: Params ***")
    for name in sorted(params.keys()):
      logging.info("  %s = %s", name, params[name])
    logging.info("*** Model: Features ***")
    for name in sorted(features.keys()):
      logging.info("  name = %s, shape = %s", name, features[name].shape)

    model = modeling.ReadItTwiceBertModel(
        config=model_config, use_one_hot_embeddings=use_one_hot_embeddings)

    # [batch_size, main_seq_length]
    token_ids = features["token_ids"]
    batch_size = tf.shape(token_ids)[0]
    main_seq_length = tf.shape(token_ids)[1]
    block_ids = features["block_ids"]
    block_pos = features["block_pos"]

    annotation_begins = features.get("annotation_begins")
    annotation_ends = features.get("annotation_ends")
    annotation_labels = features.get("annotation_labels")

    # Do not attend padding tokens
    # [batch_size, main_seq_length, main_seq_length]
    att_mask = tf.tile(
        tf.expand_dims(tf.not_equal(token_ids, padding_token_id), 1),
        [1, main_seq_length, 1])
    att_mask = tf.cast(att_mask, dtype=tf.int32)

    main_output = model(
        token_ids=token_ids,
        training=(mode == tf.estimator.ModeKeys.TRAIN),
        block_ids=block_ids,
        block_pos=block_pos,
        att_mask=att_mask,
        annotation_begins=annotation_begins,
        annotation_ends=annotation_ends,
        annotation_labels=annotation_labels,
        enable_side_inputs=enable_side_inputs,
        num_replicas_concat=num_replicas_concat,
        cross_block_attention_mode=cross_block_attention_mode)

    mlm_loss_fn = losses.LanguageModelLoss(
        model.get_token_embedding_table(),
        hidden_size=model_config.hidden_size,
        name="mlm_loss")
    mlm_loss_output = mlm_loss_fn(
        input_tensor=main_output.final_hidden_states,
        label_ids=features["masked_lm_ids"],
        positions=features["masked_lm_positions"],
        label_weights=features["masked_lm_weights"],
        mlm_is_entity_mask=features.get("mlm_is_entity_mask"),
        mlm_is_not_entity_mask=features.get("mlm_is_not_entity_mask"),
        padding_token_id=padding_token_id)
    mlm_loss = mlm_loss_output.loss

    loss_to_log = dict(mlm_loss=tf.expand_dims(mlm_loss, 0))
    loss_weight_denominator = 1.0 + sum(extra_loss.values())
    total_loss = mlm_loss * (1.0 / loss_weight_denominator)
    for loss_name, loss_weight in extra_loss.items():
      logging.info("EXTRA LOSS: %s with weight %.2f", loss_name,
                   loss_weight / loss_weight_denominator)

      if model_config.summary_mode == "entity":
        # entity label "1" corresponds to unknown entity
        # there is no need to compute coreferense resolution loss
        # for these unknown entities.
        labels_weight = tf.cast(
            tf.logical_and(
                tf.not_equal(
                    tf.expand_dims(main_output.local_summary.labels, 1), 1),
                tf.not_equal(
                    tf.expand_dims(main_output.global_summary.labels, 0), 1)),
            tf.float32)
      else:
        labels_weight = None

      if loss_name == "sdp":
        loss_fn = losses.BatchCoreferenceResolutionLoss(
            apply_linear_layer=False)
        loss_value = loss_fn(
            main_output.local_summary.states,
            main_output.local_summary.labels,
            main_output.global_summary.states,
            main_output.global_summary.labels,
            labels_weight=labels_weight)
      elif loss_name == "sdp_linear":
        loss_fn = losses.BatchCoreferenceResolutionLoss(apply_linear_layer=True)
        loss_value = loss_fn(
            main_output.local_summary.states,
            main_output.local_summary.labels,
            main_output.global_summary.states,
            main_output.global_summary.labels,
            labels_weight=labels_weight)
      elif loss_name == "spp_linear":
        loss_fn = losses.BatchCoreferenceResolutionLoss(apply_linear_layer=True)
        # Positive examples are blocks which go one after another in the
        # original document.
        labels_mask = tf.less_equal(
            tf.abs(
                tf.expand_dims(main_output.local_summary.block_pos, 1) -
                tf.expand_dims(main_output.global_summary.block_pos, 0)), 1)
        loss_value = loss_fn(
            main_output.local_summary.states,
            main_output.local_summary.labels,
            main_output.global_summary.states,
            main_output.global_summary.labels,
            labels_mask=labels_mask,
            labels_weight=labels_weight)
      elif loss_name == "lm":
        token_labels = tf.roll(token_ids, shift=-1, axis=1)
        # [batch_size, global_batch_size]
        token2side_input_att_mask = modeling.get_cross_block_att(
            block_ids,
            block_pos,
            main_output.global_summary.block_ids,
            main_output.global_summary.block_pos,
            cross_block_attention_mode=cross_block_attention_mode,
            cast_to_int32=False)
        # We want to exclude the summary of the block itself
        # from decoder side input. As a proxy for this, we use block_ids AND
        # block_pos.
        samples_are_the_same = tf.logical_and(
            tf.equal(
                tf.expand_dims(block_ids, 1),
                tf.expand_dims(main_output.global_summary.block_ids, 0)),
            tf.equal(
                tf.expand_dims(block_pos, 1),
                tf.expand_dims(main_output.global_summary.block_pos, 0)))
        token2side_input_att_mask = tf.stop_gradient(
            tf.cast(
                tf.logical_and(token2side_input_att_mask,
                               tf.logical_not(samples_are_the_same)),
                dtype=tf.int32))

        decoder = modeling.ReadItTwiceDecoderModel(
            config=model_config,
            num_layers_override=summary_num_layers,
            num_cross_attention_heads=summary_num_cross_attention_heads,
            enable_default_side_input=summary_enable_default_side_input,
            use_one_hot_embeddings=use_one_hot_embeddings)
        summary_token_logits = decoder(
            token_ids=token_ids,
            side_input=main_output.global_summary.states,
            token2side_input_att_mask=token2side_input_att_mask,
            training=True)
        language_model_loss_fn = losses.LanguageModelLoss(
            decoder.get_token_embedding_table(),
            hidden_size=model_config.hidden_size)

        # We don't penalize the first and last 32 tokens, so the model does not
        # have incentive to memoize tokens at the border of blocks.
        labels_weights = tf.concat([
            tf.zeros([batch_size, 32], dtype=tf.bool),
            tf.ones([batch_size, main_seq_length - 32 * 2], dtype=tf.bool),
            tf.zeros([batch_size, 32], dtype=tf.bool)
        ],
                                   axis=1)
        labels_weights = tf.logical_and(
            labels_weights, tf.not_equal(token_labels, padding_token_id))
        labels_weights = tf.stop_gradient(
            tf.cast(labels_weights, dtype=tf.float32))

        loss_value = language_model_loss_fn(
            summary_token_logits, token_labels,
            label_weights=labels_weights).loss
      else:
        raise ValueError("Unknown extra loss: {}".format(loss_name))

      loss_to_log[loss_name] = tf.expand_dims(loss_value, 0)
      total_loss += loss_value * (loss_weight / loss_weight_denominator)

    if model.losses:
      total_loss += tf.math.add_n(model.losses)

    tvars = tf.trainable_variables()

    initialized_variable_names = {}
    scaffold_fn = None
    if init_checkpoint:
      (assignment_map, initialized_variable_names
      ) = checkpoint_utils.get_assignment_map_from_checkpoint(
          tvars, init_checkpoint)
      if use_tpu:

        def tpu_scaffold():
          tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
          return tf.train.Scaffold()

        scaffold_fn = tpu_scaffold
      else:
        tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

    logging.info("**** Trainable Variables ****")
    for var in tvars:
      init_string = ""
      if var.name in initialized_variable_names:
        init_string = ", *INIT_FROM_CKPT*"
      logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                   init_string)

    metric_fn_tensors = dict(
        mlm_loss_per_sample=mlm_loss_output.mlm_loss_per_sample,
        mlm_accuracy_per_sample=mlm_loss_output.mlm_accuracy_per_sample,
        mlm_weight_per_sample=mlm_loss_output.mlm_weight_per_sample,
        mlm_loss_per_entity_sample=mlm_loss_output.mlm_loss_per_entity_sample,
        mlm_accuracy_per_entity_sample=mlm_loss_output
        .mlm_accuracy_per_entity_sample,
        mlm_weight_per_entity_sample=mlm_loss_output
        .mlm_weight_per_entity_sample,
        mlm_loss_per_non_entity_sample=mlm_loss_output
        .mlm_loss_per_non_entity_sample,
        mlm_accuracy_per_non_entity_sample=mlm_loss_output
        .mlm_accuracy_per_non_entity_sample,
        mlm_weight_per_non_entity_sample=mlm_loss_output
        .mlm_weight_per_non_entity_sample,
        block_ids=block_ids)

    output_spec = None
    if mode == tf.estimator.ModeKeys.TRAIN:
      train_op = optimization.create_optimizer(
          total_loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu,
          optimizer, poly_power, start_warmup_step, learning_rate_schedule)

      metric_fn_tensors.update({
          "global_step":
              tf.expand_dims(tf.train.get_or_create_global_step(), 0),
          "loss":
              tf.expand_dims(total_loss, 0),
      })
      metric_fn_tensors.update(loss_to_log)

      host_call = (functools.partial(
          record_summary_host_fn,
          metrics_dir=os.path.join(FLAGS.output_dir, "train_metrics"),
          metrics_name=metrics_name or "train_metrics"), metric_fn_tensors)

      output_spec = tf.estimator.tpu.TPUEstimatorSpec(
          mode=mode,
          loss=total_loss,
          train_op=train_op,
          scaffold_fn=scaffold_fn,
          host_call=host_call)

    elif mode == tf.estimator.ModeKeys.EVAL:

      eval_metrics = (functools.partial(
          metric_utils.masked_lm_metrics,
          is_train=False,
          metrics_name=metrics_name or "eval_metrics"), metric_fn_tensors)
      output_spec = tf.estimator.tpu.TPUEstimatorSpec(
          mode=mode,
          loss=total_loss,
          eval_metrics=eval_metrics,
          scaffold_fn=scaffold_fn)
    else:
      raise ValueError("Only TRAIN and EVAL modes are supported: %s" % mode)

    return output_spec
Esempio n. 22
0
    def set_network(self, Gs, lpips, minibatch_size=1):
        assert minibatch_size == 1
        self._Gs = Gs
        self._minibatch_size = minibatch_size
        if self._Gs is None:
            return
        if self.clone_net:
            self._Gs = self._Gs.clone()

        # Find dlatent stats.
        self._info('Finding W midpoint and stddev using %d samples...' %
                   self.dlatent_avg_samples)
        latent_samples = np.random.RandomState(123).randn(
            self.dlatent_avg_samples, *self._Gs.input_shapes[0][1:])
        dlatent_samples = self._Gs.components.mapping.run(
            latent_samples, None)[:, :1, :]  # [N, 1, 512]
        self._dlatent_avg = np.mean(dlatent_samples, axis=0,
                                    keepdims=True)  # [1, 1, 512]
        self._dlatent_std = (np.sum((dlatent_samples - self._dlatent_avg)**2) /
                             self.dlatent_avg_samples)**0.5
        self._info('std = %g' % self._dlatent_std)

        # Find noise inputs.
        self._info('Setting up noise inputs...')
        self._noise_vars = []
        noise_init_ops = []
        noise_normalize_ops = []
        while True:
            n = 'G_synthesis/noise%d' % len(self._noise_vars)
            if not n in self._Gs.vars:
                break
            v = self._Gs.vars[n]
            self._noise_vars.append(v)
            noise_init_ops.append(
                tf.assign(v, tf.random_normal(tf.shape(v), dtype=tf.float32)))
            noise_mean = tf.reduce_mean(v)
            noise_std = tf.reduce_mean((v - noise_mean)**2)**0.5
            noise_normalize_ops.append(
                tf.assign(v, (v - noise_mean) / noise_std))
            self._info(n, v)
        self._noise_init_op = tf.group(*noise_init_ops)
        self._noise_normalize_op = tf.group(*noise_normalize_ops)

        # Image output graph.
        self._info('Building image output graph...')
        if self.uniform_latents:
            self._dlatents_var = tf.Variable(
                tf.zeros([self._minibatch_size] +
                         list(self._dlatent_avg.shape[1:])),
                name='dlatents_var')
            self._noise_in = tf.placeholder(tf.float32, [], name='noise_in')
            dlatents_noise = tf.random.normal(
                shape=self._dlatents_var.shape) * self._noise_in
            self._dlatents_expr = tf.tile(
                self._dlatents_var + dlatents_noise,
                [1, self._Gs.components.synthesis.input_shape[1], 1])
        else:
            self._dlatents_var = tf.Variable(tf.zeros([
                self._minibatch_size,
                self._Gs.components.synthesis.input_shape[1],
                self._dlatent_avg.shape[2]
            ]),
                                             name='dlatents_var')
            self._noise_in = tf.placeholder(tf.float32, [], name='noise_in')
            dlatents_noise = tf.random.normal(
                shape=self._dlatents_var.shape) * self._noise_in
            self._dlatents_expr = self._dlatents_var + dlatents_noise
        self._images_expr = self._Gs.components.synthesis.get_output_for(
            self._dlatents_expr, randomize_noise=False)

        # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
        proc_images_expr = (self._images_expr + 1) * (255 / 2)
        proc_images_expr = downSampleImage(proc_images_expr, 256)

        # Loss graph.
        self._info('Building loss graph...')
        self._target_images_var = tf.Variable(tf.zeros(
            self._images_expr.shape),
                                              name='target_images_var')
        downsampled_target_image = downSampleImage(self._target_images_var,
                                                   256)
        downsampled_target_image = (downsampled_target_image + 1) * (255 / 2)

        #print('_target_images_var:', self._target_images_var.shape)
        #print('_images_expr:', self._images_expr.shape)
        self._lpips = lpips
        if self._lpips is None:
            self._lpips = misc.load_pkl(
                'https://drive.google.com/uc?id=1N2-m9qszOeVC9Tq77WxsLnuWwOedQiD2'
            )  # vgg16_zhang_perceptual.pkl
        self._perceptual_dist = self._lpips.get_output_for(
            proc_images_expr, downsampled_target_image)
        perceptual_dist_mag = tf.reduce_sum(self._perceptual_dist)

        # Euclidean distance
        #self._euclidean_dist = tf.reduce_mean(tf.math.square((self._target_images_var - proc_images_expr) / 255.))
        self._euclidean_dist = tf.reduce_mean(
            tf.math.square(
                (self._target_images_var - self._images_expr) / 2.))**0.5

        self._loss = perceptual_dist_mag + self.euclidean_dist_weight * self._euclidean_dist

        # latent magnitude regularization
        if self.regularize_magnitude_weight > 0:
            self._loss += (tf.reduce_mean(tf.math.square(self._dlatents_var))**
                           0.5) * self.regularize_magnitude_weight

        # Noise regularization graph.
        self._info('Building noise regularization graph...')
        reg_loss = 0.0
        for v in self._noise_vars:
            sz = v.shape[2]
            while True:
                reg_loss += tf.reduce_mean(
                    v * tf.roll(v, shift=1, axis=3))**2 + tf.reduce_mean(
                        v * tf.roll(v, shift=1, axis=2))**2
                if sz <= 8:
                    break  # Small enough already
                v = tf.reshape(v, [1, 1, sz // 2, 2, sz // 2, 2])  # Downscale
                v = tf.reduce_mean(v, axis=[3, 5])
                sz = sz // 2
        self._loss += reg_loss * self.regularize_noise_weight

        # Optimizer.
        self._info('Setting up optimizer...')
        self._lrate_in = tf.placeholder(tf.float32, [], name='lrate_in')
        self._opt = dnnlib.tflib.Optimizer(learning_rate=self._lrate_in)
        self._opt.register_gradients(self._loss,
                                     [self._dlatents_var] + self._noise_vars)
        self._opt_step = self._opt.apply_updates()