示例#1
0
def train():
    train_data = data_prepro.yield_data(args.train_path)
    test_data = data_prepro.yield_data(args.dev_path)

    model = bertMRC(pre_train_dir=args.pretrained_model_path, dropout_rate=0.5).to(device)
    #model.load_state_dict(torch.load(args.checkpoints))
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
            'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay_rate': 0.0}
    ]
    optimizer = optim.AdamW(params=optimizer_grouped_parameters, lr=args.learning_rate)

    schedule = WarmUp_LinearDecay(
                optimizer = optimizer, 
                init_rate = args.learning_rate,
                 warm_up_epoch = args.warm_up_epoch,
                decay_epoch = args.decay_epoch
            )
    
    loss_func = cross_entropy_loss.cross_entropy().to(device)
    acc_func = metrics.metrics_func().to(device)

    step=0
    for epoch in range(args.epoch):
        for item in train_data:
            step+=1
            input_ids, input_mask, input_seg = item["input_ids"], item["input_mask"], item["input_seg"]
            labels,flag = item["labels"],item["flag"]
    
            optimizer.zero_grad()
            logits = model( 
                input_ids=input_ids.to(device), 
                input_mask=input_mask.to(device),
                input_seg=input_seg.to(device),
                flag=flag.to(device),
                is_training=True
            )

            loss= loss_func(logits,labels.to(device))
            loss = loss.float().mean().type_as(loss)
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_norm)
            schedule.step()

            if step%100 == 0:
                logits=torch.nn.functional.softmax(logits, dim=-1)
                recall, precise, f1=acc_func(logits,labels.to(device))

                logger.info('epoch %d, step %d, loss %.4f, recall %.4f, precise %.4f, f1 %.4f'% (
                    epoch,step,loss,recall, precise, f1))

        with torch.no_grad():

            recall, precise, f1=0,0,0
            count=0

            for item in test_data:
                count+=1
                input_ids, input_mask, input_seg = item["input_ids"], item["input_mask"], item["input_seg"]
                labels,flag = item["labels"],item["flag"]
        
                optimizer.zero_grad()
                logits = model( 
                    input_ids=input_ids.to(device), 
                    input_mask=input_mask.to(device),
                    input_seg=input_seg.to(device),
                    flag=flag.to(device),
                    is_training=False
                )
                tmp_recall, tmp_precise, tmp_f1=acc_func(logits,labels.to(device))
                f1+=tmp_f1
                recall+=tmp_recall
                precise+=tmp_precise

            f1/=count
            recall/=count
            precise/=count

            logger.info('-----eval----')
            logger.info('epoch %d, step %d, loss %.4f, recall %.4f, precise %.4f, f1 %.4f'% (
                    epoch,step,loss,recall, precise, f1))
            logger.info('-----eval----')
            torch.save(model.state_dict(), f=args.checkpoints)
            logger.info('-----save the best model----')
示例#2
0
def train():
    train_data = data_prepro.yield_data(args.train_path)
    test_data = data_prepro.yield_data(args.test_path)

    model = myModel(pre_train_dir=args.pretrained_model_path, dropout_rate=0.5).to(device)
    # model.load_state_dict(torch.load(args.checkpoints))
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
            'weight_decay_rate': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay_rate': 0.0}
    ]
    optimizer = optim.AdamW(params=optimizer_grouped_parameters, lr=args.learning_rate)

    schedule = WarmUp_LinearDecay(
                optimizer = optimizer, 
                init_rate = args.learning_rate,
                 warm_up_epoch = args.warm_up_epoch,
                decay_epoch = args.decay_epoch
            )
    
    loss_func = cross_entropy_loss.cross_entropy().to(device)

    acc_func = metrics.metrics_func().to(device)
    # start_acc = metrics.metrics_start().to(device)
    # end_acc = metrics.metrics_end().to(device)

    step=0
    best=0
    for epoch in range(args.epoch):
        for item in train_data:
            step+=1
            input_ids, input_mask, input_seg = item["input_ids"], item["input_mask"], item["input_seg"]
            start_label ,end_label, span_label, seq_mask = item["start_label"],item["end_label"],item['span_label'],item["seq_mask"]
            seq_id = item["seq_id"]
            optimizer.zero_grad()
            start_logits,end_logits = model( 
                input_ids=input_ids.to(device), 
                input_mask=input_mask.to(device),
                input_seg=input_seg.to(device),
                is_training=True
            )
            start_end_loss = loss_func(start_logits,end_logits,start_label.to(device),end_label.to(device),seq_mask.to(device))
            loss = start_end_loss
            loss = loss.float().mean().type_as(loss)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_norm)
            schedule.step()
            # optimizer.step()
            if step%50 == 0:
                start_logits = torch.nn.functional.softmax(start_logits, dim=-1)
                end_logits = torch.nn.functional.softmax(end_logits, dim=-1)
                _,_,start_f1=acc_func(start_logits,start_label.to(device),seq_mask.to(device))
                _,_,end_f1=acc_func(end_logits,end_label.to(device),seq_mask.to(device))

                logger.info('epoch %d, step %d, loss %.4f, start_f1 %.4f, end_f1 %.4f'% (
                    epoch,step,loss,start_f1,end_f1))
        
        with torch.no_grad():

            start_f1=0
            end_f1=0
            count=0
            flag_f1=0

            for item in test_data:
                input_ids, input_mask, input_seg = item["input_ids"], item["input_mask"], item["input_seg"]
                start_label,end_label,span_label,seq_mask = item["start_label"],item["end_label"],item['span_label'],item["seq_mask"]
                seq_id = item["seq_id"]
                optimizer.zero_grad()
                start_logits,end_logits = model( 
                    input_ids=input_ids.to(device), 
                    input_mask=input_mask.to(device),
                    input_seg=input_seg.to(device),
                    is_training=False
                    ) 
                _,_,tmp_f1_start=acc_func(start_logits,start_label.to(device),seq_mask.to(device))
                start_f1+=tmp_f1_start

                _,_,tmp_f1_end=acc_func(end_logits,end_label.to(device),seq_mask.to(device))
                end_f1+=tmp_f1_end
                count+=1

            start_f1=start_f1/count
            end_f1=end_f1/count

            logger.info('-----eval----')
            logger.info('epoch %d, step %d, loss %.4f, start_f1 %.4f, end_f1 %.4f'% (
                    epoch,step,loss,start_f1,end_f1))
            logger.info('-----eval----')
            if best < start_f1+end_f1:
                best=start_f1+end_f1
                torch.save(model.state_dict(), f=args.checkpoints)
                logger.info('-----save the best model----')
示例#3
0
def train():
    train_data = data_prepro.yield_data(args.train_path)
    test_data = data_prepro.yield_data(args.dev_path)

    model = bertMRC(pre_train_dir=args.pretrained_model_path,
                    dropout_rate=0.5).to(device)
    # model.load_state_dict(torch.load(args.checkpoints))
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'gamma', 'beta']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.0
    }]
    optimizer = optim.AdamW(params=optimizer_grouped_parameters,
                            lr=args.learning_rate)

    schedule = WarmUp_LinearDecay(optimizer=optimizer,
                                  init_rate=args.learning_rate,
                                  warm_up_epoch=args.warm_up_epoch,
                                  decay_epoch=args.decay_epoch)

    loss_func = binary_cross_entropy.binary_cross_entropy().to(device)
    loss_cross_func = cross_entropy_loss.cross_entropy().to(device)
    acc_func = metrics.metrics_func().to(device)
    span_acc_func = metrics.metrics_span_func().to(device)

    step = 0
    for epoch in range(args.epoch):
        for item in train_data:
            step += 1
            input_ids, input_mask, input_seg = item["input_ids"], item[
                "input_mask"], item["input_seg"]
            s_startlabel, s_endlabel, o_startlabel, o_endlabel = item[
                "s_startlabel"], item["s_endlabel"], item[
                    "o_startlabel"], item["o_endlabel"]

            optimizer.zero_grad()
            s_startlogits, s_endlogits, o_startlogits, o_endlogits = model(
                input_ids=input_ids.to(device),
                input_mask=input_mask.to(device),
                input_seg=input_seg.to(device),
                is_training=True)

            s_startloss = loss_func(s_startlogits, s_startlabel.to(device))
            s_endloss = loss_func(s_endlogits, s_endlabel.to(device))

            o_startloss = loss_func(o_startlogits, o_startlabel.to(device))
            o_endloss = loss_func(o_endlogits, o_endlabel.to(device))

            loss = s_startloss + s_endloss + o_startloss + o_endloss
            loss = loss.float().mean().type_as(loss)
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           max_norm=args.clip_norm)
            schedule.step()

            if step % 100 == 0:
                _, _, s_start_f1 = acc_func(s_startlogits,
                                            s_startlabel.to(device))
                _, _, s_end_f1 = acc_func(s_endlogits, s_endlabel.to(device))

                _, _, o_start_f1 = acc_func(o_startlogits,
                                            o_startlabel.to(device))
                _, _, o_end_f1 = acc_func(o_endlogits, o_endlabel.to(device))

                logger.info(
                    'epoch %d, step %d, loss %.4f, s_start_f1 %.4f, s_end_f1 %.4f, o_start_f1 %.4f, o_end_f1 %.4f'
                    % (epoch, step, loss, s_start_f1, s_end_f1, o_start_f1,
                       o_end_f1))

        with torch.no_grad():

            s_start_f1, s_end_f1, o_start_f1, o_end_f1 = 0, 0, 0, 0
            count = 0

            for item in test_data:
                count += 1
                input_ids, input_mask, input_seg = item["input_ids"], item[
                    "input_mask"], item["input_seg"]
                s_startlabel, s_endlabel, o_startlabel, o_endlabel = item[
                    "s_startlabel"], item["s_endlabel"], item[
                        "o_startlabel"], item["o_endlabel"]

                optimizer.zero_grad()
                s_startlogits, s_endlogits, o_startlogits, o_endlogits = model(
                    input_ids=input_ids.to(device),
                    input_mask=input_mask.to(device),
                    input_seg=input_seg.to(device),
                    is_training=True)
                _, _, tmp_s_start_f1 = acc_func(s_startlogits,
                                                s_startlabel.to(device))
                _, _, tmp_s_end_f1 = acc_func(s_endlogits,
                                              s_endlabel.to(device))

                _, _, tmp_o_start_f1 = acc_func(o_startlogits,
                                                o_startlabel.to(device))
                _, _, tmp_o_end_f1 = acc_func(o_endlogits,
                                              o_endlabel.to(device))

                s_start_f1 += tmp_s_start_f1
                s_end_f1 += tmp_s_end_f1

                o_start_f1 += tmp_o_start_f1
                o_end_f1 += tmp_o_end_f1

            s_start_f1 /= count
            s_end_f1 /= count
            o_start_f1 /= count
            o_end_f1 /= count

            logger.info('-----eval----')
            logger.info(
                'epoch %d, step %d, loss %.4f, s_start_f1 %.4f, s_end_f1 %.4f, o_start_f1 %.4f, o_end_f1 %.4f'
                % (epoch, step, loss, s_start_f1, s_end_f1, o_start_f1,
                   o_end_f1))
            logger.info('-----eval----')
            torch.save(model.state_dict(), f=args.checkpoints)
            logger.info('-----save the best model----')
def train():
    train_data = data_prepro.yield_data(args.train_path)
    test_data = data_prepro.yield_data(args.test_path)

    model = myModel(pre_train_dir=args.pretrained_model_path,
                    dropout_rate=0.5).to(device)
    # model.load_state_dict(torch.load(args.checkpoints),False)

    no_decay = ["bias", "LayerNorm.weight"]
    param_optimizer = list(model.named_parameters())
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay_rate':
        0.0
    }]

    optimizer = optim.AdamW(params=optimizer_grouped_parameters,
                            lr=args.learning_rate)
    schedule = WarmUp_LinearDecay(optimizer=optimizer,
                                  init_rate=args.learning_rate,
                                  warm_up_epoch=args.warm_up_epoch,
                                  decay_epoch=args.decay_epoch)

    acc_func = metrics.metrics_func().to(device)
    step = 0
    best_f1 = 0
    for epoch in range(args.epoch):
        for item in train_data:
            step += 1
            input_ids, input_mask, input_seg = item["input_ids"], item[
                "input_mask"], item["input_seg"]
            labels = item["labels"]
            token_word = item["token_word"]

            optimizer.zero_grad()
            loss, logits = model(input_ids=input_ids.to(device),
                                 input_mask=input_mask.to(device),
                                 input_seg=input_seg.to(device),
                                 token_word=token_word.to(device),
                                 labels=labels)

            loss = loss.float().mean().type_as(loss)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           max_norm=args.clip_norm)
            schedule.step()

            if step % 100 == 0:
                recall, precise, f1 = acc_func(logits, labels.to(device))
                logger.info(
                    'epoch %d, step %d, loss %.4f, recall %.4f, precise %.4f, f1 %.4f'
                    % (epoch, step, loss, recall, precise, f1))

        with torch.no_grad():
            recall = 0
            precise = 0
            f1 = 0
            count = 0
            for item in test_data:
                count += 1
                input_ids, input_mask, input_seg = item["input_ids"], item[
                    "input_mask"], item["input_seg"]
                labels = item["labels"]
                token_word = item["token_word"]

                loss, logits = model(input_ids=input_ids.to(device),
                                     input_mask=input_mask.to(device),
                                     input_seg=input_seg.to(device),
                                     token_word=token_word.to(device),
                                     labels=labels)
                tmp_recall, tmp_precise, tmp_f1 = acc_func(
                    logits, labels.to(device))
                recall += tmp_recall
                precise += tmp_precise
                f1 += tmp_f1

            recall = recall / count
            precise = precise / count
            f1 = f1 / count

            logger.info('-----eval----')
            logger.info(
                'epoch %d, step %d, loss %.4f, recall %.4f, precise %.4f, f1 %.4f'
                % (epoch, step, loss, recall, precise, f1))
            logger.info('-----eval----')
            if best_f1 < f1:
                best_f1 = f1
                torch.save(model.state_dict(), f=args.checkpoints)
                logger.info('-----save the best model----')