コード例 #1
0
ファイル: utils.py プロジェクト: CGCL-codes/naturalcc
def initialize_from_checkpoint(args, model):
    if args['checkpoint'].get('init_checkpoint', False) and PathManager.exists(args['checkpoint']['init_checkpoint']):
        with open(args['checkpoint']['init_checkpoint'], 'rb') as reader:
            state = torch.load(reader)
            pretrained_params = state['model']
            del state
        init_params = model.state_dict()
        for module_name, module_param in pretrained_params.items():
            if module_name in init_params:
                if init_params[module_name].data.size() == module_param.data.size():
                    init_params[module_name].data.copy_(module_param.data)
                else:
                    # emebedding
                    token_num = module_param.size(0)
                    # init token embedding
                    init_params[module_name].data[:token_num, ...].copy_(module_param.data[:token_num, ...])
        LOGGER.info(f"Restore parameters from {args['checkpoint']['init_checkpoint']}.")
    else:
        LOGGER.info(f"{args['checkpoint']['init_checkpoint']} does not exist.")
コード例 #2
0
    def __init__(self, SO_FILE, LANGUAGE, to_lower=False, operators_file=None):
        self.parser = Parser()
        try:
            assert PathManager.exists(SO_FILE), FileExistsError(
                f"{SO_FILE} does not exist, automatically download TreeSitter parse file {LANGUAGE}.so."
            )
        except FileExistsError as err:
            LOGGER.warning(err)
            from ncc.hub.tree_sitter.download import download
            download(LANGUAGE)

        if LANGUAGE == 'csharp':
            LANGUAGE = 'c_sharp'
        self.parser.set_language(Language(SO_FILE, LANGUAGE))
        self.LANGUAGE = LANGUAGE
        self.to_lower = to_lower

        if operators_file is None:
            operators_file = os.path.join(os.path.dirname(__file__),
                                          'operators.json')
        with open(operators_file, 'r') as reader:
            self.operators = json_io.json_load(reader)
コード例 #3
0
def load_model_ensemble_and_task(filenames,
                                 arg_overrides=None,
                                 task=None,
                                 strict=True,
                                 suffix=''):
    from ncc import tasks

    ensemble = []
    for filename in filenames:
        filename = filename.replace(".pt", suffix + ".pt")
        if not PathManager.exists(filename):
            raise IOError("Model file not found: {}".format(filename))
        state = load_checkpoint_to_cpu(filename, arg_overrides)

        args = state["args"]
        if task is None:
            task = tasks.setup_task(args)

        # build model for ensemble
        model = task.build_model(args)
        model.load_state_dict(state["model"], strict=strict, args=args)
        ensemble.append(model)
    return ensemble, args, task
コード例 #4
0
        "--attrs", "-a",
        default=[
            'ast', 'dfs',
            # 'edtree'
        ],
        # default=[ ],
        type=str, nargs='+', help="attrs: raw_ast, ...",
    )
    parser.add_argument(
        "--cores", "-c", default=cpu_count(), type=int, help="cpu cores for flatten raw data attributes",
    )
    args = parser.parse_args()
    # print(args)

    dest_raw_attrs = {
        'ast': 'code',
        'dfs': 'ast',
        'edtree': 'ast',
    }

    for lang, mode in itertools.product(args.languages, MODES):
        for tgt_attr in args.attrs:
            src_attr = dest_raw_attrs[tgt_attr]
            src_filename = os.path.join(args.attributes_dir, lang, f"{mode}.{src_attr}")
            if PathManager.exists(src_filename):
                tgt_filename = os.path.join(args.attributes_dir, lang, f"{mode}.{tgt_attr}")
                LOGGER.info('Generating {}'.format(tgt_filename))
                process(src_filename, tgt_filename, num_workers=args.cores, lang=lang, so_dir=args.so_dir)
            else:
                LOGGER.info('{} does exist'.format(src_filename))
コード例 #5
0
 def exists(path):
     return PathManager.exists(index_file_path(path))
コード例 #6
0
from ncc import (
    __TREE_SITTER_LIBS_DIR__,
    LOGGER,
)
from ncc.utils.path_manager import PathManager

# define your config
YOUR_LANGUAGE = 'csharp'
TREE_SITTER_LIB_URL = 'https://github.com/tree-sitter/tree-sitter-c-sharp/archive/master.zip'
os.makedirs(__TREE_SITTER_LIBS_DIR__, exist_ok=True)
so_file = os.path.join(__TREE_SITTER_LIBS_DIR__, f'{YOUR_LANGUAGE}.so')

# download
lib_filename = os.path.join(__TREE_SITTER_LIBS_DIR__, f'{YOUR_LANGUAGE}.zip')
if PathManager.exists(lib_filename):
    PathManager.rm(lib_filename)
LOGGER.info(
    f"Download TreeSitter-{YOUR_LANGUAGE}-Parser from {TREE_SITTER_LIB_URL}")
wget.download(TREE_SITTER_LIB_URL, lib_filename)

# decompress
decompress_dir = os.path.join(__TREE_SITTER_LIBS_DIR__, 'tmp')
with zipfile.ZipFile(lib_filename, 'r') as zip_file:
    zip_file.extractall(path=decompress_dir)
lib_dir = os.path.join(decompress_dir, os.listdir(decompress_dir)[0])

# build
LOGGER.info(
    f"Build {YOUR_LANGUAGE}.so, and save it at {__TREE_SITTER_LIBS_DIR__}")
Language.build_library(