def batch_run(sess, model, fetch, src, tgt= None, batch= None): if batch is None: batch = len(src) for i, j in partition(len(src), batch, discard= False): feed = {model.src_: src[i:j]} if tgt is not None: feed[model.tgt_] = tgt[i:j] yield sess.run(fetch, feed)
def summ_discr(step): fetches = model.d_loss results = map( np.mean, zip(*(sess.run(fetches, { model['x']: x_test[i:j], model['y']: y_test[i:j] }) for i, j in partition(len(x_test), batch_size, discard=False)))) results = list(results) if n_dis > 1: # put all losses of the discriminators in one plot for i in range(n_dis): d_wrtr[i].add_summary( sess.run(summary_discr[i], dict(zip(fetches, results))), step) #d_wrtr[i].add_summary(sess.run(summary_discr[i], dict([(fetches[i], results[i])])), step) d_wrtr[i].flush()
def log(step, wtr=tf.summary.FileWriter("../log/{}".format(trial)), log=tf.summary.merge((tf.summary.scalar('step_xid', dummy[0]), tf.summary.scalar('step_err', dummy[1]))), fet=(valid.xid, valid.err), inp=(valid.src_idx, valid.len_src, valid.tgt_idx, valid.len_tgt), src=src_valid, tgt=tgt_valid, bat=256): stats = [ sess.run(fet, dict(zip(inp, feed(src[i:j], tgt[i:j])))) for i, j in partition(len(tgt), bat) ] stats = [np.mean(np.concatenate(stat)) for stat in zip(*stats)] wtr.add_summary(sess.run(log, dict(zip(dummy, stats))), step) wtr.flush()
def summ(step, model=model_valid): wtr.add_summary( sess.run( summary, dict( zip((model.errt, model.loss_gen, model.loss_kld), map( comp(np.mean, np.concatenate), zip(*(sess.run((model.errt_samp, model.loss_gen_samp, model.loss_kld_samp), { model.src: valid[i:j], model.tgt: valid[i:j] }) for i, j in partition( len(valid), T.batch_valid, discard=False) )))))), step) wtr.flush()
def summ(step): fetches = model.g_loss, model.lam, model.d_loss_mean, model.auc_gx results = map( np.mean, zip(*(sess.run(fetches, { model['x']: x_test[i:j], model['y']: y_test[i:j] }) for i, j in partition(len(x_test), batch_size, discard=False)))) results = list(results) wrtr.add_summary(sess.run(summary_test, dict(zip(fetches, results))), step) if dataset == "ucsd1": # bike, skateboard, grasswalk, shopping cart, car, normal, normal, grass wrtr.add_summary( sess.run( summary_images, { model.x: x_test[[990, 1851, 2140, 2500, 2780, 2880, 3380, 3580]] }), step) else: wrtr.add_summary(sess.run(summary_images, {model.x: x_test}), step) wrtr.flush()
def translate(src, mode): for i, j in partition(len(src), 256): src_idx, len_src = cws(src[i:j], ret_img=False, ret_idx=True) pred, pidx = infer(mode, m, sess, cwt, src_idx, len_src) yield from trim_str(pidx, cwt)
# Load the model model = vAe('infer') # Restore the session sess = tf.InteractiveSession() tf.train.Saver().restore(sess, path_ckpt) ################################ # deterministic representation # ################################ # encode text with sentence piece model data = list(map(partial(sp.encode_capped, vocab), text)) data = vpack(data, (len(data), max(map(len, data))), vocab.eos_id(), np.int32) # calculate z for the test data in batches inpt = [model.z.eval({model.src: data[i:j]}) for i, j in partition(len(data), 128)] inpt = np.concatenate(inpt, axis=0) np.save(path_emb, inpt) ####################################################### # averaged representation with sentencepiece sampling # ####################################################### def infer_avg(sent, samples=128): bat = [sp.encode_capped_sample(vocab, sent) for _ in range(samples)] bat = vpack(bat, (len(bat), max(map(len, bat))), vocab.eos_id(), np.int32) z = model.z.eval({model.src: bat}) return np.mean(z, axis=0) from tqdm import tqdm
def translate(src): for i, j in partition(len(src), 256): src_idx, len_src = cws(src[i:j], ret_img=False, ret_idx=True) yield from trim_str(infer(src_idx, len_src), cwt)