def get_concat_dataset(data_dir, has_target_edges=False):
    files = sorted([
        os.path.join(data_dir, f) for f in os.listdir(data_dir)
        if f.endswith('.h5')
    ])
    datasets = [NarrativeGraphDataset(f, has_target_edges) for f in files]
    return ConcatDataset(datasets)
def _task_pp_discourse_link_v2(model, test_data, results, **kwargs):
    if kwargs['q'] is None:
        return
    input_edges, correct, choices = kwargs['q']
    _, g_binputs, g_target_idxs, g_nid2rows, _, _ = \
        get_ng_inputs(test_data)
    input_edges = torch.from_numpy(input_edges.astype('float32'))
    disc_ridx_list = kwargs['disc_ridx_list']
    gold_target_edge = choices[correct]

    # create target edge choices
    correct_idx = None
    target_edges = torch.ones((4, len(disc_ridx_list)), dtype=torch.int64)
    target_edges[0, :] = gold_target_edge[0]
    target_edges[2, :] = gold_target_edge[2]
    for i, ridx in enumerate(disc_ridx_list):
        if ridx == gold_target_edge[1]:
            correct_idx = i
        target_edges[1, i] = ridx
    target_edges = target_edges.unsqueeze(0)

    n_nodes = g_nid2rows.shape[0]
    gs = get_dgl_graph_list([input_edges], n_nodes)

    batch = {
        'bg': [gs],
        'input_ids': g_binputs[0],
        'input_masks': g_binputs[1],
        'token_type_ids': g_binputs[2],
        'target_idxs': g_target_idxs,
        'nid2rows': [g_nid2rows],
        'n_instances': [g_binputs.shape[1]],
        'target_edges': [target_edges]
    }
    if 'predict' not in results:
        results['predict'] = []
    if 'y' not in results:
        results['y'] = []
    if 'scores' not in results:
        results['scores'] = []
    with torch.no_grad():
        if args.gpu_id !=  -1:
            batch = NarrativeGraphDataset.to_gpu(batch, args.gpu_id)
        pred_scores, y = model(mode='predict', **batch)
        pred_scores, y = pred_scores.cpu(), y.cpu()

        pred_idx = torch.argmax(pred_scores).item()

        logger.debug('scores={}, pred={}, y={}'.format(pred_scores, pred_idx, correct_idx))

        # note that we put the real rtype idx here
        results['predict'].append(disc_ridx_list[pred_idx])
        results['y'].append(disc_ridx_list[correct_idx])
        results['scores'].append(pred_scores.tolist())
def _task_pp_any_next_v2(model, test_data, results, **kwargs):
    if kwargs['q'] is None:
        return
    input_edges, correct, choices = kwargs['q']
    _, g_binputs, g_target_idxs, g_nid2rows, _, _ = \
        get_ng_inputs(test_data)
    input_edges = torch.from_numpy(input_edges.astype('float32'))

    n_choices = len(choices)
    target_edges = torch.ones((4, n_choices), dtype=torch.int64)
    for i in range(n_choices):
        for j in range(3):
            target_edges[j, i] = choices[i][j]
    target_edges = target_edges.unsqueeze(0)

    n_nodes = g_nid2rows.shape[0]
    gs = get_dgl_graph_list([input_edges], n_nodes)

    batch = {
        'bg': [gs],
        'input_ids': g_binputs[0],
        'input_masks': g_binputs[1],
        'token_type_ids': g_binputs[2],
        'target_idxs': g_target_idxs,
        'nid2rows': [g_nid2rows],
        'n_instances': [g_binputs.shape[1]],
        'target_edges': [target_edges]
    }
    with torch.no_grad():
        if args.gpu_id !=  -1:
            batch = NarrativeGraphDataset.to_gpu(batch, args.gpu_id)
        pred_scores, y = model(mode='predict', **batch)
        pred_scores, y = pred_scores.cpu(), y.cpu()

        pred = torch.argmax(pred_scores)
        if 'predict' not in results:
            results['predict'] = []
        if 'y' not in results:
            results['y'] = []

        logger.debug('scores={}, pred={}, y={}'.format(pred_scores, pred, correct))
        results['predict'].append(pred.item())
        results['y'].append(correct)
def coref_evaluate(local_rank,
                   model,
                   dataloader,
                   gpu,
                   get_prec_recall_f1=False,
                   logger=None):
    model.eval()
    cm = np.zeros((2, 2), dtype=np.int64)
    n_examples = np.zeros((2, ), dtype=np.int64)
    with torch.no_grad():
        for batch in tqdm(dataloader,
                          desc='evaluating',
                          disable=(local_rank not in [-1, 0])):
            # to GPU
            if gpu != -1:
                batch = NarrativeGraphDataset.to_gpu(batch, gpu)

            # forward pass
            pred_scores, y = model(mode='predict', **batch)
            y_pred = (pred_scores.cpu() >= 0.5).long()
            y = y.cpu()

            # record
            cm += confusion_matrix(y, y_pred)

            n_pos = y.sum()
            n_examples[1] += n_pos
            n_examples[0] += (y.shape[0] - n_pos)

    # F1 on positive class
    tn, fp, fn, tp = cm.ravel()
    prec = tp / (tp + fp) if tp + fp != 0 else 0.0
    recall = tp / (tp + fn) if tp + fn != 0 else 0.0
    f1 = (2.0 * prec * recall) / (prec + recall) if prec + recall != 0 else 0.0
    logger.info('tp={}, tn={}, fp={}, fn={}'.format(tp, tn, fp, fn))
    logger.info('#pos={}, #neg={}, prec={}, recall={}, f1={}'.format(
        n_examples[1], n_examples[0], prec, recall, f1))

    if get_prec_recall_f1:
        return prec, recall, f1
    return f1
def _task_triplet_classification(model, test_data, results, **kwargs):
    # predict for one graph
    count_correct = 0

    # because we have to get target edges from pkl for some tasks
    # we don't use the NarrativeGraphDataset here
    # instead, we re-write it for each task
    _, binputs, target_idxs, nid2rows, target_edges, input_edges = \
        get_ng_inputs(test_data)
    n_nodes = nid2rows.shape[0]
    gs = get_dgl_graph_list(input_edges, n_nodes)

    gold_rtypes = []
    for te in target_edges:
        gold_rtypes.append(te[1])
    gold_rtypes = torch.cat(gold_rtypes, dim=0).tolist()

    with torch.no_grad():
        batch = {
            'bg': [gs],
            'input_ids': binputs[0],
            'input_masks': binputs[1],
            'token_type_ids': binputs[2],
            'target_idxs': target_idxs,
            'nid2rows': [nid2rows],
            'n_instances': [binputs.shape[1]],
            'target_edges': [target_edges]
        }
        if args.gpu_id !=  -1:
            batch = NarrativeGraphDataset.to_gpu(batch, args.gpu_id)
        pred_scores, y = model(mode='predict', **batch)
        pred_scores, y = pred_scores.cpu(), y.cpu()

        thr = 0.5
        y_pred = (pred_scores >= thr).long().tolist()
        y = y.tolist()
        for r, yp, gold in zip(gold_rtypes, y_pred, y):
            if r not in results:
                results[r] = []
            results[r].append((yp, gold))
def prepare_train_dataset(f, local_rank, args):
    train_dataset = NarrativeGraphDataset(f)
    if args.n_gpus > 1:
        train_sampler = DistributedSampler(train_dataset,
                                           num_replicas=args.n_gpus,
                                           rank=local_rank)
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=args.train_batch_size,  # fix 1 for sampling
            shuffle=False,
            num_workers=1,  # 1 is safe for hdf5
            collate_fn=train_collate,
            sampler=train_sampler)
    else:
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=args.train_batch_size,  # fix 1 for sampling
            shuffle=True,
            num_workers=1,  # 1 is safe for hdf5
            collate_fn=train_collate)
        train_sampler = None
    return train_dataset, train_sampler, train_dataloader
def eval_one_mcnc_question(model, test_data, results, q, coref_ridx,
                           score_pooling):
    correct, choices, target_edge_idx, chain_info = q
    ng_edges, bert_inputs, target_idxs, nid2rows, coref_nids = \
        get_ng_inputs(test_data)

    # prepare input edges
    input_edges = torch.cat(
        (ng_edges[:, :target_edge_idx], ng_edges[:, target_edge_idx + 1:]),
        dim=1)
    n_nodes = nid2rows.shape[0]
    gs = get_dgl_graph_list([input_edges], n_nodes)

    # prepare target edges
    chain_len = coref_nids.shape[0]
    n_choices = len(choices)

    if score_pooling == 'last':
        d = n_choices
    else:
        d = (chain_len - 1) * n_choices

    target_edges = torch.ones((4, d), dtype=torch.int64)
    target_edges[1, :] = coref_ridx
    if score_pooling != 'last':
        for i in range(n_choices):
            ch = choices[i]
            for j in range(chain_len - 1):
                head = coref_nids[j]

                k = i * (chain_len - 1) + j
                target_edges[0, k] = head
                target_edges[2, k] = ch[2]
    else:
        head = coref_nids[-2]
        for i in range(n_choices):
            ch = choices[i]
            target_edges[0, i] = head
            target_edges[2, i] = ch[2]
    target_edges = target_edges.unsqueeze(0)

    batch = {
        'bg': [gs],
        'input_ids': bert_inputs[0],
        'input_masks': bert_inputs[1],
        'token_type_ids': bert_inputs[2],
        'target_idxs': target_idxs,
        'nid2rows': [nid2rows],
        'n_instances': [bert_inputs.shape[1]],
        'target_edges': [target_edges]
    }
    with torch.no_grad():
        if args.gpu_id != -1:
            batch = NarrativeGraphDataset.to_gpu(batch, args.gpu_id)
        pred_scores, y = model(mode='predict', **batch)

        if score_pooling == 'mean':
            cand_scores = pred_scores.view(-1, chain_len - 1).mean(1)
        elif score_pooling == 'max':
            cand_scores = pred_scores.view(-1, chain_len - 1).max(1)[0]
        else:  # last
            cand_scores = pred_scores
        pred = torch.argmax(cand_scores).cpu()

        if 'predict' not in results:
            results['predict'] = []
        if 'y' not in results:
            results['y'] = []

        logger.debug('scores={}, pred={}, y={}'.format(cand_scores, pred,
                                                       correct))
        results['predict'].append(pred.item())
        results['y'].append(correct)
def train(rank, args):
    logger = utils.get_root_logger(args, log_fname='log_rank{}'.format(rank))
    if args.n_gpus > 1:
        local_rank = rank
        args.gpu_id = rank
    else:
        local_rank = -1

    if args.n_gpus > 1:
        dist.init_process_group(backend='nccl',
                                init_method='env://',
                                world_size=args.n_gpus,
                                rank=local_rank)

    set_seed(
        args.gpu_id, args.seed
    )  # in distributed training, this has to be same for all processes

    logger.info('local_rank = {}, n_gpus = {}'.format(local_rank, args.n_gpus))
    logger.info('n_epochs = {}'.format(args.n_epochs))

    if args.gpu_id != -1:
        torch.cuda.set_device(args.gpu_id)

    # initialize training essentials
    if local_rank in [-1, 0]:
        dev_dataset = get_concat_dataset(args.dev_dir, has_target_edges=True)
        dev_dataloader = DataLoader(dev_dataset,
                                    batch_size=args.eval_batch_size,
                                    shuffle=False,
                                    num_workers=1,
                                    collate_fn=test_collate,
                                    pin_memory=False)

        torch.save(args, os.path.join(args.output_dir, 'training_args.bin'))
        tb_writer = SummaryWriter(comment='_{}'.format(args.output_dir))

    train_files = [
        os.path.join(args.train_dir, f) for f in os.listdir(args.train_dir)
        if f.endswith('.h5')
    ]
    t1 = time.time()
    if args.n_train_instances != -1:  # save loading time
        n_instances = args.n_train_instances
    else:
        n_instances = get_n_instances(args.train_dir)
    logger.info('get_n_instances = {}: {}s'.format(n_instances,
                                                   time.time() - t1))

    # NG config
    ng_config = json.load(open(args.ng_config))
    assert ng_config["config_target"] == "narrative_graph"
    rtype2idx = ng_config['rtype2idx']
    if ng_config['no_entity']:
        ep_rtype_rev = {}
        ent_pred_ridxs = set()
    else:
        ep_rtype_rev = {
            rtype2idx[v]: rtype2idx[k]
            for k, v in ng_config['entity_predicate_rtypes'].items()
        }
        ent_pred_ridxs = set(ep_rtype_rev.keys())
    coref_ridx = rtype2idx['cnext']

    n_rtypes = len(rtype2idx)
    pred_pred_ridxs = set(range(n_rtypes)) - ent_pred_ridxs
    if args.sample_coref_only:
        pp_ridx2distr = get_pp_ridx2distr_coref(ng_config)
    else:
        pp_ridx2distr = get_pp_ridx2distr(ng_config)

    # model config
    model_config = json.load(open(args.model_config, 'r'))

    model = get_init_model(args, logger)
    optimizer = get_optimizer(args, model, logger)
    scheduler = get_scheduler(args, n_instances, optimizer, logger)

    # training
    dev_ridx = coref_ridx if args.dev_coref else -1
    if local_rank in [-1, 0]:
        logger.info("***** Running training *****")
        logger.info("  Num Epochs = %d", args.n_epochs)
        logger.info("  Training batch size = %d", args.train_batch_size)
        logger.info("  Evaluation batch size = %d", args.eval_batch_size)
        logger.info("  Accu. train batch size = %d",
                    args.train_batch_size * args.gradient_accumulation_steps)
        logger.info("  Gradient Accumulation steps = %d",
                    args.gradient_accumulation_steps)
        logger.info("  Weight Decay = {}".format(args.weight_decay))
        logger.info("  Learning Rate = {}".format(args.lr))

        # first eval result
        if args.no_first_eval:
            best_metric = 0.0
        else:
            best_metric = evaluate(local_rank,
                                   model,
                                   dev_dataloader,
                                   args.gpu_id,
                                   model_name=args.model_name,
                                   logger=logger,
                                   dev_ridx=dev_ridx)
        logger.info('start dev_metric = {}'.format(best_metric))
    else:
        best_metric = 0.0

    step = 0
    prev_acc_loss, acc_loss = 0.0, 0.0
    t1 = time.time()
    model.zero_grad()
    for i_epoch in range(args.n_epochs):
        logger.info('========== Epoch {} =========='.format(i_epoch))

        t2 = time.time()
        random.shuffle(train_files)  # shuffle files
        for i_file, f in enumerate(train_files):
            logger.debug('file = {}'.format(f))
            logger.info('{} / {} files completed'.format(
                i_file, len(train_files)))

            # load one dataset (file) in memory
            t3 = time.time()
            train_dataset, train_sampler, train_dataloader = \
                prepare_train_dataset(f, local_rank, args)

            # training on batches
            if args.n_gpus > 1:
                train_sampler.set_epoch(i_epoch)
            for train_batch in tqdm(train_dataloader,
                                    desc='training on one file',
                                    disable=True):
                model.train()
                # truncate graph
                train_batch = train_sample_truncated_ng(
                    **train_batch,
                    ent_pred_ridxs=ent_pred_ridxs,
                    pred_pred_ridxs=pred_pred_ridxs,
                    ep_rtype_rev=ep_rtype_rev,
                    n_truncated_ng=args.n_truncated_ng,
                    edge_sample_rate=args.edge_sample_rate,
                    n_neg_per_pos=args.n_neg_per_pos,
                    pp_ridx2distr=pp_ridx2distr,
                    coref_ridx=coref_ridx,
                    sample_entity_only=args.sample_entity_only)

                # to GPU
                if args.gpu_id != -1:
                    train_batch = NarrativeGraphDataset.to_gpu(
                        train_batch, args.gpu_id)

                # forward pass:
                loss = model(mode='loss', **train_batch)

                if args.n_gpus > 1:
                    loss = loss.mean(
                    )  # mean() to average on multi-gpu parallel training
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                # backward pass
                loss.backward()

                acc_loss += loss.item()

                # accumation
                if (
                        step + 1
                ) % args.gradient_accumulation_steps == 0:  # ignore the last accumulation
                    # update params
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                    optimizer.step()
                    scheduler.step()
                    model.zero_grad()
                    step += 1

                    # loss
                    if local_rank in [-1, 0]:
                        if args.logging_steps > 0 and step % args.logging_steps == 0:
                            cur_loss = (acc_loss -
                                        prev_acc_loss) / args.logging_steps
                            logger.info(
                                'train_loss={}, step={}, time={}s'.format(
                                    cur_loss, step,
                                    time.time() - t1))
                            tb_writer.add_scalar('train_loss', cur_loss, step)
                            tb_writer.add_scalar('lr',
                                                 scheduler.get_last_lr()[0])

                            # evaluate
                            if not args.no_eval:
                                dev_metric = evaluate(
                                    local_rank,
                                    model,
                                    dev_dataloader,
                                    args.gpu_id,
                                    model_name=args.model_name,
                                    logger=logger,
                                    dev_ridx=dev_ridx)
                                logger.info('dev_metric={}'.format(dev_metric))
                                if best_metric < dev_metric:
                                    best_metric = dev_metric

                                    # save
                                    save_model(model, optimizer, scheduler,
                                               args.output_dir, step, logger)
                            else:
                                # simply save model
                                save_model(model, optimizer, scheduler,
                                           args.output_dir, step, logger)
                            prev_acc_loss = acc_loss

            logger.info('done file: {} s'.format(time.time() - t3))
        logger.info('done epoch: {} s'.format(time.time() - t2))

    logger.info('done training: {} s'.format(time.time() - t1))

    t1 = time.time()
    if local_rank in [-1, 0]:
        tb_writer.close()

        del model, dev_dataloader

        # test
        if args.test_dir is not None:
            test_metric = test(local_rank, args.output_dir, args.test_dir,
                               logger, args)
            logger.info('test_metric = {}'.format(test_metric))
    logger.info('done testing: {} s'.format(time.time() - t1))
def basic_evaluate(local_rank,
                   model,
                   dataloader,
                   gpu,
                   get_prec_recall_f1=False,
                   logger=None,
                   dev_ridx=-1):
    model.eval()
    n_output_rels = model.module.num_output_rels if \
        isinstance(model, DistributedDataParallel) else model.num_output_rels
    class_cm = np.zeros((n_output_rels, 2, 2), dtype=np.int64)
    n_examples = np.zeros((n_output_rels, 2), dtype=np.int64)
    with torch.no_grad():
        for batch in tqdm(dataloader,
                          desc='evaluating',
                          disable=(local_rank not in [-1, 0])):
            rels = get_batch_relations(batch['target_edges'])

            # to GPU
            if gpu != -1:
                batch = NarrativeGraphDataset.to_gpu(batch, gpu)

            # forward pass
            pred_scores, y = model(mode='predict', **batch)
            y_pred = (pred_scores.cpu() >= 0.5).long()
            y = y.cpu()
            class_measure(class_cm, n_examples, rels, y_pred, y, n_output_rels)

    # macro-averaged
    precisions, recalls = [], []
    for i in range(n_output_rels):
        tn, fp, fn, tp = class_cm[i].ravel()
        c_prec = tp / (tp + fp) if tp + fp != 0 else 0.0
        c_recall = tp / (tp + fn) if tp + fn != 0 else 0.0
        c_f1 = (2.0 * c_prec * c_recall) / (
            c_prec + c_recall) if c_prec + c_recall != 0 else 0.0

        precisions.append(c_prec)
        recalls.append(c_recall)
        if logger:
            logger.info(
                'class={}, #pos={}, #neg={}, prec={}, recall={}, f1={}'.format(
                    i, n_examples[i][1], n_examples[i][0], c_prec, c_recall,
                    c_f1))

    if dev_ridx == -1:
        prec_macro = sum(precisions) / len(precisions)
        recall_macro = sum(recalls) / len(recalls)
        f1_macro = (2 * prec_macro * recall_macro) / (prec_macro + recall_macro) if \
            (prec_macro + recall_macro) != 0 else 0.0
        if get_prec_recall_f1:
            return prec_macro, recall_macro, f1_macro
        return f1_macro
    else:
        prec = precisions[dev_ridx]
        recall = recalls[dev_ridx]
        f1 = (2 * prec * recall) / (prec + recall) if (prec +
                                                       recall) != 0 else 0.0
        if get_prec_recall_f1:
            return prec, recall, f1
        return f1
def transe_evaluate(local_rank,
                    model,
                    dataloader,
                    gpu,
                    get_prec_recall_f1=False,
                    logger=None,
                    dev_ridx=-1):
    model.eval()
    n_output_rels = model.module.num_output_rels if \
        isinstance(model, DistributedDataParallel) else model.num_output_rels
    class_cm = np.zeros((n_output_rels, 2, 2), dtype=np.int64)
    n_examples = np.zeros((n_output_rels, 2), dtype=np.int64)
    all_rels, all_pred_scores, all_ys = [], [], []
    with torch.no_grad():
        for batch in tqdm(dataloader,
                          desc='evaluating',
                          disable=(local_rank not in [-1, 0])):
            rels = get_batch_relations(batch['target_edges'])

            # to GPU
            if gpu != -1:
                batch = NarrativeGraphDataset.to_gpu(batch, gpu)

            # forward pass
            pred_scores, y = model(mode='predict', **batch)

            all_pred_scores.append(pred_scores.cpu())
            all_ys.append(y.cpu())
            all_rels.append(rels)

    all_pred_scores = torch.cat(all_pred_scores, dim=0)
    all_ys = torch.cat(all_ys, dim=0)
    all_rels = torch.cat(all_rels, dim=0)

    # find the threshold that makes the classification
    avg_thr = all_pred_scores.mean()  # the mean should a fast choice
    # search a threshold around the mean
    thr_candidates = [
        avg_thr + step * 0.01 * avg_thr for step in range(-20, 21)
    ]

    n_output_rel = model.module.num_output_rels if isinstance(
        model, DistributedDataParallel) else model.num_output_rels
    best_thr = None
    best_metric = float('-inf')
    for thr in thr_candidates:
        prec_macro, recall_macro, f1_macro = _eval_thr(thr,
                                                       all_pred_scores,
                                                       all_rels,
                                                       all_ys,
                                                       n_output_rels,
                                                       dev_ridx=dev_ridx,
                                                       logger=None)
        if f1_macro > best_metric:
            best_metric = f1_macro
            best_thr = thr

    logger.info('BEST_THRESHOLD = {}'.format(best_thr))
    prec_macro, recall_macro, f1_macro = _eval_thr(best_thr,
                                                   all_pred_scores,
                                                   all_rels,
                                                   all_ys,
                                                   n_output_rels,
                                                   dev_ridx=dev_ridx,
                                                   logger=logger)
    if get_prec_recall_f1:
        return prec_macro, recall_macro, f1_marco
    return f1_macro