def train_model(train_stage,
                save_model_path,
                model,
                optimizer,
                epochs,
                train_dataset,
                eval_dataset,
                batch_size=1,
                gradient_accumulation_steps=1,
                num_workers=1):
    print('Start Training'.center(60, '*'))
    training_dataloader = data.DataLoader(
        dataset=train_dataset,
        collate_fn=TextCollate(train_dataset),
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=True)
    for epoch in range(1, epochs + 1):
        print('train epoch: ' + str(epoch))
        avg_loss = train_epoch(train_stage, model, optimizer,
                               training_dataloader,
                               gradient_accumulation_steps, epoch)
        print('average_loss:{}'.format(avg_loss))

        eval_model(train_stage,
                   model,
                   eval_dataset,
                   batch_size=batch_size,
                   num_workers=num_workers)
        save_model(save_model_path, model, epoch)
Exemple #2
0
def eval_model(train_stage, model, dataset, batch_size=1, num_workers=1):
    global global_step
    global debug_break
    model.eval()
    dataloader = data.DataLoader(dataset=dataset, collate_fn=TextCollate(dataset),
                                 batch_size=batch_size, num_workers=num_workers, shuffle=False)
    total_loss = 0.0
    correct_sum = 0
    proc_sum = 0
    num_batch = dataloader.__len__()
    print('Evaluating Model...')
    for step, batch in enumerate(dataloader):
        tokens = batch['tokens'].to(device)
        segment_ids = batch['segment_ids'].to(device)
        attn_masks = batch['attn_masks'].to(device)
        labels = batch['labels'].to(device)

        with torch.no_grad():
            loss, logits = model(tokens, token_type_ids=segment_ids, attention_mask=attn_masks, labels=labels,
                                 training_stage=train_stage, inference=False)
        loss = loss.mean()
        loss_val = loss.item()
        total_loss += loss_val
        if debug_break and step > 50:
            break
        if train_stage == 0:
            _, top_index = logits.topk(1)
            correct_sum += (top_index.view(-1) == labels).sum().item()
            proc_sum += labels.shape[0]
    print('eval total avg loss: {}'.format(total_loss / num_batch))
    if train_stage == 0:
        print("Correct Prediction: " + str(correct_sum))
        print("Accuracy Rate: " + format(correct_sum / proc_sum, "0.4f"))
Exemple #3
0
def train_model(train_stage,
                save_model_path,
                master_gpu_id,
                model,
                optimizer,
                epochs,
                train_dataset,
                eval_dataset,
                batch_size=1,
                gradient_accumulation_steps=1,
                use_cuda=False,
                num_workers=1):
    logging.info("Start Training".center(60, "="))
    training_dataloader = data.DataLoader(
        dataset=train_dataset,
        collate_fn=TextCollate(train_dataset),
        pin_memory=use_cuda,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=True)
    for epoch in range(1, epochs + 1):
        logging.info("Training Epoch: " + str(epoch))
        avg_loss = train_epoch(train_stage, master_gpu_id, model, optimizer,
                               training_dataloader,
                               gradient_accumulation_steps, use_cuda)
        logging.info("Average Loss: " + format(avg_loss, "0.4f"))
        eval_model(train_stage,
                   master_gpu_id,
                   model,
                   eval_dataset,
                   batch_size=batch_size,
                   use_cuda=use_cuda,
                   num_workers=num_workers)
        save_model(save_model_path, model, epoch)
Exemple #4
0
def eval_model(train_stage, master_gpu_id, model, dataset, batch_size=1,
               use_cuda=False, num_workers=1):
    global global_step
    global debug_break
    model.eval()
    dataloader = data.DataLoader(dataset=dataset,
                                 collate_fn=TextCollate(dataset),
                                 pin_memory=use_cuda,
                                 batch_size=batch_size,
                                 num_workers=num_workers,
                                 shuffle=False)
    total_loss = 0.0
    correct_sum = 0
    proc_sum = 0
    num_sample = dataloader.dataset.__len__()
    num_batch = dataloader.__len__()
    predicted_probs = []
    true_labels = []
    logging.info("Evaluating Model...")
    infos = []
    for step, batch in enumerate(tqdm(dataloader, unit="batch", ncols=100, desc="Evaluating process: ")):
        texts = batch["texts"]
        tokens = batch["tokens"].cuda(master_gpu_id) if use_cuda else batch["tokens"]
        segment_ids = batch["segment_ids"].cuda(master_gpu_id) if use_cuda else batch["segment_ids"]
        attn_masks = batch["attn_masks"].cuda(master_gpu_id) if use_cuda else batch["attn_masks"]
        labels = batch["labels"].cuda(master_gpu_id) if use_cuda else batch["labels"]
        with torch.no_grad():
            loss, logits = model(tokens, token_type_ids=segment_ids, attention_mask=attn_masks, labels=labels,
                            training_stage=train_stage, inference=False)
        loss = loss.mean()
        loss_val = loss.item()
        total_loss += loss_val
        #writer.add_scalar('eval/loss', total_loss/num_batch, global_step)
        if debug_break and step > 50:
            break
        if train_stage == 0:
            _, top_index = logits.topk(1)
            correct_sum += (top_index.view(-1) == labels).sum().item()
            proc_sum += labels.shape[0]
    logging.info('eval total avg loss:%s', format(total_loss/num_batch, "0.4f"))
    if train_stage == 0:
        logging.info("Correct Prediction: " + str(correct_sum))
        logging.info("Accuracy Rate: " + format(correct_sum / proc_sum, "0.4f"))
Exemple #5
0
def infer_model(master_gpu_id,
                model,
                dataset,
                use_cuda=False,
                num_workers=1,
                inference_speed=None,
                dump_info_file=None):
    global global_step
    global debug_break
    model.eval()
    infer_dataloader = data.DataLoader(dataset=dataset,
                                       collate_fn=TextCollate(dataset),
                                       pin_memory=use_cuda,
                                       batch_size=1,
                                       num_workers=num_workers,
                                       shuffle=False)
    correct_sum = 0
    num_sample = infer_dataloader.dataset.__len__()
    predicted_probs = []
    true_labels = []
    infos = []
    logging.info("Inference Model...")
    cnt = 0
    stime_all = time.time()
    for step, batch in enumerate(
            tqdm(infer_dataloader,
                 unit="batch",
                 ncols=100,
                 desc="Inference process: ")):
        texts = batch["texts"]
        tokens = batch["tokens"].cuda(
            master_gpu_id) if use_cuda else batch["tokens"]
        segment_ids = batch["segment_ids"].cuda(
            master_gpu_id) if use_cuda else batch["segment_ids"]
        attn_masks = batch["attn_masks"].cuda(
            master_gpu_id) if use_cuda else batch["attn_masks"]
        labels = batch["labels"].cuda(
            master_gpu_id) if use_cuda else batch["labels"]
        with torch.no_grad():
            probs, layer_idxes, uncertain_infos = model(
                tokens,
                token_type_ids=segment_ids,
                attention_mask=attn_masks,
                inference=True,
                inference_speed=inference_speed)
        _, top_index = probs.topk(1)

        correct_sum += (top_index.view(-1) == labels).sum().item()
        cnt += 1
        if cnt == 1:
            stime = time.time()
        if dump_info_file != None:
            for label, pred, prob, layer_i, text in zip(
                    labels, top_index.view(-1), probs, [layer_idxes], texts):
                infos.append((label.item(), pred.item(), prob.cpu().numpy(),
                              layer_i, text))
        if debug_break and step > 50:
            break

    time_per = (time.time() - stime) / (cnt - 1)
    time_all = time.time() - stime_all
    acc = format(correct_sum / num_sample, "0.4f")
    logging.info("speed_arg:%s, time_per_record:%s, acc:%s, total_time:%s",
                 inference_speed, format(time_per, '0.4f'), acc,
                 format(time_all, '0.4f'))
    if dump_info_file != None and len(dump_info_file) != 0:
        with open(dump_info_file, 'w') as fw:
            for label, pred, prob, layer_i, text in infos:
                fw.write('\t'.join([str(label),
                                    str(pred),
                                    str(layer_i), text]) + '\n')

    if probs.shape[1] == 2:
        labels_pr = [info[0] for info in infos]
        preds_pr = [info[1] for info in infos]
        precise, recall = eval_pr(labels_pr, preds_pr)
        logging.info("precise:%s, recall:%s", format(precise, '0.4f'),
                     format(recall, '0.4f'))