Esempio n. 1
0
def cli_main():
    import argparse
    parser = argparse.ArgumentParser(
        description=
        "Downloading/Decompressing CodeSearchNet dataset(s) or Tree-Sitter Library(ies)"
    )
    parser.add_argument("--yaml_file",
                        "-f",
                        type=str,
                        help="load {yaml_file}.yml for train",
                        default='config/python_wan/python')
    parser.add_argument(
        '--out_file',
        '-o',
        type=str,
        help='output generated file',
        default=None,
    )
    args = parser.parse_args()
    yaml_file = os.path.join(os.path.dirname(__file__),
                             '{}.yml'.format(args.yaml_file))
    out_file = args.out_file
    if out_file:
        dirname = os.path.dirname(out_file)
        assert os.path.isdir(dirname)
        os.makedirs(dirname, exist_ok=True)
    LOGGER.info('Load arguments in {}, output gnerated sentences at {}(if None, it won\'t record prediction).' \
                .format(yaml_file, out_file))
    args = load_yaml(yaml_file)
    LOGGER.info(args)

    torch.cuda.set_device(args['distributed_training']['device_id'])
    main(args, out_file)
Esempio n. 2
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        paths = utils.split_paths(args['task']['data'])
        assert len(paths) > 0

        dict = args['task'].get('dict', None)
        dict_type = args['task'].get('dict_type', None)
        if dict is None and dict_type is None:
            # load dictionaries
            src_dict = cls.load_dictionary(
                os.path.join(
                    paths[0],
                    '{}.dict.jsonl'.format(args['task']['source_lang'])))
            tgt_dict = cls.load_dictionary(
                os.path.join(
                    paths[0],
                    '{}.dict.jsonl'.format(args['task']['target_lang'])))
            assert src_dict.pad() == tgt_dict.pad()
            assert src_dict.eos() == tgt_dict.eos()
            assert src_dict.unk() == tgt_dict.unk()
            LOGGER.info('[{}] dictionary: {} types'.format(
                args['task']['source_lang'], len(src_dict)))
            LOGGER.info('[{}] dictionary: {} types'.format(
                args['task']['target_lang'], len(tgt_dict)))
        else:
            raise NotImplementedError
        return cls(args, src_dict, tgt_dict)
Esempio n. 3
0
    def setup_task(cls, args, **kwargs):
        """Setup the task (e.g., load dictionaries).

        Args:
            args (argparse.Namespace): parsed command-line arguments
        """
        paths = utils.split_paths(args['task']['data'])
        assert len(paths) > 0

        share_dict = args['task'].get('share_dict', False)
        if share_dict:
            src_dict = tgt_dict = cls.load_dictionary(
                os.path.join(paths[0], "dict.jsonl"))
        else:
            # load dictionaries
            src_dict = cls.load_dictionary(
                os.path.join(paths[0],
                             f"{args['task']['source_lang']}.dict.jsonl"))
            tgt_dict = cls.load_dictionary(
                os.path.join(paths[0],
                             f"{args['task']['target_lang']}.dict.jsonl"))
            assert src_dict.pad() == tgt_dict.pad()
            assert src_dict.eos() == tgt_dict.eos()
            assert src_dict.unk() == tgt_dict.unk()
            LOGGER.info('[{}] dictionary: {} types'.format(
                args['task']['source_lang'], len(src_dict)))
            LOGGER.info('[{}] dictionary: {} types'.format(
                args['task']['target_lang'], len(tgt_dict)))
        return cls(args, src_dict, tgt_dict)
Esempio n. 4
0
    def _inference_with_bleu(self, generator, sample, model):
        import sacrebleu

        def decode(toks, escape_unk=False):
            s = self.tgt_dict.string(
                toks.int().cpu(),
                self.args['task']['eval_bleu_remove_bpe'],
                escape_unk=escape_unk,
            )
            if self.tokenizer:
                s = self.tokenizer.decode(s)
            return s

        gen_out = self.inference_step(generator, [model], sample, None)
        hyps, refs = [], []
        for i in range(len(gen_out)):
            hyps.append(decode(gen_out[i][0]['tokens']))
            refs.append(
                decode(
                    utils.strip_pad(sample['target'][i], self.tgt_dict.pad()),
                    escape_unk=True,  # don't count <unk> as matches to the hypo
                ))
        if self.args['task']['eval_bleu_print_samples']:
            LOGGER.info('example hypothesis: ' + hyps[0])
            LOGGER.info('example reference: ' + refs[0])
        # tokenize = sacrebleu.DEFAULT_TOKENIZER if not self.args['task']['eval_tokenized_bleu'] else 'none'
        # return sacrebleu.corpus_bleu(hyps, [refs], tokenize=tokenize)
        if self.args['task']['eval_tokenized_bleu']:
            return sacrebleu.corpus_bleu(hyps, [refs], tokenize='none')
        else:
            return sacrebleu.corpus_bleu(hyps, [refs])
Esempio n. 5
0
    def setup_task(cls, args, **kwargs):
        """Setup the task.
        """
        # paths = args.data.split(':')
        paths = utils.split_paths(args['task']['data'])
        assert len(paths) > 0
        dictionary = Dictionary.load(os.path.join(paths[0], 'dict.jsonl'))

        data_path = paths[0]
        if args['task']['langs'] is None:
            languages = sorted([
                name for name in os.listdir(data_path)
                if os.path.isdir(os.path.join(data_path, name))
            ])
        else:
            languages = args['task']['langs']  # .split(',')

        if args['task']['add_lang_token']:
            for lang in languages:
                dictionary.add_symbol('[{}]'.format(lang))

        LOGGER.info("Loading dictionary: {} types".format(len(dictionary)))
        # if not hasattr(args, 'shuffle_instance'):
        #     args.shuffle_instance = False
        return cls(args, dictionary)
Esempio n. 6
0
def gnn_tensorize(datapoint, source_dictionary, edge_types):
    tensorized_data = TensorizedGraphData(
        adjacency_lists=list(__iterate_edge_types(datapoint, edge_types)),
        node_tensorized_data=[
            # enforce_not_None(self.__node_embedding_model.tensorize(ni))
            source_dictionary.index(ni) for ni in datapoint.node_information
        ],
        reference_nodes={
            n: np.array(np.array(refs, dtype=np.int32))
            for n, refs in datapoint.reference_nodes.items()
        },
        num_nodes=len(datapoint.node_information),
    )

    if tensorized_data.num_nodes > 80000:
        LOGGER.warning("Dropping graph with %s nodes." %
                       tensorized_data.num_nodes)
        return None

    num_edges = sum(len(adj) for adj in tensorized_data.adjacency_lists)
    if num_edges > 100000:
        LOGGER.warning("Dropping graph with %s edges." % num_edges)
        return None

    return tensorized_data
Esempio n. 7
0
    def binary_ast_fn(filename, dest_filename, idx, start=0, end=-1, *args):
        kwargs = args[0][0]  # canot feed dict parameters in multi-processing

        dest_filename = dest_filename + str(idx)
        with file_io.open(filename, "r") as reader, file_io.open(dest_filename, 'w') as writer:
            reader.seek(start)
            line = safe_readline(reader)
            while line:
                if end > 0 and reader.tell() > end:
                    break
                ast = json_io.json_loads(line)
                if ast:
                    try:
                        ast = util_ast.value2children(ast)
                        ast = util_ast.remove_root_with_uni_child(ast)
                        root_idx = util_ast.get_root_idx(ast)
                        ast = util_ast.delete_node_with_uni_child(ast, idx=root_idx)
                        root_idx = util_ast.get_root_idx(ast)
                        bin_ast = util_ast.binarize_tree(ast, idx=root_idx)  # to binary ast tree
                        root_idx = util_ast.get_root_idx(ast)
                        bin_ast = util_ast.reset_indices(bin_ast, root_idx)  # reset node indices
                        bin_ast = util_ast.pad_leaf_node(bin_ast, MAX_SUB_TOKEN_LEN)
                    except RecursionError:
                        LOGGER.error('RecursionError, ignore this tree')
                        bin_ast = None
                    except Exception as err:
                        LOGGER.error(err)
                        bin_ast = None
                else:
                    bin_ast = None
                print(json_io.json_dumps(bin_ast), file=writer)
                line = safe_readline(reader)
Esempio n. 8
0
def torch_persistent_save(*args, **kwargs):
    for i in range(3):
        try:
            return torch.save(*args, **kwargs)
        except Exception:
            if i == 2:
                LOGGER.error(traceback.format_exc())
 def check_alignment(alignment, src_len, tgt_len):
     if alignment is None or len(alignment) == 0:
         return False
     if alignment[:, 0].max().item() >= src_len - 1 or alignment[:, 1].max().item() >= tgt_len - 1:
         LOGGER.warning("alignment size mismatch found, skipping alignment!")
         return False
     return True
Esempio n. 10
0
def __collect_sample(ast, MAX_PATH: int):
    def _tokenize(s):
        pattern = re.compile(r"(?<!^)(?=[A-Z])")
        tokenized = pattern.sub("_", s).lower().split("_")
        return list(filter(None, tokenized))[:MAX_SUB_TOKEN_LEN]

    tree_paths = __raw_tree_paths(ast)
    contexts = []
    for tree_path in tree_paths:
        start, connector, finish = tree_path

        start = _tokenize(start)
        finish = _tokenize(finish)

        try:
            connector = [ast[connector[0]]['value']] + \
                        [ast[v]['type'] for v in connector[1:-1]] + \
                        [ast[connector[-1]]['value']]
        except:
            # error path, skip it
            continue

        contexts.append([start, connector, finish])
    try:
        assert len(contexts) > 0, Exception('ast\'s path is None')
        if len(contexts) > MAX_PATH:
            shuffle(contexts)
            contexts = contexts[:MAX_PATH]
        return contexts
    except Exception as err:
        LOGGER.error(err)
        LOGGER.error(ast)
        return None
Esempio n. 11
0
 def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
     if args['preprocess']['dataset_impl'] == "raw":
         in_file = file_name(input_prefix, lang)
         out_dir = args['preprocess']['destdir']
         os.makedirs(out_dir, exist_ok=True)
         LOGGER.info('Copying {} into {}'.format(in_file, out_dir))
         shutil.copy(src=in_file, dst=args['preprocess']['destdir'])
     else:
         in_file = file_name(input_prefix, lang)
         out_file = dest_path(output_prefix, lang)
         os.makedirs(os.path.dirname(out_file), exist_ok=True)
         offsets = find_offsets(in_file, num_workers)
         with Pool(num_workers) as mpool:
             results = [
                 mpool.apply_async(
                     build_dgl_graph,
                     (vocab, in_file, f'{out_file}{worker_id}.mmap',
                      offsets[worker_id], offsets[worker_id + 1]),
                 ) for worker_id in range(num_workers)
             ]
             results = [res.get() for res in results]
         graph_batch = []
         for worker_id in range(num_workers):
             sub_file = f'{out_file}{worker_id}.mmap'
             glist, _ = load_graphs(sub_file)
             graph_batch.extend(glist)
             os.remove(sub_file)
         save_graphs(f'{out_file}.mmap', graph_batch)
Esempio n. 12
0
 def _log_oom(self, exc):
     msg = "OOM: Ran out of memory with exception: {}".format(exc)
     LOGGER.warning(msg)
     if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"):
         for device_idx in range(torch.cuda.device_count()):
             LOGGER.warning(torch.cuda.memory_summary(device=device_idx))
     sys.stderr.flush()
Esempio n. 13
0
    def filter_indices_by_size(
        self, indices, dataset, max_positions=None, ignore_invalid_inputs=False
    ):
        """
        Filter examples that are too large

        Args:
            indices (np.array): original array of sample indices
            dataset (~fairseq.data.FairseqDataset): dataset to batch
            max_positions (optional): max sentence length supported by the
                model (default: None).
            ignore_invalid_inputs (bool, optional): don't raise Exception for
                sentences that are too long (default: False).
        Returns:
            np.array: array of filtered sample indices
        """
        indices, ignored = dataset.filter_indices_by_size(indices, max_positions)
        if len(ignored) > 0:
            if not ignore_invalid_inputs:
                raise Exception(
                    (
                        "Size of sample #{} is invalid (={}) since max_positions={}, "
                        "skip this example with --skip-invalid-size-inputs-valid-test"
                    ).format(ignored[0], dataset.size(ignored[0]), max_positions)
                )
            LOGGER.warning(
                (
                    "{:,} samples have invalid sizes and will be skipped, "
                    "max_positions={}, first few sample ids={}"
                ).format(len(ignored), max_positions, ignored[:10])
            )
        return indices
Esempio n. 14
0
def load_tokens_dataset(
    data_path, split, src, src_dict, tgt, tgt_dict, dataset_impl,
    max_source_positions=None, max_target_positions=None, max_positions=None,
    append_source_eos=False, append_target_eos=False,
    shuffle=False,
):
    src_path = os.path.join(data_path, '{}.{}'.format(split, src))
    src_dataset = _load_dataset(src_path, dataset_impl)
    if max_source_positions is not None:
        src_dataset = TruncateDataset(src_dataset, max_source_positions)
    LOGGER.info('loaded {} examples from: {}'.format(len(src_dataset), src_path))

    tgt_path = os.path.join(data_path, '{}.{}'.format(split, tgt))
    tgt_dataset = _load_dataset(tgt_path, dataset_impl)
    if max_target_positions is not None:
        tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)
    LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset), tgt_path))

    return BertDataset(
        src_dataset, src_dataset.sizes, src_dict,
        tgt_dataset, tgt_dataset.sizes, tgt_dict,
        max_source_positions=max_source_positions, max_target_positions=max_target_positions,
        max_positions=max_positions,
        append_source_eos=append_source_eos, append_target_eos=append_target_eos,
        shuffle=shuffle,
    )
Esempio n. 15
0
def cli_main():
    import argparse
    parser = argparse.ArgumentParser(
        description=
        "Downloading/Decompressing CodeSearchNet dataset(s) or Tree-Sitter Library(ies)"
    )
    parser.add_argument("--yaml_file",
                        "-f",
                        type=str,
                        help="load {language}.yml for train",
                        default='config/csn_feng/ruby')
    parser.add_argument(
        '--out_file',
        '-o',
        type=str,
        help='output generated file',
        default=None,
    )
    args = parser.parse_args()
    yaml_file = os.path.join(os.path.dirname(__file__),
                             f"{args.yaml_file}.yml")
    out_file = None if args.out_file is None else recursive_expanduser(
        args.out_file)
    LOGGER.info('Load arguments in {}'.format(yaml_file))
    args = load_yaml(yaml_file)
    LOGGER.info(args)
    main(args, out_file)
Esempio n. 16
0
def cli_main():
    # modal_path = '~/.ncc/demo/summarization/neural_transformer/python_wan.pt'
    modal_path = '~/.ncc/demo/summarization/seq2seq/python_wan.pt'
    code = "def positional(max_positional_args):\n\tdef positional_decorator(wrapped):\n\t\[email protected](wrapped)\n\t\tdef positional_wrapper(*args, **kwargs):\n\t\t\tif (len(args) > max_posi      tional_args):\n\t\t\t\tplural_s = ''\n\t\t\t\tif (max_positional_args != 1):\n\t\t\t\t\tplural_s = 's'\n\t\t\t\tmessage = ('%s()\ttakes\tat\tmost\t%d\tpositional\targument%s\t(%d\tgive      n)' % (wrapped.__name__, max_positional_args, plural_s, len(args)))\n\t\t\t\tif (positional_parameters_enforcement == POSITIONAL_EXCEPTION):\n\t\t\t\t\traise TypeError(message)\n\t\t\t      \telif (positional_parameters_enforcement == POSITIONAL_WARNING):\n\t\t\t\t\tlogger.warning(message)\n\t\t\t\telse:\n\t\t\t\t\tpass\n\t\t\treturn wrapped(*args, **kwargs)\n\t\treturn p      ositional_wrapper\n\tif isinstance(max_positional_args, six.integer_types):\n\t\treturn positional_decorator\n\telse:\n\t\t(args, _, _, defaults) = inspect.getargspec(max_positional_ar      gs)\n\t\treturn positional((len(args) - len(defaults)))(max_positional_args)"
    # ground truth: "a decorator to declare that only the first n arguments my be positional ."

    # modal_path = '~/.ncc/demo/completion/seqrnn/py150.pt'
    # code = "body_content = self._serialize.body(parameters, 'ServicePrincipalCreateParameters')\nrequest = self._client.post(url, query_parameters)\nresponse = self._client.send( request, header_parameters, body_content, operation_config)"
    # ground truth: "(request, header_parameters, body_content, **operation_config)"

    import argparse
    parser = argparse.ArgumentParser(description="Command Interface")
    parser.add_argument("--model",
                        "-m",
                        type=str,
                        help="pytorch model path",
                        default=modal_path)
    parser.add_argument("--input",
                        "-i",
                        type=str,
                        help="model input",
                        default=code)
    args = parser.parse_args()
    args.model = os.path.expanduser(args.model)

    model_output = main(args.model, args.input)
    LOGGER.info(model_output)
Esempio n. 17
0
def load_langpair_dataset(
    data_path, split,
    src, src_dict,
    tgt, tgt_dict,
    dataset_impl,

    left_pad_source, left_pad_target,
    max_source_positions, max_target_positions,
    prepend_bos=False, load_alignments=False,
    truncate_source=False, append_source_id=False,
    truncate_target=False,
    append_eos_to_target=False,
    portion=None,
):
    src_path = os.path.join(data_path, '{}.{}'.format(split, src))
    src_dataset = _load_dataset(path=src_path, impl=dataset_impl, dict=src_dict)

    if portion is not None and split == 'train':
        LOGGER.info('set {}.{} portion to {}'.format(split, src, portion))
        src_dataset = PortionDataset(src_dataset, portion)

    tgt_path = os.path.join(data_path, '{}.{}'.format(split, tgt))
    tgt_dataset = _load_dataset(path=tgt_path, impl=dataset_impl, dict=tgt_dict)
    if truncate_target:
        LOGGER.info('truncate {}.{} to {}'.format(split, tgt, max_target_positions))
        tgt_dataset = TruncateDataset(tgt_dataset, max_target_positions)

    if prepend_bos:
        assert hasattr(src_dict, "bos_index") and hasattr(tgt_dict, "bos_index")
        src_dataset = PrependTokenDataset(src_dataset, src_dict.bos())
        if tgt_dataset is not None:
            tgt_dataset = PrependTokenDataset(tgt_dataset, tgt_dict.bos())

    eos = None
    if append_source_id:
        if tgt_dataset is not None:
            tgt_dataset = AppendTokenDataset(tgt_dataset, tgt_dict.index('[{}]'.format(tgt)))
        eos = tgt_dict.index('[{}]'.format(tgt))

    if portion is not None and split == 'train':
        LOGGER.info('set {}.{} portion to {}'.format(split, tgt, portion))
        tgt_dataset = PortionDataset(tgt_dataset, portion)

    tgt_dataset_sizes = tgt_dataset.sizes if tgt_dataset is not None else None

    LOGGER.info('loaded {} examples from: {}'.format(len(src_dataset), src_path))
    LOGGER.info('loaded {} examples from: {}'.format(len(tgt_dataset), tgt_path))
    return GraphLanguagePairDataset(
        src_dataset, src_dataset.sizes, src_dict,
        tgt_dataset, tgt_dataset_sizes, tgt_dict,
        left_pad_source=left_pad_source,
        left_pad_target=left_pad_target,
        max_source_positions=max_source_positions,
        max_target_positions=max_target_positions,
        align_dataset=None, eos=eos,
        remove_eos_from_source=True,
        append_eos_to_target=append_eos_to_target,
        shuffle=True,

    )
Esempio n. 18
0
 def build_dataset(args: Dict, src_dicts: Dict[str, Dictionary],
                   tgt_dict: Dictionary):
     """build dataset for modal"""
     for modality, src_dict in src_dicts.items():
         LOGGER.info('Building dataset for {}'.format(modality))
         for lang, data_prefs in args['preprocess']['dataprefs'].items():
             make_all(modality, src_dict, lang, data_prefs)
Esempio n. 19
0
    def make_binary_dataset(vocab: Dictionary, input_file, output_file,
                            attr: str, num_workers: int):
        """make binary dataset"""
        LOGGER.info("[{}] Dictionary: {} types".format(attr, len(vocab) - 1))
        n_seq_tok = [0, 0]
        replaced = Counter()  # save un-recorded tokens

        def merge_result(worker_result):
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]

        # split a file into different parts
        # if use multi-processing, we first process 2nd to last file
        # 1.txt -> 10 processor, 0(p0)(0-99), 100(p1)(100-199), ...
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            # p1-pN -> (1 bin-txt, 1 idx), (N bin-txt, N idx)
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_file, worker_id)
                pool.apply_async(binarize,
                                 (args, input_file, vocab, prefix, attr,
                                  offsets[worker_id], offsets[worker_id + 1]),
                                 callback=merge_result)
            pool.close()
        # process 1th file, if multi-processing available. If not, process all file
        # p0 -> 0,end
        ds_file = '{}.mmap'.format(output_file)
        ds = indexed_dataset.make_builder(
            ds_file,
            impl=args['preprocess']['dataset_impl'],
            vocab_size=len(vocab))
        merge_result(
            Binarizer.binarize_bpe(input_file,
                                   vocab,
                                   lambda t: ds.add_item(t),
                                   offset=0,
                                   end=offsets[1]))
        if num_workers > 1:
            # p1-pN
            pool.join()
            # merge sub-processors' index and data files into final files and delete them.
            for worker_id in range(1, num_workers):
                temp_file_path = "{}{}".format(output_file, worker_id)
                ds.merge_file_(temp_file_path)
                # idx, txt
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))
        ds.finalize('{}.idx'.format(output_file))

        LOGGER.info(
            "[{}] {}: {} sents, {} tokens, BPE no replaced token".format(
                attr,
                input_file,
                n_seq_tok[0],
                n_seq_tok[1],
            ))
Esempio n. 20
0
 def setup_task(cls, args, **kwargs):
     """Setup the task.
     """
     dictionary = Dictionary.load(os.path.join(args.data, 'dict.txt'))
     LOGGER.info('dictionary: {} types'.format(len(dictionary)))
     if not hasattr(args, 'shuffle_instance'):
         args.shuffle_instance = False
     return cls(args, dictionary)
Esempio n. 21
0
 def setup_task(cls, args, **kwargs):
     paths = utils.split_paths(args['task']['data'])
     assert len(paths) > 0
     dictionary = cls.load_dictionary(
         os.path.join(paths[0],
                      'dict.{}.json'.format(args['task']['source_lang'])))
     LOGGER.info('dictionary: {} types'.format(len(dictionary)))
     return cls(args, dictionary)
Esempio n. 22
0
def normalize_program(fn, **kwargs):
    if not isinstance(fn, (str, bytes)):
        LOGGER.error(f"normalize_program got non-str: {type(fn)}, {fn}")
    fn = NEWLINE_REGEX.sub(rf" {constants.EOL}", fn)
    if kwargs.get('remove_eol', False):
        fn = str.replace(fn, constants.EOL, ' ')
    fn = WHITESPACE_REGEX.sub(" ", fn)
    return fn
Esempio n. 23
0
    def valid_step(self, sample, raise_oom=False):
        """Do forward pass in evaluation mode."""
        if self._dummy_batch == "DUMMY":
            self._dummy_batch = sample

        with torch.no_grad():
            self.model.eval()
            self.criterion.eval()

            sample = self._prepare_sample(sample)
            if sample is None:
                sample = self._prepare_sample(self._dummy_batch)
                is_dummy_batch = True
            else:
                is_dummy_batch = False

            try:
                _loss, sample_size, logging_output = self.task.valid_step(
                    sample, self.model, self.criterion)
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if not raise_oom:
                        LOGGER.warning(
                            "ran out of memory in validation step, retrying batch"
                        )
                        for p in self.model.parameters():
                            if p.grad is not None:
                                p.grad = None  # free some memory
                        if self.cuda:
                            torch.cuda.empty_cache()
                        return self.valid_step(sample, raise_oom=True)
                raise e

            logging_outputs = [logging_output]
            if is_dummy_batch:
                sample_size *= 0  # multiply by 0 to preserve device

        # gather logging outputs from all replicas
        if self.args['distributed_training']['distributed_world_size'] > 1:
            logging_outputs, (sample_size, ) = self._aggregate_logging_outputs(
                logging_outputs,
                sample_size,
                ignore=is_dummy_batch,
            )
            if 'bleu' in logging_outputs[0]:
                logging_outputs[0]['bleu'] /= self.args[
                    'distributed_training']['distributed_world_size']
            if 'rouge_l' in logging_outputs[0]:
                logging_outputs[0]['rouge_l'] /= self.args[
                    'distributed_training']['distributed_world_size']
            if 'meteor' in logging_outputs[0]:
                logging_outputs[0]['meteor'] /= self.args[
                    'distributed_training']['distributed_world_size']
        # log validation stats
        logging_output = self._reduce_and_log_stats(logging_outputs,
                                                    sample_size)
        return logging_output
Esempio n. 24
0
 def setup_task(cls, args, **kwargs):
     paths = utils.split_paths(args['task']['data'])
     assert len(paths) > 0
     # dictionary = cls.load_dictionary(os.path.join(paths[0], 'codesearchnet_ruby.dict.txt'))
     dictionary = cls.load_dictionary(
         os.path.join(paths[0], 'csnjs_8k_9995p_unigram_url.dict.txt'))
     # dictionary = cls.load_dictionary(args['dataset']['srcdict'])
     LOGGER.info('dictionary: {} types'.format(len(dictionary)))
     return cls(args, dictionary)
Esempio n. 25
0
    def load_dataset(self, split, epoch=1, combine=False, **kwargs):
        """Load a given dataset split.

        Args:
            split (str): name of the split (e.g., train, valid, test)
        """
        paths = utils.split_paths(self.args.data)
        assert len(paths) > 0
        data_path = paths[(epoch - 1) % len(paths)]
        split_path = os.path.join(data_path, split)

        dataset = data_utils.load_indexed_dataset(
            split_path,
            self.dictionary,
            self.args.dataset_impl,
            combine=combine,
        )
        if dataset is None:
            raise FileNotFoundError('Dataset not found: {} ({})'.format(
                split, split_path))

        dataset = StripTokenDataset(dataset, self.dictionary.eos())

        # create continuous blocks of tokens
        dataset = TokenBlockDataset(
            dataset,
            dataset.sizes,
            self.args.tokens_per_sample -
            2,  # one less for <s> and one for </s>
            pad=self.dictionary.pad(),
            eos=self.dictionary.eos(),
            break_mode=self.args.sample_break_mode,
            document_sep_len=0)

        # prepend beginning-of-sentence token (<s>, equiv. to [CLS] in BERT)
        dataset = PrependTokenDataset(dataset, self.source_dictionary.bos())
        dataset = AppendTokenDataset(dataset, self.source_dictionary.eos())

        mask_whole_words = get_whole_word_mask(self.args, self.source_dictionary) \
            if self.args.mask_length != 'subword' else None

        self.datasets[split] = DenoisingDataset(
            dataset,
            dataset.sizes,
            self.dictionary,
            self.mask_idx,
            mask_whole_words,
            shuffle=self.args.shuffle_instance,
            seed=self.seed,
            args=self.args)
        LOGGER.info(
            "Split: {0}, Loaded {1} samples of denoising_dataset".format(
                split,
                len(self.datasets[split]),
            ))
Esempio n. 26
0
 def __init__(self, args, params):
     super().__init__(args)
     fused_adam_cls = get_fused_adam_class()
     use_fused_adam = (not args['optimization']['adam']['use_old_adam']
                       and fused_adam_cls is not None
                       and torch.cuda.is_available())
     if use_fused_adam:
         LOGGER.info('using FusedAdam')
         self._optimizer = fused_adam_cls(params, **self.optimizer_config)
     else:
         self._optimizer = Adam(params, **self.optimizer_config)
Esempio n. 27
0
def flatten(raw_file, dst_dir, mode):
    """flatten attributes of raw data"""
    data_frame = pd.read_csv(raw_file)
    attrs = data_frame.columns.values.tolist()[1:-1]
    LOGGER.info('Cast attributes({}) of OpenCL-{} dataset'.format(attrs, lang))
    for attr in attrs:
        dst_file = os.path.join(dst_dir, f"{mode}.{attr}")
        data = getattr(data_frame, attr).values.tolist()
        with file_io.open(dst_file, 'w') as writer:
            for line in data:
                print(json_io.json_dumps(line), file=writer)
Esempio n. 28
0
 def make_dataset(vocab, input_prefix, output_prefix, lang, num_workers=1):
     if args['preprocess']['dataset_impl'] == "raw":
         in_file = file_name(input_prefix, lang)
         out_dir = args['preprocess']['destdir']
         os.makedirs(out_dir, exist_ok=True)
         LOGGER.info('Copying {} into {}'.format(in_file, out_dir))
         shutil.copy(src=in_file, dst=args['preprocess']['destdir'])
     else:
         in_file = file_name(input_prefix, lang)
         out_file = dest_path(output_prefix, lang)
         os.makedirs(os.path.dirname(out_file), exist_ok=True)
         make_binary_dataset(vocab, in_file, out_file, num_workers)
Esempio n. 29
0
def download(name):
    if name in TREE_SITTER_SO_FILE_ARCHIVE_MAP:
        url = TREE_SITTER_SO_FILE_ARCHIVE_MAP[name]
        LOGGER.info(f"Download {name}.so from {url}")
        gdown.download(url=url,
                       output=os.path.join(__TREE_SITTER_LIBS_DIR__,
                                           f"{name}.so"))
    else:
        raise FileExistsError(
            f"{name}.so has not been uploaded to the server. Please, build {name}.so with " \
            f" {os.path.dirname(__file__)}/build_so.py"
        )
Esempio n. 30
0
def flatten(raw_dir, lang, mode, flatten_dir, attrs, num_cores):
    """flatten attributes of raw data"""
    LOGGER.info('Cast attributes({}) of {}-{} dataset'.format(
        attrs, lang, mode))
    with Pool(num_cores) as mpool:
        result = [
            mpool.apply_async(flatten_attrs,
                              (raw_file, flatten_dir, lang, mode, set(attrs)))
            for raw_file in PathManager.ls(
                os.path.join(raw_dir, lang, mode, '*.jsonl.gz'))
        ]
        result = [res.get() for res in result]