def main(params):

    # Create vocabulary and author index
    saved_model = torch.load(params['model'])
    if 'misc' in saved_model:
        misc = saved_model['misc']
        char_to_ix = misc['char_to_ix']
        auth_to_ix = misc['auth_to_ix']
        ix_to_char = misc['ix_to_char']
        ix_to_auth = misc['ix_to_auth']
    else:
        char_to_ix = saved_model['char_to_ix']
        auth_to_ix = saved_model['auth_to_ix']
        ix_to_char = saved_model['ix_to_char']
        ix_to_auth = saved_model['ix_to_auth']
    cp_params = saved_model['arch']
    if params['softmax_scale']:
        cp_params['softmax_scale'] = params['softmax_scale']

    dp = DataProvider(cp_params)

    if params['m_type'] == 'generative':
        model = CharLstm(cp_params)
    else:
        model = CharTranslator(cp_params)
    # set to train mode, this activates dropout
    model.eval()
    auth_colors = ['red', 'blue']

    startc = dp.data['configs']['start']
    endc = dp.data['configs']['end']

    append_tensor = np.zeros((1, 1), dtype=np.int)
    append_tensor[0, 0] = char_to_ix[startc]
    append_tensor = torch.LongTensor(append_tensor).cuda()

    # Restore saved checkpoint
    model.load_state_dict(saved_model['state_dict'])
    hidden = model.init_hidden(1)
    jc = '' if cp_params.get('atoms','char') == 'char' else ' '

    for i in xrange(params['num_samples']):
        c_aid = np.random.choice(auth_to_ix.values())
        if params['m_type'] == 'generative':
            batch = dp.get_random_string(slen = params['seed_length'], split=params['split'])
        else:
            batch = dp.get_sentence_batch(1,split=params['split'], atoms=cp_params.get('atoms','char'), aid=ix_to_auth[c_aid])

        inps, targs, auths, lens = dp.prepare_data(batch, char_to_ix, auth_to_ix, maxlen=cp_params['max_seq_len'])
        auths_inp = 1 - auths if params['flip'] else auths
        outs = adv_forward_pass(model, inps, lens, end_c=char_to_ix[endc], maxlen=cp_params['max_seq_len'], auths=auths_inp,
                     cycle_compute=params['show_rev'], append_symb=append_tensor)
        #char_outs = model.forward_gen(inps, hidden, auths_inp, n_max = cp_params['max_len'],end_c=char_to_ix['.'])
        print '--------------------------------------------'
        #print 'Translate from %s to %s'%(batch[0]['author'], ix_to_auth[auths_inp[0]])
        print colored('Inp %6s: '%(ix_to_auth[auths[0]]),'green') + colored('%s'%(jc.join([ix_to_char[c[0]] for c in inps[1:]])),auth_colors[auths[0]])
        print colored('Out %6s: '%(ix_to_auth[auths_inp[0]]),'grey')+ colored('%s'%(jc.join([ix_to_char[c.data.cpu()[0]] for c in outs[0] if c.data.cpu()[0] in ix_to_char])),auth_colors[auths_inp[0]])

        if params['show_rev']:
            print colored('Rev %6s: '%(ix_to_auth[auths[0]]),'green')+ colored('%s'%(jc.join([ix_to_char[c.data.cpu()[0]] for c in outs[-1] if c.data.cpu()[0] in ix_to_char])),auth_colors[auths[0]])
Esempio n. 2
0
def main(params):
    # Create vocabulary and author index
    saved_model = torch.load(params['model'])
    if 'misc' in saved_model:
        misc = saved_model['misc']
        char_to_ix = misc['char_to_ix']
        auth_to_ix = misc['auth_to_ix']
        ix_to_char = misc['ix_to_char']
        ix_to_auth = misc['ix_to_auth']
    else:
        char_to_ix = saved_model['char_to_ix']
        auth_to_ix = saved_model['auth_to_ix']
        ix_to_char = saved_model['ix_to_char']
        ix_to_auth = saved_model['ix_to_auth']
    cp_params = saved_model['arch']
    if params['softmax_scale']:
        cp_params['softmax_scale'] = params['softmax_scale']

    dp = DataProvider(cp_params)

    if params['m_type'] == 'generative':
        model = CharLstm(cp_params)
    else:
        model = CharTranslator(cp_params)
    # set to train mode, this activates dropout
    model.eval()
    auth_colors = ['red', 'blue']

    startc = dp.data['configs']['start']
    endc = dp.data['configs']['end']

    append_tensor = np.zeros((1, 1), dtype=np.int)
    append_tensor[0, 0] = char_to_ix[startc]
    append_tensor = torch.LongTensor(append_tensor).cuda()

    # Restore saved checkpoint
    model.load_state_dict(saved_model['state_dict'])
    hidden = model.init_hidden(1)
    jc = '' if cp_params.get('atoms', 'char') == 'char' else ' '

    for i in range(params['num_samples']):
        c_aid = np.random.choice(list(auth_to_ix.values()))
        if params['m_type'] == 'generative':
            batch = dp.get_random_string(slen=params['seed_length'],
                                         split=params['split'])
        else:
            batch = dp.get_sentence_batch(1,
                                          split=params['split'],
                                          atoms=cp_params.get('atoms', 'char'),
                                          aid=ix_to_auth[c_aid])

        inps, targs, auths, lens = dp.prepare_data(
            batch, char_to_ix, auth_to_ix, maxlen=cp_params['max_seq_len'])
        auths_inp = 1 - auths if params['flip'] else auths
        forward, backward = adv_forward_pass(model,
                                             inps,
                                             lens,
                                             end_c=char_to_ix[endc],
                                             maxlen=cp_params['max_seq_len'],
                                             auths=auths_inp,
                                             cycle_compute=params['show_rev'],
                                             append_symb=append_tensor)
        # char_outs = model.forward_gen(inps, hidden, auths_inp, n_max = cp_params['max_len'],end_c=char_to_ix['.'])
        print('--------------------------------------------')
        print('Translate from %s to %s' %
              (batch[0]['author'], ix_to_auth[auths_inp.item()]))

        # General helper functions

        # Clears whitespace but retains character for re.sub
        def strip_match(match):
            return match.group(0).strip()

        # Joins together decimals
        def fix_decimals(match):
            match = match.group(0)
            return re.sub('\s', '', match)

        # Cleans text by removing unnecessary whitespace and substituting back in some symbols
        def clean_text(text):
            text = re.sub('-lrb-', '(', text)
            text = re.sub('-rrb-', ')', text)
            text = re.sub('-lsb-', '[', text)
            text = re.sub('-rsb-', ']', text)
            text = re.sub('-lcb-', '{', text)
            text = re.sub('-rcb-', '}', text)
            text = re.sub('\'\'', '\"', text)
            text = re.sub('\si\s', ' I ', text)
            text = re.sub('^i\s', 'I ', text)
            text = re.sub('\sna\s', 'na ', text)
            text = re.sub('\$\s', strip_match, text)
            text = re.sub('[-#]\s|\s([-.!,\':;?]|n\'t)', strip_match, text)
            text = re.sub('\d+. \d+', fix_decimals, text)
            return text

        # Get original sentence and clean it up a bit
        input_list = [ix_to_char[c.item()] for c in inps[1:]]
        input_string = jc.join(input_list)
        input_string = clean_text(input_string)

        # Get translated sentence and clean it up a bit
        output_list = [
            ix_to_char[c.item()] for c in forward if c.item() in ix_to_char
        ]
        if output_list[-1] == 'END':
            output_list = output_list[:-1]
        output_string = jc.join(output_list)
        output_string = clean_text(output_string)

        print(
            colored('Inp %6s: ' % (ix_to_auth[auths.item()]), 'green') +
            colored('%s' % input_string, auth_colors[auths.item()]))
        print(
            colored('Out %6s: ' % (ix_to_auth[auths_inp.item()]), 'grey') +
            colored('%s' % output_string, auth_colors[auths_inp.item()]))

        if params['show_rev']:
            print(
                colored('Rev %6s: ' % (ix_to_auth[auths.item()]), 'green') +
                colored(
                    '%s' % (jc.join([
                        ix_to_char[c.item()]
                        for c in backward if c.item() in ix_to_char
                        and ix_to_char[c.item()] != 'END'
                    ])), auth_colors[auths.item()]))