def eval_model(ds, states, f_log_prob): total_ce_sum = 0.0 total_frame_count = 0 for b_idx, batch in enumerate(ds.get_epoch_iterator(), start=1): reset_state(states) input_data, input_mask, ivector_data, ivector_mask, target_data, target_mask = batch n_batch, n_seq, n_feat = input_data.shape if n_batch < args.batch_size: continue input_data = numpy.transpose(input_data, (1, 0, 2)) target_data = numpy.transpose(target_data, (1, 0)) target_mask = numpy.transpose(target_mask, (1, 0)) cost, cost_len = f_log_prob(input_data, target_data, target_mask) total_ce_sum += cost.sum() total_frame_count += cost_len.sum() return total_ce_sum / total_frame_count
def avg_z_1_3d(ds, states, f_debug): total_ce_sum = 0.0 total_frame_count = 0 z_1_3d_list = [] for b_idx, batch in enumerate(ds.get_epoch_iterator(), start=1): reset_state(states) input_data, input_mask, ivector_data, ivector_mask, target_data, target_mask = batch n_batch, n_seq, n_feat = input_data.shape if n_batch < args.batch_size: continue input_data = numpy.transpose(input_data, (1, 0, 2)) target_data = numpy.transpose(target_data, (1, 0)) target_mask = numpy.transpose(target_mask, (1, 0)) h_rnn_1_3d, h_rnn_2_3d, h_rnn_3_3d, z_1_3d, z_2_3d = \ f_debug(input_data, target_data, target_mask) z_1_3d_list.append(z_1_3d) return batch_mean(z_1_3d_list)
tparams = init_tparams_with_restored_value(tparams, args.model) print('Loading data streams from {}'.format(args.data_path)) sync_data(args) ds = create_ivector_datastream(args.data_path, args.dataset, args.batch_size) epoch_sw = StopWatch() status_sw = StopWatch() status_sw.reset() z_1_3d_list = [] for b_idx, batch in enumerate(ds.get_epoch_iterator(), start=1): reset_state(states) input_data, input_mask, ivector_data, ivector_mask, target_data, target_mask = batch n_batch, n_seq, n_feat = input_data.shape if n_batch < args.batch_size: continue input_data_trans = numpy.transpose(input_data, (1, 0, 2)) _, _, _, z_1_3d, _ = f_debug(input_data_trans) z_1_3d_trans = numpy.transpose(z_1_3d, (1, 0)) compressed_input_data = compress_batch(input_data, z_1_3d_trans) compressed_input_mask = compress_batch(input_mask, z_1_3d_trans) compressed_target_data = compress_batch(target_data, z_1_3d_trans) compressed_target_mask = compress_batch(target_mask, z_1_3d_trans)