Ejemplo n.º 1
0
def train_iteration(model, optimizer, dataset, train_pairs, qrels):
    total = 0
    model.train()
    total_loss = 0.
    with tqdm('training',
              total=BATCH_SIZE * BATCHES_PER_EPOCH,
              ncols=80,
              desc='train') as pbar:
        for record in data.iter_train_pairs(model, dataset, train_pairs, qrels,
                                            GRAD_ACC_SIZE):
            scores = model(record['query_tok'], record['query_mask'],
                           record['doc_tok'], record['doc_mask'])
            count = len(record['query_id']) // 2
            scores = scores.reshape(count, 2)
            loss = torch.mean(1. -
                              scores.softmax(dim=1)[:, 0])  # pariwse softmax
            loss.backward()
            total_loss += loss.item()
            total += count
            if total % BATCH_SIZE == 0:
                optimizer.step()
                optimizer.zero_grad()
            pbar.update(count)
            if total >= BATCH_SIZE * BATCHES_PER_EPOCH:
                return total_loss
Ejemplo n.º 2
0
def train_iteration(model, optimizer, dataset, train_pairs, contentid2entity,
                    embed):
    GRAD_ACC_SIZE = 2
    total = 0
    model.train()
    total_loss = 0.
    with tqdm('training',
              total=BATCH_SIZE * BATCHES_PER_EPOCH,
              ncols=80,
              desc='train',
              leave=False) as pbar:
        for record in data.iter_train_pairs(model, dataset, train_pairs,
                                            GRAD_ACC_SIZE, contentid2entity):
            query_entity = embed(record['query_entity'].cpu() + 1).cuda()
            doc_entity = embed(record['doc_entity'].cpu() + 1).cuda()
            scores = model(record['query_tok'], record['query_mask'],
                           record['doc_tok'], record['doc_mask'], query_entity,
                           doc_entity)
            count = len(record['query_id']) // 2
            scores = scores.reshape(count, 2)
            loss = torch.mean(1. -
                              scores.softmax(dim=1)[:, 0])  # pariwse softmax
            loss.backward()
            total_loss += loss.item()
            total += count
            if total % BATCH_SIZE == 0:
                optimizer.step()
                optimizer.zero_grad()
            pbar.update(count)
            if total >= BATCH_SIZE * BATCHES_PER_EPOCH:
                return total_loss
Ejemplo n.º 3
0
def duet_train_iteration(model, optimizer, dataset, train_pairs, qrels,
                         warmup_epoch, epoch):
    BATCH_SIZE = 16
    BATCHES_PER_EPOCH = 64
    GRAD_ACC_SIZE = 2
    total = 0
    model.train()
    total_loss = 0.
    total_vloss = 0.
    total_closs = 0.
    cq_sum = 0.
    cd_sum = 0.
    with tqdm('training',
              total=BATCH_SIZE * BATCHES_PER_EPOCH,
              ncols=80,
              desc='train',
              leave=False) as pbar:
        for record in data.iter_train_pairs(model, dataset, train_pairs, qrels,
                                            GRAD_ACC_SIZE):
            scores, v_scores, c_scores = model(record['query_tok'],
                                               record['query_mask'],
                                               record['doc_tok'],
                                               record['doc_mask'])
            count = len(record['query_id']) // 2
            v_scores = v_scores.reshape(count, 2)
            c_scores = c_scores.reshape(count, 2)
            scores = scores.reshape(count, 2)
            #print(v_scores)
            #print(c_scores)
            #print(scores)

            ## independent learning
            v_loss = torch.mean(
                1. - v_scores.softmax(dim=1)[:, 0])  # pariwse softmax
            c_loss = torch.mean(
                1. - c_scores.softmax(dim=1)[:, 0])  # pariwse softmax
            #loss = v_loss + c_loss

            ## joint learning
            loss = torch.mean(1. -
                              scores.softmax(dim=1)[:, 0])  # pariwse softmax

            if (epoch < warmup_epoch):  ## initial warming up
                v_loss.backward()
                c_loss.backward()
            else:  ## jointly learning
                loss.backward()

            total_loss += loss.item()
            total_vloss += v_loss.item()
            total_closs += c_loss.item()
            total += count
            if total % BATCH_SIZE == 0:
                optimizer.step()
                optimizer.zero_grad()
            pbar.update(count)
            if total >= BATCH_SIZE * BATCHES_PER_EPOCH:
                return total_loss, total_vloss, total_closs
Ejemplo n.º 4
0
def main(model, dataset, train_pairs, qrels, valid_run, qrelf, model_out_dir):
    params = [(k, v) for k, v in model.named_parameters() if v.requires_grad]
    non_bert_params = {
        'params': [v for k, v in params if not k.startswith('bert.')]
    }
    bert_params = {
        'params': [v for k, v in params if k.startswith('bert.')],
        'lr': BERT_LR
    }
    optimizer = torch.optim.Adam([non_bert_params, bert_params], lr=LR)
    # optimizer = torch.optim.SGD([non_bert_params, bert_params], lr=LR, momentum=0.9)

    # model.to(device)
    model_parallel = dp.DataParallel(model, device_ids=devices)

    epoch = 0
    top_valid_score = None
    for epoch in range(MAX_EPOCH):

        # loss = train_iteration(model, optimizer, dataset, train_pairs, qrels)
        # print(f'train epoch={epoch} loss={loss}')
        # # return
        train_set = TrainDataset(it=data.iter_train_pairs(
            model, dataset, train_pairs, qrels, 1),
                                 length=BATCH_SIZE * BATCHES_PER_EPOCH)
        train_loader = torch.utils.data.DataLoader(
            train_set,
            batch_size=GRAD_ACC_SIZE,
        )
        # for i, tr in enumerate(train_loader):
        #     for tt in tr:
        #         print(tt, tr[tt].size())
        #     break
        # print('finished')
        # return
        model_parallel(train_iteration_multi, train_loader)
        '''
Ejemplo n.º 5
0
def train_iteration(model, optimizer, dataset, train_pairs, qrels):
    model.train()
    total = 0
    total_loss = 0.
    with tqdm('training',
              total=BATCH_SIZE * BATCHES_PER_EPOCH,
              ncols=80,
              desc='train') as pbar:
        for n_iter, record in enumerate(
                data.iter_train_pairs(model, dataset, train_pairs, qrels,
                                      GRAD_ACC_SIZE)):
            # if n_iter > 15:
            # return
            scores = model(record['query_tok'], record['query_mask'],
                           record['doc_tok'], record['doc_mask'])
            count = len(record['query_id']) // 2
            # scores = scores.reshape(count, 2)

            # loss = torch.mean(1. - scores.softmax(dim=1)[:, 0]) # pairwise softmax
            # loss.backward()
            # total_loss += loss.item()
            # total_loss += loss
            total += count

            # if n_iter > 0:
            # print(n_iter, [(record[x].size(), record[x].device) for x in ['query_tok', 'query_mask', 'doc_tok', 'doc_mask']])
            # import torch_xla.debug.metrics as met
            # print(met.metrics_report())

            if total % BATCH_SIZE == 0:
                xm.optimizer_step(optimizer, barrier=True)
                optimizer.zero_grad()

            pbar.update(count)
            if total >= BATCH_SIZE * BATCHES_PER_EPOCH:
                return total_loss