コード例 #1
0
ファイル: network.py プロジェクト: LaneWei/doom-rl
 def optimize(self):
     sample = self.get_sample()
     if sample is None:
         return
     
     states, actions, rewards, states_, masks = sample
     states = process_batch(states)
     states_ = process_batch(states_)
     v_ = self.session.run(self.v, {self.s_input: states_})
     target = rewards + self.gamma_n * v_ * masks
     _, a, loss = self.session.run([self.train_op, self.advantage, self.loss], {self.s_input: states,
                                                                                self.a_input: actions,
                                                                                self.r_input: target})
     return loss
コード例 #2
0
def get_search_results_with_relevancies(user_relevancies=None):
    if user_relevancies is None:
        user_relevancies = {}

    print("Generating search results...")
    # document list with relevances (search results) for each (query, algorithm) pair
    results = {(query, algo): results
               for query, algo, results in process_batch(
                   [(query, SEARCH_LIMIT, algo, user_relevancies)
                    for query in queries
                    for algo in algorithms], search_wrapper)}
    return results
コード例 #3
0
ファイル: parsing.py プロジェクト: amankhoza/data-mining
    def validate_docs_links_out(docs):
        msg = "Validating documents links_out"
        logger.info('%s %s', MSG_START, msg)

        # path_to_url = {}
        # for doc in docs:
        #     path_to_url[doc.path] = doc.url
        internal_urls = set([doc.url for doc in docs])

        docs = process_batch([(doc, internal_urls) for doc in docs], UCLParser.validate_doc_links_out)

        # logger.info('Validated unique urls: %d', len(urls))
        logger.info('%s %s', MSG_SUCCESS, msg)
        return docs
コード例 #4
0
ファイル: main.py プロジェクト: grill-lab/trec-cast-tools
    output_path = f"{args.output_dir}/{args.output_type}"
    md5_dir_path = f"{args.output_dir}/md5_hashes"
    Path(output_path).mkdir(parents=True, exist_ok=True)
    Path(md5_dir_path).mkdir(parents=True, exist_ok=True)

    if not args.skip_process_kilt:
        print("--- Processing KILT ---")

        kilt_generator: KILTGenerator = KILTGenerator(
            args.kilt_collection, args.duplicates_file, args.batch_size
        ).generate_documents()

        process_batch(
            collection_name='KILT',
            generator=kilt_generator,
            passage_chunker=passage_chunker,
            output_type=args.output_type,
            output_path=output_path,
            md5_dir_path=md5_dir_path
        )

    if not args.skip_process_marco_v2:
        print("Processing MARCO v2")

        marco_v2_generator: MARCO_v2_Generator = MARCO_v2_Generator(
            args.marco_v2_collection, args.duplicates_file, args.batch_size
        ).generate_documents()

        process_batch(
            collection_name='MARCO_v2',
            generator=marco_v2_generator,
            passage_chunker=passage_chunker,
コード例 #5
0
ファイル: parsing.py プロジェクト: amankhoza/data-mining
    def parse_website(dir, use_cache=False, multithreading=True):
        msg = "Parsing website"
        logger.info('%s %s', MSG_START, msg)
        logger.info("From directory %s", dir)

        if not os.path.isdir(dir):
            logger.error('ERROR: Directory %s does not exist!', dir)
            logger.info('%s %s', MSG_FAILED, msg)
            exit()

        pickle_file = docs_cache_dir + re.sub('/|:|\\\\', '', dir) + '.pickle'
        loaded_from_cache = False

        docs = []

        if use_cache:
            # Read docs from cache
            try:
                with open(pickle_file, 'rb') as handle:
                    docs = pickle.load(handle)
                    loaded_from_cache = True
                    logger.info("Loaded %d documents from cache.", len(docs))
            except Exception:
                logger.info("No cached documents found")

        if not loaded_from_cache:
            logger.info("Getting file paths...")
            files = get_files(dir, '.html')
            logger.info("Found %d html files", len(files))
            logger.info("Parsing files...")
            if multithreading:
                docs = [doc for doc in process_batch([(file,) for file in files], UCLParser.parse_file) if doc is not None]
            else:
                docs = []
                total_docs = len(files)
                for i, file in enumerate(files):
                    doc = UCLParser.parse_file(file)
                    if doc is not None:
                        docs.append(doc)
                    print_progress(i + 1, total_docs, 'Progress:')
                # docs = [doc for doc in process_batch([(file,) for file in files], UCLParser.parse_file, 1) if doc is not None]
            logger.info("Successfully parsed %d files", len(docs))

            docs = UCLParser.validate_docs_links_out(docs)
            # UCLParser.remove_duplicate_docs(docs)
            docs = UCLParser.add_links_in(docs)

            UCLParser.add_pagerank(docs)

            if not os.path.isdir(docs_cache_dir):
                os.mkdir(docs_cache_dir)
            # Cache documents
            try:
                with open(pickle_file, 'wb') as handle:
                    pickle.dump(docs, handle, protocol=pickle.HIGHEST_PROTOCOL)
                    logger.info("Successfully cached %d documents.", len(docs))
            except Exception as e:
                logger.info("Failed to cache documents: %s", e)

        logger.info('%s %s', MSG_SUCCESS, msg)
        return docs
コード例 #6
0
ファイル: network.py プロジェクト: LaneWei/doom-rl
 def predict(self, s):
     s = process_batch(s)
     v, pi, = self.session.run([self.v, self.pi], {self.s_input: s})
     return v, pi
コード例 #7
0
ファイル: app.py プロジェクト: falopez10/BioCicle
def post_compare_sequence():
    output = {}

    try:

        merged_tree = {'name': '', 'children': {}, 'SCORE': []}

        data = request.get_json()

        if not "batch_size" in data:
            data["batch_size"] = 1

        data["sequences"] = [
            sequence.strip(" \t\n\r") for sequence in data["sequences"]
        ]

        # Detect sequences processed before
        saved_sequences, tmp_sequences = utils.get_unsaved_sequences(
            data["sequences"])

        # Include previously saved sequences
        processed_batch = saved_sequences.copy()

        for saved_sequence in processed_batch:
            utils.get_hierarchy_from_dict(saved_sequence['sequence_id'],
                                          saved_sequence['comparisons'],
                                          target=merged_tree)

        counter = 0
        current_batch_stop = counter
        pieces_left = len(tmp_sequences) > 0

        while pieces_left:

            tmp_sequences = tmp_sequences[current_batch_stop:]

            num_sequences_left = len(tmp_sequences)

            if data["batch_size"] < num_sequences_left:
                current_batch_stop = data["batch_size"]

            else:
                current_batch_stop = num_sequences_left
                pieces_left = False

            # Compare unprocessed sequences
            file_batch = [
                utils.compare_sequence(sequence) for sequence in tmp_sequences
            ]

            counter += data["batch_size"]
            log.datetime_log("{} sequences compared.".format(counter))

            # Generate tree for unprocessed sequences
            merged_tree, unsaved_batch = utils.process_batch(
                tmp_sequences, file_batch, merged_tree)

            processed_batch.extend(unsaved_batch)

        # Prepare output
        hierarchy, aggregated_score = utils.form_hierarchy(merged_tree)
        output["merged_tree"] = hierarchy

        output["taxonomies_batch"] = processed_batch

        log.datetime_log("{} hierarchies formed.".format(counter))

        return jsonify(output)

    except Exception as e:
        output["Error"] = str(e)
        log.datetime_log("Error: {}".format(e))
        return jsonify(output)
コード例 #8
0
ファイル: app.py プロジェクト: falopez10/BioCicle
def upload_file():

    output = {}

    try:
        data = request.get_json()

        if data["file"] is not None and data["filename"] is not None:

            taxonomy = []
            parsed_filename = data["filename"].split(".")[0]
            merged_tree = {'name': '', 'children': {}, 'SCORE': []}

            try:
                file_path = utils.try_to_save_file(data["file"],
                                                   data["filename"])

                log.datetime_log("Succeded saving file.")

                merged_tree, taxonomy = utils.process_batch([parsed_filename],
                                                            [file_path],
                                                            merged_tree)

            except utils.FileExists as e:
                taxonomy, tmp_sequences = utils.get_unsaved_sequences(
                    [parsed_filename])

                if len(taxonomy) == 0:
                    sequence_id = utils.get_sequence_id(data["filename"])
                    if sequence_id is not None:
                        log.datetime_log(
                            "File existed and sequence {} parsed succesfully.".
                            format(sequence_id))
                        taxonomy, tmp_sequences = utils.get_unsaved_sequences(
                            [sequence_id])

                if len(taxonomy) > 0:
                    utils.get_hierarchy_from_dict(taxonomy[0]['sequence_id'],
                                                  taxonomy[0]['comparisons'],
                                                  target=merged_tree)

                else:
                    log.datetime_log(
                        "File existed but sequence not parsed: trying to write a new file."
                    )
                    file_path = ""
                    cont = 0
                    while len(file_path) == 0 and cont < 50:
                        try:
                            file_path = utils.try_to_save_file(
                                data["file"], data["filename"], modifier=cont)

                        except utils.FileExists as e:
                            cont += 1

                    log.datetime_log(
                        "File succesfully saved at {}.".format(file_path))

                    merged_tree, taxonomy = utils.process_batch(
                        [parsed_filename], [file_path], merged_tree)

            # Prepare output
            print("merged_tree")
            print(merged_tree)
            hierarchy, aggregated_score = utils.form_hierarchy(merged_tree)
            print("hierarchy")
            print(hierarchy)
            output["merged_tree"] = hierarchy['children'][0]

            output["taxonomies_batch"] = taxonomy
            return jsonify(output)

    except Exception as e:
        output["Error"] = str(e)
        log.datetime_log("Error: {}".format(e))
        return jsonify(output)
コード例 #9
0
            x = tensor(previous_state, args.device)[None]
            action = int(torch.argmax(model(x).detach().cpu()))

        observation, reward, done, info = env.step(action)

        buff.pop(0)
        buff.append(observation)

        next_state = preprocess(buff)

        REPLAY_MEMORY.pop(0)
        REPLAY_MEMORY.append([previous_state, action, reward, next_state, done])

        #DO THE ACTUAL LEARNING

        prev_states, ys = process_batch(random.sample(REPLAY_MEMORY, args.batch_size), target_model, env.action_space.n, args.gamma, args.device)

        optimizer.zero_grad()
        loss = huber_loss(model(prev_states.to(device=args.device)), ys.to(device=args.device))
        loss.backward()
        optimizer.step()

        episode_loss+=loss.item()
        episode_steps+=1
        episode_reward+=reward

        if args.render_env:
            env.render()

    if args.epsilon > args.min_epsilon:
        args.epsilon*=args.eps_decay