def process_dataset(name):
    full = get_jsonl_data(f"data/{name}/full.jsonl")
    random.shuffle(full)

    labels = sorted(set([d['label'] for d in full]))
    n_labels = len(labels)
    random.seed(42)
    random.shuffle(labels)

    train_labels, valid_labels, test_labels = (labels[:int(n_labels / 3)],
                                               labels[int(n_labels /
                                                          3):int(2 * n_labels /
                                                                 3)],
                                               labels[int(2 * n_labels / 3):])

    write_jsonl_data([d for d in full if d['label'] in train_labels],
                     f"data/{name}/train.jsonl",
                     force=True)
    write_txt_data(train_labels, f"data/{name}/labels.train.txt")

    write_jsonl_data([d for d in full if d['label'] in valid_labels],
                     f"data/{name}/valid.jsonl",
                     force=True)
    write_txt_data(valid_labels, f"data/{name}/labels.valid.txt")

    write_jsonl_data([d for d in full if d['label'] in test_labels],
                     f"data/{name}/test.jsonl",
                     force=True)
    write_txt_data(test_labels, f"data/{name}/labels.test.txt")
Пример #2
0
def run_relation(train_path: str,
                 model_name_or_path: str,
                 n_support: int,
                 n_query: int,
                 n_classes: int,
                 valid_path: str = None,
                 test_path: str = None,
                 output_path: str = f"runs/{now()}",
                 max_iter: int = 10000,
                 evaluate_every: int = 100,
                 early_stop: int = None,
                 n_test_episodes: int = 1000,
                 log_every: int = 10,
                 relation_module_type: str = "base",
                 ntl_n_slices: int = 100,
                 arsc_format: bool = False,
                 data_path: str = None):
    if output_path:
        if os.path.exists(output_path) and len(os.listdir(output_path)):
            raise FileExistsError(
                f"Output path {output_path} already exists. Exiting.")

    # --------------------
    # Creating Log Writers
    # --------------------
    os.makedirs(output_path)
    os.makedirs(os.path.join(output_path, "logs/train"))
    train_writer: SummaryWriter = SummaryWriter(logdir=os.path.join(
        output_path, "logs/train"),
                                                flush_secs=1,
                                                max_queue=1)
    valid_writer: SummaryWriter = None
    test_writer: SummaryWriter = None
    log_dict = dict(train=list())

    if valid_path:
        os.makedirs(os.path.join(output_path, "logs/valid"))
        valid_writer = SummaryWriter(logdir=os.path.join(
            output_path, "logs/valid"),
                                     flush_secs=1,
                                     max_queue=1)
        log_dict["valid"] = list()
    if test_path:
        os.makedirs(os.path.join(output_path, "logs/test"))
        test_writer = SummaryWriter(logdir=os.path.join(
            output_path, "logs/test"),
                                    flush_secs=1,
                                    max_queue=1)
        log_dict["test"] = list()

    def raw_data_to_labels_dict(data, shuffle=True):
        labels_dict = collections.defaultdict(list)
        for item in data:
            labels_dict[item["label"]].append(item["sentence"])
        labels_dict = dict(labels_dict)
        if shuffle:
            for key, val in labels_dict.items():
                random.shuffle(val)
        return labels_dict

    # Load model
    bert = BERTEncoder(model_name_or_path).to(device)
    matching_net = RelationNet(encoder=bert,
                               relation_module_type=relation_module_type,
                               ntl_n_slices=ntl_n_slices)
    optimizer = torch.optim.Adam(matching_net.parameters(), lr=2e-5)

    # Load data
    if not arsc_format:
        train_data = get_jsonl_data(train_path)
        train_data_dict = raw_data_to_labels_dict(train_data, shuffle=True)
        logger.info(f"train labels: {train_data_dict.keys()}")

        if valid_path:
            valid_data = get_jsonl_data(valid_path)
            valid_data_dict = raw_data_to_labels_dict(valid_data, shuffle=True)
            logger.info(f"valid labels: {valid_data_dict.keys()}")
        else:
            valid_data_dict = None

        if test_path:
            test_data = get_jsonl_data(test_path)
            test_data_dict = raw_data_to_labels_dict(test_data, shuffle=True)
            logger.info(f"test labels: {test_data_dict.keys()}")
        else:
            test_data_dict = None
    else:
        train_data_dict = None
        test_data_dict = None
        valid_data_dict = None

    train_accuracies = list()
    train_losses = list()
    n_eval_since_last_best = 0
    best_valid_acc = 0.0

    for step in range(max_iter):
        if not arsc_format:
            loss, loss_dict = matching_net.train_step(
                optimizer=optimizer,
                data_dict=train_data_dict,
                n_support=n_support,
                n_query=n_query,
                n_classes=n_classes)
        else:
            loss, loss_dict = matching_net.train_step_ARSC(optimizer=optimizer,
                                                           data_path=data_path)

        train_accuracies.append(loss_dict["acc"])
        train_losses.append(loss_dict["loss"])

        # Logging
        if (step + 1) % log_every == 0:
            train_writer.add_scalar(tag="loss",
                                    scalar_value=np.mean(train_losses),
                                    global_step=step)
            train_writer.add_scalar(tag="accuracy",
                                    scalar_value=np.mean(train_accuracies),
                                    global_step=step)
            logger.info(
                f"train | loss: {np.mean(train_losses):.4f} | acc: {np.mean(train_accuracies):.4f}"
            )
            log_dict["train"].append({
                "metrics": [{
                    "tag": "accuracy",
                    "value": np.mean(train_accuracies)
                }, {
                    "tag": "loss",
                    "value": np.mean(train_losses)
                }],
                "global_step":
                step
            })

            train_accuracies = list()
            train_losses = list()

        if valid_path or test_path:
            if (step + 1) % evaluate_every == 0:
                for path, writer, set_type, set_data in zip(
                    [valid_path, test_path], [valid_writer, test_writer],
                    ["valid", "test"], [valid_data_dict, test_data_dict]):
                    if path:
                        if not arsc_format:
                            set_results = matching_net.test_step(
                                data_dict=set_data,
                                n_support=n_support,
                                n_query=n_query,
                                n_classes=n_classes,
                                n_episodes=n_test_episodes)
                        else:
                            set_results = matching_net.test_step_ARSC(
                                data_path=data_path,
                                n_episodes=n_test_episodes,
                                set_type={
                                    "valid": "dev",
                                    "test": "test"
                                }[set_type])
                        writer.add_scalar(tag="loss",
                                          scalar_value=set_results["loss"],
                                          global_step=step)
                        writer.add_scalar(tag="accuracy",
                                          scalar_value=set_results["acc"],
                                          global_step=step)
                        log_dict[set_type].append({
                            "metrics": [{
                                "tag": "accuracy",
                                "value": set_results["acc"]
                            }, {
                                "tag": "loss",
                                "value": set_results["loss"]
                            }],
                            "global_step":
                            step
                        })

                        logger.info(
                            f"{set_type} | loss: {set_results['loss']:.4f} | acc: {set_results['acc']:.4f}"
                        )
                        if set_type == "valid":
                            if set_results["acc"] > best_valid_acc:
                                best_valid_acc = set_results["acc"]
                                n_eval_since_last_best = 0
                                logger.info(f"Better eval results!")
                            else:
                                n_eval_since_last_best += 1
                                logger.info(
                                    f"Worse eval results ({n_eval_since_last_best}/{early_stop})"
                                )

                if early_stop and n_eval_since_last_best >= early_stop:
                    logger.warning(f"Early-stopping.")
                    break
    with open(os.path.join(output_path, "metrics.json"), "w") as file:
        json.dump(log_dict, file, ensure_ascii=False)
Пример #3
0
def run_baseline(train_path: str,
                 model_name_or_path: str,
                 n_support: int,
                 n_classes: int,
                 valid_path: str = None,
                 test_path: str = None,
                 output_path: str = f'runs/{now()}',
                 n_test_episodes: int = 600,
                 log_every: int = 10,
                 n_train_epoch: int = 400,
                 train_batch_size: int = 16,
                 is_pp: bool = False,
                 test_batch_size: int = 4,
                 n_test_iter: int = 100,
                 metric: str = "cosine",
                 arsc_format: bool = False,
                 data_path: str = None):
    if output_path:
        if os.path.exists(output_path) and len(os.listdir(output_path)):
            raise FileExistsError(
                f"Output path {output_path} already exists. Exiting.")

    # --------------------
    # Creating Log Writers
    # --------------------
    os.makedirs(output_path)
    os.makedirs(os.path.join(output_path, "logs/train"))
    train_writer: SummaryWriter = SummaryWriter(logdir=os.path.join(
        output_path, "logs/train"),
                                                flush_secs=1,
                                                max_queue=1)
    valid_writer: SummaryWriter = None
    test_writer: SummaryWriter = None
    log_dict = dict(train=list())

    if valid_path:
        os.makedirs(os.path.join(output_path, "logs/valid"))
        valid_writer = SummaryWriter(logdir=os.path.join(
            output_path, "logs/valid"),
                                     flush_secs=1,
                                     max_queue=1)
        log_dict["valid"] = list()
    if test_path:
        os.makedirs(os.path.join(output_path, "logs/test"))
        test_writer = SummaryWriter(logdir=os.path.join(
            output_path, "logs/test"),
                                    flush_secs=1,
                                    max_queue=1)
        log_dict["test"] = list()

    def raw_data_to_labels_dict(data, shuffle=True):
        labels_dict = collections.defaultdict(list)
        for item in data:
            labels_dict[item['label']].append(item["sentence"])
        labels_dict = dict(labels_dict)
        if shuffle:
            for key, val in labels_dict.items():
                random.shuffle(val)
        return labels_dict

    # Load model
    bert = BERTEncoder(model_name_or_path).to(device)
    baseline_net = BaselineNet(encoder=bert, is_pp=is_pp,
                               metric=metric).to(device)

    # Load data
    if not arsc_format:
        train_data = get_jsonl_data(train_path)
        train_data_dict = raw_data_to_labels_dict(train_data, shuffle=True)
        logger.info(f"train labels: {train_data_dict.keys()}")

        if valid_path:
            valid_data = get_jsonl_data(valid_path)
            valid_data_dict = raw_data_to_labels_dict(valid_data, shuffle=True)
            logger.info(f"valid labels: {valid_data_dict.keys()}")
        else:
            valid_data_dict = None

        if test_path:
            test_data = get_jsonl_data(test_path)
            test_data_dict = raw_data_to_labels_dict(test_data, shuffle=True)
            logger.info(f"test labels: {test_data_dict.keys()}")
        else:
            test_data_dict = None

        baseline_net.train_model(data_dict=train_data_dict,
                                 summary_writer=train_writer,
                                 n_epoch=n_train_epoch,
                                 batch_size=train_batch_size,
                                 log_every=log_every)

        # Validation
        if valid_path:
            validation_metrics = baseline_net.test_model(
                data_dict=valid_data_dict,
                n_support=n_support,
                n_classes=n_classes,
                n_episodes=n_test_episodes,
                summary_writer=valid_writer,
                n_test_iter=n_test_iter,
                test_batch_size=test_batch_size)
            with open(os.path.join(output_path, 'validation_metrics.json'),
                      "w") as file:
                json.dump(validation_metrics, file, ensure_ascii=False)
        # Test
        if test_path:
            test_metrics = baseline_net.test_model(data_dict=test_data_dict,
                                                   n_support=n_support,
                                                   n_classes=n_classes,
                                                   n_episodes=n_test_episodes,
                                                   summary_writer=test_writer)

            with open(os.path.join(output_path, 'test_metrics.json'),
                      "w") as file:
                json.dump(test_metrics, file, ensure_ascii=False)

    else:
        # baseline_net.train_model_ARSC(
        #     train_summary_writer=train_writer,
        #     n_episodes=10,
        #     n_train_iter=20
        # )
        # metrics = baseline_net.test_model_ARSC(
        #     n_iter=n_test_iter,
        #     valid_summary_writer=valid_writer,
        #     test_summary_writer=test_writer
        # )
        metrics = baseline_net.run_ARSC(train_summary_writer=train_writer,
                                        valid_summary_writer=valid_writer,
                                        test_summary_writer=test_writer,
                                        n_episodes=1000,
                                        train_eval_every=50,
                                        n_train_iter=50,
                                        n_test_iter=200,
                                        test_eval_every=25,
                                        data_path=data_path)
        with open(os.path.join(output_path, 'baseline_metrics.json'),
                  "w") as file:
            json.dump(metrics, file, ensure_ascii=False)