def _run_epoch(self, sess, network, inputs, targets, train_op, is_train): start_time = time.time() y = [] y_true = [] total_loss, n_batches = 0.0, 0 for sub_idx, each_data in enumerate(itertools.izip(inputs, targets)): each_x, each_y = each_data # # Initialize state of LSTM - Unidirectional LSTM # state = sess.run(network.initial_state) # Initialize state of LSTM - Bidirectional LSTM fw_state = sess.run(network.fw_initial_state) bw_state = sess.run(network.bw_initial_state) for x_batch, y_batch in iterate_batch_seq_minibatches( inputs=each_x, targets=each_y, batch_size=self.batch_size, seq_length=self.seq_length): feed_dict = { network.input_var: x_batch, network.target_var: y_batch } # Unidirectional LSTM # for i, (c, h) in enumerate(network.initial_state): # feed_dict[c] = state[i].c # feed_dict[h] = state[i].h # _, loss_value, y_pred, state = sess.run( # [train_op, network.loss_op, network.pred_op, network.final_state], # feed_dict=feed_dict # ) for i, (c, h) in enumerate(network.fw_initial_state): feed_dict[c] = fw_state[i].c feed_dict[h] = fw_state[i].h for i, (c, h) in enumerate(network.bw_initial_state): feed_dict[c] = bw_state[i].c feed_dict[h] = bw_state[i].h _, loss_value, y_pred, fw_state, bw_state = sess.run( [ train_op, network.loss_op, network.pred_op, network.fw_final_state, network.bw_final_state ], feed_dict=feed_dict) total_loss += loss_value n_batches += 1 y.append(y_pred) y_true.append(y_batch) # Check the loss value assert not np.isnan(loss_value), \ "Model diverged with loss = NaN" duration = time.time() - start_time total_loss /= n_batches total_y_pred = np.hstack(y) total_y_true = np.hstack(y_true) return total_y_true, total_y_pred, total_loss, duration
def custom_run_epoch(sess, network, inputs, targets, train_op, is_train, output_dir, subject_idx): start_time = time.time() y = [] y_true = [] all_fw_memory_cells = [] all_bw_memory_cells = [] total_loss, n_batches = 0.0, 0 for sub_f_idx, each_data in enumerate(zip(inputs, targets)): each_x, each_y = each_data # # Initialize state of LSTM - Unidirectional LSTM # state = sess.run(network.initial_state) # Initialize state of LSTM - Bidirectional LSTM fw_state = sess.run(network.fw_initial_state) bw_state = sess.run(network.bw_initial_state) # Prepare storage for memory cells n_all_data = len(each_x) extra = n_all_data % network.seq_length n_data = n_all_data - extra cell_size = 512 fw_memory_cells = np.zeros((n_data, network.n_rnn_layers, cell_size)) bw_memory_cells = np.zeros((n_data, network.n_rnn_layers, cell_size)) seq_idx = 0 # Store prediction and actual stages of each patient each_y_true = [] each_y_pred = [] for x_batch, y_batch in iterate_batch_seq_minibatches( inputs=each_x, targets=each_y, batch_size=network.batch_size, seq_length=network.seq_length): feed_dict = { network.input_var: x_batch, network.target_var: y_batch } # Unidirectional LSTM # for i, (c, h) in enumerate(network.initial_state): # feed_dict[c] = state[i].c # feed_dict[h] = state[i].h # _, loss_value, y_pred, state = sess.run( # [train_op, network.loss_op, network.pred_op, network.final_state], # feed_dict=feed_dict # ) for i, (c, h) in enumerate(network.fw_initial_state): feed_dict[c] = fw_state[i].c feed_dict[h] = fw_state[i].h for i, (c, h) in enumerate(network.bw_initial_state): feed_dict[c] = bw_state[i].c feed_dict[h] = bw_state[i].h _, loss_value, y_pred, fw_state, bw_state = sess.run( [ train_op, network.loss_op, network.pred_op, network.fw_final_state, network.bw_final_state ], feed_dict=feed_dict) # Extract memory cells fw_states = sess.run(network.fw_states, feed_dict=feed_dict) bw_states = sess.run(network.bw_states, feed_dict=feed_dict) offset_idx = seq_idx * network.seq_length for s_idx in range(network.seq_length): for r_idx in range(network.n_rnn_layers): fw_memory_cells[offset_idx + s_idx][r_idx] = np.squeeze( fw_states[s_idx][r_idx].c) bw_memory_cells[offset_idx + s_idx][r_idx] = np.squeeze( bw_states[s_idx][r_idx].c) seq_idx += 1 each_y_true.extend(y_batch) each_y_pred.extend(y_pred) total_loss += loss_value n_batches += 1 # Check the loss value assert not np.isnan(loss_value), \ "Model diverged with loss = NaN" all_fw_memory_cells.append(fw_memory_cells) all_bw_memory_cells.append(bw_memory_cells) y.append(each_y_pred) y_true.append(each_y_true) # Save memory cells and predictions save_dict = { "fw_memory_cells": fw_memory_cells, "bw_memory_cells": bw_memory_cells, "y_true": y_true, "y_pred": y } save_path = os.path.join(output_dir, "output_subject{}.npz".format(subject_idx)) np.savez(save_path, **save_dict) print("Saved outputs to {}".format(save_path)) duration = time.time() - start_time total_loss /= n_batches total_y_pred = np.hstack(y) total_y_true = np.hstack(y_true) return total_y_true, total_y_pred, total_loss, duration