Beispiel #1
0
def plot_sampled_minibatch(sampled_output,
                           experts=None,
                           figsize=(8, 20),
                           partial_write_np_image_to_tb=None,
                           without_samples=False,
                           plot_bev_kwargs={},
                           tensorstr='',
                           fig=None,
                           axes=None,
                           eager=False):
    fig, axes = get_figure(fig=fig, axes=axes, figsize=figsize)
    B = sampled_output.phi_metadata.B
    A = sampled_output.rollout.event_shape[0]
    limit = tensoru.size(sampled_output.phi.overhead_features, 1)
    # Make the expert color different from the model if there's only one agent.
    expert_kwargs = {'color': COLORS[1]} if A == 1 else {}
    for b in range(min(B, 10)):
        ax = axes.ravel()[b]
        magic_reset_axis(ax)
        if not without_samples:
            plot_single_sampled_output(sampled_output,
                                       batch_index=b,
                                       render=False,
                                       fig=fig,
                                       ax=ax)

        plot_past(sampled_output.phi.S_past_grid_frame,
                  b=b,
                  fig=fig,
                  ax=ax,
                  limit=limit)
        plot_rollout(sampled_output.rollout, fig=fig, ax=ax, b=b)
        plot_bev(sampled_output, batch_index=b, ax=ax, **plot_bev_kwargs)
        if experts is not None:
            plot_expert(experts,
                        batch_index=b,
                        fig=fig,
                        ax=ax,
                        render=False,
                        limit=tensoru.size(
                            sampled_output.phi.overhead_features, 1),
                        **expert_kwargs)
    # TODO make work with eager mode.
    plot_figure('sampled_minibatch' + tensorstr,
                fig,
                partial_write_np_image_to_tb=partial_write_np_image_to_tb)
    if not eager: figsclose()
    return fig
Beispiel #2
0
def get_map_feats(feature_map,
                  batch_shape,
                  batch_size,
                  last_agent_positions_grid,
                  A,
                  phi=None):
    """Interpolate the positions into the feature map.

    :param feature_map: (B, H, W, F), feature map
    :param batch_shape: first several dimensions of positions. could be inferred but we require it for sanity checks.
    :param batch_size: size of batch. could be inferred but we require it for sanity checks.
    :param last_agent_positions_grid: batch_shape + (A, D)
    :param A: number of agents. could be inferred but we require it for sanity checks.
    :param F: feature map dimension. could be inferred but we require it for sanity checks.
    :param phi: DEPRECATED
    :returns: 
    :rtype: 

    """
    assert (tensoru.rank(feature_map) == 4)
    assert (tensoru.rank(last_agent_positions_grid) >= 3)
    F = tensoru.size(feature_map, -1)

    # (B, batch_shape[1:]*A, d)
    last_agent_positions_grid_r_xy = tf.reshape(last_agent_positions_grid,
                                                (batch_shape[0], -1, 2))
    last_agent_positions_grid_r_ij = last_agent_positions_grid_r_xy[..., ::-1]
    # (B, batch_shape[1:]*A, F)
    # N.B. the indexing order! Our points are stored in xy-format, with x corresponding to the W dimension of the feature grid.
    map_feat = tfu.interpolate_bilinear(feature_map,
                                        last_agent_positions_grid_r_ij,
                                        indexing='ij')
    # (BSize, A, F)
    map_feats = tf.reshape(map_feat, (batch_size, A, F))
    return map_feats
Beispiel #3
0
    def _static_prepare_samples(self, global_step):
        for t in self.input_singleton.sample_placeholders:
            self.model_collections.add_sample_input(t)
        for t in self.input_singleton.placeholders:
            self.model_collections.add_infer_input(t)

        # Track some intermediate tensors that we'll want to reconstruct when we load the model.
        self.model_collections.add_intermediate_input(
            self.input_singleton.phi.S_past_car_frames)
        self.model_collections.add_intermediate_input(
            self.input_singleton.phi.S_past_grid_frame)
        self.model_collections.add_intermediate_input(
            self.input_singleton.phi_m.agent_counts)
        self.model_collections.add_intermediate_label(
            self.input_singleton.experts.S_future_car_frames)
        self.model_collections.add_intermediate_label(
            self.input_singleton.experts.S_future_grid_frame)

        log.info("Computing static-mode samples.")
        self.sampled_output = self.model_distribution.sample(
            phi=self.input_singleton.phi,
            phi_metadata=self.input_singleton.phi_m,
            T=self.dataset.T)
        # dbg
        self.model_distribution.bijection.check_gradients(
            self.sampled_output.base_and_log_q.Z_sample)
        # Z input is a sample input
        self.model_collections.add_sample_input(
            self.sampled_output.base_and_log_q.Z_sample)

        # Record outputs for sampling.
        for t in self.sampled_output.rollout.rollout_outputs:
            self.model_collections.add_sample_output(t)
        # Log q is a sample output.
        self.model_collections.add_sample_output(
            self.sampled_output.base_and_log_q.log_q_samples)

        # Functions to write a numpy image into tensorboard. The function takes a single input (np.ndarray)
        self.partial_np2tb_smbs = []
        for c in range(
                tensoru.size(self.input_singleton.phi.overhead_features, -1)):
            self.partial_np2tb_smbs.append(
                plot.bind_write_np_image_to_tb(
                    sess=self.sess,
                    writer=self.writer,
                    global_step=global_step,
                    key='sampled_minibatch_bev_{}'.format(c)))
        self.partial_np2tb_smb_joint = plot.bind_write_np_image_to_tb(
            sess=self.sess,
            writer=self.writer,
            global_step=global_step,
            key='sampled_minibatch_bev_joint')
        self.partial_np2tb_cnn0 = plot.bind_write_np_image_to_tb(
            sess=self.sess,
            writer=self.writer,
            global_step=global_step,
            key='CNN_0')

        log.info("Initializing variables...")
        self.sess.run(tfv1.global_variables_initializer())
Beispiel #4
0
def plot_whiskers(rollout,
                  batch_index=0,
                  k_slice=slice(0, 1),
                  a_slice=slice(0, 5),
                  period=5,
                  fig=None,
                  ax=None,
                  render=True):
    A, T, d = rollout.event_shape
    whiskers = [
        _['whiskers_grid'][batch_index, k_slice, a_slice]
        for _ in rollout.metadata_list[::period]
    ]
    frames = [_['local2grid'] for _ in rollout.metadata_list]
    origins = tf.stack([_.t[batch_index, k_slice, a_slice] for _ in frames],
                       axis=-2)
    limit = tensoru.size(rollout.phi.overhead_features, 1)

    # plot_joint_trajectory(whiskers[0], key='whiskers0', render=render, fig=fig, ax=ax, marker='d', zorder=2, alpha=0.5, limit=limit)

    plot_joint_trajectory(origins,
                          key='origins',
                          render=render,
                          fig=fig,
                          ax=ax,
                          marker='+',
                          zorder=2,
                          alpha=0.5,
                          limit=limit)

    #    plot_joint_trajectory(whiskers[-1], key='whiskers1', render=render, fig=fig, ax=ax, marker='d', zorder=2, alpha=0.5, limit=limit)
    return fig, ax
Beispiel #5
0
def get_whisker_map_feats(feature_map,
                          batch_shape,
                          batch_size,
                          cars2grid,
                          template_cars,
                          A,
                          phi=None):
    """Interpolate the positions into the feature map.

    :param feature_map: (B, H, W, F), feature map
    :param batch_shape: first several dimensions of positions. could be inferred but we require it for sanity checks.
    :param batch_size: size of batch. could be inferred but we require it for sanity checks.
    :param last_agent_positions_cars: batch_shape + (A, D)
    :param cars2grid: SimilarityTransform from car frames to grid
    :param template: (N, D) template of positions in local frame at which to interpolate
    :param A: number of agents. could be inferred but we require it for sanity checks.
    :param F: feature map dimension. could be inferred but we require it for sanity checks.
    :param phi: DEPRECATED
    :returns: 
    :rtype: 

    """

    assert (tensoru.rank(feature_map) == 4)
    F = tensoru.size(feature_map, -1)

    if len(batch_shape) == 2:
        points_ein = 'bkaNj'
        template_cars = template_cars[None]
    elif len(batch_shape) == 1:
        points_ein = 'baNj'
    else:
        raise ValueError

    n_whiskers = tensoru.size(template_cars, -2)
    whiskers_grid = cars2grid.apply(template_cars, points_ein=points_ein)
    whiskers_grid_r_xy = tf.reshape(whiskers_grid, (batch_shape[0], -1, 2))
    # (B, batch_shape[1:]*A, F)
    # N.B. the indexing order! Our points are stored in xy-format, with x corresponding to the W dimension of the feature grid.
    map_feat = tfu.interpolate_bilinear(feature_map,
                                        whiskers_grid_r_xy,
                                        indexing='xy')
    # (B, ..., A, n_whiskers, F)
    map_feat_r = tf.reshape(map_feat, batch_shape + (A, n_whiskers, F))
    # (B*..., A, n_whiskers*F)
    map_feats = tf.reshape(map_feat_r, (batch_size, A, n_whiskers * F))
    return map_feats, whiskers_grid
Beispiel #6
0
    def __init__(self,
                 S_past_world_frame,
                 yaws,
                 overhead_features,
                 agent_presence,
                 feature_pixels_per_meter,
                 is_training,
                 light_strings=None,
                 yaws_in_degrees=True,
                 past_perturb_epsilon=5e-2,
                 name=False):
        """

        :param S_past_world_frame: (B, A, T, D)
        :param yaws: (B, A)
        :param overhead_features: (B, H, W, C)
        :param agent_presence: (B, A)
        :param yaws_in_degrees: bool
        """
        assert (feature_pixels_per_meter >= 1)
        if light_strings is None:
            light_strings = np.array(['NONE'] * tensoru.size(yaws, 0),
                                     dtype=np.unicode_)

        # Overwrite these members with tensorized versions of them.
        self.tensor_init(S_past_world_frame, yaws, overhead_features,
                         agent_presence, light_strings, is_training)

        self._frames_init()

        past_noise_world_frame = tf.random.normal(
            mean=0.,
            stddev=self.past_perturb_epsilon,
            shape=tensoru.shape(self.S_past_world_frame),
            dtype=tf.float64)
        # Always create an alternative noisy-past (learning may not use it).
        self.S_past_world_frame_noisy = self.S_past_world_frame + past_noise_world_frame

        # (B, A, T, D)
        self.S_past_car_frames = self.world2local.apply(
            self.S_past_world_frame)
        self.S_past_car_frames_noisy = self.world2local.apply(
            self.S_past_world_frame_noisy)
        # (B, A, T, D)
        self.S_past_grid_frame = self.world2grid.apply(self.S_past_world_frame,
                                                       dtype=tf.float64)
        self.S_past_grid_frame_noisy = self.world2grid.apply(
            self.S_past_world_frame_noisy, dtype=tf.float64)

        self.light_features = one_hotify_light_strings(self.light_strings)

        if name:
            # Name some intermediate computations (not placeholders)
            self.S_past_grid_frame = tf.identity(self.S_past_grid_frame,
                                                 name='S_past_grid_frame')
            self.S_past_car_frames = tf.identity(self.S_past_car_frames,
                                                 name='S_past_car_frames')
            self.light_features = tf.identity(self.light_features,
                                              name='light_features')
Beispiel #7
0
def plot_rollout(rollout, b=None, fig=None, ax=None):
    assert (b is not None)
    limit = tensoru.size(rollout.phi.overhead_features, 1)
    plot_joint_trajectory(joint_traj=rollout.S_grid_frame[b],
                          key='rollout_future',
                          render=False,
                          fig=fig,
                          ax=ax,
                          marker='o',
                          zorder=3,
                          alpha=0.5,
                          limit=limit)
Beispiel #8
0
def plot_sample(sampled_output,
                expert=None,
                b=0,
                figsize=(4, 4),
                partial_write_np_image_to_tb=None,
                bev_kwargs={}):
    fig, ax = plt.subplots(1, 1, figsize=figsize)
    magic_reset_axis(ax)
    plot_single_sampled_output(sampled_output,
                               batch_index=b,
                               render=False,
                               fig=fig,
                               ax=ax)
    plot_bev(sampled_output, batch_index=b, ax=ax, **bev_kwargs)
    A = tensoru.size(sampled_output.rollout.S_car_frames, 2)
    expert_kwargs = {'color': COLORS[1]} if A == 1 else {}
    limit = tensoru.size(sampled_output.phi.overhead_features, 1)
    plot_past(sampled_output.phi.S_past_grid_frame,
              b=b,
              fig=fig,
              ax=ax,
              limit=limit)
    if expert is not None:
        plot_expert(expert,
                    batch_index=b,
                    fig=fig,
                    ax=ax,
                    render=False,
                    limit=tensoru.size(sampled_output.phi.overhead_features,
                                       1),
                    **expert_kwargs)
    res = plot_figure(
        'sampled_minibatch',
        fig,
        partial_write_np_image_to_tb=partial_write_np_image_to_tb)
    plt.close('all')
    return res
Beispiel #9
0
 def __init__(self, phi, metadata_list, name=False):
     name = None if not name else 'agent_counts'
     assert (isinstance(self.metadata_list, MetadataList))
     assert (all([isinstance(_, MetadataItem) for _ in metadata_list]))
     self.agent_counts = tf.reduce_sum(tf.cast(phi.agent_presence,
                                               tf.float64),
                                       axis=1,
                                       name=name)
     B0 = self.phi.S_past_car_frames.shape[0]
     B1 = self.phi.yaws.shape[0]
     B2 = self.phi.overhead_features.shape[0]
     B3 = self.phi.agent_presence.shape[0]
     assert (B0 == B1 == B2 == B3)
     self.B = B0.value
     self.H = tensoru.size(self.phi.overhead_features, 1)
     if len(self.metadata_list):
         self.tensor_init(**self.metadata_list.to_dict())
Beispiel #10
0
    def _static_plot(self, minibatch):
        log.debug("Static plotting")
        sessrun = functools.partial(self.sess.run, feed_dict=minibatch)
        # Convert data to numpy to prepare for plotting.
        sampled_output_np = self.sampled_output.to_numpy(sessrun)
        experts_np = self.input_singleton.experts.to_numpy(sessrun)

        # Plot things over every channel of the BEV.
        for c in range(
                tensoru.size(self.input_singleton.phi.overhead_features, -1)):
            plot_bev_kwargs = {
                'onechannel': True,
                'channel_idx': c,
                'allchannel': False
            }
            plot.plot_sampled_minibatch(
                sampled_output=sampled_output_np,
                experts=experts_np,
                partial_write_np_image_to_tb=self.partial_np2tb_smbs[c],
                figsize=self.figsize,
                without_samples=self.plot_without_samples,
                plot_bev_kwargs=plot_bev_kwargs,
                tensorstr='_bev-{}'.format(c))

        plot_bev_kwargs = {'onechannel': False, 'allchannel': False}
        plot.plot_sampled_minibatch(
            sampled_output=sampled_output_np,
            experts=experts_np,
            partial_write_np_image_to_tb=self.partial_np2tb_smb_joint,
            figsize=self.figsize,
            without_samples=self.plot_without_samples,
            plot_bev_kwargs=plot_bev_kwargs,
            tensorstr='_bev-joint')

        try:
            feature_map = sessrun(
                self.model_distribution.bijection.feature_map)
            feature_map_fig_w = int(np.ceil(np.sqrt(feature_map.shape[-1])))
            plot.plot_feature_map(
                feature_map=feature_map,
                partial_write_np_image_to_tb=self.partial_np2tb_cnn0,
                nrows=feature_map_fig_w,
                ncols=feature_map_fig_w)
        except AttributeError as e:
            log.error(e)
Beispiel #11
0
def plot_single_sampled_output(sampled_output,
                               batch_index=0,
                               fig=None,
                               ax=None,
                               render=True):
    S = sampled_output.rollout.S_grid_frame
    S_past = sampled_output.phi.S_past_grid_frame
    limit = tensoru.size(sampled_output.phi.overhead_features, 1)
    # Plot future.
    fig, ax = plot_joint_trajectory(S[batch_index],
                                    key='sampled_trajectories',
                                    render=False,
                                    marker='o',
                                    zorder=1,
                                    alpha=.4,
                                    fig=fig,
                                    ax=ax,
                                    limit=limit)
    return fig, ax
Beispiel #12
0
def plot_joint_trajectory(joint_traj,
                          key='joint_trajectories',
                          fig=None,
                          ax=None,
                          render=True,
                          limit=100,
                          **kwargs):
    """

    :param joint_traj: K x A x T x d
    :param key: 
    :param fig: 
    :param ax: 
    :param render: 
    :param kwargs: 
    :returns: 
    :rtype: 

    """

    assert (tensoru.rank(joint_traj) == 4)
    A = tensoru.size(joint_traj, 1)
    for a in range(A):
        render_a = (a == A - 1) and render
        color = kwargs.get('color', COLORS[a])
        if isinstance(joint_traj, tf.Tensor):
            single_traj = joint_traj[:, a].numpy()
        else:
            single_traj = joint_traj[:, a]
        kwargs.pop('color', None)
        fig, ax = plot_trajectory(key,
                                  single_traj,
                                  color=color,
                                  render=render_a,
                                  fig=fig,
                                  ax=ax,
                                  axis=[0, limit, limit, 0],
                                  **kwargs)
        assert (fig is not None)
    return fig, ax
Beispiel #13
0
def plot_joint_trajectory(joint_traj,
                          limit,
                          scale=1,
                          agents=None,
                          fig=None,
                          ax=None,
                          **kwargs):
    assert (tensoru.rank(joint_traj) == 4)

    if agents is None:
        agents = range(tensoru.size(joint_traj, 1))

    for a in agents:

        if "color" in kwargs:
            color = kwargs.pop("color")
        else:
            color = cm.get_cmap("tab10").colors[a]

        if isinstance(joint_traj, tf.Tensor):
            single_traj = joint_traj[:, a].numpy().copy()
        else:
            single_traj = joint_traj[:, a].copy()

        if scale > 0:
            single_traj *= scale
        single_traj[..., -1] *= -1

        fig, ax = plot_trajectory(single_traj,
                                  color=color,
                                  fig=fig,
                                  ax=ax,
                                  axis=limit,
                                  **kwargs)
        assert (fig is not None)
    return fig, ax
Beispiel #14
0
    def _prepare(self, batch_shape, phi):
        self.t = 0
        ldb = lambda x: log.info("Step {}, ".format(self.t) + x)
        self.step_generate_record.append({})
        assert (isinstance(batch_shape, tuple))
        # (B, ...), e.g. (B,); (B,K), etc.
        self.batch_shape = batch_shape
        self.batch_size = functools.reduce(operator.mul, self.batch_shape)

        self.S_past_car_frames = tf.cond(
            tf.logical_and(phi.is_training,
                           tf.convert_to_tensor(self.rnnconf.past_perturb)),
            lambda: phi.S_past_car_frames_noisy, lambda: phi.S_past_car_frames)

        # To represent tensors packed to shape (flattened_batch * A_agents, ...), e.g. (B*K*A, ...)
        self.A_batch_size = self.batch_size * self.A
        self.rnn_state = self.rnn.zero_state(self.A_batch_size,
                                             dtype=tf.float64)
        self.mu_shape = self.batch_shape + (self.A, self.D)
        self.sigma_shape = self.batch_shape + (self.A, self.D, self.D)

        if len(self.batch_shape) == 2: self.batch_str = 'bk'
        elif len(self.batch_shape) == 1: self.batch_str = 'b'
        else: raise ValueError("Unhandled batch size")

        if len(self.batch_shape) == 2:
            # (B, K, A)
            yaws_batch = tensoru.expand_and_tile_axis(phi.yaws,
                                                      axis=1,
                                                      N=self.batch_shape[1])
            # (BKA, 1)
            self.yaws_A_batch = tf.reshape(yaws_batch, (self.A_batch_size, 1))
        elif len(self.batch_shape) == 1:
            # (BA, 1)
            self.yaws_A_batch = tf.reshape(phi.yaws, (self.A_batch_size, 1))
        else:
            raise ValueError("Unhandled batch size")

        # Create past encodings once.
        if self.past_encodings is None:
            if self.rnnconf.past_do_preconv:
                ldb("Doing preconv")
                # 1D-convolve over the past states before plugging them into the RNN.
                self.preconv_W = tf.Variable(tf.ones(
                    (self.rnnconf.preconv_horizon, self.D,
                     self.rnnconf.past_gru_units),
                    dtype=tf.float64),
                                             name="W_preconv")
                self.preconv_b = tf.Variable(1e-5 * tf.ones(
                    (self.rnnconf.past_gru_units, ), dtype=tf.float64),
                                             name="b_preconv")
                past_rnn_inputs = [
                    tf.nn.conv1d(self.S_past_car_frames[:, a],
                                 self.preconv_W,
                                 stride=1,
                                 padding="SAME") + self.preconv_b
                    for a in range(self.A)
                ]
            else:
                ldb("Not doing preconv")
                past_rnn_inputs = [
                    self.S_past_car_frames[:, a] for a in range(self.A)
                ]
            # For every agent, run the RNN cell and retrieve the state at the last time step.
            self.past_encodings = [
                tf.nn.dynamic_rnn(cell=self.past_rnn,
                                  inputs=past_rnn_inputs[a],
                                  initial_state=None,
                                  time_major=False,
                                  dtype=tf.float64)[1] for a in range(self.A)
            ]
        # Create feature map once.
        if self.feature_map is None:
            if self.cnnconf.create_residual_connections:
                ldb("Creating a CNN with residual connections")
                if self.cnnconf.do_batchnorm:
                    self.feature_map = tensoru.convnet_with_residuals_and_batchnorm(
                        self.convnet,
                        self.batchnorms,
                        phi.overhead_features,
                        skip_indices=set([0, self.cnnconf.n_conv_layers - 1]),
                        is_training=phi.is_training)
                else:
                    self.feature_map = tensoru.convnet_with_residuals(
                        self.convnet,
                        phi.overhead_features,
                        skip_indices=set([0, self.cnnconf.n_conv_layers - 1]))
            else:
                ldb("Creating a CNN without residual connections")
                self.feature_map = tensoru.convnet(self.convnet,
                                                   phi.overhead_features)
            if self.cnnconf.append_cnn_input_to_cnn_output:
                ldb("append_cnn_input_to_cnn_output=True")
                self.feature_map = tf.concat(
                    (self.feature_map, phi.overhead_features), axis=-1)
            if self.cnnconf.create_overhead_feature:
                # Global avg, max, min.
                self.overhead_feature = tf.concat(
                    (tf.reduce_mean(self.feature_map, axis=(-3, -2)),
                     tf.reduce_max(self.feature_map, axis=(-3, -2)),
                     tf.reduce_min(self.feature_map, axis=(-3, -2))),
                    axis=-1)
                assert (len(batch_shape) > 1)
                tilex = functools.reduce(operator.mul, batch_shape[1:])
                self.overhead_feature = tf.tile(self.overhead_feature,
                                                (tilex, 1))
        self.feature_map_C = tensoru.size(self.feature_map, -1)
        self.radii_feat_size = self.feature_map_C * len(
            self.whiskerconf.radii) * self.whiskerconf.n_samples

        # Tile the past encodings for the current batch_shape
        assert (tensoru.size(phi.light_features, 0) == self.batch_shape[0])
        if len(batch_shape) > 1:
            assert (len(batch_shape) == 2)
            #[(BK, F) ... ]
            self.past_encodings_batch = [
                tensoru.expand_and_tile_and_pack(e, 1, 0, N=batch_shape[1])
                for e in self.past_encodings
            ]
            # (BK, F)
            light_features = tf.keras.backend.repeat_elements(
                phi.light_features, self.lightconf.lightrep, axis=-1)
            light_features_batch = tensoru.expand_and_tile_and_pack(
                light_features, 1, 0, N=batch_shape[1])
            # (BKA, F)
            self.light_features_A_batch = tensoru.expand_and_tile_and_pack(
                light_features_batch, 1, 0, N=self.A)
        else:
            # [(B, F) ... ]
            self.past_encodings_batch = self.past_encodings
            # (BA, F)
            light_features_batch = tf.keras.backend.repeat_elements(
                phi.light_features, self.lightconf.lightrep, axis=-1)
            self.light_features_A_batch = tensoru.expand_and_tile_and_pack(
                light_features_batch, 1, 0, N=self.A)

        self.past_encodings_joint = []
        if self.A > 1:
            for a in range(self.A):
                others_feat = tf.reduce_sum(self.past_encodings_batch[:a] +
                                            self.past_encodings_batch[a + 1:],
                                            axis=0)
                # Create a feature for each agent that depends on its own encoding and the sum of the other agent's encodings.
                self.past_encodings_joint.append(
                    tf.concat((self.past_encodings_batch[a], others_feat),
                              axis=-1))
            past_encodings_joint_shape = 2 * self.rnnconf.past_gru_units
            self.social_map_feat_size = 2 * self.feature_map_C
        else:
            self.past_encodings_joint = [self.past_encodings_batch[0]]
            past_encodings_joint_shape = self.rnnconf.past_gru_units
            self.social_map_feat_size = self.feature_map_C

        self.past_encodings_A_batch = tf.reshape(
            tf.stack(self.past_encodings_joint, axis=-2),
            (self.A_batch_size, past_encodings_joint_shape))