def selective_generate(model, data_feed, config, selected_clusters): model.eval() de_tknize = utils.get_dekenize() data_feed.epoch_init(config, shuffle=False, verbose=False) # get all code codes = set([d['code'] for d in selected_clusters]) logger.info("Generation: {} batches".format(data_feed.num_batch)) data = [] total_cnt = 0.0 in_cnt = 0.0 while True: batch = data_feed.next_batch() if batch is None: break outputs, labels = model(batch, mode=GEN, gen_type=config.gen_type) # move from GPU to CPU pred_labels = [ t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE] ] pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1) true_labels = labels.cpu().data.numpy() y_ids = outputs[DecoderRNN.KEY_LATENT].cpu().data.numpy() if config.gmm: pass else: y_ids = y_ids.reshape(-1, config.latent_size) ctx = batch.get('contexts') ctx_size = ctx.shape[1] for b_id in range(pred_labels.shape[0]): y_id = map(str, y_ids[b_id]) code = '-'.join(y_id) total_cnt += 1 if code in codes: pred_str, attn = engine.get_sent(model, de_tknize, pred_labels, b_id, attn=None) ctx_str = [] for i in range(ctx_size): temp, _ = engine.get_sent(model, de_tknize, ctx[:, i, 1:], b_id) ctx_str.append(temp) ctx_str = '<t>'.join(ctx_str) true_str, _ = engine.get_sent(model, de_tknize, true_labels, b_id) in_cnt += 1 data.append({ 'context': ctx_str, 'target': true_str, 'predict': pred_str, 'code': code }) logger.info("In rate {}".format(in_cnt / total_cnt)) return data
def dump_latent(model, data_feed, config, dest_f, num_batch=1): model.eval() de_tknize = utils.get_dekenize() data_feed.epoch_init(config, verbose=False, shuffle=False) logger.info("Dumping: {} batches".format( data_feed.num_batch if num_batch is None else num_batch)) all_zs = [] all_labels = [] all_metas = [] while True: batch = data_feed.next_batch() if batch is None or (num_batch is not None and data_feed.ptr > num_batch): break results = model(batch, mode=TEACH_FORCE, return_latent=True) labels = batch.outputs metas = batch.metas log_qy = results.log_qy.cpu().squeeze(0).data y_ids = results.y_ids.cpu().data dec_init = results.dec_init_state.cpu().squeeze().data for b_id in range(labels.shape[0]): true_str, _ = engine.get_sent(model, de_tknize, labels, b_id) all_labels.append(true_str) all_metas.append(metas[b_id]) all_zs.append((log_qy.numpy(), dec_init.numpy(), y_ids.numpy())) pickle.dump({ 'z': all_zs, 'labels': all_labels, "metas": all_metas }, dest_f) logger.info("Dumping Done")
def generate(model, data_feed, config, evaluator, num_batch=1, dest_f=None): model.eval() old_batch_size = config.batch_size if num_batch != None: config.batch_size = 5 de_tknize = utils.get_dekenize() data_feed.epoch_init(config, shuffle=False, verbose=False) config.batch_size = old_batch_size evaluator.initialize() logger.info("Generation: {} batches".format( data_feed.num_batch if num_batch is None else num_batch)) def write(msg): if dest_f is None: logger.info(msg) else: dest_f.write(msg + '\n') while True: batch = data_feed.next_batch() if batch is None or (num_batch is not None and data_feed.ptr > num_batch): break outputs, labels = model(batch, mode=GEN, gen_type=config.gen_type) prev_outputs, next_outputs = outputs prev_labels, next_labels = labels cur_labels = batch.get('outputs') prev_labels = prev_labels.cpu().data.numpy() next_labels = next_labels.cpu().data.numpy() prev_pred = [ t.cpu().data.numpy() for t in prev_outputs[DecoderRNN.KEY_SEQUENCE] ] prev_pred = np.array(prev_pred, dtype=int).squeeze(-1).swapaxes(0, 1) next_pred = [ t.cpu().data.numpy() for t in next_outputs[DecoderRNN.KEY_SEQUENCE] ] next_pred = np.array(next_pred, dtype=int).squeeze(-1).swapaxes(0, 1) for b_id in range(cur_labels.shape[0]): ctx_str, _ = engine.get_sent(model, de_tknize, cur_labels, b_id) prev_true_str, _ = engine.get_sent(model, de_tknize, prev_labels, b_id) next_true_str, _ = engine.get_sent(model, de_tknize, next_labels, b_id) pred_prev_str, _ = engine.get_sent(model, de_tknize, prev_pred, b_id) pred_next_str, _ = engine.get_sent(model, de_tknize, next_pred, b_id) evaluator.add_example(prev_true_str, pred_prev_str) evaluator.add_example(next_true_str, pred_next_str) write("Response: {}".format(ctx_str)) write("Prev Target: {}".format(prev_true_str)) write("Prev Predict: {}".format(pred_prev_str)) write("Next Target: {}".format(next_true_str)) write("Next Predict: {}\n".format(pred_next_str)) if dest_f is None: logging.info(evaluator.get_report(include_error=dest_f is not None)) else: dest_f.write(evaluator.get_report(include_error=dest_f is not None)) logger.info("Generation Done")
def gen_with_source(model, data_feed, config, num_batch=1, dest_f=None): model.eval() old_batch_size = config.batch_size if num_batch != None: config.batch_size = 3 de_tknize = utils.get_dekenize() data_feed.epoch_init(config, shuffle=False, verbose=False) logger.info("Generation: {} batches".format( data_feed.num_batch if num_batch is None else num_batch)) print_cnt = 0 sample_n = 5 def write(msg): if dest_f is None: logger.info(msg) else: dest_f.write(msg + '\n') while True: batch = data_feed.next_batch() if batch is None or (num_batch is not None and data_feed.ptr > num_batch): break sample_outputs, _ = model(batch, mode=GEN, gen_type="sample", sample_n=sample_n) greedy_outputs, labels = model(batch, mode=GEN, gen_type="greedy", sample_n=sample_n) # move from GPU to CPU labels = labels.cpu() sample_labels = [ t.cpu().data.numpy() for t in sample_outputs[DecoderRNN.KEY_SEQUENCE] ] greedy_labels = [ t.cpu().data.numpy() for t in greedy_outputs[DecoderRNN.KEY_SEQUENCE] ] log_py = greedy_outputs[DecoderRNN.KEY_POLICY] greedy_y_ids = greedy_outputs[DecoderRNN.KEY_LATENT].cpu().data.numpy() sample_y_ids = sample_outputs[DecoderRNN.KEY_LATENT].cpu().data.numpy() sample_labels = np.array(sample_labels, dtype=int).squeeze(-1).swapaxes(0, 1) greedy_labels = np.array(greedy_labels, dtype=int).squeeze(-1).swapaxes(0, 1) true_labels = labels.data.numpy() for b_id in range(true_labels.shape[0]): ctx_str, _ = engine.get_sent(model, de_tknize, batch.get('source'), b_id) true_str, _ = engine.get_sent(model, de_tknize, true_labels, b_id) print_cnt += 1 write("Source: {}".format(ctx_str)) write("Target: {}".format(true_str)) for n_id in range(sample_n): pred_str, attn = engine.get_sent( model, de_tknize, greedy_labels, b_id + config.batch_size * n_id) code = map(str, greedy_y_ids[b_id + config.batch_size * n_id]) write("Sample Z ({}): {}".format(" ".join(code), pred_str)) for n_id in range(sample_n): pred_str, attn = engine.get_sent( model, de_tknize, sample_labels, b_id + config.batch_size * n_id) code = map(str, sample_y_ids[b_id + config.batch_size * n_id]) write("Sample W ({}): {}".format(" ".join(code), pred_str)) write('\n') config.batch_size = old_batch_size logger.info("Generation Done\n")
def generate(model, data_feed, config, evaluator, num_batch=1, dest_f=None): model.eval() old_batch_size = config.batch_size if num_batch != None: config.batch_size = 5 de_tknize = utils.get_dekenize() data_feed.epoch_init(config, shuffle=False, verbose=False) config.batch_size = old_batch_size evaluator.initialize() logger.info("Generation: {} batches".format( data_feed.num_batch if num_batch is None else num_batch)) while True: batch = data_feed.next_batch() if batch is None or (num_batch is not None and data_feed.ptr > num_batch): break outputs, labels = model(batch, mode=GEN, gen_type=config.gen_type) # move from GPU to CPU pred_labels = [ t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_SEQUENCE] ] pred_labels = np.array(pred_labels, dtype=int).squeeze(-1).swapaxes(0, 1) true_labels = labels.cpu().data.numpy() # get attention if possible if config.use_attn: pred_attns = [ t.cpu().data.numpy() for t in outputs[DecoderRNN.KEY_ATTN_SCORE] ] pred_attns = np.array(pred_attns, dtype=float).squeeze(2).swapaxes(0, 1) else: pred_attns = None if "contexts" in batch: ctx = batch.get('contexts') ctx_size = ctx.shape[1] elif "source" in batch: ctx = batch.get('source') ctx_size = 1 else: raise ValueError("Not support source class. (contexts / source)") for b_id in range(pred_labels.shape[0]): pred_str, attn = engine.get_sent(model, de_tknize, pred_labels, b_id, attn=pred_attns) if "contexts" in batch: ctx_str = [] for i in range(ctx_size): temp, _ = engine.get_sent(model, de_tknize, ctx[:, i, 1:], b_id) if temp: ctx_str.append(temp) ctx_str = '<t>'.join(ctx_str) else: ctx_str, _ = engine.get_sent(model, de_tknize, ctx[:, 1:], b_id) true_str, _ = engine.get_sent(model, de_tknize, true_labels, b_id) evaluator.add_example(true_str, pred_str) if dest_f is None: logger.info("Source: {}".format(ctx_str)) logger.info("Target: {}".format(true_str)) logger.info("Predict: {}\n".format(pred_str)) else: dest_f.write("Source: {}\n".format(ctx_str)) dest_f.write("Target: {}\n".format(true_str)) dest_f.write("Predict: {}\n\n".format(pred_str)) if dest_f is None: logging.info(evaluator.get_report(include_error=dest_f is not None)) else: dest_f.write(evaluator.get_report(include_error=dest_f is not None)) logger.info("Generation Done")