def process_dataset(name): full = get_jsonl_data(f"data/{name}/full.jsonl") random.shuffle(full) labels = sorted(set([d['label'] for d in full])) n_labels = len(labels) random.seed(42) random.shuffle(labels) train_labels, valid_labels, test_labels = (labels[:int(n_labels / 3)], labels[int(n_labels / 3):int(2 * n_labels / 3)], labels[int(2 * n_labels / 3):]) write_jsonl_data([d for d in full if d['label'] in train_labels], f"data/{name}/train.jsonl", force=True) write_txt_data(train_labels, f"data/{name}/labels.train.txt") write_jsonl_data([d for d in full if d['label'] in valid_labels], f"data/{name}/valid.jsonl", force=True) write_txt_data(valid_labels, f"data/{name}/labels.valid.txt") write_jsonl_data([d for d in full if d['label'] in test_labels], f"data/{name}/test.jsonl", force=True) write_txt_data(test_labels, f"data/{name}/labels.test.txt")
def run_relation(train_path: str, model_name_or_path: str, n_support: int, n_query: int, n_classes: int, valid_path: str = None, test_path: str = None, output_path: str = f"runs/{now()}", max_iter: int = 10000, evaluate_every: int = 100, early_stop: int = None, n_test_episodes: int = 1000, log_every: int = 10, relation_module_type: str = "base", ntl_n_slices: int = 100, arsc_format: bool = False, data_path: str = None): if output_path: if os.path.exists(output_path) and len(os.listdir(output_path)): raise FileExistsError( f"Output path {output_path} already exists. Exiting.") # -------------------- # Creating Log Writers # -------------------- os.makedirs(output_path) os.makedirs(os.path.join(output_path, "logs/train")) train_writer: SummaryWriter = SummaryWriter(logdir=os.path.join( output_path, "logs/train"), flush_secs=1, max_queue=1) valid_writer: SummaryWriter = None test_writer: SummaryWriter = None log_dict = dict(train=list()) if valid_path: os.makedirs(os.path.join(output_path, "logs/valid")) valid_writer = SummaryWriter(logdir=os.path.join( output_path, "logs/valid"), flush_secs=1, max_queue=1) log_dict["valid"] = list() if test_path: os.makedirs(os.path.join(output_path, "logs/test")) test_writer = SummaryWriter(logdir=os.path.join( output_path, "logs/test"), flush_secs=1, max_queue=1) log_dict["test"] = list() def raw_data_to_labels_dict(data, shuffle=True): labels_dict = collections.defaultdict(list) for item in data: labels_dict[item["label"]].append(item["sentence"]) labels_dict = dict(labels_dict) if shuffle: for key, val in labels_dict.items(): random.shuffle(val) return labels_dict # Load model bert = BERTEncoder(model_name_or_path).to(device) matching_net = RelationNet(encoder=bert, relation_module_type=relation_module_type, ntl_n_slices=ntl_n_slices) optimizer = torch.optim.Adam(matching_net.parameters(), lr=2e-5) # Load data if not arsc_format: train_data = get_jsonl_data(train_path) train_data_dict = raw_data_to_labels_dict(train_data, shuffle=True) logger.info(f"train labels: {train_data_dict.keys()}") if valid_path: valid_data = get_jsonl_data(valid_path) valid_data_dict = raw_data_to_labels_dict(valid_data, shuffle=True) logger.info(f"valid labels: {valid_data_dict.keys()}") else: valid_data_dict = None if test_path: test_data = get_jsonl_data(test_path) test_data_dict = raw_data_to_labels_dict(test_data, shuffle=True) logger.info(f"test labels: {test_data_dict.keys()}") else: test_data_dict = None else: train_data_dict = None test_data_dict = None valid_data_dict = None train_accuracies = list() train_losses = list() n_eval_since_last_best = 0 best_valid_acc = 0.0 for step in range(max_iter): if not arsc_format: loss, loss_dict = matching_net.train_step( optimizer=optimizer, data_dict=train_data_dict, n_support=n_support, n_query=n_query, n_classes=n_classes) else: loss, loss_dict = matching_net.train_step_ARSC(optimizer=optimizer, data_path=data_path) train_accuracies.append(loss_dict["acc"]) train_losses.append(loss_dict["loss"]) # Logging if (step + 1) % log_every == 0: train_writer.add_scalar(tag="loss", scalar_value=np.mean(train_losses), global_step=step) train_writer.add_scalar(tag="accuracy", scalar_value=np.mean(train_accuracies), global_step=step) logger.info( f"train | loss: {np.mean(train_losses):.4f} | acc: {np.mean(train_accuracies):.4f}" ) log_dict["train"].append({ "metrics": [{ "tag": "accuracy", "value": np.mean(train_accuracies) }, { "tag": "loss", "value": np.mean(train_losses) }], "global_step": step }) train_accuracies = list() train_losses = list() if valid_path or test_path: if (step + 1) % evaluate_every == 0: for path, writer, set_type, set_data in zip( [valid_path, test_path], [valid_writer, test_writer], ["valid", "test"], [valid_data_dict, test_data_dict]): if path: if not arsc_format: set_results = matching_net.test_step( data_dict=set_data, n_support=n_support, n_query=n_query, n_classes=n_classes, n_episodes=n_test_episodes) else: set_results = matching_net.test_step_ARSC( data_path=data_path, n_episodes=n_test_episodes, set_type={ "valid": "dev", "test": "test" }[set_type]) writer.add_scalar(tag="loss", scalar_value=set_results["loss"], global_step=step) writer.add_scalar(tag="accuracy", scalar_value=set_results["acc"], global_step=step) log_dict[set_type].append({ "metrics": [{ "tag": "accuracy", "value": set_results["acc"] }, { "tag": "loss", "value": set_results["loss"] }], "global_step": step }) logger.info( f"{set_type} | loss: {set_results['loss']:.4f} | acc: {set_results['acc']:.4f}" ) if set_type == "valid": if set_results["acc"] > best_valid_acc: best_valid_acc = set_results["acc"] n_eval_since_last_best = 0 logger.info(f"Better eval results!") else: n_eval_since_last_best += 1 logger.info( f"Worse eval results ({n_eval_since_last_best}/{early_stop})" ) if early_stop and n_eval_since_last_best >= early_stop: logger.warning(f"Early-stopping.") break with open(os.path.join(output_path, "metrics.json"), "w") as file: json.dump(log_dict, file, ensure_ascii=False)
def run_baseline(train_path: str, model_name_or_path: str, n_support: int, n_classes: int, valid_path: str = None, test_path: str = None, output_path: str = f'runs/{now()}', n_test_episodes: int = 600, log_every: int = 10, n_train_epoch: int = 400, train_batch_size: int = 16, is_pp: bool = False, test_batch_size: int = 4, n_test_iter: int = 100, metric: str = "cosine", arsc_format: bool = False, data_path: str = None): if output_path: if os.path.exists(output_path) and len(os.listdir(output_path)): raise FileExistsError( f"Output path {output_path} already exists. Exiting.") # -------------------- # Creating Log Writers # -------------------- os.makedirs(output_path) os.makedirs(os.path.join(output_path, "logs/train")) train_writer: SummaryWriter = SummaryWriter(logdir=os.path.join( output_path, "logs/train"), flush_secs=1, max_queue=1) valid_writer: SummaryWriter = None test_writer: SummaryWriter = None log_dict = dict(train=list()) if valid_path: os.makedirs(os.path.join(output_path, "logs/valid")) valid_writer = SummaryWriter(logdir=os.path.join( output_path, "logs/valid"), flush_secs=1, max_queue=1) log_dict["valid"] = list() if test_path: os.makedirs(os.path.join(output_path, "logs/test")) test_writer = SummaryWriter(logdir=os.path.join( output_path, "logs/test"), flush_secs=1, max_queue=1) log_dict["test"] = list() def raw_data_to_labels_dict(data, shuffle=True): labels_dict = collections.defaultdict(list) for item in data: labels_dict[item['label']].append(item["sentence"]) labels_dict = dict(labels_dict) if shuffle: for key, val in labels_dict.items(): random.shuffle(val) return labels_dict # Load model bert = BERTEncoder(model_name_or_path).to(device) baseline_net = BaselineNet(encoder=bert, is_pp=is_pp, metric=metric).to(device) # Load data if not arsc_format: train_data = get_jsonl_data(train_path) train_data_dict = raw_data_to_labels_dict(train_data, shuffle=True) logger.info(f"train labels: {train_data_dict.keys()}") if valid_path: valid_data = get_jsonl_data(valid_path) valid_data_dict = raw_data_to_labels_dict(valid_data, shuffle=True) logger.info(f"valid labels: {valid_data_dict.keys()}") else: valid_data_dict = None if test_path: test_data = get_jsonl_data(test_path) test_data_dict = raw_data_to_labels_dict(test_data, shuffle=True) logger.info(f"test labels: {test_data_dict.keys()}") else: test_data_dict = None baseline_net.train_model(data_dict=train_data_dict, summary_writer=train_writer, n_epoch=n_train_epoch, batch_size=train_batch_size, log_every=log_every) # Validation if valid_path: validation_metrics = baseline_net.test_model( data_dict=valid_data_dict, n_support=n_support, n_classes=n_classes, n_episodes=n_test_episodes, summary_writer=valid_writer, n_test_iter=n_test_iter, test_batch_size=test_batch_size) with open(os.path.join(output_path, 'validation_metrics.json'), "w") as file: json.dump(validation_metrics, file, ensure_ascii=False) # Test if test_path: test_metrics = baseline_net.test_model(data_dict=test_data_dict, n_support=n_support, n_classes=n_classes, n_episodes=n_test_episodes, summary_writer=test_writer) with open(os.path.join(output_path, 'test_metrics.json'), "w") as file: json.dump(test_metrics, file, ensure_ascii=False) else: # baseline_net.train_model_ARSC( # train_summary_writer=train_writer, # n_episodes=10, # n_train_iter=20 # ) # metrics = baseline_net.test_model_ARSC( # n_iter=n_test_iter, # valid_summary_writer=valid_writer, # test_summary_writer=test_writer # ) metrics = baseline_net.run_ARSC(train_summary_writer=train_writer, valid_summary_writer=valid_writer, test_summary_writer=test_writer, n_episodes=1000, train_eval_every=50, n_train_iter=50, n_test_iter=200, test_eval_every=25, data_path=data_path) with open(os.path.join(output_path, 'baseline_metrics.json'), "w") as file: json.dump(metrics, file, ensure_ascii=False)