def main(params): # Create vocabulary and author index saved_model = torch.load(params['model']) char_to_ix = saved_model['char_to_ix'] auth_to_ix = saved_model['auth_to_ix'] ix_to_char = saved_model['ix_to_char'] cp_params = saved_model['arch'] dp = DataProvider(cp_params) if params['m_type'] == 'translator': model = CharTranslator(cp_params) else: model = get_classifier(cp_params) # set to train mode, this activates dropout #model.eval() # Restore saved checkpoint model.load_state_dict(saved_model['state_dict']) eval_function = eval_translator if params[ 'm_type'] == 'translator' else eval_model if cp_params[ 'mode'] == 'generative' else eval_classify score = eval_function(dp, model, cp_params, char_to_ix, auth_to_ix, split=params['split'], max_docs=params['num_eval'], dump_scores=params['dump_scores'])
def main(params): # Create vocabulary and author index saved_model = torch.load(params['model']) if 'misc' in saved_model: misc = saved_model['misc'] char_to_ix = misc['char_to_ix'] auth_to_ix = misc['auth_to_ix'] ix_to_char = misc['ix_to_char'] ix_to_auth = misc['ix_to_auth'] else: char_to_ix = saved_model['char_to_ix'] auth_to_ix = saved_model['auth_to_ix'] ix_to_char = saved_model['ix_to_char'] ix_to_auth = saved_model['ix_to_auth'] cp_params = saved_model['arch'] if params['softmax_scale']: cp_params['softmax_scale'] = params['softmax_scale'] dp = DataProvider(cp_params) if params['m_type'] == 'generative': model = CharLstm(cp_params) else: model = CharTranslator(cp_params) # set to train mode, this activates dropout model.eval() auth_colors = ['red', 'blue'] startc = dp.data['configs']['start'] endc = dp.data['configs']['end'] append_tensor = np.zeros((1, 1), dtype=np.int) append_tensor[0, 0] = char_to_ix[startc] append_tensor = torch.LongTensor(append_tensor).cuda() # Restore saved checkpoint model.load_state_dict(saved_model['state_dict']) hidden = model.init_hidden(1) jc = '' if cp_params.get('atoms','char') == 'char' else ' ' for i in xrange(params['num_samples']): c_aid = np.random.choice(auth_to_ix.values()) if params['m_type'] == 'generative': batch = dp.get_random_string(slen = params['seed_length'], split=params['split']) else: batch = dp.get_sentence_batch(1,split=params['split'], atoms=cp_params.get('atoms','char'), aid=ix_to_auth[c_aid]) inps, targs, auths, lens = dp.prepare_data(batch, char_to_ix, auth_to_ix, maxlen=cp_params['max_seq_len']) auths_inp = 1 - auths if params['flip'] else auths outs = adv_forward_pass(model, inps, lens, end_c=char_to_ix[endc], maxlen=cp_params['max_seq_len'], auths=auths_inp, cycle_compute=params['show_rev'], append_symb=append_tensor) #char_outs = model.forward_gen(inps, hidden, auths_inp, n_max = cp_params['max_len'],end_c=char_to_ix['.']) print '--------------------------------------------' #print 'Translate from %s to %s'%(batch[0]['author'], ix_to_auth[auths_inp[0]]) print colored('Inp %6s: '%(ix_to_auth[auths[0]]),'green') + colored('%s'%(jc.join([ix_to_char[c[0]] for c in inps[1:]])),auth_colors[auths[0]]) print colored('Out %6s: '%(ix_to_auth[auths_inp[0]]),'grey')+ colored('%s'%(jc.join([ix_to_char[c.data.cpu()[0]] for c in outs[0] if c.data.cpu()[0] in ix_to_char])),auth_colors[auths_inp[0]]) if params['show_rev']: print colored('Rev %6s: '%(ix_to_auth[auths[0]]),'green')+ colored('%s'%(jc.join([ix_to_char[c.data.cpu()[0]] for c in outs[-1] if c.data.cpu()[0] in ix_to_char])),auth_colors[auths[0]])
def main(params): # Create vocabulary and author index saved_model = torch.load(params['model']) # char_to_ix = saved_model['char_to_ix'] # auth_to_ix = saved_model['auth_to_ix'] # ix_to_char = saved_model['ix_to_char'] cp_params = saved_model['arch'] # dp = DataProvider(cp_params) if params['m_type'] == 'translator': model = CharTranslator(cp_params) else: model = get_classifier(cp_params) # set to train mode, this activates dropout # model.eval() # Restore saved checkpoint model.load_state_dict(saved_model['state_dict'])
def main(params): saved_model = torch.load(params['checkpoint']) cp_params = saved_model['arch'] dp = DataProvider(cp_params) if 'misc' in saved_model: misc = saved_model['misc'] char_to_ix = misc['char_to_ix'] auth_to_ix = misc['auth_to_ix'] ix_to_char = misc['ix_to_char'] ix_to_auth = misc['ix_to_auth'] else: char_to_ix = saved_model['char_to_ix'] auth_to_ix = saved_model['auth_to_ix'] ix_to_char = saved_model['ix_to_char'] ix_to_auth = saved_model['ix_to_auth'] del saved_model total_sents = 0. resf = params['resfile'] res = json.load(open(resf, 'r')) bsz = params['batch_size'] for doc in res['docs']: for st in doc['sents']: total_sents += 1 all_feats = np.zeros((2 * total_sents, 4096), dtype='float16') c_idx = 0 def process_batch(batch, c_idx, featstr='sent_enc'): inps, _, _, lens = dp.prepare_data(batch, char_to_ix, auth_to_ix, maxlen=cp_params['max_seq_len']) enc_out = modelGenEncoder.forward_encode(inps, lens) enc_out = enc_out.data.cpu().numpy().astype('float16') all_feats[c_idx:c_idx + enc_out.shape[0]] = enc_out for i, b in enumerate(batch): res['docs'][b['id']]['sents'][b['sid']][featstr] = c_idx + i c_idx += enc_out.shape[0] return c_idx if params['use_semantic_encoder']: modelGenEncoder = BLSTMEncoder(char_to_ix, ix_to_char, params['glove_path']) encoderState = torch.load(params['use_semantic_encoder']) else: modelGenEncoder = CharTranslator(cp_params, encoder_only=True) encoderState = model_gen_state state = modelGenEncoder.state_dict() for k in encoderState: if k in state: state[k] = encoderState[k] modelGenEncoder.load_state_dict(state) modelGenEncoder.eval() del encoderState batch = [] print ' Processing original text' for i in tqdm(xrange(len(res['docs']))): ix = auth_to_ix[res['docs'][i]['author']] for j in xrange(len(res['docs'][i]['sents'])): st = res['docs'][i]['sents'][j]['sent'].split() if len(st) > 0: batch.append({ 'in': st, 'targ': st, 'author': res['docs'][i]['author'], 'id': i, 'sid': j }) if len(batch) == bsz: c_idx = process_batch(batch, c_idx, featstr='sent_enc') del batch batch = [] if batch: c_idx = process_batch(batch, c_idx, featstr='sent_enc') del batch batch = [] print 'Processing translated text' for i in tqdm(xrange(len(res['docs']))): ix = auth_to_ix[res['docs'][i]['author']] for j in xrange(len(res['docs'][i]['sents'])): st = res['docs'][i]['sents'][j]['trans'].split() if len(st) > 0: batch.append({ 'in': st, 'targ': st, 'author': res['docs'][i]['author'], 'id': i, 'sid': j }) if len(batch) == bsz: c_idx = process_batch(batch, c_idx, featstr='trans_enc') batch = [] if batch: c_idx = process_batch(batch, c_idx, featstr='trans_enc') batch = [] json.dump(res, open(resf, 'w')) np.save('.'.join(resf.split('.')[:-1]) + 'sememb.npy', all_feats)
def main(params): # Create vocabulary and author index saved_model = torch.load(params['genmodel']) cp_params = saved_model['arch'] if params['evalmodel']: eval_model = torch.load(params['evalmodel']) eval_params = eval_model['arch'] eval_state = eval_model['state_dict'] else: print "FIX THIS" return if 'misc' in saved_model: misc = saved_model['misc'] char_to_ix = misc['char_to_ix'] auth_to_ix = misc['auth_to_ix'] ix_to_char = misc['ix_to_char'] ix_to_auth = misc['ix_to_auth'] else: char_to_ix = saved_model['char_to_ix'] auth_to_ix = saved_model['auth_to_ix'] ix_to_char = saved_model['ix_to_char'] if 'ix_to_auth' in saved_model: ix_to_auth = saved_model['ix_to_auth'] else: ix_to_auth = {auth_to_ix[a]:a for a in auth_to_ix} dp = DataProvider(cp_params) if params['softmax_scale']: cp_params['softmax_scale'] = params['softmax_scale'] modelGen = CharTranslator(cp_params) modelEval = CharLstm(eval_params) startc = dp.data['configs']['start'] endc = dp.data['configs']['end'] modelGen.eval() modelEval.eval() # Restore saved checkpoint modelGen.load_state_dict(saved_model['state_dict']) state = modelEval.state_dict() state.update(eval_state) modelEval.load_state_dict(state) append_tensor = np.zeros((1, 1), dtype=np.int) append_tensor[0, 0] = char_to_ix[startc] append_tensor = torch.LongTensor(append_tensor).cuda() accum_diff_eval = [[],[]] accum_err_eval = np.zeros(len(auth_to_ix)) accum_err_real = np.zeros(len(auth_to_ix)) accum_count_gen = np.zeros(len(auth_to_ix)) accum_recall_forward = np.zeros(len(auth_to_ix)) accum_prec_forward = np.zeros(len(auth_to_ix)) accum_recall_rev = np.zeros(len(auth_to_ix)) accum_prec_rev = np.zeros(len(auth_to_ix)) jc = '' if cp_params.get('atoms','char') == 'char' else ' ' result = {'docs':[], 'misc':{'auth_to_ix':auth_to_ix, 'ix_to_auth':ix_to_auth}, 'cp_params':cp_params, 'params': params} id_to_ix = {} for i,iid in enumerate(dp.splits[params['split']]): result['docs'].append({'sents':[], 'author':dp.data['docs'][iid][dp.athstr], 'id':iid}) if 'attrib' in dp.data['docs'][iid]: result['docs'][-1]['attrib'] = dp.data['docs'][iid]['attrib'] id_to_ix[iid] = i n_samp = params['n_samples'] for i, b_data in tqdm(enumerate(dp.iter_sentences_bylen(split=params['split'], atoms=cp_params.get('atoms','word'), batch_size = params['batch_size'], auths = auth_to_ix.keys()))): if i > params['num_batches'] and params['num_batches']>0: break; #for i in xrange(params['num_batches']): #c_aid = np.random.choice(auth_to_ix.values()) #batch = dp.get_sentence_batch(1,split=params['split'], atoms=cp_params.get('atoms','char'), aid=ix_to_auth[c_aid]) c_bsz = len(b_data[0]) done = b_data[1] inps, targs, auths, lens = dp.prepare_data(b_data[0], char_to_ix, auth_to_ix, maxlen=cp_params['max_seq_len']) # outs are organized as auths_inp = 1 - auths if params['flip'] else auths outs = adv_forward_pass(modelGen, modelEval, inps, lens, end_c=char_to_ix[endc], maxlen=cp_params['max_seq_len'], auths=auths_inp, cycle_compute=params['show_rev'], append_symb=append_tensor, n_samples=params['n_samples']) eval_out_gt = modelEval.forward_classify(targs, lens=lens, compute_softmax=True) auths_inp = auths_inp.numpy() i_bsz = np.arange(c_bsz) real_aid_out = eval_out_gt[0].data.cpu().numpy()[i_bsz, auths_inp] gen_scores = outs[0].view(n_samp,c_bsz,-1) gen_aid_out = gen_scores.cpu().numpy()[:,i_bsz, auths_inp] gen_char = [v.view(n_samp,c_bsz) for v in outs[1]] gen_lens = outs[2].view(n_samp,c_bsz) np.add.at(accum_err_eval, auths_inp, gen_aid_out[0,:] >=0.5) np.add.at(accum_err_real, auths_inp, real_aid_out >=0.5) np.add.at(accum_count_gen,auths_inp,1) for b in xrange(inps.size()[1]): inpset = set(inps[:,b].tolist()[:lens[b]]) ; samples = [] accum_diff_eval[auths_inp[b]].append(gen_aid_out[0,b] - real_aid_out[b]) for si in xrange(n_samp): genset = set([c[si, b] for c in gen_char[:gen_lens[si,b]]]); accum_recall_forward[auths_inp[b]] += (float(len(genset & inpset)) / float(len(inpset))) accum_prec_forward[auths_inp[b]] += (float(len(genset & inpset)) / float(len(genset))) if params['show_rev']: revgenset = set([c[b] for c in outs[-2][:outs[-1][b]] ]) accum_recall_rev[auths_inp[b]] += (float(len(revgenset & inpset)) / float(len(inpset))) accum_prec_rev[auths_inp[b]] += (float(len(revgenset & inpset)) / float(len(revgenset))) inp_text = jc.join([ix_to_char[c] for c in targs[:,b] if c in ix_to_char]) trans_text = jc.join([ix_to_char[c.cpu()[si,b]] for c in gen_char[:gen_lens[si,b]] if c.cpu()[si,b] in ix_to_char]) samples.append({'sent':inp_text,'score':eval_out_gt[0][b].data.cpu().tolist(), 'trans': trans_text, 'trans_score':gen_scores[si,b].cpu().tolist(),'sid':b_data[0][b]['sid']}) result['docs'][id_to_ix[b_data[0][b]['id']]]['sents'].append(samples) if params['print']: print '--------------------------------------------' print 'Author: %s'%(b_data[0][0]['author']) print 'Inp text %s: %s (%.2f)'%(ix_to_auth[auths[0]], jc.join([ix_to_char[c[0]] for c in inps[1:]]), real_aid_out[0]) print 'Out text %s: %s (%.2f)'%(ix_to_auth[auths_inp[0]],jc.join([ix_to_char[c.cpu()[0]] for c in outs[1] if c.cpu()[0] in ix_to_char]), gen_aid_out[0]) if params['show_rev']: print 'Rev text %s: '%(ix_to_auth[auths[0]])+ '%s'%(jc.join([ix_to_char[c.cpu()[0]] for c in outs[-2] if c.cpu()[0] in ix_to_char])) #else: # print '%d/%d\r'%(i, params['num_batches']), err_a1, err_a2 = accum_err_eval[0]/(1e-5+accum_count_gen[0]), accum_err_eval[1]/(1e-5+accum_count_gen[1]) err_real_a1, err_real_a2 = accum_err_real[0]/(1e-5+accum_count_gen[0]), accum_err_real[1]/(1e-5+accum_count_gen[1]) print '--------------------------------------------' print 'Efficiency in fooling discriminator' print '--------------------------------------------' print(' erra1 {:3.2f} - erra2 {:3.2f}'.format(100.*err_a1, 100.*err_a2)) print(' err_real_a1 {:3.2f} - err_real_a2 {:3.2f}'.format(100.*err_real_a1, 100.*err_real_a2)) print(' count %d - %d'%(accum_count_gen[0], accum_count_gen[1])) diff_arr0, diff_arr1 = np.array(accum_diff_eval[0]), np.array(accum_diff_eval[1]) print 'Mean difference : translation to %s = %.2f , translation to %s = %.2f '%(ix_to_auth[0], diff_arr0.mean(), ix_to_auth[1], diff_arr1.mean()) print 'Difference > 0 : translation to %s = %.2f%%, translation to %s = %.2f%% '%(ix_to_auth[0], 100.*(diff_arr0>0).sum()/(1e-5+diff_arr0.shape[0]), ix_to_auth[1], 100.*(diff_arr1>0).sum()/(1e-5+diff_arr1.shape[0])) print 'Difference < 0 : translation to %s = %.2f%%, translation to %s = %.2f%% '%(ix_to_auth[0], 100.*(diff_arr0<0).sum()/(1e-5+diff_arr0.shape[0]), ix_to_auth[1], 100.*(diff_arr1<0).sum()/(1e-5+diff_arr1.shape[0])) print '\n--------------------------------------------' print 'Consistencey with the input text' print '--------------------------------------------' print 'Generated text A0- Precision = %.2f, Recall = %.2f'%(accum_prec_forward[0]/accum_count_gen[0], accum_recall_forward[0]/accum_count_gen[0] ) print 'Generated text A1- Precision = %.2f, Recall = %.2f'%(accum_prec_forward[1]/accum_count_gen[1], accum_recall_forward[1]/accum_count_gen[1] ) if params['show_rev']: print '\n' print 'Reconstr text A0- Precision = %.2f, Recall = %.2f'%(accum_prec_rev[0]/accum_count_gen[0], accum_recall_rev[0]/accum_count_gen[0] ) print 'Reconstr text A1- Precision = %.2f, Recall = %.2f'%(accum_prec_rev[1]/accum_count_gen[1], accum_recall_rev[1]/accum_count_gen[1] ) print '\n--------------------------------------------' print 'Document Level Scores' print '--------------------------------------------' doc_accuracy = np.zeros(len(auth_to_ix)) doc_accuracy_trans = np.zeros(len(auth_to_ix)) doc_count = np.zeros(len(auth_to_ix)) for doc in result['docs']: doc_score_orig = np.array([0.,0.]) doc_score_trans = np.array([0.,0.]) for st in doc['sents']: doc_score_orig += np.log(st[0]['score']) doc_score_trans += np.log(st[0]['trans_score']) doc_accuracy[auth_to_ix[doc['author']]] += float(doc_score_orig.argmax() == auth_to_ix[doc['author']]) doc_accuracy_trans[auth_to_ix[doc['author']]] += float(doc_score_trans.argmax() == auth_to_ix[doc['author']]) doc_count[auth_to_ix[doc['author']]] += 1. print 'Original data' print '-------------' print 'Doc accuracy is %s : %.2f , %s : %.2f'%(ix_to_auth[0], (doc_accuracy[0]/doc_count[0]),ix_to_auth[1], (doc_accuracy[1]/doc_count[1]) ) fp = doc_count[1]- doc_accuracy[1] recall = doc_accuracy[0]/doc_count[0] precision = doc_accuracy[0]/(doc_accuracy[0]+fp) f1score = 2.*(precision*recall)/(precision+recall) print 'Precision is %.2f : Recall is %.2f , F1-score is %.2f'%(precision, recall, f1score) print '\nTranslated data' print '-----------------' print 'Doc accuracy is %s : %.2f , %s : %.2f'%(ix_to_auth[0], (doc_accuracy_trans[0]/doc_count[0]),ix_to_auth[1], (doc_accuracy_trans[1]/doc_count[1]) ) fp = doc_count[1]- doc_accuracy_trans[1] recall = doc_accuracy_trans[0]/doc_count[0] precision = doc_accuracy_trans[0]/(doc_accuracy_trans[0]+fp) f1score = 2.*(precision*recall)/(precision+recall) print 'Precision is %.2f : Recall is %.2f , F1-score is %.2f'%(precision, recall, f1score) if params['dumpjson']: json.dump(result, open(params['dumpjson'],'w'))
def main(params): dp = DataProvider(params) # Create vocabulary and author index if params['resume'] == None: if params['atoms'] == 'char': char_to_ix, ix_to_char = dp.createCharVocab( params['vocab_threshold']) else: char_to_ix, ix_to_char = dp.createWordVocab( params['vocab_threshold']) auth_to_ix, ix_to_auth = dp.createAuthorIdx() else: saved_model = torch.load(params['resume']) char_to_ix = saved_model['char_to_ix'] auth_to_ix = saved_model['auth_to_ix'] ix_to_auth = saved_model['ix_to_auth'] ix_to_char = saved_model['ix_to_char'] params['vocabulary_size'] = len(char_to_ix) params['num_output_layers'] = len(auth_to_ix) model = CharTranslator(params) # set to train mode, this activates dropout model.train() #Initialize the RMSprop optimizer if params['use_sgd']: optim = torch.optim.SGD(model.parameters(), lr=params['learning_rate'], momentum=params['decay_rate']) else: optim = torch.optim.RMSprop(model.parameters(), lr=params['learning_rate'], alpha=params['decay_rate'], eps=params['smooth_eps']) # Loss function if params['mode'] == 'generative': criterion = nn.CrossEntropyLoss() else: criterion = nn.NLLLoss() # Restore saved checkpoint if params['resume'] != None: model.load_state_dict(saved_model['state_dict']) optim.load_state_dict(saved_model['optimizer']) total_loss = 0. start_time = time.time() hidden = model.init_hidden(params['batch_size']) hidden_zeros = model.init_hidden(params['batch_size']) # Initialize the cache if params['randomize_batches']: dp.set_hid_cache(range(len(dp.data['docs'])), hidden_zeros) # Compute the iteration parameters epochs = params['max_epochs'] total_seqs = dp.get_num_sents(split='train') iter_per_epoch = total_seqs // params['batch_size'] total_iters = iter_per_epoch * epochs best_loss = 1000000. best_val = 1000. eval_every = int(iter_per_epoch * params['eval_interval']) #val_score = eval_model(dp, model, params, char_to_ix, auth_to_ix, split='val', max_docs = params['num_eval']) val_score = 0. #eval_model(dp, model, params, char_to_ix, auth_to_ix, split='val', max_docs = params['num_eval']) val_rank = 1000 eval_function = eval_translator if params[ 'mode'] == 'generative' else eval_classify leakage = 0. #params['leakage'] print total_iters for i in xrange(total_iters): #TODO if params['split_generators']: c_aid = ix_to_auth[np.random.choice(auth_to_ix.values())] else: c_aid = None batch = dp.get_sentence_batch(params['batch_size'], split='train', atoms=params['atoms'], aid=c_aid, sample_by_len=params['sample_by_len']) inps, targs, auths, lens = dp.prepare_data( batch, char_to_ix, auth_to_ix, maxlen=params['max_seq_len']) # Reset the hidden states for which new docs have been sampled # Starting each batch, we detach the hidden state from how it was previously produced. # If we didn't, the model would try backpropagating all the way to start of the dataset. hidden = repackage_hidden(hidden) optim.zero_grad() #TODO if params['mode'] == 'generative': output, _ = model.forward_mltrain(inps, lens, inps, lens, hidden_zeros, auths=auths) targets = pack_padded_sequence(Variable(targs).cuda(), lens) loss = criterion(pack_padded_sequence(output, lens)[0], targets[0]) else: # for classifier auths is the target output, hidden = model.forward_classify(inps, hidden, compute_softmax=True) targets = Variable(auths).cuda() loss = criterion(output, targets) loss.backward() # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. torch.nn.utils.clip_grad_norm(model.parameters(), params['grad_clip']) # Take an optimization step optim.step() total_loss += loss.data.cpu().numpy()[0] # Save the hidden states in cache for later use if i % eval_every == 0 and i > 0: val_rank, val_score = eval_function(dp, model, params, char_to_ix, auth_to_ix, split='val') #if i % iter_per_epoch == 0 and i > 0 and leakage > params['leakage_min']: # leakage = leakage * params['leakage_decay'] #if (i % iter_per_epoch == 0) and ((i//iter_per_epoch) >= params['lr_decay_st']): if i % params['log_interval'] == 0 and i > 0: cur_loss = total_loss / params['log_interval'] elapsed = time.time() - start_time print( '| epoch {:2.2f} | {:5d}/{:5d} batches | lr {:02.2e} | ms/batch {:5.2f} | ' 'loss {:5.2f} | ppl {:8.2f}'.format( float(i) / iter_per_epoch, i, total_iters, params['learning_rate'], elapsed * 1000 / args.log_interval, cur_loss, math.exp(cur_loss))) total_loss = 0. if val_rank <= best_val: save_checkpoint( { 'iter': i, 'arch': params, 'val_loss': val_rank, 'val_pplx': val_score, 'char_to_ix': char_to_ix, 'ix_to_char': ix_to_char, 'auth_to_ix': auth_to_ix, 'ix_to_auth': ix_to_auth, 'state_dict': model.state_dict(), 'loss': cur_loss, 'optimizer': optim.state_dict(), }, fappend=params['fappend'], outdir=params['checkpoint_output_directory']) best_val = val_rank start_time = time.time()
def main(params): # Create vocabulary and author index saved_model = torch.load(params['model']) if 'misc' in saved_model: misc = saved_model['misc'] char_to_ix = misc['char_to_ix'] auth_to_ix = misc['auth_to_ix'] ix_to_char = misc['ix_to_char'] ix_to_auth = misc['ix_to_auth'] else: char_to_ix = saved_model['char_to_ix'] auth_to_ix = saved_model['auth_to_ix'] ix_to_char = saved_model['ix_to_char'] ix_to_auth = saved_model['ix_to_auth'] cp_params = saved_model['arch'] if params['softmax_scale']: cp_params['softmax_scale'] = params['softmax_scale'] dp = DataProvider(cp_params) if params['m_type'] == 'generative': model = CharLstm(cp_params) else: model = CharTranslator(cp_params) # set to train mode, this activates dropout model.eval() auth_colors = ['red', 'blue'] startc = dp.data['configs']['start'] endc = dp.data['configs']['end'] append_tensor = np.zeros((1, 1), dtype=np.int) append_tensor[0, 0] = char_to_ix[startc] append_tensor = torch.LongTensor(append_tensor).cuda() # Restore saved checkpoint model.load_state_dict(saved_model['state_dict']) hidden = model.init_hidden(1) jc = '' if cp_params.get('atoms', 'char') == 'char' else ' ' for i in range(params['num_samples']): c_aid = np.random.choice(list(auth_to_ix.values())) if params['m_type'] == 'generative': batch = dp.get_random_string(slen=params['seed_length'], split=params['split']) else: batch = dp.get_sentence_batch(1, split=params['split'], atoms=cp_params.get('atoms', 'char'), aid=ix_to_auth[c_aid]) inps, targs, auths, lens = dp.prepare_data( batch, char_to_ix, auth_to_ix, maxlen=cp_params['max_seq_len']) auths_inp = 1 - auths if params['flip'] else auths forward, backward = adv_forward_pass(model, inps, lens, end_c=char_to_ix[endc], maxlen=cp_params['max_seq_len'], auths=auths_inp, cycle_compute=params['show_rev'], append_symb=append_tensor) # char_outs = model.forward_gen(inps, hidden, auths_inp, n_max = cp_params['max_len'],end_c=char_to_ix['.']) print('--------------------------------------------') print('Translate from %s to %s' % (batch[0]['author'], ix_to_auth[auths_inp.item()])) # General helper functions # Clears whitespace but retains character for re.sub def strip_match(match): return match.group(0).strip() # Joins together decimals def fix_decimals(match): match = match.group(0) return re.sub('\s', '', match) # Cleans text by removing unnecessary whitespace and substituting back in some symbols def clean_text(text): text = re.sub('-lrb-', '(', text) text = re.sub('-rrb-', ')', text) text = re.sub('-lsb-', '[', text) text = re.sub('-rsb-', ']', text) text = re.sub('-lcb-', '{', text) text = re.sub('-rcb-', '}', text) text = re.sub('\'\'', '\"', text) text = re.sub('\si\s', ' I ', text) text = re.sub('^i\s', 'I ', text) text = re.sub('\sna\s', 'na ', text) text = re.sub('\$\s', strip_match, text) text = re.sub('[-#]\s|\s([-.!,\':;?]|n\'t)', strip_match, text) text = re.sub('\d+. \d+', fix_decimals, text) return text # Get original sentence and clean it up a bit input_list = [ix_to_char[c.item()] for c in inps[1:]] input_string = jc.join(input_list) input_string = clean_text(input_string) # Get translated sentence and clean it up a bit output_list = [ ix_to_char[c.item()] for c in forward if c.item() in ix_to_char ] if output_list[-1] == 'END': output_list = output_list[:-1] output_string = jc.join(output_list) output_string = clean_text(output_string) print( colored('Inp %6s: ' % (ix_to_auth[auths.item()]), 'green') + colored('%s' % input_string, auth_colors[auths.item()])) print( colored('Out %6s: ' % (ix_to_auth[auths_inp.item()]), 'grey') + colored('%s' % output_string, auth_colors[auths_inp.item()])) if params['show_rev']: print( colored('Rev %6s: ' % (ix_to_auth[auths.item()]), 'green') + colored( '%s' % (jc.join([ ix_to_char[c.item()] for c in backward if c.item() in ix_to_char and ix_to_char[c.item()] != 'END' ])), auth_colors[auths.item()]))