def main(params): # Create vocabulary and author index saved_model = torch.load(params['model']) if 'misc' in saved_model: misc = saved_model['misc'] char_to_ix = misc['char_to_ix'] auth_to_ix = misc['auth_to_ix'] ix_to_char = misc['ix_to_char'] ix_to_auth = misc['ix_to_auth'] else: char_to_ix = saved_model['char_to_ix'] auth_to_ix = saved_model['auth_to_ix'] ix_to_char = saved_model['ix_to_char'] ix_to_auth = saved_model['ix_to_auth'] cp_params = saved_model['arch'] if params['softmax_scale']: cp_params['softmax_scale'] = params['softmax_scale'] dp = DataProvider(cp_params) if params['m_type'] == 'generative': model = CharLstm(cp_params) else: model = CharTranslator(cp_params) # set to train mode, this activates dropout model.eval() auth_colors = ['red', 'blue'] startc = dp.data['configs']['start'] endc = dp.data['configs']['end'] append_tensor = np.zeros((1, 1), dtype=np.int) append_tensor[0, 0] = char_to_ix[startc] append_tensor = torch.LongTensor(append_tensor).cuda() # Restore saved checkpoint model.load_state_dict(saved_model['state_dict']) hidden = model.init_hidden(1) jc = '' if cp_params.get('atoms','char') == 'char' else ' ' for i in xrange(params['num_samples']): c_aid = np.random.choice(auth_to_ix.values()) if params['m_type'] == 'generative': batch = dp.get_random_string(slen = params['seed_length'], split=params['split']) else: batch = dp.get_sentence_batch(1,split=params['split'], atoms=cp_params.get('atoms','char'), aid=ix_to_auth[c_aid]) inps, targs, auths, lens = dp.prepare_data(batch, char_to_ix, auth_to_ix, maxlen=cp_params['max_seq_len']) auths_inp = 1 - auths if params['flip'] else auths outs = adv_forward_pass(model, inps, lens, end_c=char_to_ix[endc], maxlen=cp_params['max_seq_len'], auths=auths_inp, cycle_compute=params['show_rev'], append_symb=append_tensor) #char_outs = model.forward_gen(inps, hidden, auths_inp, n_max = cp_params['max_len'],end_c=char_to_ix['.']) print '--------------------------------------------' #print 'Translate from %s to %s'%(batch[0]['author'], ix_to_auth[auths_inp[0]]) print colored('Inp %6s: '%(ix_to_auth[auths[0]]),'green') + colored('%s'%(jc.join([ix_to_char[c[0]] for c in inps[1:]])),auth_colors[auths[0]]) print colored('Out %6s: '%(ix_to_auth[auths_inp[0]]),'grey')+ colored('%s'%(jc.join([ix_to_char[c.data.cpu()[0]] for c in outs[0] if c.data.cpu()[0] in ix_to_char])),auth_colors[auths_inp[0]]) if params['show_rev']: print colored('Rev %6s: '%(ix_to_auth[auths[0]]),'green')+ colored('%s'%(jc.join([ix_to_char[c.data.cpu()[0]] for c in outs[-1] if c.data.cpu()[0] in ix_to_char])),auth_colors[auths[0]])
def main(params): # Create vocabulary and author index saved_model = torch.load(params['model']) if 'misc' in saved_model: misc = saved_model['misc'] char_to_ix = misc['char_to_ix'] auth_to_ix = misc['auth_to_ix'] ix_to_char = misc['ix_to_char'] ix_to_auth = misc['ix_to_auth'] else: char_to_ix = saved_model['char_to_ix'] auth_to_ix = saved_model['auth_to_ix'] ix_to_char = saved_model['ix_to_char'] ix_to_auth = saved_model['ix_to_auth'] cp_params = saved_model['arch'] if params['softmax_scale']: cp_params['softmax_scale'] = params['softmax_scale'] dp = DataProvider(cp_params) if params['m_type'] == 'generative': model = CharLstm(cp_params) else: model = CharTranslator(cp_params) # set to train mode, this activates dropout model.eval() auth_colors = ['red', 'blue'] startc = dp.data['configs']['start'] endc = dp.data['configs']['end'] append_tensor = np.zeros((1, 1), dtype=np.int) append_tensor[0, 0] = char_to_ix[startc] append_tensor = torch.LongTensor(append_tensor).cuda() # Restore saved checkpoint model.load_state_dict(saved_model['state_dict']) hidden = model.init_hidden(1) jc = '' if cp_params.get('atoms', 'char') == 'char' else ' ' for i in range(params['num_samples']): c_aid = np.random.choice(list(auth_to_ix.values())) if params['m_type'] == 'generative': batch = dp.get_random_string(slen=params['seed_length'], split=params['split']) else: batch = dp.get_sentence_batch(1, split=params['split'], atoms=cp_params.get('atoms', 'char'), aid=ix_to_auth[c_aid]) inps, targs, auths, lens = dp.prepare_data( batch, char_to_ix, auth_to_ix, maxlen=cp_params['max_seq_len']) auths_inp = 1 - auths if params['flip'] else auths forward, backward = adv_forward_pass(model, inps, lens, end_c=char_to_ix[endc], maxlen=cp_params['max_seq_len'], auths=auths_inp, cycle_compute=params['show_rev'], append_symb=append_tensor) # char_outs = model.forward_gen(inps, hidden, auths_inp, n_max = cp_params['max_len'],end_c=char_to_ix['.']) print('--------------------------------------------') print('Translate from %s to %s' % (batch[0]['author'], ix_to_auth[auths_inp.item()])) # General helper functions # Clears whitespace but retains character for re.sub def strip_match(match): return match.group(0).strip() # Joins together decimals def fix_decimals(match): match = match.group(0) return re.sub('\s', '', match) # Cleans text by removing unnecessary whitespace and substituting back in some symbols def clean_text(text): text = re.sub('-lrb-', '(', text) text = re.sub('-rrb-', ')', text) text = re.sub('-lsb-', '[', text) text = re.sub('-rsb-', ']', text) text = re.sub('-lcb-', '{', text) text = re.sub('-rcb-', '}', text) text = re.sub('\'\'', '\"', text) text = re.sub('\si\s', ' I ', text) text = re.sub('^i\s', 'I ', text) text = re.sub('\sna\s', 'na ', text) text = re.sub('\$\s', strip_match, text) text = re.sub('[-#]\s|\s([-.!,\':;?]|n\'t)', strip_match, text) text = re.sub('\d+. \d+', fix_decimals, text) return text # Get original sentence and clean it up a bit input_list = [ix_to_char[c.item()] for c in inps[1:]] input_string = jc.join(input_list) input_string = clean_text(input_string) # Get translated sentence and clean it up a bit output_list = [ ix_to_char[c.item()] for c in forward if c.item() in ix_to_char ] if output_list[-1] == 'END': output_list = output_list[:-1] output_string = jc.join(output_list) output_string = clean_text(output_string) print( colored('Inp %6s: ' % (ix_to_auth[auths.item()]), 'green') + colored('%s' % input_string, auth_colors[auths.item()])) print( colored('Out %6s: ' % (ix_to_auth[auths_inp.item()]), 'grey') + colored('%s' % output_string, auth_colors[auths_inp.item()])) if params['show_rev']: print( colored('Rev %6s: ' % (ix_to_auth[auths.item()]), 'green') + colored( '%s' % (jc.join([ ix_to_char[c.item()] for c in backward if c.item() in ix_to_char and ix_to_char[c.item()] != 'END' ])), auth_colors[auths.item()]))