def _train(args):
    container_name = ""

    R = lambda: nav_env.get_multiplexer_class(args.navtask, args.solver.task)
    m = utils.Foo()
    m.tf_graph = tf.Graph()

    config = tf.ConfigProto()
    config.device_count['GPU'] = 1

    with m.tf_graph.as_default():
        with tf.device(
                tf.train.replica_device_setter(args.solver.ps_tasks,
                                               merge_devices=True)):
            with tf.container(container_name):
                m = args.setup_to_run(m,
                                      args,
                                      is_training=True,
                                      batch_norm_is_training=True,
                                      summary_mode='train')

                train_step_kwargs = args.setup_train_step_kwargs(
                    m,
                    R(),
                    os.path.join(args.logdir, 'train'),
                    rng_seed=args.solver.task,
                    is_chief=args.solver.task == 0,
                    num_steps=args.navtask.task_params.num_steps *
                    args.navtask.task_params.num_goals,
                    iters=1,
                    train_display_interval=args.summary.display_interval,
                    dagger_sample_bn_false=args.arch.dagger_sample_bn_false)

                delay_start = (
                    args.solver.task *
                    (args.solver.task + 1)) / 2 * FLAGS.delay_start_iters
                logging.error('delaying start for task %d by %d steps.',
                              args.solver.task, delay_start)

                additional_args = {}
                final_loss = slim.learning.train(
                    train_op=m.train_op,
                    logdir=args.logdir,
                    master=args.solver.master,
                    is_chief=args.solver.task == 0,
                    number_of_steps=args.solver.max_steps,
                    train_step_fn=tf_utils.train_step_custom_online_sampling,
                    train_step_kwargs=train_step_kwargs,
                    global_step=m.global_step_op,
                    init_op=m.init_op,
                    init_fn=m.init_fn,
                    sync_optimizer=m.sync_optimizer,
                    saver=m.saver_op,
                    startup_delay_steps=delay_start,
                    summary_op=None,
                    session_config=config,
                    **additional_args)
예제 #2
0
def _compute_hardness():
    # Load the stanford data to compute the hardness.
    if FLAGS.type == '':
        args = sna.get_args_for_config(FLAGS.config_name + '+bench_' +
                                       FLAGS.imset)
    else:
        args = sna.get_args_for_config(FLAGS.type + '.' + FLAGS.config_name +
                                       '+bench_' + FLAGS.imset)

    args.navtask.logdir = None
    R = lambda: nav_env.get_multiplexer_class(args.navtask, 0)
    R = R()

    rng_data = [np.random.RandomState(0), np.random.RandomState(0)]

    # Sample a room.
    h_dists = []
    gt_dists = []
    for i in range(250):
        e = R.sample_env(rng_data)
        nodes = e.task.nodes

        # Initialize the agent.
        init_env_state = e.reset(rng_data)

        gt_dist_to_goal = [
            e.episode.dist_to_goal[0][j][s]
            for j, s in enumerate(e.episode.start_node_ids)
        ]

        for j in range(args.navtask.task_params.batch_size):
            start_node_id = e.episode.start_node_ids[j]
            end_node_id = e.episode.goal_node_ids[0][j]
            h_dist = graph_utils.heuristic_fn_vec(
                nodes[[start_node_id], :],
                nodes[[end_node_id], :],
                n_ori=args.navtask.task_params.n_ori,
                step_size=args.navtask.task_params.step_size)[0][0]
            gt_dist = e.episode.dist_to_goal[0][j][start_node_id]
            h_dists.append(h_dist)
            gt_dists.append(gt_dist)

    h_dists = np.array(h_dists)
    gt_dists = np.array(gt_dists)
    e = R.sample_env([np.random.RandomState(0), np.random.RandomState(0)])
    input = e.get_common_data()
    orig_maps = input['orig_maps'][0, 0, :, :, 0]
    return h_dists, gt_dists, orig_maps
예제 #3
0
def _compute_hardness():
  # Load the stanford data to compute the hardness.
  if FLAGS.type == '':
    args = sna.get_args_for_config(FLAGS.config_name+'+bench_'+FLAGS.imset)
  else:
    args = sna.get_args_for_config(FLAGS.type+'.'+FLAGS.config_name+'+bench_'+FLAGS.imset)

  args.navtask.logdir = None
  R = lambda: nav_env.get_multiplexer_class(args.navtask, 0)
  R = R()

  rng_data = [np.random.RandomState(0), np.random.RandomState(0)]

  # Sample a room.
  h_dists = []
  gt_dists = []
  for i in range(250):
    e = R.sample_env(rng_data)
    nodes = e.task.nodes

    # Initialize the agent.
    init_env_state = e.reset(rng_data)

    gt_dist_to_goal = [e.episode.dist_to_goal[0][j][s]
                       for j, s in enumerate(e.episode.start_node_ids)]

    for j in range(args.navtask.task_params.batch_size):
      start_node_id = e.episode.start_node_ids[j]
      end_node_id =e.episode.goal_node_ids[0][j]
      h_dist = graph_utils.heuristic_fn_vec(
          nodes[[start_node_id],:], nodes[[end_node_id], :],
          n_ori=args.navtask.task_params.n_ori,
          step_size=args.navtask.task_params.step_size)[0][0]
      gt_dist = e.episode.dist_to_goal[0][j][start_node_id]
      h_dists.append(h_dist)
      gt_dists.append(gt_dist)

  h_dists = np.array(h_dists)
  gt_dists = np.array(gt_dists)
  e = R.sample_env([np.random.RandomState(0), np.random.RandomState(0)])
  input = e.get_common_data()
  orig_maps = input['orig_maps'][0,0,:,:,0]
  return h_dists, gt_dists, orig_maps
def plot_trajectory_first_person(dt, orig_maps, out_dir):
  out_dir = os.path.join(out_dir, FLAGS.config_name+_get_suffix_str(),
                         FLAGS.imset)
  fu.makedirs(out_dir)

  # Load the model so that we can render.
  plt.set_cmap('gray')
  samples_per_action = 8; wait_at_action = 0;

  Writer = animation.writers['mencoder']
  writer = Writer(fps=3*(samples_per_action+wait_at_action),
                  metadata=dict(artist='anonymous'), bitrate=1800)

  args = sna.get_args_for_config(FLAGS.config_name + '+bench_'+FLAGS.imset)
  args.navtask.logdir = None
  navtask_ = copy.deepcopy(args.navtask)
  navtask_.camera_param.modalities = ['rgb']
  navtask_.task_params.modalities = ['rgb']
  sz = 512
  navtask_.camera_param.height = sz
  navtask_.camera_param.width = sz
  navtask_.task_params.img_height = sz
  navtask_.task_params.img_width = sz
  R = lambda: nav_env.get_multiplexer_class(navtask_, 0)
  R = R()
  b = R.buildings[0]

  f = [0 for _ in range(wait_at_action)] + \
      [float(_)/samples_per_action for _ in range(samples_per_action)];

  # Generate things for it to render.
  inds_to_do = []
  inds_to_do += [1, 4, 10] #1291, 1268, 1273, 1289, 1302, 1426, 1413, 1449, 1399, 1390]

  for i in inds_to_do:
    fig = plt.figure(figsize=(10,8))
    gs = GridSpec(3,4)
    gs.update(wspace=0.05, hspace=0.05, left=0.0, top=0.97, right=1.0, bottom=0.)
    ax = fig.add_subplot(gs[:,:-1])
    ax1 = fig.add_subplot(gs[0,-1])
    ax2 = fig.add_subplot(gs[1,-1])
    ax3 = fig.add_subplot(gs[2,-1])
    axes = [ax, ax1, ax2, ax3]
    # ax = fig.add_subplot(gs[:,:])
    # axes = [ax]
    for ax in axes:
      ax.set_axis_off()

    node_ids = dt['all_node_ids'][i, :, 0]*1
    # Prune so that last node is not repeated more than 3 times?
    if np.all(node_ids[-4:] == node_ids[-1]):
      while node_ids[-4] == node_ids[-1]:
        node_ids = node_ids[:-1]
    num_steps = np.minimum(FLAGS.num_steps, len(node_ids))

    xyt = b.to_actual_xyt_vec(b.task.nodes[node_ids])
    xyt_diff = xyt[1:,:] - xyt[:-1:,:]
    xyt_diff[:,2] = np.mod(xyt_diff[:,2], 4)
    ind = np.where(xyt_diff[:,2] == 3)[0]
    xyt_diff[ind, 2] = -1
    xyt_diff = np.expand_dims(xyt_diff, axis=1)
    to_cat = [xyt_diff*_ for _ in f]
    perturbs_all = np.concatenate(to_cat, axis=1)
    perturbs_all = np.concatenate([perturbs_all, np.zeros_like(perturbs_all[:,:,:1])], axis=2)
    node_ids_all = np.expand_dims(node_ids, axis=1)*1
    node_ids_all = np.concatenate([node_ids_all for _ in f], axis=1)
    node_ids_all = np.reshape(node_ids_all[:-1,:], -1)
    perturbs_all = np.reshape(perturbs_all, [-1, 4])
    imgs = b.render_nodes(b.task.nodes[node_ids_all,:], perturb=perturbs_all)

    # Get action at each node.
    actions = []
    _, action_to_nodes = b.get_feasible_actions(node_ids)
    for j in range(num_steps-1):
      action_to_node = action_to_nodes[j]
      node_to_action = dict(zip(action_to_node.values(), action_to_node.keys()))
      actions.append(node_to_action[node_ids[j+1]])

    def init_fn():
      return fig,
    gt_dist_to_goal = []

    # Render trajectories.
    def worker(j):
      # Plot the image.
      step_number = j/(samples_per_action + wait_at_action)
      img = imgs[j]; ax = axes[0]; ax.clear(); ax.set_axis_off();
      img = img.astype(np.uint8); ax.imshow(img);
      tt = ax.set_title(
          "First Person View\n" +
          "Top corners show diagnostics (distance, agents' action) not input to agent.",
          fontsize=12)
      plt.setp(tt, color='white')

      # Distance to goal.
      t = 'Dist to Goal:\n{:2d} steps'.format(int(dt['all_d_at_t'][i, step_number]))
      t = ax.text(0.01, 0.99, t,
          horizontalalignment='left',
          verticalalignment='top',
          fontsize=20, color='red',
          transform=ax.transAxes, alpha=1.0)
      t.set_bbox(dict(color='white', alpha=0.85, pad=-0.1))

      # Action to take.
      action_latex = ['$\odot$ ', '$\curvearrowright$ ', '$\curvearrowleft$ ', r'$\Uparrow$ ']
      t = ax.text(0.99, 0.99, action_latex[actions[step_number]],
          horizontalalignment='right',
          verticalalignment='top',
          fontsize=40, color='green',
          transform=ax.transAxes, alpha=1.0)
      t.set_bbox(dict(color='white', alpha=0.85, pad=-0.1))


      # Plot the map top view.
      ax = axes[-1]
      if j == 0:
        # Plot the map
        locs = dt['all_locs'][i,:num_steps,:]
        goal_loc = dt['all_goal_locs'][i,:,:]
        xymin = np.minimum(np.min(goal_loc, axis=0), np.min(locs, axis=0))
        xymax = np.maximum(np.max(goal_loc, axis=0), np.max(locs, axis=0))
        xy1 = (xymax+xymin)/2. - 0.7*np.maximum(np.max(xymax-xymin), 24)
        xy2 = (xymax+xymin)/2. + 0.7*np.maximum(np.max(xymax-xymin), 24)

        ax.set_axis_on()
        ax.patch.set_facecolor((0.333, 0.333, 0.333))
        ax.set_xticks([]); ax.set_yticks([]);
        ax.imshow(orig_maps, origin='lower', vmin=-1.0, vmax=2.0)
        ax.plot(goal_loc[:,0], goal_loc[:,1], 'g*', markersize=12)

        locs = dt['all_locs'][i,:1,:]
        ax.plot(locs[:,0], locs[:,1], 'b.', markersize=12)

        ax.set_xlim([xy1[0], xy2[0]])
        ax.set_ylim([xy1[1], xy2[1]])

      locs = dt['all_locs'][i,step_number,:]
      locs = np.expand_dims(locs, axis=0)
      ax.plot(locs[:,0], locs[:,1], 'r.', alpha=1.0, linewidth=0, markersize=4)
      tt = ax.set_title('Trajectory in topview', fontsize=14)
      plt.setp(tt, color='white')
      return fig,

    line_ani = animation.FuncAnimation(fig, worker,
                                       (num_steps-1)*(wait_at_action+samples_per_action),
                                       interval=500, blit=True, init_func=init_fn)
    tmp_file_name = 'tmp.mp4'
    line_ani.save(tmp_file_name, writer=writer, savefig_kwargs={'facecolor':'black'})
    out_file_name = os.path.join(out_dir, 'vis_{:04d}.mp4'.format(i))
    print(out_file_name)

    if fu.exists(out_file_name):
      gfile.Remove(out_file_name)
    gfile.Copy(tmp_file_name, out_file_name)
    gfile.Remove(tmp_file_name)
    plt.close(fig)
def _test(args):
    args.solver.master = ''
    container_name = ""
    checkpoint_dir = os.path.join(format(args.logdir))
    logging.error('Checkpoint_dir: %s', args.logdir)

    config = tf.ConfigProto()
    config.device_count['GPU'] = 1

    m = utils.Foo()
    m.tf_graph = tf.Graph()

    rng_data_seed = 0
    rng_action_seed = 0
    R = lambda: nav_env.get_multiplexer_class(args.navtask, rng_data_seed)
    with m.tf_graph.as_default():
        with tf.container(container_name):
            m = args.setup_to_run(m,
                                  args,
                                  is_training=False,
                                  batch_norm_is_training=args.control.
                                  force_batchnorm_is_training_at_test,
                                  summary_mode=args.control.test_mode)
            train_step_kwargs = args.setup_train_step_kwargs(
                m,
                R(),
                os.path.join(args.logdir, args.control.test_name),
                rng_seed=rng_data_seed,
                is_chief=True,
                num_steps=args.navtask.task_params.num_steps *
                args.navtask.task_params.num_goals,
                iters=args.summary.test_iters,
                train_display_interval=None,
                dagger_sample_bn_false=args.arch.dagger_sample_bn_false)

            saver = slim.learning.tf_saver.Saver(
                variables.get_variables_to_restore())

            sv = slim.learning.supervisor.Supervisor(
                graph=ops.get_default_graph(),
                logdir=None,
                init_op=m.init_op,
                summary_op=None,
                summary_writer=None,
                global_step=None,
                saver=m.saver_op)

            last_checkpoint = None
            reported = False
            while True:
                last_checkpoint_ = None
                while last_checkpoint_ is None:
                    last_checkpoint_ = slim.evaluation.wait_for_new_checkpoint(
                        checkpoint_dir,
                        last_checkpoint,
                        seconds_to_sleep=10,
                        timeout=60)
                if last_checkpoint_ is None: break

                last_checkpoint = last_checkpoint_
                checkpoint_iter = int(
                    os.path.basename(last_checkpoint).split('-')[1])

                logging.info(
                    'Starting evaluation at %s using checkpoint %s.',
                    time.strftime('%Y-%m-%d-%H:%M:%S', time.localtime()),
                    last_checkpoint)

                if (args.control.only_eval_when_done == False
                        or checkpoint_iter >= args.solver.max_steps):
                    start = time.time()
                    logging.info(
                        'Starting evaluation at %s using checkpoint %s.',
                        time.strftime('%Y-%m-%d-%H:%M:%S', time.localtime()),
                        last_checkpoint)

                    with sv.managed_session(
                            args.solver.master,
                            config=config,
                            start_standard_services=False) as sess:
                        sess.run(m.init_op)
                        sv.saver.restore(sess, last_checkpoint)
                        sv.start_queue_runners(sess)
                        if args.control.reset_rng_seed:
                            train_step_kwargs['rng_data'] = [
                                np.random.RandomState(rng_data_seed),
                                np.random.RandomState(rng_data_seed)
                            ]
                            train_step_kwargs[
                                'rng_action'] = np.random.RandomState(
                                    rng_action_seed)
                        vals, _ = tf_utils.train_step_custom_online_sampling(
                            sess,
                            None,
                            m.global_step_op,
                            train_step_kwargs,
                            mode=args.control.test_mode)
                        should_stop = False

                        if checkpoint_iter >= args.solver.max_steps:
                            should_stop = True

                        if should_stop:
                            break
def _train(args):
    #pdb.set_trace()
    container_name = ""
    #tmp setting TRI
    args.solver.max_steps = 500000
    args.solver.steps_per_decay = 50000
    args.solver.initial_learning_rate = 1e-8
    args.navtask.task_params.batch_size = 32

    #pdb.set_trace()
    R = lambda: nav_env.get_multiplexer_class(args.navtask, args.solver.task)
    m = utils.Foo()
    m_cloned = utils.Foo()
    m.tf_graph = tf.Graph()

    #Tri
    #add a cloned building object for checking the exploration result during training
    #m.cloned_obj = R()
    m.batch_size = args.navtask.task_params.batch_size
    m.train_type = 1
    m.suffle = False
    m.is_first_step = True
    m.save_pic_step = 10000
    m.save_pic_count = 0
    m.save_reward_step = 500
    m.save_reward_count = 0

    m.is_main = True
    m_cloned.is_main = False

    config = tf.ConfigProto()
    config.device_count['GPU'] = 1

    with m.tf_graph.as_default():
        with tf.device(
                tf.train.replica_device_setter(args.solver.ps_tasks,
                                               merge_devices=True)):
            with tf.container(container_name):
                #pdb.set_trace()
                m = args.setup_to_run(m,
                                      args,
                                      is_training=True,
                                      batch_norm_is_training=True,
                                      summary_mode='train')

                #pdb.set_trace()

                #with tf.name_scope('cloned'):
                m_cloned.x = m.x
                m_cloned.vars_to_restore = m.vars_to_restore
                m_cloned.batch_norm_is_training_op = m.batch_norm_is_training_op
                m_cloned.input_tensors = m.input_tensors
                with tf.variable_scope('cloned'):
                    m_cloned = args.setup_to_run(m_cloned,
                                                 args,
                                                 is_training=True,
                                                 batch_norm_is_training=True,
                                                 summary_mode='train')
                #pdb.set_trace()
                clonemodel(m, m_cloned)
                set_copying_ops(m)
                m.init_op = tf.group(tf.global_variables_initializer(),
                                     tf.local_variables_initializer())
                m.saver_op = tf.train.Saver(keep_checkpoint_every_n_hours=6,
                                            write_version=tf.train.SaverDef.V2)
                #with tf.Session() as sess:
                #  sess.run(m.init_op)
                #pdb.set_trace()
                train_step_kwargs = args.setup_train_step_kwargs(
                    m,
                    R(),
                    os.path.join(args.logdir, 'train'),
                    rng_seed=args.solver.task,
                    is_chief=args.solver.task == 0,
                    num_steps=args.navtask.task_params.num_steps *
                    args.navtask.task_params.num_goals,
                    iters=1,
                    train_display_interval=args.summary.display_interval,
                    dagger_sample_bn_false=args.arch.dagger_sample_bn_false)
                #pdb.set_trace()
                delay_start = (
                    args.solver.task *
                    (args.solver.task + 1)) / 2 * FLAGS.delay_start_iters
                logging.error('delaying start for task %d by %d steps.',
                              args.solver.task, delay_start)

                #Tri
                #clonemodel(m)
                #set_copying_ops(m)
                #set_tmp_params(m)
                #generating data for testing the learning process during training
                #obj = train_step_kwargs['obj']
                rng_data = train_step_kwargs['rng_data']
                #m.e1 = obj.sample_env(rng_data)
                #m.init_env_state1 = m.e1.reset(rng_data)
                #m.e2 = obj.sample_env(rng_data)
                #m.init_env_state2 = m.e2.reset(rng_data)
                m.rng_data = deepcopy(rng_data)

                #pdb.set_trace()
                additional_args = {}
                final_loss = slim.learning.train(
                    train_op=m.train_op,
                    logdir=args.logdir,
                    master=args.solver.master,
                    is_chief=args.solver.task == 0,
                    number_of_steps=args.solver.max_steps,
                    train_step_fn=tf_utils.train_step_custom_online_sampling,
                    train_step_kwargs=train_step_kwargs,
                    global_step=m.global_step_op,
                    init_op=m.init_op,
                    init_fn=m.init_fn,
                    sync_optimizer=m.sync_optimizer,
                    saver=m.saver_op,
                    startup_delay_steps=delay_start,
                    summary_op=None,
                    session_config=config,
                    **additional_args)
예제 #7
0
def plot_trajectory_first_person(dt, orig_maps, out_dir):
  out_dir = os.path.join(out_dir, FLAGS.config_name+_get_suffix_str(),
                         FLAGS.imset)
  fu.makedirs(out_dir)

  # Load the model so that we can render.
  plt.set_cmap('gray')
  samples_per_action = 8; wait_at_action = 0;

  Writer = animation.writers['mencoder']
  writer = Writer(fps=3*(samples_per_action+wait_at_action),
                  metadata=dict(artist='anonymous'), bitrate=1800)

  args = sna.get_args_for_config(FLAGS.config_name + '+bench_'+FLAGS.imset)
  args.navtask.logdir = None
  navtask_ = copy.deepcopy(args.navtask)
  navtask_.camera_param.modalities = ['rgb']
  navtask_.task_params.modalities = ['rgb']
  sz = 512
  navtask_.camera_param.height = sz
  navtask_.camera_param.width = sz
  navtask_.task_params.img_height = sz
  navtask_.task_params.img_width = sz
  R = lambda: nav_env.get_multiplexer_class(navtask_, 0)
  R = R()
  b = R.buildings[0]

  f = [0 for _ in range(wait_at_action)] + \
      [float(_)/samples_per_action for _ in range(samples_per_action)];

  # Generate things for it to render.
  inds_to_do = []
  inds_to_do += [1, 4, 10] #1291, 1268, 1273, 1289, 1302, 1426, 1413, 1449, 1399, 1390]

  for i in inds_to_do:
    fig = plt.figure(figsize=(10,8))
    gs = GridSpec(3,4)
    gs.update(wspace=0.05, hspace=0.05, left=0.0, top=0.97, right=1.0, bottom=0.)
    ax = fig.add_subplot(gs[:,:-1])
    ax1 = fig.add_subplot(gs[0,-1])
    ax2 = fig.add_subplot(gs[1,-1])
    ax3 = fig.add_subplot(gs[2,-1])
    axes = [ax, ax1, ax2, ax3]
    # ax = fig.add_subplot(gs[:,:])
    # axes = [ax]
    for ax in axes:
      ax.set_axis_off()

    node_ids = dt['all_node_ids'][i, :, 0]*1
    # Prune so that last node is not repeated more than 3 times?
    if np.all(node_ids[-4:] == node_ids[-1]):
      while node_ids[-4] == node_ids[-1]:
        node_ids = node_ids[:-1]
    num_steps = np.minimum(FLAGS.num_steps, len(node_ids))

    xyt = b.to_actual_xyt_vec(b.task.nodes[node_ids])
    xyt_diff = xyt[1:,:] - xyt[:-1:,:]
    xyt_diff[:,2] = np.mod(xyt_diff[:,2], 4)
    ind = np.where(xyt_diff[:,2] == 3)[0]
    xyt_diff[ind, 2] = -1
    xyt_diff = np.expand_dims(xyt_diff, axis=1)
    to_cat = [xyt_diff*_ for _ in f]
    perturbs_all = np.concatenate(to_cat, axis=1)
    perturbs_all = np.concatenate([perturbs_all, np.zeros_like(perturbs_all[:,:,:1])], axis=2)
    node_ids_all = np.expand_dims(node_ids, axis=1)*1
    node_ids_all = np.concatenate([node_ids_all for _ in f], axis=1)
    node_ids_all = np.reshape(node_ids_all[:-1,:], -1)
    perturbs_all = np.reshape(perturbs_all, [-1, 4])
    imgs = b.render_nodes(b.task.nodes[node_ids_all,:], perturb=perturbs_all)

    # Get action at each node.
    actions = []
    _, action_to_nodes = b.get_feasible_actions(node_ids)
    for j in range(num_steps-1):
      action_to_node = action_to_nodes[j]
      node_to_action = dict(zip(action_to_node.values(), action_to_node.keys()))
      actions.append(node_to_action[node_ids[j+1]])

    def init_fn():
      return fig,
    gt_dist_to_goal = []

    # Render trajectories.
    def worker(j):
      # Plot the image.
      step_number = j/(samples_per_action + wait_at_action)
      img = imgs[j]; ax = axes[0]; ax.clear(); ax.set_axis_off();
      img = img.astype(np.uint8); ax.imshow(img);
      tt = ax.set_title(
          "First Person View\n" +
          "Top corners show diagnostics (distance, agents' action) not input to agent.",
          fontsize=12)
      plt.setp(tt, color='white')

      # Distance to goal.
      t = 'Dist to Goal:\n{:2d} steps'.format(int(dt['all_d_at_t'][i, step_number]))
      t = ax.text(0.01, 0.99, t,
          horizontalalignment='left',
          verticalalignment='top',
          fontsize=20, color='red',
          transform=ax.transAxes, alpha=1.0)
      t.set_bbox(dict(color='white', alpha=0.85, pad=-0.1))

      # Action to take.
      action_latex = ['$\odot$ ', '$\curvearrowright$ ', '$\curvearrowleft$ ', r'$\Uparrow$ ']
      t = ax.text(0.99, 0.99, action_latex[actions[step_number]],
          horizontalalignment='right',
          verticalalignment='top',
          fontsize=40, color='green',
          transform=ax.transAxes, alpha=1.0)
      t.set_bbox(dict(color='white', alpha=0.85, pad=-0.1))


      # Plot the map top view.
      ax = axes[-1]
      if j == 0:
        # Plot the map
        locs = dt['all_locs'][i,:num_steps,:]
        goal_loc = dt['all_goal_locs'][i,:,:]
        xymin = np.minimum(np.min(goal_loc, axis=0), np.min(locs, axis=0))
        xymax = np.maximum(np.max(goal_loc, axis=0), np.max(locs, axis=0))
        xy1 = (xymax+xymin)/2. - 0.7*np.maximum(np.max(xymax-xymin), 24)
        xy2 = (xymax+xymin)/2. + 0.7*np.maximum(np.max(xymax-xymin), 24)

        ax.set_axis_on()
        ax.patch.set_facecolor((0.333, 0.333, 0.333))
        ax.set_xticks([]); ax.set_yticks([]);
        ax.imshow(orig_maps, origin='lower', vmin=-1.0, vmax=2.0)
        ax.plot(goal_loc[:,0], goal_loc[:,1], 'g*', markersize=12)

        locs = dt['all_locs'][i,:1,:]
        ax.plot(locs[:,0], locs[:,1], 'b.', markersize=12)

        ax.set_xlim([xy1[0], xy2[0]])
        ax.set_ylim([xy1[1], xy2[1]])

      locs = dt['all_locs'][i,step_number,:]
      locs = np.expand_dims(locs, axis=0)
      ax.plot(locs[:,0], locs[:,1], 'r.', alpha=1.0, linewidth=0, markersize=4)
      tt = ax.set_title('Trajectory in topview', fontsize=14)
      plt.setp(tt, color='white')
      return fig,

    line_ani = animation.FuncAnimation(fig, worker,
                                       (num_steps-1)*(wait_at_action+samples_per_action),
                                       interval=500, blit=True, init_func=init_fn)
    tmp_file_name = 'tmp.mp4'
    line_ani.save(tmp_file_name, writer=writer, savefig_kwargs={'facecolor':'black'})
    out_file_name = os.path.join(out_dir, 'vis_{:04d}.mp4'.format(i))
    print(out_file_name)

    if fu.exists(out_file_name):
      gfile.Remove(out_file_name)
    gfile.Copy(tmp_file_name, out_file_name)
    gfile.Remove(tmp_file_name)
    plt.close(fig)