예제 #1
0
 def create_outfile_for_arg(self, fname: str, identifier: str):
     path = '{}/{}/{}.pkl'.format(self.args.outdir, fname, identifier)
     if os.path.exists(path) and self.args.append_arg_level:
         self.file_map[fname][identifier] = IndexedFileWriter(path,
                                                              mode='a')
     else:
         self.file_map[fname][identifier] = IndexedFileWriter(path,
                                                              mode='w')
예제 #2
0
    def init(self):
        if os.path.exists(self.args.outfile) and self.args.append:
            self.fwriter = IndexedFileWriter(self.args.outfile, mode='a')
        else:
            self.fwriter = IndexedFileWriter(self.args.outfile, mode='w')

        self.blacklist = set()
        self.whitelist = set()
        self.error_cnt_map = collections.defaultdict(int)
    def load_data(self, path, is_training_data=False):
        reader = IndexedFileReader(path)
        num_training_points = self.params.args.get('num_training_points', -1)
        if self.params.args.get('use_memory', False):
            result = self.process_raw_graphs(reader, is_training_data)
            reader.close()
            if is_training_data and num_training_points != -1:
                return result[:num_training_points]

            return result

        if self.params.args.get('use_disk', False):
            wpath = path + '.processed'

            if not os.path.exists(wpath):
                w = IndexedFileWriter(wpath)
                for d in tqdm.tqdm(reader,
                                   desc='Dumping processed graphs to disk'):
                    processed = self.process_raw_graph(d)
                    if processed is not None:
                        w.append(pickle.dumps(processed))

                w.close()
                reader.close()

            result = IndexedFileReader(wpath)

            if is_training_data and self.params.get('load_shuffle', False):
                result.shuffle()

            if is_training_data and num_training_points != -1:
                result.indices = result.indices[:num_training_points]

            return result

        #  We won't pre-process anything. We'll convert on-the-fly. Saves memory but is very slow and wasteful
        if is_training_data and self.params.get('load_shuffle', False):
            reader.shuffle()

        if is_training_data and num_training_points != -1:
            reader.indices = reader.indices[:num_training_points]

        reader.set_loader(lambda x: self.process_raw_graph(pickle.load(x)))
        return reader
예제 #4
0
 def init(self):
     if os.path.exists(self.args.outfile) and self.args.append:
         self.fwriter = IndexedFileWriter(self.args.outfile, mode='a')
     else:
         self.fwriter = IndexedFileWriter(self.args.outfile, mode='w')
예제 #5
0
class FunctionSeqDataGenerator:
    """
    This generator implements parallelized computation of training data for training function sequence predictors.
    """
    class Worker:
        args: ArgNamespace = None
        generators: Dict[str, BaseGenerator] = None

        @classmethod
        def init(cls, args: ArgNamespace):
            cls.args = args
            cls.generators = load_generators()
            if cls.args.debug:
                logger.info("Loaded {} generators in process {}".format(
                    len(cls.generators), os.getpid()))

        @classmethod
        def process(cls, raw_data: Dict):
            if raw_data is None:
                return None

            try:
                graph = RelationGraph(GraphOptions())
                inputs = raw_data['inputs']
                output = raw_data['output']
                graph.from_input_output(inputs, output)

                encoding = graph.get_encoding()
                encoding['label'] = raw_data['function_sequence']
                return encoding

            except SilentException:
                return None

            except Exception as e:
                try:
                    logger.warn("Failed for {}".format(raw_data))
                    logging.exception(e)
                    return None

                except:
                    pass

                return None

    def __init__(self, args: ArgNamespace):
        self.args = args
        self.fwriter: IndexedFileWriter = None

    def raw_data_iterator(self):
        def valid(dpoint):
            for depth, record in enumerate(dpoint['generator_tracking']):
                record = record.record
                for k, v in record.items():
                    if k.startswith("ext_") and v[
                            'source'] == 'intermediates' and v['idx'] >= depth:
                        return False

            return True

        with open(self.args.raw_data_path, 'rb') as f:
            while True:
                try:
                    point = pickle.load(f)
                    if 'args' not in point and 'generator_tracking' not in point:
                        logger.warn(
                            "Raw data points are missing the 'args' attribute. Did you generate this "
                            "data using the smart-generators branch of autopandas?"
                        )
                        return

                    if valid(point):
                        yield point

                except EOFError:
                    break

    def init(self):
        if os.path.exists(self.args.outfile) and self.args.append:
            self.fwriter = IndexedFileWriter(self.args.outfile, mode='a')
        else:
            self.fwriter = IndexedFileWriter(self.args.outfile, mode='w')

    def process_result(self, training_point: Dict):
        self.fwriter.append(pickle.dumps(training_point))

    def generate(self):
        self.init()
        num_generated = 0
        num_processed = 0
        num_raw_points = -1
        if os.path.exists(self.args.raw_data_path + '.index'):
            reader = IndexedFileReader(self.args.raw_data_path)
            num_raw_points = len(reader)
            reader.close()

        start_time = time.time()
        with pebble.ProcessPool(
                max_workers=self.args.processes,
                initializer=FunctionSeqDataGenerator.Worker.init,
                initargs=(self.args, )) as p:

            chunksize = self.args.processes * self.args.chunksize
            for chunk in misc.grouper(chunksize, self.raw_data_iterator()):
                future = p.map(FunctionSeqDataGenerator.Worker.process,
                               chunk,
                               timeout=self.args.task_timeout)
                res_iter = future.result()

                idx = -1
                while True:
                    idx += 1
                    if idx < len(chunk) and chunk[idx] is not None:
                        num_processed += 1

                    try:
                        result = next(res_iter)
                        if chunk[idx] is None:
                            continue

                        if result is not None:
                            self.process_result(result)
                            num_generated += 1

                    except StopIteration:
                        break

                    except TimeoutError as error:
                        pass

                    except Exception as e:
                        try:
                            logger.warn("Failed for", chunk[idx])
                            logging.exception(e)

                        except:
                            pass

                    finally:

                        speed = round(
                            num_processed / (time.time() - start_time), 1)
                        if num_raw_points != -1:
                            time_remaining = round(
                                (num_raw_points - num_processed) / speed, 1)
                        else:
                            time_remaining = '???'

                        logger.log(
                            "Generated/Processed : {}/{} ({}/s, TTC={}s)".
                            format(num_generated, num_processed, speed,
                                   time_remaining),
                            end='\r')

            p.stop()
            try:
                p.join(10)
            except:
                pass

        self.fwriter.close()

        logger.log("\n-------------------------------------------------")
        logger.info("Total Time : {:.2f}s".format(time.time() - start_time))
        logger.info(
            "Generated {} training points from {} raw data points".format(
                num_generated, num_processed))
예제 #6
0
class RawDataGenerator:
    """
    This generator implements parallelized computation of raw training data that contains random I/O examples
    along with the programs producing the output, and the generator choices made in producing that program
    """
    class Worker:
        args: ArgNamespace = None
        generators: Dict[str, BaseGenerator] = None

        @classmethod
        def init(cls, args: ArgNamespace):
            cls.args = args
            cls.generators = load_randomized_generators()
            if cls.args.debug:
                logger.info("Loaded {} generators in process {}".format(
                    len(cls.generators), os.getpid()))

        @classmethod
        def process(cls, named_seqs: List[List[str]]):
            if named_seqs is None:
                return 0, None

            seqs: List[List[BaseGenerator]] = [
                list(map(lambda x: cls.generators[x], s)) for s in named_seqs
            ]
            max_seq_trials = cls.args.max_seq_trials
            results: List[Dict] = []

            for idx, seq in enumerate(seqs):
                engine = RandProgEngine(seq, cls.args)
                for trial in range(max_seq_trials):
                    try:
                        spec: ExplorationSpec = engine.generate()
                    except Exception as e:
                        if cls.args.debug:
                            logger.warn("Encountered exception for",
                                        named_seqs[idx])
                            logger.log(e)
                            logging.exception(e)

                        continue

                    if spec is None:
                        continue

                    dpoint = {
                        'inputs': spec.inputs,
                        'output': spec.output,
                        'intermediates': spec.intermediates,
                        'program_str': str(spec.program),
                        'program': spec.program,
                        'function_sequence': named_seqs[idx],
                        'generator_tracking': spec.tracking
                    }

                    # print("-" * 50)
                    # print(dpoint)
                    # print("-" * 50)
                    # print([t.record for t in spec.tracking])
                    # print(spec.program)

                    #  Confirm it's picklable. Sometimes, unpickling throws an error
                    #  when the main process is receiving the msg, and things break down
                    #  in a very, very nasty manner
                    #  TODO : Can we switch to dill while using multiprocessing/pebble?
                    try:
                        a = pickle.dumps(dpoint)
                        pickle.loads(a)
                    except:
                        continue

                    results.append(dpoint)
                    break

            return len(named_seqs), results

    def __init__(self, args: ArgNamespace):
        self.args = args
        self.fwriter: IndexedFileWriter = None
        self.blacklist: Set[Tuple[str]] = set()
        self.whitelist: Set[Tuple[str]] = set()
        self.error_cnt_map: Dict[Tuple[str],
                                 int] = collections.defaultdict(int)
        self.sequences: List[List[str]] = None

    def load_sequences(self) -> List[List[str]]:
        generators: Dict[str, BaseGenerator] = load_randomized_generators()
        generator_name_map: Dict[str,
                                 List[str]] = collections.defaultdict(list)
        for k, v in generators.items():
            generator_name_map[v.name].append(v.qual_name)
            generator_name_map[v.qual_name].append(v.qual_name)

        sequences_src: str = self.args.sequences
        unimplemented_funcs: Set[str] = set()
        if sequences_src.endswith(".pkl"):
            with open(sequences_src, 'rb') as f:
                sequences: List[List[str]] = list(map(list, pickle.load(f)))

        else:
            sequences: List[List[str]] = [
                list(i.split(':')) for i in sequences_src.split(',')
            ]

        def get_valid_sequences(seq: List[str]):
            for i in seq:
                if i not in generator_name_map:
                    unimplemented_funcs.add(i)
                    return

            if not (self.args.min_depth <= len(seq) <= self.args.max_depth):
                return

            for seq in itertools.product(*[generator_name_map[i]
                                           for i in seq]):
                yield list(seq)

        final_sequences: List[List[str]] = []
        for seq in sequences:
            final_sequences.extend(get_valid_sequences(seq))

        for i in unimplemented_funcs:
            logger.warn("Generator not implemented for : {}".format(i))

        logger.info("Found {} sequences. "
                    "Filtered out {}. "
                    "Returning {}.".format(
                        len(sequences),
                        len(sequences) - len(final_sequences),
                        len(final_sequences)))
        return final_sequences

    def gen_named_seqs(self) -> Generator[List[List[str]], Any, Any]:
        while True:
            self.blacklist -= self.whitelist
            if len(self.blacklist) > 0:
                for seq in self.blacklist:
                    logger.warn(
                        "Blacklisting {} because of too many errors".format(
                            seq))

                self.sequences = [
                    i for i in self.sequences if tuple(i) not in self.blacklist
                ]
                self.blacklist = set()

            for seq in self.sequences:
                yield [seq]

            if self.args.no_repeat:
                break

    def init(self):
        if os.path.exists(self.args.outfile) and self.args.append:
            self.fwriter = IndexedFileWriter(self.args.outfile, mode='a')
        else:
            self.fwriter = IndexedFileWriter(self.args.outfile, mode='w')

        self.blacklist = set()
        self.whitelist = set()
        self.error_cnt_map = collections.defaultdict(int)

    def process_dpoint(self, dpoint: Dict):
        self.fwriter.append(pickle.dumps(dpoint))

    def report_error_seqs(self, seqs: List[List[str]]):
        if seqs is None:
            return

        for seq in seqs:
            key = tuple(seq)
            if key in self.whitelist:
                continue

            self.error_cnt_map[key] += self.args.max_seq_trials
            if self.error_cnt_map[key] > self.args.blacklist_threshold:
                self.blacklist.add(key)

    def generate(self):
        self.init()
        num_generated = 0
        num_processed = 0
        num_required = self.args.num_training_points
        self.sequences = self.load_sequences()
        start_time = time.time()
        speed = 0
        time_remaining = 'inf'

        with pebble.ProcessPool(max_workers=self.args.processes,
                                initializer=RawDataGenerator.Worker.init,
                                initargs=(self.args, )) as p:

            #  First do smaller chunksizes to allow the blacklist to take effect
            chunksize = self.args.processes * self.args.chunksize

            if self.args.blacklist_threshold == -1:
                chunksize_blacklist = chunksize
            else:
                chunksize_blacklist = max(
                    (self.args.blacklist_threshold //
                     self.args.max_seq_trials), 1) * len(self.sequences)

            for chunk in misc.grouper([chunksize_blacklist, chunksize],
                                      self.gen_named_seqs()):
                if not p.active:
                    break

                future = p.map(RawDataGenerator.Worker.process,
                               chunk,
                               timeout=self.args.task_timeout)
                res_iter = future.result()

                idx = -1
                while True:
                    idx += 1
                    if num_generated >= num_required:
                        p.stop()
                        try:
                            p.join(10)
                        except:
                            pass
                        break

                    try:
                        returned = next(res_iter)
                        if returned is None:
                            self.report_error_seqs(chunk[idx])
                            continue

                        num_input_seqs, results = returned
                        num_processed += num_input_seqs
                        if results is not None and len(results) > 0:
                            for seq in chunk[idx]:
                                self.whitelist.add(tuple(seq))

                            for result in results:
                                num_generated += 1
                                self.process_dpoint(result)

                            speed = round(
                                num_generated / (time.time() - start_time), 1)
                            time_remaining = round(
                                (num_required - num_generated) / speed, 1)

                        elif num_input_seqs > 0:
                            self.report_error_seqs(chunk[idx])

                        logger.log("Num Generated : {} ({}/s, TTC={}s)".format(
                            num_generated, speed, time_remaining),
                                   end='\r')

                    except StopIteration:
                        break

                    except TimeoutError as error:
                        pass

                    except Exception as e:
                        logger.warn("Failed for", chunk[idx])

            p.stop()
            try:
                p.join(10)
            except:
                pass

        self.fwriter.close()
        logger.log("\n-------------------------------------------------")
        logger.info("Total Time : {:.2f}s".format(time.time() - start_time))
        logger.info("Number of sequences processed :", num_processed)
        logger.info("Number of training points generated :", num_generated)