def make_test_cstr_fns(env, feature_fn, states, cstr_mu, cstr_feat_id, ths=1.): """Construct a constraint function per partition true: valid """ n_partitions = len(cstr_mu) features = fu.get_features_from_states(env, states, feature_fn) def make_cstr_fn(z_idx): mu = np.array(cstr_mu[z_idx]) # length equals number of features ## std = param_per_partition[z_idx]['std'] def cstr_fn(state_ids): f = features[state_ids] if len(np.shape(f)) == 2: ret = np.sum( np.abs(f[:, cstr_feat_id] - mu[cstr_feat_id]) > ths, axis=-1) else: ret = np.sum(np.abs(f[cstr_feat_id] - mu[cstr_feat_id]) > ths) return ret == 0 return cstr_fn return [make_cstr_fn(i) for i in range(n_partitions)]
def find_goal(mdp, env, log, states, feature_fn, roadmap, error=1e-10, ths=1e-3, queue_size=1000, enable_cstr=True, cstr_ths=2.33, use_discrete_state=True, use_nearest_goal=True, return_policy=True, **kwargs): irl_support_feature_ids = log['support_feature_ids'] irl_support_feature_values = log['support_feature_values'] cstr_feat_id = log['cstr_feat_id'] # expected goals and constraints from a demonstration goal_features, cstr_ids, cstr_mus, _ = bic.get_expected_goal( log, states, enable_cstr=enable_cstr, queue_size=queue_size) # Now, we find goal and constraint on a new environment # n_cstr equals n_partitions cstr_fns = make_test_cstr_fns(env, feature_fn, states, cstr_mus, cstr_feat_id, ths=cstr_ths) if use_discrete_state: T_org = copy.copy(mdp.T) # find feature goals features = fu.get_features_from_states(env, states, feature_fn) distFunc = kwargs.get('distFunc', None) # compute q_mat per sub-goal new_goals = [] for i, (f_id, c_id, c_mu) in enumerate(zip(goal_features, cstr_ids, cstr_mus)): print "Find {}th goal, cstr={} ".format(i, c_id) # feature goal idx = irl_support_feature_ids.index(f_id) f = irl_support_feature_values[idx] # get rewards rewards = mdp.get_rewards() rewards = np.array(rewards) rewards[np.where(rewards > 0)] = 0. # find the closest feature from goal features d = np.linalg.norm(features - f, ord=np.inf, axis=-1) dist_ths = ths ## if np.amin(d) > dist_ths: ## dist_ths = np.amin(d) #from IPython import embed; embed()#; sys.exit() bad_goals = [] while True: s_ids = [j for j in range(len(d)) if d[j] <= dist_ths] if len(s_ids) > 0: goal_found = False for idx in s_ids: if idx in bad_goals: continue # find a goal that violates constraints if c_id == 0 and cstr_fns[i](idx) is False: print "Removed bad goals violating constraints" print features[idx] bad_goals.append(idx) continue rx1, _ = dijkstra_planning.dijkstra_planning( env, env.start_state, states[idx], env.roadmap, env.states, distFunc=distFunc) if rx1 is not None: goal_found = True break bad_goals.append(idx) print s_ids, " : Goal found? ", goal_found, " dist ths: ", dist_ths if goal_found is False: print "Goal feature may not match with current goal setup?" if goal_found: break dist_ths += ths print "Found goals: ", s_ids ## print states[s_ids] ## print env.get_goal_state() #from IPython import embed; embed()#; sys.exit() # Select the nearest state from goal and start states if len(s_ids) > 1 and use_nearest_goal is False: dist = [] for j, idx in enumerate(s_ids): rx1, _ = dijkstra_planning.dijkstra_planning( env, env.start_state, states[idx], env.roadmap, env.states, distFunc=distFunc) if rx1 is None: dist.append(np.inf) continue rx2, _ = dijkstra_planning.dijkstra_planning( env, env.start_state, states[idx], env.roadmap, env.states, distFunc=distFunc) if rx1 is None: dist.append(np.inf) continue dist.append(len(rx1) + len(rx2)) #from IPython import embed; embed(); sys.exit() min_j = np.argmin(dist) s_ids = s_ids[min_j:min_j + 1] print "Selected a reachable state as a goal {}".format(s_ids) ## elif len(s_ids)>1 and use_nearest_goal: ## s_ if return_policy: rewards[s_ids] = 1. mdp.set_rewards(rewards) # NOTE: we only use single constraint (0: constrained, 1: free) if enable_cstr is False or (cstr_fns is None or c_id > 0 or c_id == -1): #or (type(cstr_fns[i]) is list and c_id == len(cstr_fns[i])) \ #or c_id == -1: # no constraint case mdp.T = copy.copy(T_org) else: # constraint case validity_map = cstr_fns[i](range(len(states)))[roadmap] validity_map[:, 0] = True T = T_org * validity_map[:, np.newaxis, :] sum_T = np.sum(T, axis=-1) sum_T[np.where(sum_T == 0.)] = 1. T /= sum_T[:, :, np.newaxis] mdp.T = T ## #from IPython import embed; embed()#; sys.exit() ## #sys.path.insert(0,'..') ## from viz import viz as v ## r = cstr_fns[i](range(len(states))) ## v.reward_plot(r, states) ## ## v.reward_plot_3d(r, states, env) ## #sys.exit() mdp.set_goal(s_ids) ## values, param_dict = mdp.solve_mdp(error, return_params=True)#, max_cnt=100) policy, values = mdp.find_policy(error) else: policy = [] if distFunc is None: idx = np.argmin( np.linalg.norm(states[s_ids] - env.get_start_state(), axis=-1)) else: idx = np.argmin(distFunc(states[s_ids], env.get_start_state())) if enable_cstr: new_goals.append( [s_ids[idx], copy.copy(policy), f_id, c_mu, c_id]) else: new_goals.append([s_ids[idx], copy.copy(policy), f_id]) return new_goals else: new_goals = [] state = env.get_start_state() for i, (f_id, c_id, c_mu) in enumerate(zip(goal_features, cstr_ids, cstr_mus)): print "Find {}th goal, cstr={} ".format(i, c_id) # feature goal idx = irl_support_feature_ids.index(f_id) f = irl_support_feature_values[idx] if enable_cstr: # find the closest state from a feature f s = find_minimum_cost_state(state, env, f, feature_fn, cstr_feat_id, c_id, c_mu, cstr_ths) new_goals.append([s, None, f_id, c_mu, c_id]) else: # find the closest state from a feature f s = find_minimum_cost_state(state, env, f, feature_fn) new_goals.append([s, None, f_id]) state = s return new_goals
def bn_irl(env, roadmap, skdtree, states, T, gamma, trajs, idx_trajs, feature_fn, alphas=(0.1,1.0), eta=0.5, punishment=0., Ts=[0.1, 0.7], **kwargs): """ Bayesian Nonparametric Inverse Reinforcement Learning (BN IRL) inputs: gamma float - RL discount factor trajs a list of demonstrations lr float - learning rate n_iters int - number of optimization steps alphas confidence for resampling and assignments eta concentration (alpha in CRP) returns rewards Nx1 vector - recoverred state rewards """ n_goals = kwargs.get('n_goals', 2) burn_in = kwargs.get('burn_in', 1000) n_iters = kwargs.get('n_iter', 2000) n_iters = n_iters if n_iters > burn_in else int(n_iters*1.5) use_clusters = True use_action_prior = kwargs.get('use_action_prior', False) alpha1, alpha2 = alphas max_cnt = kwargs.get('max_cnt', 100) return_best = False N_STATES = len(states) N_ACTIONS = len(roadmap[0]) # support observations (i.e., demonstrations) observations = ut.trajs2observations(trajs, idx_trajs, roadmap, states) support_states = bic.get_support_space(observations) n_observations = len(observations) n_support_states = len(support_states) observation_states = [obs.state for obs in observations] observation_actions = [obs.action for obs in observations] # vi agent agent = vi.valueIterAgent(N_ACTIONS, N_STATES, roadmap, skdtree, states, rewards=kwargs.get('rewards', None), gamma=gamma, T=T) # precomputation Q and pi per support features that is a list of goals [feature, policy] support_values = None support_validity = None if os.path.isfile(kwargs['sav_filenames']['Q']) is False or kwargs.get('vi_renew', False): assert len(np.shape(trajs))==3, "wrong shape or number of trajectories" print ("Renewing q functions") feat_map = fu.get_features_from_states(env, states, feature_fn) # get features given each feature sub goal support_feature_ids, support_feature_values = bic.get_support_feature_space(support_states, states, feat_map, env) support_feature_state_dict = bic.get_feature_to_state_dict(support_states, support_feature_ids) support_policy, support_values =\ biq.computeQ(agent, support_states, support_features=(support_feature_ids, support_feature_values), support_feature_state_dict=support_feature_state_dict, max_cnt=max_cnt) d = {} d['support_policy'] = support_policy d['support_feature_ids'] = support_feature_ids d['support_feature_values'] = support_feature_values d['support_feature_state_dict'] = support_feature_state_dict pickle.dump( d, open( kwargs['sav_filenames']['Q'], "wb" ) ) else: d = pickle.load( open(kwargs['sav_filenames']['Q'], "rb")) support_policy = d['support_policy'] support_feature_ids = d['support_feature_ids'] support_feature_values = d['support_feature_values'] support_feature_state_dict = d['support_feature_state_dict'] # initialization renew_log = True if renew_log: goals, z = bic.init_irl_params(n_observations, n_goals, support_policy, support_states, support_feature_ids,\ support_feature_state_dict, observations) # log log = {'goals': [], 'z': []} log['observations'] = observations log['eta'] = eta log['alphas'] = alphas log['support_states'] = support_states log['support_feature_ids'] = support_feature_ids log['support_feature_values'] = support_feature_values log['support_feature_state_dict'] = support_feature_state_dict else: print "Loaded saved log" log = pickle.load( open(kwargs['sav_filenames']['irl'], "rb")) z = log['z'][-1] goals = log['goals'][-1] burn_in = 0 #======================================================================================= eps = np.finfo(float).eps tqdm_e = tqdm(range(n_iters), desc='Score', leave=True, unit=" episodes") for iteration in tqdm_e: # sample subgoal & constraints for j, goal in enumerate(goals): observation_states_part = [s for z_i, s in zip(z, observation_states) if z_i==j] observation_actions_part = [a for z_i, a in zip(z, observation_actions) if z_i==j] goals[j] = bic.resample(observation_states_part, observation_actions_part, support_states, support_policy, #prior=prior, alpha=alpha1, support_feature_ids=support_feature_ids, support_feature_state_dict=support_feature_state_dict, T=Ts[0], punishment=punishment, return_best=return_best,) if iteration > burn_in: # re-ordering z and goals new_z, new_goals = bic.reorder(z, goals, support_states) # remove policies to reduce memory load log_goals = copy.deepcopy(new_goals) for i in range(len(log_goals)): log_goals[i][1] = None log['goals'].append(log_goals) log['z'].append(copy.deepcopy(new_z)) ## if iteration%500==0: ## pickle.dump( log, open( kwargs['sav_filenames']['irl'], "wb" ) ) # sample assignment / each observation tmp_use_clusters=use_clusters for i, obs in enumerate(observations): goal_state_support_ids = [support_states.index(goal[0]) for goal in goals] #reassignment z, goals = bic.sample_partition_assignment(obs, i, z, goals,\ support_states, support_policy, use_clusters=tmp_use_clusters, eta=eta, alpha=alpha2, states=states, roadmap=roadmap, support_feature_ids=support_feature_ids, support_feature_state_dict=support_feature_state_dict, punishment=punishment, T=Ts[1], return_best=return_best,) #post process if use_clusters: z, goals = bic.post_process(z, goals) if use_clusters is False: if len(goals) != n_goals: tmp_use_clusters = True else: tmp_use_clusters = False tqdm_e.set_description("goals: {0:.1f}".format(len(goals))) tqdm_e.refresh() pickle.dump( log, open( kwargs['sav_filenames']['irl'], "wb" ) ) return log
def bn_irl(env, roadmap, skdtree, states, T, gamma, trajs, idx_trajs, feature_fn, alphas=(0.1, 1.0), eta=0.5, punishment=0., Ts=[0.1, 0.7], num_feat=100, cstr_ths=2.33, window_size=5, **kwargs): """ Bayesian Nonparametric Inverse Reinforcement Learning (BN IRL) inputs: gamma float - RL discount factor trajs a list of demonstrations lr float - learning rate n_iters int - number of optimization steps alphas confidence for resampling and assignments eta concentration (alpha in CRP) returns rewards Nx1 vector - recoverred state rewards """ n_goals = kwargs.get('n_goals', 3) burn_in = kwargs.get('burn_in', 1000) n_iters = kwargs.get('n_iter', 2000) n_iters = n_iters if n_iters > burn_in else int(n_iters * 1.5) use_clusters = True use_action_prior = kwargs.get('use_action_prior', False) alpha1, alpha2 = alphas max_cnt = kwargs.get('max_cnt', 100) return_best = False N_STATES = len(states) N_ACTIONS = len(roadmap[0]) # support observations (i.e., demonstrations) observations = ut.trajs2observations(trajs, idx_trajs, roadmap, states) support_states = bic.get_support_space(observations) n_observations = len(observations) n_support_states = len(support_states) observation_states = [obs.state for obs in observations] observation_actions = [obs.action for obs in observations] # visualize feature distribution ## feat_map = fu.get_features_from_states(env, states, feature_fn) ## feat_trajs = np.array(ut.trajs2featTrajs(idx_trajs, feat_map)) ## viz_feat_traj(feat_trajs[0]) # precomputation Q and pi per support features that is a list of goals [feature, policy] time0 = time.time() if os.path.isfile(kwargs['sav_filenames']['Q']) is False or kwargs.get( 'vi_renew', True): assert len( np.shape(trajs)) == 3, "wrong shape or number of trajectories" # trajs, idx_trajs, feat_trajs have the same order feat_map = fu.get_features_from_states(env, states, feature_fn) feat_trajs = np.array(ut.trajs2featTrajs(idx_trajs, feat_map)) # get features given each feature sub goal # support_feature: each support state's feature index, {indices | s_idx \in S, f(s)==f(s_g)} # support_feature_values: {feature values | ... } support_feature_ids, support_feature_values = bic.get_support_feature_space( support_states, states, feat_map, env) support_feature_state_dict = bic.get_feature_to_state_dict( support_states, support_feature_ids) # Select cstr_feat_id cstr_feat_id = [] for i in range(np.shape(feat_trajs)[-1]): fs = feat_trajs[0, :, i] stds = [] for j in range(len(fs) - window_size - 1): stds.append(np.std(fs[j:j + window_size])) stds = np.array(stds) print np.mean(stds * 2.), cstr_ths, cstr_ths / float(window_size) # check the average of variances over moving windows #if np.mean(stds*2.) < cstr_ths: #/float(window_size): if np.mean(stds * 2) < cstr_ths: cstr_feat_id.append(i) print "CSTR ID: ", cstr_feat_id #TODO: currently, we simplified feature range and pairs! # discretize features f_min = np.amin(feat_trajs[0, :, cstr_feat_id], axis=-1) f_max = np.amax(feat_trajs[0, :, cstr_feat_id], axis=-1) feat_range = np.linspace(f_min, f_max, num_feat) #Nx2 support_policy_dict = {} support_validity_dict = {} # compute policies for states x features for i in range(len(feat_range)): print "f_idx = {}/{}".format(i, len(feat_range)) # compute constraint function over observations feat = np.zeros(len(feat_map[0])) feat[cstr_feat_id] = feat_range[i] cstr_fns, cstr_params = make_train_cstr_fns(feat_map, feat, states, cstr_feat_id, ths=cstr_ths) cstr_map = cstr_fns[0](None, f=feat_map) * 1. # if the score is 0, skip since these constraints do not make sense # 0th element is just needed for other processes if sum(cstr_map[idx_trajs[0]]) == 0 and i > 0 and False: support_policy_dict[i] = {} support_validity_dict[i] = {} support_values = None else: # vi agent agent = vi.valueIterAgent(N_ACTIONS, N_STATES, roadmap, skdtree, states, rewards=kwargs.get('rewards', None), gamma=gamma, T=T) support_policy, support_values, support_validity =\ biq.computeQ(agent, support_states, support_features=(support_feature_ids, support_feature_values), support_feature_state_dict=support_feature_state_dict, cstr_fn=cstr_fns, add_no_cstr=False, max_cnt=max_cnt, feat_map=feat_map, roadmap=roadmap) support_policy_dict[i] = support_policy support_validity_dict[i] = support_validity #from IPython import embed; embed(); sys.exit() # Just to print out the constraint scores cstr_score = 0 for idx in idx_trajs[0]: if support_values is None: continue cstr_score += cstr_map[idx] * support_values[idx_trajs[0][-1]][ 0, idx] print "{}th feat's cstr score = {}".format(i, cstr_score) from viz import viz as v ## ss = states[cstr_map>0] ## cc = cstr_map[cstr_map>0] ## if np.amax(support_values[idx_trajs[0][-1]]) == 0: continue v.reward_value_plot(agent.rewards, support_values[idx_trajs[0][-1]][0], states, trajs=trajs) ## v.reward_value_3d(cc, cc, ss, trajs=trajs, env=env) ## v.reward_value_plot(cstr_map, cstr_map, states, trajs=trajs) ## v.reward_value_3d(cstr_map, cstr_map, states, trajs=trajs, env=env) ## continue # vi agent agent = vi.valueIterAgent(N_ACTIONS, N_STATES, roadmap, skdtree, states, rewards=kwargs.get('rewards', None), gamma=gamma, T=T) # no constraints cstr_fns = [] support_policy, support_values, support_validity =\ biq.computeQ(agent, support_states, support_features=(support_feature_ids, support_feature_values), support_feature_state_dict=support_feature_state_dict, cstr_fn=cstr_fns, add_no_cstr=True, max_cnt=max_cnt) # add normal policy for i in range(len(feat_range)): if len(support_policy_dict[i].keys()) == 0: support_policy_dict[i] = [] support_validity_dict[i] = [] else: for j in support_policy_dict[i].keys(): support_policy_dict[i][j] = np.vstack( [support_policy_dict[i][j], support_policy[j]]) support_validity_dict[i][j] = np.hstack( [support_validity_dict[i][j], support_validity[j]]) support_policy_dict[-1] = support_policy support_validity_dict[-1] = support_validity cstr_score = 0 for idx in idx_trajs[0]: cstr_score += support_values[idx_trajs[0][-1]][0][idx] print "No cstr score = {}".format(cstr_score) ## from viz import viz as v ## v.reward_value_plot(agent.rewards, support_values[idx_trajs[0][-1]][0], states, trajs=trajs) d = {} d['support_policy_dict'] = support_policy_dict d['support_validity_dict'] = support_validity_dict d['feat_trajs'] = feat_trajs d['feat_map'] = feat_map d['feat_range'] = feat_range d['cstr_feat_id'] = cstr_feat_id d['support_feature_ids'] = support_feature_ids d['support_feature_values'] = support_feature_values d['support_feature_state_dict'] = support_feature_state_dict pickle.dump(d, open(kwargs['sav_filenames']['Q'], "wb")) del support_values, support_validity_dict del d else: d = pickle.load(open(kwargs['sav_filenames']['Q'], "rb")) support_policy_dict = d['support_policy_dict'] ## support_validity_dict = d['support_validity_dict'] feat_trajs = d['feat_trajs'] feat_map = d['feat_map'] feat_range = d['feat_range'] support_feature_ids = d['support_feature_ids'] support_feature_values = d['support_feature_values'] support_feature_state_dict = d['support_feature_state_dict'] cstr_feat_id = d['cstr_feat_id'] time1 = time.time() # initialization renew_log = True if renew_log: goals, z = bic.init_irl_params(n_observations, n_goals, support_policy_dict, support_states, support_feature_ids,\ support_feature_state_dict, observations) # log log = {'goals': [], 'z': []} log['observations'] = observations log['eta'] = eta log['alphas'] = alphas log['support_states'] = support_states log['support_feature_ids'] = support_feature_ids log['support_feature_values'] = support_feature_values log['support_feature_state_dict'] = support_feature_state_dict ## log['support_policy_dict'] = support_policy_dict log['cstr_ths'] = cstr_ths log['cstr_feat_id'] = cstr_feat_id ## else: ## print "Loaded saved log" ## log = pickle.load( open(kwargs['sav_filenames']['irl'], "rb")) ## z = log['z'][-1] ## goals = log['goals'][-1] ## burn_in = 0 #======================================================================================= ## eps = np.finfo(float).eps time2 = time.time() tqdm_e = tqdm(range(n_iters), desc='Score', leave=True, unit=" episodes") for iteration in tqdm_e: support_validity_per_goal = [] # sample subgoal & constraints for j, goal in enumerate(goals): observation_states_part = [ s for z_i, s in zip(z, observation_states) if z_i == j ] observation_actions_part = [ a for z_i, a in zip(z, observation_actions) if z_i == j ] # find right policy f = feat_map[observation_states_part] f_mu = np.mean(f, axis=0) f_idx = np.argmin( np.linalg.norm(feat_range - f_mu[cstr_feat_id], axis=-1)) support_policy = support_policy_dict[f_idx] goals[j] = bic.resample_gc( observation_states_part, observation_actions_part, support_states, support_policy, alpha=alpha1, support_feature_ids=support_feature_ids, support_feature_state_dict=support_feature_state_dict, punishment=punishment, T=Ts[0], return_best=return_best, ) goals[j][-2] = {'mu': f_mu} if iteration > burn_in: # re-ordering z and goals new_z, new_goals = bic.reorder(z, goals, support_states) # remove policies to reduce memory load log_goals = copy.deepcopy(new_goals) for i in range(len(log_goals)): log_goals[i][1] = None log['goals'].append(log_goals) log['z'].append(copy.deepcopy(new_z)) #sample assignment / each observation tmp_use_clusters = use_clusters for i, obs in enumerate(observations): try: goal_state_support_ids = [ support_states.index(goal[0]) for goal in goals ] except: print "goal_state_support_ids is wrong" from IPython import embed embed() sys.exit() #reassignment z, goals = bic.sample_partition_assignment(obs, i, z, goals,\ support_states, support_policy_dict, use_clusters=tmp_use_clusters, eta=eta, alpha=alpha2, states=states, roadmap=roadmap, support_feature_ids=support_feature_ids, support_feature_state_dict=support_feature_state_dict, punishment=punishment, T=Ts[1], enable_cstr=True, return_best=return_best, feat_range=feat_range) #post process if use_clusters: z, goals = bic.post_process(z, goals) if use_clusters is False: if len(goals) != n_goals: tmp_use_clusters = True else: tmp_use_clusters = False tqdm_e.set_description( "t: {0:.2f}, post: {1:.2f}), goals: {2:.1f}".format( 0, 0, len(goals))) tqdm_e.refresh() time3 = time.time() print "---------------------------------------------" print "Precomputation Time: {}".format(time1 - time0) print "Gibbs Sampling Time: {}".format(time3 - time2) print "---------------------------------------------" pickle.dump(log, open(kwargs['sav_filenames']['irl'], "wb")) return log
def find_goal(mdp, env, log, states, feature_fn, cstr_fn=None, error=1e-10, ths=1e-3,\ queue_size=1000, use_nearest_goal=True, **kwargs): irl_support_feature_ids = log['support_feature_ids'] irl_support_feature_values = log['support_feature_values'] goal_features, _, _, _ = bic.get_expected_goal(log, states, queue_size=queue_size) T_org = copy.copy(mdp.T) # find feature goals features = fu.get_features_from_states(env, states, feature_fn) distFunc = kwargs.get('distFunc', None) # compute q_mat for a sub-goal new_goals = [] for i, f_id in enumerate(goal_features): print "Find {}th goal".format(i) # feature goal idx = irl_support_feature_ids.index(f_id) f = irl_support_feature_values[idx] # get rewards rewards = mdp.get_rewards() rewards = np.array(rewards) rewards[np.where(rewards>0)]=0. # find the closest state from a goal d = np.linalg.norm(features-f, ord=np.inf, axis=-1) dist_ths = ths if np.amin(d) > dist_ths: dist_ths = np.amin(d) bad_goals = [] while True: s_ids = [j for j in range(len(d)) if d[j] <= dist_ths] if len(s_ids)>0: goal_found=False for idx in s_ids: if idx in bad_goals: continue rx1, _ = dijkstra_planning.dijkstra_planning(env, env.start_state, states[idx], env.roadmap, env.states, distFunc=distFunc) if rx1 is not None: goal_found = True break bad_goals.append(idx) print s_ids, goal_found, dist_ths if goal_found: break dist_ths += ths print "----------------------------" print "Found sub-goals: ", s_ids print "----------------------------", len(s_ids) # Select the nearest state from goal and start states if len(s_ids)>1 and use_nearest_goal is False: dist = [] for j, idx in enumerate(s_ids): rx1, _ = dijkstra_planning.dijkstra_planning(env, env.start_state, states[idx], env.roadmap, env.states, distFunc=distFunc) if rx1 is None: dist.append(np.inf) continue rx2, _ = dijkstra_planning.dijkstra_planning(env, states[idx], env.goal_state, env.roadmap, env.states, distFunc=distFunc) if rx2 is None: dist.append(np.inf) continue dist.append(len(rx1)+len(rx2)) #from IPython import embed; embed(); sys.exit() min_j = np.argmin(dist) s_ids = s_ids[min_j:min_j+1] print "Selected a reachable state as a goal {}".format(s_ids) # set new rewards rewards[s_ids] = 1. mdp.set_rewards(rewards) # print "Start solve policy with new reward and T" mdp.T = copy.copy(T_org) ## values, param_dict = mdp.solve_mdp(error, return_params=True) policy, values = mdp.find_policy(error) new_goals.append([s_ids[0], copy.copy(policy), f_id]) ## new_goals.append([s_ids[0], copy.copy(param_dict['q']), f_id]) return new_goals