Ejemplo n.º 1
0
    def _run_epoch(self, sess, network, inputs, targets, train_op, is_train):
        start_time = time.time()
        y = []
        y_true = []
        total_loss, n_batches = 0.0, 0
        for sub_idx, each_data in enumerate(itertools.izip(inputs, targets)):
            each_x, each_y = each_data

            # # Initialize state of LSTM - Unidirectional LSTM
            # state = sess.run(network.initial_state)

            # Initialize state of LSTM - Bidirectional LSTM
            fw_state = sess.run(network.fw_initial_state)
            bw_state = sess.run(network.bw_initial_state)

            for x_batch, y_batch in iterate_batch_seq_minibatches(
                    inputs=each_x,
                    targets=each_y,
                    batch_size=self.batch_size,
                    seq_length=self.seq_length):
                feed_dict = {
                    network.input_var: x_batch,
                    network.target_var: y_batch
                }

                # Unidirectional LSTM
                # for i, (c, h) in enumerate(network.initial_state):
                #     feed_dict[c] = state[i].c
                #     feed_dict[h] = state[i].h

                # _, loss_value, y_pred, state = sess.run(
                #     [train_op, network.loss_op, network.pred_op, network.final_state],
                #     feed_dict=feed_dict
                # )

                for i, (c, h) in enumerate(network.fw_initial_state):
                    feed_dict[c] = fw_state[i].c
                    feed_dict[h] = fw_state[i].h

                for i, (c, h) in enumerate(network.bw_initial_state):
                    feed_dict[c] = bw_state[i].c
                    feed_dict[h] = bw_state[i].h

                _, loss_value, y_pred, fw_state, bw_state = sess.run(
                    [
                        train_op, network.loss_op, network.pred_op,
                        network.fw_final_state, network.bw_final_state
                    ],
                    feed_dict=feed_dict)

                total_loss += loss_value
                n_batches += 1
                y.append(y_pred)
                y_true.append(y_batch)

                # Check the loss value
                assert not np.isnan(loss_value), \
                    "Model diverged with loss = NaN"

        duration = time.time() - start_time
        total_loss /= n_batches
        total_y_pred = np.hstack(y)
        total_y_true = np.hstack(y_true)

        return total_y_true, total_y_pred, total_loss, duration
Ejemplo n.º 2
0
def custom_run_epoch(sess, network, inputs, targets, train_op, is_train,
                     output_dir, subject_idx):
    start_time = time.time()
    y = []
    y_true = []
    all_fw_memory_cells = []
    all_bw_memory_cells = []
    total_loss, n_batches = 0.0, 0
    for sub_f_idx, each_data in enumerate(zip(inputs, targets)):
        each_x, each_y = each_data

        # # Initialize state of LSTM - Unidirectional LSTM
        # state = sess.run(network.initial_state)

        # Initialize state of LSTM - Bidirectional LSTM
        fw_state = sess.run(network.fw_initial_state)
        bw_state = sess.run(network.bw_initial_state)

        # Prepare storage for memory cells
        n_all_data = len(each_x)
        extra = n_all_data % network.seq_length
        n_data = n_all_data - extra
        cell_size = 512
        fw_memory_cells = np.zeros((n_data, network.n_rnn_layers, cell_size))
        bw_memory_cells = np.zeros((n_data, network.n_rnn_layers, cell_size))
        seq_idx = 0

        # Store prediction and actual stages of each patient
        each_y_true = []
        each_y_pred = []

        for x_batch, y_batch in iterate_batch_seq_minibatches(
                inputs=each_x,
                targets=each_y,
                batch_size=network.batch_size,
                seq_length=network.seq_length):
            feed_dict = {
                network.input_var: x_batch,
                network.target_var: y_batch
            }

            # Unidirectional LSTM
            # for i, (c, h) in enumerate(network.initial_state):
            #     feed_dict[c] = state[i].c
            #     feed_dict[h] = state[i].h

            # _, loss_value, y_pred, state = sess.run(
            #     [train_op, network.loss_op, network.pred_op, network.final_state],
            #     feed_dict=feed_dict
            # )

            for i, (c, h) in enumerate(network.fw_initial_state):
                feed_dict[c] = fw_state[i].c
                feed_dict[h] = fw_state[i].h

            for i, (c, h) in enumerate(network.bw_initial_state):
                feed_dict[c] = bw_state[i].c
                feed_dict[h] = bw_state[i].h

            _, loss_value, y_pred, fw_state, bw_state = sess.run(
                [
                    train_op, network.loss_op, network.pred_op,
                    network.fw_final_state, network.bw_final_state
                ],
                feed_dict=feed_dict)

            # Extract memory cells
            fw_states = sess.run(network.fw_states, feed_dict=feed_dict)
            bw_states = sess.run(network.bw_states, feed_dict=feed_dict)
            offset_idx = seq_idx * network.seq_length
            for s_idx in range(network.seq_length):
                for r_idx in range(network.n_rnn_layers):
                    fw_memory_cells[offset_idx + s_idx][r_idx] = np.squeeze(
                        fw_states[s_idx][r_idx].c)
                    bw_memory_cells[offset_idx + s_idx][r_idx] = np.squeeze(
                        bw_states[s_idx][r_idx].c)
            seq_idx += 1
            each_y_true.extend(y_batch)
            each_y_pred.extend(y_pred)

            total_loss += loss_value
            n_batches += 1

            # Check the loss value
            assert not np.isnan(loss_value), \
                "Model diverged with loss = NaN"

        all_fw_memory_cells.append(fw_memory_cells)
        all_bw_memory_cells.append(bw_memory_cells)
        y.append(each_y_pred)
        y_true.append(each_y_true)

    # Save memory cells and predictions
    save_dict = {
        "fw_memory_cells": fw_memory_cells,
        "bw_memory_cells": bw_memory_cells,
        "y_true": y_true,
        "y_pred": y
    }
    save_path = os.path.join(output_dir,
                             "output_subject{}.npz".format(subject_idx))
    np.savez(save_path, **save_dict)
    print("Saved outputs to {}".format(save_path))

    duration = time.time() - start_time
    total_loss /= n_batches
    total_y_pred = np.hstack(y)
    total_y_true = np.hstack(y_true)

    return total_y_true, total_y_pred, total_loss, duration