def complete_default_test_parser(args):
    if torch.cuda.is_available():
        device_ids, _ = single_free_cuda()
        device = torch.device('cuda:{}'.format(device_ids[0]))
    else:
        device = torch.device('cpu')
    args.device = device
    args.num_gnn_layers = int(args.gnn.split(':')[1].split(',')[0])
    args.num_gnn_heads = int(args.gnn.split(':')[1].split(',')[1])
    if len(args.mask_edge_types):
        args.mask_edge_types = list(map(int, args.mask_edge_types.split(',')))
    # TODO: only support albert-xxlarge-v2 now
    args.input_dim = 768 if 'base' in args.encoder_name_or_path else (
        4096 if 'albert' in args.encoder_name_or_path else 1024)
    # output dir name
    args.exp_name = os.path.join(args.output_dir, args.exp_name)
    os.makedirs(args.exp_name, exist_ok=True)

    encoder_path = join(args.input_model_path,
                        args.encoder_ckpt)  ## replace encoder.pkl as encoder
    model_path = join(args.input_model_path,
                      args.model_ckpt)  ## replace encoder.pkl as encoder
    args.encoder_path = encoder_path
    args.model_path = model_path
    return args
Ejemplo n.º 2
0
def jd_adaptive_threshold_prediction(args, model, feat_dict_file_name):
    if torch.cuda.is_available():
        device_ids, _ = single_free_cuda()
        device = torch.device('cuda:{}'.format(device_ids[0]))
    else:
        device = torch.device('cpu')
    data_feat = RangeDataset(json_file_name=feat_dict_file_name)
    data_loader = DataLoader(dataset=data_feat,
                                 shuffle=False,
                                 collate_fn=RangeDataset.collate_fn,
                                 batch_size=args.test_batch_size)
    model.to(device)
    model.eval()
    pred_score_dict = {}
    total_count = 0
    for batch in data_loader:
        # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
        for key, value in batch.items():
            if key not in ['id']:
                batch[key] = value.to(device)
        with torch.no_grad():
            scores = model(batch['x_feat'])
            scores = scores.squeeze(-1)
            scores = torch.sigmoid(scores)
            score_np = scores.data.cpu().numpy()
            for i in range(score_np.shape[0]):
                key = batch['id'][i]
                total_count = total_count + 1
                score_i = score_np[i]
                pred_score_dict[key] = float(score_i)
    return pred_score_dict
Ejemplo n.º 3
0
train_feature_dict = helper.train_feature_dict
train_dataloader = helper.hotpot_train_dataloader

# #########################################################################
# # Initialize Model
# ##########################################################################
config_class, model_encoder, tokenizer_class = MODEL_CLASSES[args.model_type]
config = config_class.from_pretrained(args.encoder_name_or_path)

encoder_path = join(args.exp_name, args.encoder_name) ## replace encoder.pkl as encoder
model_path = join(args.exp_name, args.model_name) ## replace encoder.pkl as encoder
logger.info("Loading encoder from: {}".format(encoder_path))
logger.info("Loading model from: {}".format(model_path))

if torch.cuda.is_available():
    device_ids, _ = single_free_cuda()
    device = torch.device('cuda:{}'.format(device_ids[0]))
else:
    device = torch.device('cpu')

encoder, _ = load_encoder_model(args.encoder_name_or_path, args.model_type)
model = HierarchicalGraphNetwork(config=args)

if encoder_path is not None:
    state_dict = torch.load(encoder_path)
    print('loading parameter from {}'.format(encoder_path))
    for key in list(state_dict.keys()):
        if 'module.' in key:
            state_dict[key.replace('module.', '')] = state_dict[key]
            del state_dict[key]
    encoder.load_state_dict(state_dict)
Ejemplo n.º 4
0
def train(args):
    train_feat_file_name = join(args.output_dir, args.exp_name,
                                args.train_feat_json_name)
    dev_feat_file_name = join(args.output_dir, args.exp_name,
                              args.dev_feat_json_name)
    dev_score_file_name = join(args.output_dir, args.exp_name,
                               args.dev_score_name)
    threshold_category = get_threshold_category(
        interval_num=args.interval_number)
    raw_dev_data_file_name = join(args.input_dir, args.raw_dev_data)

    raw_data = load_json_score_data(
        json_score_file_name=raw_dev_data_file_name)
    raw_data_dict = dict([(row['_id'], row) for row in raw_data])

    if torch.cuda.is_available():
        device_ids, _ = single_free_cuda()
        device = torch.device('cuda:{}'.format(device_ids[0]))
    else:
        device = torch.device('cpu')

    for key, value in vars(args).items():
        print('{}: {}'.format(key, value))
    print('device = {}'.format(device))

    ##+++++++++
    random_seed = args.rand_seed
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    ##+++++++++
    train_data = RangeSeqDataset(json_file_name=train_feat_file_name,
                                 span_window_size=args.span_window_size,
                                 trim_drop_ratio=args.trim_drop_ratio)
    dev_data = RangeSeqDataset(json_file_name=dev_feat_file_name,
                               span_window_size=args.span_window_size,
                               trim_drop_ratio=0.0)
    train_data_loader = DataLoader(dataset=train_data,
                                   shuffle=True,
                                   collate_fn=RangeSeqDataset.collate_fn,
                                   num_workers=args.cpu_number,
                                   batch_size=args.train_batch_size)
    dev_data_loader = DataLoader(dataset=dev_data,
                                 shuffle=False,
                                 collate_fn=RangeSeqDataset.collate_fn,
                                 batch_size=args.eval_batch_size)
    dev_score_dict = load_json_score_data(
        json_score_file_name=dev_score_file_name)
    t_total_steps = len(train_data_loader) * args.num_train_epochs
    model = RangeSeqModel(args=args)
    # model = RangeSeqScoreModel(args=args)
    #++++++++++++++++++++++++++++++++++++++++++++
    model.to(device)

    model.zero_grad()
    model.train()
    optimizer = get_optimizer(model=model, args=args)
    scheduler = get_scheduler(optimizer=optimizer,
                              args=args,
                              total_steps=t_total_steps)
    for name, param in model.named_parameters():
        print('Parameter {}: {}, require_grad = {}'.format(
            name, str(param.size()), str(param.requires_grad)))
    print('*' * 75)

    ###++++++++++++++++++++++++++++++++++++++++++
    total_batch_num = len(train_data_loader)
    print('Total number of batches = {}'.format(total_batch_num))
    eval_batch_interval_num = int(
        total_batch_num * args.eval_interval_ratio) + 1
    print('Evaluate the model by = {} batches'.format(eval_batch_interval_num))
    ###++++++++++++++++++++++++++++++++++++++++++
    early_stop_step = 0

    start_epoch = 0
    best_em_ratio = 0.0
    best_f1 = 0.0
    dev_loss = 0.0
    dev_prediction_dict = None
    for epoch in range(start_epoch, start_epoch + int(args.num_train_epochs)):
        epoch_iterator = train_data_loader
        for step, batch in enumerate(epoch_iterator):
            model.train()
            #+++++++
            for key, value in batch.items():
                if key not in ['id']:
                    batch[key] = value.to(device)
            #+++++++
            # batch_analysis(batch['x_feat'])
            start_scores, end_scores = model(batch['x_feat'])
            loss = seq_loss_computation(start=start_scores,
                                        end=end_scores,
                                        batch=batch,
                                        weight=args.weighted_loss)
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           args.max_grad_norm)
            optimizer.step()
            scheduler.step()
            model.zero_grad()

            if step % 10 == 0:
                print(
                    'Epoch={}\tstep={}\tloss={:.5f}\teval_em={:.6f}\teval_f1={:.6f}\teval_loss={:.5f}\n'
                    .format(epoch, step, loss.data.item(), best_em_ratio,
                            best_f1, dev_loss))
            if (step + 1) % eval_batch_interval_num == 0:
                em_ratio, dev_f1, total_count, dev_loss_i, pred_dict = eval_model(
                    model=model,
                    data_loader=dev_data_loader,
                    weighted_loss=args.weighted_loss,
                    device=device,
                    alpha=args.alpha,
                    threshold_category=threshold_category,
                    dev_score_dict=dev_score_dict,
                    raw_dev_dict=raw_data_dict)
                dev_loss = dev_loss_i
                # em_ratio = em_count * 1.0/total_count
                # if em_ratio > best_em_ratio:
                #     best_em_ratio = em_ratio
                #     torch.save({k: v.cpu() for k, v in model.state_dict().items()},
                #                join(args.output_dir, args.exp_name, f'seq_threshold_pred_model.pkl'))
                #     dev_prediction_dict = pred_dict
                if best_f1 < dev_f1:
                    best_f1 = dev_f1
                    early_stop_step = 0
                    best_em_ratio = em_ratio
                    best_f1_em = 'f1_{:.4f}_em_{:.4f}'.format(
                        best_f1, best_em_ratio)
                    torch.save(
                        {k: v.cpu()
                         for k, v in model.state_dict().items()},
                        join(
                            args.output_dir, args.exp_name,
                            f'seq_pred_model_{epoch + 1}.step_{step + 1}.{best_f1_em}.pkl'
                        ))
                    dev_prediction_dict = pred_dict
                else:
                    early_stop_step += 1
    print('Best em ratio = {:.5f}'.format(best_em_ratio))
    print('Best f1 = {:.5f}'.format(best_f1))
    return best_em_ratio, best_f1, dev_prediction_dict
Ejemplo n.º 5
0
def run(args):
    if torch.cuda.is_available():
        device_ids, _ = single_free_cuda()
        device = torch.device('cuda:{}'.format(device_ids[0]))
    else:
        device = torch.device('cpu')

    if args.train_filter:
        train_npz_file_name = join(args.pred_dir, args.model_name_or_path, 'filter_' + args.train_feat_name)
    else:
        train_npz_file_name = join(args.pred_dir, args.model_name_or_path, args.train_feat_name)

    ##+++++++++
    random_seed = args.rand_seed
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    ##+++++++++
    train_npz_data = RangeDataset(npz_file_name=train_npz_file_name)
    train_data_loader = DataLoader(dataset=train_npz_data,
                                   shuffle=True,
                                   collate_fn=RangeDataset.collate_fn,
                                   num_workers=args.cpu_number//2,
                                   batch_size=args.train_batch_size)

    dev_npz_file_name = join(args.pred_dir, args.model_name_or_path, args.dev_feat_name)
    dev_npz_data = RangeDataset(npz_file_name=dev_npz_file_name)
    dev_data_loader = DataLoader(dataset=dev_npz_data,
                                   shuffle=False,
                                   collate_fn=RangeDataset.collate_fn,
                                   num_workers=args.cpu_number // 2,
                                   batch_size=args.eval_batch_size)

    model = RangeModel(args=args)
    model.to(device)

    model.zero_grad()
    model.train()
    optimizer = get_optimizer(model=model, args=args)
    for name, param in model.named_parameters():
        print('Parameter {}: {}, require_grad = {}'.format(name, str(param.size()), str(param.requires_grad)))
    print('*' * 75)
    ###++++++++++++++++++++++++++++++++++++++++++
    total_batch_num = len(train_data_loader)
    print('Total number of batches = {}'.format(total_batch_num))
    eval_batch_interval_num = int(total_batch_num * args.eval_interval_ratio) + 1
    print('Evaluate the model by = {} batches'.format(eval_batch_interval_num))
    ###++++++++++++++++++++++++++++++++++++++++++
    start_epoch = 0
    train_iterator = trange(start_epoch, start_epoch + int(args.num_train_epochs), desc="Epoch")
    best_em_ratio = 0.0
    # for epoch in train_iterator:
    for epoch in range(start_epoch, start_epoch + int(args.num_train_epochs)):
        # epoch_iterator = tqdm(train_data_loader, desc="Iteration")
        epoch_iterator = train_data_loader
        for step, batch in enumerate(epoch_iterator):
            model.train()
            #+++++++
            for key, value in batch.items():
                batch[key] = value.to(device)
            #+++++++
            scores = model(batch['x_feat']).squeeze(-1)
            loss = loss_computation(scores=scores, y_min=batch['y_min'], y_max=batch['y_max'])
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            optimizer.step()
            model.zero_grad()

            if step % 10 == 0:
                print('{}\t{}\t{:.5f}\n'.format(epoch, step, loss.data.item()))
            if (step + 1) % eval_batch_interval_num == 0:
                em_count, total_count = eval_model(model=model, data_loader=dev_data_loader, device=device)
                em_ratio = em_count * 1.0/total_count
                print('*' * 35)
                print('{}\t{}\t{:.5f}\n'.format(epoch, step, em_ratio))
                print('*' * 35)
                if em_ratio > best_em_ratio:
                    best_em_ratio = em_ratio
    print('Best em ratio = {:.5f}'.format(best_em_ratio))
    return best_em_ratio