예제 #1
0
def analyse(model,
            dataset,
            batch_size,
            device,
            analyzer,
            path_name,
            buckted=False):
    pred_holder = []
    golden_holder = []
    length_holder = []
    sentence_holder = []
    model.eval()

    with torch.no_grad():
        for step, data in enumerate(
                iterate_data(dataset, batch_size, bucketed=buckted)):
            words, labels, masks = data['WORD'].squeeze().to(device), \
                                   data['POS'].squeeze().to(device), \
                                   data['MASK'].squeeze().to(device)
            corr, preds = model.get_acc(words, labels, masks)
            if batch_size == 1:
                pred_holder.append(preds.tolist())
                golden_holder.append(labels.cpu().numpy().tolist())
                length_holder.append(np.sum(masks.cpu().numpy()))
                sentence_holder.append(words.cpu().numpy().tolist())
            else:
                pred_holder += preds.astype(int).tolist()
                golden_holder += labels.cpu().numpy().astype(int).tolist()
                length_holder += np.sum(masks.cpu().numpy(),
                                        axis=-1).astype(int).tolist()
                sentence_holder += words.cpu().numpy().astype(int).tolist()
    analyzer.error_rate(sentence_holder, pred_holder, golden_holder,
                        length_holder, path_name)
    def train(best_epoch, thread=6, aim_epoch=0):
        epoch = best_epoch[0] + 1
        while epoch - best_epoch[0] <= thread:
            epoch_loss = 0
            model.train()
            for step, data in enumerate(
                    iterate_data(train_dataset,
                                 batch_size,
                                 bucketed=True,
                                 unk_replace=unk_replace,
                                 shuffle=True)):
                optimizer.zero_grad()
                words, labels, masks = data['WORD'].to(device), data['LAB'].to(
                    device), data['MASK'].to(device)
                loss = 0
                if threshold >= 1.0:
                    loss = model.get_loss(words,
                                          labels,
                                          masks,
                                          normalize_weight=normalize_weight,
                                          sep_normalize=sep_normalize)
                else:
                    for i in range(batch_size):
                        loss += model.get_loss(
                            words[i],
                            labels[i],
                            masks[i],
                            normalize_weight=normalize_weight,
                            sep_normalize=sep_normalize)
                # loss = model.get_loss(words, labels, masks)
                loss.backward()
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                epoch_loss += loss.item() * words.size(0)
            logger.info('Epoch ' + str(epoch) + ' Loss: ' +
                        str(round(epoch_loss / num_data, 4)))
            if threshold >= 1.0:
                acc, _ = evaluate(dev_dataset, batch_size, model, device)
            else:
                acc, _ = evaluate(dev_dataset, 1, model, device)
            logger.info('\t Dev Acc: ' + str(round(acc * 100, 3)))
            if best_epoch[1] < acc:
                if threshold >= 1.0:
                    test_acc, _ = evaluate(test_dataset, batch_size, model,
                                           device)
                else:
                    test_acc, _ = evaluate(test_dataset, 1, model, device)
                logger.info('\t Test Acc: ' + str(round(test_acc * 100, 3)))
                best_epoch = (epoch, acc, test_acc)
            epoch += 1

            if aim_epoch != 0 and epoch >= aim_epoch:
                break

        logger.info("Best Epoch: " + str(best_epoch[0]) + " Dev ACC: " +
                    str(round(best_epoch[1] * 100, 3)) + "Test ACC: " +
                    str(round(best_epoch[2] * 100, 3)))
        return best_epoch
예제 #3
0
    def train(best_epoch, thread=6):
        epoch = 0
        while epoch <= thread:
            epoch_loss = 0
            num_back = 0
            num_words = 0
            num_insts = 0
            model.train()
            for step, data in enumerate(iterate_data(train_dataset, batch_size, bucketed=True, unk_replace=unk_replace, shuffle=True)):
                # for j in tqdm(range(math.ceil(len(train_dataset) / batch_size))):
                optimizer.zero_grad()
                # samples = train_dataset[j * batch_size: (j + 1) * batch_size]
                words, masks = data['WORD'].to(device), data['MASK'].to(device)

                loss = model(words, masks)
                if torch.cuda.device_count() > 1:
                    loss = loss.mean()
                loss.backward()
                optimizer.step()
                scheduler.step()
                epoch_loss += loss.item() * words.size(0)
                num_words += torch.sum(masks).item()
                num_insts += words.size()[0]

                if step % 10 == 0:
                    torch.cuda.empty_cache()
                    sys.stdout.write("\b" * num_back)
                    sys.stdout.write(" " * num_back)
                    sys.stdout.write("\b" * num_back)
                    curr_lr = scheduler.get_lr()[0]
                    log_info = '[%d/%d (%.0f%%) lr=%.6f] loss: %.4f (%.4f)' % (
                        step, num_batches, 100. * step / num_batches,
                        curr_lr, epoch_loss / num_insts, epoch_loss / num_words)
                    sys.stdout.write(log_info)
                    sys.stdout.flush()
                    num_back = len(log_info)
            logger.info('Epoch ' + str(epoch) + ' Loss: ' + str(round(epoch_loss / num_insts, 4)))

            ppl = evaluate(dev_dataset, batch_size, model, device)

            logger.info('\t Dev PPL: ' + str(round(ppl, 3)))

            if best_epoch[1] > ppl:
                test_ppl = evaluate(test_dataset, batch_size, model, device)
                logger.info('\t Test PPL: ' + str(round(test_ppl, 3)))
                best_epoch = (epoch, ppl, test_ppl)
                patient = 0
            else:
                patient += 1
            epoch += 1
            if patient > 4:
                print('reset optimizer momentums')
                scheduler.reset_state()
                patient = 0

        logger.info("Best Epoch: " + str(best_epoch[0]) + " Dev ACC: " + str(round(best_epoch[1], 3)) +
                    "Test ACC: " + str(round(best_epoch[2], 3)))
        return best_epoch
예제 #4
0
    def train(best_epoch):
        epoch = 0
        while epoch - best_epoch[0] <= 6:
            epoch_loss = 0
            num_back = 0
            num_words = 0
            num_insts = 0
            model.train()
            for step, data in enumerate(
                    iterate_data(train_dataset,
                                 batch_size,
                                 bucketed=True,
                                 unk_replace=unk_replace,
                                 shuffle=True)):
                # for j in tqdm(range(math.ceil(len(train_dataset) / batch_size))):
                optimizer.zero_grad()
                # samples = train_dataset[j * batch_size: (j + 1) * batch_size]
                words, labels, masks = data['WORD'].to(device), data['NER'].to(
                    device), data['MASK'].to(device)

                # sentences, labels, masks, revert_order = standardize_batch(samples)
                loss = model.get_loss(words, labels, masks)
                loss.backward()
                optimizer.step()
                epoch_loss += (loss.item()) * words.size(0)
                num_words += torch.sum(masks).item()
                num_insts += words.size()[0]
                if step % 10 == 0:
                    torch.cuda.empty_cache()
                    sys.stdout.write("\b" * num_back)
                    sys.stdout.write(" " * num_back)
                    sys.stdout.write("\b" * num_back)
                    log_info = '[%d/%d (%.0f%%) lr=%.6f] loss: %.4f (%.4f)' % (
                        step, num_batches, 100. * step / num_batches, lr,
                        epoch_loss / num_insts, epoch_loss / num_words)
                    sys.stdout.write(log_info)
                    sys.stdout.flush()
                    num_back = len(log_info)

            logger.info('Epoch ' + str(epoch) + ' Loss: ' +
                        str(round(epoch_loss / num_insts, 4)))
            acc, corr = evaluate(dev_dataset, batch_size, model, device)
            logger.info('\t Dev Acc: ' + str(round(acc * 100, 3)))
            if best_epoch[1] < acc:
                test_acc, _ = evaluate(test_dataset, batch_size, model, device)
                logger.info('\t Test Acc: ' + str(round(test_acc * 100, 3)))
                best_epoch = (epoch, acc, test_acc)
            epoch += 1

        logger.info("Best Epoch: " + str(best_epoch[0]) + " Dev ACC: " +
                    str(round(best_epoch[1] * 100, 3)) + "Test ACC: " +
                    str(round(best_epoch[2] * 100, 3)))
        return best_epoch
예제 #5
0
def evaluate(data, batch, model, device):
    model.eval()
    total_token_num = 0
    corr_token_num = 0
    with torch.no_grad():
        for batch_data in iterate_data(data, batch):
            # sentences, labels, masks, revert_order = standardize_batch(data[i * batch: (i + 1) * batch])
            words = batch_data['WORD'].to(device)
            labels = batch_data['NER'].to(device)
            masks = batch_data['MASK'].to(device)
            acc, corr = model.get_acc(words, labels, masks)
            corr_token_num += corr
            total_token_num += torch.sum(masks).item()
    return corr_token_num / total_token_num, corr_token_num
예제 #6
0
def evaluate(data, batch, model, device):
    model.eval()

    total_ppl = 0
    word_cnt = 0
    with torch.no_grad():
        for batch_data in iterate_data(data, batch):
            # sentences, labels, masks, revert_order = standardize_batch(data[i * batch: (i + 1) * batch])
            words = batch_data['WORD'].to(device)
            masks = batch_data['MASK'].to(device)
            lengths = batch_data['LENGTH']
            ppl = model.get_loss(words, masks)
            total_ppl += ppl.item() * words.size(0)
            word_cnt += torch.sum(lengths).item()
    return total_ppl / word_cnt
예제 #7
0
def evaluate(data, batch, model, device):
    model.eval()
    total_token_num = 0
    corr_token_num = 0
    total_pred = []
    with torch.no_grad():
        for batch_data in iterate_data(data, batch):
            # sentences, labels, masks, revert_order = standardize_batch(data[i * batch: (i + 1) * batch])
            words = batch_data['WORD'].squeeze().to(device)
            labels = batch_data['POS'].squeeze().to(device)
            masks = batch_data['MASK'].squeeze().to(device)
            corr, preds = model.get_acc(words, labels, masks)
            corr_token_num += corr
            total_token_num += torch.sum(masks).item()
            for pred in preds.tolist():
                total_pred.append(pred)
    return corr_token_num / total_token_num, total_pred
def evaluate(data, batch, model, device):
    model.eval()
    corr_token_num = 0
    total_pred = []
    with torch.no_grad():
        for batch_data in iterate_data(data, batch):
            sentences = batch_data['WORD'].to(device)
            labels = batch_data['LAB'].to(device)
            masks = batch_data['MASK'].to(device)
            corr, preds = model.get_acc(sentences, labels, masks)
            corr_token_num += corr
            preds = preds.tolist()
            if isinstance(preds, int):
                total_pred.append(preds)
            else:
                for pred in preds:
                    total_pred.append(pred)
    return corr_token_num / data[1], total_pred
예제 #9
0
    def train(best_epoch, thread=6):
        epoch = 0
        while epoch - best_epoch[0] <= thread:
            epoch_loss = 0
            num_back = 0
            num_words = 0
            num_insts = 0
            model.train()
            for step, data in enumerate(
                    iterate_data(train_dataset,
                                 batch_size,
                                 bucketed=True,
                                 unk_replace=unk_replace,
                                 shuffle=True)):
                # for j in tqdm(range(math.ceil(len(train_dataset) / batch_size))):
                optimizer.zero_grad()
                # samples = train_dataset[j * batch_size: (j + 1) * batch_size]
                words, labels, masks = data['WORD'].to(device), data['POS'].to(
                    device), data['MASK'].to(device)
                loss = 0.0
                if threshold >= 1.0:
                    # sentences, labels, masks, revert_order = standardize_batch(samples)
                    loss = model.get_loss(words,
                                          labels,
                                          masks,
                                          normalize_weight=normalize_weight,
                                          sep_normalize=sep_normalize)
                else:
                    for i in range(batch_size):
                        loss += model.get_loss(
                            words[i],
                            labels[i],
                            masks[i],
                            normalize_weight=normalize_weight,
                            sep_normalize=sep_normalize)
                # loss = model.get_loss(words, labels, masks)
                loss.backward()
                optimizer.step()
                scheduler.step()
                epoch_loss += (loss.item()) * words.size(0)
                num_words += torch.sum(masks).item()
                num_insts += words.size()[0]
                if step % 10 == 0:
                    torch.cuda.empty_cache()
                    sys.stdout.write("\b" * num_back)
                    sys.stdout.write(" " * num_back)
                    sys.stdout.write("\b" * num_back)
                    curr_lr = scheduler.get_lr()[0]
                    log_info = '[%d/%d (%.0f%%) lr=%.6f] loss: %.4f (%.4f)' % (
                        step, num_batches, 100. * step / num_batches, curr_lr,
                        epoch_loss / num_insts, epoch_loss / num_words)
                    sys.stdout.write(log_info)
                    sys.stdout.flush()
                    num_back = len(log_info)
            logger.info('Epoch ' + str(epoch) + ' Loss: ' +
                        str(round(epoch_loss / num_insts, 4)))
            if threshold >= 1.0:
                acc, _ = evaluate(dev_dataset, batch_size, model, device)
            else:
                acc, _ = evaluate(dev_dataset, 1, model, device)
            logger.info('\t Dev Acc: ' + str(round(acc * 100, 3)))
            if analysis:
                analyse(model,
                        dev_dataset,
                        batch_size,
                        device,
                        analyzer,
                        log_dir + '/dev_' + str(epoch),
                        buckted=False)
                analyse(model,
                        test_dataset,
                        batch_size,
                        device,
                        analyzer,
                        log_dir + '/test_' + str(epoch),
                        buckted=False)

            if best_epoch[1] < acc:
                test_acc, _ = evaluate(test_dataset, batch_size, model, device)
                logger.info('\t Test Acc: ' + str(round(test_acc * 100, 3)))
                best_epoch = (epoch, acc, test_acc)
                patient = 0
            else:
                patient += 1
            epoch += 1
            if patient > 4:
                print('reset optimizer momentums')
                scheduler.reset_state()
                patient = 0

        logger.info("Best Epoch: " + str(best_epoch[0]) + " Dev ACC: " +
                    str(round(best_epoch[1] * 100, 3)) + "Test ACC: " +
                    str(round(best_epoch[2] * 100, 3)))
        return best_epoch