Пример #1
0
def torch_persistent_save(*args, **kwargs):
    for i in range(3):
        try:
            return torch.save(*args, **kwargs)
        except Exception:
            if i == 2:
                LOGGER.error(traceback.format_exc())
Пример #2
0
def __collect_sample(ast, MAX_PATH: int):
    def _tokenize(s):
        pattern = re.compile(r"(?<!^)(?=[A-Z])")
        tokenized = pattern.sub("_", s).lower().split("_")
        return list(filter(None, tokenized))[:MAX_SUB_TOKEN_LEN]

    tree_paths = __raw_tree_paths(ast)
    contexts = []
    for tree_path in tree_paths:
        start, connector, finish = tree_path

        start = _tokenize(start)
        finish = _tokenize(finish)

        try:
            connector = [ast[connector[0]]['value']] + \
                        [ast[v]['type'] for v in connector[1:-1]] + \
                        [ast[connector[-1]]['value']]
        except:
            # error path, skip it
            continue

        contexts.append([start, connector, finish])
    try:
        assert len(contexts) > 0, Exception('ast\'s path is None')
        if len(contexts) > MAX_PATH:
            shuffle(contexts)
            contexts = contexts[:MAX_PATH]
        return contexts
    except Exception as err:
        LOGGER.error(err)
        LOGGER.error(ast)
        return None
Пример #3
0
    def binary_ast_fn(filename, dest_filename, idx, start=0, end=-1, *args):
        kwargs = args[0][0]  # canot feed dict parameters in multi-processing

        dest_filename = dest_filename + str(idx)
        with file_io.open(filename, "r") as reader, file_io.open(dest_filename, 'w') as writer:
            reader.seek(start)
            line = safe_readline(reader)
            while line:
                if end > 0 and reader.tell() > end:
                    break
                ast = json_io.json_loads(line)
                if ast:
                    try:
                        ast = util_ast.value2children(ast)
                        ast = util_ast.remove_root_with_uni_child(ast)
                        root_idx = util_ast.get_root_idx(ast)
                        ast = util_ast.delete_node_with_uni_child(ast, idx=root_idx)
                        root_idx = util_ast.get_root_idx(ast)
                        bin_ast = util_ast.binarize_tree(ast, idx=root_idx)  # to binary ast tree
                        root_idx = util_ast.get_root_idx(ast)
                        bin_ast = util_ast.reset_indices(bin_ast, root_idx)  # reset node indices
                        bin_ast = util_ast.pad_leaf_node(bin_ast, MAX_SUB_TOKEN_LEN)
                    except RecursionError:
                        LOGGER.error('RecursionError, ignore this tree')
                        bin_ast = None
                    except Exception as err:
                        LOGGER.error(err)
                        bin_ast = None
                else:
                    bin_ast = None
                print(json_io.json_dumps(bin_ast), file=writer)
                line = safe_readline(reader)
Пример #4
0
def normalize_program(fn, **kwargs):
    if not isinstance(fn, (str, bytes)):
        LOGGER.error(f"normalize_program got non-str: {type(fn)}, {fn}")
    fn = NEWLINE_REGEX.sub(rf" {constants.EOL}", fn)
    if kwargs.get('remove_eol', False):
        fn = str.replace(fn, constants.EOL, ' ')
    fn = WHITESPACE_REGEX.sub(" ", fn)
    return fn
Пример #5
0
def download(name):
    if name in BPE_MODEL_ARCHIVE_MAP:
        url = BPE_MODEL_ARCHIVE_MAP[name]
        LOGGER.info(f"Download {name} BPE model from {url}")
        out_file = os.path.join(__BPE_DIR__, f"{name}.tar.gz")
        gdown.download(url=url, output=out_file)
        try:
            with tarfile.open(out_file) as reader:
                reader.extractall(__BPE_DIR__)
            os.remove(out_file)
        except tarfile.ExtractError as err:
            LOGGER.error(__BPE_DIR__)
            LOGGER.warning(f"{name}.tar.gz is corrupted, please contact us.")
    else:
        raise FileExistsError(f"No {name}.tar.gz in the server. Please build your own BPE models. " \
                              f"Once they are built, you can upload them into the server.")
Пример #6
0
    def _check(library: Library):
        lowest_version = parse_version(library.version)
        if library.name == 'Python':
            version = platform.python_version()
        else:
            try:
                lib = importlib.import_module(library.key)
            except:
                LOGGER.error(
                    f"You do not install [{library.name}], please install it.")

            try:
                version = lib.__version__
            except Exception as err:
                LOGGER.error(
                    f"Cannot get version of [{library.key}], please check it via \"pip list\""
                )
                return
        current_version = parse_version(version)
        try:
            assert current_version >= lowest_version
            print(
                f"[{library.name}] version({version}) >= required version({library.version})."
            )
        except AssertionError as err:
            LOGGER.error(
                f"{'!' * 5} [{library.name}] version({version}) < required version({library.version}). {'!' * 5}"
            )
Пример #7
0
def main(args):
    task = tasks.get_task(args['preprocess']['task'])
    LOGGER.info('mkdir for {} task'.format(args['preprocess']['task']))
    os.makedirs(args['preprocess']['destdir'], exist_ok=True)

    def file_name(prefix, lang):
        fname = prefix
        if lang is not None:
            fname += ".{lang}".format(lang=lang)
        return fname

    def dest_path(prefix, lang):
        return os.path.join(args['preprocess']['destdir'],
                            file_name(prefix, lang))

    def dict_path(lang):
        return dest_path("dict", lang) + ".txt"

    target = not args['preprocess']['only_source']

    # 1. build vocabulary from bpe directory
    if not args['preprocess']['srcdict'] and os.path.exists(
            dict_path(args['preprocess']['source_lang'])):
        raise FileExistsError(dict_path(args['preprocess']['source_lang']))
    if target and not args['preprocess']['tgtdict'] and os.path.exists(
            dict_path(args['preprocess']['target_lang'])):
        raise FileExistsError(dict_path(args['preprocess']['target_lang']))

    if args['preprocess']['joined_dictionary']:
        assert not args['preprocess']['srcdict'] or not args['preprocess']['tgtdict'], \
            "cannot use both --srcdict and --tgtdict with --joined-dictionary"
        if args['preprocess']['srcdict']:
            src_dict = task.load_dictionary(args['preprocess']['srcdict'])
        elif args['preprocess']['tgtdict']:
            src_dict = task.load_dictionary(args['preprocess']['tgtdict'])
        else:
            LOGGER.error(
                'Please run sentencepiece to generate the model and vocab files first.'
            )
            exit()

        tgt_dict = src_dict

        # Load sentencepiece (sp) module
        if args['preprocess']['src_sp']:
            src_sp = spm.SentencePieceProcessor()
            src_sp.load(args['preprocess']['src_sp'])
        elif args['preprocess']['tgt_sp']:
            src_sp = spm.SentencePieceProcessor()
            src_sp.load(args['preprocess']['tgt_sp'])
        else:
            LOGGER.error('Please assign the sentencepiece model path.')
            exit()
        tgt_sp = src_sp

    else:
        if args['preprocess']['srcdict'] and args['preprocess']['src_sp']:
            src_dict = task.load_dictionary(args['preprocess']['srcdict'])
            src_sp = spm.SentencePieceProcessor()
            src_sp.load(args['preprocess']['src_sp'])
        else:
            LOGGER.error(
                'Please run sentencepiece to generate the model and vocab files first.'
            )
            exit()

        if target:
            if args['preprocess']['tgtdict'] and args['preprocess']['tgt_sp']:
                tgt_dict = task.load_dictionary(args['preprocess']['tgtdict'])
                tgt_sp = spm.SentencePieceProcessor()
                tgt_sp.load(args['preprocess']['tgt_sp'])
            else:
                # assert args['preprocess']['trainpref'], "--trainpref must be set if --tgtdict is not specified"
                # tgt_dict = build_dictionary([train_path(args['preprocess']['target_lang'])], tgt=True)
                LOGGER.error(
                    'Please run sentencepiece to generate the model and vocab files first.'
                )
                exit()
        else:
            tgt_dict = None
            tgt_sp = None
    # exit()
    # 2. ***************build dataset********************
    def make_binary_dataset(vocab: Dictionary, input_file, output_file,
                            attr: str, num_workers: int):
        """make binary dataset"""
        LOGGER.info("[{}] Dictionary: {} types".format(attr, len(vocab) - 1))
        n_seq_tok = [0, 0]
        replaced = Counter()  # save un-recorded tokens

        def merge_result(worker_result):
            replaced.update(worker_result["replaced"])
            n_seq_tok[0] += worker_result["nseq"]
            n_seq_tok[1] += worker_result["ntok"]

        # split a file into different parts
        # if use multi-processing, we first process 2nd to last file
        # 1.txt -> 10 processor, 0(p0)(0-99), 100(p1)(100-199), ...
        offsets = Binarizer.find_offsets(input_file, num_workers)
        pool = None
        if num_workers > 1:
            # p1-pN -> (1 bin-txt, 1 idx), (N bin-txt, N idx)
            pool = Pool(processes=num_workers - 1)
            for worker_id in range(1, num_workers):
                prefix = "{}{}".format(output_file, worker_id)
                pool.apply_async(binarize,
                                 (args, input_file, vocab, prefix, attr,
                                  offsets[worker_id], offsets[worker_id + 1]),
                                 callback=merge_result)
            pool.close()
        # process 1th file, if multi-processing available. If not, process all file
        # p0 -> 0,end
        ds_file = '{}.mmap'.format(output_file)
        ds = indexed_dataset.make_builder(
            ds_file,
            impl=args['preprocess']['dataset_impl'],
            vocab_size=len(vocab))
        merge_result(
            Binarizer.binarize_bpe(input_file,
                                   vocab,
                                   lambda t: ds.add_item(t),
                                   offset=0,
                                   end=offsets[1]))
        if num_workers > 1:
            # p1-pN
            pool.join()
            # merge sub-processors' index and data files into final files and delete them.
            for worker_id in range(1, num_workers):
                temp_file_path = "{}{}".format(output_file, worker_id)
                ds.merge_file_(temp_file_path)
                # idx, txt
                os.remove(indexed_dataset.data_file_path(temp_file_path))
                os.remove(indexed_dataset.index_file_path(temp_file_path))
        ds.finalize('{}.idx'.format(output_file))

        LOGGER.info(
            "[{}] {}: {} sents, {} tokens, BPE no replaced token".format(
                attr,
                input_file,
                n_seq_tok[0],
                n_seq_tok[1],
            ))

    def make_dataset(vocab,
                     sp,
                     input_prefix,
                     output_prefix,
                     lang,
                     num_workers=1):
        if args['preprocess']['dataset_impl'] == 'raw':
            with open(file_name(input_prefix, lang), 'rb') as input_file, open(
                    dest_path(output_prefix,
                              lang), 'w', encoding="utf-8") as output_file:
                for line in input_file.readlines(
                )[0:100]:  # TODO only for debug
                    line = ujson.loads(line)
                    line = normalize_program(line)
                    line = sp.EncodeAsPieces(line)
                    output_file.write(ujson.dumps(line) + '\n')
        else:
            in_file = file_name(input_prefix, lang)
            out_file = dest_path(output_prefix, lang)
            os.makedirs(os.path.dirname(out_file), exist_ok=True)
            make_binary_dataset(vocab, in_file, out_file, lang, num_workers)

    def make_all(lang, vocab, sp):
        if args['preprocess']['trainpref']:
            make_dataset(vocab,
                         sp,
                         args['preprocess']['trainpref'],
                         "train",
                         lang,
                         num_workers=args['preprocess']['workers'])
        if args['preprocess']['validpref']:
            for k, validpref in enumerate(
                    args['preprocess']['validpref'].split(",")):
                outprefix = "valid{}".format(k) if k > 0 else "valid"
                make_dataset(vocab,
                             sp,
                             validpref,
                             outprefix,
                             lang,
                             num_workers=args['preprocess']['workers'])
        if args['preprocess']['testpref']:
            for k, testpref in enumerate(
                    args['preprocess']['testpref'].split(",")):
                outprefix = "test{}".format(k) if k > 0 else "test"
                make_dataset(vocab,
                             sp,
                             testpref,
                             outprefix,
                             lang,
                             num_workers=args['preprocess']['workers'])

    # # 2. build dataset
    make_all(args['preprocess']['source_lang'], src_dict, src_sp)
    if target:
        make_all(args['preprocess']['target_lang'], tgt_dict, tgt_sp)
Пример #8
0
    def train_step(self, samples, raise_oom=False):
        """Do forward, backward and parameter update."""
        if self._dummy_batch == "DUMMY":
            self._dummy_batch = samples[0]

        self._set_seed()
        self.model.train()
        self.criterion.train()
        self.zero_grad()

        metrics.log_start_time("train_wall", priority=800, round=0)

        # forward and backward pass
        logging_outputs, sample_size, ooms = [], 0, 0
        for i, sample in enumerate(samples):
            sample, is_dummy_batch = self._prepare_sample(sample)

            def maybe_no_sync():
                """
                Whenever *samples* contains more than one mini-batch, we
                want to accumulate gradients locally and only call
                all-reduce in the last backwards pass.
                """
                if (self.args['distributed_training']['distributed_world_size']
                        > 1 and hasattr(self.model, "no_sync")
                        and i < len(samples) - 1):
                    return self.model.no_sync()
                else:
                    return contextlib.ExitStack()  # dummy contextmanager

            try:
                with maybe_no_sync():
                    # forward and backward
                    loss, sample_size_i, logging_output = self.task.train_step(
                        sample=sample,
                        model=self.model,
                        criterion=self.criterion,
                        optimizer=self.optimizer,
                        update_num=self.get_num_updates(),
                        ignore_grad=is_dummy_batch,
                    )
                    del loss

                logging_outputs.append(logging_output)
                sample_size += sample_size_i

                # emptying the CUDA cache after the first step can
                # reduce the chance of OOM
                if self.cuda and self.get_num_updates() == 0:
                    torch.cuda.empty_cache()
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self._log_oom(e)
                    if raise_oom:
                        raise e
                    LOGGER.warning(
                        "attempting to recover from OOM in forward/backward pass"
                    )
                    ooms += 1
                    self.zero_grad()
                    if self.cuda:
                        torch.cuda.empty_cache()
                    if self.args['distributed_training'][
                            'distributed_world_size'] == 1:
                        return None
                else:
                    raise e

        if is_dummy_batch:
            if torch.is_tensor(sample_size):
                sample_size.zero_()
            else:
                sample_size *= 0.0  # multiply by 0 to preserve device

        if torch.is_tensor(sample_size):
            sample_size = sample_size.float()
        else:
            sample_size = float(sample_size)

        # gather logging outputs from all replicas
        if self._sync_stats():
            logging_outputs, (sample_size,
                              ooms) = self._aggregate_logging_outputs(
                                  logging_outputs,
                                  sample_size,
                                  ooms,
                                  ignore=is_dummy_batch,
                              )

        overflow = False
        try:
            with torch.autograd.profiler.record_function("reduce-grads"):
                # reduce gradients across workers
                self.optimizer.all_reduce_grads(self.model)
                if utils.has_parameters(self.criterion):
                    self.optimizer.all_reduce_grads(self.criterion)

            with torch.autograd.profiler.record_function("multiply-grads"):
                # multiply gradients by (data_parallel_size / sample_size) since
                # DDP normalizes by the number of data parallel workers for
                # improved fp16 precision.
                # Thus we get (sum_of_gradients / sample_size) at the end.
                # In case of fp16, this step also undoes loss scaling.
                # (Debugging note: Some optimizers perform this scaling on the
                # fly, so inspecting model.parameters() or optimizer.params may
                # still show the original, unscaled gradients.)
                num = (
                    self.args['distributed_training']['distributed_world_size']
                    if not self.args['optimization']['use_bmuf']
                    or self._sync_stats() else 1)
                self.optimizer.multiply_grads(num / (sample_size or 1.0))

            with torch.autograd.profiler.record_function("clip-grads"):
                # clip grads
                grad_norm = self.clip_grad_norm(
                    self.args['optimization']['clip_norm'])

            # check that grad norms are consistent across workers
            if not self.args['optimization']['use_bmuf']:
                self._check_grad_norms(grad_norm)
            if not torch.isfinite(grad_norm).all():
                if self.args['common'].get('amp', False):
                    overflow = True
                else:
                    raise FloatingPointError("gradients are Nan/Inf")

            with torch.autograd.profiler.record_function("optimizer"):
                # take an optimization step
                self.optimizer.step()
                if self.args['common'].get('amp', False) and overflow:
                    if self._amp_retries == self.args['common'][
                            'amp_batch_retries']:
                        LOGGER.info("AMP: skipping this batch.")
                        self._amp_retries = 0
                    else:
                        self._amp_retries += 1
                        return self.train_step(
                            samples,
                            raise_oom)  # recursion to feed in same batch

        except FloatingPointError:
            # re-run the forward and backward pass with hooks attached to print
            # out where it fails
            self.zero_grad()
            with NanDetector(self._model):
                for _, sample in enumerate(samples):
                    sample, _ = self._prepare_sample(sample)
                    self.task.train_step(sample,
                                         self.model,
                                         self.criterion,
                                         self.optimizer,
                                         self.get_num_updates(),
                                         ignore_grad=False)
            raise
        except OverflowError as e:
            overflow = True
            LOGGER.info(
                f"NOTE: gradient overflow detected, ignoring gradient, {str(e)}"
            )
            grad_norm = torch.tensor(0.0).cuda()
            self.zero_grad()
        except RuntimeError as e:
            if "out of memory" in str(e):
                self._log_oom(e)
                LOGGER.error("OOM during optimization, irrecoverable")
            raise e

        logging_output = None
        if not overflow:
            self.set_num_updates(self.get_num_updates() + 1)

            if self.cuda and self.cuda_env is not None:
                # log minimum free memory over the iteration
                gb_used = torch.cuda.max_memory_allocated(
                ) / 1024 / 1024 / 1024
                torch.cuda.reset_peak_memory_stats()
                gb_free = self.cuda_env.total_memory_in_GB - gb_used
                metrics.log_scalar("gb_free",
                                   gb_free,
                                   priority=1500,
                                   round=1,
                                   weight=0)

            # log stats
            logging_output = self._reduce_and_log_stats(
                logging_outputs,
                sample_size,
                grad_norm,
            )

            # clear CUDA cache to reduce memory fragmentation
            if (self.cuda and self.args['common']['empty_cache_freq'] > 0
                    and ((self.get_num_updates() +
                          self.args['common']['empty_cache_freq'] - 1) %
                         self.args['common']['empty_cache_freq']) == 0):
                torch.cuda.empty_cache()

        if self.args['common']['fp16'] or self.args['common'].get(
                'amp', False):
            metrics.log_scalar(
                "loss_scale",
                (self.optimizer.scaler.loss_scale
                 if self.args['common']['fp16'] else
                 self.optimizer.scaler.get_scale()),
                priority=700,
                round=4,
                weight=0,
            )

        metrics.log_stop_time("train_wall")
        return logging_output