def main(params): # Create vocabulary and author index saved_model = torch.load(params['model']) if 'misc' in saved_model: misc = saved_model['misc'] char_to_ix = misc['char_to_ix'] auth_to_ix = misc['auth_to_ix'] ix_to_char = misc['ix_to_char'] ix_to_auth = misc['ix_to_auth'] else: char_to_ix = saved_model['char_to_ix'] auth_to_ix = saved_model['auth_to_ix'] ix_to_char = saved_model['ix_to_char'] ix_to_auth = saved_model['ix_to_auth'] cp_params = saved_model['arch'] if params['softmax_scale']: cp_params['softmax_scale'] = params['softmax_scale'] dp = DataProvider(cp_params) if params['m_type'] == 'generative': model = CharLstm(cp_params) else: model = CharTranslator(cp_params) # set to train mode, this activates dropout model.eval() auth_colors = ['red', 'blue'] startc = dp.data['configs']['start'] endc = dp.data['configs']['end'] append_tensor = np.zeros((1, 1), dtype=np.int) append_tensor[0, 0] = char_to_ix[startc] append_tensor = torch.LongTensor(append_tensor).cuda() # Restore saved checkpoint model.load_state_dict(saved_model['state_dict']) hidden = model.init_hidden(1) jc = '' if cp_params.get('atoms','char') == 'char' else ' ' for i in xrange(params['num_samples']): c_aid = np.random.choice(auth_to_ix.values()) if params['m_type'] == 'generative': batch = dp.get_random_string(slen = params['seed_length'], split=params['split']) else: batch = dp.get_sentence_batch(1,split=params['split'], atoms=cp_params.get('atoms','char'), aid=ix_to_auth[c_aid]) inps, targs, auths, lens = dp.prepare_data(batch, char_to_ix, auth_to_ix, maxlen=cp_params['max_seq_len']) auths_inp = 1 - auths if params['flip'] else auths outs = adv_forward_pass(model, inps, lens, end_c=char_to_ix[endc], maxlen=cp_params['max_seq_len'], auths=auths_inp, cycle_compute=params['show_rev'], append_symb=append_tensor) #char_outs = model.forward_gen(inps, hidden, auths_inp, n_max = cp_params['max_len'],end_c=char_to_ix['.']) print '--------------------------------------------' #print 'Translate from %s to %s'%(batch[0]['author'], ix_to_auth[auths_inp[0]]) print colored('Inp %6s: '%(ix_to_auth[auths[0]]),'green') + colored('%s'%(jc.join([ix_to_char[c[0]] for c in inps[1:]])),auth_colors[auths[0]]) print colored('Out %6s: '%(ix_to_auth[auths_inp[0]]),'grey')+ colored('%s'%(jc.join([ix_to_char[c.data.cpu()[0]] for c in outs[0] if c.data.cpu()[0] in ix_to_char])),auth_colors[auths_inp[0]]) if params['show_rev']: print colored('Rev %6s: '%(ix_to_auth[auths[0]]),'green')+ colored('%s'%(jc.join([ix_to_char[c.data.cpu()[0]] for c in outs[-1] if c.data.cpu()[0] in ix_to_char])),auth_colors[auths[0]])
def main(params): # Create vocabulary and author index saved_model = torch.load(params['genmodel']) cp_params = saved_model['arch'] if params['evalmodel']: eval_model = torch.load(params['evalmodel']) eval_params = eval_model['arch'] eval_state = eval_model['state_dict'] else: print "FIX THIS" return if 'misc' in saved_model: misc = saved_model['misc'] char_to_ix = misc['char_to_ix'] auth_to_ix = misc['auth_to_ix'] ix_to_char = misc['ix_to_char'] ix_to_auth = misc['ix_to_auth'] else: char_to_ix = saved_model['char_to_ix'] auth_to_ix = saved_model['auth_to_ix'] ix_to_char = saved_model['ix_to_char'] if 'ix_to_auth' in saved_model: ix_to_auth = saved_model['ix_to_auth'] else: ix_to_auth = {auth_to_ix[a]:a for a in auth_to_ix} dp = DataProvider(cp_params) if params['softmax_scale']: cp_params['softmax_scale'] = params['softmax_scale'] modelGen = CharTranslator(cp_params) modelEval = CharLstm(eval_params) startc = dp.data['configs']['start'] endc = dp.data['configs']['end'] modelGen.eval() modelEval.eval() # Restore saved checkpoint modelGen.load_state_dict(saved_model['state_dict']) state = modelEval.state_dict() state.update(eval_state) modelEval.load_state_dict(state) append_tensor = np.zeros((1, 1), dtype=np.int) append_tensor[0, 0] = char_to_ix[startc] append_tensor = torch.LongTensor(append_tensor).cuda() accum_diff_eval = [[],[]] accum_err_eval = np.zeros(len(auth_to_ix)) accum_err_real = np.zeros(len(auth_to_ix)) accum_count_gen = np.zeros(len(auth_to_ix)) accum_recall_forward = np.zeros(len(auth_to_ix)) accum_prec_forward = np.zeros(len(auth_to_ix)) accum_recall_rev = np.zeros(len(auth_to_ix)) accum_prec_rev = np.zeros(len(auth_to_ix)) jc = '' if cp_params.get('atoms','char') == 'char' else ' ' result = {'docs':[], 'misc':{'auth_to_ix':auth_to_ix, 'ix_to_auth':ix_to_auth}, 'cp_params':cp_params, 'params': params} id_to_ix = {} for i,iid in enumerate(dp.splits[params['split']]): result['docs'].append({'sents':[], 'author':dp.data['docs'][iid][dp.athstr], 'id':iid}) if 'attrib' in dp.data['docs'][iid]: result['docs'][-1]['attrib'] = dp.data['docs'][iid]['attrib'] id_to_ix[iid] = i n_samp = params['n_samples'] for i, b_data in tqdm(enumerate(dp.iter_sentences_bylen(split=params['split'], atoms=cp_params.get('atoms','word'), batch_size = params['batch_size'], auths = auth_to_ix.keys()))): if i > params['num_batches'] and params['num_batches']>0: break; #for i in xrange(params['num_batches']): #c_aid = np.random.choice(auth_to_ix.values()) #batch = dp.get_sentence_batch(1,split=params['split'], atoms=cp_params.get('atoms','char'), aid=ix_to_auth[c_aid]) c_bsz = len(b_data[0]) done = b_data[1] inps, targs, auths, lens = dp.prepare_data(b_data[0], char_to_ix, auth_to_ix, maxlen=cp_params['max_seq_len']) # outs are organized as auths_inp = 1 - auths if params['flip'] else auths outs = adv_forward_pass(modelGen, modelEval, inps, lens, end_c=char_to_ix[endc], maxlen=cp_params['max_seq_len'], auths=auths_inp, cycle_compute=params['show_rev'], append_symb=append_tensor, n_samples=params['n_samples']) eval_out_gt = modelEval.forward_classify(targs, lens=lens, compute_softmax=True) auths_inp = auths_inp.numpy() i_bsz = np.arange(c_bsz) real_aid_out = eval_out_gt[0].data.cpu().numpy()[i_bsz, auths_inp] gen_scores = outs[0].view(n_samp,c_bsz,-1) gen_aid_out = gen_scores.cpu().numpy()[:,i_bsz, auths_inp] gen_char = [v.view(n_samp,c_bsz) for v in outs[1]] gen_lens = outs[2].view(n_samp,c_bsz) np.add.at(accum_err_eval, auths_inp, gen_aid_out[0,:] >=0.5) np.add.at(accum_err_real, auths_inp, real_aid_out >=0.5) np.add.at(accum_count_gen,auths_inp,1) for b in xrange(inps.size()[1]): inpset = set(inps[:,b].tolist()[:lens[b]]) ; samples = [] accum_diff_eval[auths_inp[b]].append(gen_aid_out[0,b] - real_aid_out[b]) for si in xrange(n_samp): genset = set([c[si, b] for c in gen_char[:gen_lens[si,b]]]); accum_recall_forward[auths_inp[b]] += (float(len(genset & inpset)) / float(len(inpset))) accum_prec_forward[auths_inp[b]] += (float(len(genset & inpset)) / float(len(genset))) if params['show_rev']: revgenset = set([c[b] for c in outs[-2][:outs[-1][b]] ]) accum_recall_rev[auths_inp[b]] += (float(len(revgenset & inpset)) / float(len(inpset))) accum_prec_rev[auths_inp[b]] += (float(len(revgenset & inpset)) / float(len(revgenset))) inp_text = jc.join([ix_to_char[c] for c in targs[:,b] if c in ix_to_char]) trans_text = jc.join([ix_to_char[c.cpu()[si,b]] for c in gen_char[:gen_lens[si,b]] if c.cpu()[si,b] in ix_to_char]) samples.append({'sent':inp_text,'score':eval_out_gt[0][b].data.cpu().tolist(), 'trans': trans_text, 'trans_score':gen_scores[si,b].cpu().tolist(),'sid':b_data[0][b]['sid']}) result['docs'][id_to_ix[b_data[0][b]['id']]]['sents'].append(samples) if params['print']: print '--------------------------------------------' print 'Author: %s'%(b_data[0][0]['author']) print 'Inp text %s: %s (%.2f)'%(ix_to_auth[auths[0]], jc.join([ix_to_char[c[0]] for c in inps[1:]]), real_aid_out[0]) print 'Out text %s: %s (%.2f)'%(ix_to_auth[auths_inp[0]],jc.join([ix_to_char[c.cpu()[0]] for c in outs[1] if c.cpu()[0] in ix_to_char]), gen_aid_out[0]) if params['show_rev']: print 'Rev text %s: '%(ix_to_auth[auths[0]])+ '%s'%(jc.join([ix_to_char[c.cpu()[0]] for c in outs[-2] if c.cpu()[0] in ix_to_char])) #else: # print '%d/%d\r'%(i, params['num_batches']), err_a1, err_a2 = accum_err_eval[0]/(1e-5+accum_count_gen[0]), accum_err_eval[1]/(1e-5+accum_count_gen[1]) err_real_a1, err_real_a2 = accum_err_real[0]/(1e-5+accum_count_gen[0]), accum_err_real[1]/(1e-5+accum_count_gen[1]) print '--------------------------------------------' print 'Efficiency in fooling discriminator' print '--------------------------------------------' print(' erra1 {:3.2f} - erra2 {:3.2f}'.format(100.*err_a1, 100.*err_a2)) print(' err_real_a1 {:3.2f} - err_real_a2 {:3.2f}'.format(100.*err_real_a1, 100.*err_real_a2)) print(' count %d - %d'%(accum_count_gen[0], accum_count_gen[1])) diff_arr0, diff_arr1 = np.array(accum_diff_eval[0]), np.array(accum_diff_eval[1]) print 'Mean difference : translation to %s = %.2f , translation to %s = %.2f '%(ix_to_auth[0], diff_arr0.mean(), ix_to_auth[1], diff_arr1.mean()) print 'Difference > 0 : translation to %s = %.2f%%, translation to %s = %.2f%% '%(ix_to_auth[0], 100.*(diff_arr0>0).sum()/(1e-5+diff_arr0.shape[0]), ix_to_auth[1], 100.*(diff_arr1>0).sum()/(1e-5+diff_arr1.shape[0])) print 'Difference < 0 : translation to %s = %.2f%%, translation to %s = %.2f%% '%(ix_to_auth[0], 100.*(diff_arr0<0).sum()/(1e-5+diff_arr0.shape[0]), ix_to_auth[1], 100.*(diff_arr1<0).sum()/(1e-5+diff_arr1.shape[0])) print '\n--------------------------------------------' print 'Consistencey with the input text' print '--------------------------------------------' print 'Generated text A0- Precision = %.2f, Recall = %.2f'%(accum_prec_forward[0]/accum_count_gen[0], accum_recall_forward[0]/accum_count_gen[0] ) print 'Generated text A1- Precision = %.2f, Recall = %.2f'%(accum_prec_forward[1]/accum_count_gen[1], accum_recall_forward[1]/accum_count_gen[1] ) if params['show_rev']: print '\n' print 'Reconstr text A0- Precision = %.2f, Recall = %.2f'%(accum_prec_rev[0]/accum_count_gen[0], accum_recall_rev[0]/accum_count_gen[0] ) print 'Reconstr text A1- Precision = %.2f, Recall = %.2f'%(accum_prec_rev[1]/accum_count_gen[1], accum_recall_rev[1]/accum_count_gen[1] ) print '\n--------------------------------------------' print 'Document Level Scores' print '--------------------------------------------' doc_accuracy = np.zeros(len(auth_to_ix)) doc_accuracy_trans = np.zeros(len(auth_to_ix)) doc_count = np.zeros(len(auth_to_ix)) for doc in result['docs']: doc_score_orig = np.array([0.,0.]) doc_score_trans = np.array([0.,0.]) for st in doc['sents']: doc_score_orig += np.log(st[0]['score']) doc_score_trans += np.log(st[0]['trans_score']) doc_accuracy[auth_to_ix[doc['author']]] += float(doc_score_orig.argmax() == auth_to_ix[doc['author']]) doc_accuracy_trans[auth_to_ix[doc['author']]] += float(doc_score_trans.argmax() == auth_to_ix[doc['author']]) doc_count[auth_to_ix[doc['author']]] += 1. print 'Original data' print '-------------' print 'Doc accuracy is %s : %.2f , %s : %.2f'%(ix_to_auth[0], (doc_accuracy[0]/doc_count[0]),ix_to_auth[1], (doc_accuracy[1]/doc_count[1]) ) fp = doc_count[1]- doc_accuracy[1] recall = doc_accuracy[0]/doc_count[0] precision = doc_accuracy[0]/(doc_accuracy[0]+fp) f1score = 2.*(precision*recall)/(precision+recall) print 'Precision is %.2f : Recall is %.2f , F1-score is %.2f'%(precision, recall, f1score) print '\nTranslated data' print '-----------------' print 'Doc accuracy is %s : %.2f , %s : %.2f'%(ix_to_auth[0], (doc_accuracy_trans[0]/doc_count[0]),ix_to_auth[1], (doc_accuracy_trans[1]/doc_count[1]) ) fp = doc_count[1]- doc_accuracy_trans[1] recall = doc_accuracy_trans[0]/doc_count[0] precision = doc_accuracy_trans[0]/(doc_accuracy_trans[0]+fp) f1score = 2.*(precision*recall)/(precision+recall) print 'Precision is %.2f : Recall is %.2f , F1-score is %.2f'%(precision, recall, f1score) if params['dumpjson']: json.dump(result, open(params['dumpjson'],'w'))
def main(params): dp = DataProvider(params) # Create vocabulary and author index if params['resume'] == None: if params['atoms'] == 'char': char_to_ix, ix_to_char = dp.createCharVocab( params['vocab_threshold']) else: char_to_ix, ix_to_char = dp.createWordVocab( params['vocab_threshold']) auth_to_ix, ix_to_auth = dp.createAuthorIdx() else: saved_model = torch.load(params['resume']) char_to_ix = saved_model['char_to_ix'] auth_to_ix = saved_model['auth_to_ix'] ix_to_auth = saved_model['ix_to_auth'] ix_to_char = saved_model['ix_to_char'] params['vocabulary_size'] = len(char_to_ix) params['num_output_layers'] = len(auth_to_ix) model = CharTranslator(params) # set to train mode, this activates dropout model.train() #Initialize the RMSprop optimizer if params['use_sgd']: optim = torch.optim.SGD(model.parameters(), lr=params['learning_rate'], momentum=params['decay_rate']) else: optim = torch.optim.RMSprop(model.parameters(), lr=params['learning_rate'], alpha=params['decay_rate'], eps=params['smooth_eps']) # Loss function if params['mode'] == 'generative': criterion = nn.CrossEntropyLoss() else: criterion = nn.NLLLoss() # Restore saved checkpoint if params['resume'] != None: model.load_state_dict(saved_model['state_dict']) optim.load_state_dict(saved_model['optimizer']) total_loss = 0. start_time = time.time() hidden = model.init_hidden(params['batch_size']) hidden_zeros = model.init_hidden(params['batch_size']) # Initialize the cache if params['randomize_batches']: dp.set_hid_cache(range(len(dp.data['docs'])), hidden_zeros) # Compute the iteration parameters epochs = params['max_epochs'] total_seqs = dp.get_num_sents(split='train') iter_per_epoch = total_seqs // params['batch_size'] total_iters = iter_per_epoch * epochs best_loss = 1000000. best_val = 1000. eval_every = int(iter_per_epoch * params['eval_interval']) #val_score = eval_model(dp, model, params, char_to_ix, auth_to_ix, split='val', max_docs = params['num_eval']) val_score = 0. #eval_model(dp, model, params, char_to_ix, auth_to_ix, split='val', max_docs = params['num_eval']) val_rank = 1000 eval_function = eval_translator if params[ 'mode'] == 'generative' else eval_classify leakage = 0. #params['leakage'] print total_iters for i in xrange(total_iters): #TODO if params['split_generators']: c_aid = ix_to_auth[np.random.choice(auth_to_ix.values())] else: c_aid = None batch = dp.get_sentence_batch(params['batch_size'], split='train', atoms=params['atoms'], aid=c_aid, sample_by_len=params['sample_by_len']) inps, targs, auths, lens = dp.prepare_data( batch, char_to_ix, auth_to_ix, maxlen=params['max_seq_len']) # Reset the hidden states for which new docs have been sampled # Starting each batch, we detach the hidden state from how it was previously produced. # If we didn't, the model would try backpropagating all the way to start of the dataset. hidden = repackage_hidden(hidden) optim.zero_grad() #TODO if params['mode'] == 'generative': output, _ = model.forward_mltrain(inps, lens, inps, lens, hidden_zeros, auths=auths) targets = pack_padded_sequence(Variable(targs).cuda(), lens) loss = criterion(pack_padded_sequence(output, lens)[0], targets[0]) else: # for classifier auths is the target output, hidden = model.forward_classify(inps, hidden, compute_softmax=True) targets = Variable(auths).cuda() loss = criterion(output, targets) loss.backward() # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm(model.parameters(), params['grad_clip']) # Take an optimization step optim.step() total_loss += loss.data.cpu().numpy()[0] # Save the hidden states in cache for later use if i % eval_every == 0 and i > 0: val_rank, val_score = eval_function(dp, model, params, char_to_ix, auth_to_ix, split='val') #if i % iter_per_epoch == 0 and i > 0 and leakage > params['leakage_min']: # leakage = leakage * params['leakage_decay'] #if (i % iter_per_epoch == 0) and ((i//iter_per_epoch) >= params['lr_decay_st']): if i % params['log_interval'] == 0 and i > 0: cur_loss = total_loss / params['log_interval'] elapsed = time.time() - start_time print( '| epoch {:2.2f} | {:5d}/{:5d} batches | lr {:02.2e} | ms/batch {:5.2f} | ' 'loss {:5.2f} | ppl {:8.2f}'.format( float(i) / iter_per_epoch, i, total_iters, params['learning_rate'], elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss))) total_loss = 0. if val_rank <= best_val: save_checkpoint( { 'iter': i, 'arch': params, 'val_loss': val_rank, 'val_pplx': val_score, 'char_to_ix': char_to_ix, 'ix_to_char': ix_to_char, 'auth_to_ix': auth_to_ix, 'ix_to_auth': ix_to_auth, 'state_dict': model.state_dict(), 'loss': cur_loss, 'optimizer': optim.state_dict(), }, fappend=params['fappend'], outdir=params['checkpoint_output_directory']) best_val = val_rank start_time = time.time()
def main(params): dp = DataProvider(params) # Create vocabulary and author index if params['resume'] == None: if params['atoms'] == 'char': char_to_ix, ix_to_char = dp.create_char_vocab( params['vocab_threshold']) else: char_to_ix, ix_to_char = dp.create_word_vocab( params['vocab_threshold']) auth_to_ix, ix_to_auth = dp.create_author_idx() else: saved_model = torch.load(params['resume']) char_to_ix = saved_model['char_to_ix'] auth_to_ix = saved_model['auth_to_ix'] ix_to_char = saved_model['ix_to_char'] params['vocabulary_size'] = len(char_to_ix) params['num_output_layers'] = len(auth_to_ix) print params['vocabulary_size'], params['num_output_layers'] model = get_classifier(params) # set to train mode, this activates dropout model.train() # Initialize the RMSprop optimizer if params['use_sgd']: optim = torch.optim.SGD(model.parameters(), lr=params['learning_rate'], momentum=params['decay_rate']) else: optim = torch.optim.RMSprop([{ 'params': [p[1] for p in model.named_parameters() if p[0] != 'decoder_W'] }, { 'params': model.decoder_W, 'weight_decay': 0.000 }], lr=params['learning_rate'], alpha=params['decay_rate'], eps=params['smooth_eps']) # Loss function if len(params['balance_loss']) == 0: criterion = nn.CrossEntropyLoss() else: criterion = nn.CrossEntropyLoss( torch.FloatTensor(params['balance_loss']).cuda()) # Restore saved checkpoint if params['resume'] != None: model.load_state_dict(saved_model['state_dict']) # optim.load_state_dict(saved_model['optimizer']) total_loss = 0. class_loss = 0. start_time = time.time() hidden = model.init_hidden(params['batch_size']) hidden_zeros = model.init_hidden(params['batch_size']) # Initialize the cache if params['randomize_batches']: dp.set_hid_cache(range(len(dp.data['docs'])), hidden_zeros) # Compute the iteration parameters epochs = params['max_epochs'] total_seqs = dp.get_num_sents(split='train') iter_per_epoch = total_seqs // params['batch_size'] total_iters = iter_per_epoch * epochs best_loss = 0. best_val = 1000. eval_every = int(iter_per_epoch * params['eval_interval']) # val_score = eval_model(dp, model, params, char_to_ix, auth_to_ix, split='val', max_docs = params['num_eval']) val_score = 0. # eval_model(dp, model, params, char_to_ix, auth_to_ix, split='val', max_docs = params['num_eval']) val_rank = 0 eval_function = eval_model if params[ 'mode'] == 'generative' else eval_classify leakage = params['leakage'] for i in xrange(total_iters): # TODO if params['randomize_batches']: batch, reset_next = dp.get_rand_doc_batch(params['batch_size'], split='train') b_ids = [b['id'] for b in batch] hidden = dp.get_hid_cache(b_ids, hidden) elif params['use_sentences']: c_aid = None # ix_to_auth[np.random.choice(auth_to_ix.values())] batch = dp.get_sentence_batch( params['batch_size'], split='train', aid=c_aid, atoms=params['atoms'], sample_by_len=params['sample_by_len']) hidden = hidden_zeros else: batch, reset_h = dp.get_doc_batch(split='train') if len(reset_h) > 0: hidden[0].data.index_fill_(1, torch.LongTensor(reset_h).cuda(), 0.) hidden[1].data.index_fill_(1, torch.LongTensor(reset_h).cuda(), 0.) inps, targs, auths, lens = dp.prepare_data(batch, char_to_ix, auth_to_ix, leakage=leakage) # Reset the hidden states for which new docs have been sampled # Starting each batch, we detach the hidden state from how it was previously produced. # If we didn't, the model would try backpropagating all the way to start of the dataset. hidden = repackage_hidden(hidden) optim.zero_grad() # TODO if params['mode'] == 'generative': output, hidden = model.forward(inps, lens, hidden, auths) targets = pack_padded_sequence(Variable(targs).cuda(), lens) loss = criterion(pack_padded_sequence(output, lens)[0], targets[0]) else: # for classifier auths is the target output, _ = model.forward_classify(targs, hidden, compute_softmax=False, lens=lens) targets = Variable(auths).cuda() lossClass = criterion(output, targets) if params['compression_layer']: loss = lossClass + (model.compression_W.weight.norm( p=1, dim=1)).mean() else: loss = lossClass loss.backward() # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm(model.parameters(), params['grad_clip']) # Take an optimization step optim.step() total_loss += loss.data.cpu().numpy()[0] class_loss += lossClass.data.cpu().numpy()[0] # Save the hidden states in cache for later use if params['randomize_batches']: if len(reset_next) > 0: hidden[0].data.index_fill_(1, torch.LongTensor(reset_next).cuda(), 0.) hidden[1].data.index_fill_(1, torch.LongTensor(reset_next).cuda(), 0.) dp.set_hid_cache(b_ids, hidden) if i % eval_every == 0 and i > 0: val_rank, val_score = eval_function(dp, model, params, char_to_ix, auth_to_ix, split='val', max_docs=params['num_eval']) if i % iter_per_epoch == 0 and i > 0 and leakage > params[ 'leakage_min']: leakage = leakage * params['leakage_decay'] # if (i % iter_per_epoch == 0) and ((i//iter_per_epoch) >= params['lr_decay_st']): if i % params['log_interval'] == 0 and i > 0: cur_loss = total_loss / params['log_interval'] class_loss = class_loss / params['log_interval'] elapsed = time.time() - start_time print( '| epoch {:3.2f} | {:5d}/{:5d} batches | lr {:02.2e} | ms/batch {:5.2f} | ' 'loss {:5.2f} | ppl {:8.2f}'.format( float(i) / iter_per_epoch, i, total_iters, params['learning_rate'], elapsed * 1000 / args.log_interval, cur_loss, math.exp(class_loss))) if val_rank >= best_loss: best_loss = val_rank save_checkpoint( { 'iter': i, 'arch': params, 'val_mean_rank': val_rank, 'val_auc': val_score, 'char_to_ix': char_to_ix, 'ix_to_char': ix_to_char, 'auth_to_ix': auth_to_ix, 'state_dict': model.state_dict(), 'loss': cur_loss, 'optimizer': optim.state_dict(), }, fappend=params['fappend'], outdir=params['checkpoint_output_directory']) best_val = val_rank start_time = time.time() total_loss = 0. class_loss = 0.
def main(params): # Create vocabulary and author index saved_model = torch.load(params['model']) if 'misc' in saved_model: misc = saved_model['misc'] char_to_ix = misc['char_to_ix'] auth_to_ix = misc['auth_to_ix'] ix_to_char = misc['ix_to_char'] ix_to_auth = misc['ix_to_auth'] else: char_to_ix = saved_model['char_to_ix'] auth_to_ix = saved_model['auth_to_ix'] ix_to_char = saved_model['ix_to_char'] ix_to_auth = saved_model['ix_to_auth'] cp_params = saved_model['arch'] if params['softmax_scale']: cp_params['softmax_scale'] = params['softmax_scale'] dp = DataProvider(cp_params) if params['m_type'] == 'generative': model = CharLstm(cp_params) else: model = CharTranslator(cp_params) # set to train mode, this activates dropout model.eval() auth_colors = ['red', 'blue'] startc = dp.data['configs']['start'] endc = dp.data['configs']['end'] append_tensor = np.zeros((1, 1), dtype=np.int) append_tensor[0, 0] = char_to_ix[startc] append_tensor = torch.LongTensor(append_tensor).cuda() # Restore saved checkpoint model.load_state_dict(saved_model['state_dict']) hidden = model.init_hidden(1) jc = '' if cp_params.get('atoms', 'char') == 'char' else ' ' for i in range(params['num_samples']): c_aid = np.random.choice(list(auth_to_ix.values())) if params['m_type'] == 'generative': batch = dp.get_random_string(slen=params['seed_length'], split=params['split']) else: batch = dp.get_sentence_batch(1, split=params['split'], atoms=cp_params.get('atoms', 'char'), aid=ix_to_auth[c_aid]) inps, targs, auths, lens = dp.prepare_data( batch, char_to_ix, auth_to_ix, maxlen=cp_params['max_seq_len']) auths_inp = 1 - auths if params['flip'] else auths forward, backward = adv_forward_pass(model, inps, lens, end_c=char_to_ix[endc], maxlen=cp_params['max_seq_len'], auths=auths_inp, cycle_compute=params['show_rev'], append_symb=append_tensor) # char_outs = model.forward_gen(inps, hidden, auths_inp, n_max = cp_params['max_len'],end_c=char_to_ix['.']) print('--------------------------------------------') print('Translate from %s to %s' % (batch[0]['author'], ix_to_auth[auths_inp.item()])) # General helper functions # Clears whitespace but retains character for re.sub def strip_match(match): return match.group(0).strip() # Joins together decimals def fix_decimals(match): match = match.group(0) return re.sub('\s', '', match) # Cleans text by removing unnecessary whitespace and substituting back in some symbols def clean_text(text): text = re.sub('-lrb-', '(', text) text = re.sub('-rrb-', ')', text) text = re.sub('-lsb-', '[', text) text = re.sub('-rsb-', ']', text) text = re.sub('-lcb-', '{', text) text = re.sub('-rcb-', '}', text) text = re.sub('\'\'', '\"', text) text = re.sub('\si\s', ' I ', text) text = re.sub('^i\s', 'I ', text) text = re.sub('\sna\s', 'na ', text) text = re.sub('\$\s', strip_match, text) text = re.sub('[-#]\s|\s([-.!,\':;?]|n\'t)', strip_match, text) text = re.sub('\d+. \d+', fix_decimals, text) return text # Get original sentence and clean it up a bit input_list = [ix_to_char[c.item()] for c in inps[1:]] input_string = jc.join(input_list) input_string = clean_text(input_string) # Get translated sentence and clean it up a bit output_list = [ ix_to_char[c.item()] for c in forward if c.item() in ix_to_char ] if output_list[-1] == 'END': output_list = output_list[:-1] output_string = jc.join(output_list) output_string = clean_text(output_string) print( colored('Inp %6s: ' % (ix_to_auth[auths.item()]), 'green') + colored('%s' % input_string, auth_colors[auths.item()])) print( colored('Out %6s: ' % (ix_to_auth[auths_inp.item()]), 'grey') + colored('%s' % output_string, auth_colors[auths_inp.item()])) if params['show_rev']: print( colored('Rev %6s: ' % (ix_to_auth[auths.item()]), 'green') + colored( '%s' % (jc.join([ ix_to_char[c.item()] for c in backward if c.item() in ix_to_char and ix_to_char[c.item()] != 'END' ])), auth_colors[auths.item()]))