def _encoder_preprocessor(self, velocity_sequence, n_nodes, n_conn,
                              node_locations, node_connections):

        (senders, receivers,
         n_edge) = connectivity_utils.get_connectivity_for_batch_pyfunc(
             node_locations, node_connections, n_nodes, n_conn)

        node_features = []

        flat_velocity_sequence = snt.MergeDims(start=1,
                                               size=2)(velocity_sequence)
        node_features.append(flat_velocity_sequence)

        edge_features = []

        send = tf.gather(node_locations, senders)
        rec = tf.gather(node_locations, receivers)

        relative_displacements = send - rec
        edge_features.append(relative_displacements)
        relative_distances = tf.norm(relative_displacements,
                                     axis=-1,
                                     keepdims=True)
        edge_features.append(relative_distances)

        return gn.graphs.GraphsTuple(
            nodes=tf.concat(node_features, axis=-1),
            edges=tf.concat(edge_features, axis=-1),
            globals=None,  # self._graph_net will appending this to nodes.
            n_node=n_nodes,
            n_edge=n_edge,
            senders=senders,
            receivers=receivers,
        )
Exemplo n.º 2
0
    def sample(self, sample_shape=(), y=None, mean=False):
        """Draws a sample from the learnt distribution p(x).

        Args:
          sample_shape: `int` or 0D `Tensor` giving the number of samples to return.
            If  empty tuple (default value), 1 sample will be returned.
          y: Optional, the one hot label on which to condition the sample.
          mean: Boolean, if True the expected value of the output distribution is
            returned, otherwise samples from the output distribution.

        Returns:
          Sample tensor of shape `[B * N, ...]` where `B` is the batch size of
          the prior, `N` is the number of samples requested, and `...` represents
          the shape of the observations.

        Raises:
          ValueError: If both `sample_shape` and `n` are provided.
          ValueError: If `sample_shape` has rank > 0 or if `sample_shape`
          is an int that is < 1.
        """
        with tf.name_scope('{}_sample'.format(self.scope_name)):
            if y is None:
                y = tf.to_float(self.compute_prior().sample(sample_shape))

            if y.shape.ndims > 2:
                y = snt.MergeDims(start=0,
                                  size=y.shape.ndims - 1,
                                  name='merge_y')(y)

            z = self._latent_decoder(y, is_training=self._is_training)
            if mean:
                samples = self.predict(z.sample(), y).mean()
            else:
                samples = self.predict(z.sample(), y).sample()
        return samples
Exemplo n.º 3
0
    def _build(self, h, x, presence=None):
        """Builds the module.
    Args:
      h: Tensor of encodings of shape [B, n_enc_dims].
      x: Tensor of inputs of shape [B, n_points, n_input_dims]
      presence: Tensor of shape [B, n_points, 1] or None; if it exists, it
        indicates which input points exist.
    Returns:
      A bunch of stuff.
    """
        batch_size, n_input_points, _ = x.shape.as_list()

        capsule = _capsule.CapsuleLayer(self._n_caps, self._n_caps_dims,
                                        self._n_votes, **self._capsule_kwargs)

        res = capsule(h)
        res.transform = res.vote
        res.vote = math_ops.apply_transform(transform=res.vote)
        for k, v in res.items():
            if v.shape.ndims > 0:
                res[k] = snt.MergeDims(1, 2)(v)

        likelihood = _capsule.OrderInvariantCapsuleLikelihood(
            self._n_votes, res.vote, res.scale, res.vote_presence)
        ll_res = likelihood(x, presence)
        res.update(ll_res._asdict())

        # post processing
        mixing_probs = tf.nn.softmax(ll_res.mixing_logits, 1)
        prior_mixing_log_prob = tf.log(1. / n_input_points)
        mixing_kl = mixing_probs * (ll_res.mixing_log_prob -
                                    prior_mixing_log_prob)
        mixing_kl = tf.reduce_mean(tf.reduce_sum(mixing_kl, -1))

        wins_per_caps = tf.one_hot(ll_res.is_from_capsule, depth=self._n_caps)

        if presence is not None:
            wins_per_caps *= tf.expand_dims(presence, -1)

        wins_per_caps = tf.reduce_sum(wins_per_caps, 1)

        has_any_wins = tf.to_float(tf.greater(wins_per_caps, 0))
        should_be_active = tf.to_float(tf.greater(wins_per_caps, 1))

        sparsity_loss = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=should_be_active, logits=res.pres_logit_per_caps)

        sparsity_loss = tf.reduce_sum(sparsity_loss * has_any_wins, -1)
        sparsity_loss = tf.reduce_mean(sparsity_loss)

        caps_presence_prob = tf.reduce_max(
            tf.reshape(res.vote_presence,
                       [batch_size, self._n_caps, self._n_votes]), 2)

        res.update(
            dict(mixing_kl=mixing_kl,
                 sparsity_loss=sparsity_loss,
                 caps_presence_prob=caps_presence_prob,
                 mean_scale=tf.reduce_mean(res.scale)))
        return res
Exemplo n.º 4
0
  def _build_layers_v2(self, input_dict, num_outputs, options):
    delay = options["custom_options"]["delay"]
    assert(delay > 0)
    
    self.state_init = np.zeros([delay-1], np.int64)

    if not self.state_in:
      self.state_in = tf.placeholder(tf.int64, [None, delay-1], name="delayed_actions")

    delayed_actions = tf.concat([
        self.state_in,
        tf.expand_dims(input_dict["prev_actions"], 1)
    ], axis=1)

    self.state_out = delayed_actions[:, 1:]
    
    embedded_delayed_actions = tf.one_hot(delayed_actions, num_outputs)
    embedded_delayed_actions = snt.MergeDims(1, 2)(embedded_delayed_actions)

    trunk = snt.nets.MLP(
        output_sizes=options["fcnet_hiddens"],
        activation=getattr(tf.nn, options["fcnet_activation"]),
        activate_final=True)

    inputs = tf.concat([input_dict["obs"], embedded_delayed_actions], 1)
    trunk_outputs = trunk(input_dict["obs"])
    
    logits = snt.Linear(num_outputs)(trunk_outputs)
    
    return logits, trunk_outputs
Exemplo n.º 5
0
  def _build(self, x):
    batch_size = x.shape[0]
    img_embedding = self._encoder(x)

    splits = [self._n_caps_dims, self._n_features, 1]  # 1 for presence
    n_dims = sum(splits)

    if self._encoder_type == 'linear':
      n_outputs = self._n_caps * n_dims

      h = snt.BatchFlatten()(img_embedding)
      h = snt.Linear(n_outputs)(h)

    else:
      h = snt.AddBias(bias_dims=[1, 2, 3])(img_embedding)

      if self._encoder_type == 'conv':
        h = snt.Conv2D(n_dims * self._n_caps, 1, 1)(h)
        h = tf.reduce_mean(h, (1, 2))
        h = tf.reshape(h, [batch_size, self._n_caps, n_dims])

      elif self._encoder_type == 'conv_att':
        h = snt.Conv2D(n_dims * self._n_caps + self._n_caps, 1, 1)(h)
        h = snt.MergeDims(1, 2)(h)
        h, a = tf.split(h, [n_dims * self._n_caps, self._n_caps], -1)

        h = tf.reshape(h, [batch_size, -1, n_dims, self._n_caps])
        a = tf.nn.softmax(a, 1)
        a = tf.reshape(a, [batch_size, -1, 1, self._n_caps])
        h = tf.reduce_sum(h * a, 1)

      else:
        raise ValueError('Invalid encoder type="{}".'.format(
            self._encoder_type))

    h = tf.reshape(h, [batch_size, self._n_caps, n_dims])

    pose, feature, pres_logit = tf.split(h, splits, -1)
    if self._n_features == 0:
      feature = None

    pres_logit = tf.squeeze(pres_logit, -1)
    if self._noise_scale > 0.:
      pres_logit += ((tf.random.uniform(pres_logit.shape) - .5)
                     * self._noise_scale)


    pres = tf.nn.sigmoid(pres_logit)
    pose = math_ops.geometric_transform(pose, self._similarity_transform)
    return self.OutputTuple(pose, feature, pres, pres_logit, img_embedding)
Exemplo n.º 6
0
    def _build(self, data):

        input_x = self._img(data, False)
        target_x = self._img(data, prep=self._prep)
        batch_size = int(input_x.shape[0])

        primary_caps = self._primary_encoder(input_x)
        pres = primary_caps.presence

        expanded_pres = tf.expand_dims(pres, -1)
        pose = primary_caps.pose
        input_pose = tf.concat([pose, 1. - expanded_pres], -1)

        input_pres = pres
        if self._stop_grad_caps_inpt:
            input_pose = tf.stop_gradient(input_pose)
            input_pres = tf.stop_gradient(pres)

        target_pose, target_pres = pose, pres
        if self._stop_grad_caps_target:
            target_pose = tf.stop_gradient(target_pose)
            target_pres = tf.stop_gradient(target_pres)

        # skip connection from the img to the higher level capsule
        if primary_caps.feature is not None:
            input_pose = tf.concat([input_pose, primary_caps.feature], -1)

        # try to feed presence as a separate input
        # and if that works, concatenate templates to poses
        # this is necessary for set transformer
        n_templates = int(primary_caps.pose.shape[1])
        templates = self._primary_decoder.make_templates(
            n_templates, primary_caps.feature)

        try:
            if self._feed_templates:
                inpt_templates = templates
                if self._stop_grad_caps_inpt:
                    inpt_templates = tf.stop_gradient(inpt_templates)

                if inpt_templates.shape[0] == 1:
                    inpt_templates = snt.TileByDim(
                        [0], [batch_size])(inpt_templates)
                inpt_templates = snt.BatchFlatten(2)(inpt_templates)
                pose_with_templates = tf.concat([input_pose, inpt_templates],
                                                -1)
            else:
                pose_with_templates = input_pose

            h = self._encoder(pose_with_templates, input_pres)

        except TypeError:
            h = self._encoder(input_pose)

        res = self._decoder(h, target_pose, target_pres)
        res.primary_presence = primary_caps.presence

        if self._vote_type == 'enc':
            primary_dec_vote = primary_caps.pose
        elif self._vote_type == 'soft':
            primary_dec_vote = res.soft_winner
        elif self._vote_type == 'hard':
            primary_dec_vote = res.winner
        else:
            raise ValueError('Invalid vote_type="{}"".'.format(
                self._vote_type))

        if self._pres_type == 'enc':
            primary_dec_pres = pres
        elif self._pres_type == 'soft':
            primary_dec_pres = res.soft_winner_pres
        elif self._pres_type == 'hard':
            primary_dec_pres = res.winner_pres
        else:
            raise ValueError('Invalid pres_type="{}"".'.format(
                self._pres_type))

        res.bottom_up_rec = self._primary_decoder(
            primary_caps.pose,
            primary_caps.presence,
            template_feature=primary_caps.feature,
            img_embedding=primary_caps.img_embedding)

        res.top_down_rec = self._primary_decoder(
            res.winner,
            primary_caps.presence,
            template_feature=primary_caps.feature,
            img_embedding=primary_caps.img_embedding)

        rec = self._primary_decoder(primary_dec_vote,
                                    primary_dec_pres,
                                    template_feature=primary_caps.feature,
                                    img_embedding=primary_caps.img_embedding)

        tile = snt.TileByDim([0], [res.vote.shape[1]])
        tiled_presence = tile(primary_caps.presence)

        tiled_feature = primary_caps.feature
        if tiled_feature is not None:
            tiled_feature = tile(tiled_feature)

        tiled_img_embedding = tile(primary_caps.img_embedding)

        res.top_down_per_caps_rec = self._primary_decoder(
            snt.MergeDims(0, 2)(res.vote),
            snt.MergeDims(0, 2)(res.vote_presence) * tiled_presence,
            template_feature=tiled_feature,
            img_embedding=tiled_img_embedding)

        res.templates = templates
        res.template_pres = pres
        res.used_templates = rec.transformed_templates

        res.rec_mode = rec.pdf.mode()
        res.rec_mean = rec.pdf.mean()

        res.mse_per_pixel = tf.square(target_x - res.rec_mode)
        res.mse = math_ops.flat_reduce(res.mse_per_pixel)

        res.rec_ll_per_pixel = rec.pdf.log_prob(target_x)
        res.rec_ll = math_ops.flat_reduce(res.rec_ll_per_pixel)

        n_points = int(res.posterior_mixing_probs.shape[1])
        mass_explained_by_capsule = tf.reduce_sum(res.posterior_mixing_probs,
                                                  1)

        (res.posterior_within_sparsity_loss,
         res.posterior_between_sparsity_loss) = _capsule.sparsity_loss(
             self._posterior_sparsity_loss_type,
             mass_explained_by_capsule / n_points,
             num_classes=self._n_classes)

        (res.prior_within_sparsity_loss,
         res.prior_between_sparsity_loss) = _capsule.sparsity_loss(
             self._prior_sparsity_loss_type,
             res.caps_presence_prob,
             num_classes=self._n_classes,
             within_example_constant=self._prior_within_example_constant)

        label = self._label(data)
        if label is not None:
            res.posterior_cls_xe, res.posterior_cls_acc = probe.classification_probe(
                mass_explained_by_capsule,
                label,
                self._n_classes,
                labeled=data.get('labeled', None))
            res.prior_cls_xe, res.prior_cls_acc = probe.classification_probe(
                res.caps_presence_prob,
                label,
                self._n_classes,
                labeled=data.get('labeled', None))

        res.best_cls_acc = tf.maximum(res.prior_cls_acc, res.posterior_cls_acc)

        res.primary_caps_l1 = math_ops.flat_reduce(res.primary_presence)

        if self._weight_decay > 0.0:
            decay_losses_list = []
            for var in tf.trainable_variables():
                if 'w:' in var.name or 'weights:' in var.name:
                    decay_losses_list.append(tf.nn.l2_loss(var))
            res.weight_decay_loss = tf.reduce_sum(decay_losses_list)
        else:
            res.weight_decay_loss = 0.0

        return res
Exemplo n.º 7
0
def render_constellations(pred_points,
                          capsule_num,
                          canvas_size,
                          gt_points=None,
                          n_caps=2,
                          gt_presence=None,
                          pred_presence=None,
                          caps_presence_prob=None):
    """Renderes predicted and ground-truth points as gaussian blobs.

  Args:
    pred_points: [B, m, 2].
    capsule_num: [B, m] tensor indicating which capsule the corresponding point
      comes from. Plots from different capsules are plotted with different
      colors. Currently supported values: {0, 1, ..., 11}.
    canvas_size: tuple of ints
    gt_points: [B, k, 2]; plots ground-truth points if present.
    n_caps: integer, number of capsules.
    gt_presence: [B, k] binary tensor.
    pred_presence: [B, m] binary tensor.
    caps_presence_prob: [B, m], a tensor of presence probabilities for caps.

  Returns:
    [B, *canvas_size] tensor with plotted points
  """

    # convert coords to be in [0, side_length]
    pred_points = denormalize_coords(pred_points, canvas_size, rounded=True)

    # render predicted points
    batch_size, n_points = pred_points.shape[:2].as_list()
    capsule_num = tf.to_float(tf.one_hot(capsule_num, depth=n_caps))
    capsule_num = tf.reshape(capsule_num,
                             [batch_size, n_points, 1, 1, n_caps, 1])

    color = tf.convert_to_tensor(_COLORS[:n_caps])
    color = tf.reshape(color, [1, 1, 1, 1, n_caps, 3]) * capsule_num
    color = tf.reduce_sum(color, -2)
    color = tf.squeeze(tf.squeeze(color, 3), 2)

    colored = render_by_scatter(canvas_size, pred_points, color, pred_presence)

    # Prepare a vertical separator between predicted and gt points.
    # Separator is composed of all supported colors and also serves as
    # a legend.
    # [b, h, w, 3]
    n_colors = _COLORS.shape[0]
    sep = tf.reshape(tf.convert_to_tensor(_COLORS), [1, 1, n_colors, 3])
    n_tiles = int(colored.shape[2]) // n_colors
    sep = snt.TileByDim([0, 1, 3], [batch_size, 3, n_tiles])(sep)
    sep = tf.reshape(sep, [batch_size, 3, n_tiles * n_colors, 3])

    pad = int(colored.shape[2]) - n_colors * n_tiles
    pad, r = pad // 2, pad % 2

    if caps_presence_prob is not None:
        n_caps = int(caps_presence_prob.shape[1])
        prob_pads = ([0, 0], [0, n_colors - n_caps])
        caps_presence_prob = tf.pad(caps_presence_prob, prob_pads)
        zeros = tf.zeros([batch_size, 3, n_colors, n_tiles, 3],
                         dtype=tf.float32)

        shape = [batch_size, 1, n_colors, 1, 1]
        caps_presence_prob = tf.reshape(caps_presence_prob, shape)

        prob_vals = snt.MergeDims(2, 2)(caps_presence_prob + zeros)
        sep = tf.concat([sep, tf.ones_like(sep[:, :1]), prob_vals], 1)

    sep = tf.pad(sep, [(0, 0), (1, 1), (pad, pad + r), (0, 0)],
                 constant_values=1.)

    # render gt points
    if gt_points is not None:
        gt_points = denormalize_coords(gt_points, canvas_size, rounded=True)

        gt_rendered = render_by_scatter(canvas_size,
                                        gt_points,
                                        colors=None,
                                        gt_presence=gt_presence)

        colored = tf.where(tf.cast(colored, bool), colored, gt_rendered)
        colored = tf.concat([gt_rendered, sep, colored], 1)

    res = tf.clip_by_value(colored, 0., 1.)
    return res
Exemplo n.º 8
0
    def _init_helper(self,
                     observation_space,
                     action_space,
                     config,
                     existing_inputs=None):
        print(get_available_gpus())
        config = dict(impala.impala.DEFAULT_CONFIG, **config)
        assert config["batch_mode"] == "truncate_episodes", \
          "Must use `truncate_episodes` batch mode with V-trace."
        self.config = config

        self.sess = tf.get_default_session()
        self.grads = None

        imitation = config["imitation"]

        if imitation:
            T = config["sample_batch_size"]
            B = config["train_batch_size"] // T
            batch_shape = (T, B)
        else:
            batch_shape = (None, )

        if isinstance(action_space, gym.spaces.Discrete):
            is_multidiscrete = False
            actions_shape = batch_shape
            output_hidden_shape = [action_space.n]
        elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
            is_multidiscrete = True
            actions_shape = batch_shape + (len(action_space.nvec), )
            output_hidden_shape = action_space.nvec.astype(np.int32)
        else:
            raise UnsupportedSpaceException(
                "Action space {} is not supported for IMPALA.".format(
                    action_space))

        if imitation:
            make_action_ph = lambda: ssbm_actions.make_ph(
                ssbm_actions.flat_repeated_config, batch_shape)
            actions = make_action_ph()
            prev_actions = make_action_ph()
        else:
            actions = tf.placeholder(tf.int64, actions_shape, name="actions")
            prev_actions = tf.placeholder(tf.int64,
                                          actions_shape,
                                          name="prev_actions")

        # Create input placeholders
        dones = tf.placeholder(tf.bool, batch_shape, name="dones")
        rewards = tf.placeholder(tf.float32, batch_shape, name="rewards")
        if imitation:
            observations = ssbm_spaces.slippi_conv_list[0].make_ph(batch_shape)
        else:
            observations = tf.placeholder(tf.float32, [None] +
                                          list(observation_space.shape))

        existing_state_in = None
        existing_seq_lens = None

        # Setup the policy
        autoregressive = config.get("autoregressive")
        if autoregressive:
            logit_dim = 128  # not really logits
        else:
            dist_class, logit_dim = ModelCatalog.get_action_dist(
                action_space, self.config["model"])

        prev_rewards = tf.placeholder(tf.float32,
                                      batch_shape,
                                      name="prev_reward")
        self.model = HumanActionModel(
            {
                "obs": observations,
                "prev_actions": prev_actions,
                "prev_rewards": prev_rewards,
                "is_training": self._get_is_training_placeholder(),
            },
            observation_space,
            action_space,
            logit_dim,
            self.config["model"],
            imitation=imitation,
            state_in=existing_state_in,
            seq_lens=existing_seq_lens)

        if autoregressive:
            action_dist = ssbm_actions.AutoRegressive(
                nest.map_structure(lambda conv: conv.build_dist(),
                                   ssbm_actions.flat_repeated_config),
                residual=config.get("residual"))
            actions_logp = snt.BatchApply(action_dist.logp)(self.model.outputs,
                                                            actions)
            action_sampler, sampled_logp = snt.BatchApply(action_dist.sample)(
                self.model.outputs)
            sampled_prob = tf.exp(sampled_logp)
        else:
            dist_inputs = tf.split(self.model.outputs,
                                   output_hidden_shape,
                                   axis=-1)
            action_dist = dist_class(snt.MergeDims(0, 2)(dist_inputs))
            int64_actions = [tf.cast(x, tf.int64) for x in actions]
            actions_logp = action_dist.logp(snt.MergeDims(0, 2)(int64_actions))
            action_sampler = action_dist.sample()
            sampled_prob = action_dist.sampled_action_prob(),

        self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          tf.get_variable_scope().name)

        # actual loss computation
        imitation_loss = -tf.reduce_mean(actions_logp)

        tm_values = self.model.values
        baseline_values = tm_values[:-1]

        if config.get("soft_horizon"):
            discounts = config["gamma"]
        else:
            discounts = tf.to_float(~dones[:-1]) * config["gamma"]

        td_lambda = trfl.td_lambda(state_values=baseline_values,
                                   rewards=rewards[:-1],
                                   pcontinues=discounts,
                                   bootstrap_value=tm_values[-1],
                                   lambda_=config.get("lambda", 1.))

        # td_lambda.loss has shape [B] after a reduce_sum
        vf_loss = tf.reduce_mean(td_lambda.loss) / T

        self.total_loss = imitation_loss + self.config[
            "vf_loss_coeff"] * vf_loss

        # Initialize TFPolicyGraph
        loss_in = [
            (SampleBatch.ACTIONS, actions),
            (SampleBatch.DONES, dones),
            # (BEHAVIOUR_LOGITS, behaviour_logits),
            (SampleBatch.REWARDS, rewards),
            (SampleBatch.CUR_OBS, observations),
            (SampleBatch.PREV_ACTIONS, prev_actions),
            (SampleBatch.PREV_REWARDS, prev_rewards),
        ]
        LearningRateSchedule.__init__(self, self.config["lr"],
                                      self.config["lr_schedule"])
        TFPolicyGraph.__init__(
            self,
            observation_space,
            action_space,
            self.sess,
            obs_input=observations,
            action_sampler=action_sampler,
            action_prob=sampled_prob,
            loss=self.total_loss,
            model=self.model,
            loss_inputs=loss_in,
            state_inputs=self.model.state_in,
            state_outputs=self.model.state_out,
            prev_action_input=prev_actions,
            prev_reward_input=prev_rewards,
            seq_lens=self.model.seq_lens,
            max_seq_len=self.config["model"]["max_seq_len"],
            batch_divisibility_req=self.config["sample_batch_size"])

        self._loss_input_dict = dict(self._loss_inputs,
                                     state_in=self._state_inputs)

        self.sess.run(tf.global_variables_initializer())

        self.stats_fetches = {
            LEARNER_STATS_KEY: {
                "cur_lr":
                tf.cast(self.cur_lr, tf.float64),
                "imitation_loss":
                imitation_loss,
                #"entropy": self.loss.entropy,
                "grad_gnorm":
                tf.global_norm(self._grads),
                "var_gnorm":
                tf.global_norm(self.var_list),
                "vf_loss":
                vf_loss,
                "vf_explained_var":
                explained_variance(
                    tf.reshape(td_lambda.extra.discounted_returns, [-1]),
                    tf.reshape(baseline_values, [-1])),
            },
            "state_out": self.model.state_out,
        }
    def _encoder_preprocessor(self, position_sequence, n_node, global_context,
                              particle_types):
        # Extract important features from the position_sequence.
        most_recent_position = position_sequence[:, -1]
        velocity_sequence = time_diff(position_sequence)  # Finite-difference.

        # Get connectivity of the graph.
        (senders, receivers,
         n_edge) = connectivity_utils.compute_connectivity_for_batch_pyfunc(
             most_recent_position, n_node, self._connectivity_radius)

        # Collect node features.
        node_features = []

        # Normalized velocity sequence, merging spatial an time axis.
        velocity_stats = self._normalization_stats["velocity"]
        normalized_velocity_sequence = (
            velocity_sequence - velocity_stats.mean) / velocity_stats.std

        flat_velocity_sequence = snt.MergeDims(
            start=1, size=2)(normalized_velocity_sequence)
        node_features.append(flat_velocity_sequence)

        # Normalized clipped distances to lower and upper boundaries.
        # boundaries are an array of shape [num_dimensions, 2], where the second
        # axis, provides the lower/upper boundaries.
        boundaries = tf.constant(self._boundaries, dtype=tf.float32)
        distance_to_lower_boundary = (most_recent_position -
                                      tf.expand_dims(boundaries[:, 0], 0))
        distance_to_upper_boundary = (tf.expand_dims(boundaries[:, 1], 0) -
                                      most_recent_position)
        distance_to_boundaries = tf.concat(
            [distance_to_lower_boundary, distance_to_upper_boundary], axis=1)
        normalized_clipped_distance_to_boundaries = tf.clip_by_value(
            distance_to_boundaries / self._connectivity_radius, -1., 1.)
        node_features.append(normalized_clipped_distance_to_boundaries)

        # Particle type.
        if self._num_particle_types > 1:
            particle_type_embeddings = tf.nn.embedding_lookup(
                self._particle_type_embedding, particle_types)
            node_features.append(particle_type_embeddings)

        # Collect edge features.
        edge_features = []

        # Relative displacement and distances normalized to radius
        normalized_relative_displacements = (
            tf.gather(most_recent_position, senders) - tf.gather(
                most_recent_position, receivers)) / self._connectivity_radius
        edge_features.append(normalized_relative_displacements)

        normalized_relative_distances = tf.norm(
            normalized_relative_displacements, axis=-1, keepdims=True)
        edge_features.append(normalized_relative_distances)

        # Normalize the global context.
        if global_context is not None:
            context_stats = self._normalization_stats["context"]
            # Context in some datasets are all zero, so add an epsilon for numerical
            # stability.
            global_context = (global_context -
                              context_stats.mean) / tf.math.maximum(
                                  context_stats.std, STD_EPSILON)

        return gn.graphs.GraphsTuple(
            nodes=tf.concat(node_features, axis=-1),
            edges=tf.concat(edge_features, axis=-1),
            globals=
            global_context,  # self._graph_net will appending this to nodes.
            n_node=n_node,
            n_edge=n_edge,
            senders=senders,
            receivers=receivers,
        )
Exemplo n.º 10
0
    def _init_helper(self,
                     observation_space,
                     action_space,
                     config,
                     existing_inputs=None):
        config = dict(DEFAULT_CONFIG, **config)
        assert config["batch_mode"] == "truncate_episodes", \
          "Must use `truncate_episodes` batch mode with V-trace."
        self.config = config

        self.sess = tf.get_default_session()
        self.grads = None

        imitation = config["imitation"]
        assert not imitation

        if imitation:
            T = config["sample_batch_size"]
            B = config["train_batch_size"] // T
            batch_shape = (T, B)
        else:
            batch_shape = (None, )

        if isinstance(action_space, gym.spaces.Discrete):
            is_multidiscrete = False
            actions_shape = batch_shape
            output_hidden_shape = [action_space.n]
        elif isinstance(action_space, gym.spaces.multi_discrete.MultiDiscrete):
            is_multidiscrete = True
            actions_shape = batch_shape + (len(action_space.nvec), )
            output_hidden_shape = action_space.nvec.astype(np.int32)
        else:
            raise UnsupportedSpaceException(
                "Action space {} is not supported for IMPALA.".format(
                    action_space))

        assert is_multidiscrete

        if imitation:
            make_action_ph = lambda: ssbm_actions.make_ph(
                ssbm_actions.flat_repeated_config, batch_shape)
            actions = make_action_ph()
            prev_actions = make_action_ph()
        else:  # actions are stacked "multidiscrete"
            actions = tf.placeholder(tf.int64, actions_shape, name="actions")
            prev_actions = tf.placeholder(tf.int64,
                                          actions_shape,
                                          name="prev_actions")

        # Create input placeholders
        dones = tf.placeholder(tf.bool, batch_shape, name="dones")
        rewards = tf.placeholder(tf.float32, batch_shape, name="rewards")
        if imitation:
            observations = ssbm_spaces.slippi_conv_list[0].make_ph(batch_shape)
        else:
            observations = tf.placeholder(tf.float32, [None] +
                                          list(observation_space.shape))
            behavior_logp = tf.placeholder(tf.float32, batch_shape)

        existing_state_in = None
        existing_seq_lens = None

        # Setup the policy
        autoregressive = config.get("autoregressive")
        if autoregressive:
            logit_dim = 128  # not really logits
        else:
            dist_class, logit_dim = ModelCatalog.get_action_dist(
                action_space, self.config["model"])

        prev_rewards = tf.placeholder(tf.float32,
                                      batch_shape,
                                      name="prev_reward")
        self.model = HumanActionModel(
            {
                "obs": observations,
                "prev_actions": prev_actions,
                "prev_rewards": prev_rewards,
                "is_training": self._get_is_training_placeholder(),
            },
            observation_space,
            action_space,
            logit_dim,
            self.config["model"],
            imitation=imitation,
            state_in=existing_state_in,
            seq_lens=existing_seq_lens)

        # HumanActionModel doesn't flatten outputs
        flat_outputs = snt.MergeDims(0, 2)(self.model.outputs)

        if autoregressive:
            action_dist = ssbm_actions.AutoRegressive(
                nest.map_structure(lambda conv: conv.build_dist(),
                                   ssbm_actions.flat_repeated_config),
                residual=config.get("residual"))
            actions_logp, actions_entropy = action_dist.logp(
                flat_outputs, tf.unstack(actions, axis=-1))
            action_sampler, self.sampled_logp = action_dist.sample(
                flat_outputs)
            action_sampler = tf.stack(
                [tf.cast(t, tf.int64) for t in nest.flatten(action_sampler)],
                axis=-1)
            sampled_prob = tf.exp(self.sampled_logp)
        else:
            dist_inputs = tf.split(flat_outputs, output_hidden_shape, axis=-1)
            action_dist = dist_class(dist_inputs)
            int64_actions = [tf.cast(x, tf.int64) for x in actions]
            actions_logp = action_dist.logp(int64_actions)
            actions_entropy = action_dist.entropy()
            action_sampler = action_dist.sample()
            sampled_prob = action_dist.sampled_action_prob()
            self.sampled_logp = tf.log(sampled_prob)

        self.var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          tf.get_variable_scope().name)

        def make_time_major(tensor, drop_last=False):
            """Swaps batch and trajectory axis.
      Args:
        tensor: A tensor or list of tensors to reshape.
        drop_last: A bool indicating whether to drop the last
        trajectory item.
      Returns:
        res: A tensor with swapped axes or a list of tensors with
        swapped axes.
      """
            if isinstance(tensor, list):
                return [make_time_major(t, drop_last) for t in tensor]

            if self.model.state_init:
                B = tf.shape(self.model.seq_lens)[0]
                T = tf.shape(tensor)[0] // B
            else:
                # Important: chop the tensor into batches at known episode cut
                # boundaries. TODO(ekl) this is kind of a hack
                T = self.config["sample_batch_size"]
                B = tf.shape(tensor)[0] // T
            rs = tf.reshape(tensor,
                            tf.concat([[B, T], tf.shape(tensor)[1:]], axis=0))

            # swap B and T axes
            res = tf.transpose(
                rs,
                [1, 0] + list(range(2, 1 + int(tf.shape(tensor).shape[0]))))

            if drop_last:
                return res[:-1]
            return res

        # actual loss computation
        values_tm = make_time_major(self.model.value_function())
        baseline_values = values_tm[:-1]
        actions_logp_tm = make_time_major(actions_logp, True)
        behavior_logp_tm = make_time_major(behavior_logp, True)
        log_rhos_tm = actions_logp_tm - behavior_logp_tm

        discounts = tf.fill(tf.shape(baseline_values), config["gamma"])
        if not config.get("soft_horizon"):
            discounts *= tf.to_float(~make_time_major(dones, True))

        vtrace_returns = vtrace.from_importance_weights(
            log_rhos=log_rhos_tm,
            discounts=discounts,
            rewards=make_time_major(rewards, True),
            values=baseline_values,
            bootstrap_value=values_tm[-1])

        vf_loss = tf.reduce_mean(
            tf.squared_difference(vtrace_returns.vs, baseline_values))
        pi_loss = -tf.reduce_mean(
            actions_logp_tm * vtrace_returns.pg_advantages)
        entropy_mean = tf.reduce_mean(actions_entropy)

        total_loss = pi_loss
        total_loss += self.config["vf_loss_coeff"] * vf_loss
        total_loss -= self.config["entropy_coeff"] * entropy_mean
        self.total_loss = total_loss

        kl_mean = -tf.reduce_mean(log_rhos_tm)

        # Initialize TFPolicyGraph
        loss_in = [
            (SampleBatch.ACTIONS, actions),
            (SampleBatch.DONES, dones),
            ("behavior_logp", behavior_logp),
            (SampleBatch.REWARDS, rewards),
            (SampleBatch.CUR_OBS, observations),
            (SampleBatch.PREV_ACTIONS, prev_actions),
            (SampleBatch.PREV_REWARDS, prev_rewards),
        ]
        LearningRateSchedule.__init__(self, self.config["lr"],
                                      self.config["lr_schedule"])
        TFPolicyGraph.__init__(
            self,
            observation_space,
            action_space,
            self.sess,
            obs_input=observations,
            action_sampler=action_sampler,
            action_prob=sampled_prob,
            loss=self.total_loss,
            model=self.model,
            loss_inputs=loss_in,
            state_inputs=self.model.state_in,
            state_outputs=self.model.state_out,
            prev_action_input=prev_actions,
            prev_reward_input=prev_rewards,
            seq_lens=self.model.seq_lens,
            max_seq_len=self.config["model"]["max_seq_len"],
            batch_divisibility_req=self.config["sample_batch_size"])

        self.sess.run(tf.global_variables_initializer())

        self.stats_fetches = {
            LEARNER_STATS_KEY: {
                "cur_lr":
                tf.cast(self.cur_lr, tf.float64),
                "pi_loss":
                pi_loss,
                "entropy":
                entropy_mean,
                "grad_gnorm":
                tf.global_norm(self._grads),
                "var_gnorm":
                tf.global_norm(self.var_list),
                "vf_loss":
                vf_loss,
                "vf_explained_var":
                explained_variance(tf.reshape(vtrace_returns.vs, [-1]),
                                   tf.reshape(baseline_values, [-1])),
                "kl_mean":
                kl_mean,
            },
        }
Exemplo n.º 11
0
  def _build(self, h, x, presence=None):
    """Builds the module.

    Args:
      h: Tensor of encodings of shape [B, n_enc_dims].
      x: Tensor of inputs of shape [B, n_points, n_input_dims]
      presence: Tensor of shape [B, n_points, 1] or None; if it exists, it
        indicates which input points exist.

    Returns:
      A bunch of stuff.
    """
    batch_size, n_input_points, _ = x.shape.as_list()
    res = AttrDict(
        dynamic_weights_l2=tf.constant(0.)
    )

    output_shapes = (
        [1],  # per-capsule presence
        [self._n_votes],  # per-vote-presence
        [self._n_votes],  # per-vote scale
        [self._n_votes, self._n_caps_dims]
    )

    splits = [np.prod(i).astype(np.int32) for i in output_shapes]
    n_outputs = sum(splits)
    batch_mlp = neural.BatchMLP([self._n_hiddens, self._n_hiddens, n_outputs],
                                use_bias=True)

    all_params = batch_mlp(h)
    all_params = tf.split(all_params, splits, -1)
    batch_shape = [batch_size, self._n_caps]
    all_params = [tf.reshape(i, batch_shape + s)
                  for (i, s) in zip(all_params, output_shapes)]

    def add_noise(tensor):
      return tf.random.uniform(tensor.shape, minval=-.5, maxval=.5) * 4.

    res.pres_logit_per_caps = add_noise(all_params[0])
    res.pres_logit_per_vote = add_noise(all_params[1])
    res.scale = tf.nn.softplus(all_params[2] + .5) + 1e-6
    res.vote_presence = (tf.nn.sigmoid(res.pres_logit_per_caps)
                         * tf.nn.sigmoid(res.pres_logit_per_vote))
    res.vote = all_params[3]

    for k, v in res.items():
      if v.shape.ndims > 0:
        res[k] = snt.MergeDims(1, 2)(v)

    likelihood = _capsule.OrderInvariantCapsuleLikelihood(self._n_votes,
                                                          res.vote, res.scale,
                                                          res.vote_presence)
    ll_res = likelihood(x, presence)
    res.update(ll_res._asdict())

    # post processing
    mixing_probs = tf.nn.softmax(ll_res.mixing_logits, 1)
    prior_mixing_log_prob = tf.log(1. / n_input_points)
    mixing_kl = mixing_probs * (ll_res.mixing_log_prob - prior_mixing_log_prob)
    mixing_kl = tf.reduce_mean(tf.reduce_sum(mixing_kl, -1))

    wins_per_caps = tf.one_hot(ll_res.is_from_capsule, depth=self._n_caps)

    if presence is not None:
      wins_per_caps *= tf.expand_dims(presence, -1)

    wins_per_caps = tf.reduce_sum(wins_per_caps, 1)

    has_any_wins = tf.to_float(tf.greater(wins_per_caps, 0))
    should_be_active = tf.to_float(tf.greater(wins_per_caps, 1))

    sparsity_loss = tf.nn.sigmoid_cross_entropy_with_logits(
        labels=should_be_active, logits=res.pres_logit_per_caps)

    sparsity_loss = tf.reduce_sum(sparsity_loss * has_any_wins, -1)
    sparsity_loss = tf.reduce_mean(sparsity_loss)

    caps_presence_prob = tf.reduce_max(
        tf.reshape(res.vote_presence,
                   [batch_size, self._n_caps, self._n_votes]), 2)

    res.update(dict(
        mixing_kl=mixing_kl,
        sparsity_loss=sparsity_loss,
        caps_presence_prob=caps_presence_prob,
        mean_scale=tf.reduce_mean(res.scale)
    ))
    return res
Exemplo n.º 12
0
 def value_function(self):
   return snt.MergeDims(0, 2)(self.values)