예제 #1
0
def main(args):
    # Set up logging
    args.save_dir = util.get_save_dir(args.save_dir, args.name, training=False)
    log = util.get_logger(args.save_dir, args.name)
    log.info(f'Args: {dumps(vars(args), indent=4, sort_keys=True)}')
    device, gpu_ids = util.get_available_devices()
    args.batch_size *= max(1, len(gpu_ids))

    # Get embeddings
    log.info('Loading embeddings...')
    word_vectors = util.torch_from_json(args.word_emb_file)
    char_vectors = util.torch_from_json(args.char_emb_file)
    # Get model
    log.info('Building model...')
    nbr_model = 0
    if (args.load_path_baseline):
        model_baseline = Baseline(word_vectors=word_vectors, hidden_size=100)
        model_baseline = nn.DataParallel(model_baseline, gpu_ids)
        log.info(f'Loading checkpoint from {args.load_path_baseline}...')
        model_baseline = util.load_model(model_baseline,
                                         args.load_path_baseline,
                                         gpu_ids,
                                         return_step=False)
        model_baseline = model_baseline.to(device)
        model_baseline.eval()
        nll_meter_baseline = util.AverageMeter()
        nbr_model += 1
        save_prob_baseline_start = []
        save_prob_baseline_end = []

    if (args.load_path_bidaf):
        model_bidaf = BiDAF(word_vectors=word_vectors,
                            char_vectors=char_vectors,
                            char_emb_dim=args.char_emb_dim,
                            hidden_size=args.hidden_size)
        model_bidaf = nn.DataParallel(model_bidaf, gpu_ids)
        log.info(f'Loading checkpoint from {args.load_path_bidaf}...')
        model_bidaf = util.load_model(model_bidaf,
                                      args.load_path_bidaf,
                                      gpu_ids,
                                      return_step=False)
        model_bidaf = model_bidaf.to(device)
        model_bidaf.eval()
        nll_meter_bidaf = util.AverageMeter()
        nbr_model += 1
        save_prob_bidaf_start = []
        save_prob_bidaf_end = []

    if (args.load_path_bidaf_fusion):
        model_bidaf_fu = BiDAF_fus(word_vectors=word_vectors,
                                   char_vectors=char_vectors,
                                   char_emb_dim=args.char_emb_dim,
                                   hidden_size=args.hidden_size)
        model_bidaf_fu = nn.DataParallel(model_bidaf_fu, gpu_ids)
        log.info(f'Loading checkpoint from {args.load_path_bidaf_fusion}...')
        model_bidaf_fu = util.load_model(model_bidaf_fu,
                                         args.load_path_bidaf_fusion,
                                         gpu_ids,
                                         return_step=False)
        model_bidaf_fu = model_bidaf_fu.to(device)
        model_bidaf_fu.eval()
        nll_meter_bidaf_fu = util.AverageMeter()
        nbr_model += 1
        save_prob_bidaf_fu_start = []
        save_prob_bidaf_fu_end = []

    if (args.load_path_qanet):
        model_qanet = QANet(word_vectors=word_vectors,
                            char_vectors=char_vectors,
                            char_emb_dim=args.char_emb_dim,
                            hidden_size=args.hidden_size,
                            n_heads=args.n_heads,
                            n_conv_emb_enc=args.n_conv_emb,
                            n_conv_mod_enc=args.n_conv_mod,
                            n_emb_enc_blocks=args.n_emb_blocks,
                            n_mod_enc_blocks=args.n_mod_blocks,
                            divisor_dim_kqv=args.divisor_dim_kqv)

        model_qanet = nn.DataParallel(model_qanet, gpu_ids)
        log.info(f'Loading checkpoint from {args.load_path_qanet}...')
        model_qanet = util.load_model(model_qanet,
                                      args.load_path_qanet,
                                      gpu_ids,
                                      return_step=False)
        model_qanet = model_qanet.to(device)
        model_qanet.eval()
        nll_meter_qanet = util.AverageMeter()
        nbr_model += 1
        save_prob_qanet_start = []
        save_prob_qanet_end = []

    if (args.load_path_qanet_old):
        model_qanet_old = QANet_old(word_vectors=word_vectors,
                                    char_vectors=char_vectors,
                                    device=device,
                                    char_emb_dim=args.char_emb_dim,
                                    hidden_size=args.hidden_size,
                                    n_heads=args.n_heads,
                                    n_conv_emb_enc=args.n_conv_emb,
                                    n_conv_mod_enc=args.n_conv_mod,
                                    n_emb_enc_blocks=args.n_emb_blocks,
                                    n_mod_enc_blocks=args.n_mod_blocks)

        model_qanet_old = nn.DataParallel(model_qanet_old, gpu_ids)
        log.info(f'Loading checkpoint from {args.load_path_qanet_old}...')
        model_qanet_old = util.load_model(model_qanet_old,
                                          args.load_path_qanet_old,
                                          gpu_ids,
                                          return_step=False)
        model_qanet_old = model_qanet_old.to(device)
        model_qanet_old.eval()
        nll_meter_qanet_old = util.AverageMeter()
        nbr_model += 1
        save_prob_qanet_old_start = []
        save_prob_qanet_old_end = []

    if (args.load_path_qanet_inde):
        model_qanet_inde = QANet_independant_encoder(
            word_vectors=word_vectors,
            char_vectors=char_vectors,
            char_emb_dim=args.char_emb_dim,
            hidden_size=args.hidden_size,
            n_heads=args.n_heads,
            n_conv_emb_enc=args.n_conv_emb,
            n_conv_mod_enc=args.n_conv_mod,
            n_emb_enc_blocks=args.n_emb_blocks,
            n_mod_enc_blocks=args.n_mod_blocks,
            divisor_dim_kqv=args.divisor_dim_kqv)

        model_qanet_inde = nn.DataParallel(model_qanet_inde, gpu_ids)
        log.info(f'Loading checkpoint from {args.load_path_qanet_inde}...')
        model_qanet_inde = util.load_model(model_qanet_inde,
                                           args.load_path_qanet_inde,
                                           gpu_ids,
                                           return_step=False)
        model_qanet_inde = model_qanet_inde.to(device)
        model_qanet_inde.eval()
        nll_meter_qanet_inde = util.AverageMeter()
        nbr_model += 1
        save_prob_qanet_inde_start = []
        save_prob_qanet_inde_end = []

    if (args.load_path_qanet_s_e):
        model_qanet_s_e = QANet_S_E(word_vectors=word_vectors,
                                    char_vectors=char_vectors,
                                    char_emb_dim=args.char_emb_dim,
                                    hidden_size=args.hidden_size,
                                    n_heads=args.n_heads,
                                    n_conv_emb_enc=args.n_conv_emb,
                                    n_conv_mod_enc=args.n_conv_mod,
                                    n_emb_enc_blocks=args.n_emb_blocks,
                                    n_mod_enc_blocks=args.n_mod_blocks,
                                    divisor_dim_kqv=args.divisor_dim_kqv)

        model_qanet_s_e = nn.DataParallel(model_qanet_s_e, gpu_ids)
        log.info(f'Loading checkpoint from {args.load_path_qanet_s_e}...')
        model_qanet_s_e = util.load_model(model_qanet_s_e,
                                          args.load_path_qanet_s_e,
                                          gpu_ids,
                                          return_step=False)
        model_qanet_s_e = model_qanet_s_e.to(device)
        model_qanet_s_e.eval()
        nll_meter_qanet_s_e = util.AverageMeter()
        nbr_model += 1
        save_prob_qanet_s_e_start = []
        save_prob_qanet_s_e_end = []

    # Get data loader
    log.info('Building dataset...')
    record_file = vars(args)[f'{args.split}_record_file']
    dataset = SQuAD(record_file, args.use_squad_v2)
    data_loader = data.DataLoader(dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=args.num_workers,
                                  collate_fn=collate_fn)

    # Evaluate
    log.info(f'Evaluating on {args.split} split...')
    pred_dict = {}  # Predictions for TensorBoard
    sub_dict = {}  # Predictions for submission
    eval_file = vars(args)[f'{args.split}_eval_file']
    with open(eval_file, 'r') as fh:
        gold_dict = json_load(fh)
    with torch.no_grad(), \
            tqdm(total=len(dataset)) as progress_bar:
        for cw_idxs, cc_idxs, qw_idxs, qc_idxs, y1, y2, ids in data_loader:
            # Setup for forward
            cw_idxs = cw_idxs.to(device)
            qw_idxs = qw_idxs.to(device)
            cc_idxs = cc_idxs.to(device)
            qc_idxs = qc_idxs.to(device)
            batch_size = cw_idxs.size(0)

            y1, y2 = y1.to(device), y2.to(device)
            l_p1, l_p2 = [], []
            # Forward
            if (args.load_path_baseline):
                log_p1_baseline, log_p2_baseline = model_baseline(
                    cw_idxs, cc_idxs)
                loss_baseline = F.nll_loss(log_p1_baseline, y1) + F.nll_loss(
                    log_p2_baseline, y2)
                nll_meter_baseline.update(loss_baseline.item(), batch_size)
                l_p1 += [log_p1_baseline.exp()]
                l_p2 += [log_p2_baseline.exp()]
                if (args.save_probabilities):
                    save_prob_baseline_start += [
                        log_p1_baseline.exp().detach().cpu().numpy()
                    ]
                    save_prob_baseline_end += [
                        log_p2_baseline.exp().detach().cpu().numpy()
                    ]

            if (args.load_path_qanet):
                log_p1_qanet, log_p2_qanet = model_qanet(
                    cw_idxs, cc_idxs, qw_idxs, qc_idxs)
                loss_qanet = F.nll_loss(log_p1_qanet, y1) + F.nll_loss(
                    log_p2_qanet, y2)
                nll_meter_qanet.update(loss_qanet.item(), batch_size)
                # Get F1 and EM scores
                l_p1 += [log_p1_qanet.exp()]
                l_p2 += [log_p2_qanet.exp()]
                if (args.save_probabilities):
                    save_prob_qanet_start += [
                        log_p1_qanet.exp().detach().cpu().numpy()
                    ]
                    save_prob_qanet_end += [
                        log_p2_qanet.exp().detach().cpu().numpy()
                    ]

            if (args.load_path_qanet_old):
                log_p1_qanet_old, log_p2_qanet_old = model_qanet_old(
                    cw_idxs, cc_idxs, qw_idxs, qc_idxs)
                loss_qanet_old = F.nll_loss(log_p1_qanet_old, y1) + F.nll_loss(
                    log_p2_qanet_old, y2)
                nll_meter_qanet_old.update(loss_qanet_old.item(), batch_size)
                # Get F1 and EM scores
                l_p1 += [log_p1_qanet_old.exp()]
                l_p2 += [log_p2_qanet_old.exp()]
                if (args.save_probabilities):
                    save_prob_qanet_old_start += [
                        log_p1_qanet_old.exp().detach().cpu().numpy()
                    ]
                    save_prob_qanet_old_end += [
                        log_p2_qanet_old.exp().detach().cpu().numpy()
                    ]

            if (args.load_path_qanet_inde):
                log_p1_qanet_inde, log_p2_qanet_inde = model_qanet_inde(
                    cw_idxs, cc_idxs, qw_idxs, qc_idxs)
                loss_qanet_inde = F.nll_loss(
                    log_p1_qanet_inde, y1) + F.nll_loss(log_p2_qanet_inde, y2)
                nll_meter_qanet_inde.update(loss_qanet_inde.item(), batch_size)
                # Get F1 and EM scores
                l_p1 += [log_p1_qanet_inde.exp()]
                l_p2 += [log_p2_qanet_inde.exp()]
                if (args.save_probabilities):
                    save_prob_qanet_inde_start += [
                        log_p1_qanet_inde.exp().detach().cpu().numpy()
                    ]
                    save_prob_qanet_inde_end += [
                        log_p2_qanet_inde.exp().detach().cpu().numpy()
                    ]

            if (args.load_path_qanet_s_e):
                log_p1_qanet_s_e, log_p2_qanet_s_e = model_qanet_s_e(
                    cw_idxs, cc_idxs, qw_idxs, qc_idxs)
                loss_qanet_s_e = F.nll_loss(log_p1_qanet_s_e, y1) + F.nll_loss(
                    log_p2_qanet_s_e, y2)
                nll_meter_qanet_s_e.update(loss_qanet_s_e.item(), batch_size)
                # Get F1 and EM scores
                l_p1 += [log_p1_qanet_s_e.exp()]
                l_p2 += [log_p2_qanet_s_e.exp()]
                if (args.save_probabilities):
                    save_prob_qanet_s_e_start += [
                        log_p1_qanet_s_e.exp().detach().cpu().numpy()
                    ]
                    save_prob_qanet_s_e_end += [
                        log_p2_qanet_s_e.exp().detach().cpu().numpy()
                    ]

            if (args.load_path_bidaf):
                log_p1_bidaf, log_p2_bidaf = model_bidaf(
                    cw_idxs, cc_idxs, qw_idxs, qc_idxs)
                loss_bidaf = F.nll_loss(log_p1_bidaf, y1) + F.nll_loss(
                    log_p2_bidaf, y2)
                nll_meter_bidaf.update(loss_bidaf.item(), batch_size)
                l_p1 += [log_p1_bidaf.exp()]
                l_p2 += [log_p2_bidaf.exp()]
                if (args.save_probabilities):
                    save_prob_bidaf_start += [
                        log_p1_bidaf.exp().detach().cpu().numpy()
                    ]
                    save_prob_bidaf_end += [
                        log_p2_bidaf.exp().detach().cpu().numpy()
                    ]

            if (args.load_path_bidaf_fusion):
                log_p1_bidaf_fu, log_p2_bidaf_fu = model_bidaf_fu(
                    cw_idxs, cc_idxs, qw_idxs, qc_idxs)
                loss_bidaf_fu = F.nll_loss(log_p1_bidaf_fu, y1) + F.nll_loss(
                    log_p2_bidaf_fu, y2)
                nll_meter_bidaf_fu.update(loss_bidaf_fu.item(), batch_size)
                l_p1 += [log_p1_bidaf_fu.exp()]
                l_p2 += [log_p2_bidaf_fu.exp()]
                if (args.save_probabilities):
                    save_prob_bidaf_fu_start += [
                        log_p1_bidaf_fu.exp().detach().cpu().numpy()
                    ]
                    save_prob_bidaf_fu_end += [
                        log_p2_bidaf_fu.exp().detach().cpu().numpy()
                    ]

            p1, p2 = l_p1[0], l_p2[0]
            for i in range(1, nbr_model):
                p1 += l_p1[i]
                p2 += l_p2[i]
            p1 /= nbr_model
            p2 /= nbr_model

            starts, ends = util.discretize(p1, p2, args.max_ans_len,
                                           args.use_squad_v2)

            # Log info
            progress_bar.update(batch_size)
            if args.split != 'test':
                # No labels for the test set, so NLL would be invalid
                if (args.load_path_qanet):
                    progress_bar.set_postfix(NLL=nll_meter_qanet.avg)
                elif (args.load_path_bidaf):
                    progress_bar.set_postfix(NLL=nll_meter_bidaf.avg)
                elif (args.load_path_bidaf_fusion):
                    progress_bar.set_postfix(NLL=nll_meter_bidaf_fu.avg)
                elif (args.load_path_qanet_old):
                    progress_bar.set_postfix(NLL=nll_meter_qanet_old.avg)
                elif (args.load_path_qanet_inde):
                    progress_bar.set_postfix(NLL=nll_meter_qanet_inde.avg)
                elif (args.load_path_qanet_s_e):
                    progress_bar.set_postfix(NLL=nll_meter_qanet_s_e.avg)
                else:
                    progress_bar.set_postfix(NLL=nll_meter_baseline.avg)

            idx2pred, uuid2pred = util.convert_tokens(gold_dict, ids.tolist(),
                                                      starts.tolist(),
                                                      ends.tolist(),
                                                      args.use_squad_v2)
            pred_dict.update(idx2pred)
            sub_dict.update(uuid2pred)

    if (args.save_probabilities):
        if (args.load_path_baseline):
            with open(args.save_dir + "/probs_start", "wb") as fp:  #Pickling
                pickle.dump(save_prob_baseline_start, fp)
            with open(args.save_dir + "/probs_end", "wb") as fp:  #Pickling
                pickle.dump(save_prob_baseline_end, fp)

        if (args.load_path_bidaf):
            with open(args.save_dir + "/probs_start", "wb") as fp:  #Pickling
                pickle.dump(save_prob_bidaf_start, fp)
            with open(args.save_dir + "/probs_end", "wb") as fp:  #Pickling
                pickle.dump(save_prob_bidaf_end, fp)

        if (args.load_path_bidaf_fusion):
            with open(args.save_dir + "/probs_start", "wb") as fp:  #Pickling
                pickle.dump(save_prob_bidaf_fu_start, fp)
            with open(args.save_dir + "/probs_end", "wb") as fp:  #Pickling
                pickle.dump(save_prob_bidaf_fu_end, fp)

        if (args.load_path_qanet):
            with open(args.save_dir + "/probs_start", "wb") as fp:  #Pickling
                pickle.dump(save_prob_qanet_start, fp)
            with open(args.save_dir + "/probs_end", "wb") as fp:  #Pickling
                pickle.dump(save_prob_qanet_end, fp)

        if (args.load_path_qanet_old):
            with open(args.save_dir + "/probs_start", "wb") as fp:  #Pickling
                pickle.dump(save_prob_qanet_old_start, fp)
            with open(args.save_dir + "/probs_end", "wb") as fp:  #Pickling
                pickle.dump(save_prob_qanet_old_end, fp)

        if (args.load_path_qanet_inde):
            with open(args.save_dir + "/probs_start", "wb") as fp:  #Pickling
                pickle.dump(save_prob_qanet_inde_start, fp)
            with open(args.save_dir + "/probs_end", "wb") as fp:  #Pickling
                pickle.dump(save_prob_qanet_inde_end, fp)

        if (args.load_path_qanet_s_e):
            with open(args.save_dir + "/probs_start", "wb") as fp:  #Pickling
                pickle.dump(save_prob_qanet_s_e_start, fp)
            with open(args.save_dir + "/probs_end", "wb") as fp:  #Pickling
                pickle.dump(save_prob_qanet_s_e_end, fp)

    # Log results (except for test set, since it does not come with labels)
    if args.split != 'test':
        results = util.eval_dicts(gold_dict, pred_dict, args.use_squad_v2)
        if (args.load_path_qanet):
            meter_avg = nll_meter_qanet.avg
        elif (args.load_path_bidaf):
            meter_avg = nll_meter_bidaf.avg
        elif (args.load_path_bidaf_fusion):
            meter_avg = nll_meter_bidaf_fu.avg
        elif (args.load_path_qanet_inde):
            meter_avg = nll_meter_qanet_inde.avg
        elif (args.load_path_qanet_s_e):
            meter_avg = nll_meter_qanet_s_e.avg
        elif (args.load_path_qanet_old):
            meter_avg = nll_meter_qanet_old.avg
        else:
            meter_avg = nll_meter_baseline.avg
        results_list = [('NLL', meter_avg), ('F1', results['F1']),
                        ('EM', results['EM'])]
        if args.use_squad_v2:
            results_list.append(('AvNA', results['AvNA']))
        results = OrderedDict(results_list)

        # Log to console
        results_str = ', '.join(f'{k}: {v:05.2f}' for k, v in results.items())
        log.info(f'{args.split.title()} {results_str}')

        # Log to TensorBoard
        tbx = SummaryWriter(args.save_dir)
        util.visualize(tbx,
                       pred_dict=pred_dict,
                       eval_path=eval_file,
                       step=0,
                       split=args.split,
                       num_visuals=args.num_visuals)

    # Write submission file
    sub_path = join(args.save_dir, args.split + '_' + args.sub_file)
    log.info(f'Writing submission file to {sub_path}...')
    with open(sub_path, 'w', newline='', encoding='utf-8') as csv_fh:
        csv_writer = csv.writer(csv_fh, delimiter=',')
        csv_writer.writerow(['Id', 'Predicted'])
        for uuid in sorted(sub_dict):
            csv_writer.writerow([uuid, sub_dict[uuid]])
예제 #2
0
파일: test.py 프로젝트: wcAlex/QA-XL
def main(args):
    # Set up logging
    args.save_dir = util.get_save_dir(args.save_dir, args.name, training=False)
    log = util.get_logger(args.save_dir, args.name)
    log.info('Args: {}'.format(dumps(vars(args), indent=4, sort_keys=True)))
    device, gpu_ids = util.get_available_devices()
    args.batch_size *= max(1, len(gpu_ids))

    # Get embeddings
    log.info('Loading embeddings...')
    word_vectors = util.torch_from_json(args.word_emb_file)
    char_vectors = util.torch_from_json(args.char_emb_file)

    # Get model
    log.info('Building model...')
    model = QANet(word_vectors,
                  char_vectors,
                  args.test_para_limit,
                  args.test_ques_limit,
                  args.f_model,
                  num_head=args.num_head,
                  train_cemb=(not args.pretrained_char))
    model = nn.DataParallel(model, gpu_ids)
    log.info('Loading checkpoint from {}...'.format(args.load_path))
    model = util.load_model(model, args.load_path, gpu_ids, return_step=False)
    model = model.to(device)
    model.eval()

    # Get data loader
    log.info('Building dataset...')
    record_file = vars(args)['{}_record_file'.format(args.split)]
    dataset = SQuAD(record_file, args.use_squad_v2)
    data_loader = data.DataLoader(dataset,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=args.num_workers,
                                  collate_fn=collate_fn)

    # Evaluate
    log.info('Evaluating on {} split...'.format(args.split))
    loss_f = torch.nn.CrossEntropyLoss()
    nll_meter = util.AverageMeter()
    pred_dict = {}  # Predictions for TensorBoard
    sub_dict = {}  # Predictions for submission
    eval_file = vars(args)['{}_eval_file'.format(args.split)]
    with open(eval_file, 'r') as fh:
        gold_dict = json_load(fh)
    with torch.no_grad(), \
            tqdm(total=len(dataset)) as progress_bar:
        for cw_idxs, cc_idxs, qw_idxs, qc_idxs, y1, y2, ids in data_loader:
            # Setup for forward
            cw_idxs = cw_idxs.to(device)
            qw_idxs = qw_idxs.to(device)

            cc_idxs = cc_idxs.to(device)
            qc_idxs = qc_idxs.to(device)

            batch_size = cw_idxs.size(0)

            # Forward
            log_p1, log_p2 = model(cw_idxs, cc_idxs, qw_idxs, qc_idxs)
            y1, y2 = y1.to(device), y2.to(device)
            loss = torch.mean(loss_f(log_p1, y1) + loss_f(log_p2, y2))
            loss_val = loss.item()

            # Get F1 and EM scores
            p1, p2 = log_p1.exp(), log_p2.exp()
            starts, ends = util.discretize(p1, p2, args.max_ans_len,
                                           args.use_squad_v2)

            # Log info
            progress_bar.update(batch_size)
            if args.split != 'test':
                # No labels for the test set, so NLL would be invalid
                progress_bar.set_postfix(NLL=nll_meter.avg)

            idx2pred, uuid2pred = util.convert_tokens(gold_dict, ids.tolist(),
                                                      starts.tolist(),
                                                      ends.tolist(),
                                                      args.use_squad_v2)
            pred_dict.update(idx2pred)
            sub_dict.update(uuid2pred)

    # Log results (except for test set, since it does not come with labels)
    if args.split != 'test':
        results = util.eval_dicts(gold_dict, pred_dict, args.use_squad_v2)
        results_list = [('NLL', nll_meter.avg), ('F1', results['F1']),
                        ('EM', results['EM'])]
        if args.use_squad_v2:
            results_list.append(('AvNA', results['AvNA']))
        results = OrderedDict(results_list)

        # Log to console
        results_str = ', '.join('{}: {:05.2f}'.format(k, v)
                                for k, v in results.items())
        log.info('{} {}'.format(args.split.title(), results_str))

        # Log to TensorBoard
        tbx = SummaryWriter(args.save_dir)
        util.visualize(tbx,
                       pred_dict=pred_dict,
                       eval_path=eval_file,
                       step=0,
                       split=args.split,
                       num_visuals=args.num_visuals)

    # Write submission file
    sub_path = join(args.save_dir, args.split + '_' + args.sub_file)
    log.info('Writing submission file to {}...'.format(sub_path))
    with open(sub_path, 'w') as csv_fh:
        csv_writer = csv.writer(csv_fh, delimiter=',')
        csv_writer.writerow(['Id', 'Predicted'])
        for uuid in sorted(sub_dict):
            csv_writer.writerow([uuid, sub_dict[uuid]])
예제 #3
0
def main(args):
    # Set up logging and devices
    args.save_dir = util.get_save_dir(args.save_dir, args.name, training=True)
    log = util.get_logger(args.save_dir, args.name)
    tbx = SummaryWriter(args.save_dir)
    device, args.gpu_ids = util.get_available_devices()
    log.info('Args: {}'.format(dumps(vars(args), indent=4, sort_keys=True)))
    args.batch_size *= max(1, len(args.gpu_ids))

    # Set random seed
    log.info('Using random seed {}...'.format(args.seed))
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # Get embeddings
    # Args:  word_vectors: word vector tensor of dimension [vocab_size * wemb_dim]
    log.info('Loading embeddings...')
    word_vectors = util.torch_from_json(args.word_emb_file)
    char_vectors = util.torch_from_json(args.char_emb_file)

    # Get Model
    log.info('Building Model...')
    model = QANet(word_vectors,
                  char_vectors,
                  args.para_limit,
                  args.ques_limit,
                  args.f_model,
                  num_head=args.num_head,
                  train_cemb = (not args.pretrained_char))
    model = nn.DataParallel(model, args.gpu_ids)
    
    if args.load_path:
        log.info('Loading checkpoint from {}...'.format(args.load_path))
        model, step = util.load_model(model, args.load_path, args.gpu_ids)
    else:
        step = 0
    model = model.to(device)
    model.train()
    ema = util.EMA(model, args.ema_decay)

    # Get saver
    saver = util.CheckpointSaver(args.save_dir,
                                 max_checkpoints=args.max_checkpoints,
                                 metric_name=args.metric_name,
                                 maximize_metric=args.maximize_metric,
                                 log=log)

    # Get optimizer and scheduler
    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = optim.Adam(
        params=parameters,
        lr=args.lr,
        betas=(args.beta1, args.beta2),
        eps=1e-8,
        weight_decay=3e-7)
    cr = 1.0 / math.log(args.lr_warm_up_num)
    scheduler = optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=lambda ee: cr * math.log(ee + 1)
        if ee < args.lr_warm_up_num else 1)
    loss_f = torch.nn.CrossEntropyLoss()

    # Get data loader
    log.info('Building dataset...')
    train_dataset = SQuAD(args.train_record_file, args.use_squad_v2)
    train_loader = data.DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   collate_fn=collate_fn)
    dev_dataset = SQuAD(args.dev_record_file, args.use_squad_v2)
    dev_loader = data.DataLoader(dev_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers,
                                 collate_fn=collate_fn)

    # Train
    log.info('Training...')
    steps_till_eval = args.eval_steps
    epoch = step // len(train_dataset)
    while epoch != args.num_epochs:
        epoch += 1
        log.info('Starting epoch {}...'.format(epoch))
        with torch.enable_grad(), \
                tqdm(total=len(train_loader.dataset)) as progress_bar:
            for cw_idxs, cc_idxs, qw_idxs, qc_idxs, y1, y2, ids in train_loader:
                # Setup for forward
                cw_idxs = cw_idxs.to(device)
                qw_idxs = qw_idxs.to(device)

                cc_idxs = cc_idxs.to(device)
                qc_idxs = qc_idxs.to(device)
                
                batch_size = cw_idxs.size(0)
                optimizer.zero_grad()

                # Forward
                log_p1, log_p2 = model(cw_idxs, cc_idxs, qw_idxs, qc_idxs)
                y1, y2 = y1.to(device), y2.to(device)
                loss = torch.mean(loss_f(log_p1, y1) + loss_f(log_p2, y2))
                loss_val = loss.item()

                # Backward
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                optimizer.step()
                scheduler.step(step // batch_size)
                ema(model, step // batch_size)

                # Log info
                step += batch_size
                progress_bar.update(batch_size)
                progress_bar.set_postfix(epoch=epoch,
                                         NLL=loss_val)
                tbx.add_scalar('train/NLL', loss_val, step)
                tbx.add_scalar('train/LR',
                               optimizer.param_groups[0]['lr'],
                               step)

                steps_till_eval -= batch_size
                if steps_till_eval <= 0:
                    steps_till_eval = args.eval_steps

                    # Evaluate and save checkpoint
                    log.info('Evaluating at step {}...'.format(step))
                    ema.assign(model)
                    results, pred_dict = evaluate(model, dev_loader, device,
                                                  args.dev_eval_file,
                                                  args.max_ans_len,
                                                  args.use_squad_v2)
                    saver.save(step, model, results[args.metric_name], device)
                    ema.resume(model)

                    # Log to console
                    results_str = ', '.join('{}: {:05.2f}'.format(k, v)
                                            for k, v in results.items())
                    log.info('Dev {}'.format(results_str))

                    # Log to TensorBoard
                    log.info('Visualizing in TensorBoard...')
                    for k, v in results.items():
                        tbx.add_scalar('dev/{}'.format(k), v, step)
                    util.visualize(tbx,
                                   pred_dict=pred_dict,
                                   eval_path=args.dev_eval_file,
                                   step=step,
                                   split='dev',
                                   num_visuals=args.num_visuals)
예제 #4
0
파일: train.py 프로젝트: devansh287/squad
def main(args):
    # Set up logging and devices
    args.save_dir = util.get_save_dir(args.save_dir, args.name, training=True)
    log = util.get_logger(args.save_dir, args.name)
    tbx = SummaryWriter(args.save_dir)
    device, args.gpu_ids = util.get_available_devices()
    log.info(f'Args: {dumps(vars(args), indent=4, sort_keys=True)}')
    args.batch_size *= max(1, len(args.gpu_ids))

    # Set random seed
    log.info(f'Using random seed {args.seed}...')
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # Get embeddings
    log.info('Loading embeddings...')
    word_vectors = util.torch_from_json(args.word_emb_file)
    char_vectors = util.torch_from_json(args.char_emb_file)

    # Get model
    log.info('Building model...')
    """
    model = BiDAF(word_vectors=word_vectors,
                  hidden_size=args.hidden_size,
                  drop_prob=args.drop_prob)
    """
    '''
    model = charBiDAF(word_vectors=word_vectors,
                      char_vectors=char_vectors,
                      emb_size=char_vectors.size(1),
                      hidden_size=args.hidden_size,
                      drop_prob=args.drop_prob)
    '''
    model = QANet(word_vectors=word_vectors,
                  char_vectors=char_vectors,
                  emb_size=char_vectors.size(1),
                  hidden_size=args.hidden_size,
                  drop_prob=args.drop_prob)
    print(
        "YOU ARE ENTERING THE QA NET MODEL TRAINING_____________________________________________________________________"
    )
    model = nn.DataParallel(model, args.gpu_ids)
    if args.load_path:
        log.info(f'Loading checkpoint from {args.load_path}...')
        model, step = util.load_model(model, args.load_path, args.gpu_ids)
    else:
        step = 0
    model = model.to(device)
    model.train()
    #print("Model status - ", cuda_or_cpu(model))
    ema = util.EMA(model, args.ema_decay)

    # Get saver
    saver = util.CheckpointSaver(args.save_dir,
                                 max_checkpoints=args.max_checkpoints,
                                 metric_name=args.metric_name,
                                 maximize_metric=args.maximize_metric,
                                 log=log)

    # Get optimizer and scheduler
    optimizer = optim.Adadelta(model.parameters(),
                               args.lr,
                               weight_decay=args.l2_wd)
    scheduler = sched.LambdaLR(optimizer, lambda s: 1.)  # Constant LR

    # Get data loader
    log.info('Building dataset...')
    train_dataset = SQuAD(args.train_record_file, args.use_squad_v2)
    #torch.utils.data._utils.MP_STATUS_CHECK_INTERVAL = 300
    train_loader = data.DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   collate_fn=collate_fn)
    dev_dataset = SQuAD(args.dev_record_file, args.use_squad_v2)

    dev_loader = data.DataLoader(dev_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=0,
                                 collate_fn=collate_fn)

    # Train
    log.info('Training...')
    steps_till_eval = args.eval_steps
    epoch = step // len(train_dataset)
    while epoch != args.num_epochs:
        epoch += 1
        log.info(f'Starting epoch {epoch}...')
        with torch.enable_grad(), \
                tqdm(total=len(train_loader.dataset)) as progress_bar:
            for cw_idxs, cc_idxs, qw_idxs, qc_idxs, y1, y2, ids in train_loader:

                # Setup for forward
                cw_idxs = cw_idxs.to(device)
                cc_idxs = cc_idxs.to(device)
                qw_idxs = qw_idxs.to(device)
                qc_idxs = qc_idxs.to(device)
                batch_size = cw_idxs.size(0)
                optimizer.zero_grad()

                # Forward
                log_p1, log_p2 = model(cw_idxs, cc_idxs, qw_idxs, qc_idxs)
                y1, y2 = y1.to(device), y2.to(device)
                loss = F.nll_loss(log_p1, y1) + F.nll_loss(log_p2, y2)
                loss_val = loss.item()

                # Backward
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(),
                                         args.max_grad_norm)
                optimizer.step()

                #print(torch.cuda.memory_allocated(device=None))
                scheduler.step(step // batch_size)
                ema(model, step // batch_size)
                #torch.cuda.empty_cache()

                # Log info
                step += batch_size
                progress_bar.update(batch_size)
                progress_bar.set_postfix(epoch=epoch, NLL=loss_val)
                tbx.add_scalar('train/NLL', loss_val, step)
                tbx.add_scalar('train/LR', optimizer.param_groups[0]['lr'],
                               step)

                steps_till_eval -= batch_size
                if steps_till_eval <= 0:
                    steps_till_eval = args.eval_steps

                    # Evaluate and save checkpoint
                    log.info(f'Evaluating at step {step}...')
                    # print('Memory 1: ', torch.cuda.memory_allocated())
                    ema.assign(model)
                    # print('Memory 2: ', torch.cuda.memory_allocated())
                    results, pred_dict = evaluate(model, dev_loader, device,
                                                  args.dev_eval_file,
                                                  args.max_ans_len,
                                                  args.use_squad_v2)
                    saver.save(step, model, results[args.metric_name], device)
                    ema.resume(model)

                    # Log to console
                    results_str = ', '.join(f'{k}: {v:05.2f}'
                                            for k, v in results.items())
                    log.info(f'Dev {results_str}')

                    # Log to TensorBoard
                    log.info('Visualizing in TensorBoard...')
                    for k, v in results.items():
                        tbx.add_scalar(f'dev/{k}', v, step)
                    util.visualize(tbx,
                                   pred_dict=pred_dict,
                                   eval_path=args.dev_eval_file,
                                   step=step,
                                   split='dev',
                                   num_visuals=args.num_visuals)
예제 #5
0
def main(args):
    # Set up logging and devices
    args.save_dir = util.get_save_dir(args.save_dir, args.name, training=True)
    log = util.get_logger(args.save_dir, args.name)
    tbx = SummaryWriter(args.save_dir)
    device, args.gpu_ids = util.get_available_devices()
    log.info('Args: {}'.format(dumps(vars(args), indent=4, sort_keys=True)))
    args.batch_size *= max(1, len(args.gpu_ids))

    # Set random seed
    log.info('Using random seed {}...'.format(args.seed))
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # Get embeddings
    log.info('Loading embeddings...')
    word_vectors = util.torch_from_json(args.word_emb_file)

    # setup_args = get_setup_args()
    with open(args.char2idx_file, "r") as f:
        char2idx = json_load(f)
    # Get model
    log.info('Building model...')
    model = QANet(word_vectors=word_vectors, char2idx=char2idx)
    model = nn.DataParallel(model, args.gpu_ids)
    if args.load_path:
        log.info('Loading checkpoint from {}...'.format(args.load_path))
        model, step = util.load_model(model, args.load_path, args.gpu_ids)
    else:
        step = 0
    model = model.to(device)
    model.train()
    ema = util.EMA(model, args.ema_decay)

    # Get saver
    saver = util.CheckpointSaver(args.save_dir,
                                 max_checkpoints=args.max_checkpoints,
                                 metric_name=args.metric_name,
                                 maximize_metric=args.maximize_metric,
                                 log=log)

    # Get optimizer and scheduler
    optimizer = optim.Adadelta(model.parameters(),
                               args.lr,
                               weight_decay=args.l2_wd)
    # optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.8, 0.999))
    scheduler = sched.LambdaLR(optimizer, lambda s: 1.)  # Constant LR

    # Get data loader
    log.info('Building dataset...')

    dataset1 = np.load(args.train_record_file)
    outfile = '/content/dataset/train_light.npz'
    n_row = 75000
    np.savez(outfile,
             context_idxs=dataset1['context_idxs'][:n_row],
             context_char_idxs=dataset1['context_char_idxs'][:n_row],
             ques_idxs=dataset1['ques_idxs'][:n_row],
             ques_char_idxs=dataset1['ques_char_idxs'][:n_row],
             y1s=dataset1['y1s'][:n_row],
             y2s=dataset1['y2s'][:n_row],
             ids=dataset1['ids'][:n_row])

    # train_dataset = SQuAD(args.train_record_file, args.use_squad_v2)
    train_dataset = SQuAD(outfile, args.use_squad_v2)
    train_loader = data.DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   collate_fn=collate_fn)
    dev_dataset = SQuAD(args.dev_record_file, args.use_squad_v2)
    dev_loader = data.DataLoader(dev_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers,
                                 collate_fn=collate_fn)

    # Train
    log.info('Training...')
    steps_till_eval = args.eval_steps
    epoch = step // len(train_dataset)
    while epoch != args.num_epochs:
        epoch += 1
        log.info('Starting epoch {}...'.format(epoch))
        with torch.enable_grad(), \
                tqdm(total=len(train_loader.dataset)) as progress_bar:
            for cw_idxs, cc_idxs, qw_idxs, qc_idxs, y1, y2, ids in train_loader:
                # Setup for forward
                cw_idxs = cw_idxs.to(device)
                qw_idxs = qw_idxs.to(device)
                batch_size = cw_idxs.size(0)
                optimizer.zero_grad()

                # Forward
                log_p1, log_p2 = model(cw_idxs, cc_idxs, qw_idxs, qc_idxs)
                y1, y2 = y1.to(device), y2.to(device)
                loss = F.nll_loss(log_p1, y1) + F.nll_loss(log_p2, y2)
                loss_val = loss.item()
                if step % 10000 == 0:
                    print('loss val: {}'.format(loss_val))

                # Backward
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(),
                                         args.max_grad_norm)
                optimizer.step()
                scheduler.step(step // batch_size)
                ema(model, step // batch_size)

                # Log info
                step += batch_size
                progress_bar.update(batch_size)
                progress_bar.set_postfix(epoch=epoch, NLL=loss_val)
                tbx.add_scalar('train/NLL', loss_val, step)
                tbx.add_scalar('train/LR', optimizer.param_groups[0]['lr'],
                               step)

                steps_till_eval -= batch_size
                if steps_till_eval <= 0:
                    steps_till_eval = args.eval_steps

                    # Evaluate and save checkpoint
                    log.info('Evaluating at step {}...'.format(step))
                    ema.assign(model)
                    results, pred_dict = evaluate(model, dev_loader, device,
                                                  args.dev_eval_file,
                                                  args.max_ans_len,
                                                  args.use_squad_v2)
                    saver.save(step, model, results[args.metric_name], device)
                    ema.resume(model)

                    # Log to console
                    results_str = ', '.join('{}: {:05.2f}'.format(k, v)
                                            for k, v in results.items())
                    log.info('Dev {}'.format(results_str))

                    # Log to TensorBoard
                    log.info('Visualizing in TensorBoard...')
                    for k, v in results.items():
                        tbx.add_scalar('dev/{}'.format(k), v, step)
                    util.visualize(tbx,
                                   pred_dict=pred_dict,
                                   eval_path=args.dev_eval_file,
                                   step=step,
                                   split='dev',
                                   num_visuals=args.num_visuals)
예제 #6
0
def main(args):
    # Set up logging and devices
    args.save_dir = util.get_save_dir(args.save_dir, args.name, training=True)
    log = util.get_logger(args.save_dir, args.name)
    tbx = SummaryWriter(args.save_dir)
    device, args.gpu_ids = util.get_available_devices()
    log.info('Args: {}'.format(dumps(vars(args), indent=4, sort_keys=True)))
    args.batch_size *= max(1, len(args.gpu_ids))

    # Set random seed
    log.info('Using random seed {}...'.format(args.seed))
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    # Get word embeddings
    log.info('Loading word embeddings...')
    word_vectors = util.torch_from_json(args.word_emb_file)

    ### start our code:
    # Get char-embeddings
    log.info('Loading char-embeddings...')
    char_vectors = util.torch_from_json(args.char_emb_file)
    ### end our code

    # Get model
    log.info('Building model...')
    # model = BiDAF(word_vectors=word_vectors,
    #               hidden_size=args.hidden_size,
    #               char_vectors = char_vectors,
    #               drop_prob=args.drop_prob)

    ### start our code:   (QANet)
    model = QANet(word_vectors=word_vectors,
                  char_vectors=char_vectors,
                  hidden_size=args.hidden_size,
                  kernel_size=7,
                  filters=128,
                  drop_prob=args.drop_prob)
    ### end our code

    model = nn.DataParallel(model, args.gpu_ids)
    if args.load_path:
        log.info('Loading checkpoint from {}...'.format(args.load_path))
        model, step = util.load_model(model, args.load_path, args.gpu_ids)
    else:
        step = 0
    model = model.to(device)
    model.train()
    ema = util.EMA(model, args.ema_decay)

    # Get saver
    saver = util.CheckpointSaver(args.save_dir,
                                 max_checkpoints=args.max_checkpoints,
                                 metric_name=args.metric_name,
                                 maximize_metric=args.maximize_metric,
                                 log=log)

    # Get optimizer and scheduler
    # https://pytorch.org/docs/stable/optim.html
    # Original:
    # optimizer = optim.Adadelta(model.parameters(), args.lr,
    #                           weight_decay=args.l2_wd)
    #                   # default: lr=0.5, rho=0.9, eps=1e-06, weight_decay=0
    # scheduler = sched.LambdaLR(optimizer, lambda s: 1.)  # Constant LR
    # A function which computes a multiplicative factor given an integer parameter epoch,
    # or a list of such functions, one for each group in optimizer.param_groups

    ### start our code:
    optimizer = optim.Adam(model.parameters(),
                           lr=1e-3,
                           betas=(0.8, 0.999),
                           eps=1e-7,
                           weight_decay=3e-7,
                           amsgrad=False)
    scheduler = sched.LambdaLR(optimizer, lambda s: 1.)  # Constant LR

    # Adabound
    # optimizer = adabound.AdaBound(model.parameters(), lr=1e-3, final_lr=0.1)
    # scheduler = sched.LambdaLR(optimizer, lambda s: 1.)
    ### end our code

    # Get data loader
    log.info('Building dataset...')
    train_dataset = SQuAD(args.train_record_file, args.use_squad_v2)
    train_loader = data.DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=True,
                                   num_workers=args.num_workers,
                                   collate_fn=collate_fn)
    dev_dataset = SQuAD(args.dev_record_file, args.use_squad_v2)
    dev_loader = data.DataLoader(dev_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 num_workers=args.num_workers,
                                 collate_fn=collate_fn)

    # Train
    log.info('Training...')
    steps_till_eval = args.eval_steps
    epoch = step // len(train_dataset)
    while epoch != args.num_epochs:
        epoch += 1
        log.info('Starting epoch {}...'.format(epoch))
        with torch.enable_grad(), \
                tqdm(total=len(train_loader.dataset)) as progress_bar:
            for cw_idxs, cc_idxs, qw_idxs, qc_idxs, y1, y2, ids in train_loader:

                # print("cc_idxs: ", cc_idxs)
                # print("qc_idxs: ", qc_idxs)

                # Setup for forward
                cw_idxs = cw_idxs.to(device)
                qw_idxs = qw_idxs.to(device)

                ### start our code:
                cc_idxs = cc_idxs.to(device)
                qc_idxs = qc_idxs.to(device)

                ### end our code

                batch_size = cw_idxs.size(0)
                optimizer.zero_grad()

                # Forward
                # log_p1, log_p2 = model(cw_idxs, qw_idxs)   # original

                ### start our code:
                log_p1, log_p2 = model(cw_idxs, qw_idxs, cc_idxs, qc_idxs)

                ### end our code

                y1, y2 = y1.to(device), y2.to(device)
                loss = F.nll_loss(log_p1, y1) + F.nll_loss(log_p2, y2)
                loss_val = loss.item()

                # Backward
                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(),
                                         args.max_grad_norm)

                ### start our code:
                # optimizer = LR_decay(optimizer, epoch, lr = 0.8)   # initial learning rate 0.8
                if step < 1000:
                    optimizer = LR_warmup(optimizer, step, lr=0.)
                ### end our code

                optimizer.step()
                scheduler.step(step // batch_size)
                ema(model, step // batch_size)

                # Log info
                step += batch_size
                progress_bar.update(batch_size)
                progress_bar.set_postfix(epoch=epoch, NLL=loss_val)
                tbx.add_scalar('train/NLL', loss_val, step)
                tbx.add_scalar('train/LR', optimizer.param_groups[0]['lr'],
                               step)

                steps_till_eval -= batch_size
                if steps_till_eval <= 0:
                    steps_till_eval = args.eval_steps

                    # Evaluate and save checkpoint
                    log.info('Evaluating at step {}...'.format(step))
                    ema.assign(model)
                    results, pred_dict = evaluate(model, dev_loader, device,
                                                  args.dev_eval_file,
                                                  args.max_ans_len,
                                                  args.use_squad_v2)
                    saver.save(step, model, results[args.metric_name], device)
                    ema.resume(model)

                    # Log to console
                    results_str = ', '.join('{}: {:05.2f}'.format(k, v)
                                            for k, v in results.items())
                    log.info('Dev {}'.format(results_str))

                    # Log to TensorBoard
                    log.info('Visualizing in TensorBoard...')
                    for k, v in results.items():
                        tbx.add_scalar('dev/{}'.format(k), v, step)
                    util.visualize(tbx,
                                   pred_dict=pred_dict,
                                   eval_path=args.dev_eval_file,
                                   step=step,
                                   split='dev',
                                   num_visuals=args.num_visuals)