Esempio n. 1
0
def setup_to_run(m, args, is_training, batch_norm_is_training, summary_mode):
  assert(args.arch.multi_scale), 'removed support for old single scale code.'
  # Set up the model.
  tf.set_random_seed(args.solver.seed)
  task_params = args.navtask.task_params

  batch_norm_is_training_op = \
      tf.placeholder_with_default(batch_norm_is_training, shape=[],
                                  name='batch_norm_is_training_op') 

  # Setup the inputs
  m.input_tensors = {}
  m.train_ops = {}
  m.input_tensors['common'], m.input_tensors['step'], m.input_tensors['train_bkp'] = \
      _inputs(task_params)

  m.init_fn = None

  if task_params.input_type == 'vision':
    m.vision_ops = get_map_from_images(
        m.input_tensors['step']['imgs'], args.mapper_arch,
        task_params, args.solver.freeze_conv,
        args.solver.wt_decay, is_training, batch_norm_is_training_op,
        num_maps=len(task_params.map_crop_sizes))

    # Load variables from snapshot if needed.
    if args.solver.pretrained_path is not None:
      m.init_fn = slim.assign_from_checkpoint_fn(args.solver.pretrained_path,
                                                 m.vision_ops.vars_to_restore)

    # Set up caching of vision features if needed.
    if args.solver.freeze_conv:
      m.train_ops['step_data_cache'] = [m.vision_ops.encoder_output]
    else:
      m.train_ops['step_data_cache'] = []

    # Set up blobs that are needed for the computation in rest of the graph.
    m.ego_map_ops = m.vision_ops.fss_logits
    m.coverage_ops = m.vision_ops.confs_probs
    
    # Zero pad these to make them same size as what the planner expects.
    for i in range(len(m.ego_map_ops)):
      if args.mapper_arch.pad_map_with_zeros_each[i] > 0:
        paddings = np.zeros((5,2), dtype=np.int32)
        paddings[2:4,:] = args.mapper_arch.pad_map_with_zeros_each[i]
        paddings_op = tf.constant(paddings, dtype=tf.int32)
        m.ego_map_ops[i] = tf.pad(m.ego_map_ops[i], paddings=paddings_op)
        m.coverage_ops[i] = tf.pad(m.coverage_ops[i], paddings=paddings_op)
  
  elif task_params.input_type == 'analytical_counts':
    m.ego_map_ops = []; m.coverage_ops = []
    for i in range(len(task_params.map_crop_sizes)):
      ego_map_op = m.input_tensors['step']['analytical_counts_{:d}'.format(i)]
      coverage_op = tf.cast(tf.greater_equal(
          tf.reduce_max(ego_map_op, reduction_indices=[4],
                        keep_dims=True), 1), tf.float32)
      coverage_op = tf.ones_like(ego_map_op) * coverage_op
      m.ego_map_ops.append(ego_map_op)
      m.coverage_ops.append(coverage_op)
      m.train_ops['step_data_cache'] = []
  
  num_steps = task_params.num_steps
  num_goals = task_params.num_goals

  map_crop_size_ops = []
  for map_crop_size in task_params.map_crop_sizes:
    map_crop_size_ops.append(tf.constant(map_crop_size, dtype=tf.int32, shape=(2,)))

  with tf.name_scope('check_size'):
    is_single_step = tf.equal(tf.unstack(tf.shape(m.ego_map_ops[0]), num=5)[1], 1)

  fr_ops = []; value_ops = [];
  fr_intermediate_ops = []; value_intermediate_ops = [];
  crop_value_ops = [];
  resize_crop_value_ops = [];
  confs = []; occupancys = [];

  previous_value_op = None
  updated_state = []; state_names = [];

  for i in range(len(task_params.map_crop_sizes)):
    map_crop_size = task_params.map_crop_sizes[i]
    with tf.variable_scope('scale_{:d}'.format(i)): 
      # Accumulate the map.
      fn = lambda ns: running_combine(
             m.ego_map_ops[i],
             m.coverage_ops[i],
             m.input_tensors['step']['incremental_locs'] * task_params.map_scales[i],
             m.input_tensors['step']['incremental_thetas'],
             m.input_tensors['step']['running_sum_num_{:d}'.format(i)],
             m.input_tensors['step']['running_sum_denom_{:d}'.format(i)],
             m.input_tensors['step']['running_max_denom_{:d}'.format(i)],
             map_crop_size, ns)

      running_sum_num, running_sum_denom, running_max_denom = \
          tf.cond(is_single_step, lambda: fn(1), lambda: fn(num_steps*num_goals))
      updated_state += [running_sum_num, running_sum_denom, running_max_denom]
      state_names += ['running_sum_num_{:d}'.format(i),
                      'running_sum_denom_{:d}'.format(i),
                      'running_max_denom_{:d}'.format(i)]

      # Concat the accumulated map and goal
      occupancy = running_sum_num / tf.maximum(running_sum_denom, 0.001)
      conf = running_max_denom
      # print occupancy.get_shape().as_list()

      # Concat occupancy, how much occupied and goal.
      with tf.name_scope('concat'):
        sh = [-1, map_crop_size, map_crop_size, task_params.map_channels]
        occupancy = tf.reshape(occupancy, shape=sh)
        conf = tf.reshape(conf, shape=sh)

        sh = [-1, map_crop_size, map_crop_size, task_params.goal_channels]
        goal = tf.reshape(m.input_tensors['step']['ego_goal_imgs_{:d}'.format(i)], shape=sh)
        to_concat = [occupancy, conf, goal]

        if previous_value_op is not None:
          to_concat.append(previous_value_op)

        x = tf.concat(to_concat, 3)

      # Pass the map, previous rewards and the goal through a few convolutional
      # layers to get fR.
      fr_op, fr_intermediate_op = fr_v2(
         x, output_neurons=args.arch.fr_neurons,
         inside_neurons=args.arch.fr_inside_neurons,
         is_training=batch_norm_is_training_op, name='fr',
         wt_decay=args.solver.wt_decay, stride=args.arch.fr_stride)

      # Do Value Iteration on the fR
      if args.arch.vin_num_iters > 0:
        value_op, value_intermediate_op = value_iteration_network(
            fr_op, num_iters=args.arch.vin_num_iters,
            val_neurons=args.arch.vin_val_neurons,
            action_neurons=args.arch.vin_action_neurons,
            kernel_size=args.arch.vin_ks, share_wts=args.arch.vin_share_wts,
            name='vin', wt_decay=args.solver.wt_decay)
      else:
        value_op = fr_op
        value_intermediate_op = []

      # Crop out and upsample the previous value map.
      remove = args.arch.crop_remove_each
      if remove > 0:
        crop_value_op = value_op[:, remove:-remove, remove:-remove,:]
      else:
        crop_value_op = value_op
      crop_value_op = tf.reshape(crop_value_op, shape=[-1, args.arch.value_crop_size,
                                                       args.arch.value_crop_size,
                                                       args.arch.vin_val_neurons])
      if i < len(task_params.map_crop_sizes)-1:
        # Reshape it to shape of the next scale.
        previous_value_op = tf.image.resize_bilinear(crop_value_op,
                                                     map_crop_size_ops[i+1],
                                                     align_corners=True)
        resize_crop_value_ops.append(previous_value_op)
      
      occupancys.append(occupancy)
      confs.append(conf)
      value_ops.append(value_op)
      crop_value_ops.append(crop_value_op)
      fr_ops.append(fr_op)
      fr_intermediate_ops.append(fr_intermediate_op)
  
  m.value_ops = value_ops
  m.value_intermediate_ops = value_intermediate_ops
  m.fr_ops = fr_ops
  m.fr_intermediate_ops = fr_intermediate_ops
  m.final_value_op = crop_value_op
  m.crop_value_ops = crop_value_ops
  m.resize_crop_value_ops = resize_crop_value_ops
  m.confs = confs
  m.occupancys = occupancys

  sh = [-1, args.arch.vin_val_neurons*((args.arch.value_crop_size)**2)]
  m.value_features_op = tf.reshape(m.final_value_op, sh, name='reshape_value_op')
  
  # Determine what action to take.
  with tf.variable_scope('action_pred'):
    batch_norm_param = args.arch.pred_batch_norm_param
    if batch_norm_param is not None:
      batch_norm_param['is_training'] = batch_norm_is_training_op
    m.action_logits_op, _ = tf_utils.fc_network(
        m.value_features_op, neurons=args.arch.pred_neurons,
        wt_decay=args.solver.wt_decay, name='pred', offset=0,
        num_pred=task_params.num_actions,
        batch_norm_param=batch_norm_param) 
    m.action_prob_op = tf.nn.softmax(m.action_logits_op)

  init_state = tf.constant(0., dtype=tf.float32, shape=[
      task_params.batch_size, 1, map_crop_size, map_crop_size,
      task_params.map_channels])

  m.train_ops['state_names'] = state_names
  m.train_ops['updated_state'] = updated_state
  m.train_ops['init_state'] = [init_state for _ in updated_state]

  m.train_ops['step'] = m.action_prob_op
  m.train_ops['common'] = [m.input_tensors['common']['orig_maps'],
                           m.input_tensors['common']['goal_loc']]
  m.train_ops['batch_norm_is_training_op'] = batch_norm_is_training_op
  m.loss_ops = []; m.loss_ops_names = [];

  if args.arch.readout_maps:
    with tf.name_scope('readout_maps'):
      all_occupancys = tf.concat(m.occupancys + m.confs, 3)
      readout_maps, probs = readout_general(
          all_occupancys, num_neurons=args.arch.rom_arch.num_neurons,
          strides=args.arch.rom_arch.strides, 
          layers_per_block=args.arch.rom_arch.layers_per_block, 
          kernel_size=args.arch.rom_arch.kernel_size,
          batch_norm_is_training_op=batch_norm_is_training_op,
          wt_decay=args.solver.wt_decay)

      gt_ego_maps = [m.input_tensors['step']['readout_maps_{:d}'.format(i)]
                     for i in range(len(task_params.readout_maps_crop_sizes))]
      m.readout_maps_gt = tf.concat(gt_ego_maps, 4)
      gt_shape = tf.shape(m.readout_maps_gt)
      m.readout_maps_logits = tf.reshape(readout_maps, gt_shape)
      m.readout_maps_probs = tf.reshape(probs, gt_shape)

      # Add a loss op
      m.readout_maps_loss_op = tf.losses.sigmoid_cross_entropy(
          tf.reshape(m.readout_maps_gt, [-1, len(task_params.readout_maps_crop_sizes)]), 
          tf.reshape(readout_maps, [-1, len(task_params.readout_maps_crop_sizes)]),
          scope='loss')
      m.readout_maps_loss_op = 10.*m.readout_maps_loss_op

  ewma_decay = 0.99 if is_training else 0.0
  weight = tf.ones_like(m.input_tensors['train_bkp']['action'], dtype=tf.float32,
                        name='weight')
  m.reg_loss_op, m.data_loss_op, m.total_loss_op, m.acc_ops = \
    compute_losses_multi_or(m.action_logits_op,
                            m.input_tensors['train_bkp']['action'], weights=weight,
                            num_actions=task_params.num_actions,
                            data_loss_wt=args.solver.data_loss_wt,
                            reg_loss_wt=args.solver.reg_loss_wt,
                            ewma_decay=ewma_decay)
  
  if args.arch.readout_maps:
    m.total_loss_op = m.total_loss_op + m.readout_maps_loss_op
    m.loss_ops += [m.readout_maps_loss_op]
    m.loss_ops_names += ['readout_maps_loss']

  m.loss_ops += [m.reg_loss_op, m.data_loss_op, m.total_loss_op]
  m.loss_ops_names += ['reg_loss', 'data_loss', 'total_loss']

  if args.solver.freeze_conv:
    vars_to_optimize = list(set(tf.trainable_variables()) -
                            set(m.vision_ops.vars_to_restore))
  else:
    vars_to_optimize = None

  m.lr_op, m.global_step_op, m.train_op, m.should_stop_op, m.optimizer, \
  m.sync_optimizer = tf_utils.setup_training(
      m.total_loss_op, 
      args.solver.initial_learning_rate, 
      args.solver.steps_per_decay,
      args.solver.learning_rate_decay, 
      args.solver.momentum,
      args.solver.max_steps, 
      args.solver.sync, 
      args.solver.adjust_lr_sync,
      args.solver.num_workers, 
      args.solver.task,
      vars_to_optimize=vars_to_optimize,
      clip_gradient_norm=args.solver.clip_gradient_norm,
      typ=args.solver.typ, momentum2=args.solver.momentum2,
      adam_eps=args.solver.adam_eps)

  if args.arch.sample_gt_prob_type == 'inverse_sigmoid_decay':
    m.sample_gt_prob_op = tf_utils.inverse_sigmoid_decay(args.arch.isd_k,
                                                         m.global_step_op)
  elif args.arch.sample_gt_prob_type == 'zero':
    m.sample_gt_prob_op = tf.constant(-1.0, dtype=tf.float32)

  elif args.arch.sample_gt_prob_type.split('_')[0] == 'step':
    step = int(args.arch.sample_gt_prob_type.split('_')[1])
    m.sample_gt_prob_op = tf_utils.step_gt_prob(
        step, m.input_tensors['step']['step_number'][0,0,0])

  m.sample_action_type = args.arch.action_sample_type
  m.sample_action_combine_type = args.arch.action_sample_combine_type

  m.summary_ops = {
      summary_mode: _add_summaries(m, args, summary_mode,
                                   args.summary.arop_full_summary_iters)}

  m.init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
  m.saver_op = tf.train.Saver(keep_checkpoint_every_n_hours=4,
                              write_version=tf.train.SaverDef.V2)
  return m
def setup_to_run(m, args, is_training, batch_norm_is_training, summary_mode):
    # Set up the model.
    tf.set_random_seed(args.solver.seed)
    task_params = args.navtask.task_params
    num_steps = task_params.num_steps
    num_goals = task_params.num_goals
    num_actions = task_params.num_actions
    num_actions_ = num_actions

    n_views = task_params.n_views

    batch_norm_is_training_op = \
        tf.placeholder_with_default(batch_norm_is_training, shape=[],
                                    name='batch_norm_is_training_op')
    # Setup the inputs
    m.input_tensors = {}
    lstm_states = []
    lstm_state_dims = []
    state_names = []
    updated_state_ops = []
    init_state_ops = []
    if args.arch.lstm_output:
        lstm_states += ['lstm_output']
        lstm_state_dims += [
            args.arch.lstm_output_dim + task_params.num_actions
        ]
    if args.arch.lstm_ego:
        lstm_states += ['lstm_ego']
        lstm_state_dims += [args.arch.lstm_ego_dim + args.arch.lstm_ego_out]
        lstm_states += ['lstm_img']
        lstm_state_dims += [args.arch.lstm_img_dim + args.arch.lstm_img_out]
    elif args.arch.lstm_img:
        # An LSTM only on the image
        lstm_states += ['lstm_img']
        lstm_state_dims += [args.arch.lstm_img_dim + args.arch.lstm_img_out]
    else:
        # No LSTMs involved here.
        None

    m.input_tensors['common'], m.input_tensors['step'], m.input_tensors['train'] = \
        _inputs(task_params, lstm_states, lstm_state_dims)

    with tf.name_scope('check_size'):
        is_single_step = tf.equal(
            tf.unstack(tf.shape(m.input_tensors['step']['imgs']), num=6)[1], 1)

    images_reshaped = tf.reshape(m.input_tensors['step']['imgs'],
                                 shape=[
                                     -1, task_params.img_height,
                                     task_params.img_width,
                                     task_params.img_channels
                                 ],
                                 name='re_image')

    rel_goal_loc_reshaped = tf.reshape(
        m.input_tensors['step']['rel_goal_loc'],
        shape=[-1, task_params.rel_goal_loc_dim],
        name='re_rel_goal_loc')

    x, vars_ = get_repr_from_image(images_reshaped, task_params.modalities,
                                   task_params.data_augment, args.arch.encoder,
                                   args.solver.freeze_conv,
                                   args.solver.wt_decay, is_training)

    # Reshape into nice things so that these can be accumulated over time steps
    # for faster backprop.
    sh_before = x.get_shape().as_list()
    m.encoder_output = tf.reshape(x,
                                  shape=[task_params.batch_size, -1, n_views] +
                                  sh_before[1:])
    x = tf.reshape(m.encoder_output, shape=[-1] + sh_before[1:])

    # Add a layer to reduce dimensions for a fc layer.
    if args.arch.dim_reduce_neurons > 0:
        ks = 1
        neurons = args.arch.dim_reduce_neurons
        init_var = np.sqrt(2.0 / (ks**2) / neurons)
        batch_norm_param = args.arch.batch_norm_param
        batch_norm_param['is_training'] = batch_norm_is_training_op
        m.conv_feat = slim.conv2d(
            x,
            neurons,
            kernel_size=ks,
            stride=1,
            normalizer_fn=slim.batch_norm,
            normalizer_params=batch_norm_param,
            padding='SAME',
            scope='dim_reduce',
            weights_regularizer=slim.l2_regularizer(args.solver.wt_decay),
            weights_initializer=tf.random_normal_initializer(stddev=init_var))
        reshape_conv_feat = slim.flatten(m.conv_feat)
        sh = reshape_conv_feat.get_shape().as_list()
        m.reshape_conv_feat = tf.reshape(reshape_conv_feat,
                                         shape=[-1, sh[1] * n_views])

    # Restore these from a checkpoint.
    if args.solver.pretrained_path is not None:
        m.init_fn = slim.assign_from_checkpoint_fn(args.solver.pretrained_path,
                                                   vars_)
    else:
        m.init_fn = None

    # Hit the goal_location with a bunch of fully connected layers, to embed it
    # into some space.
    with tf.variable_scope('embed_goal'):
        batch_norm_param = args.arch.batch_norm_param
        batch_norm_param['is_training'] = batch_norm_is_training_op
        m.embed_goal, _ = tf_utils.fc_network(
            rel_goal_loc_reshaped,
            neurons=args.arch.goal_embed_neurons,
            wt_decay=args.solver.wt_decay,
            name='goal_embed',
            offset=0,
            batch_norm_param=batch_norm_param,
            dropout_ratio=args.arch.fc_dropout,
            is_training=is_training)

    if args.arch.embed_goal_for_state:
        with tf.variable_scope('embed_goal_for_state'):
            batch_norm_param = args.arch.batch_norm_param
            batch_norm_param['is_training'] = batch_norm_is_training_op
            m.embed_goal_for_state, _ = tf_utils.fc_network(
                m.input_tensors['common']['rel_goal_loc_at_start'][:, 0, :],
                neurons=args.arch.goal_embed_neurons,
                wt_decay=args.solver.wt_decay,
                name='goal_embed',
                offset=0,
                batch_norm_param=batch_norm_param,
                dropout_ratio=args.arch.fc_dropout,
                is_training=is_training)

    # Hit the goal_location with a bunch of fully connected layers, to embed it
    # into some space.
    with tf.variable_scope('embed_img'):
        batch_norm_param = args.arch.batch_norm_param
        batch_norm_param['is_training'] = batch_norm_is_training_op
        m.embed_img, _ = tf_utils.fc_network(
            m.reshape_conv_feat,
            neurons=args.arch.img_embed_neurons,
            wt_decay=args.solver.wt_decay,
            name='img_embed',
            offset=0,
            batch_norm_param=batch_norm_param,
            dropout_ratio=args.arch.fc_dropout,
            is_training=is_training)

    # For lstm_ego, and lstm_image, embed the ego motion, accumulate it into an
    # LSTM, combine with image features and accumulate those in an LSTM. Finally
    # combine what you get from the image LSTM with the goal to output an action.
    if args.arch.lstm_ego:
        ego_reshaped = preprocess_egomotion(
            m.input_tensors['step']['incremental_locs'],
            m.input_tensors['step']['incremental_thetas'])
        with tf.variable_scope('embed_ego'):
            batch_norm_param = args.arch.batch_norm_param
            batch_norm_param['is_training'] = batch_norm_is_training_op
            m.embed_ego, _ = tf_utils.fc_network(
                ego_reshaped,
                neurons=args.arch.ego_embed_neurons,
                wt_decay=args.solver.wt_decay,
                name='ego_embed',
                offset=0,
                batch_norm_param=batch_norm_param,
                dropout_ratio=args.arch.fc_dropout,
                is_training=is_training)

        state_name, state_init_op, updated_state_op, out_op = lstm_setup(
            'lstm_ego', m.embed_ego, task_params.batch_size, is_single_step,
            args.arch.lstm_ego_dim, args.arch.lstm_ego_out,
            num_steps * num_goals, m.input_tensors['step']['lstm_ego'])
        state_names += [state_name]
        init_state_ops += [state_init_op]
        updated_state_ops += [updated_state_op]

        # Combine the output with the vision features.
        m.img_ego_op = combine_setup('img_ego', args.arch.combine_type_ego,
                                     m.embed_img, out_op,
                                     args.arch.img_embed_neurons[-1],
                                     args.arch.lstm_ego_out)

        # LSTM on these vision features.
        state_name, state_init_op, updated_state_op, out_op = lstm_setup(
            'lstm_img', m.img_ego_op, task_params.batch_size, is_single_step,
            args.arch.lstm_img_dim, args.arch.lstm_img_out,
            num_steps * num_goals, m.input_tensors['step']['lstm_img'])
        state_names += [state_name]
        init_state_ops += [state_init_op]
        updated_state_ops += [updated_state_op]

        m.img_for_goal = out_op
        num_img_for_goal_neurons = args.arch.lstm_img_out

    elif args.arch.lstm_img:
        # LSTM on just the image features.
        state_name, state_init_op, updated_state_op, out_op = lstm_setup(
            'lstm_img', m.embed_img, task_params.batch_size, is_single_step,
            args.arch.lstm_img_dim, args.arch.lstm_img_out,
            num_steps * num_goals, m.input_tensors['step']['lstm_img'])
        state_names += [state_name]
        init_state_ops += [state_init_op]
        updated_state_ops += [updated_state_op]
        m.img_for_goal = out_op
        num_img_for_goal_neurons = args.arch.lstm_img_out

    else:
        m.img_for_goal = m.embed_img
        num_img_for_goal_neurons = args.arch.img_embed_neurons[-1]

    if args.arch.use_visit_count:
        m.embed_visit_count = visit_count_fc(
            m.input_tensors['step']['visit_count'],
            m.input_tensors['step']['last_visit'],
            args.arch.goal_embed_neurons,
            args.solver.wt_decay,
            args.arch.fc_dropout,
            is_training=is_training)
        m.embed_goal = m.embed_goal + m.embed_visit_count

    m.combined_f = combine_setup('img_goal', args.arch.combine_type,
                                 m.img_for_goal, m.embed_goal,
                                 num_img_for_goal_neurons,
                                 args.arch.goal_embed_neurons[-1])

    # LSTM on the combined representation.
    if args.arch.lstm_output:
        name = 'lstm_output'
        # A few fully connected layers here.
        with tf.variable_scope('action_pred'):
            batch_norm_param = args.arch.batch_norm_param
            batch_norm_param['is_training'] = batch_norm_is_training_op
            x, _ = tf_utils.fc_network(m.combined_f,
                                       neurons=args.arch.pred_neurons,
                                       wt_decay=args.solver.wt_decay,
                                       name='pred',
                                       offset=0,
                                       batch_norm_param=batch_norm_param,
                                       dropout_ratio=args.arch.fc_dropout)

        if args.arch.lstm_output_init_state_from_goal:
            # Use the goal embedding to initialize the LSTM state.
            # UGLY CLUGGY HACK: if this is doing computation for a single time step
            # then this will not involve back prop, so we can use the state input from
            # the feed dict, otherwise we compute the state representation from the
            # goal and feed that in. Necessary for using goal location to generate the
            # state representation.
            m.embed_goal_for_state = tf.expand_dims(m.embed_goal_for_state,
                                                    dim=1)
            state_op = tf.cond(is_single_step,
                               lambda: m.input_tensors['step'][name],
                               lambda: m.embed_goal_for_state)
            state_name, state_init_op, updated_state_op, out_op = lstm_setup(
                name, x, task_params.batch_size, is_single_step,
                args.arch.lstm_output_dim, num_actions_, num_steps * num_goals,
                state_op)
            init_state_ops += [m.embed_goal_for_state]
        else:
            state_op = m.input_tensors['step'][name]
            state_name, state_init_op, updated_state_op, out_op = lstm_setup(
                name, x, task_params.batch_size, is_single_step,
                args.arch.lstm_output_dim, num_actions_, num_steps * num_goals,
                state_op)
            init_state_ops += [state_init_op]

        state_names += [state_name]
        updated_state_ops += [updated_state_op]

        out_op = tf.reshape(out_op, shape=[-1, num_actions_])
        if num_actions_ > num_actions:
            m.action_logits_op = out_op[:, :num_actions]
            m.baseline_op = out_op[:, num_actions:]
        else:
            m.action_logits_op = out_op
            m.baseline_op = None
        m.action_prob_op = tf.nn.softmax(m.action_logits_op)

    else:
        # A few fully connected layers here.
        with tf.variable_scope('action_pred'):
            batch_norm_param = args.arch.batch_norm_param
            batch_norm_param['is_training'] = batch_norm_is_training_op
            out_op, _ = tf_utils.fc_network(m.combined_f,
                                            neurons=args.arch.pred_neurons,
                                            wt_decay=args.solver.wt_decay,
                                            name='pred',
                                            offset=0,
                                            num_pred=num_actions_,
                                            batch_norm_param=batch_norm_param,
                                            dropout_ratio=args.arch.fc_dropout,
                                            is_training=is_training)
            if num_actions_ > num_actions:
                m.action_logits_op = out_op[:, :num_actions]
                m.baseline_op = out_op[:, num_actions:]
            else:
                m.action_logits_op = out_op
                m.baseline_op = None
            m.action_prob_op = tf.nn.softmax(m.action_logits_op)

    m.train_ops = {}
    m.train_ops['step'] = m.action_prob_op
    m.train_ops['common'] = [
        m.input_tensors['common']['orig_maps'],
        m.input_tensors['common']['goal_loc'],
        m.input_tensors['common']['rel_goal_loc_at_start']
    ]
    m.train_ops['state_names'] = state_names
    m.train_ops['init_state'] = init_state_ops
    m.train_ops['updated_state'] = updated_state_ops
    m.train_ops['batch_norm_is_training_op'] = batch_norm_is_training_op

    # Flat list of ops which cache the step data.
    m.train_ops['step_data_cache'] = [tf.no_op()]

    if args.solver.freeze_conv:
        m.train_ops['step_data_cache'] = [m.encoder_output]
    else:
        m.train_ops['step_data_cache'] = []

    ewma_decay = 0.99 if is_training else 0.0
    weight = tf.ones_like(m.input_tensors['train']['action'],
                          dtype=tf.float32,
                          name='weight')

    m.reg_loss_op, m.data_loss_op, m.total_loss_op, m.acc_ops = \
      compute_losses_multi_or(
          m.action_logits_op, m.input_tensors['train']['action'],
          weights=weight, num_actions=num_actions,
          data_loss_wt=args.solver.data_loss_wt,
          reg_loss_wt=args.solver.reg_loss_wt, ewma_decay=ewma_decay)

    if args.solver.freeze_conv:
        vars_to_optimize = list(set(tf.trainable_variables()) - set(vars_))
    else:
        vars_to_optimize = None

    m.lr_op, m.global_step_op, m.train_op, m.should_stop_op, m.optimizer, \
    m.sync_optimizer = tf_utils.setup_training(
        m.total_loss_op,
        args.solver.initial_learning_rate,
        args.solver.steps_per_decay,
        args.solver.learning_rate_decay,
        args.solver.momentum,
        args.solver.max_steps,
        args.solver.sync,
        args.solver.adjust_lr_sync,
        args.solver.num_workers,
        args.solver.task,
        vars_to_optimize=vars_to_optimize,
        clip_gradient_norm=args.solver.clip_gradient_norm,
        typ=args.solver.typ, momentum2=args.solver.momentum2,
        adam_eps=args.solver.adam_eps)

    if args.arch.sample_gt_prob_type == 'inverse_sigmoid_decay':
        m.sample_gt_prob_op = tf_utils.inverse_sigmoid_decay(
            args.arch.isd_k, m.global_step_op)
    elif args.arch.sample_gt_prob_type == 'zero':
        m.sample_gt_prob_op = tf.constant(-1.0, dtype=tf.float32)
    elif args.arch.sample_gt_prob_type.split('_')[0] == 'step':
        step = int(args.arch.sample_gt_prob_type.split('_')[1])
        m.sample_gt_prob_op = tf_utils.step_gt_prob(
            step, m.input_tensors['step']['step_number'][0, 0, 0])

    m.sample_action_type = args.arch.action_sample_type
    m.sample_action_combine_type = args.arch.action_sample_combine_type
    _add_summaries(m, summary_mode, args.summary.arop_full_summary_iters)

    m.init_op = tf.group(tf.global_variables_initializer(),
                         tf.local_variables_initializer())
    m.saver_op = tf.train.Saver(keep_checkpoint_every_n_hours=4,
                                write_version=tf.train.SaverDef.V2)

    return m
Esempio n. 3
0
def setup_to_run(m, args, is_training, batch_norm_is_training, summary_mode):
  # Set up the model.
  tf.set_random_seed(args.solver.seed)
  task_params = args.navtask.task_params
  num_steps = task_params.num_steps
  num_goals = task_params.num_goals
  num_actions = task_params.num_actions
  num_actions_ = num_actions

  n_views = task_params.n_views

  batch_norm_is_training_op = \
      tf.placeholder_with_default(batch_norm_is_training, shape=[],
                                  name='batch_norm_is_training_op') 
  # Setup the inputs
  m.input_tensors = {}
  lstm_states = []; lstm_state_dims = [];
  state_names = []; updated_state_ops = []; init_state_ops = [];
  if args.arch.lstm_output:
    lstm_states += ['lstm_output']
    lstm_state_dims += [args.arch.lstm_output_dim+task_params.num_actions]
  if args.arch.lstm_ego:
    lstm_states += ['lstm_ego']
    lstm_state_dims += [args.arch.lstm_ego_dim + args.arch.lstm_ego_out]
    lstm_states += ['lstm_img']
    lstm_state_dims += [args.arch.lstm_img_dim + args.arch.lstm_img_out]
  elif args.arch.lstm_img:
    # An LSTM only on the image
    lstm_states += ['lstm_img']
    lstm_state_dims += [args.arch.lstm_img_dim + args.arch.lstm_img_out]
  else:
    # No LSTMs involved here.
    None

  m.input_tensors['common'], m.input_tensors['step'], m.input_tensors['train'] = \
      _inputs(task_params, lstm_states, lstm_state_dims)

  with tf.name_scope('check_size'):
    is_single_step = tf.equal(tf.unstack(tf.shape(m.input_tensors['step']['imgs']), 
                                        num=6)[1], 1)

  images_reshaped = tf.reshape(m.input_tensors['step']['imgs'], 
      shape=[-1, task_params.img_height, task_params.img_width,
             task_params.img_channels], name='re_image')

  rel_goal_loc_reshaped = tf.reshape(m.input_tensors['step']['rel_goal_loc'], 
      shape=[-1, task_params.rel_goal_loc_dim], name='re_rel_goal_loc')

  x, vars_ = get_repr_from_image(
      images_reshaped, task_params.modalities, task_params.data_augment,
      args.arch.encoder, args.solver.freeze_conv, args.solver.wt_decay,
      is_training)

  # Reshape into nice things so that these can be accumulated over time steps
  # for faster backprop.
  sh_before = x.get_shape().as_list()
  m.encoder_output = tf.reshape(
      x, shape=[task_params.batch_size, -1, n_views] + sh_before[1:])
  x = tf.reshape(m.encoder_output, shape=[-1] + sh_before[1:])

  # Add a layer to reduce dimensions for a fc layer.
  if args.arch.dim_reduce_neurons > 0:
    ks = 1; neurons = args.arch.dim_reduce_neurons;
    init_var = np.sqrt(2.0/(ks**2)/neurons)
    batch_norm_param = args.arch.batch_norm_param
    batch_norm_param['is_training'] = batch_norm_is_training_op
    m.conv_feat = slim.conv2d(
        x, neurons, kernel_size=ks, stride=1, normalizer_fn=slim.batch_norm,
        normalizer_params=batch_norm_param, padding='SAME', scope='dim_reduce',
        weights_regularizer=slim.l2_regularizer(args.solver.wt_decay),
        weights_initializer=tf.random_normal_initializer(stddev=init_var))
    reshape_conv_feat = slim.flatten(m.conv_feat)
    sh = reshape_conv_feat.get_shape().as_list()
    m.reshape_conv_feat = tf.reshape(reshape_conv_feat, 
                                     shape=[-1, sh[1]*n_views])

  # Restore these from a checkpoint.
  if args.solver.pretrained_path is not None:
    m.init_fn = slim.assign_from_checkpoint_fn(args.solver.pretrained_path,
                                               vars_)
  else:
    m.init_fn = None

  # Hit the goal_location with a bunch of fully connected layers, to embed it
  # into some space.
  with tf.variable_scope('embed_goal'):
    batch_norm_param = args.arch.batch_norm_param
    batch_norm_param['is_training'] = batch_norm_is_training_op
    m.embed_goal, _ = tf_utils.fc_network(
        rel_goal_loc_reshaped, neurons=args.arch.goal_embed_neurons,
        wt_decay=args.solver.wt_decay, name='goal_embed', offset=0,
        batch_norm_param=batch_norm_param, dropout_ratio=args.arch.fc_dropout,
        is_training=is_training)
  
  if args.arch.embed_goal_for_state:
    with tf.variable_scope('embed_goal_for_state'):
      batch_norm_param = args.arch.batch_norm_param
      batch_norm_param['is_training'] = batch_norm_is_training_op
      m.embed_goal_for_state, _ = tf_utils.fc_network(
          m.input_tensors['common']['rel_goal_loc_at_start'][:,0,:],
          neurons=args.arch.goal_embed_neurons, wt_decay=args.solver.wt_decay,
          name='goal_embed', offset=0, batch_norm_param=batch_norm_param,
          dropout_ratio=args.arch.fc_dropout, is_training=is_training)

  # Hit the goal_location with a bunch of fully connected layers, to embed it
  # into some space.
  with tf.variable_scope('embed_img'):
    batch_norm_param = args.arch.batch_norm_param
    batch_norm_param['is_training'] = batch_norm_is_training_op
    m.embed_img, _ = tf_utils.fc_network(
        m.reshape_conv_feat, neurons=args.arch.img_embed_neurons,
        wt_decay=args.solver.wt_decay, name='img_embed', offset=0,
        batch_norm_param=batch_norm_param, dropout_ratio=args.arch.fc_dropout,
        is_training=is_training)

  # For lstm_ego, and lstm_image, embed the ego motion, accumulate it into an
  # LSTM, combine with image features and accumulate those in an LSTM. Finally
  # combine what you get from the image LSTM with the goal to output an action.
  if args.arch.lstm_ego:
    ego_reshaped = preprocess_egomotion(m.input_tensors['step']['incremental_locs'], 
                                        m.input_tensors['step']['incremental_thetas'])
    with tf.variable_scope('embed_ego'):
      batch_norm_param = args.arch.batch_norm_param
      batch_norm_param['is_training'] = batch_norm_is_training_op
      m.embed_ego, _ = tf_utils.fc_network(
          ego_reshaped, neurons=args.arch.ego_embed_neurons,
          wt_decay=args.solver.wt_decay, name='ego_embed', offset=0,
          batch_norm_param=batch_norm_param, dropout_ratio=args.arch.fc_dropout,
          is_training=is_training)

    state_name, state_init_op, updated_state_op, out_op = lstm_setup(
        'lstm_ego', m.embed_ego, task_params.batch_size, is_single_step, 
        args.arch.lstm_ego_dim, args.arch.lstm_ego_out, num_steps*num_goals,
        m.input_tensors['step']['lstm_ego'])
    state_names += [state_name]
    init_state_ops += [state_init_op]
    updated_state_ops += [updated_state_op]

    # Combine the output with the vision features.
    m.img_ego_op = combine_setup('img_ego', args.arch.combine_type_ego,
                                 m.embed_img, out_op,
                                 args.arch.img_embed_neurons[-1],
                                 args.arch.lstm_ego_out)

    # LSTM on these vision features.
    state_name, state_init_op, updated_state_op, out_op = lstm_setup(
        'lstm_img', m.img_ego_op, task_params.batch_size, is_single_step, 
        args.arch.lstm_img_dim, args.arch.lstm_img_out, num_steps*num_goals,
        m.input_tensors['step']['lstm_img'])
    state_names += [state_name]
    init_state_ops += [state_init_op]
    updated_state_ops += [updated_state_op]

    m.img_for_goal = out_op
    num_img_for_goal_neurons = args.arch.lstm_img_out

  elif args.arch.lstm_img:
    # LSTM on just the image features.
    state_name, state_init_op, updated_state_op, out_op = lstm_setup(
        'lstm_img', m.embed_img, task_params.batch_size, is_single_step,
        args.arch.lstm_img_dim, args.arch.lstm_img_out, num_steps*num_goals,
        m.input_tensors['step']['lstm_img'])
    state_names += [state_name]
    init_state_ops += [state_init_op]
    updated_state_ops += [updated_state_op]
    m.img_for_goal = out_op
    num_img_for_goal_neurons = args.arch.lstm_img_out

  else:
    m.img_for_goal = m.embed_img
    num_img_for_goal_neurons = args.arch.img_embed_neurons[-1]


  if args.arch.use_visit_count:
    m.embed_visit_count = visit_count_fc(
        m.input_tensors['step']['visit_count'],
        m.input_tensors['step']['last_visit'], args.arch.goal_embed_neurons,
        args.solver.wt_decay, args.arch.fc_dropout, is_training=is_training)
    m.embed_goal = m.embed_goal + m.embed_visit_count
  
  m.combined_f = combine_setup('img_goal', args.arch.combine_type,
                               m.img_for_goal, m.embed_goal,
                               num_img_for_goal_neurons,
                               args.arch.goal_embed_neurons[-1])

  # LSTM on the combined representation.
  if args.arch.lstm_output:
    name = 'lstm_output'
    # A few fully connected layers here.
    with tf.variable_scope('action_pred'):
      batch_norm_param = args.arch.batch_norm_param
      batch_norm_param['is_training'] = batch_norm_is_training_op
      x, _ = tf_utils.fc_network(
          m.combined_f, neurons=args.arch.pred_neurons,
          wt_decay=args.solver.wt_decay, name='pred', offset=0,
          batch_norm_param=batch_norm_param, dropout_ratio=args.arch.fc_dropout)

    if args.arch.lstm_output_init_state_from_goal:
      # Use the goal embedding to initialize the LSTM state.
      # UGLY CLUGGY HACK: if this is doing computation for a single time step
      # then this will not involve back prop, so we can use the state input from
      # the feed dict, otherwise we compute the state representation from the
      # goal and feed that in. Necessary for using goal location to generate the
      # state representation.
      m.embed_goal_for_state = tf.expand_dims(m.embed_goal_for_state, dim=1)
      state_op = tf.cond(is_single_step, lambda: m.input_tensors['step'][name],
                         lambda: m.embed_goal_for_state)
      state_name, state_init_op, updated_state_op, out_op = lstm_setup(
          name, x, task_params.batch_size, is_single_step,
          args.arch.lstm_output_dim,
          num_actions_,
          num_steps*num_goals, state_op)
      init_state_ops += [m.embed_goal_for_state]
    else:
      state_op = m.input_tensors['step'][name]
      state_name, state_init_op, updated_state_op, out_op = lstm_setup(
          name, x, task_params.batch_size, is_single_step,
          args.arch.lstm_output_dim,
          num_actions_, num_steps*num_goals, state_op)
      init_state_ops += [state_init_op]

    state_names += [state_name]
    updated_state_ops += [updated_state_op]

    out_op = tf.reshape(out_op, shape=[-1, num_actions_])
    if num_actions_ > num_actions:
      m.action_logits_op = out_op[:,:num_actions]
      m.baseline_op = out_op[:,num_actions:]
    else:
      m.action_logits_op = out_op
      m.baseline_op = None
    m.action_prob_op = tf.nn.softmax(m.action_logits_op)

  else:
    # A few fully connected layers here.
    with tf.variable_scope('action_pred'):
      batch_norm_param = args.arch.batch_norm_param
      batch_norm_param['is_training'] = batch_norm_is_training_op
      out_op, _ = tf_utils.fc_network(
          m.combined_f, neurons=args.arch.pred_neurons,
          wt_decay=args.solver.wt_decay, name='pred', offset=0,
          num_pred=num_actions_,
          batch_norm_param=batch_norm_param,
          dropout_ratio=args.arch.fc_dropout, is_training=is_training)
      if num_actions_ > num_actions:
        m.action_logits_op = out_op[:,:num_actions]
        m.baseline_op = out_op[:,num_actions:]
      else:
        m.action_logits_op = out_op 
        m.baseline_op = None
      m.action_prob_op = tf.nn.softmax(m.action_logits_op)

  m.train_ops = {}
  m.train_ops['step'] = m.action_prob_op
  m.train_ops['common'] = [m.input_tensors['common']['orig_maps'],
                           m.input_tensors['common']['goal_loc'],
                           m.input_tensors['common']['rel_goal_loc_at_start']]
  m.train_ops['state_names'] = state_names
  m.train_ops['init_state'] = init_state_ops
  m.train_ops['updated_state'] = updated_state_ops
  m.train_ops['batch_norm_is_training_op'] = batch_norm_is_training_op

  # Flat list of ops which cache the step data.
  m.train_ops['step_data_cache'] = [tf.no_op()]

  if args.solver.freeze_conv:
    m.train_ops['step_data_cache'] = [m.encoder_output]
  else:
    m.train_ops['step_data_cache'] = []

  ewma_decay = 0.99 if is_training else 0.0
  weight = tf.ones_like(m.input_tensors['train']['action'], dtype=tf.float32,
                        name='weight')

  m.reg_loss_op, m.data_loss_op, m.total_loss_op, m.acc_ops = \
    compute_losses_multi_or(
        m.action_logits_op, m.input_tensors['train']['action'],
        weights=weight, num_actions=num_actions,
        data_loss_wt=args.solver.data_loss_wt,
        reg_loss_wt=args.solver.reg_loss_wt, ewma_decay=ewma_decay)


  if args.solver.freeze_conv:
    vars_to_optimize = list(set(tf.trainable_variables()) - set(vars_))
  else:
    vars_to_optimize = None

  m.lr_op, m.global_step_op, m.train_op, m.should_stop_op, m.optimizer, \
  m.sync_optimizer = tf_utils.setup_training(
      m.total_loss_op, 
      args.solver.initial_learning_rate, 
      args.solver.steps_per_decay,
      args.solver.learning_rate_decay, 
      args.solver.momentum,
      args.solver.max_steps, 
      args.solver.sync, 
      args.solver.adjust_lr_sync,
      args.solver.num_workers, 
      args.solver.task,
      vars_to_optimize=vars_to_optimize,
      clip_gradient_norm=args.solver.clip_gradient_norm,
      typ=args.solver.typ, momentum2=args.solver.momentum2,
      adam_eps=args.solver.adam_eps)
  
  
  if args.arch.sample_gt_prob_type == 'inverse_sigmoid_decay':
    m.sample_gt_prob_op = tf_utils.inverse_sigmoid_decay(args.arch.isd_k,
                                                         m.global_step_op)
  elif args.arch.sample_gt_prob_type == 'zero':
    m.sample_gt_prob_op = tf.constant(-1.0, dtype=tf.float32)
  elif args.arch.sample_gt_prob_type.split('_')[0] == 'step':
    step = int(args.arch.sample_gt_prob_type.split('_')[1])
    m.sample_gt_prob_op = tf_utils.step_gt_prob(
        step, m.input_tensors['step']['step_number'][0,0,0])
  
  m.sample_action_type = args.arch.action_sample_type
  m.sample_action_combine_type = args.arch.action_sample_combine_type
  _add_summaries(m, summary_mode, args.summary.arop_full_summary_iters)
  
  m.init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
  m.saver_op = tf.train.Saver(keep_checkpoint_every_n_hours=4,
                              write_version=tf.train.SaverDef.V2)
  
  return m
Esempio n. 4
0
def setup_to_run(m, args, is_training, batch_norm_is_training, summary_mode):
  assert(args.arch.multi_scale), 'removed support for old single scale code.'
  # Set up the model.
  tf.set_random_seed(args.solver.seed)
  task_params = args.navtask.task_params

  batch_norm_is_training_op = \
      tf.placeholder_with_default(batch_norm_is_training, shape=[],
                                  name='batch_norm_is_training_op') 

  # Setup the inputs
  m.input_tensors = {}
  m.train_ops = {}
  m.input_tensors['common'], m.input_tensors['step'], m.input_tensors['train'] = \
      _inputs(task_params)

  m.init_fn = None

  if task_params.input_type == 'vision':
    m.vision_ops = get_map_from_images(
        m.input_tensors['step']['imgs'], args.mapper_arch,
        task_params, args.solver.freeze_conv,
        args.solver.wt_decay, is_training, batch_norm_is_training_op,
        num_maps=len(task_params.map_crop_sizes))

    # Load variables from snapshot if needed.
    if args.solver.pretrained_path is not None:
      m.init_fn = slim.assign_from_checkpoint_fn(args.solver.pretrained_path,
                                                 m.vision_ops.vars_to_restore)

    # Set up caching of vision features if needed.
    if args.solver.freeze_conv:
      m.train_ops['step_data_cache'] = [m.vision_ops.encoder_output]
    else:
      m.train_ops['step_data_cache'] = []

    # Set up blobs that are needed for the computation in rest of the graph.
    m.ego_map_ops = m.vision_ops.fss_logits
    m.coverage_ops = m.vision_ops.confs_probs
    
    # Zero pad these to make them same size as what the planner expects.
    for i in range(len(m.ego_map_ops)):
      if args.mapper_arch.pad_map_with_zeros_each[i] > 0:
        paddings = np.zeros((5,2), dtype=np.int32)
        paddings[2:4,:] = args.mapper_arch.pad_map_with_zeros_each[i]
        paddings_op = tf.constant(paddings, dtype=tf.int32)
        m.ego_map_ops[i] = tf.pad(m.ego_map_ops[i], paddings=paddings_op)
        m.coverage_ops[i] = tf.pad(m.coverage_ops[i], paddings=paddings_op)
  
  elif task_params.input_type == 'analytical_counts':
    m.ego_map_ops = []; m.coverage_ops = []
    for i in range(len(task_params.map_crop_sizes)):
      ego_map_op = m.input_tensors['step']['analytical_counts_{:d}'.format(i)]
      coverage_op = tf.cast(tf.greater_equal(
          tf.reduce_max(ego_map_op, reduction_indices=[4],
                        keep_dims=True), 1), tf.float32)
      coverage_op = tf.ones_like(ego_map_op) * coverage_op
      m.ego_map_ops.append(ego_map_op)
      m.coverage_ops.append(coverage_op)
      m.train_ops['step_data_cache'] = []
  
  num_steps = task_params.num_steps
  num_goals = task_params.num_goals

  map_crop_size_ops = []
  for map_crop_size in task_params.map_crop_sizes:
    map_crop_size_ops.append(tf.constant(map_crop_size, dtype=tf.int32, shape=(2,)))

  with tf.name_scope('check_size'):
    is_single_step = tf.equal(tf.unstack(tf.shape(m.ego_map_ops[0]), num=5)[1], 1)

  fr_ops = []; value_ops = [];
  fr_intermediate_ops = []; value_intermediate_ops = [];
  crop_value_ops = [];
  resize_crop_value_ops = [];
  confs = []; occupancys = [];

  previous_value_op = None
  updated_state = []; state_names = [];

  for i in range(len(task_params.map_crop_sizes)):
    map_crop_size = task_params.map_crop_sizes[i]
    with tf.variable_scope('scale_{:d}'.format(i)): 
      # Accumulate the map.
      fn = lambda ns: running_combine(
             m.ego_map_ops[i],
             m.coverage_ops[i],
             m.input_tensors['step']['incremental_locs'] * task_params.map_scales[i],
             m.input_tensors['step']['incremental_thetas'],
             m.input_tensors['step']['running_sum_num_{:d}'.format(i)],
             m.input_tensors['step']['running_sum_denom_{:d}'.format(i)],
             m.input_tensors['step']['running_max_denom_{:d}'.format(i)],
             map_crop_size, ns)

      running_sum_num, running_sum_denom, running_max_denom = \
          tf.cond(is_single_step, lambda: fn(1), lambda: fn(num_steps*num_goals))
      updated_state += [running_sum_num, running_sum_denom, running_max_denom]
      state_names += ['running_sum_num_{:d}'.format(i),
                      'running_sum_denom_{:d}'.format(i),
                      'running_max_denom_{:d}'.format(i)]

      # Concat the accumulated map and goal
      occupancy = running_sum_num / tf.maximum(running_sum_denom, 0.001)
      conf = running_max_denom
      # print occupancy.get_shape().as_list()

      # Concat occupancy, how much occupied and goal.
      with tf.name_scope('concat'):
        sh = [-1, map_crop_size, map_crop_size, task_params.map_channels]
        occupancy = tf.reshape(occupancy, shape=sh)
        conf = tf.reshape(conf, shape=sh)

        sh = [-1, map_crop_size, map_crop_size, task_params.goal_channels]
        goal = tf.reshape(m.input_tensors['step']['ego_goal_imgs_{:d}'.format(i)], shape=sh)
        to_concat = [occupancy, conf, goal]

        if previous_value_op is not None:
          to_concat.append(previous_value_op)

        x = tf.concat(to_concat, 3)

      # Pass the map, previous rewards and the goal through a few convolutional
      # layers to get fR.
      fr_op, fr_intermediate_op = fr_v2(
         x, output_neurons=args.arch.fr_neurons,
         inside_neurons=args.arch.fr_inside_neurons,
         is_training=batch_norm_is_training_op, name='fr',
         wt_decay=args.solver.wt_decay, stride=args.arch.fr_stride)

      # Do Value Iteration on the fR
      if args.arch.vin_num_iters > 0:
        value_op, value_intermediate_op = value_iteration_network(
            fr_op, num_iters=args.arch.vin_num_iters,
            val_neurons=args.arch.vin_val_neurons,
            action_neurons=args.arch.vin_action_neurons,
            kernel_size=args.arch.vin_ks, share_wts=args.arch.vin_share_wts,
            name='vin', wt_decay=args.solver.wt_decay)
      else:
        value_op = fr_op
        value_intermediate_op = []

      # Crop out and upsample the previous value map.
      remove = args.arch.crop_remove_each
      if remove > 0:
        crop_value_op = value_op[:, remove:-remove, remove:-remove,:]
      else:
        crop_value_op = value_op
      crop_value_op = tf.reshape(crop_value_op, shape=[-1, args.arch.value_crop_size,
                                                       args.arch.value_crop_size,
                                                       args.arch.vin_val_neurons])
      if i < len(task_params.map_crop_sizes)-1:
        # Reshape it to shape of the next scale.
        previous_value_op = tf.image.resize_bilinear(crop_value_op,
                                                     map_crop_size_ops[i+1],
                                                     align_corners=True)
        resize_crop_value_ops.append(previous_value_op)
      
      occupancys.append(occupancy)
      confs.append(conf)
      value_ops.append(value_op)
      crop_value_ops.append(crop_value_op)
      fr_ops.append(fr_op)
      fr_intermediate_ops.append(fr_intermediate_op)
  
  m.value_ops = value_ops
  m.value_intermediate_ops = value_intermediate_ops
  m.fr_ops = fr_ops
  m.fr_intermediate_ops = fr_intermediate_ops
  m.final_value_op = crop_value_op
  m.crop_value_ops = crop_value_ops
  m.resize_crop_value_ops = resize_crop_value_ops
  m.confs = confs
  m.occupancys = occupancys

  sh = [-1, args.arch.vin_val_neurons*((args.arch.value_crop_size)**2)]
  m.value_features_op = tf.reshape(m.final_value_op, sh, name='reshape_value_op')
  
  # Determine what action to take.
  with tf.variable_scope('action_pred'):
    batch_norm_param = args.arch.pred_batch_norm_param
    if batch_norm_param is not None:
      batch_norm_param['is_training'] = batch_norm_is_training_op
    m.action_logits_op, _ = tf_utils.fc_network(
        m.value_features_op, neurons=args.arch.pred_neurons,
        wt_decay=args.solver.wt_decay, name='pred', offset=0,
        num_pred=task_params.num_actions,
        batch_norm_param=batch_norm_param) 
    m.action_prob_op = tf.nn.softmax(m.action_logits_op)

  init_state = tf.constant(0., dtype=tf.float32, shape=[
      task_params.batch_size, 1, map_crop_size, map_crop_size,
      task_params.map_channels])

  m.train_ops['state_names'] = state_names
  m.train_ops['updated_state'] = updated_state
  m.train_ops['init_state'] = [init_state for _ in updated_state]

  m.train_ops['step'] = m.action_prob_op
  m.train_ops['common'] = [m.input_tensors['common']['orig_maps'],
                           m.input_tensors['common']['goal_loc']]
  m.train_ops['batch_norm_is_training_op'] = batch_norm_is_training_op
  m.loss_ops = []; m.loss_ops_names = [];

  if args.arch.readout_maps:
    with tf.name_scope('readout_maps'):
      all_occupancys = tf.concat(m.occupancys + m.confs, 3)
      readout_maps, probs = readout_general(
          all_occupancys, num_neurons=args.arch.rom_arch.num_neurons,
          strides=args.arch.rom_arch.strides, 
          layers_per_block=args.arch.rom_arch.layers_per_block, 
          kernel_size=args.arch.rom_arch.kernel_size,
          batch_norm_is_training_op=batch_norm_is_training_op,
          wt_decay=args.solver.wt_decay)

      gt_ego_maps = [m.input_tensors['step']['readout_maps_{:d}'.format(i)]
                     for i in range(len(task_params.readout_maps_crop_sizes))]
      m.readout_maps_gt = tf.concat(gt_ego_maps, 4)
      gt_shape = tf.shape(m.readout_maps_gt)
      m.readout_maps_logits = tf.reshape(readout_maps, gt_shape)
      m.readout_maps_probs = tf.reshape(probs, gt_shape)

      # Add a loss op
      m.readout_maps_loss_op = tf.losses.sigmoid_cross_entropy(
          tf.reshape(m.readout_maps_gt, [-1, len(task_params.readout_maps_crop_sizes)]), 
          tf.reshape(readout_maps, [-1, len(task_params.readout_maps_crop_sizes)]),
          scope='loss')
      m.readout_maps_loss_op = 10.*m.readout_maps_loss_op

  ewma_decay = 0.99 if is_training else 0.0
  weight = tf.ones_like(m.input_tensors['train']['action'], dtype=tf.float32,
                        name='weight')
  m.reg_loss_op, m.data_loss_op, m.total_loss_op, m.acc_ops = \
    compute_losses_multi_or(m.action_logits_op,
                            m.input_tensors['train']['action'], weights=weight,
                            num_actions=task_params.num_actions,
                            data_loss_wt=args.solver.data_loss_wt,
                            reg_loss_wt=args.solver.reg_loss_wt,
                            ewma_decay=ewma_decay)
  
  if args.arch.readout_maps:
    m.total_loss_op = m.total_loss_op + m.readout_maps_loss_op
    m.loss_ops += [m.readout_maps_loss_op]
    m.loss_ops_names += ['readout_maps_loss']

  m.loss_ops += [m.reg_loss_op, m.data_loss_op, m.total_loss_op]
  m.loss_ops_names += ['reg_loss', 'data_loss', 'total_loss']

  if args.solver.freeze_conv:
    vars_to_optimize = list(set(tf.trainable_variables()) -
                            set(m.vision_ops.vars_to_restore))
  else:
    vars_to_optimize = None

  m.lr_op, m.global_step_op, m.train_op, m.should_stop_op, m.optimizer, \
  m.sync_optimizer = tf_utils.setup_training(
      m.total_loss_op, 
      args.solver.initial_learning_rate, 
      args.solver.steps_per_decay,
      args.solver.learning_rate_decay, 
      args.solver.momentum,
      args.solver.max_steps, 
      args.solver.sync, 
      args.solver.adjust_lr_sync,
      args.solver.num_workers, 
      args.solver.task,
      vars_to_optimize=vars_to_optimize,
      clip_gradient_norm=args.solver.clip_gradient_norm,
      typ=args.solver.typ, momentum2=args.solver.momentum2,
      adam_eps=args.solver.adam_eps)

  if args.arch.sample_gt_prob_type == 'inverse_sigmoid_decay':
    m.sample_gt_prob_op = tf_utils.inverse_sigmoid_decay(args.arch.isd_k,
                                                         m.global_step_op)
  elif args.arch.sample_gt_prob_type == 'zero':
    m.sample_gt_prob_op = tf.constant(-1.0, dtype=tf.float32)

  elif args.arch.sample_gt_prob_type.split('_')[0] == 'step':
    step = int(args.arch.sample_gt_prob_type.split('_')[1])
    m.sample_gt_prob_op = tf_utils.step_gt_prob(
        step, m.input_tensors['step']['step_number'][0,0,0])

  m.sample_action_type = args.arch.action_sample_type
  m.sample_action_combine_type = args.arch.action_sample_combine_type

  m.summary_ops = {
      summary_mode: _add_summaries(m, args, summary_mode,
                                   args.summary.arop_full_summary_iters)}

  m.init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
  m.saver_op = tf.train.Saver(keep_checkpoint_every_n_hours=4,
                              write_version=tf.train.SaverDef.V2)
  return m