예제 #1
0
def convert_to_graph_tool(G):
    timer = utils.Timer()
    timer.tic()
    gtG = gt.Graph(directed=G.is_directed())
    gtG.ep['action'] = gtG.new_edge_property('int')

    nodes_list = G.nodes()
    nodes_array = np.array(nodes_list)

    nodes_id = np.zeros((nodes_array.shape[0], ), dtype=np.int64)

    for i in range(nodes_array.shape[0]):
        v = gtG.add_vertex()
        nodes_id[i] = int(v)

    # d = {key: value for (key, value) in zip(nodes_list, nodes_id)}
    d = dict(itertools.izip(nodes_list, nodes_id))

    for src, dst, data in G.edges_iter(data=True):
        e = gtG.add_edge(d[src], d[dst])
        gtG.ep['action'][e] = data['action']
    nodes_to_id = d
    timer.toc(average=True,
              log_at=1,
              log_str='src.graph_utils.convert_to_graph_tool')
    return gtG, nodes_array, nodes_to_id
예제 #2
0
def compute_traversibility(map,
                           robot_base,
                           robot_height,
                           robot_radius,
                           valid_min,
                           valid_max,
                           num_point_threshold,
                           shapess,
                           sc=100.,
                           n_samples_per_face=200):
    """Returns a bit map with pixels that are traversible or not as long as the
  robot center is inside this volume we are good colisions can be detected by
  doing a line search on things, or walking from current location to final
  location in the bitmap, or doing bwlabel on the traversibility map."""

    tt = utils.Timer()
    tt.tic()
    num_obstcale_points = np.zeros((map.size[1], map.size[0]))
    num_points = np.zeros((map.size[1], map.size[0]))

    for i, shapes in enumerate(shapess):
        for j in range(shapes.get_number_of_meshes()):
            p, face_areas, face_idx = shapes.sample_points_on_face_of_shape(
                j, n_samples_per_face, sc)
            wt = face_areas[face_idx] / n_samples_per_face

            ind = np.all(np.concatenate(
                (p[:, [2]] > robot_base,
                 p[:, [2]] < robot_base + robot_height),
                axis=1),
                         axis=1)
            num_obstcale_points += _project_to_map(map, p[ind, :], wt[ind])

            ind = np.all(np.concatenate(
                (p[:, [2]] > valid_min, p[:, [2]] < valid_max), axis=1),
                         axis=1)
            num_points += _project_to_map(map, p[ind, :], wt[ind])

    selem = skimage.morphology.disk(robot_radius / map.resolution)
    obstacle_free = skimage.morphology.binary_dilation(
        _fill_holes(num_obstcale_points > num_point_threshold, 20),
        selem) != True
    valid_space = _fill_holes(num_points > num_point_threshold, 20)
    traversible = np.all(np.concatenate(
        (obstacle_free[..., np.newaxis], valid_space[..., np.newaxis]),
        axis=2),
                         axis=2)
    # plt.imshow(np.concatenate((obstacle_free, valid_space, traversible), axis=1))
    # plt.show()

    map_out = copy.deepcopy(map)
    map_out.num_obstcale_points = num_obstcale_points
    map_out.num_points = num_points
    map_out.traversible = traversible
    map_out.obstacle_free = obstacle_free
    map_out.valid_space = valid_space
    tt.toc(log_at=1, log_str='src.map_utils.compute_traversibility: ')
    return map_out
예제 #3
0
def make_prediction():
    '''
        Do prediction for each stock 
    '''
    otimer = utils.Timer("Get started process of stocs's prices prediction:")
    dh = db.Data_handler()
    stocks_list = dh.get_stocks_list()
    for one_stoks in stocks_list:
        print("\nPredicting prices for {}, {}".format(one_stoks[0],
                                                      one_stoks[2]))
        make_stock_prediction(mfd_id=one_stoks[1], db_connection=dh)
    otimer.show()
예제 #4
0
 def update_data(self, ppredict=False):
     otimer = utils.Timer("Get started loading all stocks prices:")
     dh = db.Data_handler()
     stocks_list = dh.get_stocks_list()
     for one_stoks in stocks_list:
         print("\nLoading prices for {}, {}".format(one_stoks[0],
                                                    one_stoks[2]))
         self.load_stock_prises_from_inet(mfd_id=one_stoks[1])
         if ppredict:
             prediction.make_stock_prediction(mfd_id=one_stoks[1],
                                              db_connection=dh)
     otimer.show()
예제 #5
0
def generate_graph(valid_fn_vec=None,
                   sc=1.,
                   n_ori=6,
                   starting_location=(0, 0, 0),
                   vis=False,
                   directed=True):
    timer = utils.Timer()
    timer.tic()
    if directed: G = nx.DiGraph(directed=True)
    else: G = nx.Graph()
    G.add_node(starting_location)
    new_nodes = G.nodes()
    while len(new_nodes) != 0:
        nodes_to_add = []
        nodes_to_validate = []
        for n in new_nodes:
            if directed:
                na, nv = _get_next_nodes(n, sc, n_ori)
            else:
                na, nv = _get_next_nodes_undirected(n, sc, n_ori)
            nodes_to_add = nodes_to_add + na
            if valid_fn_vec is not None:
                nodes_to_validate = nodes_to_validate + nv
            else:
                node_to_add = nodes_to_add + nv

        # Validate nodes.
        vs = [_[1] for _ in nodes_to_validate]
        valids = valid_fn_vec(vs)

        for nva, valid in zip(nodes_to_validate, valids):
            if valid:
                nodes_to_add.append(nva)

        new_nodes = []
        for n, v, a in nodes_to_add:
            if not G.has_node(v):
                new_nodes.append(v)
            G.add_edge(n, v, action=a)

    timer.toc(average=True, log_at=1, log_str='src.graph_utils.generate_graph')
    #return (G)
    return G
예제 #6
0
def train_step_fn(sess,
                  train_op,
                  global_step,
                  train_step_kwargs,
                  mode='train'):
    Z = train_step_kwargs['Z']
    agent = train_step_kwargs['agent']
    rng_data = train_step_kwargs['rng_data']
    rng_action = train_step_kwargs['rng_action']
    writer = train_step_kwargs['writer']
    iters = train_step_kwargs['iters']
    num_steps = train_step_kwargs['num_steps']
    logdir = train_step_kwargs['logdir']
    dagger_sample_bn_false = train_step_kwargs['dagger_sample_bn_false']
    train_display_interval = train_step_kwargs['train_display_interval']
    if 'outputs' not in Z.train_ops:
        Z.train_ops['outputs'] = []

    s_ops = Z.summary_ops[mode]
    val_additional_ops = []

    # Print all variables here.
    # if True:
    #     v = tf.get_collection(tf.GraphKeys.VARIABLES)
    #     v_op = [_.value() for _ in v]
    #     v_op_value = sess.run(v_op)
    #
    #     filter = lambda x, y: 'Adam' in x.name
    #     # filter = lambda x, y: np.is_any_nan(y)
    #     ind = [i for i, (_, __) in enumerate(zip(v, v_op_value)) if filter(_, __)]
    #     v = [v[i] for i in ind]
    #     v_op_value = [v_op_value[i] for i in ind]
    #
    #     for i in range(len(v)):
    #         logging.info('XXXX: variable: %30s, is_any_nan: %5s, norm: %f.',
    #                      v[i].name, np.any(np.isnan(v_op_value[i])),
    #                      np.linalg.norm(v_op_value[i]))

    tt = utils.Timer()

    total_loss = should_stop = None

    # Test outputs
    total_cases = np.zeros((Z.batch_size))
    succ = np.zeros((Z.batch_size))

    testcheck = {}
    testcheck['iter'] = []
    testcheck['step'] = []
    testcheck['localsmap'] = []
    testcheck['target'] = []
    testcheck['loc'] = []
    testcheck['fr'] = []
    testcheck['value'] = []
    testcheck['excuted_actions'] = []
    testcheck['reachgoal'] = []
    cnt = 0
    notwrite = True

    for i in range(iters):
        tt.tic()

        # Initialize the agent.
        init_env_state = agent.reset(
            rng_data[0],
            multi_target=['television', 'stand', 'desk', 'toilet'])
        # init_env_state = agent.reset(rng_data[0], single_target='sofa')
        # Given
        init_env_state = agent.startatPos([[-15, 15, 0]] * Z.batch_size)
        print(agent.epi.targets)

        # Get and process the common data.
        input = agent.get_common_data()
        feed_dict = prepare_feed_dict(Z.input_tensors['common'], input)
        if dagger_sample_bn_false:
            feed_dict[Z.train_ops['batch_norm_is_training_op']] = False
        common_data = sess.run(Z.train_ops['common'], feed_dict=feed_dict)

        states = []
        state_features = []
        state_target_actions = []
        executed_actions = []
        reachgoal = []
        rewards = []
        action_sample_wts = []
        states.append(init_env_state)

        num_steps = 80
        for j in range(num_steps):
            f = agent.get_step_data()
            f['stop_gt_act_step_number'] = np.ones(
                (1, 1, 1), dtype=np.int32) * j
            state_features.append(f)

            feed_dict = prepare_feed_dict(
                Z.input_tensors['step'],
                state_features[-1])  # Feed in latest state features
            optimal_action = agent.get_batch_gt_actions()
            for x, v in zip(Z.train_ops['common'], common_data):
                feed_dict[x] = v
            if dagger_sample_bn_false:
                feed_dict[Z.train_ops['batch_norm_is_training_op']] = False
            outs = sess.run([
                Z.train_ops['step'], Z.sample_gt_prob_op, Z.fr_ops, Z.value_ops
            ],
                            feed_dict=feed_dict)
            action_probs = outs[0]
            sample_gt_prob = outs[1]

            dic_optimal_actions = vars(Foo(action=optimal_action))
            state_target_actions.append(dic_optimal_actions)

            if j < num_steps - 1:
                # Sample from action_probs and optimal action.
                action, action_sample_wt = sample_action(
                    rng_action, action_probs, optimal_action, sample_gt_prob,
                    Z.sample_action_type, Z.sample_action_combine_type)
                # TODO get reward feedback
                # next_state, reward = agent.step(action)
                # locs = f['loc_on_map']

                # if mode == 'test' and cnt < 30:
                #
                #     for btch in range(Z.batch_size):
                #         target_loc = common_data[1][btch]; crnt_loc = f['loc_on_map'][btch]
                #         xt = target_loc[0][0]; yt = target_loc[0][1]
                #         xc = crnt_loc[0][0]; yc = crnt_loc[0][1]; orienc = crnt_loc[0][2]
                #         # if abs(xt - xc) + abs(yt - yc) < 10 and orienc == 0:
                #         if xc in range(-17, -12) and yc in range(10, 17) and \
                #                 agent.epi.targets[btch] in {'television', 'stand', 'toilet'}:  # and orienc == 0:
                #             testcheck['iter'] += [[i]]
                #             testcheck['step'] += [[j]]
                #             testcheck['localsmap'] += [[f['locsmap_{:d}'.format(_)][btch][0] for _ in range(len(agent.navi.map_orig_sizes))]]
                #             testcheck['target'] += [agent.epi.targets[btch]]
                #             testcheck['loc'] += [[xc, yc, orienc]]
                #             testcheck['fr'] += [[outs[2][sc][btch] for sc in range(3)]]
                #             testcheck['value'] += [[outs[3][sc][btch] for sc in range(3)]]
                #             testcheck['excuted_actions'] += [action[btch]]
                #             # testcheck['reachgoal'] += reachgoal
                #
                #             cnt += 1
                #
                if mode == 'test' and 30 >= cnt >= 20 and notwrite:
                    pickle.dump(testcheck,
                                open('%s/fr_value_124.pkl' % (logdir), 'wb'))
                    notwrite = False
                if mode == 'test' and notwrite:

                    target_loc = common_data[1]
                    crnt_loc = f['loc_on_map']
                    testcheck['step'] += [[j]]
                    testcheck['localsmap'] += [[
                        f['locsmap_{:d}'.format(_)]
                        for _ in range(len(agent.navi.map_orig_sizes))
                    ]]
                    testcheck['target'] = [agent.epi.targets]
                    testcheck['loc'] += [crnt_loc]
                    testcheck['fr'] += [outs[2]]
                    testcheck['value'] += [outs[3]]
                    testcheck['excuted_actions'] += [action]
                    testcheck['targetloc'] = agent.epi.target_locs

                # Step a batch of actions
                next_state = agent.step(action)
                reachgoal.append(agent.reachgoal)
                executed_actions.append(action)
                states.append(next_state)
                # rewards.append(reward)
                action_sample_wts.append(action_sample_wt)
                # net_state = dict(zip(Z.train_ops['state_names'], net_state))
                # net_state_to_input.append(net_state)
        if mode == 'test' and notwrite:
            pickle.dump(testcheck, open('%s/fr_value_1234.pkl' % (logdir),
                                        'wb'))
            notwrite = False

        # Concatenate things together for training.

        # rewards = np.array(rewards).T
        # action_sample_wts = np.array(action_sample_wts).T
        # executed_actions = np.array(executed_actions).T
        iter_final_state = state_features[-1]['if_reach_goal']
        assert iter_final_state.shape[0] == Z.batch_size
        succ += np.logical_xor(np.ones(Z.batch_size), iter_final_state[:, 0,
                                                                       0])

        total_cases += 1
        print('success rate in the %dth iteration.' % i,
              np.divide(succ, total_cases))

        all_state_targets = concat_state_x(state_target_actions, ['action'])
        all_state_features = concat_state_x(
            state_features,
            agent.get_step_data_names() + ['stop_gt_act_step_number'])
        # all_state_net = concat_state_x(net_state_to_input,
        # Z.train_ops['state_names'])
        # all_step_data_cache = concat_state_x(step_data_cache,
        #                                      Z.train_ops['step_data_cache'])

        dict_train = dict(input)
        dict_train.update(all_state_features)
        dict_train.update(all_state_targets)

        # dict_train.update({     # 'rewards': rewards,
        #                    'action_sample_wts': action_sample_wts,
        #                    'executed_actions': executed_actions})
        feed_dict = prepare_feed_dict(Z.input_tensors['train'], dict_train)

        if mode == 'train':
            n_step = sess.run(global_step)
            print(n_step)
            if np.mod(n_step, train_display_interval) == 0:
                total_loss, np_global_step, summary, print_summary = sess.run(
                    [
                        train_op, global_step, s_ops.summary_ops,
                        s_ops.print_summary_ops
                    ],
                    feed_dict=feed_dict)
                logging.error("")
            else:
                total_loss, np_global_step, summary = sess.run(
                    [train_op, global_step, s_ops.summary_ops],
                    feed_dict=feed_dict)

            if writer is not None and summary is not None:
                writer.add_summary(summary, np_global_step)

            should_stop = sess.run(Z.should_stop_op)

        if mode != 'train':
            arop = [[] for j in range(len(s_ops.additional_return_ops))]
            for j in range(len(s_ops.additional_return_ops)):
                if s_ops.arop_summary_iters[
                        j] < 0 or i < s_ops.arop_summary_iters[j]:
                    arop[j] = s_ops.additional_return_ops[j]
            val = sess.run(arop, feed_dict=feed_dict)
            val_additional_ops.append(val)
            tt.toc(log_at=60,
                   log_str='val timer {:d} / {:d}: '.format(i, iters),
                   type='time')

    if mode != 'train':
        # Write the default val summaries.
        summary, print_summary, np_global_step = sess.run(
            [s_ops.summary_ops, s_ops.print_summary_ops, global_step])
        if writer is not None and summary is not None:
            writer.add_summary(summary, np_global_step)

        # write custom validation ops
        val_summarys = []
        val_additional_ops = zip(*val_additional_ops)
        if len(s_ops.arop_eval_fns) > 0:
            val_metric_summary = tf.summary.Summary()
            for i in range(len(s_ops.arop_eval_fns)):
                val_summary = None
                if s_ops.arop_eval_fns[i] is not None:
                    val_summary = s_ops.arop_eval_fns[i](
                        val_additional_ops[i], np_global_step, logdir,
                        val_metric_summary, s_ops.arop_summary_iters[i])
                val_summarys.append(val_summary)
            if writer is not None:
                writer.add_summary(val_metric_summary, np_global_step)

        # Return the additional val_ops
        total_loss = (val_additional_ops, val_summarys)
        should_stop = None

    return total_loss, should_stop
예제 #7
0
def train_step_custom_v2(sess,
                         train_op,
                         global_step,
                         train_step_kwargs,
                         mode='train'):
    m = train_step_kwargs['m']
    obj = train_step_kwargs['obj']
    rng = train_step_kwargs['rng']
    writer = train_step_kwargs['writer']
    iters = train_step_kwargs['iters']
    logdir = train_step_kwargs['logdir']
    train_display_interval = train_step_kwargs['train_display_interval']

    s_ops = m.summary_ops[mode]
    val_additional_ops = []

    # Print all variables here.
    if False:
        v = tf.get_collection(tf.GraphKeys.VARIABLES)
        v_op = [_.value() for _ in v]
        v_op_value = sess.run(v_op)

        filter = lambda x, y: 'Adam' in x.name
        # filter = lambda x, y: np.is_any_nan(y)
        ind = [
            i for i, (_, __) in enumerate(zip(v, v_op_value)) if filter(_, __)
        ]
        v = [v[i] for i in ind]
        v_op_value = [v_op_value[i] for i in ind]

        for i in range(len(v)):
            logging.info('XXXX: variable: %30s, is_any_nan: %5s, norm: %f.',
                         v[i].name, np.any(np.isnan(v_op_value[i])),
                         np.linalg.norm(v_op_value[i]))

    tt = utils.Timer()
    for i in range(iters):
        tt.tic()
        e = obj.sample_env(rng)
        rngs = e.gen_rng(rng)
        input_data = e.gen_data(*rngs)
        input_data = e.pre_data(input_data)
        feed_dict = prepare_feed_dict(m.input_tensors, input_data)

        if mode == 'train':
            n_step = sess.run(global_step)

            if np.mod(n_step, train_display_interval) == 0:
                total_loss, np_global_step, summary, print_summary = sess.run(
                    [
                        train_op, global_step, s_ops.summary_ops,
                        s_ops.print_summary_ops
                    ],
                    feed_dict=feed_dict)
            else:
                total_loss, np_global_step, summary = sess.run(
                    [train_op, global_step, s_ops.summary_ops],
                    feed_dict=feed_dict)

            if writer is not None and summary is not None:
                writer.add_summary(summary, np_global_step)

            should_stop = sess.run(m.should_stop_op)

        if mode != 'train':
            arop = [[] for j in range(len(s_ops.additional_return_ops))]
            for j in range(len(s_ops.additional_return_ops)):
                if s_ops.arop_summary_iters[
                        j] < 0 or i < s_ops.arop_summary_iters[j]:
                    arop[j] = s_ops.additional_return_ops[j]
            val = sess.run(arop, feed_dict=feed_dict)
            val_additional_ops.append(val)
            tt.toc(log_at=60,
                   log_str='val timer {:d} / {:d}: '.format(i, iters),
                   type='time')

    if mode != 'train':
        # Write the default val summaries.
        summary, print_summary, np_global_step = sess.run(
            [s_ops.summary_ops, s_ops.print_summary_ops, global_step])
        if writer is not None and summary is not None:
            writer.add_summary(summary, np_global_step)

        # write custom validation ops
        val_summarys = []
        val_additional_ops = zip(*val_additional_ops)
        if len(s_ops.arop_eval_fns) > 0:
            val_metric_summary = tf.summary.Summary()
            for i in range(len(s_ops.arop_eval_fns)):
                val_summary = None
                if s_ops.arop_eval_fns[i] is not None:
                    val_summary = s_ops.arop_eval_fns[i](
                        val_additional_ops[i], np_global_step, logdir,
                        val_metric_summary, s_ops.arop_summary_iters[i])
                val_summarys.append(val_summary)
            if writer is not None:
                writer.add_summary(val_metric_summary, np_global_step)

        # Return the additional val_ops
        total_loss = (val_additional_ops, val_summarys)
        should_stop = None

    return total_loss, should_stop
예제 #8
0
def train_step_custom_online_sampling(sess,
                                      train_op,
                                      global_step,
                                      train_step_kwargs,
                                      mode='train'):
    m = train_step_kwargs['m']
    obj = train_step_kwargs['obj']
    rng_data = train_step_kwargs['rng_data']
    rng_action = train_step_kwargs['rng_action']
    writer = train_step_kwargs['writer']
    iters = train_step_kwargs['iters']
    num_steps = train_step_kwargs['num_steps']
    logdir = train_step_kwargs['logdir']
    dagger_sample_bn_false = train_step_kwargs['dagger_sample_bn_false']
    train_display_interval = train_step_kwargs['train_display_interval']
    if 'outputs' not in m.train_ops:
        m.train_ops['outputs'] = []

    s_ops = m.summary_ops[mode]
    val_additional_ops = []

    # Print all variables here.
    if False:
        v = tf.get_collection(tf.GraphKeys.VARIABLES)
        v_op = [_.value() for _ in v]
        v_op_value = sess.run(v_op)

        filter = lambda x, y: 'Adam' in x.name
        # filter = lambda x, y: np.is_any_nan(y)
        ind = [
            i for i, (_, __) in enumerate(zip(v, v_op_value)) if filter(_, __)
        ]
        v = [v[i] for i in ind]
        v_op_value = [v_op_value[i] for i in ind]

        for i in range(len(v)):
            logging.info('XXXX: variable: %30s, is_any_nan: %5s, norm: %f.',
                         v[i].name, np.any(np.isnan(v_op_value[i])),
                         np.linalg.norm(v_op_value[i]))

    tt = utils.Timer()
    for i in range(iters):
        tt.tic()
        # Sample a room.
        e = obj.sample_env(rng_data)

        # Initialize the agent.
        init_env_state = e.reset(rng_data)

        # Get and process the common data.
        input = e.get_common_data()
        input = e.pre_common_data(input)
        feed_dict = prepare_feed_dict(m.input_tensors['common'], input)
        if dagger_sample_bn_false:
            feed_dict[m.train_ops['batch_norm_is_training_op']] = False
        common_data = sess.run(m.train_ops['common'], feed_dict=feed_dict)

        states = []
        state_features = []
        state_targets = []
        net_state_to_input = []
        step_data_cache = []
        executed_actions = []
        rewards = []
        action_sample_wts = []
        states.append(init_env_state)

        net_state = sess.run(m.train_ops['init_state'], feed_dict=feed_dict)
        net_state = dict(zip(m.train_ops['state_names'], net_state))
        net_state_to_input.append(net_state)
        for j in range(num_steps):
            f = e.get_features(states[j], j)
            f = e.pre_features(f)
            f.update(net_state)
            f['step_number'] = np.ones((1, 1, 1), dtype=np.int32) * j
            state_features.append(f)

            feed_dict = prepare_feed_dict(m.input_tensors['step'],
                                          state_features[-1])
            optimal_action = e.get_optimal_action(states[j], j)
            for x, v in zip(m.train_ops['common'], common_data):
                feed_dict[x] = v
            if dagger_sample_bn_false:
                feed_dict[m.train_ops['batch_norm_is_training_op']] = False
            outs = sess.run([
                m.train_ops['step'], m.sample_gt_prob_op,
                m.train_ops['step_data_cache'], m.train_ops['updated_state'],
                m.train_ops['outputs']
            ],
                            feed_dict=feed_dict)
            action_probs = outs[0]
            sample_gt_prob = outs[1]
            step_data_cache.append(
                dict(zip(m.train_ops['step_data_cache'], outs[2])))
            net_state = outs[3]
            if hasattr(e, 'update_state'):
                outputs = outs[4]
                outputs = dict(zip(m.train_ops['output_names'], outputs))
                e.update_state(outputs, j)
            state_targets.append(e.get_targets(states[j], j))

            if j < num_steps - 1:
                # Sample from action_probs and optimal action.
                action, action_sample_wt = sample_action(
                    rng_action, action_probs, optimal_action, sample_gt_prob,
                    m.sample_action_type, m.sample_action_combine_type)
                next_state, reward = e.take_action(states[j], action, j)
                executed_actions.append(action)
                states.append(next_state)
                rewards.append(reward)
                action_sample_wts.append(action_sample_wt)
                net_state = dict(zip(m.train_ops['state_names'], net_state))
                net_state_to_input.append(net_state)

        # Concatenate things together for training.
        rewards = np.array(rewards).T
        action_sample_wts = np.array(action_sample_wts).T
        executed_actions = np.array(executed_actions).T
        all_state_targets = concat_state_x(state_targets, e.get_targets_name())
        all_state_features = concat_state_x(
            state_features,
            e.get_features_name() + ['step_number'])
        # all_state_net = concat_state_x(net_state_to_input,
        # m.train_ops['state_names'])
        all_step_data_cache = concat_state_x(step_data_cache,
                                             m.train_ops['step_data_cache'])

        dict_train = dict(input)
        dict_train.update(all_state_features)
        dict_train.update(all_state_targets)
        # dict_train.update(all_state_net)
        dict_train.update(net_state_to_input[0])
        dict_train.update(all_step_data_cache)
        dict_train.update({
            'rewards': rewards,
            'action_sample_wts': action_sample_wts,
            'executed_actions': executed_actions
        })
        feed_dict = prepare_feed_dict(m.input_tensors['train'], dict_train)
        for x in m.train_ops['step_data_cache']:
            feed_dict[x] = all_step_data_cache[x]
        if mode == 'train':
            n_step = sess.run(global_step)

            if np.mod(n_step, train_display_interval) == 0:
                total_loss, np_global_step, summary, print_summary = sess.run(
                    [
                        train_op, global_step, s_ops.summary_ops,
                        s_ops.print_summary_ops
                    ],
                    feed_dict=feed_dict)
                logging.error("")
            else:
                total_loss, np_global_step, summary = sess.run(
                    [train_op, global_step, s_ops.summary_ops],
                    feed_dict=feed_dict)

            if writer is not None and summary is not None:
                writer.add_summary(summary, np_global_step)

            should_stop = sess.run(m.should_stop_op)

        if mode != 'train':
            arop = [[] for j in range(len(s_ops.additional_return_ops))]
            for j in range(len(s_ops.additional_return_ops)):
                if s_ops.arop_summary_iters[
                        j] < 0 or i < s_ops.arop_summary_iters[j]:
                    arop[j] = s_ops.additional_return_ops[j]
            val = sess.run(arop, feed_dict=feed_dict)
            val_additional_ops.append(val)
            tt.toc(log_at=60,
                   log_str='val timer {:d} / {:d}: '.format(i, iters),
                   type='time')

    if mode != 'train':
        # Write the default val summaries.
        summary, print_summary, np_global_step = sess.run(
            [s_ops.summary_ops, s_ops.print_summary_ops, global_step])
        if writer is not None and summary is not None:
            writer.add_summary(summary, np_global_step)

        # write custom validation ops
        val_summarys = []
        val_additional_ops = zip(*val_additional_ops)
        if len(s_ops.arop_eval_fns) > 0:
            val_metric_summary = tf.summary.Summary()
            for i in range(len(s_ops.arop_eval_fns)):
                val_summary = None
                if s_ops.arop_eval_fns[i] is not None:
                    val_summary = s_ops.arop_eval_fns[i](
                        val_additional_ops[i], np_global_step, logdir,
                        val_metric_summary, s_ops.arop_summary_iters[i])
                val_summarys.append(val_summary)
            if writer is not None:
                writer.add_summary(val_metric_summary, np_global_step)

        # Return the additional val_ops
        total_loss = (val_additional_ops, val_summarys)
        should_stop = None

    return total_loss, should_stop