Exemple #1
0
def train(epoch, model, train_data, valid_data, train_loader, valid_loader, optimizer, criterion, dictionary, bpe, args):
    
    timer = Timer()
    n_it = len(train_loader)
    
    for it, data_idxs in enumerate(train_loader):
        model.train()
        
        data_idxs = np.sort(data_idxs.numpy())
        
        # get batch of raw sentences and raw syntax
        sents_ = train_data[0][data_idxs]
        synts_ = train_data[1][data_idxs]
            
        batch_size = len(sents_)
        
        # initialize tensors
        sents = np.zeros((batch_size, args.max_sent_len), dtype=np.long)    # words without position
        synts = np.zeros((batch_size, args.max_synt_len+2), dtype=np.long)  # syntax
        targs = np.zeros((batch_size, args.max_sent_len+2), dtype=np.long)  # target output
        
        for i in range(batch_size):
            
            # bpe segment and convert to tensor
            sent_ = sents_[i]
            sent_ = bpe.segment(sent_).split()
            sent_ = [dictionary.word2idx[w] if w in dictionary.word2idx else dictionary.word2idx["<unk>"] for w in sent_]
            sents[i, :len(sent_)] = sent_
            
            # add <sos> and <eos> for target output
            targ_ = [dictionary.word2idx["<sos>"]] + sent_ + [dictionary.word2idx["<eos>"]]
            targs[i, :len(targ_)] = targ_
            
            # parse syntax and convert to tensor
            synt_ = synts_[i]
            synt_ = ParentedTree.fromstring(synt_)
            synt_ = deleaf(synt_)
            synt_ = [dictionary.word2idx[f"<{w}>"] for w in synt_ if f"<{w}>" in dictionary.word2idx]
            synt_ = [dictionary.word2idx["<sos>"]] + synt_ + [dictionary.word2idx["<eos>"]]
            synts[i, :len(synt_)] = synt_
            
        sents = torch.from_numpy(sents).cuda()
        synts = torch.from_numpy(synts).cuda()
        targs = torch.from_numpy(targs).cuda()
        
        # forward
        outputs = model(sents, synts, targs)
        
        # calculate loss
        targs_ = targs[:, 1:].contiguous().view(-1)
        outputs_ = outputs.contiguous().view(-1, outputs.size(-1))
        optimizer.zero_grad()
        loss = criterion(outputs_, targs_)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        if it % args.log_interval == 0:
            # print current loss
            valid_loss = evaluate(model, valid_data, valid_loader, criterion, dictionary, bpe, args)
            print("| ep {:2d}/{} | it {:3d}/{} | {:5.2f} s | loss {:.4f} | g_norm {:.6f} | valid loss {:.4f} |".format(
                epoch, args.n_epoch, it, n_it, timer.get_time_from_last(), loss.item(), model.grad_norm, valid_loss))
            
        if it % args.gen_interval == 0:
            # generate output to args.output_dir
            generate(epoch, it, model, valid_data, valid_loader, dictionary, bpe, args)
            
        if it % args.save_interval == 0:
            # save model to args.model_dir
            torch.save(model.state_dict(), os.path.join(args.model_dir, "synpg_epoch{:02d}.pt".format(epoch)))
Exemple #2
0
def train(epoch, dataset, model, tokenizer, optimizer, args):
    timer = Timer()
    n_it = len(train_loader)
    optimizer.zero_grad()

    for it, idxs in enumerate(train_loader):
        total_loss = 0.0
        adv_total_loss = 0.0	    
        model.train()
 
        sent1_token_ids = dataset['sent1'][idxs].cuda()
        synt1_token_ids = dataset['synt1'][idxs].cuda()
        sent2_token_ids = dataset['sent2'][idxs].cuda()
        synt2_token_ids = dataset['synt2'][idxs].cuda()
        synt1_bow = dataset['synt1bow'][idxs].cuda()
        synt2_bow = dataset['synt2bow'][idxs].cuda()

        # optimize adv        
        # sent1 adv
        outputs = model.forward_adv(sent1_token_ids)
        targs = synt1_bow
        loss = adv_criterion(outputs, targs)
        loss.backward()
        adv_total_loss += loss.item()

        
        # sent2 adv
        outputs = model.forward_adv(sent2_token_ids)
        targs = synt2_bow
        loss = adv_criterion(outputs, targs)
        loss.backward()
        adv_total_loss += loss.item()
              
        if (it+1) % args.accumulation_steps == 0:
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            if epoch > 1:
                adv_optimizer.step()
            adv_optimizer.zero_grad()

        # optimize model
        # sent1->sent2 para & sent1 adv        
        outputs, adv_outputs = model(torch.cat((sent1_token_ids, synt2_token_ids), 1), sent2_token_ids)
        targs = sent2_token_ids[:, 1:].contiguous().view(-1)
        outputs = outputs.contiguous().view(-1, outputs.size(-1))
        adv_targs = synt1_bow
        loss = para_criterion(outputs, targs) 
        if epoch > 1:
            loss -= 0.1 * adv_criterion(adv_outputs, adv_targs)
        loss.backward()        
        total_loss += loss.item()
        
        # sent2->sent1 para & sent2 adv
        outputs, adv_outputs = model(torch.cat((sent2_token_ids, synt1_token_ids), 1), sent1_token_ids)
        targs = sent1_token_ids[:, 1:].contiguous().view(-1)
        outputs = outputs.contiguous().view(-1, outputs.size(-1))
        adv_targs = synt2_bow
        loss = para_criterion(outputs, targs)
        if epoch > 1:
            loss -= 0.1 * adv_criterion(adv_outputs, adv_targs)
        loss.backward()        
        total_loss += loss.item()
        

        if (it+1) % args.accumulation_steps == 0:
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            optimizer.zero_grad()
       
        if (it+1) % args.log_interval == 0 or it == 0:
            para_1_2_loss, para_2_1_loss, adv_1_loss, adv_2_loss = evaluate(model, tokenizer, args)
            valid_loss = para_1_2_loss + para_2_1_loss - 0.1 * adv_1_loss - 0.1 * adv_2_loss
            print("| ep {:2d}/{} | it {:3d}/{} | {:5.2f} s | adv loss {:.4f} | loss {:.4f} | para 1-2 loss {:.4f} | para 2-1 loss {:.4f} | adv 1 loss {:.4f} | adv 2 loss {:.4f} | valid loss {:.4f} |".format(
                    epoch, args.n_epoch, it+1, n_it, timer.get_time_from_last(), adv_total_loss, total_loss, para_1_2_loss, para_2_1_loss, adv_1_loss, adv_2_loss, valid_loss))