Beispiel #1
0
def train(args):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    logging.info("Create train_loader and val_loader.........")
    vocab_json = os.path.join(args.input_dir, 'vocab.json')
    train_pt = os.path.join(args.input_dir, 'train.pt')
    val_pt = os.path.join(args.input_dir, 'test.pt')
    train_loader = DataLoader(vocab_json,
                              train_pt,
                              args.batch_size,
                              training=True)
    val_loader = DataLoader(vocab_json, val_pt, args.batch_size)
    vocab = train_loader.vocab
    kb = DataForSPARQL(os.path.join(args.input_dir, 'kb.json'))
    logging.info("Create model.........")
    config_class, model_class, tokenizer_class = (BartConfig,
                                                  BartForConditionalGeneration,
                                                  BartTokenizer)
    tokenizer = tokenizer_class.from_pretrained(args.ckpt)
    model = model_class.from_pretrained(args.ckpt)
    model = model.to(device)
    logging.info(model)
    rule_executor = RuleExecutor(vocab, os.path.join(args.input_dir,
                                                     'kb.json'))
    # validate(args, kb, model, val_loader, device, tokenizer, rule_executor)
    predict(args, kb, model, val_loader, device, tokenizer, rule_executor)
Beispiel #2
0
def test(args):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    print('load test data')
    vocab_json = os.path.join(args.input_dir, 'vocab.json')
    test_pt = os.path.join(args.input_dir, 'test.pt')
    data = DataLoader(vocab_json, test_pt, 128, training=False)
    vocab = data.vocab
    kb = DataForSPARQL(os.path.join(args.input_dir, 'kb.json'))

    print('load model')
    model = SPARQLParser(vocab, args.dim_word, args.dim_hidden,
                         args.max_dec_len)
    model = model.to(device)
    model.load_state_dict(torch.load(os.path.join(args.save_dir, 'model.pt')))

    f = open(os.path.join(args.save_dir, 'predict.txt'), 'w')
    for batch in tqdm(data, total=len(data)):
        question, choices, sparql, answer = batch
        question = question.to(device)
        pred_sparql = model(question)

        pred_sparql = pred_sparql.cpu().numpy().tolist()
        for s in pred_sparql:
            s = [vocab['sparql_idx_to_token'][i] for i in s]
            end_idx = len(s)
            if '<END>' in s:
                end_idx = s.index('<END>')
            s = ' '.join(s[1:end_idx])
            s = postprocess_sparql_tokens(s)
            answer = str(get_sparql_answer(s, kb))
            f.write(answer + '\n')
    f.close()
Beispiel #3
0
def test_sparql(args):
    # check whether the SPARQL engine is correct, with the training set
    vocab_json = os.path.join(args.input_dir, 'vocab.json')
    train_pt = os.path.join(args.input_dir, 'train.pt')
    data = DataLoader(vocab_json, train_pt, args.batch_size, training=False)
    kb = DataForSPARQL(os.path.join(args.input_dir, 'kb.json'))

    count, correct = 0, 0
    for batch in tqdm(data, total=len(data)):
        question, choices, sparql, answer = batch
        pred_sparql = sparql

        answer = answer.cpu().numpy().tolist()
        pred_sparql = pred_sparql.cpu().numpy().tolist()
        for a, s in zip(answer, pred_sparql):
            given_answer = data.vocab['answer_idx_to_token'][a]
            s = [data.vocab['sparql_idx_to_token'][i] for i in s]
            end_idx = len(s)
            if '<END>' in s:
                end_idx = s.index('<END>')
            s = ' '.join(s[1:end_idx])
            s = postprocess_sparql_tokens(s)
            pred_answer = get_sparql_answer(s, kb)
            is_match = whether_equal(given_answer, pred_answer)
            count += 1
            if is_match:
                correct += 1
            else:
                print(given_answer, pred_answer)
Beispiel #4
0
                    node, SparqlEngine.PRED_YEAR)
                res = query_virtuoso(sp)
                res = [[binding[v] for v in res.vars]
                       for binding in res.bindings]
            if v_type == 'quantity':
                value = float(res[0][2].value)
                unit = res[0][1].value
            else:
                value = res[0][0].value
            value = ValueClass(v_type, value, unit)
            parsed_answer = str(value)
        elif parse_type == 'bool':
            parsed_answer = 'yes' if res else 'no'
        elif parse_type == 'pred':
            parsed_answer = str(res[0][0])
            parsed_answer = parsed_answer.replace('_', ' ')
        return parsed_answer
    except Exception:
        return None


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # input and output
    parser.add_argument('--kb_path', required=True)
    parser.add_argument('--ttl_path', required=True)
    args = parser.parse_args()

    data = DataForSPARQL(args.kb_path)
    engine = SparqlEngine(data, args.ttl_path)
Beispiel #5
0
def train(args):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    logging.info("Create train_loader and val_loader.........")
    vocab_json = os.path.join(args.input_dir, 'vocab.json')
    train_pt = os.path.join(args.input_dir, 'train.pt')
    val_pt = os.path.join(args.input_dir, 'val.pt')
    train_loader = DataLoader(vocab_json,
                              train_pt,
                              args.batch_size,
                              training=True)
    val_loader = DataLoader(vocab_json, val_pt, 64)

    vocab = train_loader.vocab
    kb = DataForSPARQL(os.path.join(args.input_dir, 'kb.json'))
    rule_executor = RuleExecutor(vocab, os.path.join(args.input_dir,
                                                     'kb.json'))
    logging.info("Create model.........")
    config_class, model_class, tokenizer_class = (BartConfig,
                                                  BartForConditionalGeneration,
                                                  BartTokenizer)
    tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
    model = model_class.from_pretrained(args.model_name_or_path)
    model = model.to(device)
    logging.info(model)
    t_total = len(
        train_loader
    ) // args.gradient_accumulation_steps * args.num_train_epochs  # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    bart_param_optimizer = list(model.named_parameters())
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in bart_param_optimizer
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay,
        'lr':
        args.learning_rate
    }, {
        'params': [
            p for n, p in bart_param_optimizer
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0,
        'lr':
        args.learning_rate
    }]
    args.warmup_steps = int(t_total * args.warmup_proportion)
    optimizer = optim.AdamW(optimizer_grouped_parameters,
                            lr=args.learning_rate,
                            eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=t_total)
    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt")):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

        # Train!
        logging.info("***** Running training *****")
        logging.info("  Num examples = %d", len(train_loader.dataset))
        logging.info("  Num Epochs = %d", args.num_train_epochs)
        logging.info("  Gradient Accumulation steps = %d",
                     args.gradient_accumulation_steps)
        logging.info("  Total optimization steps = %d", t_total)

    global_step = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path
                      ) and "checkpoint" in args.model_name_or_path:
        # set global_step to gobal_step of last saved checkpoint from model path
        global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
        epochs_trained = global_step // (len(train_loader) //
                                         args.gradient_accumulation_steps)
        steps_trained_in_current_epoch = global_step % (
            len(train_loader) // args.gradient_accumulation_steps)
        logging.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logging.info("  Continuing training from epoch %d", epochs_trained)
        logging.info("  Continuing training from global step %d", global_step)
        logging.info("  Will skip the first %d steps in the first epoch",
                     steps_trained_in_current_epoch)
    logging.info('Checking...')
    logging.info("===================Dev==================")
    validate(args, kb, model, val_loader, device, tokenizer, rule_executor)
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    prefix = 25984
    for _ in range(int(args.num_train_epochs)):
        pbar = ProgressBar(n_total=len(train_loader), desc='Training')
        for step, batch in enumerate(train_loader):
            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue
            model.train()
            batch = tuple(t.to(device) for t in batch)
            pad_token_id = tokenizer.pad_token_id
            source_ids, source_mask, y = batch[0], batch[1], batch[-2]
            y_ids = y[:, :-1].contiguous()
            lm_labels = y[:, 1:].clone()
            lm_labels[y[:, 1:] == pad_token_id] = -100

            inputs = {
                "input_ids": source_ids.to(device),
                "attention_mask": source_mask.to(device),
                "decoder_input_ids": y_ids.to(device),
                "lm_labels": lm_labels.to(device),
            }
            outputs = model(**inputs)
            loss = outputs[0]
            loss.backward()
            pbar(step, {'loss': loss.item()})
            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1
            if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                logging.info("===================Dev==================")
                validate(args, kb, model, val_loader, device, tokenizer,
                         rule_executor)

            #     logging.info("===================Test==================")
            #     evaluate(args, model, test_loader, device)
            if args.save_steps > 0 and global_step % args.save_steps == 0:
                # Save model checkpoint
                output_dir = os.path.join(
                    args.output_dir,
                    "checkpoint-{}".format(global_step + prefix))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                model_to_save = (
                    model.module if hasattr(model, "module") else model
                )  # Take care of distributed/parallel training
                model_to_save.save_pretrained(output_dir)
                torch.save(args, os.path.join(output_dir, "training_args.bin"))
                logging.info("Saving model checkpoint to %s", output_dir)
                tokenizer.save_vocabulary(output_dir)
                torch.save(optimizer.state_dict(),
                           os.path.join(output_dir, "optimizer.pt"))
                torch.save(scheduler.state_dict(),
                           os.path.join(output_dir, "scheduler.pt"))
                logging.info("Saving optimizer and scheduler states to %s",
                             output_dir)
        logging.info("\n")
        if 'cuda' in str(device):
            torch.cuda.empty_cache()
    return global_step, tr_loss / global_step
Beispiel #6
0
def train(args):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    logging.info("Create train_loader and val_loader.........")
    vocab_json = os.path.join(args.input_dir, 'vocab.json')
    train_pt = os.path.join(args.input_dir, 'train.pt')
    val_pt = os.path.join(args.input_dir, 'val.pt')
    train_loader = DataLoader(vocab_json,
                              train_pt,
                              args.batch_size,
                              training=True)
    val_loader = DataLoader(vocab_json, val_pt, args.batch_size)
    vocab = train_loader.vocab
    kb = DataForSPARQL(os.path.join(args.input_dir, 'kb.json'))

    logging.info("Create model.........")
    model = SPARQLParser(vocab, args.dim_word, args.dim_hidden,
                         args.max_dec_len)
    model = model.to(device)
    logging.info(model)

    optimizer = optim.Adam(model.parameters(),
                           args.lr,
                           weight_decay=args.weight_decay)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
                                               milestones=[5, 50],
                                               gamma=0.1)

    # validate(args, kb, model, val_loader, device)
    meters = MetricLogger(delimiter="  ")
    best_acc = 0
    logging.info("Start training........")
    for epoch in range(args.num_epoch):
        model.train()
        for iteration, batch in enumerate(train_loader):
            iteration = iteration + 1

            question, choices, sparql, answer = [x.to(device) for x in batch]
            loss = model(question, sparql)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            meters.update(loss=loss.item())

            if iteration % (len(train_loader) // 100) == 0:
                logging.info(
                    meters.delimiter.join([
                        "progress: {progress:.3f}",
                        "{meters}",
                        "lr: {lr:.6f}",
                    ]).format(
                        progress=epoch + iteration / len(train_loader),
                        meters=str(meters),
                        lr=optimizer.param_groups[0]["lr"],
                    ))

        acc = validate(args, kb, model, val_loader, device)
        scheduler.step()
        if acc and acc > best_acc:
            best_acc = acc
            logging.info(
                "\nupdate best ckpt with acc: {:.4f}".format(best_acc))
            torch.save(model.state_dict(),
                       os.path.join(args.save_dir, 'model.pt'))