Exemplo n.º 1
0
    def test_warmup_constant_scheduler(self):
        scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4)
        lrs = unwrap_schedule(scheduler, self.num_steps)
        expected_learning_rates = [
            2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0
        ]
        self.assertEqual(len(lrs[0]), 1)
        self.assertListEqual([l[0] for l in lrs], expected_learning_rates)

        scheduler = WarmupConstantSchedule(self.optimizer, warmup_steps=4)
        lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
        self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
Exemplo n.º 2
0
def train(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available() and args.cuda:
        torch.cuda.manual_seed(args.seed)

    print('configuration:')
    print('\n'.join('\t{:15} {}'.format(k + ':', str(v))
                    for k, v in sorted(dict(vars(args)).items())))
    print()

    config_path = os.path.join(args.save_dir, 'config.json')
    model_path = os.path.join(args.save_dir, 'model.pt')
    log_path = os.path.join(args.save_dir, 'log.csv')
    if args.save:
        export_config(args, config_path)
        check_path(model_path)
        with open(log_path, 'w') as fout:
            fout.write('step,train_acc,dev_acc\n')

    ###################################################################################################
    #   Load data                                                                                     #
    ###################################################################################################

    cp_emb = [np.load(path) for path in args.ent_emb_paths]
    cp_emb = torch.tensor(np.concatenate(cp_emb, 1))

    concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1)
    print('num_concepts: {}, concept_dim: {}'.format(concept_num, concept_dim))

    device = torch.device(
        "cuda:0" if torch.cuda.is_available() and args.cuda else "cpu")

    dataset = GconAttnDataLoader(
        train_statement_path=args.train_statements,
        train_concept_jsonl=args.train_concepts,
        dev_statement_path=args.dev_statements,
        dev_concept_jsonl=args.dev_concepts,
        test_statement_path=args.test_statements,
        test_concept_jsonl=args.test_concepts,
        concept2id_path=args.cpnet_vocab_path,
        batch_size=args.batch_size,
        eval_batch_size=args.eval_batch_size,
        device=device,
        model_name=args.encoder,
        max_cpt_num=max_cpt_num[args.dataset],
        max_seq_length=args.max_seq_len,
        is_inhouse=args.inhouse,
        inhouse_train_qids_path=args.inhouse_train_qids,
        subsample=args.subsample,
        format=args.format)

    print('len(train_set): {}   len(dev_set): {}   len(test_set): {}'.format(
        dataset.train_size(), dataset.dev_size(), dataset.test_size()))
    print()

    ###################################################################################################
    #   Build model                                                                                   #
    ###################################################################################################

    lstm_config = get_lstm_config_from_args(args)
    model = LMGconAttn(model_name=args.encoder,
                       concept_num=concept_num,
                       concept_dim=args.cpt_out_dim,
                       concept_in_dim=concept_dim,
                       freeze_ent_emb=args.freeze_ent_emb,
                       pretrained_concept_emb=cp_emb,
                       hidden_dim=args.decoder_hidden_dim,
                       dropout=args.dropoutm,
                       encoder_config=lstm_config)

    if args.freeze_ent_emb:
        freeze_net(model.decoder.concept_emb)

    try:
        model.to(device)
    except RuntimeError as e:
        print(e)
        print('best dev acc: 0.0 (at epoch 0)')
        print('final test acc: 0.0')
        print()
        return

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    grouped_parameters = [
        {
            'params': [
                p for n, p in model.encoder.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.encoder_lr
        },
        {
            'params': [
                p for n, p in model.encoder.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.encoder_lr
        },
        {
            'params': [
                p for n, p in model.decoder.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.decoder_lr
        },
        {
            'params': [
                p for n, p in model.decoder.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.decoder_lr
        },
    ]
    optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters)

    if args.lr_schedule == 'fixed':
        scheduler = ConstantLRSchedule(optimizer)
    elif args.lr_schedule == 'warmup_constant':
        scheduler = WarmupConstantSchedule(optimizer,
                                           warmup_steps=args.warmup_steps)
    elif args.lr_schedule == 'warmup_linear':
        max_steps = int(args.n_epochs *
                        (dataset.train_size() / args.batch_size))
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=args.warmup_steps,
                                         t_total=max_steps)

    print('parameters:')
    for name, param in model.decoder.named_parameters():
        if param.requires_grad:
            print('\t{:45}\ttrainable\t{}'.format(name, param.size()))
        else:
            print('\t{:45}\tfixed\t{}'.format(name, param.size()))
    num_params = sum(p.numel() for p in model.decoder.parameters()
                     if p.requires_grad)
    print('\ttotal:', num_params)

    if args.loss == 'margin_rank':
        loss_func = nn.MarginRankingLoss(margin=0.1, reduction='mean')
    elif args.loss == 'cross_entropy':
        loss_func = nn.CrossEntropyLoss(reduction='mean')

    ###################################################################################################
    #   Training                                                                                      #
    ###################################################################################################

    print('-' * 71)
    global_step, best_dev_epoch = 0, 0
    best_dev_acc, final_test_acc, total_loss = 0.0, 0.0, 0.0
    start_time = time.time()
    model.train()
    freeze_net(model.encoder)
    try:
        for epoch_id in range(args.n_epochs):
            if epoch_id == args.unfreeze_epoch:
                unfreeze_net(model.encoder)
            if epoch_id == args.refreeze_epoch:
                freeze_net(model.encoder)
            model.train()
            for qids, labels, *input_data in dataset.train():
                optimizer.zero_grad()
                bs = labels.size(0)
                for a in range(0, bs, args.mini_batch_size):
                    b = min(a + args.mini_batch_size, bs)
                    logits, _ = model(*[x[a:b] for x in input_data],
                                      layer_id=args.encoder_layer)

                    if args.loss == 'margin_rank':
                        num_choice = logits.size(1)
                        flat_logits = logits.view(-1)
                        correct_mask = F.one_hot(
                            labels, num_classes=num_choice).view(
                                -1)  # of length batch_size*num_choice
                        correct_logits = flat_logits[
                            correct_mask == 1].contiguous().view(-1, 1).expand(
                                -1, num_choice - 1).contiguous().view(
                                    -1)  # of length batch_size*(num_choice-1)
                        wrong_logits = flat_logits[
                            correct_mask ==
                            0]  # of length batch_size*(num_choice-1)
                        y = wrong_logits.new_ones((wrong_logits.size(0), ))
                        loss = loss_func(correct_logits, wrong_logits,
                                         y)  # margin ranking loss
                    elif args.loss == 'cross_entropy':
                        loss = loss_func(logits, labels[a:b])
                    loss = loss * (b - a) / bs
                    loss.backward()
                    total_loss += loss.item()
                if args.max_grad_norm > 0:
                    nn.utils.clip_grad_norm_(model.parameters(),
                                             args.max_grad_norm)
                scheduler.step()
                optimizer.step()

                if (global_step + 1) % args.log_interval == 0:
                    total_loss /= args.log_interval
                    ms_per_batch = 1000 * (time.time() -
                                           start_time) / args.log_interval
                    print(
                        '| step {:5} |  lr: {:9.7f} | loss {:7.4f} | ms/batch {:7.2f} |'
                        .format(global_step,
                                scheduler.get_lr()[0], total_loss,
                                ms_per_batch))
                    total_loss = 0
                    start_time = time.time()
                global_step += 1

            model.eval()
            dev_acc = evaluate_accuracy(dataset.dev(), model)
            test_acc = evaluate_accuracy(
                dataset.test(), model) if args.test_statements else 0.0
            print('-' * 71)
            print('| step {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format(
                global_step, dev_acc, test_acc))
            print('-' * 71)
            if args.save:
                with open(log_path, 'a') as fout:
                    fout.write('{},{},{}\n'.format(global_step, dev_acc,
                                                   test_acc))
            if dev_acc >= best_dev_acc:
                best_dev_acc = dev_acc
                final_test_acc = test_acc
                best_dev_epoch = epoch_id
                if args.save:
                    torch.save([model, args], model_path)
                    print(f'model saved to {model_path}')
            model.train()
            start_time = time.time()
            if epoch_id > args.unfreeze_epoch and epoch_id - best_dev_epoch >= args.max_epochs_before_stop:
                break
    except (KeyboardInterrupt, RuntimeError) as e:
        print(e)

    print()
    print('training ends in {} steps'.format(global_step))
    print('best dev acc: {:.4f} (at epoch {})'.format(best_dev_acc,
                                                      best_dev_epoch))
    print('final test acc: {:.4f}'.format(final_test_acc))
    print()
Exemplo n.º 3
0
Arquivo: lm.py Projeto: zxlzr/MHGRN
def train(args):
    print(args)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available() and args.cuda:
        torch.cuda.manual_seed(args.seed)

    model_path = os.path.join(args.save_dir, 'model.pt')
    check_path(model_path)

    ###################################################################################################
    #   Load data                                                                                     #
    ###################################################################################################

    device = torch.device("cuda:0" if torch.cuda.is_available() and args.cuda else "cpu")

    dataset = LMDataLoader(args.train_statements, args.dev_statements, args.test_statements,
                           batch_size=args.batch_size, eval_batch_size=args.eval_batch_size, device=device,
                           model_name=args.encoder,
                           max_seq_length=args.max_seq_len,
                           is_inhouse=args.inhouse, inhouse_train_qids_path=args.inhouse_train_qids, subsample=args.subsample)

    ###################################################################################################
    #   Build model                                                                                   #
    ###################################################################################################

    lstm_config = get_lstm_config_from_args(args)
    model = LMForMultipleChoice(args.encoder, from_checkpoint=args.from_checkpoint, encoder_config=lstm_config)

    try:
        model.to(device)
    except RuntimeError as e:
        print(e)
        print('best dev acc: 0.0 (at epoch 0)')
        print('final test acc: 0.0')
        print()
        return

    no_decay = ['bias', 'LayerNorm.weight']
    grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'lr': args.encoder_lr, 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'lr': args.encoder_lr, 'weight_decay': 0.0}
    ]
    optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters)

    if args.lr_schedule == 'fixed':
        scheduler = ConstantLRSchedule(optimizer)
    elif args.lr_schedule == 'warmup_constant':
        scheduler = WarmupConstantSchedule(optimizer, warmup_steps=args.warmup_steps)
    elif args.lr_schedule == 'warmup_linear':
        max_steps = int(args.n_epochs * (dataset.train_size() / args.batch_size))
        scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=max_steps)

    if args.loss == 'margin_rank':
        loss_func = nn.MarginRankingLoss(margin=0.1, reduction='mean')
    elif args.loss == 'cross_entropy':
        loss_func = nn.CrossEntropyLoss(reduction='mean')

    ###################################################################################################
    #   Training                                                                                      #
    ###################################################################################################

    print()
    print('***** running training *****')
    print(f'| batch_size: {args.batch_size} | num_epochs: {args.n_epochs} | num_train: {dataset.train_size()} |'
          f' num_dev: {dataset.dev_size()} | num_test: {dataset.test_size()}')

    global_step = 0
    best_dev_acc = 0
    best_dev_epoch = 0
    final_test_acc = 0
    try:
        for epoch in range(int(args.n_epochs)):
            model.train()
            tqdm_bar = tqdm(dataset.train(), desc="Training")
            for qids, labels, *input_data in tqdm_bar:
                optimizer.zero_grad()
                batch_loss = 0
                bs = labels.size(0)
                for a in range(0, bs, args.mini_batch_size):
                    b = min(a + args.mini_batch_size, bs)
                    logits = model(*[x[a:b] for x in input_data], layer_id=args.encoder_layer)
                    if args.loss == 'margin_rank':
                        num_choice = logits.size(1)
                        flat_logits = logits.view(-1)
                        correct_mask = F.one_hot(labels, num_classes=num_choice).view(-1)  # of length batch_size*num_choice
                        correct_logits = flat_logits[correct_mask == 1].contiguous().view(-1, 1).expand(-1, num_choice - 1).contiguous().view(-1)  # of length batch_size*(num_choice-1)
                        wrong_logits = flat_logits[correct_mask == 0]  # of length batch_size*(num_choice-1)
                        y = wrong_logits.new_ones((wrong_logits.size(0),))
                        loss = loss_func(correct_logits, wrong_logits, y)  # margin ranking loss
                    elif args.loss == 'cross_entropy':
                        loss = loss_func(logits, labels[a:b])
                    loss = loss * (b - a) / bs
                    loss.backward()
                    batch_loss += loss.item()
                if args.max_grad_norm > 0:
                    nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                optimizer.step()
                scheduler.step()
                tqdm_bar.desc = "loss: {:.2e}  lr: {:.2e}".format(batch_loss, scheduler.get_lr()[0])
                global_step += 1

            model.eval()
            dev_acc = evaluate_accuracy(dataset.dev(), model)
            test_acc = evaluate_accuracy(dataset.test(), model) if dataset.test_size() > 0 else 0.0
            if dev_acc > best_dev_acc:
                final_test_acc = test_acc
                best_dev_acc = dev_acc
                best_dev_epoch = epoch
                torch.save([model, args], model_path)
            print('| epoch {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format(epoch, dev_acc, test_acc))
            if epoch - best_dev_epoch >= args.max_epochs_before_stop:
                break
    except (KeyboardInterrupt, RuntimeError) as e:
        print(e)

    print('***** training ends *****')
    print()
    print('training ends in {} steps'.format(global_step))
    print('best dev acc: {:.4f} (at epoch {})'.format(best_dev_acc, best_dev_epoch))
    print('final test acc: {:.4f}'.format(final_test_acc))
    print()
def train(args):
    print(args)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available() and args.cuda:
        torch.cuda.manual_seed(args.seed)

    config_path = os.path.join(args.save_dir, 'config.json')
    model_path = os.path.join(args.save_dir, 'model.pt')
    log_path = os.path.join(args.save_dir, 'log.csv')
    export_config(args, config_path)
    check_path(model_path)
    with open(log_path, 'w') as fout:
        fout.write('step,train_acc,dev_acc\n')

    ###################################################################################################
    #   Load data                                                                                     #
    ###################################################################################################
    if 'lm' in args.ent_emb:
        print('Using contextualized embeddings for concepts')
        use_contextualized = True
    else:
        use_contextualized = False
    cp_emb = [np.load(path) for path in args.ent_emb_paths]
    cp_emb = torch.tensor(np.concatenate(cp_emb, 1), dtype=torch.float)
    concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1)
    print('| num_concepts: {} |'.format(concept_num))



    device = torch.device("cuda:0" if torch.cuda.is_available() and args.cuda else "cpu")
    dataset = LMGraphRelationNetDataLoader(args.train_statements, args.train_adj,
                                           args.dev_statements, args.dev_adj,
                                           args.test_statements, args.test_adj,
                                           batch_size=args.batch_size, eval_batch_size=args.eval_batch_size,
                                           device=(device, device),
                                           model_name=args.encoder,
                                           max_node_num=args.max_node_num, max_seq_length=args.max_seq_len,
                                           is_inhouse=args.inhouse, inhouse_train_qids_path=args.inhouse_train_qids,
                                           use_contextualized=use_contextualized,
                                           train_embs_path=args.train_embs, dev_embs_path=args.dev_embs,
                                           test_embs_path=args.test_embs,
                                           subsample=args.subsample, format=args.format)

    ###################################################################################################
    #   Build model                                                                                   #
    ###################################################################################################

    lstm_config = get_timeline_config(args)
    model = LMGraphRelationNet(args.encoder, k=args.k, n_type=3, n_basis=args.num_basis, n_layer=args.gnn_layer_num,
                               diag_decompose=args.diag_decompose, n_concept=concept_num,
                               n_relation=args.num_relation, concept_dim=args.gnn_dim,
                               concept_in_dim=(
                                   dataset.get_node_feature_dim() if use_contextualized else concept_dim),
                               n_attention_head=args.att_head_num, fc_dim=args.fc_dim, n_fc_layer=args.fc_layer_num,
                               att_dim=args.att_dim, att_layer_num=args.att_layer_num,
                               p_emb=args.dropouti, p_gnn=args.dropoutg, p_fc=args.dropoutf,
                               pretrained_concept_emb=cp_emb, freeze_ent_emb=args.freeze_ent_emb,
                               ablation=args.ablation, init_range=args.init_range,
                               eps=args.eps, use_contextualized=use_contextualized,
                               do_init_rn=args.init_rn, do_init_identity=args.init_identity,
                               encoder_config=lstm_config)
    model.to(device)

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    if args.fix_trans:
        no_decay.append('trans_scores')
    grouped_parameters = [
        {'params': [p for n, p in model.encoder.named_parameters() if not any(nd in n for nd in no_decay)],
         'weight_decay': args.weight_decay, 'lr': args.encoder_lr},
        {'params': [p for n, p in model.encoder.named_parameters() if any(nd in n for nd in no_decay)],
         'weight_decay': 0.0, 'lr': args.encoder_lr},
        {'params': [p for n, p in model.decoder.named_parameters() if not any(nd in n for nd in no_decay)],
         'weight_decay': args.weight_decay, 'lr': args.decoder_lr},
        {'params': [p for n, p in model.decoder.named_parameters() if any(nd in n for nd in no_decay)],
         'weight_decay': 0.0, 'lr': args.decoder_lr},
    ]
    optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters)

    if args.lr_schedule == 'fixed':
        scheduler = ConstantLRSchedule(optimizer)
    elif args.lr_schedule == 'warmup_constant':
        scheduler = WarmupConstantSchedule(optimizer, warmup_steps=args.warmup_steps)
    elif args.lr_schedule == 'warmup_linear':
        max_steps = int(args.n_epochs * (dataset.train_size() / args.batch_size))
        scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=max_steps)

    print('encoder parameters:')
    for name, param in model.encoder.named_parameters():
        if param.requires_grad:
            print('\t{:45}\ttrainable\t{}'.format(name, param.size()))
        else:
            print('\t{:45}\tfixed\t{}'.format(name, param.size()))
    num_params = sum(p.numel() for p in model.encoder.parameters() if p.requires_grad)
    print('\ttotal:', num_params)

    print('decoder parameters:')
    for name, param in model.decoder.named_parameters():
        if param.requires_grad:
            print('\t{:45}\ttrainable\t{}'.format(name, param.size()))
        else:
            print('\t{:45}\tfixed\t{}'.format(name, param.size()))
    num_params = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad)
    print('\ttotal:', num_params)

    if args.loss == 'margin_rank':
        loss_func = nn.MarginRankingLoss(margin=0.1, reduction='mean')
    elif args.loss == 'BCE':
        loss_func = nn.BCEWithLogitsLoss(reduction='mean')
    elif args.loss == 'cross_entropy':
        loss_func = nn.CrossEntropyLoss(reduction='mean')

    ###################################################################################################
    #   Training                                                                                      #
    ###################################################################################################

    print()
    print('-' * 71)
    global_step, best_dev_epoch = 0, 0
    best_dev_auc, final_test_auc, total_loss = 0.0, 0.0, 0.0
    start_time = time.time()
    model.train()
    for epoch_id in range(args.n_epochs):
        print('epoch: {:5} '.format(epoch_id))

        model.train()
        for qids, labels, *input_data in dataset.train():
            optimizer.zero_grad()
            bs = labels.size(0)
            for a in range(0, bs, args.mini_batch_size):
                b = min(a + args.mini_batch_size, bs)
                logits, _ = model(*[x[a:b] for x in input_data], layer_id=args.encoder_layer)

                if args.loss == 'margin_rank':
                    num_choice = logits.size(1)
                    flat_logits = logits.view(-1)
                    correct_mask = F.one_hot(labels, num_classes=num_choice).view(
                        -1)  # of length batch_size*num_choice
                    correct_logits = flat_logits[correct_mask == 1].contiguous().view(-1, 1).expand(-1,
                                                                                                    num_choice - 1).contiguous().view(
                        -1)  # of length batch_size*(num_choice-1)
                    wrong_logits = flat_logits[correct_mask == 0]  # of length batch_size*(num_choice-1)
                    y = wrong_logits.new_ones((wrong_logits.size(0),))
                    loss = loss_func(correct_logits, wrong_logits, y)  # margin ranking loss
                elif args.loss == 'cross_entropy':

                    loss = loss_func(logits, labels[a:b])
                loss = loss * (b - a) / bs
                loss.backward()
                total_loss += loss.item()
            if args.max_grad_norm > 0:
                nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            scheduler.step()
            optimizer.step()

            if (global_step + 1) % args.log_interval == 0:
                total_loss /= args.log_interval
                ms_per_batch = 1000 * (time.time() - start_time) / args.log_interval
                print('| step {:5} |  lr: {:9.7f} | loss {:7.4f} | ms/batch {:7.2f} |'.format(global_step,
                                                                                              scheduler.get_lr()[0],
                                                                                              total_loss,
                                                                                              ms_per_batch))
                total_loss = 0.0
                start_time = time.time()
            global_step += 1

        model.eval()
        dev_acc, d_precision, d_recall, d_f1, d_roc_auc = eval_metric(dataset.dev(), model)
        test_acc, t_precision, t_recall, t_f1, t_roc_auc = eval_metric(dataset.test(), model)
        if global_step % args.log_interval == 0:
            tl = total_loss
        else:
            tl = total_loss / (global_step % args.log_interval)
        print('-' * 71)
        print('| step {:5} | dev_acc {:7.4f} | test_acc {:7.4f} | loss {:7.4f} '.format(global_step,
                                                                                        dev_acc,
                                                                                        test_acc,
                                                                                        tl))
        print(
            '| step {:5} | dev_precision {:7.4f} | test_precision {:7.4f} | loss {:7.4f} '.format(
                global_step,
                d_precision,
                t_precision,
                tl))
        print('| step {:5} | dev_recall {:7.4f} | test_recall {:7.4f} | loss {:7.4f} '.format(
            global_step,
            d_recall,
            t_recall,
            tl))
        print('| step {:5} | dev_f1 {:7.4f} | test_f1 {:7.4f} | loss {:7.4f} '.format(global_step,
                                                                                      d_f1,
                                                                                      t_f1,
                                                                                      tl))
        print('| step {:5} | dev_auc {:7.4f} | test_auc {:7.4f} | loss {:7.4f} '.format(global_step,
                                                                                        d_roc_auc,
                                                                                        t_roc_auc,
                                                                                        tl))
        print('-' * 71)
        with open(log_path, 'a') as fout:
            fout.write('{},{},{}\n'.format(global_step, d_roc_auc, t_roc_auc))
        if d_roc_auc >= best_dev_auc:
            best_dev_auc = d_roc_auc
            final_test_auc = t_roc_auc
            best_dev_epoch = epoch_id
            torch.save([model, args], model_path)
            print(f'model saved to {model_path}')
        model.train()
        start_time = time.time()
        if epoch_id > args.unfreeze_epoch and epoch_id - best_dev_epoch >= args.max_epochs_before_stop:
          
            break


    print()
    print('training ends in {} steps'.format(global_step))
    print('best dev auc: {:.4f} (at epoch {})'.format(best_dev_auc, best_dev_epoch))
    print('final test auc: {:.4f}'.format(final_test_auc))
    print()
Exemplo n.º 5
0
def train(args):
    print(args)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available() and args.cuda:
        torch.cuda.manual_seed(args.seed)

    config_path = os.path.join(args.save_dir, 'config.json')
    model_path = os.path.join(args.save_dir, 'model.pt')
    log_path = os.path.join(args.save_dir, 'log.csv')
    export_config(args, config_path)
    check_path(model_path)
    with open(log_path, 'w') as fout:
        fout.write('step,dev_acc,test_acc\n')

    ###################################################################################################
    #   Load data                                                                                     #
    ###################################################################################################
    cp_emb = [np.load(path) for path in args.ent_emb_paths]
    cp_emb = torch.tensor(np.concatenate(cp_emb, 1), dtype=torch.float)

    concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1)
    print('| num_concepts: {} |'.format(concept_num))

    # try:
    if True:
        if torch.cuda.device_count() >= 2 and args.cuda:
            device0 = torch.device("cuda:0")
            device1 = torch.device("cuda:1")
        elif torch.cuda.device_count() == 1 and args.cuda:
            device0 = torch.device("cuda:0")
            device1 = torch.device("cuda:0")
        else:
            device0 = torch.device("cpu")
            device1 = torch.device("cpu")
        dataset = LM_QAGNN_DataLoader(args, args.train_statements, args.train_adj,
                                               args.dev_statements, args.dev_adj,
                                               args.test_statements, args.test_adj,
                                               batch_size=args.batch_size, eval_batch_size=args.eval_batch_size,
                                               device=(device0, device1),
                                               model_name=args.encoder,
                                               max_node_num=args.max_node_num, max_seq_length=args.max_seq_len,
                                               is_inhouse=args.inhouse, inhouse_train_qids_path=args.inhouse_train_qids,
                                               subsample=args.subsample, use_cache=args.use_cache)

        ###################################################################################################
        #   Build model                                                                                   #
        ###################################################################################################

        model = LM_QAGNN(args, args.encoder, k=args.k, n_ntype=4, n_etype=args.num_relation, n_concept=concept_num,
                                   concept_dim=args.gnn_dim,
                                   concept_in_dim=concept_dim,
                                   n_attention_head=args.att_head_num, fc_dim=args.fc_dim, n_fc_layer=args.fc_layer_num,
                                   p_emb=args.dropouti, p_gnn=args.dropoutg, p_fc=args.dropoutf,
                                   pretrained_concept_emb=cp_emb, freeze_ent_emb=args.freeze_ent_emb,
                                   init_range=args.init_range,
                                   encoder_config={})
        model.encoder.to(device0)
        model.decoder.to(device1)


    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']

    grouped_parameters = [
        {'params': [p for n, p in model.encoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.encoder_lr},
        {'params': [p for n, p in model.encoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.encoder_lr},
        {'params': [p for n, p in model.decoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.decoder_lr},
        {'params': [p for n, p in model.decoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.decoder_lr},
    ]
    optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters)

    if args.lr_schedule == 'fixed':
        try:
            scheduler = ConstantLRSchedule(optimizer)
        except:
            scheduler = get_constant_schedule(optimizer)
    elif args.lr_schedule == 'warmup_constant':
        try:
            scheduler = WarmupConstantSchedule(optimizer, warmup_steps=args.warmup_steps)
        except:
            scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps)
    elif args.lr_schedule == 'warmup_linear':
        max_steps = int(args.n_epochs * (dataset.train_size() / args.batch_size))
        try:
            scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=max_steps)
        except:
            scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=max_steps)

    print('parameters:')
    for name, param in model.decoder.named_parameters():
        if param.requires_grad:
            print('\t{:45}\ttrainable\t{}\tdevice:{}'.format(name, param.size(), param.device))
        else:
            print('\t{:45}\tfixed\t{}\tdevice:{}'.format(name, param.size(), param.device))
    num_params = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad)
    print('\ttotal:', num_params)

    if args.loss == 'margin_rank':
        loss_func = nn.MarginRankingLoss(margin=0.1, reduction='mean')
    elif args.loss == 'cross_entropy':
        loss_func = nn.CrossEntropyLoss(reduction='mean')

    ###################################################################################################
    #   Training                                                                                      #
    ###################################################################################################

    print()
    print('-' * 71)
    global_step, best_dev_epoch = 0, 0
    best_dev_acc, final_test_acc, total_loss = 0.0, 0.0, 0.0
    start_time = time.time()
    model.train()
    freeze_net(model.encoder)
    if True:
    # try:
        for epoch_id in range(args.n_epochs):
            if epoch_id == args.unfreeze_epoch:
                unfreeze_net(model.encoder)
            if epoch_id == args.refreeze_epoch:
                freeze_net(model.encoder)
            model.train()
            for qids, labels, *input_data in dataset.train():
                optimizer.zero_grad()
                bs = labels.size(0)
                for a in range(0, bs, args.mini_batch_size):
                    b = min(a + args.mini_batch_size, bs)
                    logits, _ = model(*[x[a:b] for x in input_data], layer_id=args.encoder_layer)

                    if args.loss == 'margin_rank':
                        num_choice = logits.size(1)
                        flat_logits = logits.view(-1)
                        correct_mask = F.one_hot(labels, num_classes=num_choice).view(-1)  # of length batch_size*num_choice
                        correct_logits = flat_logits[correct_mask == 1].contiguous().view(-1, 1).expand(-1, num_choice - 1).contiguous().view(-1)  # of length batch_size*(num_choice-1)
                        wrong_logits = flat_logits[correct_mask == 0]
                        y = wrong_logits.new_ones((wrong_logits.size(0),))
                        loss = loss_func(correct_logits, wrong_logits, y)  # margin ranking loss
                    elif args.loss == 'cross_entropy':
                        loss = loss_func(logits, labels[a:b])
                    loss = loss * (b - a) / bs
                    loss.backward()
                    total_loss += loss.item()
                if args.max_grad_norm > 0:
                    nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                scheduler.step()
                optimizer.step()

                if (global_step + 1) % args.log_interval == 0:
                    total_loss /= args.log_interval
                    ms_per_batch = 1000 * (time.time() - start_time) / args.log_interval
                    print('| step {:5} |  lr: {:9.7f} | loss {:7.4f} | ms/batch {:7.2f} |'.format(global_step, scheduler.get_lr()[0], total_loss, ms_per_batch))
                    total_loss = 0
                    start_time = time.time()
                global_step += 1

            model.eval()
            dev_acc = evaluate_accuracy(dataset.dev(), model)
            save_test_preds = args.save_model
            if not save_test_preds:
                test_acc = evaluate_accuracy(dataset.test(), model) if args.test_statements else 0.0
            else:
                eval_set = dataset.test()
                total_acc = []
                count = 0
                preds_path = os.path.join(args.save_dir, 'test_e{}_preds.csv'.format(epoch_id))
                with open(preds_path, 'w') as f_preds:
                    with torch.no_grad():
                        for qids, labels, *input_data in tqdm(eval_set):
                            count += 1
                            logits, _, concept_ids, node_type_ids, edge_index, edge_type = model(*input_data, detail=True)
                            predictions = logits.argmax(1) #[bsize, ]
                            preds_ranked = (-logits).argsort(1) #[bsize, n_choices]
                            for i, (qid, label, pred, _preds_ranked, cids, ntype, edges, etype) in enumerate(zip(qids, labels, predictions, preds_ranked, concept_ids, node_type_ids, edge_index, edge_type)):
                                acc = int(pred.item()==label.item())
                                print ('{},{}'.format(qid, chr(ord('A') + pred.item())), file=f_preds)
                                f_preds.flush()
                                total_acc.append(acc)
                test_acc = float(sum(total_acc))/len(total_acc)

            print('-' * 71)
            print('| epoch {:3} | step {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format(epoch_id, global_step, dev_acc, test_acc))
            print('-' * 71)
            with open(log_path, 'a') as fout:
                fout.write('{},{},{}\n'.format(global_step, dev_acc, test_acc))
            if dev_acc >= best_dev_acc:
                best_dev_acc = dev_acc
                final_test_acc = test_acc
                best_dev_epoch = epoch_id
                if args.save_model:
                    torch.save([model, args], model_path +".{}".format(epoch_id))
                    with open(model_path +".{}.log.txt".format(epoch_id), 'w') as f:
                        for p in model.named_parameters():
                            print (p, file=f)
                    print(f'model saved to {model_path}')
            else:
                if args.save_model:
                    torch.save([model, args], model_path +".{}".format(epoch_id))
                    with open(model_path +".{}.log.txt".format(epoch_id), 'w') as f:
                        for p in model.named_parameters():
                            print (p, file=f)
                    print(f'model saved to {model_path}')
            model.train()
            start_time = time.time()
            if epoch_id > args.unfreeze_epoch and epoch_id - best_dev_epoch >= args.max_epochs_before_stop:
                break
Exemplo n.º 6
0
Arquivo: rn.py Projeto: zxlzr/MHGRN
def train(args):
    print(args)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available() and args.cuda:
        torch.cuda.manual_seed(args.seed)

    config_path = os.path.join(args.save_dir, 'config.json')
    model_path = os.path.join(args.save_dir, 'model.pt')
    log_path = os.path.join(args.save_dir, 'log.csv')
    export_config(args, config_path)
    check_path(model_path)
    with open(log_path, 'w') as fout:
        fout.write('step,train_acc,dev_acc\n')

    ###################################################################################################
    #   Load data                                                                                     #
    ###################################################################################################

    if 'lm' in args.ent_emb:
        print('Using contextualized embeddings for concepts')
        use_contextualized, cp_emb = True, None
    else:
        use_contextualized = False
    cp_emb = [np.load(path) for path in args.ent_emb_paths]
    cp_emb = torch.tensor(np.concatenate(cp_emb, 1))

    concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1)

    rel_emb = np.load(args.rel_emb_path)
    rel_emb = np.concatenate((rel_emb, -rel_emb), 0)
    rel_emb = cal_2hop_rel_emb(rel_emb)
    rel_emb = torch.tensor(rel_emb)
    relation_num, relation_dim = rel_emb.size(0), rel_emb.size(1)
    # print('| num_concepts: {} | num_relations: {} |'.format(concept_num, relation_num))

    device = torch.device(
        "cuda:0" if torch.cuda.is_available() and args.cuda else "cpu")

    dataset = LMRelationNetDataLoader(
        args.train_statements,
        args.train_rel_paths,
        args.dev_statements,
        args.dev_rel_paths,
        args.test_statements,
        args.test_rel_paths,
        batch_size=args.batch_size,
        eval_batch_size=args.eval_batch_size,
        device=device,
        model_name=args.encoder,
        max_tuple_num=args.max_tuple_num,
        max_seq_length=args.max_seq_len,
        is_inhouse=args.inhouse,
        inhouse_train_qids_path=args.inhouse_train_qids,
        use_contextualized=use_contextualized,
        train_adj_path=args.train_adj,
        dev_adj_path=args.dev_adj,
        test_adj_path=args.test_adj,
        train_node_features_path=args.train_node_features,
        dev_node_features_path=args.dev_node_features,
        test_node_features_path=args.test_node_features,
        node_feature_type=args.node_feature_type)

    ###################################################################################################
    #   Build model                                                                                   #
    ###################################################################################################

    lstm_config = get_lstm_config_from_args(args)
    model = LMRelationNet(model_name=args.encoder,
                          concept_num=concept_num,
                          concept_dim=relation_dim,
                          relation_num=relation_num,
                          relation_dim=relation_dim,
                          concept_in_dim=(dataset.get_node_feature_dim() if
                                          use_contextualized else concept_dim),
                          hidden_size=args.mlp_dim,
                          num_hidden_layers=args.mlp_layer_num,
                          num_attention_heads=args.att_head_num,
                          fc_size=args.fc_dim,
                          num_fc_layers=args.fc_layer_num,
                          dropout=args.dropoutm,
                          pretrained_concept_emb=cp_emb,
                          pretrained_relation_emb=rel_emb,
                          freeze_ent_emb=args.freeze_ent_emb,
                          init_range=args.init_range,
                          ablation=args.ablation,
                          use_contextualized=use_contextualized,
                          emb_scale=args.emb_scale,
                          encoder_config=lstm_config)

    try:
        model.to(device)
    except RuntimeError as e:
        print(e)
        print('best dev acc: 0.0 (at epoch 0)')
        print('final test acc: 0.0')
        print()
        return

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    grouped_parameters = [
        {
            'params': [
                p for n, p in model.encoder.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.encoder_lr
        },
        {
            'params': [
                p for n, p in model.encoder.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.encoder_lr
        },
        {
            'params': [
                p for n, p in model.decoder.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.decoder_lr
        },
        {
            'params': [
                p for n, p in model.decoder.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.decoder_lr
        },
    ]
    optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters)

    if args.lr_schedule == 'fixed':
        scheduler = ConstantLRSchedule(optimizer)
    elif args.lr_schedule == 'warmup_constant':
        scheduler = WarmupConstantSchedule(optimizer,
                                           warmup_steps=args.warmup_steps)
    elif args.lr_schedule == 'warmup_linear':
        max_steps = int(args.n_epochs *
                        (dataset.train_size() / args.batch_size))
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=args.warmup_steps,
                                         t_total=max_steps)

    print('parameters:')
    for name, param in model.decoder.named_parameters():
        if param.requires_grad:
            print('\t{:45}\ttrainable\t{}'.format(name, param.size()))
        else:
            print('\t{:45}\tfixed\t{}'.format(name, param.size()))
    num_params = sum(p.numel() for p in model.decoder.parameters()
                     if p.requires_grad)
    print('\ttotal:', num_params)

    if args.loss == 'margin_rank':
        loss_func = nn.MarginRankingLoss(margin=0.1, reduction='mean')
    elif args.loss == 'cross_entropy':
        loss_func = nn.CrossEntropyLoss(reduction='mean')

    ###################################################################################################
    #   Training                                                                                      #
    ###################################################################################################

    print()
    print('-' * 71)
    global_step, best_dev_epoch = 0, 0
    best_dev_acc, final_test_acc, total_loss = 0.0, 0.0, 0.0
    start_time = time.time()
    model.train()
    freeze_net(model.encoder)
    try:
        rel_grad = []
        linear_grad = []
        for epoch_id in range(args.n_epochs):
            if epoch_id == args.unfreeze_epoch:
                print('encoder unfreezed')
                unfreeze_net(model.encoder)
            if epoch_id == args.refreeze_epoch:
                print('encoder refreezed')
                freeze_net(model.encoder)
            model.train()
            for qids, labels, *input_data in dataset.train():
                optimizer.zero_grad()
                bs = labels.size(0)
                for a in range(0, bs, args.mini_batch_size):
                    b = min(a + args.mini_batch_size, bs)
                    logits, _ = model(*[x[a:b] for x in input_data],
                                      layer_id=args.encoder_layer)

                    if args.loss == 'margin_rank':
                        num_choice = logits.size(1)
                        flat_logits = logits.view(-1)
                        correct_mask = F.one_hot(
                            labels, num_classes=num_choice).view(
                                -1)  # of length batch_size*num_choice
                        correct_logits = flat_logits[
                            correct_mask == 1].contiguous().view(-1, 1).expand(
                                -1, num_choice - 1).contiguous().view(
                                    -1)  # of length batch_size*(num_choice-1)
                        wrong_logits = flat_logits[
                            correct_mask ==
                            0]  # of length batch_size*(num_choice-1)
                        y = wrong_logits.new_ones((wrong_logits.size(0), ))
                        loss = loss_func(correct_logits, wrong_logits,
                                         y)  # margin ranking loss
                    elif args.loss == 'cross_entropy':
                        loss = loss_func(logits, labels[a:b])
                    loss = loss * (b - a) / bs
                    loss.backward()
                    total_loss += loss.item()
                if args.max_grad_norm > 0:
                    nn.utils.clip_grad_norm_(model.parameters(),
                                             args.max_grad_norm)
                rel_grad.append(
                    model.decoder.rel_emb.weight.grad.abs().mean().item())
                linear_grad.append(model.decoder.mlp.layers[8].weight.grad.abs(
                ).mean().item())
                scheduler.step()
                optimizer.step()

                if (global_step + 1) % args.log_interval == 0:
                    total_loss /= args.log_interval
                    ms_per_batch = 1000 * (time.time() -
                                           start_time) / args.log_interval
                    print(
                        '| step {:5} |  lr: {:9.7f} | loss {:7.4f} | ms/batch {:7.2f} |'
                        .format(global_step,
                                scheduler.get_lr()[0], total_loss,
                                ms_per_batch))
                    # print('| rel_grad: {:1.2e} | linear_grad: {:1.2e} |'.format(sum(rel_grad) / len(rel_grad), sum(linear_grad) / len(linear_grad)))
                    total_loss = 0
                    rel_grad = []
                    linear_grad = []
                    start_time = time.time()
                global_step += 1

            model.eval()
            dev_acc = evaluate_accuracy(dataset.dev(), model)
            test_acc = evaluate_accuracy(
                dataset.test(), model) if args.test_statements else 0.0
            print('-' * 71)
            print('| epoch {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format(
                epoch_id, dev_acc, test_acc))
            print('-' * 71)
            with open(log_path, 'a') as fout:
                fout.write('{},{},{}\n'.format(global_step, dev_acc, test_acc))
            if dev_acc >= best_dev_acc:
                best_dev_acc = dev_acc
                final_test_acc = test_acc
                best_dev_epoch = epoch_id
                torch.save([model, args], model_path)
                print(f'model saved to {model_path}')
            model.train()
            start_time = time.time()
            if epoch_id > args.unfreeze_epoch and epoch_id - best_dev_epoch >= args.max_epochs_before_stop:
                break
    except (KeyboardInterrupt, RuntimeError) as e:
        print(e)

    print()
    print('training ends in {} steps'.format(global_step))
    print('best dev acc: {:.4f} (at epoch {})'.format(best_dev_acc,
                                                      best_dev_epoch))
    print('final test acc: {:.4f}'.format(final_test_acc))
    print()
Exemplo n.º 7
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      betas=(0.9, 0.98),
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    if args.linear_decay:
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=int(args.warmup_ratio *
                                                          t_total),
                                         t_total=t_total)
    else:
        scheduler = WarmupConstantSchedule(optimizer,
                                           warmup_steps=int(args.warmup_ratio *
                                                            t_total))
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    best_dev_acc, best_dev_loss = 0.0, 99999999999.0
    best_steps = 0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(
        args)  # Added here for reproductibility (even between python 2 and 3)
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        max_step = len(epoch_iterator)
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                'input_ids':
                batch[0][:, 0, :],
                'output_ids':
                batch[1][:, 0, :],
                'attention_mask':
                batch[2][:, 0, :],
                'token_type_ids':
                batch[3] if args.model_type in ['xlnet'] else
                None,  # XLM don't use segment_ids
            }
            mc_inputs = {
                'input_ids': batch[0].view(-1, batch[0].size(2)),
                'output_ids': batch[1].view(-1, batch[1].size(2)),
                'attention_mask': batch[2].view(-1, batch[2].size(2)),
                'token_type_ids': batch[3] if args.model_type in ['xlnet'] else
                None,  # XLM don't use segment_ids
                'labels': batch[4]
            }

            outputs = model(**inputs)

            lm_loss = outputs.sum() / ((outputs != 0).float().sum())
            loss = lm_loss

            #loss = mc_loss
            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()

            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0 or (
                    step + 1) == max_step:
                if args.max_grad_norm > 0:
                    if args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well

                        results = evaluate(args, model, tokenizer)

                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value,
                                                 global_step)
                        if results["eval_ppl"] > best_dev_acc:
                            best_dev_acc = results["eval_ppl"]
                            best_dev_loss = results["eval_avg_ppl"]
                            best_steps = global_step
                            if args.do_test:
                                results_test = evaluate(args,
                                                        model,
                                                        tokenizer,
                                                        test=True)
                                for key, value in results_test.items():
                                    tb_writer.add_scalar(
                                        'test_{}'.format(key), value,
                                        global_step)
                                logger.info(
                                    "test acc: %s, loss: %s, global steps: %s",
                                    str(results_test['eval_acc']),
                                    str(results_test['eval_loss']),
                                    str(global_step))
                    tb_writer.add_scalar('lr',
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logger.info(
                        "Average loss: %s at global step: %s",
                        str((tr_loss - logging_loss) / args.logging_steps),
                        str(global_step))
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(
                        model, 'module'
                    ) else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_vocabulary(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, 'training_args.bin'))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step, best_steps
Exemplo n.º 8
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    if args.optimizer == "adamw":
        optimizer = AdamW(optimizer_grouped_parameters,
                          betas=(0.9, 0.98),
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
    elif args.optimizer == "adam":
        optimizer = Adam(optimizer_grouped_parameters,
                         betas=(0.9, 0.98),
                         lr=args.learning_rate,
                         eps=args.adam_epsilon)
    else:
        raise NameError('incorrect optimizier')

    if args.linear_decay:
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=int(args.warmup_ratio *
                                                          t_total),
                                         t_total=t_total)
    else:
        scheduler = WarmupConstantSchedule(optimizer,
                                           warmup_steps=int(args.warmup_ratio *
                                                            t_total))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    best_dev_acc, best_dev_loss = 0.0, 99999999999.0
    best_steps = 0
    if_stop = False
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])

    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        max_step = len(epoch_iterator)
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                'input_ids':
                batch[0],
                'attention_mask':
                batch[1],
                'token_type_ids':
                batch[2] if args.model_type in ['bert', 'xlnet'] else
                None,  # XLM don't use segment_ids
                'labels':
                batch[3]
            }
            #print(batch[0].size())
            outputs = model(**inputs)
            loss = outputs[
                0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                #print(list(model.named_parameters())[0][1].grad)
                #torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                #                               args.max_grad_norm)
            else:
                loss.backward()
                #torch.nn.utils.clip_grad_norm_(model.parameters(),
                #                               args.max_grad_norm)

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0 or (
                    step + 1) == max_step:
                if args.max_grad_norm > 0:
                    if args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value,
                                                 global_step)
                        if args.early_stop:
                            if results["eval_acc"] > best_dev_acc:
                                best_dev_acc = results["eval_acc"]
                                logger.info("Saving model checkpoint to %s",
                                            args.output_dir)
                                # Save a trained model, configuration and tokenizer using `save_pretrained()`.
                                # They can then be reloaded using `from_pretrained()`
                                model_to_save = model.module if hasattr(
                                    model, 'module'
                                ) else model  # Take care of distributed/parallel training
                                model_to_save.save_pretrained(args.output_dir)
                                tokenizer.save_pretrained(args.output_dir)

                                # Good practice: save your training arguments together with the trained model
                                torch.save(
                                    args,
                                    os.path.join(args.output_dir,
                                                 'training_args.bin'))
                            else:
                                epoch_iterator.close()

                                if_stop = True
                                break

                    logger.info(
                        "Average loss: %s at global step: %s",
                        str((tr_loss - logging_loss) / args.logging_steps),
                        str(global_step))
                    logging_loss = tr_loss

        if if_stop:
            train_iterator.close()
            break
    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / (global_step + 1e-8), best_steps
Exemplo n.º 9
0
def train(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available() and args.cuda:
        torch.cuda.manual_seed(args.seed)

    print('configuration:')
    print('\n'.join('\t{:15} {}'.format(k + ':', str(v))
                    for k, v in sorted(dict(vars(args)).items())))
    print()

    config_path = os.path.join(args.save_dir, 'config.json')
    model_path = os.path.join(args.save_dir, 'model.pt')
    log_path = os.path.join(args.save_dir, 'log.csv')
    export_config(args, config_path)
    check_path(model_path)
    with open(log_path, 'w') as fout:
        fout.write('step,train_acc,dev_acc\n')

    dic = {'transe': 0, 'numberbatch': 1}
    cp_emb, rel_emb = [
        np.load(args.ent_emb_paths[dic[source]]) for source in args.ent_emb
    ], np.load(args.rel_emb_path)
    cp_emb = np.concatenate(cp_emb, axis=1)
    cp_emb = torch.tensor(cp_emb)
    rel_emb = np.concatenate((rel_emb, -rel_emb), 0)
    rel_emb = torch.tensor(rel_emb)
    concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1)
    print('num_concepts: {}, concept_dim: {}'.format(concept_num, concept_dim))
    relation_num, relation_dim = rel_emb.size(0), rel_emb.size(1)
    print('num_relations: {}, relation_dim: {}'.format(relation_num,
                                                       relation_dim))

    try:

        device0 = torch.device(
            "cuda:0" if torch.cuda.is_available() and args.cuda else "cpu")
        device1 = torch.device(
            "cuda:1" if torch.cuda.is_available() and args.cuda else "cpu")
        dataset = KagNetDataLoader(
            args.train_statements,
            args.train_paths,
            args.train_graphs,
            args.dev_statements,
            args.dev_paths,
            args.dev_graphs,
            args.test_statements,
            args.test_paths,
            args.test_graphs,
            batch_size=args.mini_batch_size,
            eval_batch_size=args.eval_batch_size,
            device=(device0, device1),
            model_name=args.encoder,
            max_seq_length=args.max_seq_len,
            max_path_len=args.max_path_len,
            is_inhouse=args.inhouse,
            inhouse_train_qids_path=args.inhouse_train_qids,
            use_cache=args.use_cache)
        print('dataset done')

        ###################################################################################################
        #   Build model                                                                                   #
        ###################################################################################################
        lstm_config = get_lstm_config_from_args(args)

        model = LMKagNet(model_name=args.encoder,
                         concept_dim=concept_dim,
                         relation_dim=relation_dim,
                         concept_num=concept_num,
                         relation_num=relation_num,
                         qas_encoded_dim=args.qas_encoded_dim,
                         pretrained_concept_emb=cp_emb,
                         pretrained_relation_emb=rel_emb,
                         lstm_dim=args.lstm_dim,
                         lstm_layer_num=args.lstm_layer_num,
                         graph_hidden_dim=args.graph_hidden_dim,
                         graph_output_dim=args.graph_output_dim,
                         dropout=args.dropout,
                         bidirect=args.bidirect,
                         num_random_paths=args.num_random_paths,
                         path_attention=args.path_attention,
                         qa_attention=args.qa_attention,
                         encoder_config=lstm_config)
        print('model done')
        if args.freeze_ent_emb:
            freeze_net(model.decoder.concept_emb)
        print('freezed')
        model.encoder.to(device0)
        print('encoder done')
        model.decoder.to(device1)
        print('decoder done')
    except RuntimeError as e:
        print(e)
        print('best dev acc: 0.0 (at epoch 0)')
        print('final test acc: 0.0')
        print()
        return

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    grouped_parameters = [
        {
            'params': [
                p for n, p in model.encoder.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.encoder_lr
        },
        {
            'params': [
                p for n, p in model.encoder.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.encoder_lr
        },
        {
            'params': [
                p for n, p in model.decoder.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.decoder_lr
        },
        {
            'params': [
                p for n, p in model.decoder.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.decoder_lr
        },
    ]
    optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters)

    if args.lr_schedule == 'fixed':
        scheduler = ConstantLRSchedule(optimizer)
    elif args.lr_schedule == 'warmup_constant':
        scheduler = WarmupConstantSchedule(optimizer,
                                           warmup_steps=args.warmup_steps)
    elif args.lr_schedule == 'warmup_linear':
        max_steps = int(args.n_epochs *
                        (dataset.train_size() / args.batch_size))
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=args.warmup_steps,
                                         t_total=max_steps)

    print('parameters:')
    for name, param in model.decoder.named_parameters():
        if param.requires_grad:
            print('\t{:45}\ttrainable\t{}'.format(name, param.size()))
        else:
            print('\t{:45}\tfixed\t{}'.format(name, param.size()))
    num_params = sum(p.numel() for p in model.decoder.parameters()
                     if p.requires_grad)
    print('\ttotal:', num_params)

    if args.loss == 'margin_rank':
        loss_func = nn.MarginRankingLoss(margin=0.1, reduction='mean')
    elif args.loss == 'cross_entropy':
        loss_func = nn.CrossEntropyLoss(reduction='mean')

    print()
    print('-' * 71)
    global_step, last_best_step = 0, 0
    best_dev_acc, final_test_acc, total_loss = 0.0, 0.0, 0.0
    start_time = time.time()
    model.train()
    freeze_net(model.encoder)
    try:
        for epoch_id in range(args.n_epochs):
            if epoch_id == args.unfreeze_epoch:
                unfreeze_net(model.encoder)
            if epoch_id == args.refreeze_epoch:
                freeze_net(model.encoder)
            for qids, labels, *input_data in dataset.train():
                optimizer.zero_grad()
                bs = labels.size(0)
                for a in range(0, bs, args.mini_batch_size):
                    print(00)
                    b = min(a + args.mini_batch_size, bs)
                    # print(11)
                    # # print([x.device if isinstance(x, (torch.tensor,)) else None for x in input_data])
                    # print(type(input_data[0]), type(input_data[0][0]), input_data[0][0].size())
                    # print(type(input_data[1]), type(input_data[1][0]), input_data[1][0].size())
                    # print(type(input_data[2]), type(input_data[2][0]), input_data[2][0].size())
                    # print(type(input_data[3]), type(input_data[3][0]), input_data[3][0].size())
                    # print(type(input_data[4]), type(input_data[4][0]))
                    # print(type(input_data[5]), type(input_data[5][0]))
                    # print(type(input_data[6]), type(input_data[6][0]))
                    # print(type(input_data[7]), type(input_data[7][0]))
                    # print(type(input_data[8]), type(input_data[8][0]))
                    # print(type(input_data[9]))
                    # print(type(input_data[10]))
                    logits, _ = model(*[x for x in input_data],
                                      layer_id=args.encoder_layer)

                    if args.loss == 'margin_rank':
                        num_choice = logits.size(1)
                        flat_logits = logits.view(-1)
                        correct_mask = F.one_hot(
                            labels, num_classes=num_choice).view(
                                -1)  # of length batch_size*num_choice
                        correct_logits = flat_logits[
                            correct_mask == 1].contiguous().view(-1, 1).expand(
                                -1, num_choice - 1).contiguous().view(
                                    -1)  # of length batch_size*(num_choice-1)
                        wrong_logits = flat_logits[
                            correct_mask ==
                            0]  # of length batch_size*(num_choice-1)
                        y = wrong_logits.new_ones((wrong_logits.size(0), ))
                        loss = loss_func(correct_logits, wrong_logits,
                                         y)  # margin ranking loss
                    elif args.loss == 'cross_entropy':
                        loss = loss_func(logits, labels[a:b])
                    loss = loss * (b - a) / bs
                    loss.backward()
                    total_loss += loss.item()
                if args.max_grad_norm > 0:
                    nn.utils.clip_grad_norm_(model.parameters(),
                                             args.max_grad_norm)
                scheduler.step()
                optimizer.step()

                if (global_step + 1) % args.log_interval == 0:
                    total_loss /= args.log_interval
                    ms_per_batch = 1000 * (time.time() -
                                           start_time) / args.log_interval
                    print(
                        '| step {:5} |  lr: {:9.7f} | loss {:7.4f} | ms/batch {:7.2f} |'
                        .format(global_step,
                                scheduler.get_lr()[0], total_loss,
                                ms_per_batch))
                    total_loss = 0
                    start_time = time.time()

                if (global_step + 1) % args.eval_interval == 0:
                    model.eval()
                    dev_acc = evaluate_accuracy(dataset.dev(), model)
                    test_acc = evaluate_accuracy(
                        dataset.test(), model) if args.test_statements else 0.0
                    print('-' * 71)
                    print('| step {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.
                          format(global_step, dev_acc, test_acc))
                    print('-' * 71)
                    with open(log_path, 'a') as fout:
                        fout.write('{},{},{}\n'.format(global_step, dev_acc,
                                                       test_acc))
                    if dev_acc >= best_dev_acc:
                        best_dev_acc = dev_acc
                        final_test_acc = test_acc
                        last_best_step = global_step
                        torch.save([model, args], model_path)
                        print(f'model saved to {model_path}')
                    model.train()
                    start_time = time.time()

                global_step += 1
                # if global_step >= args.max_steps or global_step - last_best_step >= args.max_steps_before_stop:
                #     end_flag = True
                #     break
    except (KeyboardInterrupt, RuntimeError) as e:
        print(e)

    print()
    print('training ends in {} steps'.format(global_step))
    print('best dev acc: {:.4f} (at step)'.format(best_dev_acc,
                                                  last_best_step))
    print('final test acc: {:.4f}'.format(final_test_acc))
Exemplo n.º 10
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriterP(args.output_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    if args.lr_decay:
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=args.warmup_steps,
                                         t_total=t_total)
    else:
        scheduler = WarmupConstantSchedule(optimizer,
                                           warmup_steps=args.warmup_steps)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    try:
        with open(os.path.join(args.model_name_or_path, 'step.txt'), 'r') as c:
            global_step = int(c.readline())
    except OSError as e:
        global_step = 0

    tr_loss, logging_loss = 0.0, 0.0
    moving_loss = MovingLoss(10000)
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(
        args)  # Added here for reproducibility (even between python 2 and 3)
    try:
        for _ in train_iterator:
            epoch_iterator = tqdm(train_dataloader,
                                  desc="Iteration",
                                  disable=args.local_rank not in [-1, 0])
            for step, batch in enumerate(epoch_iterator):
                inputs, labels = mask_tokens(
                    batch, tokenizer, args) if args.mlm else (batch, batch)
                inputs = inputs.to(args.device)
                labels = labels.to(args.device)
                model.train()
                outputs = model(
                    inputs, masked_lm_labels=labels) if args.mlm else model(
                        inputs, labels=labels)
                loss = outputs[
                    0]  # model outputs are always tuple in pytorch-transformers (see doc)

                if args.n_gpu > 1:
                    loss = loss.mean(
                    )  # mean() to average on multi-gpu parallel training
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                tr_loss += loss.item()
                moving_loss.add(loss.item())
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(optimizer), args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       args.max_grad_norm)
                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    model.zero_grad()
                    global_step += 1

                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training and global_step % args.eval_steps == 0:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer,
                                           f"step {global_step}")
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value,
                                                 global_step)

                    if args.local_rank in [
                            -1, 0
                    ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                        tb_writer.add_scalar('lr',
                                             scheduler.get_lr()[0],
                                             global_step)
                        tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                             args.logging_steps, global_step)
                        logging_loss = tr_loss
                        logger.info(
                            f"Moving loss {moving_loss.loss:.2f}, perplexity {torch.exp(torch.tensor(moving_loss.loss)):.2f}"
                        )

                    if args.local_rank in [
                            -1, 0
                    ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                        # Save model checkpoint
                        save_state(args, model, tokenizer, global_step)

                if args.max_steps > 0 and global_step > args.max_steps:
                    epoch_iterator.close()
                    break
            print_sample(model, tokenizer, args.device)
            if args.max_steps > 0 and global_step > args.max_steps:
                train_iterator.close()
                break
    except (KeyboardInterrupt, SystemExit):
        save_state(args, model, tokenizer, global_step)
        raise

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Exemplo n.º 11
0
def train(args, model, tokenizer):
    """ Train the model """
    if xm.is_master_ordinal():
        tb_writer = SummaryWriterP(args.output_dir)

    def summary_write(*args, **kwargs):
        if xm.is_master_ordinal():
            tb_writer.add_scalar(*args, **kwargs)

    args.train_batch_size = args.per_gpu_train_batch_size  #* max(1, args.n_gpu)

    train_dataloader = build_dataloader(args, tokenizer)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader)) + 1
    else:
        t_total = len(train_dataloader) * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if p.requires_grad and not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if p.requires_grad and any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    # Scale learning rate to num cores
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate * xm.xrt_world_size(),
                      eps=args.adam_epsilon)
    warmup_steps = args.warmup_samples // (args.train_batch_size *
                                           xm.xrt_world_size())
    if args.lr_decay:
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=warmup_steps,
                                         t_total=t_total)
    elif args.lr_cosine:
        scheduler = WarmupCosineWithHardRestartsSchedule(
            optimizer,
            warmup_steps=warmup_steps,
            t_total=t_total,
            cycles=args.num_train_epochs)
    else:
        scheduler = WarmupConstantSchedule(optimizer,
                                           warmup_steps=warmup_steps)

    # Train!
    tracker = xm.RateTracker()
    log_info("***** Running training *****")
    log_info("  Num Epochs = %d", args.num_train_epochs)
    log_info("  Instantaneous batch size per GPU = %d",
             args.per_gpu_train_batch_size)
    log_info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size *
        (xm.xrt_world_size() if args.local_rank != -1 else 1))
    log_info("  Total optimization steps = %d", t_total)

    try:
        with open(os.path.join(args.model_name_or_path, 'step.txt'), 'r') as c:
            global_step = int(c.readline())
    except OSError as e:
        global_step = 0

    moving_loss = MovingLoss(1000 // args.logging_steps)

    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=not xm.is_master_ordinal())
    try:
        for epoch in train_iterator:
            p_train_dataloader = pl.ParallelLoader(train_dataloader,
                                                   [args.device])
            epoch_iterator = tqdm(p_train_dataloader.per_device_loader(
                args.device),
                                  total=len(train_dataloader),
                                  desc="Iteration",
                                  disable=not xm.is_master_ordinal())

            model.train()
            for step, batch in enumerate(epoch_iterator):
                optimizer.zero_grad()
                outputs = model(batch, labels=batch)
                loss = outputs[
                    0]  # model outputs are always tuple in pytorch-transformers (see doc)

                loss.backward()
                xm.optimizer_step(optimizer)
                scheduler.step()

                if step > 100:
                    epoch_iterator.close()
                    break

            # evaluate once in an epoch
            if args.evaluate_during_training:
                log_info(f"Eval {evaluate(args, model, tokenizer)}")

    except (KeyboardInterrupt, SystemExit):
        save_state(args, model, tokenizer, global_step)
        raise

    save_state(args, model, tokenizer, global_step)

    return global_step, moving_loss.loss