def train(log_dir, n_epochs, network_dict, index2token, **kwargs):
    onehot_words = kwargs['onehot_words']
    word_pos = kwargs['word_pos']
    sentence_lens_nchars = kwargs['sentence_lens_nchars']
    sentence_lens_nwords = kwargs['sentence_lens_nwords']
    vocabulary_size = kwargs['vocabulary_size']
    max_char_len = kwargs['max_char_len']
    onehot_words_val = kwargs['onehot_words_val']
    word_pos_val = kwargs['word_pos_val']
    sentence_lens_nchars_val = kwargs['sentence_lens_nchars_val']
    sentence_lens_nwords_val = kwargs['sentence_lens_nwords_val']

    batch_size = kwargs['batch_size']
    input_size = vocabulary_size
    hidden_size = kwargs['hidden_size']
    decoder_dim = kwargs['decoder_dim']
    decoder_units_p3 = kwargs['decoder_units_p3']
    num_batches = len(onehot_words) // batch_size
    network_dict['input_size'] = input_size

    max_word_len = np.max(sentence_lens_nwords)

    encoder_k = encoder.Encoder(**network_dict)

    #onehot_words,word_pos,vocabulary_size = encoder_k.run_preprocess()
    #prepping permutation matrix for all instances seperately
    perm_mat, max_lat_word_len, lat_sent_len_list = prep_perm_matrix(
        batch_size=batch_size,
        word_pos_matrix=word_pos,
        max_char_len=max_char_len)

    #placeholders
    mask_kl_pl = tf.placeholder(name='kl_pl_mask',
                                dtype=tf.float32,
                                shape=[batch_size, max_lat_word_len])
    sent_word_len_list_pl = tf.placeholder(name='word_lens',
                                           dtype=tf.int32,
                                           shape=[batch_size])
    perm_mat_pl = tf.placeholder(name='perm_mat_pl',
                                 dtype=tf.int32,
                                 shape=[batch_size, max_lat_word_len])
    onehot_words_pl = tf.placeholder(
        name='onehot_words',
        dtype=tf.float32,
        shape=[batch_size, max_char_len, vocabulary_size])
    word_pos_pl = tf.placeholder(name='word_pos',
                                 dtype=tf.float32,
                                 shape=[batch_size, max_char_len])
    sent_char_len_list_pl = tf.placeholder(name='sent_char_len_list',
                                           dtype=tf.float32,
                                           shape=[batch_size])
    #decoder
    arg_dict = {
        'decoder_p3_units': decoder_units_p3,
        'encoder_dim': hidden_size,
        'lat_word_dim': hidden_size,
        'sentence_lens': None,
        'global_lat_dim': hidden_size,
        'batch_size': batch_size,
        'max_num_lat_words': max_lat_word_len,
        'decoder_units': decoder_dim,
        'num_sentence_characters': max_char_len,
        'dict_length': vocabulary_size
    }
    decoder = Decoder(**arg_dict)

    #step counter
    global_step = tf.Variable(0, name='global_step', trainable=False)

    word_state_out, mean_state_out, logsig_state_out = encoder_k.run_encoder(
        sentence_lens=sent_char_len_list_pl,
        train=True,
        inputs=onehot_words_pl,
        word_pos=word_pos_pl,
        reuse=None)

    #picking out our words
    #why do these all start at 0?
    # replace 0's possibly with len+1
    ## RELYING ON THERE BEING NOTHING AT ZEROS
    #indice 0 problem?
    word_state_out.set_shape([max_char_len, batch_size, hidden_size])
    mean_state_out.set_shape([max_char_len, batch_size, hidden_size])
    logsig_state_out.set_shape([max_char_len, batch_size, hidden_size])
    word_state_out_p = permute_encoder_output(encoder_out=word_state_out,
                                              perm_mat=perm_mat_pl,
                                              batch_size=batch_size,
                                              max_word_len=max_lat_word_len)
    mean_state_out_p = permute_encoder_output(encoder_out=mean_state_out,
                                              perm_mat=perm_mat_pl,
                                              batch_size=batch_size,
                                              max_word_len=max_lat_word_len)
    logsig_state_out_p = permute_encoder_output(encoder_out=logsig_state_out,
                                                perm_mat=perm_mat_pl,
                                                batch_size=batch_size,
                                                max_word_len=max_lat_word_len)
    #Initialize decoder
    ##Note to self: need to input sentence lengths vector, also check to make sure all the placeholders flow into my class and tensorflow with ease

    out_o, global_latent_o, global_logsig_o, global_mu_o = decoder.run_decoder(
        word_sequence_length=sent_word_len_list_pl,
        train=True,
        reuse=None,
        units_lstm_decoder=decoder_dim,
        lat_words=word_state_out_p,
        units_dense_global=decoder_dim,
        char_sequence_length=tf.cast(sent_char_len_list_pl, dtype=tf.int32))

    # shaping for batching
    #reshape problem

    onehot_words = np.reshape(
        onehot_words, newshape=[-1, batch_size, max_char_len, vocabulary_size])
    word_pos = np.reshape(word_pos, newshape=[-1, batch_size, max_char_len])
    # making word masks for kl term
    kl_mask = []
    print(sentence_lens_nwords)
    for word_len in np.reshape(lat_sent_len_list, -1):
        vec = np.zeros([max_lat_word_len], dtype=np.float32)
        vec[0:word_len] = np.ones(shape=word_len, dtype=np.float32)
        kl_mask.append(vec)
    kl_mask = np.asarray(kl_mask)
    kl_mask = np.reshape(kl_mask, newshape=[-1, batch_size, max_lat_word_len])
    sentence_lens_nwords = np.reshape(sentence_lens_nwords,
                                      newshape=[-1, batch_size])
    sentence_lens_nchars = np.reshape(sentence_lens_nchars,
                                      newshape=[-1, batch_size])
    lat_sent_len_list = np.reshape(lat_sent_len_list, [-1, batch_size])

    #shaping for validation set
    batch_size_val = batch_size
    n_valid = np.shape(onehot_words_val)[0]

    r = n_valid % batch_size_val
    n_valid_use = n_valid - r
    #might have to fix this before reporting results
    onehot_words_val = np.reshape(
        onehot_words_val[0:n_valid_use, ...],
        newshape=[-1, batch_size_val, max_char_len, vocabulary_size])
    word_pos_val = np.reshape(word_pos_val[0:n_valid_use, ...],
                              newshape=[-1, batch_size_val, max_char_len])
    #sentence_lens_nwords_val = np.reshape(sentence_lens_nwords_val[0:n_valid_use],newshape=[-1,batch_size_val])
    sentence_lens_nchars_val = np.reshape(
        sentence_lens_nchars_val[0:n_valid_use], newshape=[-1, batch_size_val])

    ###KL annealing parameters
    shift = 5000
    total_steps = np.round(np.true_divide(n_epochs, 16) *
                           np.shape(onehot_words)[0],
                           decimals=0)

    ####
    cost, reconstruction, kl_p3, kl_p1, kl_global, kl_p2, anneal, _ = decoder.calc_cost(
        eow_mask=None,
        mask_kl=mask_kl_pl,
        kl=True,
        sentence_word_lens=sent_word_len_list_pl,
        shift=shift,
        total_steps=total_steps,
        global_step=global_step,
        global_latent_sample=global_latent_o,
        global_logsig=global_logsig_o,
        global_mu=global_mu_o,
        predictions=out_o,
        true_input=onehot_words_pl,
        posterior_logsig=logsig_state_out_p,
        posterior_mu=mean_state_out_p,
        post_samples=word_state_out_p,
        reuse=None)

    ######
    # Train Step

    # clipping gradients
    ######
    lr = 1e-4
    opt = tf.train.AdamOptimizer(lr)
    grads_t, vars_t = zip(*opt.compute_gradients(cost))
    clipped_grads_t, grad_norm_t = tf.clip_by_global_norm(grads_t,
                                                          clip_norm=5.0)
    train_step = opt.apply_gradients(zip(clipped_grads_t, vars_t),
                                     global_step=global_step)
    regex = re.compile('[^a-zA-Z]')
    #sum_grad_hist = [tf.summary.histogram(name=regex.sub('',str(j)),values=i) for i,j in zip(clipped_grads_t,vars_t)]
    norm_grad = tf.summary.scalar(name='grad_norm', tensor=grad_norm_t)

    ######
    #testing stuff
    #testing pls
    sent_word_len_list_pl_val = tf.placeholder(name='word_lens_val',
                                               dtype=tf.int32,
                                               shape=[batch_size])
    perm_mat_pl_val = tf.placeholder(name='perm_mat_val',
                                     dtype=tf.int32,
                                     shape=[batch_size, max_lat_word_len])
    onehot_words_pl_val = tf.placeholder(
        name='onehot_words_val',
        dtype=tf.float32,
        shape=[batch_size, max_char_len, vocabulary_size])
    word_pos_pl_val = tf.placeholder(name='word_pos_val',
                                     dtype=tf.float32,
                                     shape=[batch_size, max_char_len])
    sent_char_len_list_pl_val = tf.placeholder(name='sent_char_len_list_val',
                                               dtype=tf.float32,
                                               shape=[batch_size])

    #testing graph
    word_state_out_val, mean_state_out_val, logsig_state_out_val = encoder_k.run_encoder(
        sentence_lens=sent_char_len_list_pl_val,
        train=False,
        inputs=onehot_words_pl_val,
        word_pos=word_pos_pl_val,
        reuse=True)
    perm_mat_val, _, lat_sent_len_list_val = prep_perm_matrix(
        batch_size=batch_size_val,
        word_pos_matrix=word_pos_val,
        max_char_len=max_char_len,
        max_word_len=max_lat_word_len)
    kl_mask_val = []
    for word_len in np.reshape(lat_sent_len_list_val, -1):
        vec = np.zeros([max_lat_word_len], dtype=np.float32)
        vec[0:word_len] = np.ones(shape=word_len, dtype=np.float32)
        kl_mask_val.append(vec)
    kl_mask_val = np.asarray(kl_mask_val)
    kl_mask_val = np.reshape(kl_mask_val,
                             newshape=[-1, batch_size, max_lat_word_len])

    lat_sent_len_list_val = np.reshape(np.reshape(lat_sent_len_list_val,
                                                  -1)[0:n_valid_use],
                                       newshape=[-1, batch_size_val])
    word_state_out_val.set_shape([max_char_len, batch_size_val, hidden_size])
    mean_state_out_val.set_shape([max_char_len, batch_size_val, hidden_size])
    logsig_state_out.set_shape([max_char_len, batch_size_val, hidden_size])
    word_state_out_p_val = permute_encoder_output(
        encoder_out=word_state_out_val,
        perm_mat=perm_mat_pl_val,
        batch_size=batch_size_val,
        max_word_len=max_lat_word_len)
    mean_state_out_p_val = permute_encoder_output(
        encoder_out=mean_state_out_val,
        perm_mat=perm_mat_pl_val,
        batch_size=batch_size_val,
        max_word_len=max_lat_word_len)
    logsig_state_out_p_val = permute_encoder_output(
        encoder_out=logsig_state_out_val,
        perm_mat=perm_mat_pl_val,
        batch_size=batch_size_val,
        max_word_len=max_lat_word_len)
    out_o_val, global_latent_o_val, global_logsig_o_val, global_mu_o_val = decoder.run_decoder(
        word_sequence_length=sent_word_len_list_pl_val,
        train=False,
        reuse=True,
        units_lstm_decoder=decoder_dim,
        lat_words=mean_state_out_p_val,
        units_dense_global=decoder.global_lat_dim,
        char_sequence_length=tf.cast(sent_char_len_list_pl_val,
                                     dtype=tf.int32))
    #test cost
    test_cost = decoder.test_calc_cost(
        mask_kl=mask_kl_pl,
        sentence_word_lens=sent_word_len_list_pl_val,
        posterior_logsig=logsig_state_out_p_val,
        post_samples=word_state_out_p_val,
        global_mu=global_mu_o_val,
        global_logsig=global_logsig_o_val,
        global_latent_sample=global_latent_o_val,
        posterior_mu=mean_state_out_p_val,
        true_input=onehot_words_pl_val,
        predictions=out_o_val)

    ######

    ######
    #prior sampling
    samples = np.random.normal(size=[batch_size, decoder.global_lat_dim])
    gen_samples = decoder.generation(samples=samples)

    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())

    ### IW eval
    NLL, bpc = n_samples_IW(n_samples=10,
                            encoder=encoder_k,
                            decoder=decoder,
                            decoder_dim=decoder_dim,
                            sent_char_len_list_pl=sent_char_len_list_pl_val,
                            true_output=onehot_words_pl_val,
                            onehot_words_pl=onehot_words_pl_val,
                            word_pos_pl=word_pos_pl_val,
                            perm_mat_pl=perm_mat_pl_val,
                            batch_size=batch_size,
                            max_lat_word_len=max_lat_word_len,
                            sent_word_len_list_pl=sent_word_len_list_pl_val)
    sum_NLL = tf.summary.scalar(tensor=NLL, name='10sample_IWAE_LL')
    sum_bpc = tf.summary.scalar(tensor=bpc, name='10sample_IWAE_BPC')
    ###

    ######
    #tensorboard stuff
    summary_inf_train = tf.summary.merge([
        norm_grad, decoder.kls_hist, decoder.global_kl_scalar,
        decoder.rec_scalar, decoder.cost_scalar, decoder.full_kl_scalar,
        decoder.sum_all_activ_hist, decoder.sum_global_activ_hist
    ])
    summary_inf_test = tf.summary.merge(
        [sum_NLL, sum_bpc, decoder.sum_rec_val, decoder.sum_kl_val])
    summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
    ######
    log_file = log_dir + "vaelog.txt"
    logger = logging.getLogger('mVAE_log')
    hdlr = logging.FileHandler(log_file)
    formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    hdlr.setFormatter(formatter)
    logger.addHandler(hdlr)
    logger.setLevel(logging.DEBUG)
    for epoch in range(n_epochs):
        inds = range(np.shape(onehot_words)[0])
        np.random.shuffle(inds)
        for count, batch in enumerate(inds):
            anneal_c_o, train_predictions_o_np, train_cost_o_np, _, global_step_o_np, train_rec_cost_o_np, _, _, _, _, summary_inf_train_o = sess.run(
                [
                    anneal, out_o, cost, train_step, global_step,
                    reconstruction, kl_p3, kl_p1, kl_global, kl_p2,
                    summary_inf_train
                ],
                feed_dict={
                    mask_kl_pl: kl_mask[batch],
                    onehot_words_pl: onehot_words[batch],
                    word_pos_pl: word_pos[batch],
                    perm_mat_pl: perm_mat[batch],
                    sent_word_len_list_pl: lat_sent_len_list[batch],
                    sent_char_len_list_pl: sentence_lens_nchars[batch]
                })
            #logger.debug('anneal const {}'.format(anneal_c))
            #logger.debug('ground truth {}'.format(get_output_sentences(index2token, ground_truth[0:10])))
            if global_step_o_np % 400 == 0:
                rind = np.random.randint(low=0,
                                         high=np.shape(onehot_words_val)[-1])
                val_predictions_o_np, val_cost_o_np, summary_inf_test_o = sess.run(
                    [out_o_val, test_cost, summary_inf_test],
                    feed_dict={
                        mask_kl_pl: kl_mask_val[rind],
                        onehot_words_pl_val: onehot_words_val[rind],
                        word_pos_pl_val: word_pos_val[rind],
                        perm_mat_pl_val: perm_mat_val[rind],
                        sent_word_len_list_pl_val: lat_sent_len_list_val[rind],
                        sent_char_len_list_pl_val:
                        sentence_lens_nchars_val[rind]
                    })

                predictions = np.argmax(train_predictions_o_np[0:10], axis=-1)
                ground_truth = np.argmax(onehot_words[batch][0:10], axis=-1)
                val_predictions = np.argmax(val_predictions_o_np, axis=-1)
                true = np.argmax(onehot_words_val[rind], -1)
                num = np.sum([
                    np.sum(val_predictions[j][0:i] == true[j][0:i])
                    for j, i in enumerate(sentence_lens_nchars_val[rind])
                ])

                denom = np.sum(sentence_lens_nchars_val[rind])
                accuracy = np.true_divide(num, denom) * 100
                logger.debug(
                    'accuracy on random val batch {}'.format(accuracy))
                logger.debug('predictions {}'.format(
                    [[index2token[j] for j in i]
                     for i in predictions[0:10, 0:50]]))
                logger.debug('ground truth {}'.format(
                    [[index2token[j] for j in i]
                     for i in ground_truth[0:10, 0:50]]))
                logger.debug(
                    'global step: {} Epoch: {} count: {} anneal:{}'.format(
                        global_step_o_np, epoch, count, anneal_c_o))
                logger.debug('train cost: {}'.format(train_cost_o_np))
                logger.debug('validation cost {}'.format(val_cost_o_np))
                logger.debug('validation predictions {}'.format(
                    [[index2token[j] for j in i]
                     for i in val_predictions[0:10, 0:50]]))
                summary_writer.add_summary(summary_inf_test_o,
                                           global_step_o_np)
                summary_writer.flush()
            if global_step_o_np % 1000 == 0:
                # testing on the generative model
                gen_o_np = sess.run([gen_samples])
                gen_pred = np.argmax(gen_o_np[0:10], axis=-1)
                logger.debug('GEN predictions {}'.format(
                    [[index2token[j] for j in i]
                     for i in gen_pred[0][0:10, 0:50]]))

            summary_writer.add_summary(summary_inf_train_o, global_step_o_np)
            summary_writer.flush()
示例#2
0
def train(log_dir, n_epochs, network_dict, index2token, **kwargs):
    onehot_words = kwargs['onehot_words']
    word_pos = kwargs['word_pos']
    sentence_lens_nchars = kwargs['sentence_lens_nchars']
    sentence_lens_nwords = kwargs['sentence_lens_nwords']
    word_loc = kwargs['word_loc']
    vocabulary_size = kwargs['vocabulary_size']
    max_char_len = kwargs['max_char_len']
    onehot_words_val = kwargs['onehot_words_val']
    word_pos_val = kwargs['word_pos_val']
    sentence_lens_nchars_val = kwargs['sentence_lens_nchars_val']
    sentence_lens_nwords_val = kwargs['sentence_lens_nwords_val']
    word_loc_val = kwargs['word_loc_val']

    batch_size = kwargs['batch_size']
    input_size = vocabulary_size
    hidden_size = kwargs['hidden_size']
    decoder_dim = kwargs['decoder_dim']
    decoder_units_p3 = kwargs['decoder_units_p3']
    num_batches = len(onehot_words) // batch_size
    network_dict['input_size'] = input_size

    encoder_k = encoder.Encoder(**network_dict)

    # onehot_words,word_pos,vocabulary_size = encoder_k.run_preprocess()
    # prepping permutation matrix for all instances seperately
    perm_mat, max_word_len, sent_len_list = prep_perm_matrix(
        batch_size=batch_size,
        word_pos_matrix=word_pos,
        max_char_len=max_char_len)

    # placeholders
    sent_word_len_list_pl = tf.placeholder(name='word_lens',
                                           dtype=tf.int32,
                                           shape=[batch_size])
    perm_mat_pl = tf.placeholder(name='perm_mat_pl',
                                 dtype=tf.int32,
                                 shape=[batch_size, max_word_len])
    onehot_words_pl = tf.placeholder(
        name='onehot_words',
        dtype=tf.float32,
        shape=[batch_size, max_char_len, vocabulary_size])
    word_pos_pl = tf.placeholder(name='word_pos',
                                 dtype=tf.float32,
                                 shape=[batch_size, max_char_len])
    sent_char_len_list_pl = tf.placeholder(name='sent_char_len_list',
                                           dtype=tf.float32,
                                           shape=[batch_size])
    word_loc_pl = tf.placeholder(name='word_loc',
                                 dtype=tf.float32,
                                 shape=[batch_size])

    # decoder
    arg_dict = {
        'decoder_p3_units': decoder_units_p3,
        'encoder_dim': hidden_size,
        'lat_word_dim': hidden_size,
        'sentence_lens': None,
        'global_lat_dim': hidden_size,
        'batch_size': batch_size,
        'max_num_words': max_word_len,
        'decoder_units': decoder_dim,
        'num_sentence_characters': max_char_len,
        'dict_length': vocabulary_size
    }
    decoder = Decoder(**arg_dict)

    # step counter
    global_step = tf.Variable(0, name='global_step', trainable=False)

    word_state_out, mean_state_out, logsig_state_out = encoder_k.run_encoder(
        inputs=onehot_words_pl, word_pos=word_pos_pl, reuse=None)

    # picking out our words
    # why do these all start at 0?
    # replace 0's possibly with len+1
    ## RELYING ON THERE BEING NOTHING AT ZEROS
    word_state_out.set_shape([max_char_len, batch_size, hidden_size])
    mean_state_out.set_shape([max_char_len, batch_size, hidden_size])
    logsig_state_out.set_shape([max_char_len, batch_size, hidden_size])
    word_state_out_p = permute_encoder_output(encoder_out=word_state_out,
                                              perm_mat=perm_mat_pl,
                                              batch_size=batch_size,
                                              max_word_len=max_word_len)
    mean_state_out_p = permute_encoder_output(encoder_out=mean_state_out,
                                              perm_mat=perm_mat_pl,
                                              batch_size=batch_size,
                                              max_word_len=max_word_len)
    logsig_state_out_p = permute_encoder_output(encoder_out=logsig_state_out,
                                                perm_mat=perm_mat_pl,
                                                batch_size=batch_size,
                                                max_word_len=max_word_len)
    # Initialize decoder
    ##Note to self: need to input sentence lengths vector, also check to make sure all the placeholders flow into my class and tensorflow with ease

    out_o, global_latent_o, global_logsig_o, global_mu_o = decoder.run_decoder(
        reuse=None,
        units_lstm_decoder=decoder_dim,
        lat_words=word_state_out_p,
        units_dense_global=decoder_dim,
        sequence_length=tf.cast(sent_char_len_list_pl, dtype=tf.int32))

    # shaping for batching
    onehot_words = np.reshape(
        onehot_words, newshape=[-1, batch_size, max_char_len, vocabulary_size])
    word_pos = np.reshape(word_pos, newshape=[-1, batch_size, max_char_len])
    sentence_lens_nwords = np.reshape(sentence_lens_nwords,
                                      newshape=[-1, batch_size])
    sentence_lens_nchars = np.reshape(sentence_lens_nchars,
                                      newshape=[-1, batch_size])
    word_loc = np.reshape(word_loc, newshape=[-1, batch_size])

    # shaping for validation set
    batch_size_val = batch_size
    n_valid = np.shape(onehot_words_val)[0]

    r = n_valid % batch_size_val
    n_valid_use = n_valid - r
    # might have to fix this before reporting results
    onehot_words_val = np.reshape(
        onehot_words_val[0:n_valid_use, ...],
        newshape=[-1, batch_size_val, max_char_len, vocabulary_size])
    word_pos_val = np.reshape(word_pos_val[0:n_valid_use, ...],
                              newshape=[-1, batch_size_val, max_char_len])
    sentence_lens_nwords_val = np.reshape(
        sentence_lens_nwords_val[0:n_valid_use], newshape=[-1, batch_size_val])
    sentence_lens_nchars_val = np.reshape(
        sentence_lens_nchars_val[0:n_valid_use], newshape=[-1, batch_size_val])
    word_loc_val = np.reshape(word_loc_val[0:n_valid_use],
                              newshape=[-1, batch_size_val])

    ###KL annealing parameters
    shift = 10000
    total_steps = np.round(np.true_divide(n_epochs, 5) *
                           np.shape(onehot_words)[0],
                           decimals=0)

    ####
    cost, reconstruction, kl_p3, kl_p1, kl_global, kl_p2, anneal = decoder.calc_cost(
        kl=True,
        sentence_word_lens=sent_word_len_list_pl,
        shift=shift,
        total_steps=total_steps,
        global_step=global_step,
        global_latent_sample=global_latent_o,
        global_logsig=global_logsig_o,
        global_mu=global_mu_o,
        predictions=out_o,
        true_input=onehot_words_pl,
        posterior_logsig=logsig_state_out_p,
        posterior_mu=mean_state_out_p,
        post_samples=word_state_out_p,
        reuse=None)

    ######
    # Train Step

    # clipping gradients
    ######
    lr = 1e-3
    opt = tf.train.AdamOptimizer(lr)
    grads_t, vars_t = zip(*opt.compute_gradients(cost))
    clipped_grads_t, grad_norm_t = tf.clip_by_global_norm(grads_t,
                                                          clip_norm=5.0)
    train_step = opt.apply_gradients(zip(clipped_grads_t, vars_t),
                                     global_step=global_step)
    ######
    # testing stuff
    # testing pls
    sent_word_len_list_pl_val = tf.placeholder(name='word_lens_val',
                                               dtype=tf.int32,
                                               shape=[batch_size])
    perm_mat_pl_val = tf.placeholder(name='perm_mat_val',
                                     dtype=tf.int32,
                                     shape=[batch_size, max_word_len])
    onehot_words_pl_val = tf.placeholder(
        name='onehot_words_val',
        dtype=tf.float32,
        shape=[batch_size, max_char_len, vocabulary_size])
    word_pos_pl_val = tf.placeholder(name='word_pos_val',
                                     dtype=tf.float32,
                                     shape=[batch_size, max_char_len])
    sent_char_len_list_pl_val = tf.placeholder(name='sent_char_len_list_val',
                                               dtype=tf.float32,
                                               shape=[batch_size])
    word_loc_pl_val = tf.placeholder(name='word_loc_val',
                                     dtype=tf.float32,
                                     shape=[batch_size])

    # testing graph
    word_state_out_val, mean_state_out_val, logsig_state_out_val = encoder_k.run_encoder(
        inputs=onehot_words_pl_val, word_pos=word_pos_pl_val, reuse=True)
    perm_mat_val, _, sent_len_list_val = prep_perm_matrix(
        batch_size=batch_size_val,
        word_pos_matrix=word_pos_val,
        max_char_len=max_char_len,
        max_word_len=max_word_len)

    word_state_out_val, mean_state_out_val, logsig_state_out_val = encoder_k.run_encoder(
        inputs=onehot_words_pl_val, word_pos=word_pos_pl_val, reuse=True)
    word_state_out_val.set_shape([max_char_len, batch_size_val, hidden_size])
    mean_state_out_val.set_shape([max_char_len, batch_size_val, hidden_size])
    logsig_state_out.set_shape([max_char_len, batch_size_val, hidden_size])
    word_state_out_p_val = permute_encoder_output(
        encoder_out=word_state_out_val,
        perm_mat=perm_mat_pl_val,
        batch_size=batch_size_val,
        max_word_len=max_word_len)
    mean_state_out_p_val = permute_encoder_output(
        encoder_out=mean_state_out_val,
        perm_mat=perm_mat_pl_val,
        batch_size=batch_size_val,
        max_word_len=max_word_len)
    logsig_state_out_p_val = permute_encoder_output(
        encoder_out=logsig_state_out_val,
        perm_mat=perm_mat_pl_val,
        batch_size=batch_size_val,
        max_word_len=max_word_len)
    out_o_val, global_latent_o_val, global_logsig_o_val, global_mu_o_val = decoder.run_decoder(
        reuse=True,
        units_lstm_decoder=decoder_dim,
        lat_words=word_state_out_p_val,
        units_dense_global=decoder.global_lat_dim,
        sequence_length=tf.cast(sent_char_len_list_pl_val, dtype=tf.int32))
    # test cost
    test_cost = decoder.test_calc_cost(
        sentence_word_lens=sent_word_len_list_pl_val,
        posterior_logsig=logsig_state_out_p_val,
        post_samples=word_state_out_p_val,
        global_mu=global_mu_o_val,
        global_logsig=global_logsig_o_val,
        global_latent_sample=global_latent_o_val,
        posterior_mu=mean_state_out_p_val,
        true_input=onehot_words_pl_val,
        predictions=out_o_val)

    ######

    ######
    # prior sampling
    samples = np.random.normal(size=[batch_size, decoder.global_lat_dim])
    gen_samples = decoder.generation(samples=samples)

    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())

    ######
    # tensorboard stuff
    summary_inf_train = tf.summary.merge([
        decoder.kls_hist, decoder.global_kl_scalar, decoder.rec_scalar,
        decoder.cost_scalar, decoder.full_kl_scalar,
        decoder.sum_all_activ_hist, decoder.sum_global_activ_hist
    ])
    summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
    ######
    log_file = log_dir + "vaelog.txt"
    logger = logging.getLogger('mVAE_log')
    hdlr = logging.FileHandler(log_file)
    formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    hdlr.setFormatter(formatter)
    logger.addHandler(hdlr)
    logger.setLevel(logging.DEBUG)
    for epoch in range(n_epochs):
        inds = range(np.shape(onehot_words)[0])
        np.random.shuffle(inds)
        for count, batch in enumerate(inds):
            anneal_c_o, train_predictions_o_np, train_cost_o_np, _, global_step_o_np, train_rec_cost_o_np, _, _, _, _, summary_inf_train_o = sess.run(
                [
                    anneal, out_o, cost, train_step, global_step,
                    reconstruction, kl_p3, kl_p1, kl_global, kl_p2,
                    summary_inf_train
                ],
                feed_dict={
                    onehot_words_pl: onehot_words[batch],
                    word_pos_pl: word_pos[batch],
                    perm_mat_pl: perm_mat[batch],
                    sent_word_len_list_pl: sentence_lens_nwords[batch],
                    sent_char_len_list_pl: sentence_lens_nchars[batch],
                    word_loc_pl: word_loc[batch]
                })
            # logger.debug('anneal const {}'.format(anneal_c))
            # logger.debug('ground truth {}'.format(get_output_sentences(index2token, ground_truth[0:10])))
            if count % 1000 == 0:
                # testing on the validation set
                val_predictions_o_np, val_cost_o_np = sess.run(
                    [out_o_val, test_cost],
                    feed_dict={
                        onehot_words_pl_val: onehot_words_val[0],
                        word_pos_pl_val: word_pos_val[0],
                        perm_mat_pl_val: perm_mat_val[0],
                        sent_word_len_list_pl_val: sentence_lens_nwords_val[0],
                        sent_char_len_list_pl_val: sentence_lens_nchars_val[0],
                        word_loc_pl_val: word_loc_val[0]
                    })

                predictions = np.argmax(train_predictions_o_np[0:10], axis=-1)
                ground_truth = np.argmax(onehot_words[batch][0:10], axis=-1)

                logger.debug('predictions {}'.format(
                    [[index2token[j] for j in i]
                     for i in predictions[0:10, 0:50]]))
                logger.debug('ground truth {}'.format(
                    [[index2token[j] for j in i]
                     for i in ground_truth[0:10, 0:50]]))
                logger.debug(
                    'global step: {} Epoch: {} count: {} anneal:{}'.format(
                        global_step_o_np, epoch, count, anneal_c_o))
                logger.debug('train cost: {}'.format(train_cost_o_np))
                logger.debug('validation cost {}'.format(val_cost_o_np))
            if count % 10000 == 0:
                # testing on the generative model
                gen_o_np = sess.run([gen_samples])
                gen_pred = np.argmax(gen_o_np[0:10], axis=-1)
                logger.debug('GEN predictions {}'.format(
                    [[index2token[j] for j in i]
                     for i in gen_pred[0][0:10, 0:50]]))

        summary_writer.add_summary(summary_inf_train_o, global_step_o_np)
        summary_writer.flush()
示例#3
0
def train(log_dir, n_epochs, network_dict, index2token, mode, **kwargs):
    onehot_words = kwargs['onehot_words']
    word_pos = kwargs['word_pos']
    sentence_lens_nchars = kwargs['sentence_lens_nchars']
    sentence_lens_nwords = kwargs['sentence_lens_nwords']
    vocabulary_size = kwargs['vocabulary_size']
    max_char_len = kwargs['max_char_len']
    onehot_words_val = kwargs['onehot_words_val']
    word_pos_val = kwargs['word_pos_val']
    sentence_lens_nchars_val = kwargs['sentence_lens_nchars_val']
    batch_size = kwargs['batch_size']
    input_size = vocabulary_size
    hidden_size = kwargs['hidden_size']
    decoder_dim = kwargs['decoder_dim']
    decoder_units_p3 = kwargs['decoder_units_p3']
    network_dict['input_size'] = input_size

    # prepping permutation matrix for all instances seperately
    perm_mat, max_lat_word_len, lat_sent_len_list = train_helper.prep_perm_matrix(batch_size=batch_size,
                                                                                  word_pos_matrix=word_pos,
                                                                                  max_char_len=max_char_len)
    onehot_words = np.reshape(onehot_words, newshape=[-1, batch_size, max_char_len, vocabulary_size])
    word_pos = np.reshape(word_pos, newshape=[-1, batch_size, max_char_len])
    sentence_lens_nchars = np.reshape(sentence_lens_nchars, newshape=[-1, batch_size])
    lat_sent_len_list = np.reshape(lat_sent_len_list, [-1, batch_size])

    # shaping for validation set
    batch_size_val = batch_size
    n_valid = np.shape(onehot_words_val)[0]
    r = n_valid % batch_size_val
    n_valid_use = n_valid - r

    onehot_words_val = np.reshape(onehot_words_val[0:n_valid_use, ...],
                                  newshape=[-1, batch_size_val, max_char_len, vocabulary_size])
    word_pos_val = np.reshape(word_pos_val[0:n_valid_use, ...], newshape=[-1, batch_size_val, max_char_len])
    sentence_lens_nchars_val = np.reshape(sentence_lens_nchars_val[0:n_valid_use], newshape=[-1, batch_size_val])

    perm_mat_val, _, lat_sent_len_list_val = train_helper.prep_perm_matrix(batch_size=batch_size_val,
                                                                           word_pos_matrix=word_pos_val,
                                                                           max_char_len=max_char_len,
                                                                           max_word_len=max_lat_word_len)
    lat_sent_len_list_val = np.reshape(np.reshape(lat_sent_len_list_val, -1)[0:n_valid_use],
                                       newshape=[-1, batch_size_val])

    # kl_mask
    kl_mask, kl_mask_val = train_helper.kl_mask_prep(lat_sent_len_list, lat_sent_len_list_val, max_lat_word_len,
                                                     batch_size)

    # logging
    log_file = log_dir + "vaelog.txt"
    logger = logging.getLogger('mVAE_log')
    hdlr = logging.FileHandler(log_file)
    formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
    hdlr.setFormatter(formatter)
    logger.addHandler(hdlr)
    logger.setLevel(logging.DEBUG)

    # gaussian samples seed
    global_sample = np.random.normal(size=[batch_size, decoder_dim])
    word_samples = np.random.normal(size=[batch_size, max_lat_word_len, decoder_dim])

    if mode == "new_model":
        # placeholders
        mask_kl_pl = tf.placeholder(name='mask_kl_pl', dtype=tf.float32, shape=[batch_size, max_lat_word_len])
        sent_word_len_list_pl = tf.placeholder(name='sent_word_len_list_pl', dtype=tf.int32, shape=[batch_size])
        perm_mat_pl = tf.placeholder(name='perm_mat_pl', dtype=tf.int32, shape=[batch_size, max_lat_word_len])
        onehot_words_pl = tf.placeholder(name='onehot_words_pl', dtype=tf.float32,
                                         shape=[batch_size, max_char_len, vocabulary_size])
        word_pos_pl = tf.placeholder(name='word_pos_pl', dtype=tf.float32, shape=[batch_size, max_char_len])
        sent_char_len_list_pl = tf.placeholder(name='sent_char_len_list_pl', dtype=tf.float32, shape=[batch_size])

        mask_kl_pl_val = tf.placeholder(name='mask_kl_pl_val', dtype=tf.float32, shape=[batch_size, max_lat_word_len])
        sent_word_len_list_pl_val = tf.placeholder(name='sent_word_len_list_pl_val', dtype=tf.int32,
                                                   shape=[batch_size])
        perm_mat_pl_val = tf.placeholder(name='perm_mat_pl_val', dtype=tf.int32, shape=[batch_size, max_lat_word_len])
        onehot_words_pl_val = tf.placeholder(name='onehot_words_pl_val', dtype=tf.float32,
                                             shape=[batch_size, max_char_len, vocabulary_size])
        word_pos_pl_val = tf.placeholder(name='word_pos_pl_val', dtype=tf.float32, shape=[batch_size, max_char_len])
        sent_char_len_list_pl_val = tf.placeholder(name='sent_char_len_list_pl_val', dtype=tf.float32,
                                                   shape=[batch_size])

        # step counter
        global_step = tf.Variable(0, name='global_step', trainable=False)

        encoder_k = encoder.Encoder(**network_dict)
        word_state_out, mean_state_out, logsig_state_out = encoder_k.run_encoder(sentence_lens=sent_char_len_list_pl,
                                                                                 train=True, inputs=onehot_words_pl,
                                                                                 word_pos=word_pos_pl, reuse=None)

        word_state_out.set_shape([max_char_len, batch_size, hidden_size])
        mean_state_out.set_shape([max_char_len, batch_size, hidden_size])
        logsig_state_out.set_shape([max_char_len, batch_size, hidden_size])
        word_state_out_p = train_helper.permute_encoder_output(encoder_out=word_state_out, perm_mat=perm_mat_pl,
                                                               batch_size=batch_size, max_word_len=max_lat_word_len)
        mean_state_out_p = train_helper.permute_encoder_output(encoder_out=mean_state_out, perm_mat=perm_mat_pl,
                                                               batch_size=batch_size, max_word_len=max_lat_word_len)
        logsig_state_out_p = train_helper.permute_encoder_output(encoder_out=logsig_state_out, perm_mat=perm_mat_pl,
                                                                 batch_size=batch_size, max_word_len=max_lat_word_len)

        # Initialize decoder
        arg_dict = {'decoder_p3_units': decoder_units_p3, 'encoder_dim': hidden_size, 'lat_word_dim': hidden_size,
                    'sentence_lens': None, 'global_lat_dim': hidden_size, 'batch_size': batch_size,
                    'max_num_lat_words': max_lat_word_len, 'decoder_units': decoder_dim,
                    'num_sentence_characters': max_char_len, 'dict_length': vocabulary_size}
        decoder = Decoder(**arg_dict)

        train_logits, global_latent_o, global_logsig_o, global_mu_o = decoder.run_decoder(
            word_sequence_length=sent_word_len_list_pl, train=True, reuse=None, units_lstm_decoder=decoder_dim,
            lat_words=word_state_out_p, units_dense_global=decoder_dim,
            char_sequence_length=tf.cast(sent_char_len_list_pl, dtype=tf.int32))
        train_logits = tf.identity(train_logits, name="train_logits")

        # KL annealing parameters
        shift = 5000
        total_steps = np.round(np.true_divide(n_epochs, 16) * np.shape(onehot_words)[0], decimals=0)

        # calculate cost
        train_cost, reconstruction, kl_p3, kl_p1, kl_global, kl_p2, anneal_value, _ = decoder.calc_cost(eow_mask=None,
                                                                                                        mask_kl=mask_kl_pl,
                                                                                                        kl=True,
                                                                                                        sentence_word_lens=sent_word_len_list_pl,
                                                                                                        shift=shift,
                                                                                                        total_steps=total_steps,
                                                                                                        global_step=global_step,
                                                                                                        global_latent_sample=global_latent_o,
                                                                                                        global_logsig=global_logsig_o,
                                                                                                        global_mu=global_mu_o,
                                                                                                        predictions=train_logits,
                                                                                                        true_input=onehot_words_pl,
                                                                                                        posterior_logsig=logsig_state_out_p,
                                                                                                        posterior_mu=mean_state_out_p,
                                                                                                        post_samples=word_state_out_p,
                                                                                                        reuse=None)

        # clipping gradients
        lr = 1e-4
        opt = tf.train.AdamOptimizer(lr)
        grads_t, vars_t = zip(*opt.compute_gradients(train_cost))
        clipped_grads_t, grad_norm_t = tf.clip_by_global_norm(grads_t, clip_norm=5.0)
        train_step = opt.apply_gradients(zip(clipped_grads_t, vars_t), global_step=global_step, name="train_step")
        # regex = re.compile('[^a-zA-Z]')
        # sum_grad_hist = [tf.summary.histogram(name=regex.sub('',str(j)),values=i) for i,j in zip(clipped_grads_t,vars_t) if i is not None]
        # norm_grad = tf.summary.scalar(name='grad_norm',tensor=grad_norm_t)

        # testing graph
        word_state_out_val, mean_state_out_val, logsig_state_out_val = encoder_k.run_encoder(
            sentence_lens=sent_char_len_list_pl_val, train=False, inputs=onehot_words_pl_val, word_pos=word_pos_pl_val,
            reuse=True)
        word_state_out_val.set_shape([max_char_len, batch_size_val, hidden_size])
        mean_state_out_val.set_shape([max_char_len, batch_size_val, hidden_size])
        logsig_state_out.set_shape([max_char_len, batch_size_val, hidden_size])
        word_state_out_p_val = train_helper.permute_encoder_output(encoder_out=word_state_out_val,
                                                                   perm_mat=perm_mat_pl_val, batch_size=batch_size_val,
                                                                   max_word_len=max_lat_word_len)
        mean_state_out_p_val = train_helper.permute_encoder_output(encoder_out=mean_state_out_val,
                                                                   perm_mat=perm_mat_pl_val, batch_size=batch_size_val,
                                                                   max_word_len=max_lat_word_len)
        logsig_state_out_p_val = train_helper.permute_encoder_output(encoder_out=logsig_state_out_val,
                                                                     perm_mat=perm_mat_pl_val,
                                                                     batch_size=batch_size_val,
                                                                     max_word_len=max_lat_word_len)

        # decode test
        test_logits, global_latent_o_val, global_logsig_o_val, global_mu_o_val = decoder.run_decoder(
            word_sequence_length=sent_word_len_list_pl_val, train=False, reuse=True, units_lstm_decoder=decoder_dim,
            lat_words=mean_state_out_p_val, units_dense_global=decoder.global_lat_dim,
            char_sequence_length=tf.cast(sent_char_len_list_pl_val, dtype=tf.int32))
        test_logits = tf.identity(test_logits, name="test_logits")

        # test cost
        test_cost = decoder.test_calc_cost(mask_kl=mask_kl_pl_val, sentence_word_lens=sent_word_len_list_pl_val,
                                           posterior_logsig=logsig_state_out_p_val, post_samples=word_state_out_p_val,
                                           global_mu=global_mu_o_val, global_logsig=global_logsig_o_val,
                                           global_latent_sample=global_latent_o_val, posterior_mu=mean_state_out_p_val,
                                           true_input=onehot_words_pl_val, predictions=test_logits)