Esempio n. 1
0
    def __init__(self, dataset, model):
        self.dataset = dataset
        self.model = model
        self.args = get_args()
        # TO DO: refactor this?
        self.rho = torch.FloatTensor([0.05 for _ in range(self.model.hidden_size)]).unsqueeze(0).to(device)

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.args.lr)

        if self.args.resume_snapshot and os.path.isfile(self.args.resume_path):
            self.model, self.optimizer = load_checkpoint(self.args.resume_path, self.model, self.optimizer)

        # self.loss_func = nn.BCELoss()
        # self.loss_func.size_average = False
        self.loss_func = nn.MSELoss()

        self.train_loader, self.valid_loader = get_train_valid_loaders(self.dataset)
Esempio n. 2
0
    def __init__(self, args: Dict):
        super(BudnoiseAgent, self).__init__()
        # BudAgent.__init__(self, args)
        self.args = args
        print('\nInitializing agent setting')
        #=========  loading model and fastext =============
        self.embeddertype = self.args['embeddertype']
        self.batch_size = self.args['batch_size']
        self.device = torch.device(
            'cuda') if self.args['cuda'] else torch.device('cpu')
        self.policy = self.args['strategy']

        self.position = 0
        self.seed = 0

        if args['bert']:
            from data.berttokenizer import BTTokenizer, BertBatcher
            from module.bert_trainer import run_epoch
            from module.bertmodel import BertRetrieval as Retrieval
            from transformers import AdamW, WarmupLinearSchedule
            print('loading bert tokenizer')
            self.preprocessor = BTTokenizer(args)  #SplitPreprocessor()
            PAD_ID = self.preprocessor.padding_idx()
            self.batch_embedder = BertBatcher(cuda=args['cuda'], pad=PAD_ID)
            self.embeddertype = 'bpe'
        else:
            from data.tokenizer import Tokenizer_nltk
            from module.trainer import run_epoch
            from bdmodule.BUDmodel import Retrieval
            self.preprocessor = Tokenizer_nltk()  #SplitPreprocessor()

        print('\n\nloading model from {}'.format(args['ckpt']))
        if self.args['ckpt'].endswith('newtk_pretrianed.pt'):
            state = torch.load(args['ckpt'],
                               map_location=lambda storage, location: storage)
            model = state['model'].eval()
        else:
            model_state, _, _, _, _, savedconfig = load_checkpoint(
                self.args['ckpt'])
            savedconfig['cuda'] = self.args['cuda']
            model = Retrieval(savedconfig)
            model.load_state_dict(model_state)
        self.model = model.to(self.device)
def train_model(train_loader, 
                valid_loader, 
                head_trans, 
                body_trans, 
                classifier, 
                load_model = LOAD_MODEL, 
                save_model = SAVE_MODEL, 
                model_path = MODEL_PATH+'model.pth.tar', 
                num_epochs = NUM_EPOCHS, 
                lr = LR) :

    torch.backends.cudnn.benchmark = True
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # device = 'cpu'
    print(device)

    load_model = load_model
    save_model = save_model

    learning_rate = lr
    num_epochs = num_epochs

    # For tensorboard
    writer = SummaryWriter('runs/vpp')
    step = 0
  
    # Initialize Model
    model = BertClassifier(head_trans, body_trans, classifier).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr = learning_rate)

    if load_model :
        model, optimizer, step = load_checkpoint(torch.load(MODEL_PATH), model, optimizer)
    
    for epoch in range(num_epochs) :
        running_loss = 0.0
        for mode in ['train', 'eval']: 

            # Setting Necessary Data Loader
            if mode == 'train' :
                dataloader = train_loader
            else :
                dataloader = valid_loader
                all_preds, all_actual = np.array([]), np.array([])

            # Creating Loop
            loop = tqdm(enumerate(dataloader), total = len(dataloader), leave = False)
            if mode == 'train' :
                model.train()
            else :
                model.eval()
                eval_loss = 0.0

            # Iterating over batches
            for batch, (head, body, stance) in loop :

                outputs = model(head.to(device), body.to(device))
                loss = criterion(outputs.float(), stance.to(device).long())

                if mode == 'train':
                    writer.add_scalar('Training Epoch Loss', loss.item(), step)
                    step += 1

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                    # Update progress bar
                    loop.set_description(f'Epoch [{epoch+1}/{num_epochs}]')
                    loop.set_postfix(loss = loss.item())
                
                    running_loss += loss.item()
                    """
                    #running_accuracy += ((torch.argmax(outputs, dim = 1) == stance.to(device)).sum().item()) / BATCH_SIZE
                    if batch %10 == 0:
                        bt_loss = running_loss
                        writer.add_scalar('Running Loss', running_loss / 10, epoch * len(dataloader) + batch)
                        bt_loss = 0.0
                        #writer.add_scalar('Running Accuracy', running_accuracy / 10, epoch * len(dataloader) + batch)
                        
                    #running_accuracy = 0
                    """
                else : # for evaluation mode
                    eval_loss += loss.item()
                    predictions = torch.argmax(outputs, dim = 1)
                    all_preds = np.append(all_preds, predictions.cpu().detach().numpy())
                    all_actual = np.append(all_actual, stance.cpu().detach().numpy())
            if mode == 'train':
                if save_model :
                            checkpoint = {
                                'state_dict' : model.state_dict(),
                                'optimizer' : optimizer.state_dict(),
                                'step' : step
                            }
                            save_checkpoint(checkpoint, model_path)
                epoch_loss = running_loss / len(dataloader)
                writer.add_scalar('Epoch Loss', epoch_loss, epoch+1)
                epoch_loss = 0.0
            else : # for evaluation mode
                epoch_eval_loss = eval_loss / len(dataloader)
                print(f'Evaluation Loss for epoch {epoch+1} : {epoch_eval_loss}')
                writer.add_scalar('Evaluation Loss', epoch_eval_loss, epoch+1)
                epoch_eval_loss = 0.0
                get_scores(all_preds, all_actual)
    print("\n")
    def __init__(self, args: Dict):
        super(AskingAgent, self).__init__()

        self.args = args
        self.embeddertype = args['embeddertype']
        self.batch_size = args['batch_size']
        self.device = torch.device('cuda') if args['cuda'] else torch.device(
            'cpu')
        self.policy = args['strategy']

        #==========loading data =============

        if args['bert']:
            from data.berttokenizer import BTTokenizer, BertBatcher
            from module.bert_trainer import run_epoch
            from module.bertmodel import BertRetrieval as Retrieval
            from transformers import AdamW, WarmupLinearSchedule
            print('loading bert tokenizer')
            self.preprocessor = BTTokenizer(args)  #SplitPreprocessor()
            PAD_ID = self.preprocessor.padding_idx()
            self.batch_embedder = BertBatcher(cuda=args['cuda'], pad=PAD_ID)
            self.embeddertype = 'bpe'
        else:
            from data.tokenizer import Tokenizer_nltk
            from module.trainer import run_epoch
            from module.model import FAQRetrieval as Retrieval
            self.preprocessor = Tokenizer_nltk()  #SplitPreprocessor()

        #data_records = read_data(args)

        self.queryfile, self.faq_pool = data_loading.read_queries(args)
        data_train, data_val, data_test = data_loading.query_cv(
            self.queryfile, fold_n=args['cv_n'])
        self.pd_train, self.pd_val, self.pd_test = data_train, data_val, data_test

        self.gold_table, self.tags, self.faqs = self._parse_faq(
            self.queryfile.to_dict('records'))
        self.faqtag_table = self.reload_fromprob()

        #=========  preprocessing and precompute =============
        self.iqs = np.array(self._preprocess(self.iqs_text))

        # if args['datasplit'] == 'query':
        #     self.iqs_train, self.tgt_train, self.iqs_eval, self.tgt_eval= split_srctgtdata(self.iqs, self.tgt_ids, r = 0.2)
        # elif args['datasplit'] == 'faq':
        #data_train, data_val, data_test = dataloader1129.split_data_tdt(allquery, 0.2, 0.2)
        # data_train, data_val, data_test = data_loading.query_cv(self.queryfile, fold_n= args['cv_n'])
        # self.pd_train, self.pd_val, self.pd_test = data_train, data_val, data_test
        data_train, data_val, data_test = data_train.to_dict(
            'records'), data_val.to_dict('records'), data_test.to_dict(
                'records')
        self.iqs_eval, self.tgt_eval = self.get_data(data_val)
        self.iqs_train, self.tgt_train = self.get_data(data_train + data_val)
        self.iqs_test, self.tgt_test = self.get_data(data_test)
        # else:
        #     print('train test splitting error')
        # #self.num_examples = len(self.iqs_eval)
        print('There are {} initial queries from {} faqs to test'.format(
            len(self.tgt_eval), len(set(self.tgt_eval))))

        self.num_batches = math.ceil(len(self.iqs_train) / self.batch_size)
        self.position = 0

        #=========  loading encoding model and fastext =============
        if not args['bert']:
            print('Loading FastText')
            self.embedder = FastTextEmbedder(path=args['embedding_path'])
            print('Loading embeddings')
            if args['embeddertype'] == 'index':
                self.batch_embedder = IndexBatchEmbedder(self.embedder,
                                                         self.word_to_index,
                                                         cuda=args['cuda'])
            elif args['embeddertype'] == 'word':
                self.batch_embedder = WordBatchEmbedder(
                    self.embedder,
                    set(self.word_to_index.keys()),
                    cuda=args['cuda'])
            else:
                print('bath embedder method not implemented')
            print('\nVocabulary size = {:,}'.format(len(self.word_to_index)))

        print('\n\nloading model from {}'.format(args['ckpt']))
        if self.args['ckpt'].endswith('newtk_pretrianed.pt'):
            state = torch.load(args['ckpt'],
                               map_location=lambda storage, location: storage)
            model = state['model'].eval()
        else:
            model_state, _, _, _, _, savedconfig = load_checkpoint(
                self.args['ckpt'])
            savedconfig['cuda'] = self.args['cuda']
            model = Retrieval(savedconfig)
            model.load_state_dict(model_state)
        model.change_device(self.device)
        self.model = model.to(self.device)

        self.faqs_index = self._preprocess(self.faqs)
        print('total faqs: {}'.format(len(self.faqs)))
        with torch.no_grad():
            self.faqs_mat = utils.encode_candidates(
                self.faqs_index, self.batch_embedder, self.model,
                self.batch_size)  #self.encode_candidates( self.faqs_index)

        self.embedweight = nn.Parameter(torch.Tensor([0.8]).to(self.device))

        #=========  set up tag inference module=============
        if args['taginfer']:
            self.faqtag_belief0 = self.tag_input()
            self.faqtag_belief = self.faqtag_belief0.sigmoid()
            if args['ft_tag'] or args['tag_pretrain']:
                if args['tag_model'] == 'scalar':
                    w_ = np.array([0.4747236])  #np.random.rand(1)
                    b_ = np.array([-26.986095])  #np.random.rand(1)
                    w_ = np.array([0.312])  #np.random.rand(1)
                    b_ = np.array([1.0])  #np.random.rand(1)
                elif args['tag_model'] == 'vector':
                    w_ = np.load('w_813_813_linear.npy')
                    b_ = np.load('b_813_linear.npy')
                elif args['tag_model'] == 'bl':
                    nd = self.model.output_size
                    print(nd)
                    w_ = np.zeros((nd, nd))  #np.random.rand(1)
                    b_ = np.array([0.4])  #np.random.rand(1)
                else:
                    print('error!!')
                if not args['tag_pretrain']:
                    w_ = np.ones(w_.shape) * 0.1
                    b_ = np.ones(b_.shape) * 0.1
                self.tagweight = nn.Parameter(torch.Tensor(w_).to(self.device))
                self.tagbias = nn.Parameter(torch.Tensor(b_).to(self.device))

                ld_ = args['aa_ld']  #0.5 #np.random.rand(1)
                self.lmda = nn.Parameter(torch.Tensor([ld_]).to(self.device))
                self.tag_inference()
        else:
            self.faqtag_belief = torch.Tensor(self.faqtag_table).to(
                self.device)
def main(args):

    device = torch.device('cuda') if args['cuda'] else torch.device('cpu')
    print(device)

    aa = AskingAgent(args)
    aa.load_state_dict(
        torch.load('checkpoints/tagft_tagmodelTrue_fttagTrue_ftembTrue_aa.pt'))
    aa = aa.to(device)

    if args['user_type'] == 'oracle':
        user = NoisyUser(args)
    elif args['user_type'] == 'persona':
        user = PersonaUser(aa, args)
    else:
        print('no user type implemented')

    writer = SummaryWriter(
        os.path.join(args['tensorboard_dir'],
                     args['comment'] + '_' + args['flavor']))
    writer.add_text('Args', args['comment'] + ' ' + str(args) + '\n')
    save_path = args['checkpoint_path']

    #==========loading data =============
    policynet = Policy(args)
    #policynet = load_checkpoint(args['ckpt'])[0]
    policynet.model = load_checkpoint(
        'checkpoints/tagft_tagmodelTrue_fttagTrue_ftembTrue_auc05_0.742_e4900.pt'
    )[0]
    policynet = policynet.to(device)
    print('policy network model: ')
    print(policynet.model)
    writer.add_text('model', str(policynet.model))

    optimizer = optim.Adam(policynet.model.parameters(), lr=args['lr'])

    ftparams = []
    if args['ft_tag']:
        ftparams += [aa.tagweight, aa.tagbias, aa.lmda]
    if args['ft_emb']:
        if args['ft_rnn']:
            for m in aa.model.modules():
                if isinstance(m, nn.Dropout):
                    m.p = args['dropout']
                if isinstance(m, SRU):
                    m.dropout = args['dropout']
            ftparams += get_params(aa.model)
        else:
            ftparams += [aa.embedweight]
    if args['ft_emb'] or args['ft_tag']:
        print('Finetuning turned on ')
        nnoptimizer = optim.Adam(ftparams, lr=args['ft_lr'])
    else:
        nnoptimizer = None

    for episode in range(1, args['episodes']):
        if episode % (args['test_every']) == 0:
            batch = aa.testdata()
            mode = 'test'
            policynet.model.eval()
            aa.model.eval()

        elif episode % args['eval_every'] == 0:
            batch = aa.valdata()
            mode = 'val'
            policynet.model.eval()
            aa.model.eval()

        else:
            batch = aa.sampletrain(args['batch_size'])
            mode = 'train'
            policynet.model.train()
            aa.model.train()

        batch_s = len(batch[0])
        rank_batch, p_fx_batch, _ = infogain_rollout(batch, aa, user, args,
                                                     mode)

        action_batch = []
        logp_batch = []
        for cnt in range(1, len(p_fx_batch) + 1):
            p_f_x = p_fx_batch[cnt - 1]

            if not args['ft_tag'] and not args['ft_emb']:
                p_f_x = p_f_x.detach()

            if cnt == args['max_step']:
                action = np.zeros(batch_s)
                log_pact = torch.zeros(batch_s).to(device)
            else:
                state = policynet.get_state(p_f_x, cnt)
                action, log_pact, _ = policynet.select_action(state)

            action_batch.append(action)
            logp_batch.append(log_pact)

        rewards, logp_bs, scalars = reprocess_withmask(action_batch,
                                                       rank_batch, logp_batch,
                                                       device, args)

        if mode == 'train':
            if nnoptimizer:
                nnoptimizer.zero_grad()

            scalars = policynet.update_policy(optimizer, rewards, logp_bs,
                                              scalars)

            if nnoptimizer:
                print('fintuning')
                clip_grad_norm_(
                    [p for p in aa.model.parameters() if p.requires_grad], 3.0)
                nnoptimizer.step()

            if args['ft_tag']:
                aa.tag_inference()
                #print('w: {:.3f}, b: {:.3f}, lmd: {:.3f}'.format(aa.tagweight.item(), aa.tagbias.item(), aa.lmda.item()))
                #writer.add_scalar('tagmodel/weight', aa.tagweight.item(), episode) #*args['batch_size'])
                #writer.add_scalar('tagmodel/bias', aa.tagbias.item(), episode) #*args['batch_size'])
                writer.add_scalar('tagmodel/lmda', aa.lmda.item(),
                                  episode)  #*args['batch_size'])
                writer.add_scalar('tagmodel/weight', aa.tagweight.data.norm(),
                                  episode)  #*args['batch_size'])
                writer.add_scalar('tagmodel/bias', aa.tagbias.data.norm(),
                                  episode)  #*args['batch_size'])
            if args['ft_emb']:
                writer.add_scalar('tagmodel/embweight',
                                  aa.embedweight.data.norm(),
                                  episode)  #*args['batch_size'])
                if args['ft_rnn']:
                    writer.add_scalar('rnn-parameter/rnn_param_norm',
                                      compute_param_norm(aa.model), episode)
                    writer.add_scalar('rnn-parameter/rnn_grad_norm',
                                      compute_grad_norm(aa.model), episode)

        if writer is not None:
            for name, value in scalars:
                writer.add_scalar(mode + name, value,
                                  episode)  #*args['batch_size'])

        if episode % args['print_every'] == 0:
            print(mode)
            print('Step: {:,} '.format(episode * args['batch_size']) +
                  ' '.join([
                      '{} = {:.3f}'.format(name, value)
                      for name, value in scalars
                  ]))

        if episode % args['save_every'] == 0:
            torch.save(
                aa.state_dict(),
                args['checkpoint_dir'] + '/' + args['flavor'] + '_aa.pt')
            save_path = save_checkpoint(policynet.model,
                                        optimizer,
                                        episode,
                                        episode * args['batch_size'],
                                        dict(scalars)['/suc_rate'],
                                        args,
                                        prev_save_path=save_path)
def main(config: Dict):
    pprint(config)
    print('Loading Preprocessor')
    if config['bert']:
        from data.berttokenizer import BTTokenizer, BertBatcher
        from module.bert_trainer import run_epoch
        from module.bertmodel import BertRetrieval
        from transformers import AdamW, WarmupLinearSchedule
        print('loading bert tokenizer')
        preprocessor = BTTokenizer(config)  #SplitPreprocessor()
        PAD_ID = preprocessor.padding_idx()
        batch_embedder = BertBatcher(cuda=config['cuda'], pad=PAD_ID)
    else:
        from data.tokenizer import Tokenizer_nltk
        from module.trainer import run_epoch
        from module.model import FAQRetrieval as Retrieval
        preprocessor = Tokenizer_nltk()  #SplitPreprocessor()

    # ================================ Load data ================================
    bird_domain = False
    if config['domain'] == 'faq':
        from data.faq_loader import load_data
    elif config['domain'] == 'health':
        from data.health_loader import load_data
    elif config['domain'] == 'bird':
        from data.bird_loader import load_data
        from bdmodule.BUDmodel import Retrieval as Retrieval
        from bdmodule.BUDmodel import Retrieval as BertRetrieval
        bird_domain = True

    train_data, val_data, test_datalist, word_to_index = load_data(
        config, preprocessor)
    aucresult = defaultdict()
    recallresult = defaultdict()
    # ================================ setting up training environment ================================
    # Set up Tensorboard writer
    writer = SummaryWriter(
        os.path.join(config['tensorboard_dir'], config['flavor']))
    writer.add_text('Config', pformat(config))

    if not config['bert']:
        print('Loading FastText')
        embedder = FastTextEmbedder(path=config['embedding_path'])
        print('Loading embeddings')
        batch_embedder = IndexBatchEmbedder(embedder,
                                            word_to_index,
                                            cuda=config['cuda'])

    # Load or build model
    if config['checkpoint_path']:
        print('Loading model from {}'.format(config['checkpoint_path']))
        model, optimizer, init_epoch, iter_count, best_auc05 = load_checkpoint(
            config['checkpoint_path'])
        save_path = config['checkpoint_path']
    else:
        print('Building model')
        if config['bert']:
            model = BertRetrieval(config)
            # Prepare optimizer and schedule (linear warmup and decay)
            no_decay = ['bias', 'LayerNorm.weight']
            optimizer_grouped_parameters = [{
                'params': [
                    p for n, p in model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                config['weight_decay']
            }, {
                'params': [
                    p for n, p in model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                'weight_decay':
                0.0
            }]
            # optimizer_grouped_parameters = get_params(model)
            optimizer = AdamW(optimizer_grouped_parameters,
                              lr=config['lr'],
                              eps=1e-8)
            num_batch_per_epoch = min(train_data.num_batches,
                                      config['max_batches_per_epoch'])
            t_total = int(num_batch_per_epoch //
                          config['gradient_accumulation_steps'] *
                          config['max_epoch'])
            scheduler = WarmupLinearSchedule(
                optimizer,
                warmup_steps=config['warmup_steps'],
                t_total=t_total)
        else:
            model = Retrieval(config)
            optimizer = Adam(
                get_params(model),
                lr=config['lr'],
                betas=[config['adam_beta1'], config['adam_beta2']],
                eps=config['adam_eps'])
            scheduler = None
        init_epoch = iter_count = best_auc05 = 0
        save_path = None

    # Print model details
    # print(model)
    print('Number of trainable parameters = {:,}'.format(
        parameter_count(model)))

    model = model.to(model.device)

    # ================================ Training  ================================
    # Run training
    for epoch in trange(init_epoch, config['max_epoch']):
        print('=' * 20 + ' Epoch {} '.format(epoch) + '=' * 20)

        # Train
        iter_count = run_epoch(train_data,
                               batch_embedder,
                               model,
                               config,
                               train=True,
                               iter_count=iter_count,
                               writer=writer,
                               scheduler=scheduler,
                               bird=bird_domain,
                               optimizer=optimizer)
        torch.cuda.empty_cache()

        auc05 = run_epoch(val_data,
                          batch_embedder,
                          model,
                          config,
                          train=False,
                          iter_count=iter_count,
                          writer=writer,
                          bird=bird_domain)

        torch.cuda.empty_cache()

        # Save if improved validation [email protected]
        # if epoch == 0  or auc05 > best_auc05:
        if epoch % 4 == 0:
            best_auc05 = auc05
            save_path = save_checkpoint(model,
                                        optimizer,
                                        epoch,
                                        iter_count,
                                        auc05,
                                        config,
                                        prev_save_path=save_path)
            for key, test_data in test_datalist.items():
                print('Testing:')
                auc05_test = run_epoch(test_data,
                                       batch_embedder,
                                       model,
                                       config,
                                       train=False,
                                       test=True,
                                       iter_count=iter_count,
                                       writer=writer,
                                       bird=bird_domain)
                aucresult[key] = auc05_test
                for key, test_data in test_datalist.items():
                    print('test dataset : {}'.format(key))
                    ks = [1, 3, 5]
                    recalls = recall(test_data,
                                     batch_embedder,
                                     model,
                                     config['eval_batch_size'],
                                     ks=[1, 3, 5])
                    recallresult[key] = recalls
                    writer.add_scalar('auc05s/' + key, auc05,
                                      iter_count * train_data.batch_size)
                    for i in range(len(recalls)):
                        writer.add_scalar('recall/R' + str(ks[i]) + '_' + key,
                                          recalls[i],
                                          iter_count * train_data.batch_size)

    # Wrap up
    fout = open('results/' + config['flavor'] + '.txt', 'w')
    fout.write(str(config))
    print('Training complete. Best model saved to {}'.format(save_path))
    print('\nauc result: ')
    for k, v in aucresult.items():
        print('{} , {} '.format(k, v))
        fout.write('{} , {}\n'.format(k, v))
    print('\nrecall:')
    for k, v in recallresult.items():
        print(','.join([k] + [str(x) for x in v]))
        fout.write(','.join([k] + [str(x) for x in v]))
        fout.write('\n')
    writer.close()

    # Move model to cpu and prepare for inference
    model.cpu()
    model.eval()
    fout.close()