示例#1
0
def main(_run: Run, _config):
    p = Params(**_config)

    agent = create_agent(p)
    # load_agent(761, 200, agent)
    task = create_task(p)
    inner_loop = LearningLoop()

    n_experts = p.n_experts

    for epoch in tqdm(range(1, p.epochs + 1)):
        # with torch.autograd.detect_anomaly():
        should_save_images = epoch % p.image_save_period == 0
        observer = MultiObserver() if should_save_images else None

        # if epoch == 10000:
        #     task = create_task(p, 0.5)
        #
        # if epoch == 20000:
        #     task = create_task(p, 0.01)

        # if epoch % 100 == 0:
        #     task_size = random.randint(n_experts, n_experts + 10)
        #     p.task_size = task_size
        #     task = create_task(p)

        # reset_ids = epoch > 10000 or (epoch - 1) % 2000 == 0
        reset_ids = True
        agent.optim.zero_grad()
        agent.init_rollout(p.batch_size, n_experts, SearchAgentInitRolloutParams(reset_ids))

        rollout_size = p.rollout_size
        # if epoch > 10000:
        #     rollout_size = rollout_size + random.randint(-5, 5)

        err = inner_loop.train_fixed_steps(agent, task, rollout_size, p.learning_exp_decay,
                                           p.learning_rollout_steps_clip, observer)
        err.backward()

        _run.log_scalar('loss', err.cpu().detach().item())

        # save grads
        # if observer is not None:
        #     params = [p for p in agent.parameters() if p.grad is not None]
        #     for i, param in enumerate(params):
        #         observer.current.add_tensor(f'params_{i}', param.detach().cpu())
        #         observer.current.add_tensor(f'params-grads_{i}', param.grad.detach().cpu())

        agent.optim.step()

        if observer is not None:
            # print(f'Saving plots ep: {epoch}')
            tensors = [o.tensors_as_dict() for o in observer.observers]
            sacred_writer.save_tensor(tensors, 'tensors', epoch)

        # log targets
        if epoch % p.save_period == 0:
            sacred_writer.save_model(agent, 'agent', epoch)
示例#2
0
def test(model_filename: str,
         test_corpus: str,
         window_size: int = 5,
         _run: Run = None,
         _log: logger = None):
    _run.add_resource(test_corpus)
    _run.add_resource(f'{model_filename}.pkl')
    test_sents, _ = get_tagged_sents_and_words(test_corpus)

    X_test = [sent2features(s, window_size) for s in test_sents]
    y_test = [sent2labels(s) for s in test_sents]

    _log.info(f'load from: {model_filename}.pkl')
    crf = sklearn_crfsuite.CRF(model_filename=model_filename)

    y_pred = crf.predict(X_test)
    overall, by_type = evaluate(y_test, y_pred)
    _run.info[f'overall_f1'] = overall.f1_score
    _run.log_scalar('overall_f1', overall.f1_score)
    _run.info[f'overall_precision'] = overall.precision
    _run.log_scalar('overall_precision', overall.precision)
    _run.info[f'overall_recall'] = overall.recall
    _run.log_scalar('overall_recall', overall.recall)
    _log.info(f'Overall F1 score: {overall.f1_score}')
    for _, key in enumerate(sorted(by_type.keys())):
        for metric_key in by_type[key]._fields:
            metric_val = getattr(by_type[key], metric_key)
            _run.info[f'{key}-{metric_key}'] = metric_val
            _run.log_scalar(f'{key}-{metric_key}', metric_val)
            _log.info(f'{key}-{metric_key}: {metric_val}')
示例#3
0
def test(test_corpus: str, model_output: str,
         col_ref: int = 0, col_hyp: int = 0,
         _run: Run = None, _log: logger = None):
    test_sents, _ = get_tagged_sents_and_words(test_corpus)
    print(f'num sentences: {len(test_sents)}')
    y_test = [sent2labels_colmap(s, col=int(col_ref)) for s in test_sents]


    yout_sents, _ = get_tagged_sents_and_words(model_output)
    print(f'num sentences: {len(yout_sents)}')
    y_pred = [sent2labels_colmap(s, col=int(col_hyp)) for s in yout_sents]

    if len(y_test) != len(y_pred):
        for i, j in zip_longest(y_test, y_pred):
            print(i, j)

    overall, by_type = evaluate(y_test, y_pred)
    print(overall)
    print(by_type)

    _run.info[f'overall_f1'] = overall.f1_score
    _run.log_scalar('overall_f1', overall.f1_score)
    _run.info[f'overall_precision'] = overall.precision
    _run.log_scalar('overall_precision', overall.precision)
    _run.info[f'overall_recall'] = overall.recall
    _run.log_scalar('overall_recall', overall.recall) 
    _log.info(f'Overall F1 score: {overall.f1_score}')
    for _, key in enumerate(sorted(by_type.keys())):
        for metric_key in by_type[key]._fields:
            metric_val = getattr(by_type[key], metric_key)
            _run.info[f'{key}-{metric_key}'] = metric_val
            _run.log_scalar(f'{key}-{metric_key}', metric_val)
            _log.info(f'{key}-{metric_key}: {metric_val}')
示例#4
0
def train(train_corpus: str,
          dev_corpus: str,
          c1: float = 0.0,
          c2: float = 0.0,
          algorithm: str = 'lbfgs',
          max_iterations: int = 100,
          all_possible_transitions: bool = False,
          window_size: int = 1,
          model_filename: str = None,
          _run: Run = None,
          _log: logger = None):
    """
    running crf experiment
    """
    _run.add_resource(train_corpus)
    _run.add_resource(dev_corpus)
    train_sents, _ = get_tagged_sents_and_words(train_corpus)
    dev_sents, _ = get_tagged_sents_and_words(dev_corpus)

    X_train = [sent2features(s, window_size) for s in train_sents]
    y_train = [sent2labels(s) for s in train_sents]

    X_dev = [sent2features(s, window_size) for s in dev_sents]
    y_dev = [sent2labels(s) for s in dev_sents]

    crf = sklearn_crfsuite.CRF(
        algorithm=algorithm,
        c1=c1,
        c2=c2,
        max_iterations=max_iterations,
        all_possible_transitions=all_possible_transitions,
        model_filename=model_filename,
    )

    crf.fit(X_train, y_train)
    y_pred = crf.predict(X_dev)
    overall, by_type = evaluate(y_dev, y_pred)
    _run.info[f'overall_f1'] = overall.f1_score
    _run.log_scalar('overall_f1', overall.f1_score)
    _run.info[f'overall_precision'] = overall.precision
    _run.log_scalar('overall_precision', overall.precision)
    _run.info[f'overall_recall'] = overall.recall
    _run.log_scalar('overall_recall', overall.recall)
    _log.info(f'Overall F1 score: {overall.f1_score}')
    for _, key in enumerate(sorted(by_type.keys())):
        for metric_key in by_type[key]._fields:
            metric_val = getattr(by_type[key], metric_key)
            _run.info[f'{key}-{metric_key}'] = metric_val
            _run.log_scalar(f'{key}-{metric_key}', metric_val)
            _log.info(f'{key}-{metric_key}: {metric_val}')
    if model_filename is not None:
        _log.info(f'saving to: {model_filename}.pkl')
        joblib.dump(crf, f'{model_filename}.pkl')
        _run.add_artifact(f'{model_filename}.pkl')
示例#5
0
def eval_link_prediction(model,
                         triples_loader,
                         text_dataset,
                         entities,
                         epoch,
                         emb_batch_size,
                         _run: Run,
                         _log: Logger,
                         prefix='',
                         max_num_batches=None,
                         filtering_graph=None,
                         new_entities=None,
                         return_embeddings=False):
    compute_filtered = filtering_graph is not None
    mrr_by_position = torch.zeros(3, dtype=torch.float).to(device)
    mrr_pos_counts = torch.zeros_like(mrr_by_position)

    rel_categories = triples_loader.dataset.rel_categories.to(device)
    mrr_by_category = torch.zeros([2, 4], dtype=torch.float).to(device)
    mrr_cat_count = torch.zeros([1, 4], dtype=torch.float).to(device)

    hit_positions = [1, 2, 3, 5, 8]  # 3610
    hits_at_k = {pos: 0.0 for pos in hit_positions}
    mrr = 0.0
    mrr_filt = 0.0
    hits_at_k_filt = {pos: 0.0 for pos in hit_positions}

    if device != torch.device('cpu'):
        model = model.module

    if isinstance(model, models.InductiveLinkPrediction):
        num_entities = entities.shape[0]
        if compute_filtered:
            max_ent_id = max(filtering_graph.nodes)
        else:
            max_ent_id = entities.max()
        ent2idx = utils.make_ent2idx(entities, max_ent_id)
    else:
        # Transductive models have a lookup table of embeddings
        num_entities = model.ent_emb.num_embeddings
        ent2idx = torch.arange(num_entities)
        entities = ent2idx

    # Create embedding lookup table for evaluation
    ent_emb = torch.zeros((num_entities, model.dim),
                          dtype=torch.float,
                          device=device)
    idx = 0
    num_iters = np.ceil(num_entities / emb_batch_size)
    iters_count = 0
    while idx < num_entities:
        # Get a batch of entity IDs and encode them
        batch_ents = entities[idx:idx + emb_batch_size]

        if isinstance(model, models.InductiveLinkPrediction):
            # Encode with entity descriptions
            data = text_dataset.get_entity_description(batch_ents)
            text_tok, text_mask, text_len = data
            batch_emb = model(
                text_tok.unsqueeze(1).to(device),
                text_mask.unsqueeze(1).to(device))
        else:
            # Encode from lookup table
            batch_emb = model(batch_ents)

        ent_emb[idx:idx + batch_ents.shape[0]] = batch_emb

        iters_count += 1
        if iters_count % np.ceil(0.2 * num_iters) == 0:
            _log.info(f'[{idx + batch_ents.shape[0]:,}/{num_entities:,}]')

        idx += emb_batch_size

    ent_emb = ent_emb.unsqueeze(0)

    batch_count = 0
    _log.info('Computing metrics on set of triples')
    total = len(triples_loader) if max_num_batches is None else max_num_batches
    for i, triples in enumerate(triples_loader):
        print(type(triples))
        if max_num_batches is not None and i == max_num_batches:
            break

        heads, tails, rels = torch.chunk(triples, chunks=3, dim=1)
        # Map entity IDs to positions in ent_emb
        heads = ent2idx[heads].to(device)
        tails = ent2idx[tails].to(device)

        assert heads.min() >= 0
        assert tails.min() >= 0

        # Embed triple
        head_embs = ent_emb.squeeze()[heads]
        tail_embs = ent_emb.squeeze()[tails]
        rel_embs = model.rel_emb(rels.to(device))

        # Score all possible heads and tails
        heads_predictions = model.score_fn(ent_emb, tail_embs, rel_embs)
        tails_predictions = model.score_fn(head_embs, ent_emb, rel_embs)

        pred_ents = torch.cat((heads_predictions, tails_predictions))
        true_ents = torch.cat((heads, tails))

        hits = utils.hit_at_k(pred_ents, true_ents, hit_positions)
        for j, h in enumerate(hits):
            hits_at_k[hit_positions[j]] += h
        mrr += utils.mrr(pred_ents, true_ents).mean().item()

        if compute_filtered:
            filters = utils.get_triple_filters(triples, filtering_graph,
                                               num_entities, ent2idx)
            heads_filter, tails_filter = filters
            # Filter entities by assigning them the lowest score in the batch
            filter_mask = torch.cat((tails_filter, tails_filter)).to(device)
            pred_ents[filter_mask] = pred_ents.min() - 1.0
            hits_filt = utils.hit_at_k(pred_ents, true_ents, hit_positions)
            for j, h in enumerate(hits_filt):
                hits_at_k_filt[hit_positions[j]] += h
            mrr_filt_per_triple = utils.mrr(pred_ents, true_ents)
            mrr_filt += mrr_filt_per_triple.mean().item()

            if new_entities is not None:
                by_position = utils.split_by_new_position(
                    triples, mrr_filt_per_triple, new_entities)
                batch_mrr_by_position, batch_mrr_pos_counts = by_position
                mrr_by_position += batch_mrr_by_position
                mrr_pos_counts += batch_mrr_pos_counts

            if triples_loader.dataset.has_rel_categories:
                by_category = utils.split_by_category(triples,
                                                      mrr_filt_per_triple,
                                                      rel_categories)
                batch_mrr_by_cat, batch_mrr_cat_count = by_category
                mrr_by_category += batch_mrr_by_cat
                mrr_cat_count += batch_mrr_cat_count

        batch_count += 1
        if (i + 1) % int(0.2 * total) == 0:
            _log.info(f'[{i + 1:,}/{total:,}]')

    for hits_dict in (hits_at_k, hits_at_k_filt):
        for k in hits_dict:
            hits_dict[k] /= batch_count

    mrr = mrr / batch_count
    mrr_filt = mrr_filt / batch_count

    log_str = f'{prefix} mrr: {mrr:.4f}  '
    _run.log_scalar(f'{prefix}_mrr', mrr, epoch)
    for k, value in hits_at_k.items():
        log_str += f'hits@{k}: {value:.4f}  '
        _run.log_scalar(f'{prefix}_hits@{k}', value, epoch)

    if compute_filtered:
        log_str += f'mrr_filt: {mrr_filt:.4f}  '
        _run.log_scalar(f'{prefix}_mrr_filt', mrr_filt, epoch)
        for k, value in hits_at_k_filt.items():
            log_str += f'hits@{k}_filt: {value:.4f}  '
            _run.log_scalar(f'{prefix}_hits@{k}_filt', value, epoch)

    _log.info(log_str)

    if new_entities is not None and compute_filtered:
        mrr_pos_counts[mrr_pos_counts < 1.0] = 1.0
        mrr_by_position = mrr_by_position / mrr_pos_counts
        log_str = ''
        for i, t in enumerate(
            (f'{prefix}_mrr_filt_both_new', f'{prefix}_mrr_filt_head_new',
             f'{prefix}_mrr_filt_tail_new')):
            value = mrr_by_position[i].item()
            log_str += f'{t}: {value:.4f}  '
            _run.log_scalar(t, value, epoch)
        _log.info(log_str)

    if compute_filtered and triples_loader.dataset.has_rel_categories:
        mrr_cat_count[mrr_cat_count < 1.0] = 1.0
        mrr_by_category = mrr_by_category / mrr_cat_count

        for i, case in enumerate(['pred_head', 'pred_tail']):
            log_str = f'{case} '
            for cat, cat_id in CATEGORY_IDS.items():
                log_str += f'{cat}_mrr: {mrr_by_category[i, cat_id]:.4f}  '
            _log.info(log_str)

    if return_embeddings:
        out = (mrr, ent_emb)
    else:
        out = (mrr, None)

    return out
示例#6
0
def embedding_generation(dataset, dim, model, rel_model, loss_fn, encoder_name,
                         regularizer, max_len, num_negatives, batch_size,
                         emb_batch_size, max_epochs, checkpoint,
                         use_cached_text, _run: Run, _log: Logger):
    drop_stopwords = model in {
        'bert-bow', 'bert-dkrl', 'glove-bow', 'glove-dkrl'
    }
    # converted KG as input
    triples_file = f'data/{dataset}/all-triples.tsv'

    if device != torch.device('cpu'):
        num_devices = torch.cuda.device_count()
        if batch_size % num_devices != 0:
            raise ValueError(f'Batch size ({batch_size}) must be a multiple of'
                             f' the number of CUDA devices ({num_devices})')
        _log.info(f'CUDA devices used: {num_devices}')
    else:
        num_devices = 1
        _log.info('Training on CPU')

    if model == 'transductive':
        train_data = GraphDataset(triples_file,
                                  num_negatives,
                                  write_maps_file=True,
                                  num_devices=num_devices)
    else:
        if model.startswith('bert') or model == 'blp':
            tokenizer = BertTokenizer.from_pretrained(encoder_name)
        else:
            tokenizer = GloVeTokenizer('data/glove/glove.6B.300d-maps.pt')

        train_data = TextGraphDataset(triples_file,
                                      num_negatives,
                                      max_len,
                                      tokenizer,
                                      drop_stopwords,
                                      write_maps_file=True,
                                      use_cached_text=use_cached_text,
                                      num_devices=num_devices)

    # train_loader = DataLoader(train_data, batch_size, shuffle=True,
    #                           collate_fn=train_data.collate_fn,
    #                           num_workers=0, drop_last=True)

    # Build graph with all triples to compute filtered metrics
    graph = nx.MultiDiGraph()
    all_triples = torch.tensor(train_data.triples)
    graph.add_weighted_edges_from(all_triples.tolist())
    train_ent = set(train_data.entities.tolist())

    _run.log_scalar('num_train_entities', len(train_ent))

    train_ent = torch.tensor(list(train_ent))

    model = utils.get_model(model, dim, rel_model, loss_fn, len(train_ent),
                            train_data.num_rels, encoder_name, regularizer)

    if device != torch.device('cpu'):
        model = torch.nn.DataParallel(model).to(device)

    tokens = str(dataset).split("_")
    print(tokens[-1])

    # load language model
    if tokens[-1] == "questions":
        model.load_state_dict(torch.load("models/model-questions.pt"))
    else:
        model.load_state_dict(torch.load("models/model-entities.pt"))
    _log.info('Evaluating on training set, Embedding generation')
    ent_emb = embedding(model,
                        train_data,
                        train_ent,
                        emb_batch_size,
                        filtering_graph=None)

    # Save final entity embeddings obtained with trained encoder
    torch.save(ent_emb, osp.join(OUT_PATH, f'ent_emb-{_run._id}.pt'))
    torch.save(train_ent, osp.join(OUT_PATH, f'ents-{_run._id}.pt'))
示例#7
0
def link_prediction(dataset, inductive, dim, model, rel_model, loss_fn,
                    encoder_name, regularizer, max_len, num_negatives, lr,
                    use_scheduler, batch_size, emb_batch_size, eval_batch_size,
                    max_epochs, checkpoint, use_cached_text, _run: Run,
                    _log: Logger):
    drop_stopwords = model in {
        'bert-bow', 'bert-dkrl', 'glove-bow', 'glove-dkrl'
    }

    prefix = 'ind-' if inductive and model != 'transductive' else ''
    triples_file = f'data/{dataset}/{prefix}train.tsv'

    if device != torch.device('cpu'):
        num_devices = torch.cuda.device_count()
        if batch_size % num_devices != 0:
            raise ValueError(f'Batch size ({batch_size}) must be a multiple of'
                             f' the number of CUDA devices ({num_devices})')
        _log.info(f'CUDA devices used: {num_devices}')
    else:
        num_devices = 1
        _log.info('Training on CPU')

    if model == 'transductive':
        train_data = GraphDataset(triples_file,
                                  num_negatives,
                                  write_maps_file=True,
                                  num_devices=num_devices)
    else:
        if model.startswith('bert') or model == 'blp':
            tokenizer = BertTokenizer.from_pretrained(encoder_name)
        else:
            tokenizer = GloVeTokenizer('data/glove/glove.6B.300d-maps.pt')

        train_data = TextGraphDataset(triples_file,
                                      num_negatives,
                                      max_len,
                                      tokenizer,
                                      drop_stopwords,
                                      write_maps_file=True,
                                      use_cached_text=use_cached_text,
                                      num_devices=num_devices)

    train_loader = DataLoader(train_data,
                              batch_size,
                              shuffle=True,
                              collate_fn=train_data.collate_fn,
                              num_workers=0,
                              drop_last=True)

    train_eval_loader = DataLoader(train_data, eval_batch_size)

    valid_data = GraphDataset(f'data/{dataset}/{prefix}dev.tsv')
    valid_loader = DataLoader(valid_data, eval_batch_size)

    test_data = GraphDataset(f'data/{dataset}/{prefix}test.tsv')
    test_loader = DataLoader(test_data, eval_batch_size)

    # Build graph with all triples to compute filtered metrics
    if dataset != 'Wikidata5M':
        graph = nx.MultiDiGraph()
        all_triples = torch.cat(
            (train_data.triples, valid_data.triples, test_data.triples))
        graph.add_weighted_edges_from(all_triples.tolist())

        train_ent = set(train_data.entities.tolist())
        train_val_ent = set(valid_data.entities.tolist()).union(train_ent)
        train_val_test_ent = set(
            test_data.entities.tolist()).union(train_val_ent)
        val_new_ents = train_val_ent.difference(train_ent)
        test_new_ents = train_val_test_ent.difference(train_val_ent)
    else:
        graph = None

        train_ent = set(train_data.entities.tolist())
        train_val_ent = set(valid_data.entities.tolist())
        train_val_test_ent = set(test_data.entities.tolist())
        val_new_ents = test_new_ents = None

    _run.log_scalar('num_train_entities', len(train_ent))

    train_ent = torch.tensor(list(train_ent))
    train_val_ent = torch.tensor(list(train_val_ent))
    train_val_test_ent = torch.tensor(list(train_val_test_ent))

    model = utils.get_model(model, dim, rel_model, loss_fn,
                            len(train_val_test_ent), train_data.num_rels,
                            encoder_name, regularizer)
    if checkpoint is not None:
        model.load_state_dict(torch.load(checkpoint, map_location='cpu'))

    if device != torch.device('cpu'):
        model = torch.nn.DataParallel(model).to(device)

    optimizer = Adam(model.parameters(), lr=lr)
    total_steps = len(train_loader) * max_epochs
    if use_scheduler:
        warmup = int(0.2 * total_steps)
        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=warmup, num_training_steps=total_steps)
    best_valid_mrr = 0.0
    checkpoint_file = osp.join(OUT_PATH, f'model-{_run._id}.pt')
    for epoch in range(1, max_epochs + 1):
        train_loss = 0
        for step, data in enumerate(train_loader):
            loss = model(*data).mean()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if use_scheduler:
                scheduler.step()

            train_loss += loss.item()

            if step % int(0.05 * len(train_loader)) == 0:
                _log.info(f'Epoch {epoch}/{max_epochs} '
                          f'[{step}/{len(train_loader)}]: {loss.item():.6f}')
                _run.log_scalar('batch_loss', loss.item())

        _run.log_scalar('train_loss', train_loss / len(train_loader), epoch)

        if dataset != 'Wikidata5M':
            _log.info('Evaluating on sample of training set')
            eval_link_prediction(model,
                                 train_eval_loader,
                                 train_data,
                                 train_ent,
                                 epoch,
                                 emb_batch_size,
                                 prefix='train',
                                 max_num_batches=len(valid_loader))

        _log.info('Evaluating on validation set')
        val_mrr, _ = eval_link_prediction(model,
                                          valid_loader,
                                          train_data,
                                          train_val_ent,
                                          epoch,
                                          emb_batch_size,
                                          prefix='valid')

        # Keep checkpoint of best performing model (based on raw MRR)
        if val_mrr > best_valid_mrr:
            best_valid_mrr = val_mrr
            torch.save(model.state_dict(), checkpoint_file)

    # Evaluate with best performing checkpoint
    if max_epochs > 0:
        model.load_state_dict(torch.load(checkpoint_file))

    if dataset == 'Wikidata5M':
        graph = nx.MultiDiGraph()
        graph.add_weighted_edges_from(valid_data.triples.tolist())

    if dataset == 'Wikidata5M':
        graph = nx.MultiDiGraph()
        graph.add_weighted_edges_from(test_data.triples.tolist())

    _log.info('Evaluating on test set')
    _, ent_emb = eval_link_prediction(model,
                                      test_loader,
                                      train_data,
                                      train_val_test_ent,
                                      max_epochs + 1,
                                      emb_batch_size,
                                      prefix='test',
                                      filtering_graph=graph,
                                      new_entities=test_new_ents,
                                      return_embeddings=True)

    # Save final entity embeddings obtained with trained encoder
    torch.save(ent_emb, osp.join(OUT_PATH, f'ent_emb-{_run._id}.pt'))
    torch.save(train_val_test_ent, osp.join(OUT_PATH, f'ents-{_run._id}.pt'))
示例#8
0
def train(train_corpus: str,
          dev_corpus: str,
          pacrf: str,
          model_filename: str,
          labels: List,
          c1: float = 0.0,
          c2: float = 1.0,
          algorithm: str = 'lbfgs',
          max_iterations: int = None,
          all_possible_transitions: bool = False,
          window_size: int = 0,
          _run: Run = None,
          _log: logger = None):
    """
    running crf experiment
    """
    _run.add_resource(train_corpus)
    _run.add_resource(dev_corpus)
    train_sents, _ = get_tagged_sents_and_words(train_corpus)
    dev_sents, _ = get_tagged_sents_and_words(dev_corpus)

    tmp_train = tempfile.NamedTemporaryFile(mode='w+')
    # temp_train_corpus = open(f'{model_filename}-{train_corpus}.feature', mode='w+')
    print_corpus(train_sents, labels, tmp_train, window_size=window_size)

    # X_dev = [sent2features(s, window_size) for s in dev_sents]
    y_dev = [sent2labels_colmap(s, col=1) for s in dev_sents]

    tmp_dev = tempfile.NamedTemporaryFile(mode='w+')
    # temp_test_corpus = open(f'{model_filename}-{test_corpus}.feature', mode='w+')
    print_corpus(dev_sents, labels, tmp_dev, window_size=window_size)

    # to call partial-crf via Popen command
    # command = f'{pacrf} learn -m {model_filename} -a {algorithm} {temp_train_corpus}'
    # call([pacrf, "--help"])

    crfsuire_proc = Popen([pacrf, "learn", "-m", model_filename, "-a", algorithm, \
                           "-p", f"c1={c1}", "-p", f"c2={c2}", tmp_train.name])
    out, err = crfsuire_proc.communicate()
    print(out)
    print(err)
    # os.system(f'{pacrf} learn -m {model_filename} -a {algorithm} {tmp_train.name}')

    tmp_train.close()

    tmp_pred = tempfile.NamedTemporaryFile(mode='w+')

    # cmd_out([pacrf, "tag", "-m", model_filename, tmp_dev.name, ">", tmp_pred.name])

    _run.add_artifact(model_filename)

    # TODO modified this to call partial-crf via Popen command
    # y_pred = crf.predict(X_dev)
    y_pred = get_tagged_sents_and_words(tmp_pred.name)
    print(y_pred)
    y_pred = [sent2labels_colmap(s, 0) for s in y_pred]

    # TODO modified this to read partial-crf via tempfile
    overall, by_type = evaluate(y_dev, y_pred)
    tmp_pred.close()
    tmp_dev.close()

    _run.info[f'overall_f1'] = overall.f1_score
    _run.log_scalar('overall_f1', overall.f1_score)
    _run.info[f'overall_precision'] = overall.precision
    _run.log_scalar('overall_precision', overall.precision)
    _run.info[f'overall_recall'] = overall.recall
    _run.log_scalar('overall_recall', overall.recall)
    _log.info(f'Overall F1 score: {overall.f1_score}')
    for _, key in enumerate(sorted(by_type.keys())):
        for metric_key in by_type[key]._fields:
            metric_val = getattr(by_type[key], metric_key)
            _run.info[f'{key}-{metric_key}'] = metric_val
            _run.log_scalar(f'{key}-{metric_key}', metric_val)
            _log.info(f'{key}-{metric_key}: {metric_val}')