def test(): if args.param_file is None: print('please specify the saved param file.') exit(-1) else: print('loading saved parameters from ' + args.param_file + '...') checkpoint = torch.load(args.param_file, map_location=args.device) train_args = checkpoint['args'] voc = checkpoint['voc'] print('done') print('arguments for train:') print(train_args) print('rebuilding model...') model = Set2Seq2Seq(voc.num_words).to(args.device) model.load_state_dict(checkpoint['model']) param_optimizer = train_args.optimiser(model.parameters(), lr=args.learning_rate) decoder_optimizer = train_args.optimiser(model.speaker.decoder.parameters(), lr=args.learning_rate * args.decoder_ratio) param_optimizer.load_state_dict(checkpoint['opt']) decoder_optimizer.load_state_dict(checkpoint['de_opt']) print('done') print('loading test data...') test_set = FruitSeqDataset(voc, dataset_file_path=args.test_file) print('done') test_seq_acc, test_tok_acc, test_loss = eval_model(model, test_set) print("[TEST]Loss: {:.4f}; Seq-level Accuracy: {:.4f}; Tok-level Accuracy: {:.4f}".format( test_loss, test_seq_acc * 100, test_tok_acc * 100) )
def test(): print('building model...') voc = Voc() seq2seq = Seq2Seq(voc.num_words).to(args.device) param_optimizer = args.optimiser(seq2seq.parameters(), lr=args.learning_rate) decoder_optimizer = args.optimiser(seq2seq.decoder.parameters(), lr=args.learning_rate * args.decoder_ratio) print('done') if args.param_file is None: print('please specify the saved param file.') exit(-1) else: print('loading saved parameters from ' + args.param_file + '...') checkpoint = torch.load(args.param_file) seq2seq.load_state_dict(checkpoint['model']) param_optimizer.load_state_dict(checkpoint['opt']) decoder_optimizer.load_state_dict(checkpoint['de_opt']) voc = checkpoint['voc'] print('done') print('loading test data...') test_set = FruitSeqDataset(voc, dataset_file_path=args.test_file) print('done') test_seq_acc, test_tok_acc, test_loss = eval_model(seq2seq, test_set) print( "[TEST]Loss: {:.4f}; Seq-level Accuracy: {:.4f}; Tok-level Accuracy: {:.4f}" .format(test_loss, test_seq_acc * 100, test_tok_acc * 100))
def main(): print('building vocabulary...') voc = Voc() print('done') print('loading data and building batches...') data_set = FruitSeqDataset(voc, dataset_file_path=DATA_FILE, batch_size=1) str_set = data_set.load_stringset(DATA_FILE) print('done') print('rebuilding model from saved parameters in ' + args.param_file + '...') model = Set2Seq2Seq(voc.num_words).to(args.device) checkpoint = torch.load(args.param_file, map_location=args.device) train_args = checkpoint['args'] model.load_state_dict(checkpoint['model']) voc = checkpoint['voc'] print('done') model.eval() print('iterating data set...') out_file = open(OUT_FILE, mode='a') iterate_dataset(model, voc, str_set, data_set, out_file, train_args)
def reproduce_input_hidden_pairs(model, voc, dataset_file_path): repro_dataset = FruitSeqDataset(voc, dataset_file_path=dataset_file_path, batch_size=1) repro_strset = load_stringset(dataset_file_path) def _instr2coordinate_(in_str, voc): coordinate = [] for i in range(voc.num_words - 3): coordinate.append(in_str.count(chr(65 + i))) return np.asarray(coordinate) pair_set = [] for idx, data_batch in enumerate(repro_dataset): input_var = data_batch['input'] input_mask = data_batch['input_mask'] hidden, _ = model.encoder(model.embedding(input_var.t()), input_mask) hidden = hidden.squeeze().detach().cpu().numpy() in_vec = _instr2coordinate_(repro_strset[idx], voc) pair_set.append([in_vec, hidden]) return pair_set
def get_batches4sim_check(voc, dataset_file_path=args.data_file): in_set = FruitSeqDataset.load_stringset(dataset_file_path) batch_set = ChooseDataset(voc, batch_size=1, dataset_file_path=dataset_file_path) return in_set, batch_set
def train(): print('building vocabulary...') voc = Voc() print('done') print('loading data and building batches...') train_set = FruitSeqDataset(voc, dataset_file_path=args.train_file) dev_set = FruitSeqDataset(voc, dataset_file_path=args.dev_file) # test_set = FruitSeqDataset(voc, dataset_file_path=TEST_FILE_PATH) print('done') print('building model...') seq2seq = Seq2Seq(voc.num_words).to(args.device) param_optimizer = args.optimiser(seq2seq.parameters(), lr=args.learning_rate) decoder_optimizer = args.optimiser(seq2seq.decoder.parameters(), lr=args.learning_rate * args.speaker_ratio) if args.param_file is not None: print('\tloading saved parameters from ' + args.param_file + '...') checkpoint = torch.load(args.param_file) seq2seq.load_state_dict(checkpoint['model']) param_optimizer.load_state_dict(checkpoint['opt']) decoder_optimizer.load_state_dict(checkpoint['de_opt']) voc = checkpoint['voc'] print('\tdone') print('done') print('initialising...') start_iteration = 1 print_loss = 0. print_seq_acc = 0. print_tok_acc = 0. max_dev_seq_acc = 0. training_losses = [] training_tok_acc = [] training_seq_acc = [] training_sim = [] eval_tok_acc = [] eval_seq_acc = [] print('done') print('training...') for iter in range(start_iteration, args.iter_num + 1): for idx, data_batch in enumerate(train_set): seq_acc, tok_acc, loss = train_epoch(seq2seq, data_batch, param_optimizer, decoder_optimizer) print_loss += loss print_seq_acc += seq_acc print_tok_acc += tok_acc if iter % args.print_freq == 0: print_loss_avg = print_loss / (args.print_freq * len(train_set)) print_seq_acc_avg = print_seq_acc / (args.print_freq * len(train_set)) print_tok_acc_avg = print_tok_acc / (args.print_freq * len(train_set)) print( "Iteration: {}; Percent complete: {:.1f}%; Avg loss: {:.4f}; Avg seq acc: {:.4f}; Avg tok acc: {:.4f}" .format(iter, iter / args.iter_num * 100, print_loss_avg, print_seq_acc_avg, print_tok_acc_avg)) training_seq_acc.append(print_seq_acc_avg) training_tok_acc.append(print_tok_acc_avg) training_losses.append(print_loss_avg) print_seq_acc = 0. print_tok_acc = 0. print_loss = 0. if iter % args.eval_freq == 0: dev_seq_acc, dev_tok_acc, dev_loss = eval_model(seq2seq, dev_set) if dev_seq_acc > max_dev_seq_acc: max_dev_seq_acc = dev_seq_acc eval_seq_acc.append(dev_seq_acc) eval_tok_acc.append(dev_tok_acc) print( "[EVAL]Iteration: {}; Loss: {:.4f}; Avg Seq Acc: {:.4f}; Avg Tok Acc: {:.4f}; Best Seq Acc: {:.4f}" .format(iter, dev_loss, dev_seq_acc, dev_tok_acc, max_dev_seq_acc)) if iter % args.save_freq == 0: directory = os.path.join(args.save_dir, 'seq2seq') if not os.path.exists(directory): os.makedirs(directory) torch.save( { 'iteration': iter, 'model': seq2seq.state_dict(), 'opt': param_optimizer.state_dict(), 'de_opt': decoder_optimizer.state_dict(), 'loss': loss, 'voc': voc, 'args': args, 'records': { 'training_loss': training_losses, 'training_tok_acc': training_tok_acc, 'training_seq_acc': training_seq_acc, 'training_sim': training_sim, 'eval_tok_acc': eval_tok_acc, 'eval_seq_acc': eval_seq_acc } }, os.path.join( directory, '{}_{}_{}.tar'.format(args.seed, iter, 'checkpoint')))
def train(): print('building vocabulary...') voc = Voc() print('done') print('loading data and building batches...') train_set = FruitSeqDataset(voc, dataset_file_path=args.train_file) dev_set = FruitSeqDataset(voc, dataset_file_path=args.dev_file) learn_set = FruitSeqDataset(voc, dataset_file_path=args.train_file, batch_size=1) print('done') if args.param_file is not None: print('loading saved parameters from ' + args.param_file + '...') checkpoint = torch.load(args.param_file, map_location=args.device) train_args = checkpoint['args'] voc = checkpoint['voc'] print('done') print('arguments for training:') print(train_args) print('rebuilding model...') model = Set2Seq2Seq(voc.num_words).to(args.device) model.load_state_dict(checkpoint['model']) print('\tdone') else: print('building model...') model = Set2Seq2Seq(voc.num_words).to(args.device) print('done') print('preparing data for testing topological similarity...') sim_chk_inset, sim_chk_batchset = get_batches4sim_check(voc, args.data_file) print('done') print('initialising...') start_iteration = 1 training_losses = [] training_tok_acc = [] training_seq_acc = [] training_in_spkh_sim = [] training_in_msg_sim = [] training_in_lish_sim = [] eval_tok_acc = [] eval_seq_acc = [] print('done') in_spk_sim, in_msg_sim, in_lis_sim = sim_check( model, sim_chk_inset, sim_chk_batchset ) print('[SIM]Iteration: {}; In-SpkHidden Sim: {:.4f}; In-Msg Sim: {:.4f}; In-LisHidden Sim: {:.4f}'.format( 0, in_spk_sim, in_msg_sim, in_lis_sim)) print('training...') for iter in range(start_iteration, args.num_generation+1): training_records = train_generation( model, train_set, dev_set, learn_set, sim_chk_inset, sim_chk_batchset, generation_idx=iter ) training_losses += training_records[0] training_tok_acc += training_records[1] training_seq_acc += training_records[2] training_in_spkh_sim += training_records[3] training_in_msg_sim+= training_records[4] training_in_lish_sim += training_records[5] eval_tok_acc += training_records[6] eval_seq_acc += training_records[7] if iter % args.save_freq == 0: path_join = 'set2seq2seq_3phases_' + str(args.num_words) + '_' + args.msg_mode path_join += '_hard' if not args.soft else '_soft' directory = os.path.join(args.save_dir, path_join) if not os.path.exists(directory): os.makedirs(directory) torch.save({ 'generation': iter, 'model': model.state_dict(), 'voc': voc, 'args': args, 'records': { 'training_loss': training_losses, 'training_tok_acc': training_tok_acc, 'training_seq_acc': training_seq_acc, 'training_in_spkh_sim': training_in_spkh_sim, 'training_in_msg_sim': training_in_msg_sim, 'training_in_lish_sim': training_in_lish_sim, 'eval_tok_acc': eval_tok_acc, 'eval_seq_acc': eval_seq_acc } }, os.path.join(directory, '{}_{:.4f}_{}.tar'.format(iter, eval_seq_acc[-1], 'checkpoint')))
def train(): print('building vocabulary...') voc = Voc() print('done') print('loading data and building batches...') train_set = FruitSeqDataset(voc, dataset_file_path=args.train_file) dev_set = FruitSeqDataset(voc, dataset_file_path=args.dev_file) # test_set = FruitSeqDataset(voc, dataset_file_path=TEST_FILE_PATH) print('done') if args.param_file is not None: print('loading saved parameters from ' + args.param_file + '...') checkpoint = torch.load(args.param_file, map_location=args.device) train_args = checkpoint['args'] voc = checkpoint['voc'] print('done') print('arguments for training:') print(train_args) print('rebuilding model...') model = Set2Seq2Seq(voc.num_words).to(args.device) model.load_state_dict(checkpoint['model']) model_optimiser = train_args.optimiser(model.parameters(), lr=train_args.learning_rate) speaker_optimiser = train_args.optimiser(model.speaker.parameters(), lr=train_args.learning_rate * train_args.speaker_ratio) listner_optimiser = train_args.optimiser(model.listener.parameters(), lr=train_args.learning_rate * train_args.speaker_ratio) print('\tdone') else: print('building model...') model = Set2Seq2Seq(voc.num_words).to(args.device) model_optimiser = args.optimiser(model.parameters(), lr=args.learning_rate) speaker_optimiser = args.optimiser(model.speaker.decoder.parameters(), lr=args.learning_rate * args.speaker_ratio) listner_optimiser = args.optimiser(model.listener.parameters(), lr=args.learning_rate * args.listener_ratio) print('done') print('preparing data for testing topological similarity...') sim_chk_inset, sim_chk_batchset = get_batches4sim_check( voc, args.data_file) print('done') print('initialising...') start_iteration = 1 print_loss = 0. print_seq_acc = 0. print_tok_acc = 0. max_dev_seq_acc = 0. max_dev_tok_acc = 0. training_losses = [] training_tok_acc = [] training_seq_acc = [] training_in_spkh_sim = [] training_in_msg_sim = [] training_in_lish_sim = [] training_spkh_lish_sim = [] training_mi = [] eval_tok_acc = [] eval_seq_acc = [] print('done') in_spk_sim, in_msg_sim, in_lis_sim, spk_lis_sim = sim_check( model, sim_chk_inset, sim_chk_batchset) mi_sim = mi_check(model, sim_chk_batchset) print( '[SIM]Iteration: {}; In-SpkH Sim: {:.4f}; In-Msg Sim: {:.4f}; In-LisH Sim: {:.4f}; SpkH-LisH Sim: {:.4f}; In-Msg-MI: {:.4f}' .format(0, in_spk_sim, in_msg_sim, in_lis_sim, spk_lis_sim, mi_sim)) print('training...') for iter in range(start_iteration, args.iter_num + 1): for idx, data_batch in enumerate(train_set): if len(eval_seq_acc) > 10: tau = tau_scheduler(sum(eval_seq_acc[-10:]) / 10.) else: tau = tau_scheduler(0.) seq_acc, tok_acc, loss = train_epoch(model, data_batch, tau, model_optimiser, speaker_optimiser, listner_optimiser) print_loss += loss print_seq_acc += seq_acc print_tok_acc += tok_acc if iter % args.print_freq == 0: print_loss_avg = print_loss / (args.print_freq * len(train_set)) print_seq_acc_avg = print_seq_acc / (args.print_freq * len(train_set)) print_tok_acc_avg = print_tok_acc / (args.print_freq * len(train_set)) print( "Iteration: {}; Percent complete: {:.1f}%; Avg loss: {:.4f}; Avg seq acc: {:.4f}; Avg tok acc: {:.4f}" .format(iter, iter / args.iter_num * 100, print_loss_avg, print_seq_acc_avg, print_tok_acc_avg)) training_seq_acc.append(print_seq_acc_avg) training_tok_acc.append(print_tok_acc_avg) training_losses.append(print_loss_avg) print_seq_acc = 0. print_tok_acc = 0. print_loss = 0. if iter % args.eval_freq == 0: dev_seq_acc, dev_tok_acc, dev_loss = eval_model(model, dev_set) if dev_seq_acc > max_dev_seq_acc: max_dev_seq_acc = dev_seq_acc if dev_tok_acc > max_dev_tok_acc: max_dev_tok_acc = dev_tok_acc eval_seq_acc.append(dev_seq_acc) eval_tok_acc.append(dev_tok_acc) print( "[EVAL]Iteration: {}; Loss: {:.4f}; Avg Seq Acc: {:.4f}; Avg Tok Acc: {:.4f}; Best Seq Acc: {:.4f}" .format(iter, dev_loss, dev_seq_acc, dev_tok_acc, max_dev_seq_acc)) if iter % args.sim_chk_freq == 0: in_spk_sim, in_msg_sim, in_lis_sim, spk_lis_sim = sim_check( model, sim_chk_inset, sim_chk_batchset) mi_sim = mi_check(model, sim_chk_batchset) training_in_spkh_sim.append(in_spk_sim) training_in_msg_sim.append(in_msg_sim) training_in_lish_sim.append(in_lis_sim) training_spkh_lish_sim.append(spk_lis_sim) training_mi.append(mi_sim) print( '[SIM]Iteration: {}; In-SpkH Sim: {:.4f}; In-Msg Sim: {:.4f}; In-LisH Sim: {:.4f}; SpkH-LisH Sim: {:.4f}; In-Msg-MI: {:.4f}' .format(0, in_spk_sim, in_msg_sim, in_lis_sim, spk_lis_sim, mi_sim)) if iter % args.l_reset_freq == 0 and not args.l_reset_freq == -1: model.listener.reset_params() print('[RESET] reset listener') if iter % args.save_freq == 0: path_join = 'set2seq2seq_' + str( args.num_words) + '_' + args.msg_mode path_join += '_hard' if not args.soft else '_soft' directory = os.path.join(args.save_dir, path_join) if not os.path.exists(directory): os.makedirs(directory) torch.save( { 'iteration': iter, 'model': model.state_dict(), 'opt': [ model_optimiser.state_dict(), speaker_optimiser.state_dict(), listner_optimiser.state_dict() ], 'loss': loss, 'voc': voc, 'args': args, 'records': { 'training_loss': training_losses, 'training_tok_acc': training_tok_acc, 'training_seq_acc': training_seq_acc, 'training_in_spkh_sim': training_in_spkh_sim, 'training_in_msg_sim': training_in_msg_sim, 'training_in_lish_sim': training_in_lish_sim, 'training_spkh_lish_sim': training_spkh_lish_sim, 'training_mi': training_mi, 'eval_tok_acc': eval_tok_acc, 'eval_seq_acc': eval_seq_acc } }, os.path.join( directory, '{}_{:.4f}_{}.tar'.format(iter, dev_seq_acc, 'checkpoint')))
def main( model_name='Img2Seq2Choice', dataset_name='ImgChooseDataset', out_file_path='data/tmp.txt', ): if args.param_file is not None: checkpoint = torch.load(args.param_file, map_location=torch.device('cpu')) else: raise ValueError print('rebuilding vocabulary and model...') voc = checkpoint[ 'voc'] if model_name == 'Set2Seq2Seq' or model_name == 'Set2Seq2Choice' else None train_args = checkpoint['args'] print(train_args) if model_name == 'Img2Seq2Choice': model = Img2Seq2Choice(msg_length=train_args.max_msg_len, msg_vocsize=train_args.msg_vocsize, hidden_size=train_args.hidden_size, dropout=train_args.dropout_ratio, msg_mode=train_args.msg_mode).to( torch.device('cpu')) elif model_name == 'Set2Seq2Seq': model = Set2Seq2Seq(voc.num_words, msg_length=train_args.max_msg_len, msg_vocsize=train_args.msg_vocsize, hidden_size=train_args.hidden_size, dropout=train_args.dropout_ratio, msg_mode=train_args.msg_mode).to( torch.device('cpu')) elif model_name == 'Set2Seq2Choice': model = Set2Seq2Choice(voc.num_words, msg_length=train_args.max_msg_len, msg_vocsize=train_args.msg_vocsize, hidden_size=train_args.hidden_size, dropout=train_args.dropout_ratio, msg_mode=train_args.msg_mode).to( torch.device('cpu')) else: raise NotImplementedError model.load_state_dict(checkpoint['model']) model.eval() print('done') print('loading and building batch dataset...') if dataset_name == 'ImgChooseDataset': batch_set = ImgChooseDataset(dataset_dir_path=args.data_file, batch_size=1, device=torch.device('cpu')) in_set = [batch['correct']['label'][0] for batch in batch_set] elif dataset_name == 'FruitSeqDataset': batch_set = FruitSeqDataset(voc, dataset_file_path=args.data_file, batch_size=1, device=torch.device('cpu')) in_set = FruitSeqDataset.load_stringset(args.data_file) elif dataset_name == 'ChooseDataset': batch_set = ChooseDataset(voc, dataset_file_path=args.data_file, batch_size=1, device=torch.device('cpu')) in_set = FruitSeqDataset.load_stringset(args.data_file) print('done') build_listener_training_file(model, in_set, batch_set, out_file_path)