예제 #1
0
파일: MAML.py 프로젝트: GBLin5566/PAML
            break
    return val_loss / len(val_ppl), np.mean(val_ppl)


def do_evaluation(model, test_iter):
    model.eval()
    with torch.no_grad():
        ppl_list, loss_list = [], []
        for batch in test_iter:
            loss, ppl, _ = model(batch)
            loss_list.append(loss)
            ppl_list.append(ppl)
    return np.mean(loss_list), np.mean(ppl_list)


p = Personas()
# Make save_path
path_split = config.save_path.split(os.sep)
if not path_split[-1]:
    path_split.pop(-1)
suffix = {
    "model": config.model_type,
    "lr": config.lr,
    "meta_lr": config.meta_lr,
    "iter_as_step": config.iter_as_step,
}
for key, value in suffix.items():
    path_split[-1] += f"_{key}_{value}"
save_path = f'{os.sep}'.join(path_split)
writer = SummaryWriter(log_dir=save_path)
# Build model, optimizer, and set states
예제 #2
0
파일: interact.py 프로젝트: zequnl/PAML
import os
import time
import ast


def make_batch(inp, vacab):
    temp = [[inp, ['', ''], 0]]
    d = Dataset(temp, vacab)
    loader = torch.utils.data.DataLoader(dataset=d,
                                         batch_size=1,
                                         shuffle=False,
                                         collate_fn=collate_fn)
    return iter(loader).next()


p = Personas()
persona = ast.literal_eval(p.get_task('train'))
print(persona)
model = Transformer(p.vocab, model_file_path=config.save_path, is_eval=True)
t = Translator(model, p.vocab)
print('Start to chat')
while (True):
    msg = input(">>> ")
    if (len(str(msg).rstrip().lstrip()) != 0):
        persona += [str(msg).rstrip().lstrip()]
        batch = make_batch(persona, p.vocab)
        sent_b, batch_scores = t.translate_batch(batch)
        ris = ' '.join([p.vocab.index2word[idx] for idx in sent_b[0][0]
                        ]).replace('EOS', '').rstrip().lstrip()
        print(">>>", ris)
        persona += [ris]
예제 #3
0
from utils.data_reader import Personas
from model.transformer import Transformer
from model.common_layer import evaluate
from utils import config
from tqdm import tqdm
import numpy as np


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


p = Personas()

data_loader_tr, data_loader_val, data_loader_test = \
    p.get_all_data(batch_size=config.batch_size)

if (config.test):
    print("Test model", config.model)
    model = Transformer(p.vocab,
                        model_file_path=config.save_path,
                        is_eval=True)
    evaluate(model, data_loader_test, model_name=config.model, ty='test')
    exit(0)

model = Transformer(p.vocab)
print("MODEL USED", config.model)
print("TRAINABLE PARAMETERS", count_parameters(model))

best_ppl = 1000
cnt = 0
예제 #4
0
            print(
                "----------------------------------------------------------------------"
            )
            print(
                "----------------------------------------------------------------------"
            )


def do_learning(model, train_iter, val_iter, iterations, persona):
    for i in range(1, iterations):
        for j, d in enumerate(train_iter):
            _, _, _ = model.train_one_batch(d)
    generate(model, val_iter, persona)


p = Personas()
# Build model, optimizer, and set states
print("Test model", config.model)
model = Transformer(p.vocab, model_file_path=config.save_path, is_eval=False)
# get persona map
filename = 'data/ConvAI2/test_persona_map'
with open(filename, 'rb') as f:
    persona_map = pickle.load(f)

#generate
iterations = 11
weights_original = deepcopy(model.state_dict())
tasks = p.get_personas('test')
for per in tqdm(tasks):
    num_of_dialog = p.get_num_of_dialog(persona=per, split='test')
    for val_dial_index in range(num_of_dialog):
예제 #5
0
파일: CMAML.py 프로젝트: zyDotwei/CMAML
            val_p.append(t_ppl)
        return val_loss, np.mean(val_p)


def do_evaluation(model, test_iter):
    p, l = [], []
    for batch in test_iter:
        loss, ppl, _ = model.train_one_batch(batch, train=False)
        l.append(loss)
        p.append(ppl)
    return np.mean(l), np.mean(p)


#=================================main=================================

p = Personas()
writer = SummaryWriter(log_dir=config.save_path)
# Build model, optimizer, and set states
if not (config.load_frompretrain == 'None'):
    meta_net = Seq2SPG(p.vocab,
                       model_file_path=config.load_frompretrain,
                       is_eval=False)
else:
    meta_net = Seq2SPG(p.vocab)
if config.meta_optimizer == 'sgd':
    meta_optimizer = torch.optim.SGD(meta_net.parameters(), lr=config.meta_lr)
elif config.meta_optimizer == 'adam':
    meta_optimizer = torch.optim.Adam(meta_net.parameters(), lr=config.meta_lr)
elif config.meta_optimizer == 'noam':
    meta_optimizer = NoamOpt(
        config.hidden_dim, 1, 4000,
예제 #6
0
from utils.data_reader import Personas
from model.transformer import Transformer
from model.common_layer import evaluate
from utils import config
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import os
import time 
import numpy as np 

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

p = Personas()

data_loader_tr, data_loader_val, data_loader_test = p.get_all_data(batch_size=config.batch_size)

if(config.test):
    print("Test model",config.model)
    model = Transformer(p.vocab,model_file_path=config.save_path,is_eval=True)
    evaluate(model,data_loader_test,model_name=config.model,ty='test')
    exit(0)

model = Transformer(p.vocab)
print("MODEL USED",config.model)
print("TRAINABLE PARAMETERS",count_parameters(model))

best_ppl = 1000
cnt = 0