Example #1
0
    generate(model, val_iter, persona)


p = Personas()
# Build model, optimizer, and set states
print("Test model", config.model)
model = Transformer(p.vocab, model_file_path=config.save_path, is_eval=False)
# get persona map
filename = 'data/ConvAI2/test_persona_map'
with open(filename, 'rb') as f:
    persona_map = pickle.load(f)

#generate
iterations = 11
weights_original = deepcopy(model.state_dict())
tasks = p.get_personas('test')
for per in tqdm(tasks):
    num_of_dialog = p.get_num_of_dialog(persona=per, split='test')
    for val_dial_index in range(num_of_dialog):
        train_iter, val_iter = p.get_data_loader(persona=per,
                                                 batch_size=config.batch_size,
                                                 split='test',
                                                 fold=val_dial_index)
        persona = []
        for ppp in persona_map[per]:
            persona += ppp
        persona = list(set(persona))
        do_learning(model,
                    train_iter,
                    val_iter,
                    iterations=iterations,
Example #2
0
# Build model, optimizer, and set states
build_model_func_map = {
    'bert2bert': Bert2Bert,
    # 'gpt2gpt': GPT2GPT,
    'bart': Bart,
}
meta_net = build_model_func_map[config.model_type]()
optimizer_map = {
    'sgd': torch.optim.SGD,
    'adam': torch.optim.Adam,
}
meta_optimizer = optimizer_map[config.meta_optimizer](
    meta_net.parameters(), lr=config.meta_lr)

meta_batch_size = config.meta_batch_size
tasks = p.get_personas('train')
tasks_iter = make_infinite_list(tasks)

# meta early stop
patience = 50
if config.fix_dialnum_train:
    patience = 100
best_before_loss = best_meta_loss = 10000000
stop_count = 0
# Main loop
for meta_iteration in range(config.epochs):
    # save original weights to make the update
    # NOTE theta = weights_original
    weights_original = deepcopy(meta_net.state_dict())
    train_loss_before = []
    train_loss_meta = []
Example #3
0
if config.meta_optimizer == 'sgd':
    meta_optimizer = torch.optim.SGD(meta_net.parameters(), lr=config.meta_lr)
elif config.meta_optimizer == 'adam':
    meta_optimizer = torch.optim.Adam(meta_net.parameters(), lr=config.meta_lr)
elif config.meta_optimizer == 'noam':
    meta_optimizer = NoamOpt(
        config.hidden_dim, 1, 4000,
        torch.optim.Adam(meta_net.parameters(),
                         lr=0,
                         betas=(0.9, 0.98),
                         eps=1e-9))
else:
    raise ValueError

meta_batch_size = config.meta_batch_size
tasks = p.get_personas('train')
steps = (len(tasks) //
         meta_batch_size) + int(len(tasks) % meta_batch_size != 0)

# meta early stop
patience = 10
if config.fix_dialnum_train:
    patience = 100
best_loss = 10000000
stop_count = 0
for meta_iteration in range(config.epochs):
    ## save original weights to make the update
    train_loss_before = []
    train_loss_meta = []
    if meta_iteration < 10:
        m = "pretrain"