Ejemplo n.º 1
0
    def test_constant_scheduler(self):
        scheduler = ConstantLRSchedule(self.optimizer)
        lrs = unwrap_schedule(scheduler, self.num_steps)
        expected_learning_rates = [10.] * self.num_steps
        self.assertEqual(len(lrs[0]), 1)
        self.assertListEqual([l[0] for l in lrs], expected_learning_rates)

        scheduler = ConstantLRSchedule(self.optimizer)
        lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
        self.assertListEqual([l[0] for l in lrs], [l[0] for l in lrs_2])
Ejemplo n.º 2
0
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.download()
        self.train_steps = 0
        self.checkpoint_steps = 500
        self.model_ckpt = str(self.model_dir.name)
        self.distilbert = 'distilbert' in self.model_ckpt

        if os.path.exists(os.path.join(self.model_ckpt, 'config.json')):
            self.logger.info('Loading from checkpoint %s' % self.model_ckpt)
            self.model_config = AutoConfig.from_pretrained(self.model_ckpt)
        elif os.path.exists(os.path.join(self.data_dir, 'config.json')):
            self.logger.info('Loading from trained model in %s' %
                             self.data_dir)
            self.model_ckpt = self.data_dir
            self.model_config = AutoConfig.from_pretrained(self.model_ckpt)
        else:
            self.logger.info(
                'Initializing new model with pretrained weights %s' %
                self.model_ckpt)
            self.model_config = AutoConfig.from_pretrained(self.model_ckpt)
            self.model_config.num_labels = 1  # set up for regression

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        if self.device == torch.device("cpu"):
            self.logger.info("RUNNING ON CPU")
        else:
            self.logger.info("RUNNING ON CUDA")
            torch.cuda.synchronize(self.device)

        self.rerank_model = AutoModelForSequenceClassification.from_pretrained(
            self.model_ckpt, config=self.model_config)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_ckpt)
        self.rerank_model.to(self.device, non_blocking=True)

        self.optimizer = AdamW(self.rerank_model.parameters(),
                               lr=self.lr,
                               correct_bias=False)
        self.scheduler = ConstantLRSchedule(self.optimizer)

        self.weight = 1.0
Ejemplo n.º 3
0
    def init_optimizer(self, model, lr, t_total, fixed=None):
        args = self.args
        no_decay = ['bias', 'LayerNorm.weight']
        if fixed is None: fixed = []
        optimizer_grouped_parameters = [{
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n
                           for nd in no_decay) and not any(f in n
                                                           for f in fixed)
            ],
            "weight_decay":
            args.weight_decay
        }, {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay) and not any(f in n
                                                               for f in fixed)
            ],
            "weight_decay":
            0.0
        }]
        # TODO calculate t_total
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=lr,
                          eps=args.adam_epsilon)

        if args.scheduler == "linear":
            warmup_steps = t_total * args.warmup_ratio if args.warmup_steps == -1 else args.warmup_steps
            logger.info(
                "Setting scheduler, warmups=%d, lr=%.7f, total_updates=%d" %
                (warmup_steps, lr, t_total))
            scheduler = WarmupLinearSchedule(optimizer,
                                             warmup_steps=warmup_steps,
                                             t_total=t_total)
        elif args.scheduler == "constant":
            logger.info("Setting scheduler, ConstantLRSchedule")
            scheduler = ConstantLRSchedule(optimizer)
        else:
            raise ValueError
        return optimizer_grouped_parameters, optimizer, scheduler
Ejemplo n.º 4
0
def train(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available() and args.cuda:
        torch.cuda.manual_seed(args.seed)

    print('configuration:')
    print('\n'.join('\t{:15} {}'.format(k + ':', str(v))
                    for k, v in sorted(dict(vars(args)).items())))
    print()

    config_path = os.path.join(args.save_dir, 'config.json')
    model_path = os.path.join(args.save_dir, 'model.pt')
    log_path = os.path.join(args.save_dir, 'log.csv')
    if args.save:
        export_config(args, config_path)
        check_path(model_path)
        with open(log_path, 'w') as fout:
            fout.write('step,train_acc,dev_acc\n')

    ###################################################################################################
    #   Load data                                                                                     #
    ###################################################################################################

    cp_emb = [np.load(path) for path in args.ent_emb_paths]
    cp_emb = torch.tensor(np.concatenate(cp_emb, 1))

    concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1)
    print('num_concepts: {}, concept_dim: {}'.format(concept_num, concept_dim))

    device = torch.device(
        "cuda:0" if torch.cuda.is_available() and args.cuda else "cpu")

    dataset = GconAttnDataLoader(
        train_statement_path=args.train_statements,
        train_concept_jsonl=args.train_concepts,
        dev_statement_path=args.dev_statements,
        dev_concept_jsonl=args.dev_concepts,
        test_statement_path=args.test_statements,
        test_concept_jsonl=args.test_concepts,
        concept2id_path=args.cpnet_vocab_path,
        batch_size=args.batch_size,
        eval_batch_size=args.eval_batch_size,
        device=device,
        model_name=args.encoder,
        max_cpt_num=max_cpt_num[args.dataset],
        max_seq_length=args.max_seq_len,
        is_inhouse=args.inhouse,
        inhouse_train_qids_path=args.inhouse_train_qids,
        subsample=args.subsample,
        format=args.format)

    print('len(train_set): {}   len(dev_set): {}   len(test_set): {}'.format(
        dataset.train_size(), dataset.dev_size(), dataset.test_size()))
    print()

    ###################################################################################################
    #   Build model                                                                                   #
    ###################################################################################################

    lstm_config = get_lstm_config_from_args(args)
    model = LMGconAttn(model_name=args.encoder,
                       concept_num=concept_num,
                       concept_dim=args.cpt_out_dim,
                       concept_in_dim=concept_dim,
                       freeze_ent_emb=args.freeze_ent_emb,
                       pretrained_concept_emb=cp_emb,
                       hidden_dim=args.decoder_hidden_dim,
                       dropout=args.dropoutm,
                       encoder_config=lstm_config)

    if args.freeze_ent_emb:
        freeze_net(model.decoder.concept_emb)

    try:
        model.to(device)
    except RuntimeError as e:
        print(e)
        print('best dev acc: 0.0 (at epoch 0)')
        print('final test acc: 0.0')
        print()
        return

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    grouped_parameters = [
        {
            'params': [
                p for n, p in model.encoder.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.encoder_lr
        },
        {
            'params': [
                p for n, p in model.encoder.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.encoder_lr
        },
        {
            'params': [
                p for n, p in model.decoder.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.decoder_lr
        },
        {
            'params': [
                p for n, p in model.decoder.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.decoder_lr
        },
    ]
    optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters)

    if args.lr_schedule == 'fixed':
        scheduler = ConstantLRSchedule(optimizer)
    elif args.lr_schedule == 'warmup_constant':
        scheduler = WarmupConstantSchedule(optimizer,
                                           warmup_steps=args.warmup_steps)
    elif args.lr_schedule == 'warmup_linear':
        max_steps = int(args.n_epochs *
                        (dataset.train_size() / args.batch_size))
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=args.warmup_steps,
                                         t_total=max_steps)

    print('parameters:')
    for name, param in model.decoder.named_parameters():
        if param.requires_grad:
            print('\t{:45}\ttrainable\t{}'.format(name, param.size()))
        else:
            print('\t{:45}\tfixed\t{}'.format(name, param.size()))
    num_params = sum(p.numel() for p in model.decoder.parameters()
                     if p.requires_grad)
    print('\ttotal:', num_params)

    if args.loss == 'margin_rank':
        loss_func = nn.MarginRankingLoss(margin=0.1, reduction='mean')
    elif args.loss == 'cross_entropy':
        loss_func = nn.CrossEntropyLoss(reduction='mean')

    ###################################################################################################
    #   Training                                                                                      #
    ###################################################################################################

    print('-' * 71)
    global_step, best_dev_epoch = 0, 0
    best_dev_acc, final_test_acc, total_loss = 0.0, 0.0, 0.0
    start_time = time.time()
    model.train()
    freeze_net(model.encoder)
    try:
        for epoch_id in range(args.n_epochs):
            if epoch_id == args.unfreeze_epoch:
                unfreeze_net(model.encoder)
            if epoch_id == args.refreeze_epoch:
                freeze_net(model.encoder)
            model.train()
            for qids, labels, *input_data in dataset.train():
                optimizer.zero_grad()
                bs = labels.size(0)
                for a in range(0, bs, args.mini_batch_size):
                    b = min(a + args.mini_batch_size, bs)
                    logits, _ = model(*[x[a:b] for x in input_data],
                                      layer_id=args.encoder_layer)

                    if args.loss == 'margin_rank':
                        num_choice = logits.size(1)
                        flat_logits = logits.view(-1)
                        correct_mask = F.one_hot(
                            labels, num_classes=num_choice).view(
                                -1)  # of length batch_size*num_choice
                        correct_logits = flat_logits[
                            correct_mask == 1].contiguous().view(-1, 1).expand(
                                -1, num_choice - 1).contiguous().view(
                                    -1)  # of length batch_size*(num_choice-1)
                        wrong_logits = flat_logits[
                            correct_mask ==
                            0]  # of length batch_size*(num_choice-1)
                        y = wrong_logits.new_ones((wrong_logits.size(0), ))
                        loss = loss_func(correct_logits, wrong_logits,
                                         y)  # margin ranking loss
                    elif args.loss == 'cross_entropy':
                        loss = loss_func(logits, labels[a:b])
                    loss = loss * (b - a) / bs
                    loss.backward()
                    total_loss += loss.item()
                if args.max_grad_norm > 0:
                    nn.utils.clip_grad_norm_(model.parameters(),
                                             args.max_grad_norm)
                scheduler.step()
                optimizer.step()

                if (global_step + 1) % args.log_interval == 0:
                    total_loss /= args.log_interval
                    ms_per_batch = 1000 * (time.time() -
                                           start_time) / args.log_interval
                    print(
                        '| step {:5} |  lr: {:9.7f} | loss {:7.4f} | ms/batch {:7.2f} |'
                        .format(global_step,
                                scheduler.get_lr()[0], total_loss,
                                ms_per_batch))
                    total_loss = 0
                    start_time = time.time()
                global_step += 1

            model.eval()
            dev_acc = evaluate_accuracy(dataset.dev(), model)
            test_acc = evaluate_accuracy(
                dataset.test(), model) if args.test_statements else 0.0
            print('-' * 71)
            print('| step {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format(
                global_step, dev_acc, test_acc))
            print('-' * 71)
            if args.save:
                with open(log_path, 'a') as fout:
                    fout.write('{},{},{}\n'.format(global_step, dev_acc,
                                                   test_acc))
            if dev_acc >= best_dev_acc:
                best_dev_acc = dev_acc
                final_test_acc = test_acc
                best_dev_epoch = epoch_id
                if args.save:
                    torch.save([model, args], model_path)
                    print(f'model saved to {model_path}')
            model.train()
            start_time = time.time()
            if epoch_id > args.unfreeze_epoch and epoch_id - best_dev_epoch >= args.max_epochs_before_stop:
                break
    except (KeyboardInterrupt, RuntimeError) as e:
        print(e)

    print()
    print('training ends in {} steps'.format(global_step))
    print('best dev acc: {:.4f} (at epoch {})'.format(best_dev_acc,
                                                      best_dev_epoch))
    print('final test acc: {:.4f}'.format(final_test_acc))
    print()
Ejemplo n.º 5
0
Archivo: lm.py Proyecto: zxlzr/MHGRN
def train(args):
    print(args)
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available() and args.cuda:
        torch.cuda.manual_seed(args.seed)

    model_path = os.path.join(args.save_dir, 'model.pt')
    check_path(model_path)

    ###################################################################################################
    #   Load data                                                                                     #
    ###################################################################################################

    device = torch.device("cuda:0" if torch.cuda.is_available() and args.cuda else "cpu")

    dataset = LMDataLoader(args.train_statements, args.dev_statements, args.test_statements,
                           batch_size=args.batch_size, eval_batch_size=args.eval_batch_size, device=device,
                           model_name=args.encoder,
                           max_seq_length=args.max_seq_len,
                           is_inhouse=args.inhouse, inhouse_train_qids_path=args.inhouse_train_qids, subsample=args.subsample)

    ###################################################################################################
    #   Build model                                                                                   #
    ###################################################################################################

    lstm_config = get_lstm_config_from_args(args)
    model = LMForMultipleChoice(args.encoder, from_checkpoint=args.from_checkpoint, encoder_config=lstm_config)

    try:
        model.to(device)
    except RuntimeError as e:
        print(e)
        print('best dev acc: 0.0 (at epoch 0)')
        print('final test acc: 0.0')
        print()
        return

    no_decay = ['bias', 'LayerNorm.weight']
    grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'lr': args.encoder_lr, 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'lr': args.encoder_lr, 'weight_decay': 0.0}
    ]
    optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters)

    if args.lr_schedule == 'fixed':
        scheduler = ConstantLRSchedule(optimizer)
    elif args.lr_schedule == 'warmup_constant':
        scheduler = WarmupConstantSchedule(optimizer, warmup_steps=args.warmup_steps)
    elif args.lr_schedule == 'warmup_linear':
        max_steps = int(args.n_epochs * (dataset.train_size() / args.batch_size))
        scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=max_steps)

    if args.loss == 'margin_rank':
        loss_func = nn.MarginRankingLoss(margin=0.1, reduction='mean')
    elif args.loss == 'cross_entropy':
        loss_func = nn.CrossEntropyLoss(reduction='mean')

    ###################################################################################################
    #   Training                                                                                      #
    ###################################################################################################

    print()
    print('***** running training *****')
    print(f'| batch_size: {args.batch_size} | num_epochs: {args.n_epochs} | num_train: {dataset.train_size()} |'
          f' num_dev: {dataset.dev_size()} | num_test: {dataset.test_size()}')

    global_step = 0
    best_dev_acc = 0
    best_dev_epoch = 0
    final_test_acc = 0
    try:
        for epoch in range(int(args.n_epochs)):
            model.train()
            tqdm_bar = tqdm(dataset.train(), desc="Training")
            for qids, labels, *input_data in tqdm_bar:
                optimizer.zero_grad()
                batch_loss = 0
                bs = labels.size(0)
                for a in range(0, bs, args.mini_batch_size):
                    b = min(a + args.mini_batch_size, bs)
                    logits = model(*[x[a:b] for x in input_data], layer_id=args.encoder_layer)
                    if args.loss == 'margin_rank':
                        num_choice = logits.size(1)
                        flat_logits = logits.view(-1)
                        correct_mask = F.one_hot(labels, num_classes=num_choice).view(-1)  # of length batch_size*num_choice
                        correct_logits = flat_logits[correct_mask == 1].contiguous().view(-1, 1).expand(-1, num_choice - 1).contiguous().view(-1)  # of length batch_size*(num_choice-1)
                        wrong_logits = flat_logits[correct_mask == 0]  # of length batch_size*(num_choice-1)
                        y = wrong_logits.new_ones((wrong_logits.size(0),))
                        loss = loss_func(correct_logits, wrong_logits, y)  # margin ranking loss
                    elif args.loss == 'cross_entropy':
                        loss = loss_func(logits, labels[a:b])
                    loss = loss * (b - a) / bs
                    loss.backward()
                    batch_loss += loss.item()
                if args.max_grad_norm > 0:
                    nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                optimizer.step()
                scheduler.step()
                tqdm_bar.desc = "loss: {:.2e}  lr: {:.2e}".format(batch_loss, scheduler.get_lr()[0])
                global_step += 1

            model.eval()
            dev_acc = evaluate_accuracy(dataset.dev(), model)
            test_acc = evaluate_accuracy(dataset.test(), model) if dataset.test_size() > 0 else 0.0
            if dev_acc > best_dev_acc:
                final_test_acc = test_acc
                best_dev_acc = dev_acc
                best_dev_epoch = epoch
                torch.save([model, args], model_path)
            print('| epoch {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format(epoch, dev_acc, test_acc))
            if epoch - best_dev_epoch >= args.max_epochs_before_stop:
                break
    except (KeyboardInterrupt, RuntimeError) as e:
        print(e)

    print('***** training ends *****')
    print()
    print('training ends in {} steps'.format(global_step))
    print('best dev acc: {:.4f} (at epoch {})'.format(best_dev_acc, best_dev_epoch))
    print('final test acc: {:.4f}'.format(final_test_acc))
    print()
def train(args):
    print(args)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available() and args.cuda:
        torch.cuda.manual_seed(args.seed)

    config_path = os.path.join(args.save_dir, 'config.json')
    model_path = os.path.join(args.save_dir, 'model.pt')
    log_path = os.path.join(args.save_dir, 'log.csv')
    export_config(args, config_path)
    check_path(model_path)
    with open(log_path, 'w') as fout:
        fout.write('step,train_acc,dev_acc\n')

    ###################################################################################################
    #   Load data                                                                                     #
    ###################################################################################################
    if 'lm' in args.ent_emb:
        print('Using contextualized embeddings for concepts')
        use_contextualized = True
    else:
        use_contextualized = False
    cp_emb = [np.load(path) for path in args.ent_emb_paths]
    cp_emb = torch.tensor(np.concatenate(cp_emb, 1), dtype=torch.float)
    concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1)
    print('| num_concepts: {} |'.format(concept_num))



    device = torch.device("cuda:0" if torch.cuda.is_available() and args.cuda else "cpu")
    dataset = LMGraphRelationNetDataLoader(args.train_statements, args.train_adj,
                                           args.dev_statements, args.dev_adj,
                                           args.test_statements, args.test_adj,
                                           batch_size=args.batch_size, eval_batch_size=args.eval_batch_size,
                                           device=(device, device),
                                           model_name=args.encoder,
                                           max_node_num=args.max_node_num, max_seq_length=args.max_seq_len,
                                           is_inhouse=args.inhouse, inhouse_train_qids_path=args.inhouse_train_qids,
                                           use_contextualized=use_contextualized,
                                           train_embs_path=args.train_embs, dev_embs_path=args.dev_embs,
                                           test_embs_path=args.test_embs,
                                           subsample=args.subsample, format=args.format)

    ###################################################################################################
    #   Build model                                                                                   #
    ###################################################################################################

    lstm_config = get_timeline_config(args)
    model = LMGraphRelationNet(args.encoder, k=args.k, n_type=3, n_basis=args.num_basis, n_layer=args.gnn_layer_num,
                               diag_decompose=args.diag_decompose, n_concept=concept_num,
                               n_relation=args.num_relation, concept_dim=args.gnn_dim,
                               concept_in_dim=(
                                   dataset.get_node_feature_dim() if use_contextualized else concept_dim),
                               n_attention_head=args.att_head_num, fc_dim=args.fc_dim, n_fc_layer=args.fc_layer_num,
                               att_dim=args.att_dim, att_layer_num=args.att_layer_num,
                               p_emb=args.dropouti, p_gnn=args.dropoutg, p_fc=args.dropoutf,
                               pretrained_concept_emb=cp_emb, freeze_ent_emb=args.freeze_ent_emb,
                               ablation=args.ablation, init_range=args.init_range,
                               eps=args.eps, use_contextualized=use_contextualized,
                               do_init_rn=args.init_rn, do_init_identity=args.init_identity,
                               encoder_config=lstm_config)
    model.to(device)

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    if args.fix_trans:
        no_decay.append('trans_scores')
    grouped_parameters = [
        {'params': [p for n, p in model.encoder.named_parameters() if not any(nd in n for nd in no_decay)],
         'weight_decay': args.weight_decay, 'lr': args.encoder_lr},
        {'params': [p for n, p in model.encoder.named_parameters() if any(nd in n for nd in no_decay)],
         'weight_decay': 0.0, 'lr': args.encoder_lr},
        {'params': [p for n, p in model.decoder.named_parameters() if not any(nd in n for nd in no_decay)],
         'weight_decay': args.weight_decay, 'lr': args.decoder_lr},
        {'params': [p for n, p in model.decoder.named_parameters() if any(nd in n for nd in no_decay)],
         'weight_decay': 0.0, 'lr': args.decoder_lr},
    ]
    optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters)

    if args.lr_schedule == 'fixed':
        scheduler = ConstantLRSchedule(optimizer)
    elif args.lr_schedule == 'warmup_constant':
        scheduler = WarmupConstantSchedule(optimizer, warmup_steps=args.warmup_steps)
    elif args.lr_schedule == 'warmup_linear':
        max_steps = int(args.n_epochs * (dataset.train_size() / args.batch_size))
        scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=max_steps)

    print('encoder parameters:')
    for name, param in model.encoder.named_parameters():
        if param.requires_grad:
            print('\t{:45}\ttrainable\t{}'.format(name, param.size()))
        else:
            print('\t{:45}\tfixed\t{}'.format(name, param.size()))
    num_params = sum(p.numel() for p in model.encoder.parameters() if p.requires_grad)
    print('\ttotal:', num_params)

    print('decoder parameters:')
    for name, param in model.decoder.named_parameters():
        if param.requires_grad:
            print('\t{:45}\ttrainable\t{}'.format(name, param.size()))
        else:
            print('\t{:45}\tfixed\t{}'.format(name, param.size()))
    num_params = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad)
    print('\ttotal:', num_params)

    if args.loss == 'margin_rank':
        loss_func = nn.MarginRankingLoss(margin=0.1, reduction='mean')
    elif args.loss == 'BCE':
        loss_func = nn.BCEWithLogitsLoss(reduction='mean')
    elif args.loss == 'cross_entropy':
        loss_func = nn.CrossEntropyLoss(reduction='mean')

    ###################################################################################################
    #   Training                                                                                      #
    ###################################################################################################

    print()
    print('-' * 71)
    global_step, best_dev_epoch = 0, 0
    best_dev_auc, final_test_auc, total_loss = 0.0, 0.0, 0.0
    start_time = time.time()
    model.train()
    for epoch_id in range(args.n_epochs):
        print('epoch: {:5} '.format(epoch_id))

        model.train()
        for qids, labels, *input_data in dataset.train():
            optimizer.zero_grad()
            bs = labels.size(0)
            for a in range(0, bs, args.mini_batch_size):
                b = min(a + args.mini_batch_size, bs)
                logits, _ = model(*[x[a:b] for x in input_data], layer_id=args.encoder_layer)

                if args.loss == 'margin_rank':
                    num_choice = logits.size(1)
                    flat_logits = logits.view(-1)
                    correct_mask = F.one_hot(labels, num_classes=num_choice).view(
                        -1)  # of length batch_size*num_choice
                    correct_logits = flat_logits[correct_mask == 1].contiguous().view(-1, 1).expand(-1,
                                                                                                    num_choice - 1).contiguous().view(
                        -1)  # of length batch_size*(num_choice-1)
                    wrong_logits = flat_logits[correct_mask == 0]  # of length batch_size*(num_choice-1)
                    y = wrong_logits.new_ones((wrong_logits.size(0),))
                    loss = loss_func(correct_logits, wrong_logits, y)  # margin ranking loss
                elif args.loss == 'cross_entropy':

                    loss = loss_func(logits, labels[a:b])
                loss = loss * (b - a) / bs
                loss.backward()
                total_loss += loss.item()
            if args.max_grad_norm > 0:
                nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
            scheduler.step()
            optimizer.step()

            if (global_step + 1) % args.log_interval == 0:
                total_loss /= args.log_interval
                ms_per_batch = 1000 * (time.time() - start_time) / args.log_interval
                print('| step {:5} |  lr: {:9.7f} | loss {:7.4f} | ms/batch {:7.2f} |'.format(global_step,
                                                                                              scheduler.get_lr()[0],
                                                                                              total_loss,
                                                                                              ms_per_batch))
                total_loss = 0.0
                start_time = time.time()
            global_step += 1

        model.eval()
        dev_acc, d_precision, d_recall, d_f1, d_roc_auc = eval_metric(dataset.dev(), model)
        test_acc, t_precision, t_recall, t_f1, t_roc_auc = eval_metric(dataset.test(), model)
        if global_step % args.log_interval == 0:
            tl = total_loss
        else:
            tl = total_loss / (global_step % args.log_interval)
        print('-' * 71)
        print('| step {:5} | dev_acc {:7.4f} | test_acc {:7.4f} | loss {:7.4f} '.format(global_step,
                                                                                        dev_acc,
                                                                                        test_acc,
                                                                                        tl))
        print(
            '| step {:5} | dev_precision {:7.4f} | test_precision {:7.4f} | loss {:7.4f} '.format(
                global_step,
                d_precision,
                t_precision,
                tl))
        print('| step {:5} | dev_recall {:7.4f} | test_recall {:7.4f} | loss {:7.4f} '.format(
            global_step,
            d_recall,
            t_recall,
            tl))
        print('| step {:5} | dev_f1 {:7.4f} | test_f1 {:7.4f} | loss {:7.4f} '.format(global_step,
                                                                                      d_f1,
                                                                                      t_f1,
                                                                                      tl))
        print('| step {:5} | dev_auc {:7.4f} | test_auc {:7.4f} | loss {:7.4f} '.format(global_step,
                                                                                        d_roc_auc,
                                                                                        t_roc_auc,
                                                                                        tl))
        print('-' * 71)
        with open(log_path, 'a') as fout:
            fout.write('{},{},{}\n'.format(global_step, d_roc_auc, t_roc_auc))
        if d_roc_auc >= best_dev_auc:
            best_dev_auc = d_roc_auc
            final_test_auc = t_roc_auc
            best_dev_epoch = epoch_id
            torch.save([model, args], model_path)
            print(f'model saved to {model_path}')
        model.train()
        start_time = time.time()
        if epoch_id > args.unfreeze_epoch and epoch_id - best_dev_epoch >= args.max_epochs_before_stop:
          
            break


    print()
    print('training ends in {} steps'.format(global_step))
    print('best dev auc: {:.4f} (at epoch {})'.format(best_dev_auc, best_dev_epoch))
    print('final test auc: {:.4f}'.format(final_test_auc))
    print()
Ejemplo n.º 7
0
def train(args):
    print(args)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available() and args.cuda:
        torch.cuda.manual_seed(args.seed)

    config_path = os.path.join(args.save_dir, 'config.json')
    model_path = os.path.join(args.save_dir, 'model.pt')
    log_path = os.path.join(args.save_dir, 'log.csv')
    export_config(args, config_path)
    check_path(model_path)
    with open(log_path, 'w') as fout:
        fout.write('step,dev_acc,test_acc\n')

    ###################################################################################################
    #   Load data                                                                                     #
    ###################################################################################################
    cp_emb = [np.load(path) for path in args.ent_emb_paths]
    cp_emb = torch.tensor(np.concatenate(cp_emb, 1), dtype=torch.float)

    concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1)
    print('| num_concepts: {} |'.format(concept_num))

    # try:
    if True:
        if torch.cuda.device_count() >= 2 and args.cuda:
            device0 = torch.device("cuda:0")
            device1 = torch.device("cuda:1")
        elif torch.cuda.device_count() == 1 and args.cuda:
            device0 = torch.device("cuda:0")
            device1 = torch.device("cuda:0")
        else:
            device0 = torch.device("cpu")
            device1 = torch.device("cpu")
        dataset = LM_QAGNN_DataLoader(args, args.train_statements, args.train_adj,
                                               args.dev_statements, args.dev_adj,
                                               args.test_statements, args.test_adj,
                                               batch_size=args.batch_size, eval_batch_size=args.eval_batch_size,
                                               device=(device0, device1),
                                               model_name=args.encoder,
                                               max_node_num=args.max_node_num, max_seq_length=args.max_seq_len,
                                               is_inhouse=args.inhouse, inhouse_train_qids_path=args.inhouse_train_qids,
                                               subsample=args.subsample, use_cache=args.use_cache)

        ###################################################################################################
        #   Build model                                                                                   #
        ###################################################################################################

        model = LM_QAGNN(args, args.encoder, k=args.k, n_ntype=4, n_etype=args.num_relation, n_concept=concept_num,
                                   concept_dim=args.gnn_dim,
                                   concept_in_dim=concept_dim,
                                   n_attention_head=args.att_head_num, fc_dim=args.fc_dim, n_fc_layer=args.fc_layer_num,
                                   p_emb=args.dropouti, p_gnn=args.dropoutg, p_fc=args.dropoutf,
                                   pretrained_concept_emb=cp_emb, freeze_ent_emb=args.freeze_ent_emb,
                                   init_range=args.init_range,
                                   encoder_config={})
        model.encoder.to(device0)
        model.decoder.to(device1)


    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']

    grouped_parameters = [
        {'params': [p for n, p in model.encoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.encoder_lr},
        {'params': [p for n, p in model.encoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.encoder_lr},
        {'params': [p for n, p in model.decoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.decoder_lr},
        {'params': [p for n, p in model.decoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.decoder_lr},
    ]
    optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters)

    if args.lr_schedule == 'fixed':
        try:
            scheduler = ConstantLRSchedule(optimizer)
        except:
            scheduler = get_constant_schedule(optimizer)
    elif args.lr_schedule == 'warmup_constant':
        try:
            scheduler = WarmupConstantSchedule(optimizer, warmup_steps=args.warmup_steps)
        except:
            scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps)
    elif args.lr_schedule == 'warmup_linear':
        max_steps = int(args.n_epochs * (dataset.train_size() / args.batch_size))
        try:
            scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=max_steps)
        except:
            scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=max_steps)

    print('parameters:')
    for name, param in model.decoder.named_parameters():
        if param.requires_grad:
            print('\t{:45}\ttrainable\t{}\tdevice:{}'.format(name, param.size(), param.device))
        else:
            print('\t{:45}\tfixed\t{}\tdevice:{}'.format(name, param.size(), param.device))
    num_params = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad)
    print('\ttotal:', num_params)

    if args.loss == 'margin_rank':
        loss_func = nn.MarginRankingLoss(margin=0.1, reduction='mean')
    elif args.loss == 'cross_entropy':
        loss_func = nn.CrossEntropyLoss(reduction='mean')

    ###################################################################################################
    #   Training                                                                                      #
    ###################################################################################################

    print()
    print('-' * 71)
    global_step, best_dev_epoch = 0, 0
    best_dev_acc, final_test_acc, total_loss = 0.0, 0.0, 0.0
    start_time = time.time()
    model.train()
    freeze_net(model.encoder)
    if True:
    # try:
        for epoch_id in range(args.n_epochs):
            if epoch_id == args.unfreeze_epoch:
                unfreeze_net(model.encoder)
            if epoch_id == args.refreeze_epoch:
                freeze_net(model.encoder)
            model.train()
            for qids, labels, *input_data in dataset.train():
                optimizer.zero_grad()
                bs = labels.size(0)
                for a in range(0, bs, args.mini_batch_size):
                    b = min(a + args.mini_batch_size, bs)
                    logits, _ = model(*[x[a:b] for x in input_data], layer_id=args.encoder_layer)

                    if args.loss == 'margin_rank':
                        num_choice = logits.size(1)
                        flat_logits = logits.view(-1)
                        correct_mask = F.one_hot(labels, num_classes=num_choice).view(-1)  # of length batch_size*num_choice
                        correct_logits = flat_logits[correct_mask == 1].contiguous().view(-1, 1).expand(-1, num_choice - 1).contiguous().view(-1)  # of length batch_size*(num_choice-1)
                        wrong_logits = flat_logits[correct_mask == 0]
                        y = wrong_logits.new_ones((wrong_logits.size(0),))
                        loss = loss_func(correct_logits, wrong_logits, y)  # margin ranking loss
                    elif args.loss == 'cross_entropy':
                        loss = loss_func(logits, labels[a:b])
                    loss = loss * (b - a) / bs
                    loss.backward()
                    total_loss += loss.item()
                if args.max_grad_norm > 0:
                    nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                scheduler.step()
                optimizer.step()

                if (global_step + 1) % args.log_interval == 0:
                    total_loss /= args.log_interval
                    ms_per_batch = 1000 * (time.time() - start_time) / args.log_interval
                    print('| step {:5} |  lr: {:9.7f} | loss {:7.4f} | ms/batch {:7.2f} |'.format(global_step, scheduler.get_lr()[0], total_loss, ms_per_batch))
                    total_loss = 0
                    start_time = time.time()
                global_step += 1

            model.eval()
            dev_acc = evaluate_accuracy(dataset.dev(), model)
            save_test_preds = args.save_model
            if not save_test_preds:
                test_acc = evaluate_accuracy(dataset.test(), model) if args.test_statements else 0.0
            else:
                eval_set = dataset.test()
                total_acc = []
                count = 0
                preds_path = os.path.join(args.save_dir, 'test_e{}_preds.csv'.format(epoch_id))
                with open(preds_path, 'w') as f_preds:
                    with torch.no_grad():
                        for qids, labels, *input_data in tqdm(eval_set):
                            count += 1
                            logits, _, concept_ids, node_type_ids, edge_index, edge_type = model(*input_data, detail=True)
                            predictions = logits.argmax(1) #[bsize, ]
                            preds_ranked = (-logits).argsort(1) #[bsize, n_choices]
                            for i, (qid, label, pred, _preds_ranked, cids, ntype, edges, etype) in enumerate(zip(qids, labels, predictions, preds_ranked, concept_ids, node_type_ids, edge_index, edge_type)):
                                acc = int(pred.item()==label.item())
                                print ('{},{}'.format(qid, chr(ord('A') + pred.item())), file=f_preds)
                                f_preds.flush()
                                total_acc.append(acc)
                test_acc = float(sum(total_acc))/len(total_acc)

            print('-' * 71)
            print('| epoch {:3} | step {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format(epoch_id, global_step, dev_acc, test_acc))
            print('-' * 71)
            with open(log_path, 'a') as fout:
                fout.write('{},{},{}\n'.format(global_step, dev_acc, test_acc))
            if dev_acc >= best_dev_acc:
                best_dev_acc = dev_acc
                final_test_acc = test_acc
                best_dev_epoch = epoch_id
                if args.save_model:
                    torch.save([model, args], model_path +".{}".format(epoch_id))
                    with open(model_path +".{}.log.txt".format(epoch_id), 'w') as f:
                        for p in model.named_parameters():
                            print (p, file=f)
                    print(f'model saved to {model_path}')
            else:
                if args.save_model:
                    torch.save([model, args], model_path +".{}".format(epoch_id))
                    with open(model_path +".{}.log.txt".format(epoch_id), 'w') as f:
                        for p in model.named_parameters():
                            print (p, file=f)
                    print(f'model saved to {model_path}')
            model.train()
            start_time = time.time()
            if epoch_id > args.unfreeze_epoch and epoch_id - best_dev_epoch >= args.max_epochs_before_stop:
                break
Ejemplo n.º 8
0
Archivo: rn.py Proyecto: zxlzr/MHGRN
def train(args):
    print(args)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available() and args.cuda:
        torch.cuda.manual_seed(args.seed)

    config_path = os.path.join(args.save_dir, 'config.json')
    model_path = os.path.join(args.save_dir, 'model.pt')
    log_path = os.path.join(args.save_dir, 'log.csv')
    export_config(args, config_path)
    check_path(model_path)
    with open(log_path, 'w') as fout:
        fout.write('step,train_acc,dev_acc\n')

    ###################################################################################################
    #   Load data                                                                                     #
    ###################################################################################################

    if 'lm' in args.ent_emb:
        print('Using contextualized embeddings for concepts')
        use_contextualized, cp_emb = True, None
    else:
        use_contextualized = False
    cp_emb = [np.load(path) for path in args.ent_emb_paths]
    cp_emb = torch.tensor(np.concatenate(cp_emb, 1))

    concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1)

    rel_emb = np.load(args.rel_emb_path)
    rel_emb = np.concatenate((rel_emb, -rel_emb), 0)
    rel_emb = cal_2hop_rel_emb(rel_emb)
    rel_emb = torch.tensor(rel_emb)
    relation_num, relation_dim = rel_emb.size(0), rel_emb.size(1)
    # print('| num_concepts: {} | num_relations: {} |'.format(concept_num, relation_num))

    device = torch.device(
        "cuda:0" if torch.cuda.is_available() and args.cuda else "cpu")

    dataset = LMRelationNetDataLoader(
        args.train_statements,
        args.train_rel_paths,
        args.dev_statements,
        args.dev_rel_paths,
        args.test_statements,
        args.test_rel_paths,
        batch_size=args.batch_size,
        eval_batch_size=args.eval_batch_size,
        device=device,
        model_name=args.encoder,
        max_tuple_num=args.max_tuple_num,
        max_seq_length=args.max_seq_len,
        is_inhouse=args.inhouse,
        inhouse_train_qids_path=args.inhouse_train_qids,
        use_contextualized=use_contextualized,
        train_adj_path=args.train_adj,
        dev_adj_path=args.dev_adj,
        test_adj_path=args.test_adj,
        train_node_features_path=args.train_node_features,
        dev_node_features_path=args.dev_node_features,
        test_node_features_path=args.test_node_features,
        node_feature_type=args.node_feature_type)

    ###################################################################################################
    #   Build model                                                                                   #
    ###################################################################################################

    lstm_config = get_lstm_config_from_args(args)
    model = LMRelationNet(model_name=args.encoder,
                          concept_num=concept_num,
                          concept_dim=relation_dim,
                          relation_num=relation_num,
                          relation_dim=relation_dim,
                          concept_in_dim=(dataset.get_node_feature_dim() if
                                          use_contextualized else concept_dim),
                          hidden_size=args.mlp_dim,
                          num_hidden_layers=args.mlp_layer_num,
                          num_attention_heads=args.att_head_num,
                          fc_size=args.fc_dim,
                          num_fc_layers=args.fc_layer_num,
                          dropout=args.dropoutm,
                          pretrained_concept_emb=cp_emb,
                          pretrained_relation_emb=rel_emb,
                          freeze_ent_emb=args.freeze_ent_emb,
                          init_range=args.init_range,
                          ablation=args.ablation,
                          use_contextualized=use_contextualized,
                          emb_scale=args.emb_scale,
                          encoder_config=lstm_config)

    try:
        model.to(device)
    except RuntimeError as e:
        print(e)
        print('best dev acc: 0.0 (at epoch 0)')
        print('final test acc: 0.0')
        print()
        return

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    grouped_parameters = [
        {
            'params': [
                p for n, p in model.encoder.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.encoder_lr
        },
        {
            'params': [
                p for n, p in model.encoder.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.encoder_lr
        },
        {
            'params': [
                p for n, p in model.decoder.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.decoder_lr
        },
        {
            'params': [
                p for n, p in model.decoder.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.decoder_lr
        },
    ]
    optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters)

    if args.lr_schedule == 'fixed':
        scheduler = ConstantLRSchedule(optimizer)
    elif args.lr_schedule == 'warmup_constant':
        scheduler = WarmupConstantSchedule(optimizer,
                                           warmup_steps=args.warmup_steps)
    elif args.lr_schedule == 'warmup_linear':
        max_steps = int(args.n_epochs *
                        (dataset.train_size() / args.batch_size))
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=args.warmup_steps,
                                         t_total=max_steps)

    print('parameters:')
    for name, param in model.decoder.named_parameters():
        if param.requires_grad:
            print('\t{:45}\ttrainable\t{}'.format(name, param.size()))
        else:
            print('\t{:45}\tfixed\t{}'.format(name, param.size()))
    num_params = sum(p.numel() for p in model.decoder.parameters()
                     if p.requires_grad)
    print('\ttotal:', num_params)

    if args.loss == 'margin_rank':
        loss_func = nn.MarginRankingLoss(margin=0.1, reduction='mean')
    elif args.loss == 'cross_entropy':
        loss_func = nn.CrossEntropyLoss(reduction='mean')

    ###################################################################################################
    #   Training                                                                                      #
    ###################################################################################################

    print()
    print('-' * 71)
    global_step, best_dev_epoch = 0, 0
    best_dev_acc, final_test_acc, total_loss = 0.0, 0.0, 0.0
    start_time = time.time()
    model.train()
    freeze_net(model.encoder)
    try:
        rel_grad = []
        linear_grad = []
        for epoch_id in range(args.n_epochs):
            if epoch_id == args.unfreeze_epoch:
                print('encoder unfreezed')
                unfreeze_net(model.encoder)
            if epoch_id == args.refreeze_epoch:
                print('encoder refreezed')
                freeze_net(model.encoder)
            model.train()
            for qids, labels, *input_data in dataset.train():
                optimizer.zero_grad()
                bs = labels.size(0)
                for a in range(0, bs, args.mini_batch_size):
                    b = min(a + args.mini_batch_size, bs)
                    logits, _ = model(*[x[a:b] for x in input_data],
                                      layer_id=args.encoder_layer)

                    if args.loss == 'margin_rank':
                        num_choice = logits.size(1)
                        flat_logits = logits.view(-1)
                        correct_mask = F.one_hot(
                            labels, num_classes=num_choice).view(
                                -1)  # of length batch_size*num_choice
                        correct_logits = flat_logits[
                            correct_mask == 1].contiguous().view(-1, 1).expand(
                                -1, num_choice - 1).contiguous().view(
                                    -1)  # of length batch_size*(num_choice-1)
                        wrong_logits = flat_logits[
                            correct_mask ==
                            0]  # of length batch_size*(num_choice-1)
                        y = wrong_logits.new_ones((wrong_logits.size(0), ))
                        loss = loss_func(correct_logits, wrong_logits,
                                         y)  # margin ranking loss
                    elif args.loss == 'cross_entropy':
                        loss = loss_func(logits, labels[a:b])
                    loss = loss * (b - a) / bs
                    loss.backward()
                    total_loss += loss.item()
                if args.max_grad_norm > 0:
                    nn.utils.clip_grad_norm_(model.parameters(),
                                             args.max_grad_norm)
                rel_grad.append(
                    model.decoder.rel_emb.weight.grad.abs().mean().item())
                linear_grad.append(model.decoder.mlp.layers[8].weight.grad.abs(
                ).mean().item())
                scheduler.step()
                optimizer.step()

                if (global_step + 1) % args.log_interval == 0:
                    total_loss /= args.log_interval
                    ms_per_batch = 1000 * (time.time() -
                                           start_time) / args.log_interval
                    print(
                        '| step {:5} |  lr: {:9.7f} | loss {:7.4f} | ms/batch {:7.2f} |'
                        .format(global_step,
                                scheduler.get_lr()[0], total_loss,
                                ms_per_batch))
                    # print('| rel_grad: {:1.2e} | linear_grad: {:1.2e} |'.format(sum(rel_grad) / len(rel_grad), sum(linear_grad) / len(linear_grad)))
                    total_loss = 0
                    rel_grad = []
                    linear_grad = []
                    start_time = time.time()
                global_step += 1

            model.eval()
            dev_acc = evaluate_accuracy(dataset.dev(), model)
            test_acc = evaluate_accuracy(
                dataset.test(), model) if args.test_statements else 0.0
            print('-' * 71)
            print('| epoch {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format(
                epoch_id, dev_acc, test_acc))
            print('-' * 71)
            with open(log_path, 'a') as fout:
                fout.write('{},{},{}\n'.format(global_step, dev_acc, test_acc))
            if dev_acc >= best_dev_acc:
                best_dev_acc = dev_acc
                final_test_acc = test_acc
                best_dev_epoch = epoch_id
                torch.save([model, args], model_path)
                print(f'model saved to {model_path}')
            model.train()
            start_time = time.time()
            if epoch_id > args.unfreeze_epoch and epoch_id - best_dev_epoch >= args.max_epochs_before_stop:
                break
    except (KeyboardInterrupt, RuntimeError) as e:
        print(e)

    print()
    print('training ends in {} steps'.format(global_step))
    print('best dev acc: {:.4f} (at epoch {})'.format(best_dev_acc,
                                                      best_dev_epoch))
    print('final test acc: {:.4f}'.format(final_test_acc))
    print()
Ejemplo n.º 9
0
class TransformersModel(BaseModel):
    max_grad_norm = 1.0

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.download()
        self.train_steps = 0
        self.checkpoint_steps = 500
        self.model_ckpt = str(self.model_dir.name)
        self.distilbert = 'distilbert' in self.model_ckpt

        if os.path.exists(os.path.join(self.model_ckpt, 'config.json')):
            self.logger.info('Loading from checkpoint %s' % self.model_ckpt)
            self.model_config = AutoConfig.from_pretrained(self.model_ckpt)
        elif os.path.exists(os.path.join(self.data_dir, 'config.json')):
            self.logger.info('Loading from trained model in %s' %
                             self.data_dir)
            self.model_ckpt = self.data_dir
            self.model_config = AutoConfig.from_pretrained(self.model_ckpt)
        else:
            self.logger.info(
                'Initializing new model with pretrained weights %s' %
                self.model_ckpt)
            self.model_config = AutoConfig.from_pretrained(self.model_ckpt)
            self.model_config.num_labels = 1  # set up for regression

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        if self.device == torch.device("cpu"):
            self.logger.info("RUNNING ON CPU")
        else:
            self.logger.info("RUNNING ON CUDA")
            torch.cuda.synchronize(self.device)

        self.rerank_model = AutoModelForSequenceClassification.from_pretrained(
            self.model_ckpt, config=self.model_config)
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_ckpt)
        self.rerank_model.to(self.device, non_blocking=True)

        self.optimizer = AdamW(self.rerank_model.parameters(),
                               lr=self.lr,
                               correct_bias=False)
        self.scheduler = ConstantLRSchedule(self.optimizer)

        self.weight = 1.0

    def train(self, query, choices):
        input_ids, attention_mask, token_type_ids = self.encode(query, choices)

        if self.model_config.num_labels == 1:
            labels = torch.tensor(labels,
                                  dtype=torch.float).to(self.device,
                                                        non_blocking=True)
        else:
            labels = torch.tensor(labels,
                                  dtype=torch.long).to(self.device,
                                                       non_blocking=True)

        if self.distilbert:
            loss = self.rerank_model(input_ids,
                                     labels=labels,
                                     attention_mask=attention_mask)[0]
        else:
            loss = self.rerank_model(input_ids,
                                     labels=labels,
                                     attention_mask=attention_mask,
                                     token_type_ids=token_type_ids)[0]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.rerank_model.parameters(),
                                       self.max_grad_norm)
        self.optimizer.step()
        self.scheduler.step()
        self.rerank_model.zero_grad()
        self.train_steps += 1
        if self.weight < 1.0:
            self.weight += self.lr * 0.1
        if self.train_steps % self.checkpoint_steps == 0:
            self.save()

    def rank(self, query, choices):
        input_ids, attention_mask, token_type_ids = self.encode(query, choices)

        with torch.no_grad():
            if self.distilbert:
                logits = self.rerank_model(input_ids,
                                           attention_mask=attention_mask)[0]
            else:
                logits = self.rerank_model(input_ids,
                                           attention_mask=attention_mask,
                                           token_type_ids=token_type_ids)[0]
            scores = np.squeeze(logits.detach().cpu().numpy())
            if len(scores.shape) > 1 and scores.shape[1] == 2:
                scores = np.squeeze(scores[:, 1])
            if len(logits) == 1:
                scores = [scores]
            return np.argsort(scores)[::-1]

    def encode(self, query, choices):
        inputs = [
            self.tokenizer.encode_plus(query, choice, add_special_tokens=True)
            for choice in choices
        ]

        max_len = min(max(len(t['input_ids']) for t in inputs),
                      self.max_seq_len)
        input_ids = [
            t['input_ids'][:max_len] + [0] *
            (max_len - len(t['input_ids'][:max_len])) for t in inputs
        ]
        attention_mask = [[1] * len(t['input_ids'][:max_len]) + [0] *
                          (max_len - len(t['input_ids'][:max_len]))
                          for t in inputs]
        token_type_ids = [
            t['token_type_ids'][:max_len] + [0] *
            (max_len - len(t['token_type_ids'][:max_len])) for t in inputs
        ]

        input_ids = torch.tensor(input_ids).to(self.device, non_blocking=True)
        attention_mask = torch.tensor(attention_mask).to(self.device,
                                                         non_blocking=True)
        token_type_ids = torch.tensor(token_type_ids).to(self.device,
                                                         non_blocking=True)

        return input_ids, attention_mask, token_type_ids

    def save(self):
        self.logger.info('Saving model')
        os.makedirs(self.data_dir, exist_ok=True)
        self.rerank_model.save_pretrained(self.data_dir)
        self.tokenizer.save_pretrained(self.data_dir)
Ejemplo n.º 10
0
def train(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available() and args.cuda:
        torch.cuda.manual_seed(args.seed)

    print('configuration:')
    print('\n'.join('\t{:15} {}'.format(k + ':', str(v))
                    for k, v in sorted(dict(vars(args)).items())))
    print()

    config_path = os.path.join(args.save_dir, 'config.json')
    model_path = os.path.join(args.save_dir, 'model.pt')
    log_path = os.path.join(args.save_dir, 'log.csv')
    export_config(args, config_path)
    check_path(model_path)
    with open(log_path, 'w') as fout:
        fout.write('step,train_acc,dev_acc\n')

    dic = {'transe': 0, 'numberbatch': 1}
    cp_emb, rel_emb = [
        np.load(args.ent_emb_paths[dic[source]]) for source in args.ent_emb
    ], np.load(args.rel_emb_path)
    cp_emb = np.concatenate(cp_emb, axis=1)
    cp_emb = torch.tensor(cp_emb)
    rel_emb = np.concatenate((rel_emb, -rel_emb), 0)
    rel_emb = torch.tensor(rel_emb)
    concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1)
    print('num_concepts: {}, concept_dim: {}'.format(concept_num, concept_dim))
    relation_num, relation_dim = rel_emb.size(0), rel_emb.size(1)
    print('num_relations: {}, relation_dim: {}'.format(relation_num,
                                                       relation_dim))

    try:

        device0 = torch.device(
            "cuda:0" if torch.cuda.is_available() and args.cuda else "cpu")
        device1 = torch.device(
            "cuda:1" if torch.cuda.is_available() and args.cuda else "cpu")
        dataset = KagNetDataLoader(
            args.train_statements,
            args.train_paths,
            args.train_graphs,
            args.dev_statements,
            args.dev_paths,
            args.dev_graphs,
            args.test_statements,
            args.test_paths,
            args.test_graphs,
            batch_size=args.mini_batch_size,
            eval_batch_size=args.eval_batch_size,
            device=(device0, device1),
            model_name=args.encoder,
            max_seq_length=args.max_seq_len,
            max_path_len=args.max_path_len,
            is_inhouse=args.inhouse,
            inhouse_train_qids_path=args.inhouse_train_qids,
            use_cache=args.use_cache)
        print('dataset done')

        ###################################################################################################
        #   Build model                                                                                   #
        ###################################################################################################
        lstm_config = get_lstm_config_from_args(args)

        model = LMKagNet(model_name=args.encoder,
                         concept_dim=concept_dim,
                         relation_dim=relation_dim,
                         concept_num=concept_num,
                         relation_num=relation_num,
                         qas_encoded_dim=args.qas_encoded_dim,
                         pretrained_concept_emb=cp_emb,
                         pretrained_relation_emb=rel_emb,
                         lstm_dim=args.lstm_dim,
                         lstm_layer_num=args.lstm_layer_num,
                         graph_hidden_dim=args.graph_hidden_dim,
                         graph_output_dim=args.graph_output_dim,
                         dropout=args.dropout,
                         bidirect=args.bidirect,
                         num_random_paths=args.num_random_paths,
                         path_attention=args.path_attention,
                         qa_attention=args.qa_attention,
                         encoder_config=lstm_config)
        print('model done')
        if args.freeze_ent_emb:
            freeze_net(model.decoder.concept_emb)
        print('freezed')
        model.encoder.to(device0)
        print('encoder done')
        model.decoder.to(device1)
        print('decoder done')
    except RuntimeError as e:
        print(e)
        print('best dev acc: 0.0 (at epoch 0)')
        print('final test acc: 0.0')
        print()
        return

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    grouped_parameters = [
        {
            'params': [
                p for n, p in model.encoder.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.encoder_lr
        },
        {
            'params': [
                p for n, p in model.encoder.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.encoder_lr
        },
        {
            'params': [
                p for n, p in model.decoder.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay,
            'lr':
            args.decoder_lr
        },
        {
            'params': [
                p for n, p in model.decoder.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0,
            'lr':
            args.decoder_lr
        },
    ]
    optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters)

    if args.lr_schedule == 'fixed':
        scheduler = ConstantLRSchedule(optimizer)
    elif args.lr_schedule == 'warmup_constant':
        scheduler = WarmupConstantSchedule(optimizer,
                                           warmup_steps=args.warmup_steps)
    elif args.lr_schedule == 'warmup_linear':
        max_steps = int(args.n_epochs *
                        (dataset.train_size() / args.batch_size))
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=args.warmup_steps,
                                         t_total=max_steps)

    print('parameters:')
    for name, param in model.decoder.named_parameters():
        if param.requires_grad:
            print('\t{:45}\ttrainable\t{}'.format(name, param.size()))
        else:
            print('\t{:45}\tfixed\t{}'.format(name, param.size()))
    num_params = sum(p.numel() for p in model.decoder.parameters()
                     if p.requires_grad)
    print('\ttotal:', num_params)

    if args.loss == 'margin_rank':
        loss_func = nn.MarginRankingLoss(margin=0.1, reduction='mean')
    elif args.loss == 'cross_entropy':
        loss_func = nn.CrossEntropyLoss(reduction='mean')

    print()
    print('-' * 71)
    global_step, last_best_step = 0, 0
    best_dev_acc, final_test_acc, total_loss = 0.0, 0.0, 0.0
    start_time = time.time()
    model.train()
    freeze_net(model.encoder)
    try:
        for epoch_id in range(args.n_epochs):
            if epoch_id == args.unfreeze_epoch:
                unfreeze_net(model.encoder)
            if epoch_id == args.refreeze_epoch:
                freeze_net(model.encoder)
            for qids, labels, *input_data in dataset.train():
                optimizer.zero_grad()
                bs = labels.size(0)
                for a in range(0, bs, args.mini_batch_size):
                    print(00)
                    b = min(a + args.mini_batch_size, bs)
                    # print(11)
                    # # print([x.device if isinstance(x, (torch.tensor,)) else None for x in input_data])
                    # print(type(input_data[0]), type(input_data[0][0]), input_data[0][0].size())
                    # print(type(input_data[1]), type(input_data[1][0]), input_data[1][0].size())
                    # print(type(input_data[2]), type(input_data[2][0]), input_data[2][0].size())
                    # print(type(input_data[3]), type(input_data[3][0]), input_data[3][0].size())
                    # print(type(input_data[4]), type(input_data[4][0]))
                    # print(type(input_data[5]), type(input_data[5][0]))
                    # print(type(input_data[6]), type(input_data[6][0]))
                    # print(type(input_data[7]), type(input_data[7][0]))
                    # print(type(input_data[8]), type(input_data[8][0]))
                    # print(type(input_data[9]))
                    # print(type(input_data[10]))
                    logits, _ = model(*[x for x in input_data],
                                      layer_id=args.encoder_layer)

                    if args.loss == 'margin_rank':
                        num_choice = logits.size(1)
                        flat_logits = logits.view(-1)
                        correct_mask = F.one_hot(
                            labels, num_classes=num_choice).view(
                                -1)  # of length batch_size*num_choice
                        correct_logits = flat_logits[
                            correct_mask == 1].contiguous().view(-1, 1).expand(
                                -1, num_choice - 1).contiguous().view(
                                    -1)  # of length batch_size*(num_choice-1)
                        wrong_logits = flat_logits[
                            correct_mask ==
                            0]  # of length batch_size*(num_choice-1)
                        y = wrong_logits.new_ones((wrong_logits.size(0), ))
                        loss = loss_func(correct_logits, wrong_logits,
                                         y)  # margin ranking loss
                    elif args.loss == 'cross_entropy':
                        loss = loss_func(logits, labels[a:b])
                    loss = loss * (b - a) / bs
                    loss.backward()
                    total_loss += loss.item()
                if args.max_grad_norm > 0:
                    nn.utils.clip_grad_norm_(model.parameters(),
                                             args.max_grad_norm)
                scheduler.step()
                optimizer.step()

                if (global_step + 1) % args.log_interval == 0:
                    total_loss /= args.log_interval
                    ms_per_batch = 1000 * (time.time() -
                                           start_time) / args.log_interval
                    print(
                        '| step {:5} |  lr: {:9.7f} | loss {:7.4f} | ms/batch {:7.2f} |'
                        .format(global_step,
                                scheduler.get_lr()[0], total_loss,
                                ms_per_batch))
                    total_loss = 0
                    start_time = time.time()

                if (global_step + 1) % args.eval_interval == 0:
                    model.eval()
                    dev_acc = evaluate_accuracy(dataset.dev(), model)
                    test_acc = evaluate_accuracy(
                        dataset.test(), model) if args.test_statements else 0.0
                    print('-' * 71)
                    print('| step {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.
                          format(global_step, dev_acc, test_acc))
                    print('-' * 71)
                    with open(log_path, 'a') as fout:
                        fout.write('{},{},{}\n'.format(global_step, dev_acc,
                                                       test_acc))
                    if dev_acc >= best_dev_acc:
                        best_dev_acc = dev_acc
                        final_test_acc = test_acc
                        last_best_step = global_step
                        torch.save([model, args], model_path)
                        print(f'model saved to {model_path}')
                    model.train()
                    start_time = time.time()

                global_step += 1
                # if global_step >= args.max_steps or global_step - last_best_step >= args.max_steps_before_stop:
                #     end_flag = True
                #     break
    except (KeyboardInterrupt, RuntimeError) as e:
        print(e)

    print()
    print('training ends in {} steps'.format(global_step))
    print('best dev acc: {:.4f} (at step)'.format(best_dev_acc,
                                                  last_best_step))
    print('final test acc: {:.4f}'.format(final_test_acc))
Ejemplo n.º 11
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        # CHANGE: Change log dir to be in output_dir
        # We want the logs to be saved in GDrive for persistence
        # Note however that this dir can be overwritten with overwrite_output_dir
        if args.log_dir is not None:
            import socket
            from datetime import datetime
            current_time = datetime.now().strftime('%b%d_%H-%M-%S')
            log_dir = os.path.join(args.log_dir, current_time + '_' + socket.gethostname())
            tb_writer = SummaryWriter(log_dir=log_dir)
        else:
            tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    if args.lr_finder:
        # Causes to go on a mock training from `start_lr` to `end_lr` for `num_it` iterations.
        # https://github.com/fastai/fastai/blob/e5d0aeb69d195f135608318094745e497e2d713f/fastai/callbacks/lr_finder.py
        start_lr = 1e-7
        end_lr = 1
        num_it = 100
        annealing_exp = lambda x: start_lr * (end_lr / start_lr) ** (x / num_it)
        optimizer = AdamW(optimizer_grouped_parameters, lr=1, eps=args.adam_epsilon)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, annealing_exp)
    else:
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
        scheduler = ConstantLRSchedule(optimizer)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
                                                          output_device=args.local_rank,
                                                          find_unused_parameters=True)

    # Train!
    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    set_seed(args)  # Added here for reproducibility (even between python 2 and 3)
    
    # Capture LR finder output
    lr_finder_lr = []
    lr_finder_loss = []
    lr_finder_best_loss = None

    # CHANGE: One progress bar for all iterations
    def create_pbar():
        return tqdm(total=args.max_steps if args.max_steps > -1 else len(train_dataloader)*int(args.num_train_epochs), 
            desc="Training", disable=args.local_rank not in [-1, 0])
    pbar = create_pbar()
    for epoch in range(int(args.num_train_epochs)):
        for step, batch in enumerate(train_dataloader):
            inputs, labels = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            model.train()
            outputs = model(inputs, masked_lm_labels=labels) if args.mlm else model(inputs, labels=labels)
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.lr_finder:
                    # CHANGE: Capture LR and loss
                    # https://towardsdatascience.com/adaptive-and-cyclical-learning-rates-using-pytorch-2bf904d18dee
                    lr_finder_lr.append(scheduler.get_lr()[0])
                    if global_step == 1:
                        lr_finder_loss.append(loss)
                        lr_finder_best_loss = loss
                    else:
                        # Smooth loss
                        smoothing = 0.05
                        smooth_loss = smoothing * loss + (1 - smoothing) * lr_finder_loss[-1]
                        lr_finder_loss.append(smooth_loss)
                        if smooth_loss < lr_finder_best_loss:
                            lr_finder_best_loss = smooth_loss
                        # Determine if loss has runaway and we should stop
                        if smooth_loss > 4 * lr_finder_best_loss or torch.isnan(smooth_loss):
                            break
                    # Append to a file to be visualized later
                    with open(os.path.join(args.output_dir, 'lr_finder.txt'), "a") as myfile:
                        myfile.write("{},{}\n".format(lr_finder_lr[-1], lr_finder_loss[-1]))
                else:
                    # CHANGE: Los less
                    if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                        # Log metrics
                        avg_loss = (tr_loss - logging_loss)/args.logging_steps
                        tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                        tb_writer.add_scalar('loss', avg_loss, global_step)
                        logging_loss = tr_loss

                    if args.local_rank in [-1, 0] and args.eval_steps > 0 and global_step % args.eval_steps == 0:
                        if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                            pbar.close() # one bar at a time to prevent issues with nested bars
                            results = evaluate(args, model, tokenizer)
                            pbar = create_pbar()
                            pbar.n = global_step-1
                            pbar.refresh()

                            for key, value in results.items():
                                tb_writer.add_scalar('eval_{}'.format(key), value, global_step)

                    if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                        checkpoint_prefix = 'checkpoint'
                        # Save model checkpoint
                        output_dir = os.path.join(args.output_dir, '{}-{}'.format(checkpoint_prefix, global_step))
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
                        model_to_save.save_pretrained(output_dir)
                        torch.save(args, os.path.join(output_dir, 'training_args.bin'))
                        logger.info("Saving model checkpoint to %s", output_dir)

                        _rotate_checkpoints(args, checkpoint_prefix)

            pbar.update()

            if args.max_steps > 0 and global_step > args.max_steps:
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            break
    pbar.close()

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step