def __init__(self, vocab_path, bpe_model_path, model_dir, max_prediction_length): bpe_processor = BpeProcessor(bpe_model_path) vocab = Vocab(vocab_path) sequentialization_client = AstSequentializationApiClient('localhost', 5555) device = 'cuda' if torch.cuda.is_available() else 'cpu' self._predictors = { model_file.split('.ckpt')[0]: PredictionPipeline( ThenSectionPredictor( GwtSectionPredictionTransformer.load_from_checkpoint( f'{model_dir}/{model_file}', strict=False, ).to(device).eval(), vocab.get_index(vocab.SOS_TOKEN), vocab.get_index(vocab.EOS_TOKEN), max_prediction_length, ), AstSequenceProcessor(sequentialization_client), bpe_processor, vocab, ) for model_file in os.listdir(model_dir) if model_file.endswith('.ckpt') } self._sampler_loader = sampling.Loader(vocab)
def encode_predefined_dataset_split( data_split_dir_path, bpe_dataset_path, vocab_path, remove_context_declarations=False, target_format=TargetFormat.AST, ): with \ open(f'{data_split_dir_path}/train_ids.txt') as train_data_file, \ open(f'{data_split_dir_path}/validate_ids.txt') as validate_data_file, \ open(f'{data_split_dir_path}/test_ids.txt') as test_data_file: train_ids = [line[:-1] for line in train_data_file.readlines()] validate_ids = [line[:-1] for line in validate_data_file.readlines()] test_ids = [line[:-1] for line in test_data_file.readlines()] vocab = Vocab(vocab_path) sos_index = vocab.get_index(Vocab.SOS_TOKEN) eos_index = vocab.get_index(Vocab.EOS_TOKEN) if remove_context_declarations: ctx_open_id = vocab.get_index(ast_sequence.Token.CONTEXT_OPEN) ctx_close_id = vocab.get_index(ast_sequence.Token.CONTEXT_CLOSE) test_ctx_open_id = vocab.get_index( ast_sequence.Token.TEST_CONTEXT_OPEN) test_ctx_close_id = vocab.get_index( ast_sequence.Token.TEST_CONTEXT_CLOSE) with \ open(bpe_dataset_path) as dataset_file, \ open(f'{data_split_dir_path}/train.jsonl', 'w+') as train_data_file, \ open(f'{data_split_dir_path}/validate.jsonl', 'w+') as validate_data_file, \ open(f'{data_split_dir_path}/test.jsonl', 'w+') as test_data_file, \ open(f'{data_split_dir_path}/validate_code_tokens.jsonl', 'w+') as validate_code_tokens_file, \ open(f'{data_split_dir_path}/test_code_tokens.jsonl', 'w+') as test_code_tokens_file: for json_data in iterate_jsonl(dataset_file): source_data = json_data[TargetFormat.get_source_key(target_format)] target_data = json_data[TargetFormat.get_target_key(target_format)] if source_data is None or target_data is None: continue src_data = [vocab.get_index(token) for token in source_data] trg_data = [vocab.get_index(token) for token in target_data] if remove_context_declarations: if target_format == TargetFormat.CODE: raise NotImplementedError( 'removing context declarations is not supported for target format CODE yet' ) src_data = remove_context_declarations_from_ast_sequence( src_data, ctx_open_id, ctx_close_id, test_ctx_open_id, test_ctx_close_id, ) data = dump_jsonl([ src_data, [sos_index] + trg_data + [eos_index], ]) if json_data['id'] in train_ids: train_data_file.write(data) else: code_tokens = dump_jsonl([src_data, json_data['trgCode']]) \ if json_data['trgCode'] is not None else None if json_data['id'] in validate_ids: validate_data_file.write(data) if code_tokens: validate_code_tokens_file.write(code_tokens) elif json_data['id'] in test_ids: test_data_file.write(data) if code_tokens: test_code_tokens_file.write(code_tokens) else: print( f'- id {json_data["id"]} is not part of any of the splits' )
def create_encoded_dataset_split( data_split_dir_path, bpe_dataset_path, vocab_path, data_split, target_format=TargetFormat.AST, ): if not os.path.exists(data_split_dir_path): os.makedirs(data_split_dir_path) total_line_count = get_file_length(bpe_dataset_path) train_line_count = math.floor(data_split[0] * total_line_count) validation_line_count = math.floor(data_split[1] * total_line_count) test_line_count = math.floor(data_split[2] * total_line_count) train_line_count += total_line_count - ( train_line_count + validation_line_count + test_line_count) train_lines, validation_lines, test_lines = [ set(split) for split in random_split(range(total_line_count), ( train_line_count, validation_line_count, test_line_count)) ] vocab = Vocab(vocab_path) sos_index = vocab.get_index(Vocab.SOS_TOKEN) eos_index = vocab.get_index(Vocab.EOS_TOKEN) with \ open(bpe_dataset_path) as dataset_file, \ open(f'{data_split_dir_path}/train.jsonl', 'w+') as train_data_file, \ open(f'{data_split_dir_path}/validate.jsonl', 'w+') as validate_data_file, \ open(f'{data_split_dir_path}/test.jsonl', 'w+') as test_data_file, \ open(f'{data_split_dir_path}/train_ids.txt', 'w+') as train_data_ids_file, \ open(f'{data_split_dir_path}/validate_ids.txt', 'w+') as validate_data_ids_file, \ open(f'{data_split_dir_path}/test_ids.txt', 'w+') as test_data_ids_file, \ open(f'{data_split_dir_path}/validate_code_tokens.jsonl', 'w+') as validate_code_tokens_file, \ open(f'{data_split_dir_path}/test_code_tokens.jsonl', 'w+') as test_code_tokens_file: line_counter = 0 for json_data in iterate_jsonl(dataset_file): source_data = json_data[TargetFormat.get_source_key(target_format)] target_data = json_data[TargetFormat.get_target_key(target_format)] if source_data is None or target_data is None: continue src_data = [vocab.get_index(token) for token in source_data] data = dump_jsonl([ src_data, [sos_index] + [vocab.get_index(token) for token in target_data] + [eos_index], ]) if line_counter in train_lines: train_data_file.write(data) train_data_ids_file.write(json_data['id'] + '\n') else: code_tokens = dump_jsonl([src_data, json_data['trgCode']]) \ if json_data['trgCode'] is not None else None if line_counter in validation_lines: validate_data_file.write(data) validate_data_ids_file.write(json_data['id'] + '\n') if code_tokens: validate_code_tokens_file.write(code_tokens) elif line_counter in test_lines: test_data_file.write(data) test_data_ids_file.write(json_data['id'] + '\n') if code_tokens: test_code_tokens_file.write(code_tokens) else: raise ValueError( f'id {json_data["id"]} is not part of any of the splits' ) line_counter += 1