def infer(self, features, **kwargs):
        with tf.variable_scope("sparse_transformer", reuse=tf.AUTO_REUSE):
            features = self.bottom(features)
        decode_length = self.hparams.max_target_length
        cache = {}
        decoding_stats = {}
        targets_old = features.get("targets")
        start_step = 0
        initial_output = tf.zeros((self.batch_size, decode_length, 1, 1),
                                  dtype=tf.int32)
        initial_logits = tf.zeros(
            (self.batch_size, decode_length, self.vocab_size))

        # call body once to initialize cache with representations of input frames.
        features["targets"] = initial_output
        # Set shape of inputs
        if "inputs" in features:
            features["inputs"].set_shape([
                self.batch_size, self.hparams.max_length, 1,
                self.hparams.hidden_size
            ])
        with tf.variable_scope("sparse_transformer/body", reuse=tf.AUTO_REUSE):
            self.body(features,
                      decode_step=None,
                      cache=cache,
                      decoding_stats=decoding_stats)

        def infer_step(i, recent_output, recent_logits, cache, decoding_stats):
            """Inference step."""
            features_copy = features.copy()
            features_copy["targets"] = recent_output
            cur_sample, cur_logit = self.sample(features_copy,
                                                decode_step=i,
                                                cache=cache,
                                                decoding_stats=decoding_stats)
            pos = i
            samples = recent_output + tf.scatter_nd(
                indices=[[b, pos, 0, 0] for b in range(self.batch_size)],
                updates=cur_sample,
                shape=utils.shape_list(recent_output))
            logits = recent_logits + tf.scatter_nd(
                indices=[[b, pos] for b in range(self.batch_size)],
                updates=cur_logit,
                shape=utils.shape_list(recent_logits))
            return i + 1, samples, logits, cache, decoding_stats

        def while_exit_cond(i, result, logits, cache, decoding_stats):  # pylint: disable=unused-argument
            """Exit the loop if it reaches decode_length."""
            not_overflow = i < decode_length
            return not_overflow

        _, final_result, final_logits, _, decoding_stats = tf.while_loop(
            while_exit_cond,
            infer_step, [
                start_step, initial_output, initial_logits, cache,
                decoding_stats
            ],
            back_prop=False,
            parallel_iterations=1)

        original_shape = [decode_length]

        blocks_per_dim = [
            s // q for s, q in zip(original_shape, self.hparams.query_shape)
        ]
        final_result_shape = utils.shape_list(final_result)
        final_result = tf.reshape(
            final_result,
            [final_result_shape[0], -1,
             np.prod(self.hparams.query_shape), 1])
        final_logits_shape = utils.shape_list(final_logits)
        final_logits = tf.reshape(final_logits, [
            final_logits_shape[0], -1,
            np.prod(self.hparams.query_shape), final_logits_shape[-1]
        ])
        final_result = utils.unflatten_blocks_nd(final_result, blocks_per_dim)
        final_result = utils.put_back_blocks_nd(final_result,
                                                self.hparams.query_shape)
        final_logits = utils.unflatten_blocks_nd(final_logits, blocks_per_dim)
        final_logits = utils.put_back_blocks_nd(final_logits,
                                                self.hparams.query_shape)

        for name, value in decoding_stats.items():
            tf.summary.scalar("decodes/%s" % name, value / decode_length)

        # Reassign targets back to the previous value.
        if targets_old is not None:
            features["targets"] = targets_old

        return {
            "outputs": final_result,
            "scores": None,
            "logits": final_logits,
            "losses": None,
        }
    def lstm_decoder_infer(self,
                           inputs,
                           sequence_length,
                           hparams,
                           clss,
                           train,
                           initial_state=None,
                           bottleneck=None):
        # IN PREDICT MODE, RUN tf.while RNN
        max_decode_length = 51
        batch_size = common_layers.shape_list(inputs)[0]
        zero_pad, logits_so_far = self.create_initial_input_for_decode(
            batch_size)

        layers = contrib_rnn.MultiRNNCell([
            self.lstm_cell(hparams, train)
            for _ in range(hparams.num_hidden_layers)
        ])

        if initial_state is None:
            raise Exception('initial state should be init from bottleneck!')

        # append one-hot class to bottleneck, which will be given per step
        clss = tf.reshape(clss, [-1])
        if not hparams.use_cls:
            clss = tf.zeros_like(clss)
        if hparams.condition_on_sln:
            sln = tf.reshape(sequence_length, [-1])
            bottleneck = tf.concat(
                (bottleneck, tf.one_hot(clss, hparams.num_categories),
                 tf.one_hot(sln, max_decode_length)), -1)
        else:
            bottleneck = tf.concat(
                (bottleneck, tf.one_hot(clss, hparams.num_categories)), -1)

        def infer_step(logits_so_far, current_hidden):
            """Inference step of LSTM while loop."""
            # unflatten hidden:
            current_hidden = tuple(
                tf.nn.rnn_cell.LSTMStateTuple(c=s[0], h=s[1])
                for s in current_hidden)

            # put logits_so_far through top
            tm = self._problem_hparams.modality['targets']
            # need to reuse top params
            reset_scope = tf.variable_scope(tf.VariableScope(
                tf.AUTO_REUSE, ''),
                                            reuse=tf.AUTO_REUSE,
                                            auxiliary_name_scope=False)
            top_scope = tf.variable_scope('svg_decoder/{}_modality'.format(tm),
                                          reuse=tf.AUTO_REUSE)
            with reset_scope, top_scope:
                samples_so_far = self.hparams.top['targets'](
                    logits_so_far, None, self.hparams,
                    self.problem_hparams.vocab_size)
            # append a zero pad to the samples. this effectively shifts the samples
            # right, but, unlike shift_right, by not removing the last element, we
            # allow an empty samples_so_far to not be empty after padding
            samples_so_far = tf.concat([zero_pad, samples_so_far], axis=1)
            shifted_targets = common_layers.flatten4d3d(samples_so_far)
            # now take the very last one here, will be the actual input to the rnn
            shifted_targets = shifted_targets[:, -1:, :]

            # tile and append the bottleneck to inputs
            sln_offset = 0
            if hparams.condition_on_sln:
                sln_offset = 51
            pre_tile_y = tf.reshape(bottleneck, [
                common_layers.shape_list(bottleneck)[0], 1,
                hparams.bottleneck_bits + hparams.num_categories + sln_offset
            ])
            overlay_x = tf.tile(
                pre_tile_y,
                [1, common_layers.shape_list(shifted_targets)[1], 1])
            inputs = tf.concat([shifted_targets, overlay_x], -1)

            seq_len_batch = tf.ones([common_layers.shape_list(inputs)[0]])

            # RUN PRE-LSTM LAYER
            with tf.variable_scope('pre_decoder', reuse=tf.AUTO_REUSE):
                inputs = tf.layers.dense(inputs,
                                         hparams.hidden_size,
                                         name='bottom')
                inputs = tf.nn.tanh(inputs)

            # RUN LSTM
            with tf.variable_scope('lstm_decoder', reuse=tf.AUTO_REUSE):
                next_step, next_state = tf.nn.dynamic_rnn(
                    layers,
                    inputs,
                    seq_len_batch,
                    initial_state=current_hidden,
                    dtype=tf.float32,
                    time_major=False)

            next_step = tf.expand_dims(next_step, [1])

            logits_so_far = tf.concat([logits_so_far, next_step], 1)
            #print('concat success')
            # input()
            # flatten state
            next_state = tuple((s.c, s.h) for s in next_state)

            return logits_so_far, next_state

        def while_exit_cond(logits_so_far, unused_current_hidden):
            length = common_layers.shape_list(logits_so_far)[1]
            return length < max_decode_length

        # passing state must be flattened:
        initial_state = tuple([(s.c, s.h) for s in initial_state])

        # actually run tf.while:
        logits, final_state = tf.while_loop(
            while_exit_cond,
            infer_step, [logits_so_far, initial_state],
            shape_invariants=[
                tf.TensorShape([None, None, 1, hparams.hidden_size]),
                tuple([(s[0].get_shape(), s[1].get_shape())
                       for s in initial_state]),
            ],
            back_prop=False,
            parallel_iterations=1)

        # logits should be returned in 3d mode:
        logits = common_layers.flatten4d3d(logits)

        return logits, final_state
示例#3
0
def sorted_non_max_suppression_padded(scores, boxes, max_output_size,
                                      iou_threshold):
    """A wrapper that handles non-maximum suppression.

  Assumption:
    * The boxes are sorted by scores unless the box is a dot (all coordinates
      are zero).
    * Boxes with higher scores can be used to suppress boxes with lower scores.

  The overal design of the algorithm is to handle boxes tile-by-tile:

  boxes = boxes.pad_to_multiply_of(tile_size)
  num_tiles = len(boxes) // tile_size
  output_boxes = []
  for i in range(num_tiles):
    box_tile = boxes[i*tile_size : (i+1)*tile_size]
    for j in range(i - 1):
      suppressing_tile = boxes[j*tile_size : (j+1)*tile_size]
      iou = bbox_overlap(box_tile, suppressing_tile)
      # if the box is suppressed in iou, clear it to a dot
      box_tile *= _update_boxes(iou)
    # Iteratively handle the diagnal tile.
    iou = _box_overlap(box_tile, box_tile)
    iou_changed = True
    while iou_changed:
      # boxes that are not suppressed by anything else
      suppressing_boxes = _get_suppressing_boxes(iou)
      # boxes that are suppressed by suppressing_boxes
      suppressed_boxes = _get_suppressed_boxes(iou, suppressing_boxes)
      # clear iou to 0 for boxes that are suppressed, as they cannot be used
      # to suppress other boxes any more
      new_iou = _clear_iou(iou, suppressed_boxes)
      iou_changed = (new_iou != iou)
      iou = new_iou
    # remaining boxes that can still suppress others, are selected boxes.
    output_boxes.append(_get_suppressing_boxes(iou))
    if len(output_boxes) >= max_output_size:
      break

  Args:
    scores: a tensor with a shape of [batch_size, anchors].
    boxes: a tensor with a shape of [batch_size, anchors, 4].
    max_output_size: a scalar integer `Tensor` representing the maximum number
      of boxes to be selected by non max suppression.
    iou_threshold: a float representing the threshold for deciding whether boxes
      overlap too much with respect to IOU.

  Returns:
    nms_scores: a tensor with a shape of [batch_size, anchors]. It has same
      dtype as input scores.
    nms_proposals: a tensor with a shape of [batch_size, anchors, 4]. It has
      same dtype as input boxes.
  """
    batch_size = tf.shape(boxes)[0]
    num_boxes = tf.shape(boxes)[1]
    pad = tf.cast(tf.ceil(tf.cast(num_boxes, tf.float32) / NMS_TILE_SIZE),
                  tf.int32) * NMS_TILE_SIZE - num_boxes
    boxes = tf.pad(tf.cast(boxes, tf.float32), [[0, 0], [0, pad], [0, 0]])
    scores = tf.pad(tf.cast(scores, tf.float32), [[0, 0], [0, pad]])
    num_boxes += pad

    def _loop_cond(unused_boxes, unused_threshold, output_size, idx):
        return tf.logical_and(
            tf.reduce_min(output_size) < max_output_size,
            idx < num_boxes // NMS_TILE_SIZE)

    selected_boxes, _, output_size, _ = tf.while_loop(
        _loop_cond, _suppression_loop_body, [
            boxes, iou_threshold,
            tf.zeros([batch_size], tf.int32),
            tf.constant(0)
        ])
    idx = num_boxes - tf.cast(
        tf.nn.top_k(
            tf.cast(tf.reduce_any(selected_boxes > 0, [2]), tf.int32) *
            tf.expand_dims(tf.range(num_boxes, 0, -1), 0), max_output_size)[0],
        tf.int32)
    idx = tf.minimum(idx, num_boxes - 1)
    idx = tf.reshape(
        idx + tf.reshape(tf.range(batch_size) * num_boxes, [-1, 1]), [-1])
    boxes = tf.reshape(tf.gather(tf.reshape(boxes, [-1, 4]), idx),
                       [batch_size, max_output_size, 4])
    boxes = boxes * tf.cast(
        tf.reshape(tf.range(max_output_size), [1, -1, 1]) < tf.reshape(
            output_size, [-1, 1, 1]), boxes.dtype)
    scores = tf.reshape(tf.gather(tf.reshape(scores, [-1, 1]), idx),
                        [batch_size, max_output_size])
    scores = scores * tf.cast(
        tf.reshape(tf.range(max_output_size), [1, -1]) < tf.reshape(
            output_size, [-1, 1]), scores.dtype)
    return scores, boxes
示例#4
0
        def compute_gradients(self,
                              loss,
                              var_list,
                              gate_gradients=GATE_OP,
                              aggregation_method=None,
                              colocate_gradients_with_ops=False,
                              grad_loss=None,
                              gradient_tape=None):
            if callable(loss):
                # TF is running in Eager mode, check we received a vanilla tape.
                if not gradient_tape:
                    raise ValueError(
                        'When in Eager mode, a tape needs to be passed.')

                vector_loss = loss()
                if self._num_microbatches is None:
                    self._num_microbatches = tf.shape(input=vector_loss)[0]
                sample_state = self._dp_sum_query.initial_sample_state(
                    var_list)
                microbatches_losses = tf.reshape(vector_loss,
                                                 [self._num_microbatches, -1])
                sample_params = (self._dp_sum_query.derive_sample_params(
                    self._global_state))

                def process_microbatch(i, sample_state):
                    """Process one microbatch (record) with privacy helper."""
                    microbatch_loss = tf.reduce_mean(
                        input_tensor=tf.gather(microbatches_losses, [i]))
                    grads = gradient_tape.gradient(microbatch_loss, var_list)
                    sample_state = self._dp_sum_query.accumulate_record(
                        sample_params, sample_state, grads)
                    return sample_state

                for idx in range(self._num_microbatches):
                    sample_state = process_microbatch(idx, sample_state)

                grad_sums, self._global_state = (
                    self._dp_sum_query.get_noised_result(
                        sample_state, self._global_state))

                def normalize(v):
                    return v / tf.cast(self._num_microbatches, tf.float32)

                final_grads = tf.nest.map_structure(normalize, grad_sums)

                grads_and_vars = list(zip(final_grads, var_list))
                return grads_and_vars

            else:
                # TF is running in graph mode, check we did not receive a gradient tape.
                if gradient_tape:
                    raise ValueError(
                        'When in graph mode, a tape should not be passed.')

                # Note: it would be closer to the correct i.i.d. sampling of records if
                # we sampled each microbatch from the appropriate binomial distribution,
                # although that still wouldn't be quite correct because it would be
                # sampling from the dataset without replacement.
                if self._num_microbatches is None:
                    self._num_microbatches = tf.shape(input=loss)[0]

                microbatches_losses = tf.reshape(loss,
                                                 [self._num_microbatches, -1])
                sample_params = (self._dp_sum_query.derive_sample_params(
                    self._global_state))

                def process_microbatch(i, sample_state):
                    """Process one microbatch (record) with privacy helper."""
                    grads, _ = zip(*super(cls, self).compute_gradients(
                        tf.reduce_mean(
                            input_tensor=tf.gather(microbatches_losses, [i])),
                        var_list, gate_gradients, aggregation_method,
                        colocate_gradients_with_ops, grad_loss))
                    grads_list = [
                        g if g is not None else tf.zeros_like(v)
                        for (g, v) in zip(list(grads), var_list)
                    ]
                    sample_state = self._dp_sum_query.accumulate_record(
                        sample_params, sample_state, grads_list)
                    return sample_state

                if var_list is None:
                    var_list = (tf.trainable_variables() + tf.get_collection(
                        tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))

                sample_state = self._dp_sum_query.initial_sample_state(
                    var_list)

                if self._unroll_microbatches:
                    for idx in range(self._num_microbatches):
                        sample_state = process_microbatch(idx, sample_state)
                else:
                    # Use of while_loop here requires that sample_state be a nested
                    # structure of tensors. In general, we would prefer to allow it to be
                    # an arbitrary opaque type.
                    cond_fn = lambda i, _: tf.less(i, self._num_microbatches)
                    body_fn = lambda i, state: [tf.add(i, 1), process_microbatch(i, state)]  # pylint: disable=line-too-long
                    idx = tf.constant(0)
                    _, sample_state = tf.while_loop(
                        cond=cond_fn,
                        body=body_fn,
                        loop_vars=[idx, sample_state])

                grad_sums, self._global_state = (
                    self._dp_sum_query.get_noised_result(
                        sample_state, self._global_state))

                def normalize(v):
                    return tf.truediv(
                        v, tf.cast(self._num_microbatches, tf.float32))

                final_grads = tf.nest.map_structure(normalize, grad_sums)

                return list(zip(final_grads, var_list))
示例#5
0
      def task_metalearn(inp, reuse=True):
        """Run meta learning."""
        TRAIN = 'train' in prefix  # pylint: disable=invalid-name
        # Perform gradient descent for one task in the meta-batch.
        inputa, inputb, labela, labelb = inp
        task_outputbs, task_lossesb = [], []
        task_msesb = []

        # support_pred and loss, (n_data_per_task, out_dim)
        task_outputa = self.forward(
            inputa, weights, reuse=reuse)  # only not reuse on the first iter
        # labela is (n_data_per_task, out_dim)
        task_lossa = self.loss_func(task_outputa, labela)

        # INNER LOOP (no change with ib)
        grads = tf.gradients(task_lossa, list(weights.values()))
        if FLAGS.stop_grad:
          grads = [tf.stop_gradient(grad) for grad in grads]
        gradients = dict(zip(weights.keys(), grads))
        ## theta_pi = theta - alpha * grads
        fast_weights = dict(
            zip(weights.keys(), [
                weights[key] - self.update_lr * gradients[key]
                for key in weights.keys()
            ]))

        # use theta_pi to forward meta-test
        output = self.forward(inputb, fast_weights, reuse=True)
        task_outputbs.append(output)
        # meta-test loss
        task_msesb.append(self.loss_func(output, labelb))
        task_lossesb.append(self.loss_func(output, labelb))

        def while_body(fast_weights_values):
          """Update params."""
          loss = self.loss_func(
              self.forward(
                  inputa,
                  dict(zip(fast_weights.keys(), fast_weights_values)),
                  reuse=True), labela)
          grads = tf.gradients(loss, fast_weights_values)
          fast_weights_values = [
              v - self.update_lr * g for v, g in zip(fast_weights_values, grads)
          ]
          return fast_weights_values

        fast_weights_values = tf.while_loop(
            lambda _: True,
            while_body,
            loop_vars=[fast_weights.values()],
            maximum_iterations=num_updates - 1,
            back_prop=TRAIN)
        fast_weights = dict(zip(fast_weights.keys(), fast_weights_values))

        output = self.forward(inputb, fast_weights, reuse=True)
        task_outputbs.append(output)
        task_msesb.append(self.loss_func(output, labelb))
        task_lossesb.append(self.loss_func(output, labelb))
        task_output = [
            task_outputa, task_outputbs, task_lossa, task_lossesb, task_msesb
        ]

        return task_output
示例#6
0
def sample_sequence(
    hparams,
    length,
    start_token=None,
    batch_size=None,
    context=None,
    temperature=1,
    top_k=0,
    top_p=0.0,
):
    if start_token is None:
        assert (context
                is not None), 'Specify exactly one of start_token and context!'
    else:
        assert (context is
                None), 'Specify exactly one of start_token and context!'
        context = tf.fill([batch_size, 1], start_token)

    def step(hparams, tokens, past=None):
        lm_output = gpt2_model.model(hparams=hparams,
                                     X=tokens,
                                     past=past,
                                     reuse=tf.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(
            gpt2_model.past_shape(hparams=hparams, batch_size=batch_size))
        return {'logits': logits, 'presents': presents}

    with tf.name_scope('sample_sequence'):
        context_output = step(hparams, context[:, :-1])

        def body(past, prev, output):
            next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
            logits = next_outputs['logits'][:, -1, :] / tf.cast(
                temperature, tf.float32)
            if top_p > 0.0:
                logits = top_p_logits(logits, p=top_p)
            else:
                logits = top_k_logits(logits, k=top_k)
            samples = tf.random.categorical(logits,
                                            num_samples=1,
                                            dtype=tf.int32)
            return [
                tf.concat([past, next_outputs['presents']], axis=-2),
                tf.squeeze(samples, axis=[1]),
                tf.concat([output, samples], axis=1),
            ]

        def cond(*args):
            return True

        _, _, tokens = tf.while_loop(
            cond=cond,
            body=body,
            maximum_iterations=length,
            loop_vars=[context_output['presents'], context[:, -1], context],
            shape_invariants=[
                tf.TensorShape(
                    gpt2_model.past_shape(hparams=hparams,
                                          batch_size=batch_size)),
                tf.TensorShape([batch_size]),
                tf.TensorShape([batch_size, None]),
            ],
            back_prop=False,
        )

        return tokens
示例#7
0
    def _build_adapted_parameters(
        self,
        inputs,
        labels,
        initial_parameters,
        num_steps,
        back_prop=False,
        parallel_iterations=1,
        shuffle=True,
    ):
        """Builds adapted model parameters dynamically using tf.while_loop.

        Parameters
        ----------
        inputs : Tensor <float32> [None, ...]
            Inputs of the samples used for building adapted parameters.

        labels : Tensor <float32> [None, num_classes]
            Labels of the samples used for building adapted parameters.

        initial_parameters : dict of Tensors
            A dictionary with initial parameters of the model.

        num_steps : int or Tensor <int32> []
            Number of gradient steps used for adaptation.

        back_prop : bool, optional (default: False)
            Indicates whether backprop is allowed through the adapted parameters.

        parallel_iterations : int, optional (default=1)
            Parallel iterations parameter for the tf.while_loop.

        shuffle : bool, optional (default=True)
            Whether to shuffle the samples before batching.

        Returns
        -------
        adapted_parameters : dict of Tensors
            A dictionary with adapted parameters of the model.
        """
        # If batch size not specified, use all inputs.
        batch_size = self.batch_size or tf.shape(inputs)[0]
        # Build batched indices.
        # <int32> [batch_size * num_steps].
        indices = tf.math.mod(tf.range(batch_size * num_steps, dtype=tf.int32),
                              tf.shape(inputs)[0])
        if shuffle:
            indices = tf.random.shuffle(indices)
        # <int32> [num_steps, batch_size].
        batched_indices = tf.reshape(indices, shape=(num_steps, batch_size))

        def cond_fn(step, _unused_params):
            return tf.less(step, num_steps)

        def body_fn(step, parameters):
            x = tf.gather(inputs, batched_indices[step], axis=0)
            y = tf.gather(labels, batched_indices[step], axis=0)
            # Build a model with new parameters.
            with utils.custom_make_variable(parameters, self.model.name):
                self.inner_adapted_models.append(self.model_builder())
            loss = self.inner_adapted_models[-1].loss(x, y)
            # Build new parameters.
            new_parameters = utils.build_new_parameters(
                loss,
                parameters,
                optimizer=self.inner_optimizer,
                first_order=self.first_order,
            )
            return [tf.add(step, 1), new_parameters]

        _, adapted_parameters = tf.while_loop(
            cond=cond_fn,
            body=body_fn,
            loop_vars=[tf.constant(0), initial_parameters],
            parallel_iterations=parallel_iterations,
            back_prop=back_prop,
            name="adapt",
        )

        return adapted_parameters
示例#8
0
def multiply2n_ragged(tensor1, tensor2):
    #this  function multiplies two ragged tesnsors of rank 2 . the most outer ranks of the two tensros must be equal .
    #setting variables and constats
    outerloop_counter = tf.constant(0, dtype=tf.int32)
    carry_on = tf.constant(0, dtype=tf.int32)
    taValues = tf.TensorArray(tf.float32,
                              size=0,
                              dynamic_size=True,
                              clear_after_read=False,
                              infer_shape=False)
    taL2Splits = tf.TensorArray(tf.int32,
                                size=0,
                                dynamic_size=True,
                                clear_after_read=False,
                                infer_shape=False)
    taL1Splits = tf.TensorArray(tf.int32,
                                size=0,
                                dynamic_size=True,
                                clear_after_read=False,
                                infer_shape=False)
    taL1Splits = taL1Splits.write(
        0, [0])  ## required intialization for L1 split only
    innerloop_processing_graphed = tf.function(innerloop_processing)
    generateL1Tensor_writeback_graphed = tf.function(
        generateL1Tensor_writeback)

    def outerloop_cond(counter, input1, input2, taValues, taL2Splits,
                       taL1Splits, carry_on):
        value = tf.shape(input1[2])[0] - 1
        return counter < value  ## this is the length of the outermost dimision , stop of this

    def outloop_body(counter, input1, input2, taValues, taL2Splits, taL1Splits,
                     carry_on):
        l1_comp_begin = input1[2][
            counter]  ## this is begin position of the current row in the outer split  ( ie. the ith value in the outer row split tensor )
        l1_comp_end = input1[2][
            counter +
            1]  ## this is end position of the current row in the outer split   (ie. the ith + 1 value in the outer row split tensor)
        l1_comp2_begin = input2[2][
            counter]  ## we do the same for the second components
        l1_comp2_end = input2[2][
            counter + 1]  ## we do the same for the second components
        comp = innerloop_processing_graphed(
            l1_comp_begin, l1_comp_end, input1
        )  ## now retrive the data to be procesed for the selected rows from vector1
        comp2 = innerloop_processing_graphed(
            l1_comp2_begin, l1_comp2_end, input2)  ## do the same for vector 2

        #comp2 = tf.transpose(comp2) ### desired operation
        multiply = tf.matmul(comp, comp2)  #### This is the desired operation

        myshape = tf.shape(
            multiply
        )  ## calculate the shape of the result in order to prepare to write the result in a ragged tensor format.
        offset = tf.cond(
            taValues.size() > 0, lambda: tf.shape(taValues.concat())[0],
            lambda: 0
        )  ### this is a hack, TensorArray.concat returns an error if the array is empty. Thus we check before calling this.
        #print11=tf.print("=================Final Shape is : " ,myshape[0] , " X " ,myshape[1] )
        l2v = generateL1Tensor_writeback_graphed(
            offset, myshape[1], myshape[0]
        )  # generate the inner row split of the result for the current element
        taL2Splits = taL2Splits.write(
            counter, l2v)  # write back the inner rowlplit to a TensorArray
        taValues = taValues.write(
            counter, tf.reshape(multiply, [-1])
        )  # wirte back the actual ragged tensor elemnts in a another TensorArray
        carry_on = carry_on + myshape[
            0]  ## required to calculate the outer row splite
        taL1Splits = taL1Splits.write(
            counter + 1, [carry_on])  ## This is the outmost row split.
        with tf.control_dependencies(
            [comp, comp2, myshape, l2v, carry_on, multiply]):
            counter = counter + 1
        return counter, input1, input2, taValues, taL2Splits, taL1Splits, carry_on

    with tf.name_scope("RaggedMultiply"):
        outerloop_finalcounter, _, _, ta1, ta2, ta3, _ = tf.while_loop(
            outerloop_cond,
            outloop_body, [
                outerloop_counter, tensor1, tensor2, taValues, taL2Splits,
                taL1Splits, carry_on
            ],
            back_prop=True)
    uinquie_ta2, _ = tf.unique(
        ta2.concat()
    )  # this is required since some values might be duplicate in the row split itself
    t1 = ta1.concat()
    t3 = ta3.concat()
    #with  tf.control_dependencies([t1 , uinquie_ta2 ,t3  ]):
    final_values = t1, uinquie_ta2, t3
    return final_values
    def _create_cross_entropy_action_tensors(self,
                                             num_samples=200,
                                             top_k_portion=0.5):
        """Create tensorflow operations for cross_entropy max_actions."""
        top_k_num = int(top_k_portion * num_samples)

        self._dynamic_batch_size = tf.placeholder(dtype=tf.int32,
                                                  name="dynamic_batch_size")
        self._action_init_tensor = tf.placeholder(dtype=tf.float32,
                                                  name="action_init_tensor",
                                                  shape=(None,
                                                         self.action_dim))
        self._tolerance_tensor = tf.placeholder(dtype=tf.float32,
                                                name="tolerance_tensor",
                                                shape=())

        sample_mean_init = self._action_init_tensor
        sample_covariance_diag_init = tf.ones_like(self._action_init_tensor)
        top_k_value_init = tf.constant(
            [np.inf]) * tf.ones(shape=(self._dynamic_batch_size, 1))
        top_k_action_samples_init = tf.tile(
            tf.expand_dims(tf.zeros_like(self._action_init_tensor), axis=1),
            [1, top_k_num, 1])
        random_sampler = tfp.distributions.MultivariateNormalDiag(
            loc=np.zeros(self.action_dim), scale_diag=np.ones(self.action_dim))

        def cond_cross_entropy(itr, cond_terminate, sample_mean,
                               sample_covariance_diag, top_k_value,
                               top_k_action_samples):
            del sample_mean, sample_covariance_diag, top_k_value, top_k_action_samples
            cond_1 = tf.math.less(itr, self.action_maximization_iterations)
            return tf.math.logical_and(cond_1, tf.logical_not(cond_terminate))

        def body_cross_entropy(itr, cond_terminate, sample_mean,
                               sample_covariance_diag, top_k_value,
                               top_k_action_samples):
            """Function for cross entropy search of actions."""
            del top_k_action_samples
            top_k_value_prev = top_k_value
            batch_sample_mean = tf.reshape(
                tf.tile(sample_mean, [1, num_samples]),
                [self._dynamic_batch_size * num_samples, self.action_dim])
            batch_sample_covariance_diag = tf.reshape(
                tf.tile(sample_covariance_diag, [1, num_samples]),
                [self._dynamic_batch_size * num_samples, self.action_dim])

            action_samples = self._action_projection(
                batch_sample_mean +
                batch_sample_covariance_diag * tf.cast(random_sampler.sample(
                    sample_shape=[self._dynamic_batch_size * num_samples]),
                                                       dtype=tf.float32))

            state_samples = tf.reshape(
                tf.tile(self._state_tensor, [1, num_samples]),
                [self._dynamic_batch_size * num_samples, self.state_dim])
            action_samples = tf.reshape(
                action_samples,
                [self._dynamic_batch_size * num_samples, self.action_dim])
            values = tf.reshape(
                self._build_q_function_net(state_samples, action_samples),
                [self._dynamic_batch_size, num_samples])

            # everything is in batch mode
            top_k_index = tf.argsort(values, axis=1,
                                     direction="DESCENDING")[:, 0:top_k_num]
            top_k_index_1d = tf.reshape(
                top_k_index, [self._dynamic_batch_size * top_k_num, 1])
            counter_tensor_1d = tf.reshape(
                tf.tile(
                    tf.reshape(tf.range(self._dynamic_batch_size),
                               [self._dynamic_batch_size, 1]), [1, top_k_num]),
                [self._dynamic_batch_size * top_k_num, 1])

            top_k_index_2d = tf.concat([counter_tensor_1d, top_k_index_1d],
                                       axis=1)

            action_samples = tf.reshape(
                action_samples,
                [self._dynamic_batch_size, num_samples, self.action_dim])
            top_k_action_samples = tf.gather_nd(action_samples, top_k_index_2d)
            top_k_action_samples = tf.reshape(
                top_k_action_samples,
                [self._dynamic_batch_size, top_k_num, self.action_dim])

            top_k_values = tf.gather_nd(values, top_k_index_2d)
            top_k_values = tf.reshape(top_k_values,
                                      [self._dynamic_batch_size, top_k_num])

            # it's a batch_size x 1 tensor
            top_k_value = tf.reshape(tf.reduce_mean(top_k_values, axis=1),
                                     [self._dynamic_batch_size, 1])

            sample_mean = tf.reduce_mean(top_k_action_samples, axis=1)
            sample_covariance_diag = tf.math.reduce_variance(
                top_k_action_samples, axis=1)

            itr = itr + 1
            cond_terminate = tf.less_equal(
                tf.reduce_mean(tf.math.abs(top_k_value - top_k_value_prev)),
                self._tolerance_tensor)
            return itr, cond_terminate, sample_mean, sample_covariance_diag, \
                top_k_value, top_k_action_samples

        self.cost_optimizer = tf.while_loop(
            cond_cross_entropy, body_cross_entropy, [
                tf.constant(0),
                tf.constant(False), sample_mean_init,
                sample_covariance_diag_init, top_k_value_init,
                top_k_action_samples_init
            ])
示例#10
0
def run(params, y_data_test, siz_x_data, y_normscale, load_dir):

    multi_modal = True

    # USEFUL SIZES
    xsh1 = siz_x_data
    if params['by_channel'] == True:
        ysh0 = np.shape(y_data_test)[0]
        ysh1 = np.shape(y_data_test)[1]
    else:
        ysh0 = np.shape(y_data_test)[1]
        ysh1 = np.shape(y_data_test)[2]
    z_dimension = params['z_dimension']
    n_weights_r1 = params['n_weights_r1']
    n_weights_r2 = params['n_weights_r2']
    n_weights_q = params['n_weights_q']
    n_modes = params['n_modes']
    n_hlayers_r1 = len(params['n_weights_r1'])
    n_hlayers_r2 = len(params['n_weights_r2'])
    n_hlayers_q = len(params['n_weights_q'])
    n_conv_r1 = len(params['n_filters_r1'])
    n_conv_r2 = len(params['n_filters_r2'])
    n_conv_q = len(params['n_filters_q'])
    n_filters_r1 = params['n_filters_r1']
    n_filters_r2 = params['n_filters_r2']
    n_filters_q = params['n_filters_q']
    filter_size_r1 = params['filter_size_r1']
    filter_size_r2 = params['filter_size_r2']
    filter_size_q = params['filter_size_q']
    n_convsteps = params['n_convsteps']
    batch_norm = params['batch_norm']
    red = params['reduce']
    if n_convsteps != None:
        ysh_conv_r1 = int(ysh1*n_filters_r1/2**n_convsteps) if red==True else int(ysh1/2**n_convsteps)
        ysh_conv_r2 = int(ysh1*n_filters_r2/2**n_convsteps) if red==True else int(ysh1/2**n_convsteps)
        ysh_conv_q = int(ysh1*n_filters_q/2**n_convsteps) if red==True else int(ysh1/2**n_convsteps)
    else:
        ysh_conv_r1 = int(ysh1)
        ysh_conv_r2 = int(ysh1)
        ysh_conv_q = int(ysh1)
    drate = params['drate']
    maxpool_r1 = params['maxpool_r1']
    maxpool_r2 = params['maxpool_r2']
    maxpool_q = params['maxpool_q']
    conv_strides_r1 = params['conv_strides_r1']
    conv_strides_r2 = params['conv_strides_r2']
    conv_strides_q = params['conv_strides_q']
    pool_strides_r1 = params['pool_strides_r1']
    pool_strides_r2 = params['pool_strides_r2']
    pool_strides_q = params['pool_strides_q']
    if params['reduce'] == True or n_filters_r1 != None:
        if params['by_channel'] == True:
            num_det = np.shape(y_data_test)[2]
        else:
            num_det = ysh0
    else:
        num_det = None
    # identify the indices of different sets of physical parameters
    vonmise_mask, vonmise_idx_mask, vonmise_len = get_param_index(params['inf_pars'],params['vonmise_pars'])
    gauss_mask, gauss_idx_mask, gauss_len = get_param_index(params['inf_pars'],params['gauss_pars'])
    sky_mask, sky_idx_mask, sky_len = get_param_index(params['inf_pars'],params['sky_pars'])
    ra_mask, ra_idx_mask, ra_len = get_param_index(params['inf_pars'],['ra'])
    dec_mask, dec_idx_mask, dec_len = get_param_index(params['inf_pars'],['dec'])
    m1_mask, m1_idx_mask, m1_len = get_param_index(params['inf_pars'],['mass_1'])
    m2_mask, m2_idx_mask, m2_len = get_param_index(params['inf_pars'],['mass_2'])
    idx_mask = np.argsort(gauss_idx_mask + vonmise_idx_mask + m1_idx_mask + m2_idx_mask + sky_idx_mask) # + dist_idx_mask)
    masses_len = m1_len + m2_len

   
    graph = tf.Graph()
    session = tf.Session(graph=graph)
    with graph.as_default():
        tf.set_random_seed(np.random.randint(0,10))
        SMALL_CONSTANT = 1e-12

        # PLACEHOLDERS
        bs_ph = tf.placeholder(dtype=tf.int64, name="bs_ph")                       # batch size placeholder
        y_ph = tf.placeholder(dtype=tf.float32, shape=[None, params['ndata'], num_det], name="y_ph")

        # LOAD VICI NEURAL NETWORKS
        r2_xzy = VICI_decoder.VariationalAutoencoder('VICI_decoder', vonmise_mask, gauss_mask, m1_mask, m2_mask, sky_mask, n_input1=z_dimension, 
                                                     n_input2=params['ndata'], n_output=xsh1, n_channels=num_det, n_weights=n_weights_r2, 
                                                     drate=drate, n_filters=n_filters_r2, 
                                                     filter_size=filter_size_r2, maxpool=maxpool_r2)
        r1_zy = VICI_encoder.VariationalAutoencoder('VICI_encoder', n_input=params['ndata'], n_output=z_dimension, n_channels=num_det, n_weights=n_weights_r1,   # generates params for r1(z|y)
                                                    n_modes=n_modes, drate=drate, n_filters=n_filters_r1, 
                                                    filter_size=filter_size_r1, maxpool=maxpool_r1)
        q_zxy = VICI_VAE_encoder.VariationalAutoencoder('VICI_VAE_encoder', n_input1=xsh1, n_input2=params['ndata'], n_output=z_dimension, 
                                                     n_channels=num_det, n_weights=n_weights_q, drate=drate, 
                                                     n_filters=n_filters_q, filter_size=filter_size_q, maxpool=maxpool_q)

        # reduce the y data size
        y_conv = y_ph

        # GET r1(z|y)
        r1_loc, r1_scale, r1_weight = r1_zy._calc_z_mean_and_sigma(y_conv)
        temp_var_r1 = SMALL_CONSTANT + tf.exp(r1_scale)


        # define the r1(z|y) mixture model
        bimix_gauss = tfd.MixtureSameFamily(
                          mixture_distribution=tfd.Categorical(logits=r1_weight),
                          components_distribution=tfd.MultivariateNormalDiag(
                          loc=r1_loc,
                          scale_diag=tf.sqrt(temp_var_r1)))


        # DRAW FROM r1(z|y)
        r1_zy_samp = bimix_gauss.sample()


        # GET r2(x|z,y) from r1(z|y) samples
        reconstruction_xzy = r2_xzy.calc_reconstruction(r1_zy_samp,y_ph)

        # ugly but needed for now
        # extract the means and variances of the physical parameter distributions
        r2_xzy_mean_gauss = reconstruction_xzy[0]
        r2_xzy_log_sig_sq_gauss = reconstruction_xzy[1]
        r2_xzy_mean_vonmise = reconstruction_xzy[2]
        r2_xzy_log_sig_sq_vonmise = reconstruction_xzy[3]
        r2_xzy_mean_m1 = reconstruction_xzy[4]
        r2_xzy_log_sig_sq_m1 = reconstruction_xzy[5]
        r2_xzy_mean_m2 = reconstruction_xzy[6]
        r2_xzy_log_sig_sq_m2 = reconstruction_xzy[7]
        r2_xzy_mean_sky = reconstruction_xzy[8]
        r2_xzy_log_sig_sq_sky = reconstruction_xzy[9]

        # draw from r2(x|z,y) - the masses
        temp_var_r2_m1 = SMALL_CONSTANT + tf.exp(r2_xzy_log_sig_sq_m1)     # the m1 variance
        temp_var_r2_m2 = SMALL_CONSTANT + tf.exp(r2_xzy_log_sig_sq_m2)     # the m2 variance
        joint = tfd.JointDistributionSequential([
                       tfd.Independent(tfd.TruncatedNormal(r2_xzy_mean_m1,tf.sqrt(temp_var_r2_m1),0,1,validate_args=True,allow_nan_stats=True),reinterpreted_batch_ndims=0),  # m1
            lambda b0: tfd.Independent(tfd.TruncatedNormal(r2_xzy_mean_m2,tf.sqrt(temp_var_r2_m2),0,b0,validate_args=True,allow_nan_stats=True),reinterpreted_batch_ndims=0)],    # m2
            validate_args=True)
        r2_xzy_samp_masses = tf.transpose(tf.reshape(joint.sample(),[2,-1]))  # sample from the m1.m2 space

        # draw from r2(x|z,y) - the truncated gaussian 
        temp_var_r2_gauss = SMALL_CONSTANT + tf.exp(r2_xzy_log_sig_sq_gauss)
        @tf.function    # make this s a tensorflow function
        def truncnorm(idx,output):    # we set up a function that adds the log-likelihoods and also increments the counter
            loc = tf.slice(r2_xzy_mean_gauss,[0,idx],[-1,1])            # take each specific parameter mean using slice
            std = tf.sqrt(tf.slice(temp_var_r2_gauss,[0,idx],[-1,1]))   # take each specific parameter std using slice
            tn = tfd.TruncatedNormal(loc,std,0.0,1.0)                   # define the truncated Gaussian distribution
            return [idx+1, tf.concat([output,tf.reshape(tn.sample(),[bs_ph,1])],axis=1)] # return the updated index and new samples concattenated to the input 
        # we do the loop until we've hit all the truncated gaussian parameters - i starts at 0 and the samples starts with a set of zeros that we cut out later
        idx = tf.constant(0)              # initialise counter
        nsamp = params['n_samples']       # define the number of samples (MUST be a normal int NOT tensor so can't use bs_ph)
        output = tf.zeros([nsamp,1],dtype=tf.float32)    # initialise the output (we cut this first set of zeros out later
        condition = lambda i,output: i<gauss_len         # define the while loop stopping condition
        _,r2_xzy_samp_gauss = tf.while_loop(condition, truncnorm, loop_vars=[idx,output],shape_invariants=[idx.get_shape(), tf.TensorShape([nsamp,None])])
        r2_xzy_samp_gauss = tf.slice(tf.reshape(r2_xzy_samp_gauss,[-1,gauss_len+1]),[0,1],[-1,-1])   # cut out the actual samples - delete the initial vector of zeros

        # draw from r2(x|z,y) - the vonmises part
        temp_var_r2_vonmise = SMALL_CONSTANT + tf.exp(r2_xzy_log_sig_sq_vonmise)
        con = tf.reshape(tf.math.reciprocal(temp_var_r2_vonmise),[-1,vonmise_len])   # modelling wrapped scale output as log variance
        von_mises = tfp.distributions.VonMises(loc=2.0*np.pi*(r2_xzy_mean_vonmise-0.5), concentration=con)
        r2_xzy_samp_vonmise = tf.reshape(von_mises.sample()/(2.0*np.pi) + 0.5,[-1,vonmise_len])   # sample from the von mises distribution and shift and scale from -pi-pi to 0-1
        
        # draw from r2(x|z,y) - the von mises Fisher 
        temp_var_r2_sky = SMALL_CONSTANT + tf.exp(r2_xzy_log_sig_sq_sky)
        con = tf.reshape(tf.math.reciprocal(temp_var_r2_sky),[bs_ph])   # modelling wrapped scale output as log variance - only 1 concentration parameter for all sky
        von_mises_fisher = tfp.distributions.VonMisesFisher(
                          mean_direction=tf.math.l2_normalize(tf.reshape(r2_xzy_mean_sky,[bs_ph,3]),axis=1),
                          concentration=con)   # define p_vm(2*pi*mu,con=1/sig^2)
        xyz = tf.reshape(von_mises_fisher.sample(),[bs_ph,3])          # sample the distribution
        samp_ra = tf.math.floormod(tf.atan2(tf.slice(xyz,[0,1],[-1,1]),tf.slice(xyz,[0,0],[-1,1])),2.0*np.pi)/(2.0*np.pi)   # convert to the rescaled 0->1 RA from the unit vector
        samp_dec = (tf.asin(tf.slice(xyz,[0,2],[-1,1])) + 0.5*np.pi)/np.pi                       # convert to the rescaled 0->1 dec from the unit vector
        r2_xzy_samp_sky = tf.reshape(tf.concat([samp_ra,samp_dec],axis=1),[bs_ph,2])             # group the sky samples

        # combine the samples
        r2_xzy_samp = tf.concat([r2_xzy_samp_gauss,r2_xzy_samp_vonmise,r2_xzy_samp_masses,r2_xzy_samp_sky],axis=1)
        r2_xzy_samp = tf.gather(r2_xzy_samp,tf.constant(idx_mask),axis=1)

        # VARIABLES LISTS
        var_list_VICI = [var for var in tf.trainable_variables() if var.name.startswith("VICI")]

        # INITIALISE AND RUN SESSION
        init = tf.initialize_all_variables()
        session.run(init)
        saver_VICI = tf.train.Saver(var_list_VICI)
        saver_VICI.restore(session,load_dir)

    # ESTIMATE TEST SET RECONSTRUCTION PER-PIXEL APPROXIMATE MARGINAL LIKELIHOOD and draw from q(x|y)
    ns = params['n_samples'] # number of samples to save per reconstruction 

    y_data_test_exp = np.tile(y_data_test,(ns,1))/y_normscale
    y_data_test_exp = y_data_test_exp.reshape(-1,params['ndata'],num_det)
    run_startt = time.time()
    xs, mode_weights = session.run([r2_xzy_samp,r1_weight],feed_dict={bs_ph:ns,y_ph:y_data_test_exp})
    run_endt = time.time()

#    run_startt = time.time()
#    xs, mode_weights = session.run([r2_xzy_samp,r1_weight],feed_dict={bs_ph:ns,y_ph:y_data_test_exp})
#    run_endt = time.time()

    return xs, (run_endt - run_startt), mode_weights
def _greedy_decode(input_embeddings,
                   output_vocab_size,
                   target_end_id,
                   target_start_id,
                   output_vocab_embeddings_table,
                   source_len,
                   model_config,
                   mode,
                   input_copy_mask=None,
                   clean_output_mask=None):
    """Fast decoding."""
    encoder_output = common_layers.linear_transform(
        input_embeddings,
        output_size=model_config.model_parameters.encoder_dims,
        scope="bert_to_transformer")

    decode_length = model_config.data_options.max_decode_length

    # Expand the inputs in to the beam width.
    def symbols_to_logits_fn(logit_indices, current_index):
        """Go from targets to logits."""
        logit_indices = tf.expand_dims(logit_indices, 0)
        decode_steps = decode_utils.get_decode_steps(logit_indices,
                                                     output_vocab_size,
                                                     model_config)
        target_embeddings = _get_target_embeddings(
            input_embeddings, output_vocab_embeddings_table, decode_steps,
            model_config)
        decoder_output = _build_transformer_decoder(
            encoder_output,
            source_len,
            target_embeddings,
            mode,
            model_config,
            single_step_index=current_index)

        logits = _get_action_logits(encoder_output,
                                    decoder_output,
                                    output_vocab_embeddings_table,
                                    output_vocab_size,
                                    model_config,
                                    input_copy_mask=input_copy_mask,
                                    clean_output_mask=clean_output_mask)

        # Squeeze batch dimension and length dimension, as both should be 1.
        logits = tf.squeeze(logits, axis=[0, 1])
        # Shape of logits should now be (output_vocab_size).
        return logits

    def loop_cond(i, decoded_ids, unused_logprobs):
        """Loop conditional that returns false to stop loop."""
        return tf.logical_and(
            tf.reduce_all(tf.not_equal(decoded_ids, target_end_id)),
            tf.less(i, decode_length))

    def inner_loop(i, decoded_ids, logprobs):
        """Decoder function invoked on each while loop iteration."""
        logits = symbols_to_logits_fn(decoded_ids, i)
        next_id = tf.argmax(logits, axis=0)
        softmax = tf.nn.softmax(logits)
        extended_vocab_size = tf.shape(softmax)[-1]
        mask = tf.one_hot(next_id, extended_vocab_size)
        prob = tf.reduce_sum(softmax * mask)
        logprob = tf.log(prob)

        # Add one-hot values to output Tensors, since values at index > i+1 should
        # still be zero.
        logprobs += tf.one_hot(i + 1,
                               decode_length + 1,
                               on_value=logprob,
                               dtype=tf.float32)
        decoded_ids += tf.one_hot(i + 1,
                                  decode_length + 1,
                                  on_value=next_id,
                                  dtype=tf.int64)

        return i + 1, decoded_ids, logprobs

    initial_ids = tf.zeros(dtype=tf.int64, shape=[decode_length + 1])
    initial_ids += tf.one_hot(0,
                              decode_length + 1,
                              on_value=tf.cast(target_start_id, tf.int64))
    initial_logprob = tf.zeros(dtype=tf.float32, shape=[decode_length + 1])
    initial_i = tf.constant(0)

    initial_values = [initial_i, initial_ids, initial_logprob]

    _, decoded_ids, logprobs = tf.while_loop(loop_cond, inner_loop,
                                             initial_values)

    # Remove <START> symbol.
    decoded_ids = decoded_ids[1:]
    logprobs = logprobs[1:]
    # Sum logprobs to get scores for overall sequence.
    logprobs = tf.reduce_sum(logprobs, axis=0)

    # Expand decoded_ids and logprobs to reflect beam width dimension of 1.
    decoded_ids = tf.expand_dims(decoded_ids, 0)
    logprobs = tf.expand_dims(logprobs, 0)

    # This is the output dict that the function returns.
    output_decode_steps = decode_utils.get_decode_steps(
        decoded_ids, output_vocab_size, model_config)
    predictions = decode_utils.get_predictions(output_decode_steps)
    predictions[constants.SCORES_KEY] = logprobs

    return predictions
示例#12
0
def train(params, x_data, y_data, x_data_test, y_data_test, y_data_test_noisefree, y_normscale, save_dir, truth_test, bounds, fixed_vals, posterior_truth_test,snrs_test=None):    

    # if True, do multi-modal
    multi_modal = True

    # USEFUL SIZES
    xsh = np.shape(x_data)
   
    ysh = np.shape(y_data)[1]
    n_convsteps = params['n_convsteps']
    z_dimension = params['z_dimension']
    bs = params['batch_size']
    n_weights_r1 = params['n_weights_r1']
    n_weights_r2 = params['n_weights_r2']
    n_weights_q = params['n_weights_q']
    n_modes = params['n_modes']
    n_hlayers_r1 = len(params['n_weights_r1'])
    n_hlayers_r2 = len(params['n_weights_r2'])
    n_hlayers_q = len(params['n_weights_q'])
    n_conv_r1 = len(params['n_filters_r1'])
    n_conv_r2 = len(params['n_filters_r2'])
    n_conv_q = len(params['n_filters_q'])
    n_filters_r1 = params['n_filters_r1']
    n_filters_r2 = params['n_filters_r2']
    n_filters_q = params['n_filters_q']
    filter_size_r1 = params['filter_size_r1']
    filter_size_r2 = params['filter_size_r2']
    filter_size_q = params['filter_size_q']
    maxpool_r1 = params['maxpool_r1']
    maxpool_r2 = params['maxpool_r2']
    maxpool_q = params['maxpool_q']
    conv_strides_r1 = params['conv_strides_r1']
    conv_strides_r2 = params['conv_strides_r2']
    conv_strides_q = params['conv_strides_q']
    pool_strides_r1 = params['pool_strides_r1']
    pool_strides_r2 = params['pool_strides_r2']
    pool_strides_q = params['pool_strides_q']
    batch_norm = params['batch_norm']
    red = params['reduce']
    if n_convsteps != None:
        ysh_conv_r1 = int(ysh*n_filters_r1/2**n_convsteps) if red==True else int(ysh/2**n_convsteps)
        ysh_conv_r2 = int(ysh*n_filters_r2/2**n_convsteps) if red==True else int(ysh/2**n_convsteps)
        ysh_conv_q = int(ysh*n_filters_q/2**n_convsteps) if red==True else int(ysh/2**n_convsteps)
    else:
        ysh_conv_r1 = int(ysh_r1)
        ysh_conv_r2 = int(ysh_r2)
        ysh_conv_q = int(ysh_q)
    drate = params['drate']
    ramp_start = params['ramp_start']
    ramp_end = params['ramp_end']
    num_det = len(fixed_vals['det'])


    # identify the indices of different sets of physical parameters
    vonmise_mask, vonmise_idx_mask, vonmise_len = get_param_index(params['inf_pars'],params['vonmise_pars'])
    gauss_mask, gauss_idx_mask, gauss_len = get_param_index(params['inf_pars'],params['gauss_pars'])
    sky_mask, sky_idx_mask, sky_len = get_param_index(params['inf_pars'],params['sky_pars'])
    ra_mask, ra_idx_mask, ra_len = get_param_index(params['inf_pars'],['ra'])
    dec_mask, dec_idx_mask, dec_len = get_param_index(params['inf_pars'],['dec'])
    m1_mask, m1_idx_mask, m1_len = get_param_index(params['inf_pars'],['mass_1'])
    m2_mask, m2_idx_mask, m2_len = get_param_index(params['inf_pars'],['mass_2'])
    idx_mask = np.argsort(gauss_idx_mask + vonmise_idx_mask + m1_idx_mask + m2_idx_mask + sky_idx_mask) # + dist_idx_mask)

    graph = tf.Graph()
    session = tf.Session(graph=graph)
    with graph.as_default():

        # PLACE HOLDERS
        bs_ph = tf.placeholder(dtype=tf.int64, name="bs_ph")                       # batch size placeholder
        x_ph = tf.placeholder(dtype=tf.float32, shape=[None, xsh[1]], name="x_ph") # params placeholder
        y_ph = tf.placeholder(dtype=tf.float32, shape=[None, params['ndata'], num_det], name="y_ph")
        ramp = tf.placeholder(dtype=tf.float32)    # the ramp to slowly increase the KL contribution

        # LOAD VICI NEURAL NETWORKS
        r1_zy = VICI_encoder.VariationalAutoencoder('VICI_encoder', n_input=params['ndata'], n_output=z_dimension, n_channels=num_det, n_weights=n_weights_r1,   # generates params for r1(z|y)
                                                    n_modes=n_modes, drate=drate, n_filters=n_filters_r1, 
                                                    filter_size=filter_size_r1, maxpool=maxpool_r1)
        r2_xzy = VICI_decoder.VariationalAutoencoder('VICI_decoder', vonmise_mask, gauss_mask, m1_mask, m2_mask, sky_mask, n_input1=z_dimension, 
                                                     n_input2=params['ndata'], n_output=xsh[1], n_channels=num_det, n_weights=n_weights_r2, 
                                                     drate=drate, n_filters=n_filters_r2, 
                                                     filter_size=filter_size_r2, maxpool=maxpool_r2)
        q_zxy = VICI_VAE_encoder.VariationalAutoencoder('VICI_VAE_encoder', n_input1=xsh[1], n_input2=params['ndata'], n_output=z_dimension, 
                                                     n_channels=num_det, n_weights=n_weights_q, drate=drate, 
                                                     n_filters=n_filters_q, filter_size=filter_size_q, maxpool=maxpool_q) 
        tf.set_random_seed(np.random.randint(0,10))

        # reduce the y data size
        y_conv = y_ph

        # GET r1(z|y)
        # run inverse autoencoder to generate mean and logvar of z given y data - these are the parameters for r1(z|y)
        r1_loc, r1_scale, r1_weight = r1_zy._calc_z_mean_and_sigma(y_conv)
        temp_var_r1 = SMALL_CONSTANT + tf.exp(r1_scale)

        
        # define the r1(z|y) mixture model
        bimix_gauss = tfd.MixtureSameFamily(
                          mixture_distribution=tfd.Categorical(logits=r1_weight),
                          components_distribution=tfd.MultivariateNormalDiag(
                          loc=r1_loc,
                          scale_diag=tf.sqrt(temp_var_r1)))


        # DRAW FROM r1(z|y) - given the Gaussian parameters generate z samples
        r1_zy_samp = bimix_gauss.sample()        
        
        # GET q(z|x,y)
        q_zxy_mean, q_zxy_log_sig_sq = q_zxy._calc_z_mean_and_sigma(x_ph,y_conv)

        # DRAW FROM q(z|x,y)
        temp_var_q = SMALL_CONSTANT + tf.exp(q_zxy_log_sig_sq)
        mvn_q = tfp.distributions.MultivariateNormalDiag(
                          loc=q_zxy_mean,
                          scale_diag=tf.sqrt(temp_var_q))
        q_zxy_samp = mvn_q.sample()  
       
        # GET r2(x|z,y)
        eps = tf.random.normal([bs_ph, params['ndata'], num_det], 0, 1., dtype=tf.float32)
        y_ph_ramp = tf.add(tf.multiply(ramp,y_conv), tf.multiply((1.0-ramp), eps))
        reconstruction_xzy = r2_xzy.calc_reconstruction(q_zxy_samp,y_ph_ramp)

        # ugly but required for now - unpack the r2 output params
        r2_xzy_mean_gauss = reconstruction_xzy[0]           # truncated gaussian mean
        r2_xzy_log_sig_sq_gauss = reconstruction_xzy[1]     # truncated gaussian log var
        r2_xzy_mean_vonmise = reconstruction_xzy[2]         # vonmises means
        r2_xzy_log_sig_sq_vonmise = reconstruction_xzy[3]   # vonmises log var
        r2_xzy_mean_m1 = reconstruction_xzy[4]              # m1 mean
        r2_xzy_log_sig_sq_m1 = reconstruction_xzy[5]        # m1 var
        r2_xzy_mean_m2 = reconstruction_xzy[6]              # m2 mean (m2 will be conditional on m1)
        r2_xzy_log_sig_sq_m2 = reconstruction_xzy[7]        # m2 log var (m2 will be conditional on m1)
        r2_xzy_mean_sky = reconstruction_xzy[8]             # sky mean unit vector (3D)
        r2_xzy_log_sig_sq_sky = reconstruction_xzy[9]       # sky log var (1D)

        # COST FROM RECONSTRUCTION - the masses
        # this sets up a joint distribution on m1 and m2 with m2 being conditional on m1
        # the ramp eveolves the truncation boundaries from far away to 0->1 for m1 and 0->m1 for m2
        if m1_len>0 and m2_len>0:
            temp_var_r2_m1 = SMALL_CONSTANT + tf.exp(r2_xzy_log_sig_sq_m1)     # the safe r2 variance
            temp_var_r2_m2 = SMALL_CONSTANT + tf.exp(r2_xzy_log_sig_sq_m2)
            joint = tfd.JointDistributionSequential([    # shrink the truncation with the ramp
                       tfd.Independent(tfd.TruncatedNormal(r2_xzy_mean_m1,tf.sqrt(temp_var_r2_m1),-GAUSS_RANGE*(1.0-ramp),GAUSS_RANGE*(1.0-ramp) + 1.0),reinterpreted_batch_ndims=0),  # m1
                lambda b0: tfd.Independent(tfd.TruncatedNormal(r2_xzy_mean_m2,tf.sqrt(temp_var_r2_m2),-GAUSS_RANGE*(1.0-ramp),GAUSS_RANGE*(1.0-ramp) + ramp*b0),reinterpreted_batch_ndims=0)],    # m2
            )
            reconstr_loss_masses = joint.log_prob((tf.boolean_mask(x_ph,m1_mask,axis=1),tf.boolean_mask(x_ph,m2_mask,axis=1)))

        # COST FROM RECONSTRUCTION - Truncated Gaussian parts
        # this sets up a loop over uncorreltaed truncated Gaussians 
        # the ramp evolves the boundaries from far away to 0->1 
        if gauss_len>0:
            temp_var_r2_gauss = SMALL_CONSTANT + tf.exp(r2_xzy_log_sig_sq_gauss)
            gauss_x = tf.boolean_mask(x_ph,gauss_mask,axis=1)
            @tf.function
            def truncnorm(i,lp):    # we set up a function that adds the log-likelihoods and also increments the counter
                loc = tf.slice(r2_xzy_mean_gauss,[0,i],[-1,1])
                std = tf.sqrt(tf.slice(temp_var_r2_gauss,[0,i],[-1,1]))
                pos = tf.slice(gauss_x,[0,i],[-1,1])  
                tn = tfd.TruncatedNormal(loc,std,-GAUSS_RANGE*(1.0-ramp),GAUSS_RANGE*(1.0-ramp) + 1.0)   # shrink the truncation with the ramp
                return [i+1, lp + tn.log_prob(pos)]
            # we do the loop until we've hit all the truncated gaussian parameters - i starts at 0 and the logprob starts at 0 
            _,reconstr_loss_gauss = tf.while_loop(lambda i,reconstr_loss_gauss: i<gauss_len, truncnorm, [0,tf.zeros([bs_ph],dtype=tf.dtypes.float32)])

        # COST FROM RECONSTRUCTION - Von Mises parts for single parameters that wrap over 2pi
        if vonmise_len>0:
            temp_var_r2_vonmise = SMALL_CONSTANT + tf.exp(r2_xzy_log_sig_sq_vonmise)
            con = tf.reshape(tf.math.reciprocal(temp_var_r2_vonmise),[-1,vonmise_len])   # modelling wrapped scale output as log variance - convert to concentration
            von_mises = tfp.distributions.VonMises(
                          loc=2.0*np.pi*(tf.reshape(r2_xzy_mean_vonmise,[-1,vonmise_len])-0.5),   # remap 0>1 mean onto -pi->pi range
                          concentration=con)
            reconstr_loss_vonmise = von_mises.log_prob(2.0*np.pi*(tf.reshape(tf.boolean_mask(x_ph,vonmise_mask,axis=1),[-1,vonmise_len]) - 0.5))   # 2pi is the von mises input range
            
            reconstr_loss_vonmise = reconstr_loss_vonmise[:,0] + reconstr_loss_vonmise[:,1]

            # computing Gaussian likelihood for von mises parameters to be faded away with the ramp
            gauss_vonmises = tfp.distributions.MultivariateNormalDiag(
                         loc=r2_xzy_mean_vonmise,
                         scale_diag=tf.sqrt(temp_var_r2_vonmise))
            reconstr_loss_gauss_vonmise = gauss_vonmises.log_prob(tf.boolean_mask(x_ph,vonmise_mask,axis=1))        
            reconstr_loss_vonmise = ramp*reconstr_loss_vonmise + (1.0-ramp)*reconstr_loss_gauss_vonmise    # start with a Gaussian model and fade in the true vonmises
        else:
            reconstr_loss_vonmise = 0.0

        # COST FROM RECONSTRUCTION - Von Mises Fisher (sky) parts
        if sky_len>0:
            temp_var_r2_sky = SMALL_CONSTANT + tf.exp(r2_xzy_log_sig_sq_sky)
            con = tf.reshape(tf.math.reciprocal(temp_var_r2_sky),[bs_ph])   # modelling wrapped scale output as log variance - only 1 concentration parameter for all sky
            loc_xyz = tf.math.l2_normalize(tf.reshape(r2_xzy_mean_sky,[-1,3]),axis=1)    # take the 3 output mean params from r2 and normalse so they are a unit vector
            von_mises_fisher = tfp.distributions.VonMisesFisher(
                          mean_direction=loc_xyz,
                          concentration=con)
            ra_sky = 2.0*np.pi*tf.reshape(tf.boolean_mask(x_ph,ra_mask,axis=1),[-1,1])       # convert the scaled 0->1 true RA value back to radians
            dec_sky = np.pi*(tf.reshape(tf.boolean_mask(x_ph,dec_mask,axis=1),[-1,1]) - 0.5) # convert the scaled 0>1 true dec value back to radians
            xyz_unit = tf.reshape(tf.concat([tf.cos(ra_sky)*tf.cos(dec_sky),tf.sin(ra_sky)*tf.cos(dec_sky),tf.sin(dec_sky)],axis=1),[-1,3])   # construct the true parameter unit vector
            reconstr_loss_sky = von_mises_fisher.log_prob(tf.math.l2_normalize(xyz_unit,axis=1))   # normalise it for safety (should already be normalised) and compute the logprob

            # computing Gaussian likelihood for von mises Fisher (sky) parameters to be faded away with the ramp
            mean_ra = tf.math.floormod(tf.atan2(tf.slice(loc_xyz,[0,1],[-1,1]),tf.slice(loc_xyz,[0,0],[-1,1])),2.0*np.pi)/(2.0*np.pi)    # convert the unit vector to scaled 0->1 RA 
            mean_dec = (tf.asin(tf.slice(loc_xyz,[0,2],[-1,1])) + 0.5*np.pi)/np.pi        # convert the unit vector to scaled 0->1 dec
            mean_sky = tf.reshape(tf.concat([mean_ra,mean_dec],axis=1),[bs_ph,2])        # package up the scaled RA and dec 
            gauss_sky = tfp.distributions.MultivariateNormalDiag(
                         loc=mean_sky,
                         scale_diag=tf.concat([tf.sqrt(temp_var_r2_sky),tf.sqrt(temp_var_r2_sky)],axis=1))   # use the same 1D concentration parameter for both RA and dec dimensions
            reconstr_loss_gauss_sky = gauss_sky.log_prob(tf.boolean_mask(x_ph,sky_mask,axis=1))     # compute the logprob at the true sky location
            reconstr_loss_sky = ramp*reconstr_loss_sky + (1.0-ramp)*reconstr_loss_gauss_sky   # start with a Gaussian model and fade in the true vonmises Fisher

        cost_R = -1.0*tf.reduce_mean(reconstr_loss_gauss + reconstr_loss_vonmise + reconstr_loss_masses + reconstr_loss_sky)
        r2_xzy_mean = tf.gather(tf.concat([r2_xzy_mean_gauss,r2_xzy_mean_vonmise,r2_xzy_mean_m1,r2_xzy_mean_m2,r2_xzy_mean_sky],axis=1),tf.constant(idx_mask),axis=1)      # put the elements back in order
        r2_xzy_scale = tf.gather(tf.concat([r2_xzy_log_sig_sq_gauss,r2_xzy_log_sig_sq_vonmise,r2_xzy_log_sig_sq_m1,r2_xzy_log_sig_sq_m2,r2_xzy_log_sig_sq_sky],axis=1),tf.constant(idx_mask),axis=1)   # put the elements back in order
        
        log_q_q = mvn_q.log_prob(q_zxy_samp)
        log_r1_q = bimix_gauss.log_prob(q_zxy_samp)   # evaluate the log prob of r1 at the q samples
        KL = tf.reduce_mean(log_q_q - log_r1_q)      # average over batch

        # THE VICI COST FUNCTION
        COST = cost_R + ramp*KL #+ L1_weight_reg)

        # VARIABLES LISTS
        var_list_VICI = [var for var in tf.trainable_variables() if var.name.startswith("VICI")]
        
        # DEFINE OPTIMISER (using ADAM here)
        optimizer = tf.train.AdamOptimizer(params['initial_training_rate']) 
#        optimizer = tf.train.RMSPropOptimizer(params['initial_training_rate'])
        minimize = optimizer.minimize(COST,var_list = var_list_VICI)
        
        # INITIALISE AND RUN SESSION
        init = tf.global_variables_initializer()
        session.run(init)
        saver = tf.train.Saver()

    print('Training Inference Model...')    
    # START OPTIMISATION OF OELBO
    indices_generator = batch_manager.SequentialIndexer(params['batch_size'], xsh[0])
    plotdata = []

    load_chunk_it = 1
    for i in range(params['num_iterations']):

        next_indices = indices_generator.next_indices()

        # if load chunks true, load in data by chunks
        if params['load_by_chunks'] == True and i == int(params['load_iteration']*load_chunk_it):
            x_data, y_data = load_chunk(params['train_set_dir'],params['inf_pars'],params,bounds,fixed_vals)
            load_chunk_it += 1

        # Make noise realizations and add to training data
        next_x_data = x_data[next_indices,:]
        if params['reduce'] == True or n_conv_r1 != None:
            next_y_data = y_data[next_indices,:] + np.random.normal(0,1,size=(params['batch_size'],int(params['ndata']),len(fixed_vals['det'])))
        else:
            next_y_data = y_data[next_indices,:] + np.random.normal(0,1,size=(params['batch_size'],int(params['ndata']*len(fixed_vals['det']))))
        next_y_data /= y_normscale  # required for fast convergence

        if params['by_channel'] == False:
            next_y_data_new = [] 
            for sig in next_y_data:
                next_y_data_new.append(sig.T)
            next_y_data = np.array(next_y_data_new)
            del next_y_data_new
      
        # restore session if wanted
        if params['resume_training'] == True and i == 0:
            print(save_dir)
            saver.restore(session, save_dir)

        # compute the ramp value
        rmp = 0.0
        if params['ramp'] == True:
            if i>ramp_start:
                rmp = (np.log10(float(i)) - np.log10(ramp_start))/(np.log10(ramp_end) - np.log10(ramp_start))
            if i>ramp_end:
                rmp = 1.0  
        else:
            rmp = 1.0              

        # train the network 
        session.run(minimize, feed_dict={bs_ph:bs, x_ph:next_x_data, y_ph:next_y_data, ramp:rmp}) 
 
        # if we are in a report iteration extract cost function values
        if i % params['report_interval'] == 0 and i > 0:

            # get training loss
            cost, kl, AB_batch = session.run([cost_R, KL, r1_weight], feed_dict={bs_ph:bs, x_ph:next_x_data, y_ph:next_y_data, ramp:rmp})

            # get validation loss on test set
            cost_val, kl_val = session.run([cost_R, KL], feed_dict={bs_ph:y_data_test.shape[0], x_ph:x_data_test, y_ph:y_data_test/y_normscale, ramp:rmp})
            plotdata.append([cost,kl,cost+kl,cost_val,kl_val,cost_val+kl_val])

           
            try:
                # Make loss plot
                plt.figure()
                xvec = params['report_interval']*np.arange(np.array(plotdata).shape[0])
                plt.semilogx(xvec,np.array(plotdata)[:,0],label='recon',color='blue',alpha=0.5)
                plt.semilogx(xvec,np.array(plotdata)[:,1],label='KL',color='orange',alpha=0.5)
                plt.semilogx(xvec,np.array(plotdata)[:,2],label='total',color='green',alpha=0.5)
                plt.semilogx(xvec,np.array(plotdata)[:,3],label='recon_val',color='blue',linestyle='dotted')
                plt.semilogx(xvec,np.array(plotdata)[:,4],label='KL_val',color='orange',linestyle='dotted')
                plt.semilogx(xvec,np.array(plotdata)[:,5],label='total_val',color='green',linestyle='dotted')
                plt.ylim([-25,15])
                plt.xlabel('iteration')
                plt.ylabel('cost')
                plt.legend()
                plt.savefig('%s/latest_%s/cost_%s.png' % (params['plot_dir'],params['run_label'],params['run_label']))
                plt.ylim([np.min(np.array(plotdata)[-int(0.9*np.array(plotdata).shape[0]):,0]), np.max(np.array(plotdata)[-int(0.9*np.array(plotdata).shape[0]):,1])])
                plt.savefig('%s/latest_%s/cost_zoom_%s.png' % (params['plot_dir'],params['run_label'],params['run_label']))
                plt.close('all')
                
            except:
                pass

            if params['print_values']==True:
                print('--------------------------------------------------------------')
                print('Iteration:',i)
                print('Training -ELBO:',cost)
                print('Validation -ELBO:',cost_val)
                print('Training KL Divergence:',kl)
                print('Validation KL Divergence:',kl_val)
                print('Training Total cost:',kl + cost) 
                print('Validation Total cost:',kl_val + cost_val)
                print()

                # terminate training if vanishing gradient
                if np.isnan(kl+cost) == True or np.isnan(kl_val+cost_val) == True or kl+cost > int(1e5):
                    print('Network is returning NaN values')
                    print('Terminating network training')
                    if params['hyperparam_optim'] == True:
                        save_path = saver.save(session,save_dir)
                        return 5000.0, session, saver, save_dir
                    else:
                        exit()
                try:
                    # Save loss plot data
                    np.savetxt(save_dir.split('/')[0] + '/loss_data.txt', np.array(plotdata))
                except FileNotFoundError as err:
                    print(err)
                    pass

        if i % params['save_interval'] == 0 and i > 0:

            if params['hyperparam_optim'] == False:
                # Save model 
                save_path = saver.save(session,save_dir)
            else:
                pass


        # stop hyperparam optim training it and return KL divergence as figure of merit
        if params['hyperparam_optim'] == True and i == params['hyperparam_optim_stop']:
            save_path = saver.save(session,save_dir)

            return np.array(plotdata)[-1,2], session, saver, save_dir

        if i % params['plot_interval'] == 0 and i>0:

            n_mode_weight_copy = 100 # must be a multiple of 50
            # just run the network on the test data
            for j in range(params['r']*params['r']):

                # The trained inverse model weights can then be used to infer a probability density of solutions given new measurements
                if params['reduce'] == True or params['n_filters_r1'] != None:
                    XS, dt, _  = VICI_inverse_model.run(params, y_data_test[j].reshape([1,y_data_test.shape[1],y_data_test.shape[2]]), np.shape(x_data_test)[1],
                                                 y_normscale, 
                                                 "inverse_model_dir_%s/inverse_model.ckpt" % params['run_label'])
                else:
                    XS, dt, _  = VICI_inverse_model.run(params, y_data_test[j].reshape([1,-1]), np.shape(x_data_test)[1],
                                                 y_normscale, 
                                                 "inverse_model_dir_%s/inverse_model.ckpt" % params['run_label'])
                print('Runtime to generate {} samples = {} sec'.format(params['n_samples'],dt))            
                # Make corner plots
                # Get corner parnames to use in plotting labels
                parnames = []
                for k_idx,k in enumerate(params['rand_pars']):
                    if np.isin(k, params['inf_pars']):
                        parnames.append(params['cornercorner_parnames'][k_idx])

                defaults_kwargs = dict(
                    bins=50, smooth=0.9, label_kwargs=dict(fontsize=16),
                    title_kwargs=dict(fontsize=16),
                    truth_color='tab:orange', quantiles=[0.16, 0.84],
                    levels=(0.68,0.90,0.95), density=True,
                    plot_density=False, plot_datapoints=True,
                    max_n_ticks=3)

                figure = corner.corner(posterior_truth_test[j], **defaults_kwargs,labels=parnames,
                       color='tab:blue',truths=x_data_test[j,:],
                       show_titles=True)
                # compute weights, otherwise the 1d histograms will be different scales, could remove this
                corner.corner(XS,**defaults_kwargs,labels=parnames,
                       color='tab:red',
                       fill_contours=True,
                       show_titles=True, fig=figure)


                plt.savefig('%s/corner_plot_%s_%d-%d.png' % (params['plot_dir'],params['run_label'],i,j))
                plt.savefig('%s/latest_%s/corner_plot_%s_%d.png' % (params['plot_dir'],params['run_label'],params['run_label'],j))
                plt.close('all')
                print('Made corner plot %d' % j)

    return            
示例#13
0
    def __call__(self,
                 learner,
                 meta_batches=None,
                 inner_batches=None,
                 init_state=None,
                 unroll_n_steps=None):
        if unroll_n_steps is None:
            unroll_n_steps = self.unroll_n_steps
        else:
            print("Using passed in unroll steps")

        if inner_batches is None:
            inner_batches = self.inner_batches
        else:
            # convert the batches object to a tensorarray.
            def to_ta(t):
                return tf.TensorArray(dtype=t.dtype,
                                      size=self.unroll_n_steps).unstack(t)

            inner_batches = nest.map_structure(to_ta, inner_batches)

        if meta_batches is None:
            meta_batches = self.meta_batches
        else:
            # convert the batches object to a tensorarray.
            def ml_to_ta(t):
                return tf.TensorArray(dtype=t.dtype,
                                      size=self.meta_loss_evals *
                                      self.unroll_n_steps).unstack(t)

            meta_batches = nest.map_structure(ml_to_ta, meta_batches)

        if init_state is None:
            init_state = learner.current_state()
            init_state = tf_utils.force_copy(init_state)

        current_state = (tf.constant(0, dtype=tf.int32),
                         tf.constant(0., dtype=tf.float32), init_state)

        def loss_and_next_state_fn((idx, l, state)):
            batch = self.get_batch(idx, batches=inner_batches)
            l, s = learner.loss_and_next_state(state, loss_state=batch)
            return (idx + 1, l, s)

        def accumulate_fn((idx, _, s), (a_meta, a_inner)):
            """Accumulate loss for fold learning process."""
            cond = lambda i, a: tf.less(i, self.meta_loss_evals)

            def body_meta(i, a):
                # minus 1 as this takes the following step.
                batch = self.get_batch((idx - 1) * (self.meta_loss_evals) + i,
                                       batches=meta_batches)
                return (i + 1, a + learner.meta_loss(s, loss_state=batch))

            _, extra_losses = tf.while_loop(cond, body_meta, loop_vars=[0, 0.])

            def body_inner(i, a):
                # minus 1 as this takes the following step.
                batch = self.get_batch((idx - 1) * (self.meta_loss_evals) + i,
                                       batches=meta_batches)
                return (i + 1, a + learner.inner_loss(s, loss_state=batch))

            _, inner_losses = tf.while_loop(cond,
                                            body_inner,
                                            loop_vars=[0, 0.])

            return a_meta + extra_losses, a_inner + inner_losses
    def _slow_greedy_infer_guess_and_check(self, features, decode_length):
        assert self._hparams.block_size > 0
        assert self._hparams.force_full_predict
        assert self._hparams.sampling_method == "argmax"
        assert self._decode_hparams.batch_size == 1
        assert self._decode_hparams.block_size > 0
        assert self._decode_hparams.block_size <= self._hparams.block_size
        assert self._decode_hparams.guess_and_check_top_k > 0

        inputs_old = features["inputs"]
        assert "targets" not in features

        assert len(features["inputs"].shape) in [3, 4]
        if len(features["inputs"].shape) < 4:
            features["inputs"] = tf.expand_dims(features["inputs"], 2)

        block_size = self._decode_hparams.block_size
        decode_length += tf.shape(features["inputs"])[1]

        def while_exit_cond(result, length):  # pylint: disable=unused-argument
            return tf.logical_and(
                length < decode_length,
                tf.reduce_all(
                    tf.not_equal(result[:, :length, :, :],
                                 text_encoder.EOS_ID)))

        def infer_step(result, length):
            """Inference step."""
            def print_info(result, length, new_length):
                vocab = self.problem_hparams.vocabulary["targets"]
                tf.logging.info(
                    "length=%s new_length=%s length_diff=%s new_suffix=%s",
                    length,
                    new_length,
                    new_length - length,
                    str([
                        vocab._subtoken_id_to_subtoken_string(index)  # pylint: disable=protected-access
                        for index in result[0, -block_size:, 0,
                                            0][:new_length - length]
                    ]).decode("unicode-escape"),
                )

            features["targets"] = tf.pad(result,
                                         [[0, 0], [0, 1], [0, 0], [0, 0]])
            samples, logits, losses = self.sample(features)  # pylint: disable=unused-variable

            _, top_k_indices = tf.nn.top_k(
                logits[:, :-1, :1, :, :],
                k=self._decode_hparams.guess_and_check_top_k)
            in_top_k = tf.reduce_any(tf.equal(tf.to_int64(top_k_indices),
                                              tf.expand_dims(result, 4)),
                                     axis=4)

            eos_cumsum = tf.cumsum(tf.to_int32(
                tf.equal(result, text_encoder.EOS_ID)),
                                   axis=1)
            after_eos = tf.greater(common_layers.shift_right(eos_cumsum), 0)

            correct = tf.logical_and(in_top_k, tf.logical_not(after_eos))
            correct_cumsum = tf.cumsum(tf.to_int32(correct), axis=1)
            perfect_cumsum = 1 + tf.range(tf.shape(correct)[1])
            for axis in [0, 2, 3]:
                perfect_cumsum = tf.expand_dims(perfect_cumsum, axis=axis)

            new_length = tf.reduce_sum(tf.to_int32(
                tf.equal(correct_cumsum, perfect_cumsum)),
                                       axis=1)
            new_length = tf.squeeze(new_length, axis=[0, 1, 2])
            new_length = tf.minimum(new_length, decode_length)

            new_result = tf.concat([
                result[:, :new_length, :, :],
                tf.reshape(samples[:, new_length, :block_size, :],
                           [1, block_size, 1, 1])
            ],
                                   axis=1)

            with tf.control_dependencies(
                [tf.py_func(print_info, [result, length, new_length], [])]):
                new_result = tf.identity(new_result)

            return new_result, new_length

        result = tf.zeros((1, 0, 1, 1), dtype=tf.int64)
        length = tf.squeeze(tf.zeros(1, dtype=tf.int32))

        result, length = tf.while_loop(while_exit_cond,
                                       infer_step, [result, length],
                                       shape_invariants=[
                                           tf.TensorShape([1, None, 1, 1]),
                                           tf.TensorShape([]),
                                       ],
                                       back_prop=False,
                                       parallel_iterations=1)

        result = result[:, :length, :, :]

        features["inputs"] = inputs_old

        return {
            "outputs": result,
            "scores": None,
        }
示例#15
0
def pgd_attack(loss_fn,
               input_image,
               epsilon,
               num_steps,
               optimizer=UnrolledGradientDescent(),
               project_perturbation=_project_perturbation,
               image_bounds=None,
               random_init=1.):
    """Projected gradient descent for generating adversarial images.

  Args:
    loss_fn: A callable which takes `input_image` and `label` as arguments, and
      returns the loss, a scalar Tensor, we will be minimized
    input_image: Tensor, a batch of images
    epsilon: float, the L-infinity norm of the maximum allowable perturbation
    num_steps: int, the number of steps of gradient descent
    optimizer: An `UnrolledOptimizer` object
    project_perturbation: A function, which will be used to enforce some
      constraint. It should have the same signature as `_project_perturbation`.
      Note that if you use a custom projection function, you should double-check
      your implementation, since an incorrect implementation will not error,
      and will appear to work fine.
    image_bounds: A pair of floats: minimum and maximum pixel value. If None
      (default), the bounds are assumed to be 0 and 1.
    random_init: Probability of starting from random location rather than
      nominal input image.

  Returns:
    adversarial version of `input_image`, with L-infinity difference less than
      epsilon, which tries to minimize loss_fn.
  """
    image_bounds = image_bounds or (0., 1.)
    random_shape = [tf.shape(input_image)[0]
                    ] + [1] * (len(input_image.shape) - 1)
    use_random_init = tf.cast(
        tf.random_uniform(random_shape) < float(random_init), tf.float32)
    init_perturbation = use_random_init * tf.random_uniform(
        tf.shape(input_image), minval=-epsilon, maxval=epsilon)
    init_perturbation = project_perturbation(init_perturbation, epsilon,
                                             input_image, image_bounds)
    init_optim_state = optimizer.init_state([init_perturbation])

    def loop_body(i, perturbation, flat_optim_state):
        """Update perturbation to input image."""
        optim_state = nest.pack_sequence_as(structure=init_optim_state,
                                            flat_sequence=flat_optim_state)
        loss = loss_fn(input_image + perturbation)
        new_perturbation_list, new_optim_state = optimizer.minimize(
            loss, [perturbation], optim_state)
        projected_perturbation = project_perturbation(new_perturbation_list[0],
                                                      epsilon, input_image,
                                                      image_bounds)
        return i + 1, projected_perturbation, nest.flatten(new_optim_state)

    def cond(i, *_):
        return tf.less(i, num_steps)

    flat_init_optim_state = nest.flatten(init_optim_state)
    _, final_perturbation, _ = tf.while_loop(
        cond,
        loop_body,
        loop_vars=[tf.constant(0.), init_perturbation, flat_init_optim_state],
        parallel_iterations=1,
        back_prop=False)

    adversarial_image = input_image + final_perturbation
    return tf.stop_gradient(adversarial_image)
    def _create_gradient_ascent_action_tensors(self, eps=1e-6):
        """Create tensorflow operations for gradient ascent max_actions."""
        self._action_init_tensor = tf.placeholder(dtype=tf.float32,
                                                  name="action_init_tensor",
                                                  shape=(None,
                                                         self.action_dim))
        self._tolerance_tensor = tf.placeholder(dtype=tf.float32,
                                                name="tolerance_tensor",
                                                shape=())

        with tf.variable_scope("{}_{}".format(self.name, "action_variable")):
            self._action_variable_tensor = tf.Variable(
                initial_value=self._action_init_tensor,
                trainable=True,
                name="action_var")

            # gradient ascentd
            self.cost_now = -tf.reduce_mean(
                self._build_q_function_net(self._state_tensor,
                                           self._action_variable_tensor))
            self.action_gradient = tf.gradients(
                self.cost_now, self._action_variable_tensor)[0]
            # normalize the gradient
            self.normalized_action_gradient = self.action_gradient / (
                eps + tf.linalg.norm(self.action_gradient))

            if self.sufficient_ascent_flag:

                def cond_sufficient_descent(learning_rate_action,
                                            cond_sufficient_descent,
                                            cost_perturbed):
                    del cost_perturbed
                    cond_1 = tf.math.greater(learning_rate_action,
                                             self.learning_rate_action)
                    return tf.math.logical_and(
                        cond_1, tf.logical_not(cond_sufficient_descent))

                def body_sufficient_descent(learning_rate_action,
                                            cond_sufficient_descent,
                                            cost_perturbed,
                                            c_armijo=0.01,
                                            c_goldstein=0.25,
                                            lr_decay=0.1):
                    """Function for sufficient descent."""
                    del cond_sufficient_descent, cost_perturbed
                    action_variable_perturbed_tensor = self._action_variable_tensor - \
                      learning_rate_action * self.normalized_action_gradient

                    cost_perturbed = -tf.reduce_mean(
                        self._build_q_function_net(
                            self._state_tensor,
                            action_variable_perturbed_tensor))

                    # Here the negative gradient corresponds to maximization of Q fun.
                    sufficient_descent = tf.reduce_sum(
                        self.action_gradient *
                        -self.normalized_action_gradient)

                    goldstein_condition = tf.greater_equal(
                        cost_perturbed, self.cost_now + c_goldstein *
                        learning_rate_action * sufficient_descent)
                    armijo_condition = tf.less_equal(
                        cost_perturbed, self.cost_now +
                        c_armijo * learning_rate_action * sufficient_descent)
                    cond_sufficient_descent = tf.logical_and(
                        goldstein_condition, armijo_condition)

                    with tf.control_dependencies([cond_sufficient_descent]):
                        learning_rate_action = learning_rate_action * lr_decay

                    return learning_rate_action, cond_sufficient_descent, cost_perturbed

            # Construct the while loop.
            def cond_gradient_ascent(itr, cond_terminate):
                cond_1 = tf.math.less(itr, self.action_maximization_iterations)
                return tf.math.logical_and(cond_1,
                                           tf.logical_not(cond_terminate))

            def body_gradient_ascent(itr, cond_terminate, lr_init=100.0):
                """Function for gradient descent."""
                del cond_terminate
                if self.sufficient_ascent_flag:
                    # first calculate sufficeint descent
                    result_sufficient_descent = tf.while_loop(
                        cond_sufficient_descent, body_sufficient_descent, [
                            tf.constant(lr_init),
                            tf.constant(False),
                            tf.constant(np.inf)
                        ])
                    lr_action = result_sufficient_descent[0]
                    cost_perturbed = result_sufficient_descent[2]

                    cond_terminate = tf.less_equal(
                        tf.math.abs(cost_perturbed - self.cost_now),
                        self._tolerance_tensor)
                else:
                    # no sufficient descent step
                    lr_action = self.learning_rate_ga
                    action_variable_perturbed_tensor = self._action_variable_tensor - \
                      lr_action * self.normalized_action_gradient

                    cost_perturbed = -tf.reduce_mean(
                        self._build_q_function_net(
                            self._state_tensor,
                            action_variable_perturbed_tensor))
                    cond_terminate = tf.less_equal(
                        tf.math.abs(cost_perturbed - self.cost_now),
                        self._tolerance_tensor)

                train_op = tf.train.GradientDescentOptimizer(
                    learning_rate=lr_action).apply_gradients(
                        grads_and_vars=[(self.normalized_action_gradient,
                                         self._action_variable_tensor)])
                # Ensure that the update is applied before continuing.
                with tf.control_dependencies([train_op]):
                    itr = itr + 1

                    return itr, cond_terminate

            self.cost_optimizer = tf.while_loop(
                cond_gradient_ascent, body_gradient_ascent,
                [tf.constant(0), tf.constant(False)])

        self.action_init_op = tf.initializers.variables(
            tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                              scope="{}_{}".format(self.name,
                                                   "action_variable")))
示例#17
0
    def _build(self, inputs, labels):
        batch_size, input_shape, duplicated_inputs = self.prepare_inputs(
            inputs)
        if (self._max_specifications > 0 and self._max_specifications <
                self._specification.num_specifications):
            num_specs = self._max_specifications
            model_logits = self._eval_fn(inputs)
            bounds = self._specification.evaluate(model_logits)
            _, idx = tf.math.top_k(bounds, k=num_specs, sorted=False)
            if self._random_specifications:
                idx = tf.random.uniform(
                    shape=tf.shape(idx),
                    maxval=self._specification.num_specifications,
                    dtype=idx.dtype)
            idx = tf.tile(tf.expand_dims(idx, 0), [self._num_restarts, 1, 1])

            def select_fn(x, i):
                return tf.squeeze(tf.gather(x,
                                            tf.expand_dims(idx[:, :, i], -1),
                                            batch_dims=len(idx.shape) - 1),
                                  axis=-1)
        else:
            num_specs = self._specification.num_specifications
            select_fn = lambda x, i: x[:, :, i]

        def objective_fn(x):
            model_logits = self._eval_fn(x)  # [restarts * batch_size, output].
            model_logits = tf.reshape(model_logits,
                                      [self._num_restarts, batch_size, -1])
            # Output has dimension [num_restarts, batch_size, num_specifications].
            return self._specification.evaluate(model_logits)

        def flat_objective_fn(x):
            return _maximize_margin(objective_fn(x))

        def build_loss_fn(idx):
            def _reduced_loss_fn(x):
                # Pick worse attack, output has shape [num_restarts, batch_size].
                return -tf.reduce_sum(select_fn(objective_fn(x), idx))

            return _reduced_loss_fn

        if _is_spsa_optimizer(self._optimizer_builder):
            raise ValueError('"UnrolledSPSA*" unsupported in '
                             'MultiTargetedPGDAttack')
        optimizer = self._optimizer_builder(lr=self._lr, lr_fn=self._lr_fn)

        # Run a separate PGD attack for each specification.
        def cond(spec_idx, unused_attack, success):
            # If we are already successful, we break.
            return tf.logical_and(spec_idx < num_specs,
                                  tf.logical_not(tf.reduce_all(success)))

        def body(spec_idx, attack, success):
            """Runs a separate PGD attack for each specification."""
            adversarial_input = pgd_attack(
                build_loss_fn(spec_idx),
                duplicated_inputs,
                epsilon=self._epsilon,
                num_steps=self._num_steps,
                image_bounds=self._input_bounds,
                random_init=self._random_init,
                optimizer=optimizer,
                project_perturbation=self._project_perturbation)
            new_attack = self.find_worst_attack(flat_objective_fn,
                                                adversarial_input, batch_size,
                                                input_shape)
            new_logits = self._eval_fn(new_attack)
            # Count the number of sample that violate any specification.
            new_success = _any_greater(
                self._specification.evaluate(new_logits))
            # The first iteration always sets the attack and logits.
            use_new_values = tf.logical_or(tf.equal(spec_idx, 0), new_success)
            print_op = tf.print('Processed specification #', spec_idx)
            with tf.control_dependencies([print_op]):
                new_spec_idx = spec_idx + 1
            return (new_spec_idx, tf.where(use_new_values, new_attack, attack),
                    tf.logical_or(success, new_success))

        _, self._attack, self._success = tf.while_loop(
            cond,
            body,
            back_prop=False,
            parallel_iterations=1,
            loop_vars=[
                tf.constant(0, dtype=tf.int32),
                inputs,
                tf.zeros([tf.shape(inputs)[0]], dtype=tf.bool),
            ])
        self._logits = self._eval_fn(self._attack, mode='final')
        return self._attack
示例#18
0
def ssd_decode_and_crop(image_buffer, boxes, classes, raw_shape):
    """Crop image randomly and decode the cropped region.

  This function will crop an image to meet the following requirements:
  1. height to width ratio between 0.5 and 2;
  2. IoUs of some boxes exceed specified threshold;
  3. At least one box center is in the cropped region.
  We defer the jpeg decoding task until after the crop to avoid wasted work.

  Reference: https://github.com/chauhan-utk/ssd.DomainAdaptation

  Args:
    image_buffer: Tensor tf.string containing the contents of a JPEG file.
    boxes: Tensor tf.float32 of shape [num_boxes, 4], containing coordinates of
      object bounding boxes.
    classes: Tensor tf.int64 of shape [num_boxes, 1], containing class labels
      of objects.
    raw_shape: [height, width, 3].

  Returns:
    resized_image: decoded, cropped, and resized image Tensor tf.float32 of
      shape [ssd_constants.IMAGE_SIZE, ssd_constants.IMAGE_SIZE, 3], value
      range 0--255.
    cropped_boxes: box coordinates for objects in the cropped region.
    cropped_classes: class labels for objects in the cropped region.
  """

    num_boxes = tf.shape(boxes)[0]

    def no_crop_check():
        return (tf.random_uniform(
            shape=(), minval=0, maxval=1, dtype=tf.float32) <
                ssd_constants.P_NO_CROP_PER_PASS)

    def no_crop_proposal():
        return (
            tf.ones((), tf.bool),
            tf.convert_to_tensor([0, 0, 1, 1], dtype=tf.float32),
            tf.ones((num_boxes, ), tf.bool),
        )

    def crop_proposal():
        rand_vec = lambda minval, maxval: tf.random_uniform(shape=(
            ssd_constants.NUM_CROP_PASSES, 1),
                                                            minval=minval,
                                                            maxval=maxval,
                                                            dtype=tf.float32)

        width, height = rand_vec(0.3, 1), rand_vec(0.3, 1)
        left, top = rand_vec(0, 1 - width), rand_vec(0, 1 - height)

        right = left + width
        bottom = top + height

        ltrb = tf.concat([left, top, right, bottom], axis=1)

        min_iou = tf.random_shuffle(ssd_constants.CROP_MIN_IOU_CHOICES)[0]
        ious = calc_iou_tensor(ltrb, boxes)

        # discard any bboxes whose center not in the cropped image
        xc, yc = [
            tf.tile(0.5 * (boxes[:, i + 0] + boxes[:, i + 2])[tf.newaxis, :],
                    (ssd_constants.NUM_CROP_PASSES, 1)) for i in range(2)
        ]

        masks = tf.reduce_all(tf.stack([
            tf.greater(xc, tf.tile(left, (1, num_boxes))),
            tf.less(xc, tf.tile(right, (1, num_boxes))),
            tf.greater(yc, tf.tile(top, (1, num_boxes))),
            tf.less(yc, tf.tile(bottom, (1, num_boxes))),
        ],
                                       axis=2),
                              axis=2)

        # Checks of whether a crop is valid.
        valid_aspect = tf.logical_and(tf.less(height / width, 2),
                                      tf.less(width / height, 2))
        valid_ious = tf.reduce_all(tf.greater(ious, min_iou),
                                   axis=1,
                                   keepdims=True)
        valid_masks = tf.reduce_any(masks, axis=1, keepdims=True)

        valid_all = tf.cast(
            tf.reduce_all(tf.concat([valid_aspect, valid_ious, valid_masks],
                                    axis=1),
                          axis=1), tf.int32)

        # One indexed, as zero is needed for the case of no matches.
        index = tf.range(1, 1 + ssd_constants.NUM_CROP_PASSES, dtype=tf.int32)

        # Either one-hot, or zeros if there is no valid crop.
        selection = tf.equal(tf.reduce_max(index * valid_all), index)

        use_crop = tf.reduce_any(selection)
        output_ltrb = tf.reduce_sum(tf.multiply(
            ltrb,
            tf.tile(tf.cast(selection, tf.float32)[:, tf.newaxis], (1, 4))),
                                    axis=0)
        output_masks = tf.reduce_any(tf.logical_and(
            masks, tf.tile(selection[:, tf.newaxis], (1, num_boxes))),
                                     axis=0)

        return use_crop, output_ltrb, output_masks

    def proposal(*args):
        return tf.cond(
            pred=no_crop_check(),
            true_fn=no_crop_proposal,
            false_fn=crop_proposal,
        )

    _, crop_bounds, box_masks = tf.while_loop(
        cond=lambda x, *_: tf.logical_not(x),
        body=proposal,
        loop_vars=[
            tf.zeros((), tf.bool),
            tf.zeros((4, ), tf.float32),
            tf.zeros((num_boxes, ), tf.bool)
        ],
    )

    filtered_boxes = tf.boolean_mask(boxes, box_masks, axis=0)

    mlperf.logger.log(key=mlperf.tags.NUM_CROPPING_ITERATIONS,
                      value=ssd_constants.NUM_CROP_PASSES)

    # Clip boxes to the cropped region.
    filtered_boxes = tf.stack([
        tf.maximum(filtered_boxes[:, 0], crop_bounds[0]),
        tf.maximum(filtered_boxes[:, 1], crop_bounds[1]),
        tf.minimum(filtered_boxes[:, 2], crop_bounds[2]),
        tf.minimum(filtered_boxes[:, 3], crop_bounds[3]),
    ],
                              axis=1)

    left = crop_bounds[0]
    top = crop_bounds[1]
    width = crop_bounds[2] - left
    height = crop_bounds[3] - top

    cropped_boxes = tf.stack([
        (filtered_boxes[:, 0] - left) / width,
        (filtered_boxes[:, 1] - top) / height,
        (filtered_boxes[:, 2] - left) / width,
        (filtered_boxes[:, 3] - top) / height,
    ],
                             axis=1)

    # crop_window containing integer coordinates of cropped region. A normalized
    # coordinate value of y should be mapped to the image coordinate at
    # y * (height - 1).
    raw_shape = tf.cast(raw_shape, tf.float32)
    crop_window = tf.stack([
        left * (raw_shape[0] - 1), top * (raw_shape[1] - 1),
        width * raw_shape[0], height * raw_shape[1]
    ])
    crop_window = tf.cast(crop_window, tf.int32)

    # Fused op only decodes the cropped portion of an image
    cropped_image = tf.image.decode_and_crop_jpeg(image_buffer,
                                                  crop_window,
                                                  channels=3)

    # Resize converts image dtype from uint8 to float32, without rescaling values.
    resized_image = tf.image.resize_images(
        cropped_image, [ssd_constants.IMAGE_SIZE, ssd_constants.IMAGE_SIZE])
    mlperf.logger.log(key=mlperf.tags.INPUT_SIZE,
                      value=ssd_constants.IMAGE_SIZE)

    cropped_classes = tf.boolean_mask(classes, box_masks, axis=0)

    return resized_image, cropped_boxes, cropped_classes
示例#19
0
def _rnn_fn(sample_arc, x, prev_s, w_prev, w_skip, input_mask, layer_mask,
            params):
    """Multi-layer LSTM.

  Args:
    sample_arc: [num_layers * 2], sequence of tokens representing architecture.
    x: [batch_size, num_steps, hidden_size].
    prev_s: [batch_size, hidden_size].
    w_prev: [2 * hidden_size, 2 * hidden_size].
    w_skip: [None, [hidden_size, 2 * hidden_size] * (num_layers-1)].
    input_mask: `[batch_size, hidden_size]`.
    layer_mask: `[batch_size, hidden_size]`.
    params: hyper-params object.

  Returns:
    next_s: [batch_size, hidden_size].
    all_s: [[batch_size, num_steps, hidden_size] * num_layers].
  """
    batch_size = x.get_shape()[0].value
    num_steps = tf.shape(x)[1]
    num_layers = len(sample_arc) // 2

    all_s = tf.TensorArray(dtype=tf.float32, size=num_steps, infer_shape=False)

    # extract the relevant variables, so that you only do L2-reg on them.
    u_skip = []
    start_idx = 0
    for layer_id in range(num_layers):
        prev_idx = sample_arc[start_idx]
        func_idx = sample_arc[start_idx + 1]
        u_skip.append(w_skip[layer_id][func_idx, prev_idx])
        start_idx += 2
    w_skip = u_skip
    var_s = [w_prev] + w_skip[1:]

    def _select_function(h, function_id):
        h = tf.stack([tf.tanh(h), tf.nn.relu(h), tf.sigmoid(h), h], axis=0)
        h = h[function_id]
        return h

    def _condition(step, *unused_args):
        return tf.less(step, num_steps)

    def _body(step, prev_s, all_s):
        """Body function."""
        inp = x[:, step, :]

        # important change: first input uses a tanh()
        if layer_mask is not None:
            assert input_mask is not None
            ht = tf.matmul(
                tf.concat([inp * input_mask, prev_s * layer_mask], axis=1),
                w_prev)
        else:
            ht = tf.matmul(tf.concat([inp, prev_s], axis=1), w_prev)
        h, t = tf.split(ht, 2, axis=1)
        h = tf.tanh(h)
        t = tf.sigmoid(t)
        s = prev_s + t * (h - prev_s)
        layers = [s]

        start_idx = 0
        used = []
        for layer_id in range(num_layers):
            prev_idx = sample_arc[start_idx]
            func_idx = sample_arc[start_idx + 1]
            used.append(tf.one_hot(prev_idx, depth=num_layers, dtype=tf.int32))
            prev_s = tf.stack(layers, axis=0)[prev_idx]
            if layer_mask is not None:
                ht = tf.matmul(prev_s * layer_mask, w_skip[layer_id])
            else:
                ht = tf.matmul(prev_s, w_skip[layer_id])
            h, t = tf.split(ht, 2, axis=1)

            h = _select_function(h, func_idx)
            t = tf.sigmoid(t)
            s = prev_s + t * (h - prev_s)
            s.set_shape([batch_size, params.hidden_size])
            layers.append(s)
            start_idx += 2

        next_s = tf.add_n(layers[1:]) / tf.cast(num_layers, dtype=tf.float32)
        all_s = all_s.write(step, next_s)

        return step + 1, next_s, all_s

    loop_inps = [tf.constant(0, dtype=tf.int32), prev_s, all_s]
    _, next_s, all_s = tf.while_loop(_condition, _body, loop_inps)
    all_s = tf.transpose(all_s.stack(), [1, 0, 2])

    return next_s, all_s, var_s
示例#20
0
文件: base.py 项目: YuTpa/meta-blocks
    def select_indices_stratified(size,
                                  scores,
                                  clusters,
                                  indices=None,
                                  soft=False,
                                  parallel_iterations=8) -> tf.Tensor:
        """Selects indices of the instances to label given the scores.

        Parameters
        ----------
        size : int
            Number of samples to label.

        scores : Tensor <float32> [num_samples]
            A vector of scores that are used to select which sample to label.

        clusters : Tensor <int32> [num_samples]
            A vector of cluster indices used for sampling stratification.

        indices : Tensor <int32> [num_instances], optional
            A vector of absolute indices of the samples in a larger collection.
            If not None, the method returns `selected_indices` from `indices`.
            Otherwise, `selected_indices` are relative.

        soft : bool, optional (default=False)
            Whether to select top indices softly by sampling a categorical
            distribution with logits proportional to the scores.

        parallel_iterations : int (default: 8)
            Number of parallel iterations passed to tf.while_loop.

        Returns
        -------
            selected_indices : Tensor <int32> [size]
        """
        # size_per_cluster: <int32> [num_unique_clusters].
        # unique_clusters: <int32> [num_unique_clusters].
        size_per_cluster, unique_clusters = Sampler.stratify_by_cluster(
            size, clusters, parallel_iterations=parallel_iterations)

        def cond_fn(step, _unused_indices):
            return tf.less(step, tf.size(size_per_cluster))

        def body_fn(step, selected_indices):
            cluster_mask = tf.equal(clusters, unique_clusters[step])
            cluster_indices = tf.where(cluster_mask)[:, 0]
            cluster_scores = tf.gather(scores, cluster_indices, axis=0)
            selected_idx = tf.cond(
                pred=tf.greater(size_per_cluster[step], 0),
                true_fn=lambda: Sampler.select_indices(
                    size=size_per_cluster[step],
                    scores=cluster_scores,
                    indices=cluster_indices,
                    soft=soft,
                ),
                false_fn=lambda: tf.constant([], dtype=tf.int32),
            )
            return [
                tf.add(step, 1),
                selected_indices.write(step, selected_idx)
            ]

        # Select indices for each cluster cluster.
        _, selected_indices_ta = tf.while_loop(
            cond=cond_fn,
            body=body_fn,
            loop_vars=[
                tf.constant(0),
                tf.TensorArray(dtype=tf.int32,
                               infer_shape=False,
                               size=tf.size(unique_clusters)),
            ],
            back_prop=False,
            parallel_iterations=parallel_iterations,
            name="stratified-index-selection",
        )

        selected_indices = selected_indices_ta.concat()
        selected_indices = tf.reshape(selected_indices, shape=(size, ))
        if indices is not None:
            selected_indices = tf.gather(indices, selected_indices, axis=0)

        return selected_indices
示例#21
0
def hmc(energy_fn,
        init_X,
        L=20,
        step_size=1.0,
        burn_in=100,
        num_samples=1000,
        thinning_steps=1,
        max_steps=None):

    samples = tf.TensorArray(init_X.dtype,
                             size=num_samples * thinning_steps,
                             dynamic_size=False,
                             name='samples_ta')
    #init_X = init_X[tf.newaxis,:]
    X_shape = tf.shape(init_X)

    if max_steps == None:
        max_steps = 1000 * num_samples * thinning_steps

    def hmc_step(i, num_accepted, q, samples):
        # Sample momentum variables as standard Gaussians.
        p = tf.random.normal(X_shape, mean=0., stddev=1.)
        init_q = q
        # Compute initial kinetic and potential energies.
        init_K = tf.reduce_sum(tf.square(p)) / 2.
        init_U = energy_fn(q)

        # Do first half-step
        p = p - step_size * tf.gradients(init_U, q)[0] / 2.
        # Run for L steps.
        for step in range(L):
            q = q + step_size * p
            if step != L - 1:
                p = p - step_size * tf.gradients(energy_fn(q), q)[0]
        proposed_U = energy_fn(q)
        p = p - step_size * tf.gradients(proposed_U, q)[0] / 2.
        p = -p
        proposed_K = tf.reduce_sum(tf.square(p)) / 2.

        p = tf.debugging.check_numerics(p, "Nans in p.")
        q = tf.debugging.check_numerics(q, "Nans in q.")

        accept = tf.random.uniform(
            []) < tf.exp(init_U - proposed_U + init_K - proposed_K)
        accept_samples = tf.logical_and(accept, i > burn_in)
        samples = tf.cond(accept_samples,
                          lambda: samples.write(num_accepted, q),
                          lambda: samples)
        accept_samples = tf.squeeze(accept_samples)
        q = tf.cond(accept, lambda: q, lambda: init_q)
        return i + 1, num_accepted + tf.to_int32(accept_samples), q, samples

    def hmc_predicate(i, num_accepted, unused_q, unused_samples):
        return tf.logical_and(
            tf.less(i, burn_in + max_steps),
            tf.less(num_accepted, num_samples * thinning_steps))

    results = tf.while_loop(hmc_predicate,
                            hmc_step, (0, 0, init_X, samples),
                            back_prop=False)
    #[num_samples, data_dim]
    samples = results[-1].stack()
    samples = tf.reshape(samples, [num_samples, thinning_steps, -1])
    samples = samples[:, -1, :]

    num_steps = results[0]
    num_accepted = results[1]
    accept_ratio = num_accepted / (num_steps - burn_in)
    tf.summary.scalar("acceptance_ratio", accept_ratio)
    tf.summary.scalar("num_hmc_steps", num_steps - burn_in)
    return samples
示例#22
0
    def _slow_greedy_infer(self, features, decode_length):
        """A slow greedy inference method.

        Quadratic time in decode_length.

        Args:
          features: an map of string to `Tensor`
          decode_length: an integer.  How many additional timesteps to decode.

        Returns:
          A dict of decoding results {
              "outputs": integer `Tensor` of decoded ids of shape
                  [batch_size, <= decode_length] if beam_size == 1 or
                  [batch_size, top_beams, <= decode_length]
              "scores": None
              "logits": `Tensor` of shape [batch_size, time, 1, 1, vocab_size].
              "losses": a dictionary: {loss-name (string): floating point `Scalar`}
          }
        """
        if not features:
            features = {}
        inputs_old = None
        # process all conditioning features
        if "inputs" in features:
            if len(features["inputs"].shape) < 4:
                inputs_old = features["inputs"]
                features["inputs"] = tf.expand_dims(features["inputs"], 2)
        else:  # this would be for melody decoding
            if "melody" in features:
                if len(features["melody"].shape) < 4:
                    inputs_old = features["melody"]
                    features["melody"] = tf.expand_dims(features["melody"], 2)
            if "performance" in features:
                if len(features["performance"].shape) < 4:
                    inputs_old = features["performance"]
                    features["performance"] = tf.expand_dims(
                        features["performance"], 2)
        if not self.has_input:
            # Prepare partial targets.
            # In either features["inputs"] or features["targets"].
            # We force the outputs to begin with these sequences.
            partial_targets = features.get("inputs")
            if partial_targets is None:
                partial_targets = features["targets"]
            features["partial_targets"] = tf.to_int64(partial_targets)
        # Save the targets in a var and reassign it after the tf.while loop to avoid
        # having targets being in a 'while' frame. This ensures targets when used
        # in metric functions stays in the same frame as other vars.
        targets_old = features.get("targets", None)

        target_modality = self._problem_hparams.modality["targets"]

        def infer_step(recent_output, recent_logits, unused_loss):
            """Inference step."""
            if not tf.executing_eagerly():
                if self._target_modality_is_real:
                    dim = self._problem_hparams.vocab_size["targets"]
                    if dim is not None and hasattr(self._hparams,
                                                   "vocab_divisor"):
                        dim += (-dim) % self._hparams.vocab_divisor
                    recent_output.set_shape([None, None, None, dim])
                else:
                    recent_output.set_shape([None, None, None, 1])
            padded = tf.pad(recent_output, [[0, 0], [0, 1], [0, 0], [0, 0]])
            features["targets"] = padded
            # This is inefficient in that it generates samples at all timesteps,
            # not just the last one, except if target_modality is pointwise.
            samples, logits, losses = self.sample(features)
            # Concatenate the already-generated recent_output with last timestep
            # of the newly-generated samples.
            top = self._hparams.top.get("targets",
                                        modalities.get_top(target_modality))
            if getattr(top, "pointwise", False):
                cur_sample = samples[:, -1, :, :]
            else:
                cur_sample = samples[:,
                                     common_layers.shape_list(recent_output
                                                              )[1], :, :]
            if self._target_modality_is_real:
                cur_sample = tf.expand_dims(cur_sample, axis=1)
                samples = tf.concat([recent_output, cur_sample], axis=1)
            else:
                cur_sample = tf.to_int64(tf.expand_dims(cur_sample, axis=1))
                samples = tf.concat([recent_output, cur_sample], axis=1)
                if not tf.executing_eagerly():
                    samples.set_shape([None, None, None, 1])

            # Assuming we have one shard for logits.
            logits = tf.concat([recent_logits, logits[:, -1:]], 1)
            loss = sum([l for l in losses.values() if l is not None])
            return samples, logits, loss

        # Create an initial output tensor. This will be passed
        # to the infer_step, which adds one timestep at every iteration.
        if "partial_targets" in features:
            initial_output = tf.to_int64(features["partial_targets"])
            while len(initial_output.get_shape().as_list()) < 4:
                initial_output = tf.expand_dims(initial_output, 2)
            batch_size = common_layers.shape_list(initial_output)[0]
        else:
            batch_size = common_layers.shape_list(features["performance"])[0]
            if self._target_modality_is_real:
                dim = self._problem_hparams.vocab_size["targets"]
                if dim is not None and hasattr(self._hparams, "vocab_divisor"):
                    dim += (-dim) % self._hparams.vocab_divisor
                initial_output = tf.zeros((batch_size, 0, 1, dim),
                                          dtype=tf.float32)
            else:
                initial_output = tf.zeros((batch_size, 0, 1, 1),
                                          dtype=tf.int64)
        # Hack: foldl complains when the output shape is less specified than the
        # input shape, so we confuse it about the input shape.
        initial_output = tf.slice(initial_output, [0, 0, 0, 0],
                                  common_layers.shape_list(initial_output))
        target_modality = self._problem_hparams.modality["targets"]
        if target_modality == modalities.ModalityType.CLASS_LABEL:
            decode_length = 1
        else:
            if "partial_targets" in features:
                prefix_length = common_layers.shape_list(
                    features["partial_targets"])[1]
            else:
                # this code will generate outputs that tend to be long,
                # but this is to avoid the case when the melody is extremely short.
                # this can be changed to features["melody"] for the actual behavior.
                prefix_length = common_layers.shape_list(
                    features["performance"])[1]
            decode_length = prefix_length + decode_length

        # Initial values of result, logits and loss.
        result = initial_output
        vocab_size = self._problem_hparams.vocab_size["targets"]
        if vocab_size is not None and hasattr(self._hparams, "vocab_divisor"):
            vocab_size += (-vocab_size) % self._hparams.vocab_divisor
        if self._target_modality_is_real:
            logits = tf.zeros((batch_size, 0, 1, vocab_size))
            logits_shape_inv = [None, None, None, None]
        else:
            # tensor of shape [batch_size, time, 1, 1, vocab_size]
            logits = tf.zeros((batch_size, 0, 1, 1, vocab_size))
            logits_shape_inv = [None, None, None, None, None]
        if not tf.executing_eagerly():
            logits.set_shape(logits_shape_inv)

        loss = 0.0

        def while_exit_cond(result, logits, loss):  # pylint: disable=unused-argument
            """Exit the loop either if reach decode_length or EOS."""
            length = common_layers.shape_list(result)[1]

            not_overflow = length < decode_length

            if self._problem_hparams.stop_at_eos:

                def fn_not_eos():
                    return tf.not_equal(  # Check if the last predicted element is a EOS
                        tf.squeeze(result[:, -1, :, :]), text_encoder.EOS_ID)

                not_eos = tf.cond(
                    # We only check for early stopping if there is at least 1 element (
                    # otherwise not_eos will crash).
                    tf.not_equal(length, 0),
                    fn_not_eos,
                    lambda: True,
                )

                return tf.cond(
                    tf.equal(batch_size, 1),
                    # If batch_size == 1, we check EOS for early stopping.
                    lambda: tf.logical_and(not_overflow, not_eos),
                    # Else, just wait for max length
                    lambda: not_overflow)
            return not_overflow

        result, logits, loss = tf.while_loop(
            while_exit_cond,
            infer_step, [result, logits, loss],
            shape_invariants=[
                tf.TensorShape([None, None, None, None]),
                tf.TensorShape(logits_shape_inv),
                tf.TensorShape([]),
            ],
            back_prop=False,
            parallel_iterations=1)
        if inputs_old is not None:  # Restore to not confuse Estimator.
            features["inputs"] = inputs_old
        # Reassign targets back to the previous value.
        if targets_old is not None:
            features["targets"] = targets_old
        losses = {"training": loss}
        if "partial_targets" in features:
            partial_target_length = common_layers.shape_list(
                features["partial_targets"])[1]
            result = tf.slice(result, [0, partial_target_length, 0, 0],
                              [-1, -1, -1, -1])
        return {
            "outputs": result,
            "scores": None,
            "logits": logits,
            "losses": losses,
        }
示例#23
0
    def build_sample_graph(self,
                           input_pianorolls=None,
                           outer_masks=None,
                           total_gibbs_steps=None):
        """Builds the tf.while_loop based sampling graph.

    Args:
      input_pianorolls: Optional input pianorolls override. If None, uses the
          pianorolls placeholder.
      outer_masks: Optional input outer_masks override. If None, uses the
          outer_masks placeholder.
      total_gibbs_steps: Optional input total_gibbs_steps override. If None,
          uses the total_gibbs_steps placeholder.
    Returns:
      The output op of the graph.
    """
        if input_pianorolls is None:
            input_pianorolls = self.inputs["pianorolls"]
        if outer_masks is None:
            outer_masks = self.inputs["outer_masks"]

        tt = tf.shape(input_pianorolls)[1]
        sample_steps = tf.to_float(self.inputs["sample_steps"])
        if total_gibbs_steps is None:
            total_gibbs_steps = self.inputs["total_gibbs_steps"]
        temperature = self.inputs["temperature"]

        input_pianorolls = tf.to_float(input_pianorolls)
        outer_masks = self.make_outer_masks(outer_masks, input_pianorolls)

        # Calculate total_gibbs_steps as steps * num_instruments if not given.
        total_gibbs_steps = tf.cond(
            tf.equal(total_gibbs_steps, 0),
            lambda: tf.to_float(tt * self.hparams.num_instruments),
            lambda: tf.to_float(total_gibbs_steps))

        # sample_steps is set to total_gibbs_steps if not given.
        sample_steps = tf.cond(tf.equal(sample_steps,
                                        0), lambda: total_gibbs_steps,
                               lambda: tf.to_float(sample_steps))

        def infer_step(pianorolls, step_count):
            """Called by tf.while_loop, takes a Gibbs step."""
            mask_prob = compute_mask_prob_from_yao_schedule(
                step_count, total_gibbs_steps)
            # 1 indicates mask out, 0 is not mask.
            masks = make_bernoulli_masks(tf.shape(pianorolls), mask_prob,
                                         outer_masks)

            logits = self.predict(pianorolls, masks)
            samples = sample_with_temperature(logits, temperature=temperature)

            outputs = pianorolls * (1 - masks) + samples * masks

            check_completion_op = tf.assert_equal(
                tf.where(tf.equal(tf.reduce_max(masks, axis=2), 1.),
                         tf.reduce_max(outputs, axis=2),
                         tf.reduce_max(pianorolls, axis=2)), 1.)
            with tf.control_dependencies([check_completion_op]):
                outputs = tf.identity(outputs)

            step_count += 1
            return outputs, step_count

        current_step = tf.to_float(self.inputs["current_step"])

        # Initializes pianorolls by evaluating the model once to fill in all gaps.
        logits = self.predict(tf.to_float(input_pianorolls), outer_masks)
        samples = sample_with_temperature(logits, temperature=temperature)
        tf.get_variable_scope().reuse_variables()

        self.samples, current_step = tf.while_loop(
            lambda samples, current_step: current_step < sample_steps,
            infer_step, [samples, current_step],
            shape_invariants=[
                tf.TensorShape([None, None, None, None]),
                tf.TensorShape(None),
            ],
            back_prop=False,
            parallel_iterations=1,
            name="coco_while")
        self.samples.set_shape(input_pianorolls.shape)
        return self.samples
示例#24
0
    def infer(self, features, **kwargs):
        decode_length = (self.frame_height * self.frame_width *
                         self.num_channels)
        cache = {}
        decoding_stats = {}
        targets_old = features.get("targets", None)
        initial_output = tf.zeros((self.batch_size, decode_length),
                                  dtype=tf.int32)
        initial_logits = tf.zeros(
            (self.batch_size, decode_length, self.targets_vocab_size))
        # call body once to initialize cache with representations of input frames.
        features["targets"] = initial_output
        with tf.variable_scope("sparse_imagetransformer/body",
                               reuse=tf.AUTO_REUSE,
                               use_resource=True):
            self.body(features,
                      decode_step=None,
                      cache=cache,
                      decoding_stats=decoding_stats)

        def infer_step(i, recent_output, recent_logits, cache, decoding_stats):
            """Inference step."""
            features_copy = features.copy()
            features_copy["targets"] = recent_output
            cur_sample, cur_logit = self.sample(features_copy,
                                                decode_step=i,
                                                cache=cache,
                                                decoding_stats=decoding_stats)
            pos = i
            samples = recent_output + tf.scatter_nd(
                indices=[[b, pos] for b in range(self.batch_size)],
                updates=cur_sample,
                shape=utils.shape_list(recent_output))
            logits = recent_logits + tf.scatter_nd(
                indices=[[b, pos] for b in range(self.batch_size)],
                updates=cur_logit,
                shape=utils.shape_list(recent_logits))
            return i + 1, samples, logits, cache, decoding_stats

        def while_exit_cond(i, result, logits, cache, decoding_stats):  # pylint: disable=unused-argument
            """Exit the loop if it reaches decode_length."""
            not_overflow = i < decode_length
            return not_overflow

        _, final_result, final_logits, _, decoding_stats = tf.while_loop(
            while_exit_cond,
            infer_step, [
                tf.constant(0), initial_output, initial_logits, cache,
                decoding_stats
            ],
            back_prop=False,
            parallel_iterations=1)

        original_shape = self.get_shape_for_decoder()

        blocks_per_dim = [
            s // q for s, q in zip(original_shape, self.hparams.query_shape)
        ]
        final_result_shape = utils.shape_list(final_result)
        final_result = tf.reshape(
            final_result,
            [final_result_shape[0], -1,
             np.prod(self.hparams.query_shape), 1])
        final_logits_shape = utils.shape_list(final_logits)
        final_logits = tf.reshape(final_logits, [
            final_logits_shape[0], -1,
            np.prod(self.hparams.query_shape), final_logits_shape[-1]
        ])
        final_result = utils.unflatten_blocks_nd(final_result, blocks_per_dim)
        final_result = utils.put_back_blocks_nd(final_result,
                                                self.hparams.query_shape)
        final_logits = utils.unflatten_blocks_nd(final_logits, blocks_per_dim)
        final_logits = utils.put_back_blocks_nd(final_logits,
                                                self.hparams.query_shape)

        final_result = tf.reshape(
            final_result,
            [-1, self.frame_height, self.frame_width, self.num_channels])
        final_logits = tf.reshape(final_logits, [
            -1, self.frame_height, self.frame_width, self.num_channels,
            self.targets_vocab_size
        ])

        if utils.is_xla_compiled():
            _IMGS["decodes"] = final_result

        for name, value in decoding_stats.items():
            tf.summary.scalar("decodes/%s" % name, value / decode_length)

        # Reassign targets back to the previous value.
        if targets_old is not None:
            features["targets"] = targets_old

        return {
            "outputs": final_result,
            "scores": None,
            "logits": final_logits,
            "losses": None,
        }
示例#25
0
def dynamic_decode(decoder,
                   impute_finished=False,
                   maximum_iterations=None,
                   parallel_iterations=32,
                   swap_memory=False,
                   scope=None):
    """Perform dynamic decoding with `decoder`.

  Calls initialize() once and step() repeatedly on the Decoder object.

  Args:
    decoder: A `Decoder` instance.
    impute_finished: Python boolean.  If `True`, then states for batch
      entries which are marked as finished get copied through and the
      corresponding outputs get zeroed out.  This causes some slowdown at
      each time step, but ensures that the final state and outputs have
      the correct values and that backprop ignores time steps that were
      marked as finished.
    maximum_iterations: `int32` scalar, maximum allowed number of decoding
       steps.  Default is `None` (decode until the decoder is fully done).
    parallel_iterations: Argument passed to `tf.while_loop`.
    swap_memory: Argument passed to `tf.while_loop`.
    scope: Optional variable scope to use.

  Returns:
    `(final_outputs, final_state, final_sequence_lengths)`.

  Raises:
    TypeError: if `decoder` is not an instance of `Decoder`.
    ValueError: if `maximum_iterations` is provided but is not a scalar.
  """
    if not isinstance(decoder, Decoder):
        raise TypeError("Expected decoder to be type Decoder, but saw: %s" %
                        type(decoder))

    with tf.variable_scope(scope, "decoder") as varscope:
        # Determine context types.
        ctxt = tf.get_default_graph()._get_control_flow_context()  # pylint: disable=protected-access
        is_xla = control_flow_util.GetContainingXLAContext(ctxt) is not None
        in_while_loop = (control_flow_util.GetContainingWhileContext(ctxt)
                         is not None)
        # Properly cache variable values inside the while_loop.
        # Don't set a caching device when running in a loop, since it is possible
        # that train steps could be wrapped in a tf.while_loop. In that scenario
        # caching prevents forward computations in loop iterations from re-reading
        # the updated weights.
        if not tf.executing_eagerly() and not in_while_loop:
            if varscope.caching_device is None:
                varscope.set_caching_device(lambda op: op.device)

        if maximum_iterations is not None:
            maximum_iterations = tf.convert_to_tensor(
                maximum_iterations, dtype=tf.int32, name="maximum_iterations")
            if maximum_iterations.get_shape().ndims != 0:
                raise ValueError("maximum_iterations must be a scalar")

        initial_finished, initial_inputs, initial_state = decoder.initialize()

        zero_outputs = _create_zero_outputs(decoder.output_size,
                                            decoder.output_dtype,
                                            decoder.batch_size)

        if is_xla and maximum_iterations is None:
            raise ValueError(
                "maximum_iterations is required for XLA compilation.")
        if maximum_iterations is not None:
            initial_finished = tf.logical_or(initial_finished,
                                             0 >= maximum_iterations)
        initial_sequence_lengths = tf.zeros_like(initial_finished,
                                                 dtype=tf.int32)
        initial_time = tf.constant(0, dtype=tf.int32)

        def _create_ta(s, d):
            return tf.zeros([maximum_iterations, decoder.batch_size, s],
                            dtype=d)

        initial_outputs_ta = contrib_framework.nest.map_structure(
            _create_ta, decoder.output_size, decoder.output_dtype)

        def condition(unused_time, unused_outputs_ta, unused_state,
                      unused_inputs, finished, unused_sequence_lengths):
            return True

        def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
            """Internal while_loop body.

      Args:
        time: scalar int32 tensor.
        outputs_ta: structure of TensorArray.
        state: (structure of) state tensors and TensorArrays.
        inputs: (structure of) input tensors.
        finished: bool tensor (keeping track of what's finished).
        sequence_lengths: int32 tensor (keeping track of time of finish).

      Returns:
        `(time + 1, outputs_ta, next_state, next_inputs, next_finished,
          next_sequence_lengths)`.
        ```
      """
            (next_outputs, decoder_state, next_inputs,
             decoder_finished) = decoder.step(time, inputs, state)
            if decoder.tracks_own_finished:
                next_finished = decoder_finished
            else:
                next_finished = tf.logical_or(decoder_finished, finished)
            next_sequence_lengths = tf.where(
                tf.logical_not(finished),
                tf.fill(tf.shape(sequence_lengths), time + 1),
                sequence_lengths)

            contrib_framework.nest.assert_same_structure(state, decoder_state)
            contrib_framework.nest.assert_same_structure(
                outputs_ta, next_outputs)
            contrib_framework.nest.assert_same_structure(inputs, next_inputs)

            # Zero out output values past finish
            if impute_finished:
                emit = contrib_framework.nest.map_structure(
                    lambda out, zero: tf.where(finished, zero, out),
                    next_outputs, zero_outputs)
            else:
                emit = next_outputs

            # Copy through states past finish
            def _maybe_copy_state(new, cur):
                # TensorArrays and scalar states get passed through.
                if isinstance(cur, tf.TensorArray):
                    pass_through = True
                else:
                    new.set_shape(cur.shape)
                    pass_through = (new.shape.ndims == 0)
                return new if pass_through else tf.where(finished, cur, new)

            if impute_finished:
                next_state = contrib_framework.nest.map_structure(
                    _maybe_copy_state, decoder_state, state)
            else:
                next_state = decoder_state

            outputs_ta = contrib_framework.nest.map_structure(
                lambda ta, out: inplace_ops.alias_inplace_update(
                    ta, time, out), outputs_ta, emit)
            return (time + 1, outputs_ta, next_state, next_inputs,
                    next_finished, next_sequence_lengths)

        res = tf.while_loop(condition,
                            body,
                            loop_vars=(
                                initial_time,
                                initial_outputs_ta,
                                initial_state,
                                initial_inputs,
                                initial_finished,
                                initial_sequence_lengths,
                            ),
                            parallel_iterations=parallel_iterations,
                            maximum_iterations=maximum_iterations,
                            swap_memory=swap_memory)

        final_outputs_ta = res[1]
        final_state = res[2]
        final_sequence_lengths = res[5]

        final_outputs = final_outputs_ta

        try:
            final_outputs, final_state = decoder.finalize(
                final_outputs, final_state, final_sequence_lengths)
        except NotImplementedError:
            pass

        pred_ids = tf.transpose(final_state.pred_ids, [2, 0, 1])

    return pred_ids
示例#26
0
def beam_search(symbols_to_logits_fn,
                initial_ids,
                beam_size,
                decode_length,
                vocab_size,
                alpha,
                states=None,
                eos_id=EOS_ID,
                stop_early=True,
                use_tpu=False,
                use_top_k_with_unique=True):
  """Beam search with length penalties.

  Requires a function that can take the currently decoded symbols and return
  the logits for the next symbol. The implementation is inspired by
  https://arxiv.org/abs/1609.08144.

  When running, the beam search steps can be visualized by using tfdbg to watch
  the operations generating the output ids for each beam step.  These operations
  have the pattern:
    (alive|finished)_topk_(seq,scores)

  Operations marked `alive` represent the new beam sequences that will be
  processed in the next step.  Operations marked `finished` represent the
  completed beam sequences, which may be padded with 0s if no beams finished.

  Operations marked `seq` store the full beam sequence for the time step.
  Operations marked `scores` store the sequence's final log scores.

  The beam search steps will be processed sequentially in order, so when
  capturing observed from these operations, tensors, clients can make
  assumptions about which step is being recorded.

  WARNING: Assumes 2nd dimension of tensors in `states` and not invariant, this
  means that the shape of the 2nd dimension of these tensors will not be
  available (i.e. set to None) inside symbols_to_logits_fn.

  Args:
    symbols_to_logits_fn: Interface to the model, to provide logits.
        Shoud take [batch_size, decoded_ids] and return [batch_size, vocab_size]
    initial_ids: Ids to start off the decoding, this will be the first thing
        handed to symbols_to_logits_fn (after expanding to beam size)
        [batch_size]
    beam_size: Size of the beam.
    decode_length: Number of steps to decode for.
    vocab_size: Size of the vocab, must equal the size of the logits returned by
        symbols_to_logits_fn
    alpha: alpha for length penalty.
    states: dict (possibly nested) of decoding states.
    eos_id: ID for end of sentence.
    stop_early: a boolean - stop once best sequence is provably determined.
    use_tpu: A bool, whether to do beam search on TPU.
    use_top_k_with_unique: bool, whether to use a fast (but decreased precision)
      top_k during TPU beam search.

  Returns:
    Tuple of
    (decoded beams [batch_size, beam_size, decode_length]
     decoding probabilities [batch_size, beam_size])
  """
  batch_size = common_layers.shape_list(initial_ids)[0]

  # Assume initial_ids are prob 1.0
  initial_log_probs = tf.constant([[0.] + [-INF] * (beam_size - 1)])
  # Expand to beam_size (batch_size, beam_size)
  alive_log_probs = tf.tile(initial_log_probs, [batch_size, 1])

  # Expand each batch and state to beam_size
  alive_seq = _expand_to_beam_size(initial_ids, beam_size)
  alive_seq = tf.expand_dims(alive_seq, axis=2)  # (batch_size, beam_size, 1)
  if use_tpu:
    alive_seq = tf.tile(alive_seq, [1, 1, decode_length + 1])
  if states:
    states = nest.map_structure(
        lambda state: _expand_to_beam_size(state, beam_size), states)
  else:
    states = {}

  # Finished will keep track of all the sequences that have finished so far
  # Finished log probs will be negative infinity in the beginning
  # finished_flags will keep track of booleans
  finished_seq = tf.zeros(common_layers.shape_list(alive_seq), tf.int32)
  # Setting the scores of the initial to negative infinity.
  finished_scores = tf.ones([batch_size, beam_size]) * -INF
  finished_flags = tf.zeros([batch_size, beam_size], tf.bool)

  def grow_finished(finished_seq, finished_scores, finished_flags, curr_seq,
                    curr_scores, curr_finished):
    """Given sequences and scores, will gather the top k=beam size sequences.

    Args:
      finished_seq: Current finished sequences.
        [batch_size, beam_size, current_decoded_length]
      finished_scores: scores for each of these sequences.
        [batch_size, beam_size]
      finished_flags: finished bools for each of these sequences.
        [batch_size, beam_size]
      curr_seq: current topk sequence that has been grown by one position.
        [batch_size, beam_size, current_decoded_length]
      curr_scores: scores for each of these sequences. [batch_size, beam_size]
      curr_finished: Finished flags for each of these sequences.
        [batch_size, beam_size]
    Returns:
      Tuple of
        (Topk sequences based on scores,
         log probs of these sequences,
         Finished flags of these sequences)
    """
    if not use_tpu:
      # First append a column of 0'ids to finished to make the same length with
      # finished scores
      finished_seq = tf.concat(
          [finished_seq,
           tf.zeros([batch_size, beam_size, 1], tf.int32)], axis=2)

    # Set the scores of the unfinished seq in curr_seq to large negative
    # values
    curr_scores += (1. - tf.to_float(curr_finished)) * -INF
    # concatenating the sequences and scores along beam axis
    curr_finished_seq = tf.concat([finished_seq, curr_seq], axis=1)
    curr_finished_scores = tf.concat([finished_scores, curr_scores], axis=1)
    curr_finished_flags = tf.concat([finished_flags, curr_finished], axis=1)
    return compute_topk_scores_and_seq(
        curr_finished_seq,
        curr_finished_scores,
        curr_finished_scores,
        curr_finished_flags,
        beam_size,
        batch_size,
        "grow_finished",
        use_tpu=use_tpu,
        use_top_k_with_unique=use_top_k_with_unique)

  def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished, states):
    """Given sequences and scores, will gather the top k=beam size sequences.

    Args:
      curr_seq: current topk sequence that has been grown by one position.
        [batch_size, beam_size, i+1]
      curr_scores: scores for each of these sequences. [batch_size, beam_size]
      curr_log_probs: log probs for each of these sequences.
        [batch_size, beam_size]
      curr_finished: Finished flags for each of these sequences.
        [batch_size, beam_size]
      states: dict (possibly nested) of decoding states.
    Returns:
      Tuple of
        (Topk sequences based on scores,
         log probs of these sequences,
         Finished flags of these sequences)
    """
    # Set the scores of the finished seq in curr_seq to large negative
    # values
    curr_scores += tf.to_float(curr_finished) * -INF
    return compute_topk_scores_and_seq(curr_seq, curr_scores, curr_log_probs,
                                       curr_finished, beam_size, batch_size,
                                       "grow_alive", states, use_tpu=use_tpu)

  def grow_topk(i, alive_seq, alive_log_probs, states):
    r"""Inner beam search loop.

    This function takes the current alive sequences, and grows them to topk
    sequences where k = 2*beam. We use 2*beam because, we could have beam_size
    number of sequences that might hit <EOS> and there will be no alive
    sequences to continue. With 2*beam_size, this will not happen. This relies
    on the assumption the vocab size is > beam size. If this is true, we'll
    have at least beam_size non <EOS> extensions if we extract the next top
    2*beam words.
    Length penalty is given by = (5+len(decode)/6) ^ -\alpha. Pls refer to
    https://arxiv.org/abs/1609.08144.

    Args:
      i: loop index
      alive_seq: Topk sequences decoded so far [batch_size, beam_size, i+1]
      alive_log_probs: probabilities of these sequences. [batch_size, beam_size]
      states: dict (possibly nested) of decoding states.
    Returns:
      Tuple of
        (Topk sequences extended by the next word,
         The log probs of these sequences,
         The scores with length penalty of these sequences,
         Flags indicating which of these sequences have finished decoding,
         dict of transformed decoding states)
    """
    # Get the logits for all the possible next symbols
    if use_tpu and states:
      flat_ids = tf.reshape(
          tf.slice(alive_seq, [0, 0, i], [batch_size, beam_size, 1]),
          [batch_size * beam_size, -1])
    else:
      flat_ids = tf.reshape(alive_seq, [batch_size * beam_size, -1])

    # (batch_size * beam_size, decoded_length)
    if states:
      flat_states = nest.map_structure(_merge_beam_dim, states)
      flat_logits, flat_states = symbols_to_logits_fn(flat_ids, i, flat_states)
      states = nest.map_structure(
          lambda t: _unmerge_beam_dim(t, batch_size, beam_size), flat_states)
    elif use_tpu:
      flat_logits = symbols_to_logits_fn(flat_ids, i)
    else:
      flat_logits = symbols_to_logits_fn(flat_ids)

    logits = tf.reshape(flat_logits, [batch_size, beam_size, -1])

    # Convert logits to normalized log probs
    candidate_log_probs = common_layers.log_prob_from_logits(logits)

    # Multiply the probabilities by the current probabilities of the beam.
    # (batch_size, beam_size, vocab_size) + (batch_size, beam_size, 1)
    log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs, axis=2)

    length_penalty = tf.pow(((5. + tf.to_float(i + 1)) / 6.), alpha)

    curr_scores = log_probs / length_penalty
    # Flatten out (beam_size, vocab_size) probs in to a list of possibilities
    flat_curr_scores = tf.reshape(curr_scores, [-1, beam_size * vocab_size])

    if use_tpu and use_top_k_with_unique:
      topk_scores, topk_ids = top_k_with_unique(
          flat_curr_scores, k=beam_size * 2)
    else:
      topk_scores, topk_ids = tf.nn.top_k(flat_curr_scores, k=beam_size * 2)

    # Recovering the log probs because we will need to send them back
    topk_log_probs = topk_scores * length_penalty

    # Work out what beam the top probs are in.
    topk_beam_index = topk_ids // vocab_size
    topk_ids %= vocab_size  # Unflatten the ids

    if not use_tpu:
      # The next three steps are to create coordinates for tf.gather_nd to pull
      # out the correct sequences from id's that we need to grow.
      # We will also use the coordinates to gather the booleans of the beam
      # items that survived.
      batch_pos = compute_batch_indices(batch_size, beam_size * 2)

      # top beams will give us the actual coordinates to do the gather.
      # stacking will create a tensor of dimension batch * beam * 2, where the
      # last dimension contains the i,j gathering coordinates.
      topk_coordinates = tf.stack([batch_pos, topk_beam_index], axis=2)

      # Gather up the most probable 2*beams both for the ids and
      # finished_in_alive bools
      topk_seq = tf.gather_nd(alive_seq, topk_coordinates)
      if states:
        states = nest.map_structure(
            lambda state: tf.gather_nd(state, topk_coordinates), states)

      # Append the most probable alive
      topk_seq = tf.concat([topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2)
    else:
      # Gather up the most probable 2*beams both for the ids and
      # finished_in_alive bools
      topk_seq = fast_tpu_gather(alive_seq, topk_beam_index)

      if states:
        states = nest.map_structure(
            lambda state: fast_tpu_gather(state, topk_beam_index), states)

      # Update the most probable alive
      topk_seq = tf.transpose(topk_seq, perm=[2, 0, 1])
      topk_seq = inplace_ops.alias_inplace_update(topk_seq, i + 1, topk_ids)
      topk_seq = tf.transpose(topk_seq, perm=[1, 2, 0])

    topk_finished = tf.equal(topk_ids, eos_id)

    return topk_seq, topk_log_probs, topk_scores, topk_finished, states

  def inner_loop(i, alive_seq, alive_log_probs, finished_seq, finished_scores,
                 finished_flags, states):
    """Inner beam search loop.

    There are three groups of tensors, alive, finished, and topk.
    The alive group contains information about the current alive sequences
    The topk group contains information about alive + topk current decoded words
    the finished group contains information about finished sentences, that is,
    the ones that have decoded to <EOS>. These are what we return.
    The general beam search algorithm is as follows:
    While we haven't terminated (pls look at termination condition)
      1. Grow the current alive to get beam*2 topk sequences
      2. Among the topk, keep the top beam_size ones that haven't reached EOS
      into alive
      3. Among the topk, keep the top beam_size ones have reached EOS into
      finished
    Repeat
    To make things simple with using fixed size tensors, we will end
    up inserting unfinished sequences into finished in the beginning. To stop
    that we add -ve INF to the score of the unfinished sequence so that when a
    true finished sequence does appear, it will have a higher score than all the
    unfinished ones.

    Args:
      i: loop index
      alive_seq: Topk sequences decoded so far [batch_size, beam_size, i+1]
      alive_log_probs: probabilities of the beams. [batch_size, beam_size]
      finished_seq: Current finished sequences.
        [batch_size, beam_size, i+1]
      finished_scores: scores for each of these sequences.
        [batch_size, beam_size]
      finished_flags: finished bools for each of these sequences.
        [batch_size, beam_size]
      states: dict (possibly nested) of decoding states.

    Returns:
      Tuple of
        (Incremented loop index
         New alive sequences,
         Log probs of the alive sequences,
         New finished sequences,
         Scores of the new finished sequences,
         Flags indicating which sequence in finished as reached EOS,
         dict of final decoding states)
    """

    # Each inner loop, we carry out three steps:
    # 1. Get the current topk items.
    # 2. Extract the ones that have finished and haven't finished
    # 3. Recompute the contents of finished based on scores.
    topk_seq, topk_log_probs, topk_scores, topk_finished, states = grow_topk(
        i, alive_seq, alive_log_probs, states)
    alive_seq, alive_log_probs, _, states = grow_alive(
        topk_seq, topk_scores, topk_log_probs, topk_finished, states)
    finished_seq, finished_scores, finished_flags, _ = grow_finished(
        finished_seq, finished_scores, finished_flags, topk_seq, topk_scores,
        topk_finished)

    return (i + 1, alive_seq, alive_log_probs, finished_seq, finished_scores,
            finished_flags, states)

  def _is_not_finished(i, unused_alive_seq, alive_log_probs,
                       unused_finished_seq, finished_scores,
                       unused_finished_in_finished, unused_states):
    """Checking termination condition.

    We terminate when we decoded up to decode_length or the lowest scoring item
    in finished has a greater score that the highest prob item in alive divided
    by the max length penalty

    Args:
      i: loop index
      alive_log_probs: probabilities of the beams. [batch_size, beam_size]
      finished_scores: scores for each of these sequences.
        [batch_size, beam_size]

    Returns:
      Bool.
    """
    max_length_penalty = tf.pow(((5. + tf.to_float(decode_length)) / 6.), alpha)
    # The best possible score of the most likely alive sequence.
    lower_bound_alive_scores = alive_log_probs[:, 0] / max_length_penalty

    if not stop_early:
      # by considering the min score (in the top N beams) we ensure that
      # the decoder will keep decoding until there is at least one beam
      # (in the top N) that can be improved (w.r.t. the alive beams).
      # any unfinished beam will have score -INF - thus the min
      # will always be -INF if there is at least one unfinished beam -
      # which means the bound_is_met condition cannot be true in this case.
      lowest_score_of_finished_in_finished = tf.reduce_min(finished_scores)
    else:
      # by taking the max score we only care about the first beam;
      # as soon as this first beam cannot be beaten from the alive beams
      # the beam decoder can stop.
      # similarly to the above, if the top beam is not completed, its
      # finished_score is -INF, thus it will not activate the
      # bound_is_met condition. (i.e., decoder will keep going on).
      # note we need to find the max for every sequence eparately - so, we need
      # to keep the batch dimension (see axis=1)
      lowest_score_of_finished_in_finished = tf.reduce_max(finished_scores,
                                                           axis=1)

    bound_is_met = tf.reduce_all(
        tf.greater(lowest_score_of_finished_in_finished,
                   lower_bound_alive_scores))

    return tf.logical_and(
        tf.less(i, decode_length), tf.logical_not(bound_is_met))

  inner_shape = tf.TensorShape([None, None, None])
  if use_tpu:
    inner_shape = tf.TensorShape([batch_size, beam_size, decode_length + 1])
  if use_tpu:
    state_struc = nest.map_structure(lambda state: state.get_shape(), states)
  else:
    state_struc = nest.map_structure(get_state_shape_invariants, states)
  (_, alive_seq, alive_log_probs, finished_seq, finished_scores,
   finished_flags, states) = tf.while_loop(
       _is_not_finished,
       inner_loop, [
           tf.constant(0), alive_seq, alive_log_probs, finished_seq,
           finished_scores, finished_flags, states
       ],
       shape_invariants=[
           tf.TensorShape([]),
           inner_shape,
           alive_log_probs.get_shape(),
           inner_shape,
           finished_scores.get_shape(),
           finished_flags.get_shape(),
           state_struc
       ],
       parallel_iterations=1,
       back_prop=False)

  alive_seq.set_shape((None, beam_size, None))
  finished_seq.set_shape((None, beam_size, None))

  # Accounting for corner case: It's possible that no sequence in alive for a
  # particular batch item ever reached EOS. In that case, we should just copy
  # the contents of alive for that batch item. tf.reduce_any(finished_flags, 1)
  # if 0, means that no sequence for that batch index had reached EOS. We need
  # to do the same for the scores as well.
  finished_seq = tf.where(
      tf.reduce_any(finished_flags, 1), finished_seq, alive_seq)
  finished_scores = tf.where(
      tf.reduce_any(finished_flags, 1), finished_scores, alive_log_probs)
  return finished_seq, finished_scores, states
def setup(act_fun):
    channel_num = 3
    if FLAGS.mnist_model:
        print("------------------Using MNIST model------------")
        model = MnistNet(
            num_channels=channel_num,
            num_filters=128,
            act_fun=act_fun)
    elif FLAGS.large_model:
        print("------------------Using ResNet32Large model------------")
        model = ResNet32Large(
            num_channels=channel_num,
            num_filters=128,
            train=True,
            act_fun=act_fun)
    elif FLAGS.larger_model:
        print("------------------Using ResNet32Larger model------------")
        model = ResNet32Larger(
            num_channels=channel_num,
            num_filters=128,
            act_fun=act_fun)
    elif FLAGS.wider_model:
        print("------------------Using ResNet32Wider model------------")
        model = ResNet32Wider(
            num_channels=channel_num,
            num_filters=192,
            act_fun=act_fun)
    else:
        print("------------------Using ResNet32 model------------")
        model = ResNet32(
            num_channels=channel_num,
            num_filters=128,
            act_fun=act_fun)

    batch_size = FLAGS.batch_size
    weights = [model.construct_weights('context_0')]

    Y = tf.placeholder(shape=(None), dtype=tf.int32)
    LABEL = None
    X_NOISE = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
    X = tf.placeholder(shape=(None, 32, 32, 3), dtype=tf.float32)
    LABEL = tf.placeholder(shape=(None, 10), dtype=tf.float32)
    LABEL_POS = tf.placeholder(shape=(None, 10), dtype=tf.float32)
    # Varibles to run in training
    X_SPLIT = tf.split(X, FLAGS.num_gpus)
    X_NOISE_SPLIT = tf.split(X_NOISE, FLAGS.num_gpus)
    LABEL_SPLIT = tf.split(LABEL, FLAGS.num_gpus)
    LABEL_POS_SPLIT = tf.split(LABEL_POS, FLAGS.num_gpus)
    LABEL_SPLIT_INIT = list(LABEL_SPLIT)
    tower_grads = []
    tower_gen_grads = []
    x_mod_list = []

    optimizer = AdamOptimizer(FLAGS.lr, beta1=0.0, beta2=0.999)
    optimizer = hvd.DistributedOptimizer(optimizer)

    for j in range(FLAGS.num_gpus):
        if FLAGS.model_cclass:
            ind_batch_size = FLAGS.batch_size // FLAGS.num_gpus
            label_tensor = tf.Variable(
                tf.convert_to_tensor(
                    np.reshape(
                        np.tile(np.eye(10), (FLAGS.batch_size, 1, 1)),
                        (FLAGS.batch_size * 10, 10)),
                    dtype=tf.float32),
                trainable=False,
                dtype=tf.float32)
            x_split = tf.tile(
                tf.reshape(
                    X_SPLIT[j], (ind_batch_size, 1, 32, 32, 3)), (1, 10, 1, 1, 1))
            x_split = tf.reshape(x_split, (ind_batch_size * 10, 32, 32, 3))
            energy_pos = model.forward(
                x_split,
                weights[0],
                label=label_tensor,
                stop_at_grad=False)

            energy_pos_full = tf.reshape(energy_pos, (ind_batch_size, 10))
            energy_partition_est = tf.reduce_logsumexp(
                energy_pos_full, axis=1, keepdims=True)
            uniform = tf.random_uniform(tf.shape(energy_pos_full))
            label_tensor = tf.argmax(-energy_pos_full -
                                     tf.log(-tf.log(uniform)) - energy_partition_est, axis=1)
            label = tf.one_hot(label_tensor, 10, dtype=tf.float32)
            label = tf.Print(label, [label_tensor, energy_pos_full])
            LABEL_SPLIT[j] = label
            energy_pos = tf.concat(energy_pos, axis=0)
        else:
            energy_pos = [
                model.forward(
                    X_SPLIT[j],
                    weights[0],
                    label=LABEL_POS_SPLIT[j],
                    stop_at_grad=False)]
            energy_pos = tf.concat(energy_pos, axis=0)

        print("Building graph...")
        x_mod = x_orig = X_NOISE_SPLIT[j]

        x_grads = []

        energy_negs = []
        loss_energys = []

        energy_negs.extend([model.forward(tf.stop_gradient(
            x_mod), weights[0], label=LABEL_SPLIT[j], stop_at_grad=False, reuse=True)])
        eps_begin = tf.zeros(1)

        steps = tf.constant(0)
        c = lambda i, x: tf.less(i, FLAGS.num_steps)

        def langevin_step(counter, x_mod):
            x_mod = x_mod + tf.random_normal(tf.shape(x_mod),
                                             mean=0.0,
                                             stddev=0.005 * FLAGS.rescale * FLAGS.noise_scale)

            energy_noise = energy_start = tf.concat(
                [model.forward(
                        x_mod,
                        weights[0],
                        label=LABEL_SPLIT[j],
                        reuse=True,
                        stop_at_grad=False,
                        stop_batch=True)],
                axis=0)

            x_grad, label_grad = tf.gradients(
                FLAGS.temperature * energy_noise, [x_mod, LABEL_SPLIT[j]])
            energy_noise_old = energy_noise

            lr = FLAGS.step_lr

            if FLAGS.proj_norm != 0.0:
                if FLAGS.proj_norm_type == 'l2':
                    x_grad = tf.clip_by_norm(x_grad, FLAGS.proj_norm)
                elif FLAGS.proj_norm_type == 'li':
                    x_grad = tf.clip_by_value(
                        x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
                else:
                    print("Other types of projection are not supported!!!")
                    assert False

            # Clip gradient norm for now
            if FLAGS.hmc:
                # Step size should be tuned to get around 65% acceptance
                def energy(x):
                    return FLAGS.temperature * \
                        model.forward(x, weights[0], label=LABEL_SPLIT[j], reuse=True)

                x_last = hmc(x_mod, 15., 10, energy)
            else:
                x_last = x_mod - (lr) * x_grad

            x_mod = x_last
            x_mod = tf.clip_by_value(x_mod, 0, FLAGS.rescale)

            counter = counter + 1

            return counter, x_mod

        steps, x_mod = tf.while_loop(c, langevin_step, (steps, x_mod))

        energy_eval = model.forward(x_mod, weights[0], label=LABEL_SPLIT[j],
                                    stop_at_grad=False, reuse=True)
        x_grad = tf.gradients(FLAGS.temperature * energy_eval, [x_mod])[0]
        x_grads.append(x_grad)

        energy_negs.append(
            model.forward(
                tf.stop_gradient(x_mod),
                weights[0],
                label=LABEL_SPLIT[j],
                stop_at_grad=False,
                reuse=True))

        test_x_mod = x_mod

        temp = FLAGS.temperature

        energy_neg = energy_negs[-1]
        x_off = tf.reduce_mean(
            tf.abs(x_mod[:tf.shape(X_SPLIT[j])[0]] - X_SPLIT[j]))

        loss_energy = model.forward(
            x_mod,
            weights[0],
            reuse=True,
            label=LABEL,
            stop_grad=True)

        print("Finished processing loop construction ...")

        target_vars = {}

        if FLAGS.cclass or FLAGS.model_cclass:
            label_sum = tf.reduce_sum(LABEL_SPLIT[0], axis=0)
            label_prob = label_sum / tf.reduce_sum(label_sum)
            label_ent = -tf.reduce_sum(label_prob *
                                       tf.math.log(label_prob + 1e-7))
        else:
            label_ent = tf.zeros(1)

        target_vars['label_ent'] = label_ent

        if FLAGS.train:

            if FLAGS.objective == 'logsumexp':
                pos_term = temp * energy_pos
                energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
                coeff = tf.stop_gradient(tf.exp(-temp * energy_neg_reduced))
                norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
                pos_loss = tf.reduce_mean(temp * energy_pos)
                neg_loss = coeff * (-1 * temp * energy_neg) / norm_constant
                loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
            elif FLAGS.objective == 'cd':
                pos_loss = tf.reduce_mean(temp * energy_pos)
                neg_loss = -tf.reduce_mean(temp * energy_neg)
                loss_ml = FLAGS.ml_coeff * (pos_loss + tf.reduce_sum(neg_loss))
            elif FLAGS.objective == 'softplus':
                loss_ml = FLAGS.ml_coeff * \
                    tf.nn.softplus(temp * (energy_pos - energy_neg))

            loss_total = tf.reduce_mean(loss_ml)

            if not FLAGS.zero_kl:
                loss_total = loss_total + tf.reduce_mean(loss_energy)

            loss_total = loss_total + \
                FLAGS.l2_coeff * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square((energy_neg))))

            print("Started gradient computation...")
            gvs = optimizer.compute_gradients(loss_total)
            gvs = [(k, v) for (k, v) in gvs if k is not None]

            print("Applying gradients...")

            tower_grads.append(gvs)

            print("Finished applying gradients.")

            target_vars['loss_ml'] = loss_ml
            target_vars['total_loss'] = loss_total
            target_vars['loss_energy'] = loss_energy
            target_vars['weights'] = weights
            target_vars['gvs'] = gvs

        target_vars['X'] = X
        target_vars['Y'] = Y
        target_vars['LABEL'] = LABEL
        target_vars['LABEL_POS'] = LABEL_POS
        target_vars['X_NOISE'] = X_NOISE
        target_vars['energy_pos'] = energy_pos
        target_vars['energy_start'] = energy_negs[0]

        if len(x_grads) >= 1:
            target_vars['x_grad'] = x_grads[-1]
            target_vars['x_grad_first'] = x_grads[0]
        else:
            target_vars['x_grad'] = tf.zeros(1)
            target_vars['x_grad_first'] = tf.zeros(1)

        target_vars['x_mod'] = x_mod
        target_vars['x_off'] = x_off
        target_vars['temp'] = temp
        target_vars['energy_neg'] = energy_neg
        target_vars['test_x_mod'] = test_x_mod
        target_vars['eps_begin'] = eps_begin

    if FLAGS.train:
        grads = average_gradients(tower_grads)
        train_op = optimizer.apply_gradients(grads)
        target_vars['train_op'] = train_op

    config = tf.ConfigProto()

    if hvd.size() > 1:
        config.gpu_options.visible_device_list = str(hvd.local_rank())

    sess = tf.Session(config=config)
    saver = loader = tf.train.Saver(max_to_keep=30, keep_checkpoint_every_n_hours=6)

    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print("Model has a total of {} parameters".format(total_parameters))

    sess.run(tf.global_variables_initializer())

    resume_itr = 0

    if (FLAGS.resume_iter != -1 or not FLAGS.train) and hvd.rank() == 0:
        model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
        resume_itr = FLAGS.resume_iter
        # saver.restore(sess, model_file)
        optimistic_restore(sess, model_file)

    sess.run(hvd.broadcast_global_variables(0))
    return target_vars, saver, sess, resume_itr
示例#28
0
def _spsa_gradients(loss_fn, x, delta=0.01, num_samples=16, num_iterations=4):
    """Compute gradient estimates using SPSA.

  Args:
    loss_fn: Callable that takes a single argument of shape [batch_size, ...]
      and returns the loss contribution of each element of the batch as a
      tensor of shape [batch_size].
    x: List of tensors with a single element. We only support computation of
      the gradient of the loss with respect to x[0]. We take a list as input to
      keep the same API call as tf.gradients.
    delta: The gradients are computed by computing the loss within x - delta and
      x + delta.
    num_samples: The total number of random samples used to compute the gradient
      is `num_samples` times `num_iterations`. `num_samples` contributes to the
      gradient by tiling `x` `num_samples` times.
    num_iterations: The total number of random samples used to compute the
      gradient is `num_samples` times `num_iterations`. `num_iterations`
      contributes to the gradient by iterating using a `tf.while_loop`.

  Returns:
    List of tensors with a single element corresponding to the gradient of
    loss_fn(x[0]) with respect to x[0].
  """

    if len(x) != 1:
        raise NotImplementedError('SPSA gradients with respect to multiple '
                                  'variables is not supported.')
    # loss_fn takes a single argument.
    tensor = x[0]

    def _get_delta(x):
        return delta * tf.sign(
            tf.random_uniform(
                tf.shape(x), minval=-1., maxval=1., dtype=x.dtype))

    # Process batch_size samples at a time.
    def cond(i, *_):
        return tf.less(i, num_iterations)

    def loop_body(i, total_grad):
        """Compute gradient estimate."""
        batch_size = tf.shape(tensor)[0]
        # The tiled tensor has shape [num_samples, batch_size, ...]
        tiled_tensor = tf.expand_dims(tensor, axis=0)
        tiled_tensor = tf.tile(tiled_tensor,
                               [num_samples] + [1] * len(tensor.shape))
        # The tiled tensor has now shape [2, num_samples, batch_size, ...].
        delta = _get_delta(tiled_tensor)
        tiled_tensor = tf.stack([tiled_tensor + delta, tiled_tensor - delta],
                                axis=0)
        # Compute loss with shape [2, num_samples, batch_size].
        losses = loss_fn(
            tf.reshape(tiled_tensor, [2 * num_samples, batch_size] +
                       tensor.shape.as_list()[1:]))
        losses = tf.reshape(losses, [2, num_samples, batch_size])

        # Compute approximate gradient using broadcasting.
        shape = losses.shape.as_list() + [1] * (len(tensor.shape) - 1)
        shape = [(s or -1) for s in shape]  # Remove None.
        losses = tf.reshape(losses, shape)
        g = tf.reduce_mean((losses[0] - losses[1]) / (2. * delta), axis=0)
        return [i + 1, g / num_iterations + total_grad]

    _, g = tf.while_loop(cond,
                         loop_body,
                         loop_vars=[tf.constant(0.),
                                    tf.zeros_like(tensor)],
                         parallel_iterations=1,
                         back_prop=False)
    return [g]
  def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
    """The `model_fn` for TPUEstimator."""

    tf.logging.info("*** Features ***")
    for name in sorted(features.keys()):
      tf.logging.info("  name = %s, shape = %s" % (name, features[name].shape))

    unique_ids = features["unique_ids"]
    input_ids = features["input_ids"]
    segment_ids = features["segment_ids"]

    is_training = (mode == tf.estimator.ModeKeys.TRAIN)
    seq_length = modeling.get_shape_list(input_ids)[1]
    query_length = FLAGS.max_query_length
    batch_size = params["batch_size"]

    _, attention_mask = make_attention_mask(batch_size, query_length,
                                            seq_length)

    with tf.variable_scope("bert") as scope:
      word_logits = create_model(
          bert_config=bert_config,
          is_training=is_training,
          input_ids=input_ids,
          input_mask=attention_mask,
          segment_ids=segment_ids,
          use_one_hot_embeddings=use_one_hot_embeddings,
          scope=scope)

    if not is_training:
      with tf.variable_scope("bert", reuse=True) as scope:
        output_ids = input_ids
        word_id = tf.argmax(word_logits, axis=2, output_type=tf.int32)

        # This operation implements: output_ids[:, 2] = word_id[:, 0]
        word_id = tf.pad(word_id, [[0, 0], [2, seq_length - query_length]])
        output_ids = input_ids + word_id * tf.one_hot(
            2, seq_length, dtype=tf.int32)

        def body(i, ids):
          """A decoding step."""
          word_logits = create_model(
              bert_config=bert_config,
              is_training=is_training,
              input_ids=ids,
              input_mask=attention_mask,
              segment_ids=segment_ids,
              use_one_hot_embeddings=use_one_hot_embeddings,
              scope=scope)

          word_id = tf.argmax(word_logits, axis=2, output_type=tf.int32)

          # This operation implements: output_ids[:, 1 + i] = word_id[:, i - 1]
          word_id = tf.pad(word_id, [[0, 0], [2, seq_length - query_length]])
          return [
              i + 1,
              ids + word_id * tf.one_hot(i + 1, seq_length, dtype=tf.int32)
          ]

        i0 = tf.constant(2)
        c = lambda i, _: i < query_length - 1
        _, output_ids = tf.while_loop(c, body, loop_vars=[i0, output_ids])

    tvars = tf.trainable_variables()

    initialized_variable_names = {}
    scaffold_fn = None
    if init_checkpoint:
      (assignment_map, initialized_variable_names
      ) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
      if use_tpu:

        def tpu_scaffold():
          tf.train.init_from_checkpoint(init_checkpoint, assignment_map)
          return tf.train.Scaffold()

        scaffold_fn = tpu_scaffold
      else:
        tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

    tf.logging.info("**** Trainable Variables ****")
    for var in tvars:
      init_string = ""
      if var.name in initialized_variable_names:
        init_string = ", *INIT_FROM_CKPT*"
      tf.logging.info("  name = %s, shape = %s%s", var.name, var.shape,
                      init_string)

    output_spec = None
    if mode == tf.estimator.ModeKeys.TRAIN:
      # Computes the loss for word prediction.
      loss = tf.losses.sparse_softmax_cross_entropy(
          input_ids[:, 2:query_length],
          word_logits,
          reduction=tf.losses.Reduction.MEAN)

      train_op = optimization.create_optimizer(loss, learning_rate,
                                               num_train_steps,
                                               num_warmup_steps, use_tpu)

      output_spec = tf.estimator.tpu.TPUEstimatorSpec(
          mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn)

    elif mode == tf.estimator.ModeKeys.PREDICT:
      predictions = {
          "unique_ids": tf.identity(unique_ids),
          "input_ids": output_ids,
          "segment_ids": tf.minimum(segment_ids, 1),
          "input_mask": tf.to_int32(tf.not_equal(output_ids, 0)),
          "start_positions": tf.identity(features["start_positions"]),
          "end_positions": tf.identity(features["end_positions"]),
          "answer_types": tf.identity(features["answer_types"])
      }
      output_spec = tf.estimator.tpu.TPUEstimatorSpec(
          mode=mode, predictions=predictions, scaffold_fn=scaffold_fn)
    else:
      raise ValueError("Only TRAIN and PREDICT modes are supported: %s" %
                       (mode))

    return output_spec
示例#30
0
def sample_sequence(*,
                    hparams,
                    length,
                    start_token=None,
                    batch_size=None,
                    context=None,
                    temperature=1,
                    top_k=0):
    if start_token is None:
        assert context is not None, 'Specify exactly one of start_token and context!'
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
        context = tf.fill([batch_size, 1], start_token)

    def step(hparams, tokens, past=None):
        lm_output = model.model(hparams=hparams,
                                X=tokens,
                                past=past,
                                reuse=tf.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(
            model.past_shape(hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }

    with tf.name_scope('sample_sequence'):
        # Don't feed the last context token -- leave that to the loop below
        # TODO: Would be slightly faster if we called step on the entire context,
        # rather than leaving the last token transformer calculation to the while loop.
        context_output = step(hparams, context[:, :-1])

        def body(past, prev, output):
            next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
            logits = next_outputs['logits'][:, -1, :] / \
                tf.to_float(temperature)
            logits = top_k_logits(logits, k=top_k)
            samples = tf.multinomial(logits,
                                     num_samples=1,
                                     output_dtype=tf.int32)
            return [
                tf.concat([past, next_outputs['presents']], axis=-2),
                tf.squeeze(samples, axis=[1]),
                tf.concat([output, samples], axis=1),
            ]

        def cond(*args):
            return True

        _, _, tokens = tf.while_loop(
            cond=cond,
            body=body,
            maximum_iterations=length,
            loop_vars=[
                context_output['presents'],
                context[:, -1],
                context,
            ],
            shape_invariants=[
                tf.TensorShape(
                    model.past_shape(hparams=hparams, batch_size=batch_size)),
                tf.TensorShape([batch_size]),
                tf.TensorShape([batch_size, None]),
            ],
            back_prop=False,
        )

        return tokens