예제 #1
0
    def rollback(self, frames):
        '''
        Added this function primarily for reward analysis purpose.
        Provided the frames, this function rolls the environment back in time by the number of
        frames provided
        '''
        self.current_frame = self.current_frame - frames

        if str(self.current_frame) in self.annotation_dict.keys():
            self.get_state_from_frame_universal(self.annotation_dict[str(self.current_frame)])

        if self.external_control:

            self.agent_state = utils.copy_dict(self.pos_history[-frames-1])
            self.cur_heading_dir = self.heading_dir_history[-frames-1]

            if frames > len(self.heading_dir_history):
                print('Trying to rollback more than the size of current history!')
            else:
                for i in range(1,frames+1):

                    self.heading_dir_history.pop(-1)
                    self.pos_history.pop(-1)

        if self.release_control:
            self.release_control = False
        self.state['agent_state'] = utils.copy_dict(self.agent_state)
        if self.display:
            self.render()

        return self.state
예제 #2
0
 def _init(self, *args, **kwargs):
     self.pdb2xyzr = Pdb2xyzr(self.pdb_file,
                              **copy_dict(kwargs, run=False, verbose=False))
     self.output_files = self.pdb2xyzr.output_files + self.output_files
     self.cmd = [
         MSMS_CMD, "-if", self.pdb2xyzr.xyzr_file, "-probe_radius",
         self.probe_radius, "-af", "area", "-of", "tri_surface", "-density",
         self.density, "-hdensity", self.hdensity
     ]
     if self.all_components:
         self.cmd.append("-all_components")
     if self.no_area:
         self.cmd.append("-no_area")
     if self.envelope:
         self.envelope_pdb = self.outpath("envelope.pdb")
         self.envelope_msms = Msms(
             self.pdb_file,
             **copy_dict(kwargs,
                         run=False,
                         envelope=0.0,
                         density=0.3,
                         hdensity=0.3,
                         probe_radius=self.envelope,
                         all_components=False,
                         output_dir=self.subdir("envelope")))
         self.output_files += self.envelope_msms.output_files
예제 #3
0
    def return_position(self, ped_id, frame_id):

        try:
            return utils.copy_dict(self.pedestrian_dict[str(ped_id)][str(frame_id)])
        except KeyError:
            while str(frame_id) not in self.pedestrian_dict[str(ped_id)]:
                frame_id -= 1
            return utils.copy_dict(self.pedestrian_dict[str(ped_id)][str(frame_id)])
예제 #4
0
def collect_trajectories(
    env, feature_extractor, policy, num_trajectories, max_episode_length,
):
    """
    Helper function that collects trajectories and applies metrics from a
    metric applicator on a per trajectory basis.

    :param env: environment to collect trajectories from.
    :type env: any gym-like environment.

    :param feature_extractor: a feature extractor to translate state
    dictionary to a feature vector.
    :type feature_extractor: feature extractor class.

    :param policy: Policy to extract actions from.
    :type policy: standard policy child of BasePolicy.

    :param num_trajectories: Number of trajectories to sample.
    :type num_trajectories: int.

    :param max_episode_length: Maximum length of individual trajectories.
    :type max_episode_length: int.

    :return: dictionary mapping trajectory to metric results from that
    trajectory.
    :rtype: dictionary
    """

    all_trajectories = []

    for traj_idx in range(num_trajectories):

        print("Collecting trajectory {}".format(traj_idx), end="\r")

        state = env.reset()
        done = False
        t = 0
        traj = [copy_dict(state)]

        while not done and t < max_episode_length:

            feat = feature_extractor.extract_features(state)
            feat = torch.from_numpy(feat).type(torch.FloatTensor).to(DEVICE)

            action = policy.eval_action(feat)
            state, _, done, _ = env.step(action)
            traj.append(copy_dict(state))

            t += 1

        all_trajectories.append(traj)

    return all_trajectories
예제 #5
0
    def _init(self, *args, **kwargs):
        self._init_records(None, **kwargs)
        self._init_parallel(self.pdb_input, **kwargs)

        if not self.parallel:
            self.pdb_id = self.pdb_input
            self.npdb_features = {"sstruc": False, "phi_psi": False}

            self.output_files = []
            self.pdb_info = PdbInfo(
                self.pdb_id,
                **copy_dict(kwargs,
                            run=False,
                            output_dir=self.subdir("pdb_info")))
            self.output_files += [self.pdb_info.info_file]

            self.water_variants = [
                ("ori", self.pdb_input),
            ]

            self.tool_list = [
                ("voronoia", Voronoia, {
                    "ex": self.ex,
                    "shuffle": self.voro_shuffle,
                    "make_reference": False,
                    "get_nrholes": False
                }),
            ]

            def filt(lst, lst_all):
                if not lst:
                    return lst_all
                if lst[0] == "!":
                    return [e for e in lst_all if e[0] not in lst[1:]]
                else:
                    return [e for e in lst_all if e[0] in lst]

            #self.water_variants = filt( self.variants, self.water_variants )
            self.do_tool_list = filt(self.tools, self.tool_list)
            for suffix, pdb_file in self.water_variants:
                for prefix, tool, tool_kwargs in self.tool_list:
                    name = "%s_%s" % (prefix, suffix)
                    self.__dict__[name] = tool(
                        pdb_file,
                        **copy_dict(kwargs,
                                    run=False,
                                    output_dir=self.subdir(name),
                                    **tool_kwargs))
                    self.output_files += self.__dict__[name].output_files

            self.output_files += [self.info_file]
            self.output_files += [self.outpath("voronoia.provi")]
            self.output_files += [self.stats_file]
예제 #6
0
def combine_dicts(_td, _td2copy, inplace=True):
    """        
    This function combine 2 dicts.
    In case of same fields, the first dict dominates
    
    Parameters
    ----------
    _td : dict/list of dict
        dict of the trial data
    _td2copy : dict/list of dict
        dict of the trial data
    inplace : string, optional
        Perform operation on the input data dict. The default is False.

    Returns
    -------
    td_out : dict/list of dict
        trial data variable

    """
    if inplace:
        td = _td
    else:
        td = copy_dict(_td)

    td2copy = copy_dict(_td2copy)

    input_dict = False
    if type(td) is dict:
        input_dict = True
        td = [td]

    if type(td2copy) is dict:
        td2copy = [td2copy]

    # check that tds have the same dimension
    if len(td) != len(td2copy):
        raise Exception('ERROR: The 2 tds have different dimension!')

    for td1_el, td2_el in zip(td, td2copy):
        for k, v in td2_el.items():
            if k not in set(td1_el.keys()):
                td1_el[k] = v

    if input_dict:
        td = td[0]

    if not inplace:
        return td
예제 #7
0
    def get_state_from_frame_universal(self, frame_info):
        '''
        For processed datasets
        '''
        self.obstacles = []
        for element in frame_info:

            #populating the obstacles
            if float(element[1]) not in self.skip_list:

                obs = self.pedestrian_dict[element[1]][str(self.current_frame)]
                obs['id'] = element[1]
                self.obstacles.append(obs)
            #populating the agent
            #dont update the agent if training is going on
            if not self.external_control:
                if float(element[1]) == self.cur_ped:
                    agent = self.pedestrian_dict[element[1]][str(self.current_frame)]
                    self.agent_state = agent
                    self.state['agent_state'] = utils.copy_dict(self.agent_state)
                    ref_vector = np.asarray([-1, 0])
                    if self.state['agent_state']['orientation'] is not None:
                        self.cur_heading_dir = (360 + rad_to_deg(total_angle_between(self.state['agent_state']['orientation'], ref_vector)))%360
                    else:
                        self.cur_heading_dir = 0
            #populating the ghost
            if float(element[1]) == self.ghost:


                self.ghost_state = self.pedestrian_dict[element[1]][str(self.current_frame)]
                self.ghost_state_history.append((copy.deepcopy(self.ghost_state), self.current_frame))

        self.state['obstacles'] = self.obstacles
예제 #8
0
 def _init( self, *args, **kwargs ):
     
     self.delete_backbone= SpiderDeleteBackbone(
         self.map_file,
         self.pdb_file,
         self.pixelsize,
         **copy_dict( kwargs, run=False )
     )
예제 #9
0
 def func(self):
     """ Initial procedure by Dominic Theune. """
     rep_count = 0
     dowserwat = []
     dowserwat_all = []
     max_resno = 0
     while True:
         alt = "repeat" if rep_count else self.alt
         rep_out = os.path.join(self.repeat_dir, str(rep_count))
         dowser = Dowser(
             self.pdb_file,
             **copy_dict(self.kwargs,
                         run=False,
                         output_dir=rep_out,
                         alt=alt))
         # write dowser waters found until now
         with open(dowser.wat_file, "w") as fp:
             fp.writelines(dowserwat)
         with open(dowser.watall_file, "w") as fp:
             fp.writelines(dowserwat_all)
         # exec dowser
         dowser()
         # read newly found dowser waters
         with open(dowser.wat_file, "r") as fp:
             new_wat = fp.readlines()
         with open(dowser.watall_file, "r") as fp:
             new_watall = fp.readlines()
         # increment water resno
         new_wat2 = []
         for l in new_wat:
             if l[0:6] == "HETATM":
                 resno = int(l[22:26]) + max_resno
                 l = l[0:22] + "{:>4}".format(resno) + l[26:]
             new_wat2.append(l)
         new_watall2 = []
         for l in new_watall:
             if l[0:6] == "HETATM":
                 resno = int(l[22:26]) + max_resno
                 l = l[0:22] + "{:>4}".format(resno) + l[26:]
             new_watall2.append(l)
         if new_watall2:
             max_resno = max(map(lambda x: int(x[22:26]), new_watall2))
         # check if there are new waters
         if (len(new_wat2) > 1 and len(new_watall2) > 1 and
             (not self.max_repeats or rep_count < self.max_repeats - 1)):
             dowserwat += new_wat2
             dowserwat_all += new_watall2
             rep_count += 1
         else:
             break
     # write every dowser water found
     with open(self.wat_file, "w") as fp:
         fp.writelines(dowserwat)
     with open(self.watall_file, "w") as fp:
         fp.writelines(dowserwat_all)
     # copy intsurf
     shutil.copy(dowser.intsurf_file, self.intsurf_file)
예제 #10
0
 def _init(self, *args, **kwargs):
     self.mapman = Mapman(
         self.brix_input, "brix",
         **copy_dict(kwargs,
                     run=False,
                     newformat="map",
                     output_dir=self.outpath("mapman")))
     self.output_files += self.mapman.output_files
     self.sub_tool_list.append(self.mapman)
예제 #11
0
    def test_make_dict_copy(self):
        original_dict = {
            'name': 'Topicos Avançados em Engenharia de Software'
        }
        original_dict_copy = copy_dict(original_dict)
        original_dict_copy['description'] = 'Ter - Qui'

        self.assertNotEqual(
            original_dict,
            original_dict_copy
        )
예제 #12
0
 def _init( self, *args, **kwargs ):
     if len(self.pdb_input) == 4:
         self.pdb_download = PdbDownload(
             self.pdb_input,
             **copy_dict(
                 kwargs, run=False,
                 output_dir=self.subdir("download")
             )
         )
         self.pdb_file = self.pdb_download.pdb_file_list[0]
     else:
         self.pdb_download = None
         self.pdb_file = self.abspath( self.pdb_input )
예제 #13
0
 def _init( self, *args, **kwargs ):
     self.spider_convert= SpiderConvert(
                             self.map_file,
                             **copy_dict(
                                 kwargs, run=False, output_dir=self.subdir("convert")
                             )
     )
     self.spider_crop= SpiderCropMap(
                         self.spider_convert.map_file,self.pdb_file,self.pixelsize,
                         **copy_dict(
                                 kwargs, run=False, output_dir=self.subdir("crop")
                             )
     )
     self.spider_reconvert=SpiderReConvert(
                                 self.spider_crop.box_file,
                                 self.spider_convert.map_file,
                                 self.spider_crop.box_map_file,
                                 self.spider_convert.map_file,
                                 **copy_dict(
                                 kwargs, run=False, output_dir=self.subdir("mrc")
                             )
                                 
     )
예제 #14
0
    def __init__(
        self,
        display=False,
        is_onehot=False,
        seed=7,
        obstacles=None,
        show_trail=False,
        is_random=False,
        annotation_file="./expert_datasets/university_students/"+
                        "annotation/processed/frame_skip_1/"+
                        "students003_processed_corrected.txt",
        subject=None,
        tick_speed=60,
        obs_width=10,
        step_size=2,
        agent_width=10,
        replace_subject=True,
        segment_size=None,
        external_control=True,
        step_reward=0.001,
        show_comparison=True,
        consider_heading=True,
        show_orientation=True,
        rows=576,
        cols=720,
        width=10,
        ):

        args = utils.copy_dict(locals())

        # these arguments are side effects of calling locals()
        del args['self']
        del args['__class__']

        # construct path from current script directory
        annotation_file_path = os.path.join(
            os.path.dirname(__file__),
            annotation_file
            )

        args['annotation_file'] = annotation_file_path

        super().__init__(**args)
예제 #15
0
파일: smc.py 프로젝트: zcmail/hdhp.py
def _infer_single_thread(history, params):
    prng = RandomState(seed=params.seed)
    time_history_per_user = defaultdict(list)
    doc_history_per_user = defaultdict(list)
    question_history_per_user = defaultdict(list)
    table_history_with_user = []
    dish_on_table_per_user = []

    # Set the accuracy
    count_resamples = 0
    square_norms = []
    with open(params.progress_file, 'a') as out:
        out.write('Starting %d particles on %d thread.\n' % (params.num_particles,
                                                             params.threads))

    start_tic = time()

    # Initialize the particles
    epsilon = 1e-10
    particles = [Particle(theta_0=params.theta_0, alpha_0=params.alpha_0,
                          mu_0=params.mu_0,
                          uid=prng.randint(maxint), seed=prng.randint(maxint),
                          vocabulary_length=len(params.vocabulary),
                          update_kernels=params.update_kernels,
                          omega=params.omega, beta=params.beta,
                          num_users=len(params.users),
                          keep_alpha_history=params.keep_alpha_history,
                          mu_rate=params.mu_rate)
                 for i in range(params.num_particles)]

    inferred_tables = {}  # for each particle, save the topic history
    for p in particles:
        inferred_tables[p.uid] = []
    # Fit each particle to the history
    square_norms = []
    table_history_with_user = []
    dish_on_table_per_user = []
    for i, h_i in enumerate(history):
        max_logweight = None
        weights = []
        total = 0
        t_i, d_i, u_i, q_i = h_i
        if u_i not in time_history_per_user:
            time_history_per_user[u_i] = []
            doc_history_per_user[u_i] = []
            question_history_per_user[u_i] = []
        time_history_per_user[u_i].append(t_i)
        doc_history_per_user[u_i].append(d_i)
        question_history_per_user[u_i].append(q_i)

        for p_i in particles:
            # Fit each particle to the next event
            b_i, z_i = p_i.update(h_i)
            inferred_tables[p_i.uid].append((b_i, z_i))

        if i > 0 and i % params.resample_every == 0:
            # Start resampling
            for p_i in particles:
                if max_logweight is None or max_logweight < p_i.logweight:
                    max_logweight = p_i.logweight
            for p_i in particles:
                # Normalize the weights of the  particles
                if p_i.logweight - max_logweight >= \
                        ln(epsilon) - ln(params.num_particles):
                    weights.append(exp(p_i.logweight - max_logweight))
                else:
                    weights.append(exp(p_i.logweight - max_logweight))
                total += weights[-1]
            normalized = [w / sum(weights) for w in weights]
            # Check if resampling is needed
            norm2 = sum([w ** 2 for w in normalized])
            square_norms.append(norm2)
            if params.num_particles > 1 \
                    and norm2 > params.particle_weight_threshold / params.num_particles\
                    and i < len(history) - 1:
                # Resample particles (though never for the last event)
                count_resamples += 1
                new_particle_indices = pick_new_particles(particles,
                                                          normalized, prng)
                new_particles = []
                new_table_history_with_user = []
                new_dish_on_table_per_user = []
                for index in new_particle_indices:
                    # copy table_history for that particle
                    if len(table_history_with_user):
                        old_history = copy(table_history_with_user[index])
                    else:
                        old_history = []
                    new_history = copy(particles[index].table_history_with_user)
                    old_history.extend(new_history)
                    new_table_history_with_user.append(old_history)
                    if len(dish_on_table_per_user):
                        dish_table_user = copy_dict(dish_on_table_per_user[index])
                    else:
                        dish_table_user = {}
                    dishes_toadd = copy_dict(particles[index].dish_on_table_todelete)
                    for user in dishes_toadd:
                        if user not in dish_table_user:
                            dish_table_user[user] = {}
                        for t in dishes_toadd[user]:
                            assert t not in dish_table_user[user]
                            dish_table_user[user][t] = dishes_toadd[user][t]
                    new_dish_on_table_per_user.append(dish_table_user)

                # delete history from new particles
                for index in new_particle_indices:
                    particles[index].table_history_with_user = []
                    for user in particles[index].dish_on_table_todelete:
                        particles[index].dish_on_table_todelete[user] = {}

                for index in new_particle_indices:
                    particles[index].table_history_with_user = []
                    new_particle = particles[index].copy()
                    new_particle.reseed(prng.randint(maxint))
                    new_particle.reset_weight()
                    new_particles.append(new_particle)
                    inferred_tables[new_particle.uid] = \
                        copy(inferred_tables[particles[index].uid])
                particles = new_particles
                table_history_with_user = new_table_history_with_user
                dish_on_table_per_user = new_dish_on_table_per_user

                # If inferred tables dictionary grows too big, prune it
                if len(inferred_tables) > 50 * params.num_particles:
                    new_inferred_tables = {}
                    for p in particles:
                        new_inferred_tables[p.uid] = copy(inferred_tables[p.uid])
                    del inferred_tables
                    inferred_tables = new_inferred_tables
                with open(params.progress_file, mode='a') as temp:
                    temp.write("Time: %.2f (%d)\n" % (time() - start_tic, i))

    # Finally sample a single particle according to its weight.
    for p_i in particles:
        if max_logweight is None or max_logweight < p_i.logweight:
            max_logweight = p_i.logweight
    for p_i in particles:
        # Normalize the weights of the  particles
        if p_i.logweight - max_logweight >= \
                ln(epsilon) - ln(params.num_particles):
            weights.append(exp(p_i.logweight - max_logweight))
        else:
            weights.append(exp(p_i.logweight - max_logweight))
        total += weights[-1]
    normalized = [w / sum(weights) for w in weights]
    final_particle_id = pick_new_particles(particles, normalized, prng)[0]
    final_particle = particles[final_particle_id]

    table_history_with_user = table_history_with_user[final_particle_id]
    new_history = copy(final_particle.table_history_with_user)
    table_history_with_user.extend(new_history)
    final_particle.table_history_with_user = table_history_with_user
    dish_on_table_per_user = dish_on_table_per_user[final_particle_id]
    dishes_toadd = copy_dict(final_particle.dish_on_table_per_user)

    for user in dishes_toadd:
        if user not in dish_on_table_per_user:
            dish_on_table_per_user[user] = {}
        for t in dishes_toadd[user]:
            assert t not in dish_on_table_per_user[user]
            dish_on_table_per_user[user][t] = dishes_toadd[user][t]
    for user in final_particle.dish_on_table_todelete:
        if user not in dish_on_table_per_user:
            dish_on_table_per_user[user] = {}
        for t in final_particle.dish_on_table_todelete[user]:
            assert t not in dish_on_table_per_user[user]
            dish_on_table_per_user[user][t] = \
                final_particle.dish_on_table_todelete[user][t]
    final_particle.dish_on_table_per_user = dish_on_table_per_user

    final_particle.time_history_per_user = copy(time_history_per_user)
    final_particle.doc_history_per_user = copy(doc_history_per_user)
    final_particle.question_history_per_user = copy(question_history_per_user)
    final_particle.table_history_per_user = {}
    for (u_i, table) in final_particle.table_history_with_user:
        if u_i not in final_particle.table_history_per_user:
            final_particle.table_history_per_user[u_i] = []
        final_particle.table_history_per_user[u_i].append(table)
    final_particle.vocabulary = params.vocabulary
    # pool.close()
    with open(params.progress_file, mode='a') as temp:
        temp.write("Resampled %d times\n" % (count_resamples))
        temp.write("Finished in time: %.2f\n" %
                   (time() - start_tic))
    return final_particle, square_norms
예제 #16
0
 def _init( self, *args, **kwargs ):
     self.spider_shift = SpiderShift(
         self.mrc_file, self.pdb_file,
         **copy_dict(
             kwargs, run=False, output_dir=self.subdir("shift")
         )
     )
     self.spider_convert = SpiderConvert(
         self.spider_shift.map_shift,
         **copy_dict(
             kwargs, run=False, output_dir=self.subdir("convert")
         )
     )
     self.spider_box = SpiderBox(
         self.spider_shift.map_shift,
         self.spider_convert.map_file,
         self.spider_shift.edited_pdb_file, self.res1, self.res2,
         self.length, self.resolution,
         **copy_dict( kwargs, run=False, output_dir=self.subdir("box") )
     )
     self.pdb_box = SpiderPdbBox(
         self.spider_shift.edited_pdb_file,
         self.spider_box.box_file,
         self.spider_shift.map_shift,
         **copy_dict( kwargs, run=False, output_dir=self.subdir("pdbbox"))
     )
     self.spider_delete_filled_densities = SpiderDeleteFilledDensities(
         self.spider_shift.map_shift,
         self.spider_box.box_map_file,
         self.pdb_box.edited_pdb_file,
         self.spider_box.box_file,
         self.resolution, self.res1, self.res2,
         **copy_dict(
             kwargs, run=False,
             output_dir=self.subdir("delete_filled_densities")
         )
     )
     self.spider_reconvert = SpiderReConvert(
         self.spider_box.box_file,
         self.spider_convert.map_file,
         self.spider_box.box_map_file,
         self.spider_convert.map_file,
         **copy_dict(
             kwargs, run=False, output_dir=self.subdir("reconvert")
         )
     )
     self.spider_crosscorrelation = SpiderCrosscorrelation(
         self.spider_convert.map_file,
         self.spider_delete_filled_densities.empty_map_file,
         self.spider_box.box_file,
         self.loop_file,
         self.linkerinfo,
         **copy_dict(
             kwargs, run=False, output_dir=self.subdir("crosscorrelation"),
             max_loops=self.max_loops
         )
     )
     self.output_files.extend( list( itertools.chain(
         self.spider_convert.output_files,
         self.spider_delete_filled_densities.output_files,
         self.spider_box.output_files,
         self.spider_reconvert.output_files,
         self.spider_crosscorrelation.output_files
     )))
예제 #17
0
    def reset_and_replace(self, ped=None):
        '''
        Resets the environment and replaces one of the existing pedestrians
        from the video feed in the environment with the agent.
        Pro tip: Use this for testing the result.
        '''
        #pdb.set_trace()
        no_of_peds = len(self.pedestrian_dict.keys())

        if self.subject is None:
            while True:
                if ped is not None:
                    while str(ped) not in self.pedestrian_dict.keys():
                        ped += 1
                    self.cur_ped=ped
                    break
                else:
                    if self.is_random:
                        self.cur_ped = np.random.randint(1,no_of_peds+1)
                    else:
                        if self.cur_ped is None or self.cur_ped == self.last_pedestrian:
                            self.cur_ped = 1
                        else:
                            self.cur_ped += 1
                    if str(self.cur_ped) in self.pedestrian_dict.keys():
                        break
        else:
            self.cur_ped = self.subject

        #print('Replacing agent :', self.cur_ped)
        #if self.display:
        if self.show_comparison:
            self.ghost = self.cur_ped

        self.skip_list = [] 
        self.skip_list.append(self.cur_ped)
        if self.segment_size is None:
            
            self.current_frame = int(self.pedestrian_dict[str(self.cur_ped)]['initial_frame']) #frame from the first entry of the list
            self.final_frame = int(self.pedestrian_dict[str(self.cur_ped)]['final_frame'])
            self.goal_state = self.pedestrian_dict[str(self.cur_ped)][str(self.final_frame)]['position']
           
        else:
            first_frame = int(self.pedestrian_dict[str(self.cur_ped)]['initial_frame'])
            final_frame = int(self.pedestrian_dict[str(self.cur_ped)]['final_frame'])
            total_frames = final_frame - first_frame

            total_segments = int(total_frames/self.segment_size) + 1
            cur_segment = np.random.randint(total_segments)
            self.current_frame = first_frame + cur_segment*self.segment_size
            self.final_frame = min(self.current_frame+self.segment_size, final_frame)
            self.goal_state = self.pedestrian_dict[str(self.cur_ped)][str(self.final_frame)]['position']


        self.get_state_from_frame_universal(self.annotation_dict[str(self.current_frame)])

        self.agent_state = utils.copy_dict(self.pedestrian_dict[str(self.cur_ped)][str(self.current_frame)])
        #the starting state for any pedestrian in the dict has none for orientation and speed
        self.agent_state['speed'] = 0  #zero speed
        self.cur_heading_dir = 0
        self.agent_state['orientation'] = np.matmul(get_rot_matrix(deg_to_rad(self.cur_heading_dir)),
                                                                            np.array([-1, 0]))

        self.release_control = False


        self.state = {}
        self.state['agent_state'] = utils.copy_dict(self.agent_state)
        self.state['agent_head_dir'] = self.cur_heading_dir #starts heading towards top
        self.state['goal_state'] = self.goal_state

        self.state['release_control'] = self.release_control
        #if self.obstacles is not None:
        self.state['obstacles'] = self.obstacles

        self.pos_history = []
        self.pos_history.append((utils.copy_dict(self.agent_state), self.current_frame))
        if self.ghost:
            self.ghost_state_history = []
            self.ghost_state = utils.copy_dict(self.agent_state)
            self.ghost_state_history.append((utils.copy_dict(self.ghost_state), self.current_frame))

        if self.ghost:
            self.ghost_state_history = []
            self.ghost_state = utils.copy_dict(self.agent_state)
            self.ghost_state_history.append((utils.copy_dict(self.ghost_state), self.current_frame))

        self.state['ghost_state'] = utils.copy_dict(self.ghost_state)
        self.distanceFromgoal = np.linalg.norm(self.agent_state['position']-self.goal_state,1)
        self.heading_dir_history = []
        self.heading_dir_history.append(self.cur_heading_dir)

        if self.display:
            pygame.display.set_caption('Your not so friendly continuous environment')
            self.render()


        #pdb.set_trace()
        return self.state
예제 #18
0
    def reset(self, ped=None):
        '''
        Resets the environment, starting the obstacles from the start.
        If subject is specified, then the initial frame and final frame is set
        to the time frame when the subject is in the scene.

        If no subject is specified then the initial frame is set to the overall
        initial frame and goes till the last frame available in the annotation file.

        Also, the agent and goal positions are initialized at random.

        Pro tip: Use this function while training the agent.
        '''
        if self.replace_subject:
            return self.reset_and_replace(ped)

        else:
            self.current_frame = self.initial_frame
            self.pos_history = []
            self.ghost_state_history = []
            #if this flag is true, the position of the obstacles and the goal
            #change with each reset
            dist_g = self.goal_spawn_clearance

            if self.annotation_file:
                self.get_state_from_frame_universal(self.annotation_dict[str(self.current_frame)])


            num_obs = len(self.obstacles)

            #placing the obstacles



            #only for the goal and the agent when the subject is not specified speicfically.

            if self.cur_ped is None:

                #placing the goal
                while True:
                    flag = False
                    self.goal_state = np.asarray([np.random.randint(self.lower_limit_goal[0],self.upper_limit_goal[0]),
                                                  np.random.randint(self.lower_limit_goal[1],self.upper_limit_goal[1])])

                    for i in range(num_obs):
                        if np.linalg.norm(self.obstacles[i]['position']-self.goal_state) < dist_g:

                            flag = True
                    if not flag:
                        break

                #placing the agent
                dist = self.agent_spawn_clearance
                while True:
                    flag = False
                    #pdb.set_trace()
                    self.agent_state['position'] = np.asarray([np.random.randint(self.lower_limit_agent[0],self.upper_limit_agent[0]),
                                                   np.random.randint(self.lower_limit_agent[1],self.upper_limit_agent[1])])

                    for i in range(num_obs):
                        if np.linalg.norm(self.obstacles[i]['position']-self.agent_state['position']) < dist:
                            flag = True

                    if not flag:
                        break
                #add speed and orientation to the agent state after it is placed successfully

                self.agent_state['speed'] = 0 #dead stop
                self.cur_heading_dir = 0 #pointing upwards
                self.agent_state['orientation'] = np.matmul(get_rot_matrix(deg_to_rad(self.cur_heading_dir)),
                                                                            np.array([self.agent_state['speed'], 0]))

            self.release_control = False


            self.state = {}
            self.state['agent_state'] = utils.copy_dict(self.agent_state)
            self.state['agent_head_dir'] = self.cur_heading_dir #starts heading towards top
            self.state['goal_state'] = self.goal_state

            self.state['release_control'] = self.release_control
            #if self.obstacles is not None:
            self.state['obstacles'] = self.obstacles

            self.pos_history.append((utils.copy_dict(self.agent_state), self.current_frame))
            if self.ghost:
                self.ghost_state_history.append((utils.copy_dict(self.ghost_state), self.current_frame))

            self.state['ghost_state'] = utils.copy_dict(self.ghost_state)
            self.distanceFromgoal = np.linalg.norm(self.agent_state['position']-self.goal_state,1)
            self.cur_heading_dir = 0
            self.heading_dir_history = []
            self.heading_dir_history.append(self.cur_heading_dir)

            pygame.display.set_caption('Your not so friendly continuous environment')
            if self.display:
                self.render()

            return self.state
예제 #19
0
    def step(self, action=None):
        '''
        if external_control: t
            the state of the agent is updated based on the current action. 
        else:
            the state of the agent is updated based on the information from the frames

        the rest of the actions, like calculating reward and checking if the episode is done remains as usual.
        '''
        self.current_frame += 1

        if str(self.current_frame) in self.annotation_dict.keys():
            self.get_state_from_frame_universal(self.annotation_dict[str(self.current_frame)])

        if self.external_control:

            if not self.release_control:

                if action is not None:

                    if not self.continuous_action:
                        action_orient = int(action%len(self.orientation_array))
                        action_speed = int(action/len(self.orientation_array))
                        #print(action_speed, "   ", action)
                        orient_change = self.orientation_array[action_orient]
                        speed_change = self.speed_array[action_speed]
                    else:
                        speed_change = action[0]
                        orient_change = action[1]

                    #if self.consider_heading:
                        #after 360, it comes back to 0

                    self.cur_heading_dir = (self.cur_heading_dir+orient_change)%360
                    agent_cur_speed = max(0,min(self.agent_state['speed'] + speed_change, self.max_speed))

                    prev_position = self.agent_state['position']
                    rot_mat = get_rot_matrix(deg_to_rad(-self.cur_heading_dir))
                    cur_displacement = np.matmul(rot_mat, np.array([-agent_cur_speed, 0]))
                    '''
                    cur_displacement is a 2 dim vector where the displacement is in the form:
                        [row, col]
                    '''
                    self.agent_state['position'] = np.maximum(np.minimum(self.agent_state['position']+ \
                                       cur_displacement,self.upper_limit_agent),self.lower_limit_agent)

                    self.agent_state['speed'] = agent_cur_speed
                    self.agent_state['orientation'] = np.matmul(rot_mat, np.array([-1,0]))

            self.heading_dir_history.append(self.cur_heading_dir)

            self.pos_history.append((utils.copy_dict(self.agent_state), self.current_frame))

            if self.ghost:
                self.ghost_state_history.append((utils.copy_dict(self.ghost_state), self.current_frame))

        #calculate the reward and completion condition
        reward, done = self.calculate_reward(action)
        self.prev_action = action

        #if you are done ie hit an obstacle or the goal
        #you leave control of the agent and you are forced to
        #suffer/enjoy the consequences of your actions for the
        #rest of your miserable/awesome life

        if self.display:
            self.render()

        # step should return fourth element 'info'

        #just update the position of the agent
        #the rest of the information remains the same

        #added new
        if not self.release_control:
            self.state['agent_state'] = utils.copy_dict(self.agent_state)
            self.state['agent_head_dir'] = self.cur_heading_dir
            self.state['ghost_state'] = utils.copy_dict(self.ghost_state)
        if self.external_control:
            if done:
                self.release_control = True

        return self.state, reward, done, None
예제 #20
0
def remove_fields(_td, _field, exact_field=False, inplace=True):
    '''
    This function removes fields from a dict.

    Parameters
    ----------
    _td : dict / list of dict
        dict of the trial data.
    _field : str / list of str
        Fields to remove.
    exact_field : bool, optional
        Look for the exact field name in the dict. The default is False.
    inplace : bool, optional
        Perform operaiton on the input data dict. The default is False.

    Returns
    -------
    td : dict/list of dict
        trial data dict with added items

    '''

    if inplace:
        td = _td
    else:
        td = copy_dict(_td)

    # check dict input variable
    input_dict = False
    if type(td) is dict:
        input_dict = True
        td = [td]

    if type(td) is not list:
        raise Exception('ERROR: _td must be a list of dictionaries!')

    # check string input variable
    if type(_field) is str:
        _field = [_field]

    if type(_field) is not list:
        raise Exception('ERROR: _field must be a list of strings!')

    for td_tmp in td:
        td_copy = td_tmp.copy()
        for iStr in _field:
            any_del = False
            for iFld in td_copy.keys():
                if exact_field:
                    if iStr == iFld:
                        del td_tmp[iFld]
                        any_del = True
                else:
                    if iStr in iFld:
                        del td_tmp[iFld]
                        any_del = True
            if not any_del:
                print('Field {} not found. I could not be removed...'.format(
                    iStr))

    if input_dict:
        td = td[0]

    if not inplace:
        return td
예제 #21
0
def add_params(_td, params, **kwargs):
    '''
    This function adds parameters to the trial data dictionary

    Parameters
    ----------
    _td : dict / list of dict
        dict of trial data.
    params : dict
        Params of the data. See below for examples
    data_struct : str, optional
        Type of _td struct in input. The default value is 'flat'.
        Options:
            flat (all data available on the first layer of the dict)
            layer (data separated among inside layers of the dict)
        
    Returns
    -------
    td : dict / list of dict.
        dict of trial data.
        
    Examples for params in input:
    params = {'folder' : 'FOLDER_FIELD',
              'file' : 'FILE_FIELD',
              'data':{'EMG': {'signals':['SIGNAL_1','...','SIGNAL_N'], 'Fs': 'FS_FIELD', 'time': 'TIME_FIELD'},
                      'LFP': {'signals':['SIGNAL_1','...','SIGNAL_N'], 'Fs': 'FS_FIELD', 'time': 'TIME_FIELD'},
                      'KIN': {'signals':['SIGNAL_1','...','SIGNAL_N'], 'Fs': 'FS_FIELD', 'time': 'TIME_FIELD'}}}
    
    params = {'folder' : 'FOLDER_FIELD',
              'file' : 'FILE_FIELD',
              'data': {'Data':{'signals':['SIGNAL_1','...','SIGNAL_N'], 'Fs': 'FS_FIELD', 'time': 'TIME_FIELD'}}}
    '''

    data_struct = 'flat'
    inplace = True

    # Check input variables
    for key, value in kwargs.items():
        key = key.lower()
        if key == 'data_struct':
            data_struct = value
        elif key == 'inplace':
            inplace = value

    if data_struct not in ['flat', 'layer']:
        raise Exception(
            'ERROR: data_struct nor "flat", or "layer". It is: {}'.filename(
                data_struct))

    if inplace:
        td = _td
    else:
        td = copy_dict(_td)

    input_dict = False
    if type(td) == dict:
        input_dict = True
        td = [td]

    # Loop over the trials
    for td_tmp in td:
        # Params initiation
        if 'params' not in td_tmp.keys():
            signals_2_use = ['params']
            td_tmp['params'] = dict()
            td_tmp['params']['data'] = dict()
            td_tmp['params']['event'] = dict()
        else:
            signals_2_use = set(td_tmp.keys())
            params = td_tmp['params']

        params_c = copy.deepcopy(params)
        # Check input variables
        for key, val in params_c.items():
            # key = 'EMG'; val = params[key];
            if key in ['folder', 'file']:
                td_tmp['params'][key] = td_tmp[val]

            elif key in ['data']:
                for ke, va in val.items():
                    td_tmp['params']['data'][ke] = dict()
                    for k, v in va.items():
                        # k = 'signals'; v = va[k];
                        if k in ['signals', 'time']:
                            td_tmp['params']['data'][ke][k] = v
                            if type(v) is list:
                                signals_2_use.extend(v)
                                if data_struct == 'layer':  # place data on the main layer
                                    for el in v:
                                        td_tmp[el] = td_tmp[ke][el]
                            elif type(v) is str:
                                signals_2_use.append(v)
                                if data_struct == 'layer':  # place data on the main layer
                                    td_tmp[v] = td_tmp[ke][v]
                            else:
                                raise Exception(
                                    'ERROR: Value "{}" in params is nor str or list! It is {}...'
                                    .format(k, v))
                        elif k in ['fs']:
                            if data_struct == 'layer':
                                val2take = td_tmp[ke][v]
                                if type(val2take
                                        ) is np.ndarray and val2take.size == 1:
                                    val2take = val2take[0]
                                td_tmp['params']['data'][ke][k] = val2take
                            elif data_struct == 'flat':
                                val2take = td_tmp[v]
                                if type(val2take
                                        ) is np.ndarray and val2take.size == 1:
                                    val2take = val2take[0]
                                td_tmp['params']['data'][ke][k] = val2take
                        else:
                            raise Exception(
                                'ERROR: Value in params dict "{}" is not "signal", "time", or "fs"! It is {}...'
                                .format(k, v))

                    # Check existance of 'fs' and 'time' data info
                    keys_in_dict = set(td_tmp['params']['data'][ke].keys())
                    if 'fs' not in keys_in_dict and 'time' not in keys_in_dict:
                        print(
                            'WARNING: neither time or fs are available for "{}"'
                            .format(ke))
                    elif 'fs' not in keys_in_dict and 'time' in keys_in_dict:
                        td_tmp['params']['data'][ke]['fs'] = 1 / np.diff(
                            td_tmp[td_tmp['params']['data'][ke]
                                   ['time']][:2])[0]
                    elif 'fs' in keys_in_dict and 'time' not in keys_in_dict:
                        fs = td_tmp['params']['data'][ke]['fs']
                        sign_len = np.max(td_tmp[td_tmp['params']['data'][ke]
                                                 ['signals'][0]].shape)
                        td_tmp['params']['data'][ke]['time'] = ke + '_time'
                        td_tmp[ke + '_time'] = np.linspace(
                            0, sign_len / fs, sign_len)
                        signals_2_use.append(ke + '_time')

            elif key in ['event']:
                for ke, va in val.items():
                    td_tmp['params']['event'][ke] = dict()
                    for k, v in va.items():
                        # k = 'signals'; v = va[k];
                        if k in ['signals']:
                            td_tmp['params']['event'][ke][k] = v
                            if type(v) is list:
                                signals_2_use.extend(v)
                                if data_struct == 'layer':  # place data on the main layer
                                    for el in v:
                                        td_tmp[el] = td_tmp[ke][el]
                            elif type(v) is str:
                                signals_2_use.append(v)
                                if data_struct == 'layer':  # place data on the main layer
                                    td_tmp[v] = td_tmp[ke][v]
                            else:
                                raise Exception(
                                    'ERROR: Value "{}" in params is nor str or list! It is {}...'
                                    .format(k, v))
                        elif k in ['kind']:
                            td_tmp['params']['event'][ke][k] = v
                        else:
                            raise Exception(
                                'ERROR: Value in params dict "{}" is not "signal", "time", or "fs"! It is {}...'
                                .format(k, v))

    remove_all_fields_but(td,
                          flatten_list(set(signals_2_use)),
                          exact_field=True,
                          inplace=True)

    if input_dict:
        td = td[0]

    if not inplace:
        return td
예제 #22
0
파일: smc.py 프로젝트: zcmail/hdhp.py
    def copy(self):
        new_p = Particle(num_users=self.num_users,
                         vocabulary_length=self.vocabulary_length,
                         seed=self.seed, mu_rate=self.mu_rate,
                         theta_0=self.theta_0,
                         omega=self.omega,
                         beta=self.beta,
                         mu_0=self.mu_0,
                         uid=self.uid,
                         logweight=self.logweight,
                         update_kernels=self.update_kernels,
                         keep_alpha_history=self.keep_alpha_history)
        new_p.alpha_0 = copy(self.alpha_0)
        new_p.num_events = self.num_events
        new_p.topic_previous_event = self.topic_previous_event
        new_p.total_tables = self.total_tables
        new_p._max_dish = self._max_dish

        new_p.time_previous_user_event = copy(self.time_previous_user_event)
        new_p.total_tables_per_user = copy(self.total_tables_per_user)
        new_p.first_observed_time = copy(self.first_observed_time)
        new_p.first_observed_user_time = copy(self.first_observed_user_time)
        new_p.table_history_with_user = copy(self.table_history_with_user)

        new_p.dish_cache = copy_dict(self.dish_cache)
        new_p.dish_counters = copy_dict(self.dish_counters)
        new_p.dish_on_table_per_user = \
            copy_dict(self.dish_on_table_per_user)

        new_p.dish_on_table_per_user = {}
        new_p.dish_on_table_todelete = {}
        for u in self.dish_on_table_per_user:
            new_p.dish_on_table_per_user[u] = {}
            new_p.dish_on_table_todelete[u] = {}
            self.dish_on_table_todelete[u] = {}

            for t in self.dish_on_table_per_user[u]:
                if t in self.active_tables_per_user[u]:
                    new_p.dish_on_table_per_user[u][t] = \
                        self.dish_on_table_per_user[u][t]
                else:
                    dish = self.dish_on_table_per_user[u][t]
                    self.dish_on_table_todelete[u][t] = dish
                    new_p.dish_on_table_todelete[u][t] = dish
                    if t in self.user_table_cache[u]:
                        del self.user_table_cache[u][t]

        new_p.per_topic_word_counts = copy_dict(self.per_topic_word_counts)
        new_p.per_topic_word_count_total = copy_dict(self.per_topic_word_count_total)
        new_p.time_kernels = copy_dict(self.time_kernels)
        new_p.time_kernel_prior = copy_dict(self.time_kernel_prior)
        new_p.user_table_cache = copy_dict(self.user_table_cache)
        if self.keep_alpha_history:
            new_p.alpha_history = copy_dict(self.alpha_history)
            new_p.alpha_distribution_history = \
                copy_dict(self.alpha_distribution_history)
        new_p.mu_per_user = copy_dict(self.mu_per_user)
        new_p.active_tables_per_user = copy_dict(self.active_tables_per_user)
        return new_p
예제 #23
0
    def _init(self, *args, **kwargs):
        self._init_records(None, **kwargs)
        self._init_parallel(self.pdb_input, **kwargs)

        if not self.parallel:
            self.pdb_id = self.pdb_input
            self.npdb_features = {"sstruc": False, "phi_psi": False}

            self.output_files = []

            self.pdb_info = PdbInfo(
                self.pdb_id,
                **copy_dict(kwargs,
                            run=False,
                            output_dir=self.subdir("pdb_info")))
            self.output_files += [self.pdb_info.info_file]

            self.opm_info = OpmInfo(
                self.pdb_id,
                **copy_dict(kwargs,
                            run=False,
                            output_dir=self.subdir("opm_info")))

            # self.opm.info_file can be used to distinguish
            if self.use_ppm2:
                self.opm = Ppm2(
                    self.pdb_id,
                    **copy_dict(kwargs,
                                run=False,
                                output_dir=self.subdir("opm")))
            else:
                self.opm = Opm(
                    self.pdb_id,
                    **copy_dict(kwargs,
                                run=False,
                                output_dir=self.subdir("opm")))
            self.output_files += self.opm.output_files
            self.output_files += [self.processed_pdb, self.no_water_file]
            self.dowser = DowserRepeat(
                self.processed_pdb,
                **copy_dict(kwargs,
                            run=False,
                            alt='x',
                            output_dir=self.subdir("dowser"),
                            max_repeats=self.dowser_max))
            self.output_files += self.dowser.output_files
            msms_kwargs = {
                "all_components": True,
                "density": 1.0,
                "hdensity": 3.0,
                "envelope": self.probe_radius * 2,
                "envelope_hclust": self.envelope_hclust,
                "atom_radius_add": self.atom_radius_add,
            }
            self.dowserPlus2 = DowserPlus2(
                self.processed_pdb,
                **copy_dict(kwargs,
                            run=False,
                            output_dir=self.subdir("dowser++")))
            self.output_files += self.dowserPlus2.output_files
            self.msms0 = Msms(
                self.no_water_file,
                **copy_dict(kwargs,
                            run=False,
                            output_dir=self.subdir("msms0"),
                            **copy_dict(msms_kwargs,
                                        probe_radius=self.vdw_probe_radius,
                                        all_components=False)))
            self.output_files += self.msms0.output_files
            self.output_files += [self.original_dry_pdb, self.final_pdb]

            self.dssp = Dssp(
                self.final_pdb,
                **copy_dict(kwargs, run=False, output_dir=self.subdir("dssp")))
            self.output_files += self.dssp.output_files

            self.mpstruc_download = MpstrucDownload(
                self.pdb_id,
                **copy_dict(kwargs,
                            run=False,
                            output_dir=self.subdir("mpstruc_download")))
            self.output_files += self.mpstruc_download.output_files

            self.mpstruc_info = MpstrucInfo(
                self.pdb_id,
                mpstruc_xml=self.mpstruc_download.mpstruc_xml,
                **copy_dict(kwargs,
                            run=False,
                            output_dir=self.subdir("mpstruc_info")))

            self.water_variants = [
                ("non", self.no_water_file),
                ("org", self.original_dry_pdb),
                ("dow", self.dowser_dry_pdb),
                ("fin", self.final_pdb),
            ]

            self.tool_list = [
                ("voronoia", Voronoia, {
                    "ex": 0.2,
                    "shuffle": self.voro_shuffle
                }),
                ("hbexplore", HBexplore, {}),
                ("msms_vdw", Msms,
                 copy_dict(msms_kwargs, probe_radius=self.vdw_probe_radius)),
                # ( "msms_coulomb", Msms, copy_dict( msms_kwargs,
                #     probe_radius=self.probe_radius) )
            ]

            def filt(lst, lst_all):
                if not lst:
                    return lst_all
                if lst[0] == "!":
                    return [e for e in lst_all if e[0] not in lst[1:]]
                else:
                    return [e for e in lst_all if e[0] in lst]

            self.water_variants = filt(self.variants, self.water_variants)
            self.do_tool_list = filt(self.tools, self.tool_list)

            for suffix, pdb_file in self.water_variants:
                for prefix, tool, tool_kwargs in self.tool_list:
                    name = "%s_%s" % (prefix, suffix)
                    self.__dict__[name] = tool(
                        pdb_file,
                        **copy_dict(kwargs,
                                    run=False,
                                    output_dir=self.subdir(name),
                                    **tool_kwargs))
                    self.output_files += self.__dict__[name].output_files

            self.output_files += [self.mpstruc_info.info_file]
            self.output_files += [self.opm_info.info_file]
            self.output_files += [self.info_file]
            self.output_files += [self.outpath("mppd.provi")]
            self.output_files += [self.stats_file]
예제 #24
0
def remove_all_fields_but(_td, _field, exact_field=False, inplace=True):
    '''
    This function removes all fields from a dict but the one selected.

    Parameters
    ----------
    _td : dict / list of dict
        dict of the trial data.
    _field : str / list of str
        Field to keep.
    exact_field : bool, optional
        Look for the exact field name in the dict. The default is False.
    inplace : bool, optional
        Perform operaiton on the input data dict. The default is False.

    Returns
    -------
    td : dict/list of dict
        trial data dict with added items

    '''

    if inplace:
        td = _td
    else:
        td = copy_dict(_td)

    # check dict input variable
    input_dict = False
    if type(td) is dict:
        input_dict = True
        td = [td]

    if type(td) is not list:
        raise Exception('ERROR: _td must be a list of dictionaries!')

    # check string input variable
    if type(_field) is str:
        _field = [_field]

    if type(_field) is not list:
        raise Exception('ERROR: _str must be a list of strings!')

    for td_tmp in td:
        td_copy = td_tmp.copy()
        for field in td_copy.keys():
            for field_name in _field:
                del_field = True
                if exact_field:
                    if field_name == field:
                        del_field = False
                        break
                else:
                    if field_name in field:
                        del_field = False
                        break
            if del_field == True:
                del td_tmp[field]

    if input_dict:
        td = td[0]

    if not inplace:
        return td
예제 #25
0
def collect_trajectories_and_metrics_non_NN(
    env,
    agent,
    num_trajectories,
    max_episode_length,
    metric_applicator,
    feature_extractor=None,
    disregard_collisions=False,
):
    """
    Helper function that collects trajectories and applies metrics from a
    metric applicator on a per trajectory basis.

    :param env: environment to collect trajectories from.
    :type env: any gym-like environment.


    :param policy: Agent to extract actions from.
    :type Agent: A non Neural netrowk based controller.

    :param num_trajectories: Number of trajectories to sample.
    :type num_trajectories: int.

    :param max_episode_length: Maximum length of individual trajectories.
    :type max_episode_length: int.

    :param metric_applicator: a metric applicator class containing all
    metrics that need to be applied.
    :type metric_applicator: Instance, child, or similar to
    metric_utils.MetricApplicator.

    :param feature_extractor: a feature extractor to translate state
    dictionary to a feature vector.
    :type feature_extractor: feature extractor class.
    
    
    :return: dictionary mapping trajectory to metric results from that
    trajectory.
    :rtype: dictionary
    """

    metric_results = {}

    for traj_idx in range(num_trajectories):

        state = env.reset()
        current_pedestrian = env.cur_ped
        print("Collecting trajectory {} \r".format(current_pedestrian))
        done = False
        t = 0
        traj = [copy_dict(state)]

        while not done and t < max_episode_length:

            if feature_extractor is not None:
                feat = feature_extractor.extract_features(state)
                feat = (
                    torch.from_numpy(feat).type(torch.FloatTensor).to(DEVICE)
                )
            else:
                feat = state
            action = agent.eval_action(feat)
            state, _, done, _ = env.step(action)

            if done:
                env.release_control = False

            if disregard_collisions:
                done = env.check_overlap(
                    env.agent_state["position"],
                    env.goal_state,
                    env.cellWidth,
                    0,
                )

            traj.append(copy_dict(state))

            t += 1

        # metrics
        traj_metric_result = metric_applicator.apply([traj])
        metric_results[current_pedestrian] = traj_metric_result

    return metric_results
예제 #26
0
def separate_fields(_td, _fields, **kwargs):
    '''
    This function combines the fields in one dictionary

    Parameters
    ----------
    _td : dict / list of dict
        Trial data.
    _fields : list of str
        Two fields to combine.
    new_names : list of str
        Method for combining the arrays. The default is subtract.
    inplace : bool, optional
        Perform operation on the input data dict. The default is False.

    Returns
    -------
    td : dict / list of dict
        Trial data.

    '''

    new_names = None
    inplace = True
    save_to_params = False

    # Check input variables
    for key, value in kwargs.items():
        key = key.lower()
        if key == 'new_names':
            new_names = value
        elif key == 'inplace':
            inplace = value
        elif key == 'save_to_params':
            save_to_params = True
            save_to_params_field = value

    if inplace:
        td = _td
    else:
        td = copy_dict(_td)

    # check dict input variable
    input_dict = False
    if type(td) is dict:
        input_dict = True
        td = [td]

    if type(td) is not list:
        raise Exception('ERROR: _td must be a list of dictionaries!')

    if not is_field(td, _fields):
        raise Exception('ERROR: some _fileds are not in td!')

    # check _fields input variable
    if type(_fields) is str:
        _fields = [_fields]
    if type(_fields) is not list:
        raise Exception('ERROR: _fields must be a list!')

    # Check new names
    if new_names == None:
        new_names = []
        for td_tmp in td:
            for field in _fields:
                field_size = transpose(np.array(td_tmp[field]),
                                       'column').shape[1]
                new_names.append(
                    [field + '_' + str(iS) for iS in range(field_size)])
    else:
        if type(new_names) is not list:
            raise Exception('ERROR: new_names must be a list!')
        if type(new_names[0]) is str:
            new_names = [new_names]

        for td_tmp in td:
            for iN, (names, field) in enumerate(zip(new_names, _fields)):
                field_size = transpose(np.array(td_tmp[field]),
                                       'column').shape[1]
                if len(names) != field_size:
                    raise Exception(
                        'ERROR: new_names[{}] has different length compared to the dimension of field "{}"!'
                        .format(iN, field))

    if save_to_params:
        subfield = td_subfield(td[0], save_to_params_field)
        if 'signals' not in set(subfield.keys()):
            raise Exception(
                'ERROR: field "signals" does not exist in "{}"'.format(
                    save_to_params_field))

    for td_tmp in td:
        for names, field in zip(new_names, _fields):
            for iN, name in enumerate(names):
                td_tmp[name] = transpose(np.array(td_tmp[field]), 'column')[:,
                                                                            iN]

        if save_to_params:
            subfield = td_subfield(td_tmp, save_to_params_field)
            for field in _fields:
                subfield['signals'].remove(field)
            subfield['signals'].extend(flatten_list(new_names))

    remove_fields(td, flatten_list(_fields), exact_field=True, inplace=inplace)

    if input_dict:
        td = td[0]

    if not inplace:
        return td
예제 #27
0
 def __init__(self, app):
     
     self.app = app
     self.config_file = app.config_file
     self.tool_path = utils.copy_dict(app.tool_path)
     self.dirty = False
예제 #28
0
def get_field(_td, _signals, save_signals_name=False):
    '''
    This function get fields from the trial data and place them in a another dict
    
    Parameters
    ----------
    _td : dict
        dict from which we collect the fields.
    _signals : str/list of str
        Fields to collect from the dict.
    save_signals_name : bool
        Add a field containing the name of the fileds. The default is False.
        
    Returns
    -------
    td_out : dict
        New dict containing the selected fields.

    '''
    td = copy_dict(_td)
    td_out = []

    input_dict = False
    if type(td) == dict:
        input_dict = True
        td = [td]

    if type(_signals) == str:
        print(
            'Signals input must be a list. You inputes a string --> converting to list...'
        )
        _signals = [_signals]

    # Check that _signals are in the dictionary
    if not is_field(td, _signals):
        raise Exception('ERROR: Selected signals are not in the dict')

    # Loop over the trials
    for td_tmp in td:
        td_out_tmp = dict()
        # List of signals name
        signals_name = []
        # Loop over the signals
        for sgl in _signals:
            if type(sgl) == list:
                signal_name = '{} - {}'.format(sgl[0], sgl[1])
                signals_name.append(signal_name)
                td_out_tmp[signal_name] = np.array(td_tmp[sgl[0]]) - np.array(
                    td_tmp[sgl[1]])
            else:
                signal_name = sgl
                signals_name.append(signal_name)
                td_out_tmp[signal_name] = np.array(td_tmp[sgl])

        if save_signals_name:
            td_out_tmp['params'] = dict()
            td_out_tmp['params']['signals'] = signals_name

        td_out.append(td_out_tmp)

    if input_dict:
        td_out = td_out[0]

    return td_out
예제 #29
0
 def _init(self, *args, **kwargs):
     self.pdb_assembly = PdbAssembly(
         self.pdb_id,
         **copy_dict(kwargs, run=False, output_dir=self.subdir("assembly")))
예제 #30
0
def combine_fields(_td, _fields, **kwargs):
    '''
    This function combines the fields in one dictionary

    Parameters
    ----------
    _td : dict / list of dict
        Trial data.
    _fields : list of str
        Two fields to combine.
    method : str
        Method for combining the arrays. The default is subtract.
    remove_fields : bool, optional
        Remove the fields before returning the dataset. The default is True.
    inplace : bool, optional
        Perform operation on the input data dict. The default is False.

    Returns
    -------
    td : dict / list of dict
        Trial data.

    '''

    method = 'subtract'
    remove_selected_fields = True
    save_to_params = False
    inplace = True

    # Check input variables
    for key, value in kwargs.items():
        key = key.lower()
        if key == 'method':
            method = value
        elif key == 'remove_fields':
            remove_selected_fields = value
        elif key == 'save_to_params':
            save_to_params = True
            save_to_params_field = value
        elif key == 'inplace':
            inplace = value

    if inplace:
        td = _td
    else:
        td = copy_dict(_td)

    # check dict input variable
    input_dict = False
    if type(td) is dict:
        input_dict = True
        td = [td]

    if type(td) is not list:
        raise Exception('ERROR: _td must be a list of dictionaries!')

    # check _fields input variable
    if type(_fields) is not list:
        raise Exception('ERROR: _fields must be a list!')

    if type(_fields[0]) is not list:
        _fields = [_fields]

    for field in _fields:
        if len(field) != 2:
            raise Exception(
                'ERROR: lists in _fields must cointain max 2 strings!')

    if method not in ['subtract', 'multuply', 'divide', 'add']:
        raise Exception('ERROR: specified method has not been implemented!')

    if save_to_params:
        subfield = td_subfield(td[0], save_to_params_field)
        if 'signals' not in set(subfield.keys()):
            raise Exception(
                'ERROR: field "signals" does not exist in "{}"'.format(
                    save_to_params_field))

    for td_tmp in td:
        signals_name = []
        for field in _fields:
            if len(field[0]) != len(field[1]):
                raise Exception(
                    'ERROR: the 2 arrays must have the same length!')

            if method == 'subtract':
                signal_name = '{}-{}'.format(field[0], field[1])
                td_tmp[signal_name] = bipolar(td_tmp[field[0]],
                                              td_tmp[field[1]])
            elif method == 'add':
                signal_name = '{}+{}'.format(field[0], field[1])
                td_tmp[signal_name] = add(td_tmp[field[0]], td_tmp[field[1]])
            elif method == 'multiply':
                raise Exception('Method must be implemented!')
            elif method == 'divide':
                raise Exception('Method must be implemented!')
            signals_name.append(signal_name)

        if save_to_params:
            subfield = td_subfield(td_tmp, save_to_params_field)
            if remove_selected_fields:
                subfield['signals'] = signals_name
            else:
                subfield['signals'].extend(signals_name)

    if remove_selected_fields:
        remove_fields(td,
                      flatten_list(_fields),
                      exact_field=True,
                      inplace=inplace)

    if input_dict:
        td = td[0]

    if not inplace:
        return td