Ejemplo n.º 1
0
    def plot_memory_all_model_params_sequence(self,
                                              data_dict,
                                              predictions,
                                              sample_number=0):
        """
        Creates list of figures used in interactive visualization, with a
        slider enabling to move forth and back along the time axis (iteration
        in a given episode). The visualization presents input, output and
        target sequences passed as input parameters. Additionally, it utilizes
        state tuples collected during the experiment for displaying the memory
        state, read and write attentions; and gating params.

        :param data_dict: DataDict containing at least:
            - "sequences": a tensor of input data of size [BATCH_SIZE x LENGTH_SIZE x INPUT_SIZE]
            - "targets": a tensor of targets of size  [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE]

        :param predictions: Prediction sequence [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE]
        
        :param sample_number: Number of sample in batch (DEFAULT: 0)

        """
        # Check if we are supposed to visualize at all.
        if not self.app_state.visualize:
            return False

        # Initialize timePlot window - if required.
        if self.plotWindow is None:
            from miprometheus.utils.time_plot import TimePlot
            self.plotWindow = TimePlot()

        # import time
        # start_time = time.time()
        # Create figure template.
        fig = self.generate_memory_all_model_params_figure_layout()
        # Get axes that artists will draw on.
        (ax_memory, ax_write_gate, ax_write_shift, ax_write_attention,
         ax_write_similarity, ax_read_gate, ax_read_shift, ax_read_attention,
         ax_read_similarity, ax_inputs, ax_targets, ax_predictions) = fig.axes

        # Unpack data tuple.
        inputs_seq = data_dict["sequences"][sample_number].cpu().detach(
        ).numpy()
        targets_seq = data_dict["targets"][sample_number].cpu().detach().numpy(
        )

        predictions_seq = predictions[sample_number].cpu().detach().numpy()

        # Set intial values of displayed  inputs, targets and predictions -
        # simply zeros.
        inputs_displayed = np.transpose(np.zeros(inputs_seq.shape))
        targets_displayed = np.transpose(np.zeros(targets_seq.shape))
        predictions_displayed = np.transpose(np.zeros(predictions_seq.shape))

        # Set initial values of memory and attentions.
        # Unpack initial state.
        (ctrl_state, interface_state, memory_state,
         read_vectors) = self.cell_state_initial
        (read_state_tuples, write_state_tuple) = interface_state
        (write_attention, write_similarity, write_gate,
         write_shift) = write_state_tuple

        # Initialize "empty" matrices.

        read0_attention_displayed = np.zeros(
            (memory_state.shape[1], targets_seq.shape[0]))
        read0_similarity_displayed = np.zeros(
            (memory_state.shape[1], targets_seq.shape[0]))
        read0_gate_displayed = np.zeros(
            (write_gate.shape[1], targets_seq.shape[0]))
        read0_shift_displayed = np.zeros(
            (write_shift.shape[1], targets_seq.shape[0]))

        write_attention_displayed = np.zeros(
            (memory_state.shape[1], targets_seq.shape[0]))
        write_similarity_displayed = np.zeros(
            (memory_state.shape[1], targets_seq.shape[0]))
        # Generally we can use write shapes as are the same.
        write_gate_displayed = np.zeros(
            (write_gate.shape[1], targets_seq.shape[0]))
        write_shift_displayed = np.zeros(
            (write_shift.shape[1], targets_seq.shape[0]))

        # Log sequence length - so the user can understand what is going on.
        logger = logging.getLogger('ModelBase')
        logger.info(
            "Generating dynamic visualization of {} figures, please wait...".
            format(inputs_seq.shape[0]))

        # Create frames - a list of lists, where each row is a list of artists
        # used to draw a given frame.
        frames = []

        for i, (input_element, target_element, prediction_element,
                cell_state) in enumerate(
                    zip(inputs_seq, targets_seq, predictions_seq,
                        self.cell_state_history)):
            # Display information every 10% of figures.
            if (inputs_seq.shape[0] > 10) and (i % (inputs_seq.shape[0] // 10)
                                               == 0):
                logger.info("Generating figure {}/{}".format(
                    i, inputs_seq.shape[0]))

            # Update displayed values on adequate positions.
            inputs_displayed[:, i] = input_element
            targets_displayed[:, i] = target_element
            predictions_displayed[:, i] = prediction_element

            # Unpack cell state.
            (ctrl_state, interface_state, memory_state,
             read_vectors) = cell_state
            (read_state_tuples, write_state_tuple) = interface_state
            (write_attention, write_similarity, write_gate,
             write_shift) = write_state_tuple
            (read_attentions, read_similarities, read_gates,
             read_shifts) = zip(*read_state_tuples)

            # Set variables.
            memory_displayed = memory_state[0].detach().numpy()
            # Get params of read head 0.
            read0_attention_displayed[:, i] = read_attentions[0][0][:,
                                                                    0].detach(
                                                                    ).numpy()
            read0_similarity_displayed[:,
                                       i] = read_similarities[0][0][:,
                                                                    0].detach(
                                                                    ).numpy()
            read0_gate_displayed[:, i] = read_gates[0][0][:,
                                                          0].detach().numpy()
            read0_shift_displayed[:,
                                  i] = read_shifts[0][0][:,
                                                         0].detach().numpy()

            # Get params of write head
            write_attention_displayed[:, i] = write_attention[0][:, 0].detach(
            ).numpy()
            write_similarity_displayed[:, i] = write_similarity[0][:,
                                                                   0].detach(
                                                                   ).numpy()
            write_gate_displayed[:, i] = write_gate[0][:, 0].detach().numpy()
            write_shift_displayed[:, i] = write_shift[0][:, 0].detach().numpy()

            # Create "Artists" drawing data on "ImageAxes".
            artists = [None] * len(fig.axes)

            # "Show" data on "axes".
            artists[0] = ax_memory.imshow(memory_displayed,
                                          interpolation='nearest',
                                          aspect='auto')

            # Read head.
            artists[1] = ax_read_attention.imshow(read0_attention_displayed,
                                                  interpolation='nearest',
                                                  aspect='auto')
            artists[2] = ax_read_similarity.imshow(read0_similarity_displayed,
                                                   interpolation='nearest',
                                                   aspect='auto')
            artists[3] = ax_read_gate.imshow(read0_gate_displayed,
                                             interpolation='nearest',
                                             aspect='auto')
            artists[4] = ax_read_shift.imshow(read0_shift_displayed,
                                              interpolation='nearest',
                                              aspect='auto')

            # Write head.
            artists[5] = ax_write_attention.imshow(write_attention_displayed,
                                                   interpolation='nearest',
                                                   aspect='auto')
            artists[6] = ax_write_similarity.imshow(write_similarity_displayed,
                                                    interpolation='nearest',
                                                    aspect='auto')
            artists[7] = ax_write_gate.imshow(write_gate_displayed,
                                              interpolation='nearest',
                                              aspect='auto')
            artists[8] = ax_write_shift.imshow(write_shift_displayed,
                                               interpolation='nearest',
                                               aspect='auto')

            # "Default data".
            artists[9] = ax_inputs.imshow(inputs_displayed,
                                          interpolation='nearest',
                                          aspect='auto')
            artists[10] = ax_targets.imshow(targets_displayed,
                                            interpolation='nearest',
                                            aspect='auto')
            artists[11] = ax_predictions.imshow(predictions_displayed,
                                                interpolation='nearest',
                                                aspect='auto')

            # Add "frame".
            frames.append(artists)

        # print("--- %s seconds ---" % (time.time() - start_time))
        # Plot figure and list of frames.

        self.plotWindow.update(fig, frames)
        return self.plotWindow.is_closed
Ejemplo n.º 2
0
class DWM(SequentialModel):
    """
    Differentiable Working Memory (DWM), is a memory augmented neural network
    which emulates the human working memory.

    The DWM shows the same functional characteristics of working memory
    and robustly learns psychology-inspired tasks and converges faster
    than comparable state-of-the-art models

    """
    def __init__(self, params, problem_default_values_={}):
        """
        Constructor. Initializes parameters on the basis of dictionary passed
        as argument.

        :param params: Local view to the Parameter Regsitry ''model'' section.

        :param problem_default_values_: Dictionary containing key-values received from problem.

        """
        # Call base constructor. Sets up default values etc.
        super(DWM, self).__init__(params, problem_default_values_)
        # Model name.
        self.name = "Differentiable Working Memory (DWM)"

        # Parse default values received from problem and add them to registry.
        self.params.add_default_params({
            'input_item_size':
            problem_default_values_['input_item_size'],
            'output_item_size':
            problem_default_values_['output_item_size']
        })

        self.in_dim = params["input_item_size"]
        self.output_units = params['output_item_size']

        self.state_units = params["hidden_state_size"]

        self.num_heads = params["num_heads"]
        self.is_cam = params["use_content_addressing"]
        self.num_shift = params["shift_size"]
        self.M = params["memory_content_size"]
        self.memory_addresses_size = params["memory_addresses_size"]

        # This is for the time plot
        self.cell_state_history = None

        # Create the DWM components
        self.DWMCell = DWMCell(self.in_dim, self.output_units,
                               self.state_units, self.num_heads, self.is_cam,
                               self.num_shift, self.M)

    def forward(self, data_dict):
        """
        Forward function requires that the data_dict will contain at least "sequences"

        :param data_dict: DataDict containing at least:
            - "sequences": a tensor of input data of size [BATCH_SIZE x LENGTH_SIZE x INPUT_SIZE]

        :returns: output: logits which represent the prediction of DWM [batch, sequence_length, output_size]

        Example:

        >>> dwm = DWM(params)
        >>> inputs = torch.randn(5, 3, 10)
        >>> targets = torch.randn(5, 3, 20)
        >>> data_tuple = (inputs, targets)
        >>> output = dwm(data_tuple)

        """
        # Get dtype.
        #dtype = self.app_state.dtype

        # Unpack dict.
        inputs = data_dict['sequences']

        # Get batch size and seq length.
        batch_size = inputs.size(0)
        seq_length = inputs.size(-2)

        if self.app_state.visualize:
            self.cell_state_history = []

        output = None
        # TODO
        if len(inputs.size()) == 4:
            inputs = inputs[:, 0, :, :]

        # The length of the memory is set to be equal to the input length in
        # case ```self.memory_addresses_size == -1```
        if self.memory_addresses_size == -1:
            if seq_length < self.num_shift:
                # memory size can't be smaller than num_shift (see
                # circular_convolution implementation)
                memory_addresses_size = self.num_shift
            else:
                memory_addresses_size = seq_length  # a hack for now
        else:
            memory_addresses_size = self.memory_addresses_size

        # Init state
        cell_state = self.DWMCell.init_state(memory_addresses_size, batch_size)

        # loop over the different sequences
        for j in range(seq_length):
            output_cell, cell_state = self.DWMCell(inputs[..., j, :],
                                                   cell_state)

            if output_cell is None:
                continue

            output_cell = output_cell[..., None, :]
            if output is None:
                output = output_cell

            # Concatenate output
            else:
                output = torch.cat([output, output_cell], dim=-2)

            # This is for the time plot
            if self.app_state.visualize:
                self.cell_state_history.append(
                    (cell_state.memory_state.detach().numpy(),
                     cell_state.interface_state.head_weight.detach().numpy(),
                     cell_state.interface_state.snapshot_weight.detach().numpy(
                     )))

        return output

    # Method to change memory size
    def set_memory_size(self, mem_size):
        self.memory_addresses_size = mem_size

    @staticmethod
    def generate_figure_layout():
        """
        DOCUMENTATION!!

        """
        import matplotlib

        # Prepare "generic figure template".
        # Create figure object.
        fig = matplotlib.pyplot.figure(figsize=(16, 9))
        # fig.tight_layout()
        fig.subplots_adjust(left=0.07, right=0.96, top=0.88, bottom=0.15)

        gs0 = matplotlib.gridspec.GridSpec(1, 2, width_ratios=[5.0, 3.0])

        # Create a specific grid for DWM .
        gs00 = matplotlib.gridspec.GridSpecFromSubplotSpec(
            1, 3, subplot_spec=gs0[0], width_ratios=[3.0, 4.0, 4.0])

        ax_memory = fig.add_subplot(gs00[:, 0])  # all rows, col 0
        ax_attention = fig.add_subplot(gs00[:, 1])  # all rows, col 2-3
        ax_bookmark = fig.add_subplot(gs00[:, 2])  # all rows, col 4-5

        gs01 = matplotlib.gridspec.GridSpecFromSubplotSpec(
            3,
            1,
            subplot_spec=gs0[1],
            hspace=0.5,
            height_ratios=[1.0, 0.8, 0.8])
        ax_inputs = fig.add_subplot(gs01[0, :])  # row 0, span 2 columns
        ax_targets = fig.add_subplot(gs01[1, :])  # row 0, span 2 columns
        ax_predictions = fig.add_subplot(gs01[2, :])  # row 0, span 2 columns

        # Set ticks - for bit axes only (for now).
        ax_inputs.xaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_inputs.yaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_targets.xaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_targets.yaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_predictions.xaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_predictions.yaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_memory.xaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_memory.yaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_bookmark.xaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_bookmark.yaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_attention.xaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_attention.yaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))

        # Set labels.
        ax_inputs.set_title('Inputs')
        ax_inputs.set_ylabel('Control/Data bits')
        ax_targets.set_title('Targets')
        ax_targets.set_ylabel('Data bits')
        ax_predictions.set_title('Predictions')
        ax_predictions.set_ylabel('Data bits')
        ax_predictions.set_xlabel('Item number/Iteration')

        ax_memory.set_title('Memory')
        ax_memory.set_ylabel('Memory Addresses')
        ax_memory.set_xlabel('Content bits')
        ax_attention.set_title('Head Attention')
        ax_attention.set_xlabel('Iteration')
        ax_bookmark.set_title('Bookmark Attention')
        ax_bookmark.set_xlabel('Iteration')

        return fig

    def plot(self, data_dict, predictions, sample_number=0):
        """
        Interactive visualization, with a slider enabling to move forth and
        back along the time axis (iteration in a given episode).

        :param data_dict: DataDict containing at least:
            - "sequences": a tensor of input data of size [BATCH_SIZE x LENGTH_SIZE x INPUT_SIZE]
            - "targets": a tensor of targets of size  [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE]

        :param predictions: Prediction sequence [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE]

        :param sample_number: Number of sample in batch (DEFAULT: 0)

        """
        # Check if we are supposed to visualize at all.
        if not self.app_state.visualize:
            return False

        # Initialize timePlot window - if required.
        if self.plotWindow is None:
            from miprometheus.utils.time_plot import TimePlot
            self.plotWindow = TimePlot()

        # import time
        # start_time = time.time()
        inputs_seq = data_dict["sequences"][sample_number].cpu().detach(
        ).numpy()
        targets_seq = data_dict["targets"][sample_number].cpu().detach().numpy(
        )
        predictions_seq = predictions[0].cpu().detach()
        #predictions_seq = torch.sigmoid(predictions_seq).numpy()

        # temporary for data with additional channel
        if len(inputs_seq.shape) == 3:
            inputs_seq = inputs_seq[0, :, :]

        # Create figure template.
        fig = self.generate_figure_layout()
        # Get axes that artists will draw on.
        (ax_memory, ax_attention, ax_bookmark, ax_inputs, ax_targets,
         ax_predictions) = fig.axes

        # Set intial values of displayed  inputs, targets and predictions -
        # simply zeros.
        inputs_displayed = np.transpose(np.zeros(inputs_seq.shape))
        targets_displayed = np.transpose(np.zeros(targets_seq.shape))
        predictions_displayed = np.transpose(np.zeros(predictions_seq.shape))

        head_attention_displayed = np.zeros(
            (self.cell_state_history[0][1].shape[-1], targets_seq.shape[0]))
        bookmark_attention_displayed = np.zeros(
            (self.cell_state_history[0][2].shape[-1], targets_seq.shape[0]))

        # Log sequence length - so the user can understand what is going on.
        logger = logging.getLogger('ModelBase')
        logger.info(
            "Generating dynamic visualization of {} figures, please wait...".
            format(inputs_seq.shape[0]))

        # Create frames - a list of lists, where each row is a list of artists
        # used to draw a given frame.
        frames = []

        for i, (input_element, target_element, prediction_element,
                (memory, wt, wt_d)) in enumerate(
                    zip(inputs_seq, targets_seq, predictions_seq,
                        self.cell_state_history)):
            # Display information every 10% of figures.
            if (inputs_seq.shape[0] > 10) and (i % (inputs_seq.shape[0] // 10)
                                               == 0):
                logger.info("Generating figure {}/{}".format(
                    i, inputs_seq.shape[0]))

            # Update displayed values on adequate positions.
            inputs_displayed[:, i] = input_element
            targets_displayed[:, i] = target_element
            predictions_displayed[:, i] = prediction_element

            memory_displayed = np.clip(memory[0], -3.0, 3.0)
            head_attention_displayed[:, i] = wt[0, 0, :]
            bookmark_attention_displayed[:, i] = wt_d[0, 0, :]

            # Create "Artists" drawing data on "ImageAxes".
            artists = [None] * len(fig.axes)
            params = {
                'edgecolor': 'black',
                'cmap': 'inferno',
                'linewidths': 1.4e-3
            }

            # Tell artists what to do;)
            artists[0] = ax_memory.pcolormesh(np.transpose(memory_displayed),
                                              vmin=-3.0,
                                              vmax=3.0,
                                              **params)
            artists[1] = ax_attention.pcolormesh(
                np.copy(head_attention_displayed),
                vmin=0.0,
                vmax=1.0,
                **params)
            artists[2] = ax_bookmark.pcolormesh(
                np.copy(bookmark_attention_displayed),
                vmin=0.0,
                vmax=1.0,
                **params)
            artists[3] = ax_inputs.pcolormesh(np.copy(inputs_displayed),
                                              vmin=0.0,
                                              vmax=1.0,
                                              **params)
            artists[4] = ax_targets.pcolormesh(np.copy(targets_displayed),
                                               vmin=0.0,
                                               vmax=1.0,
                                               **params)
            artists[5] = ax_predictions.pcolormesh(
                np.copy(predictions_displayed), vmin=0.0, vmax=1.0, **params)

            # Add "frame".
            frames.append(artists)

        # Plot figure and list of frames.

        self.plotWindow.update(fig, frames)
        return self.plotWindow.is_closed
Ejemplo n.º 3
0
class NTM(SequentialModel):
    """
    Class representing the Neural Turing Machine module.
    """
    def __init__(self, params, problem_default_values_={}):
        """
        Constructor. Initializes parameters on the basis of dictionary passed
        as argument.

        :param params: Local view to the Parameter Regsitry ''model'' section.

        :param problem_default_values_: Dictionary containing key-values received from problem.

        """
        # Call base constructor. Sets up default values etc.
        super(NTM, self).__init__(params, problem_default_values_)
        # Model name.
        self.name = 'NTM'

        # Parse default values received from problem and add them to registry.
        self.params.add_default_params({
            'input_item_size':
            problem_default_values_['input_item_size'],
            'output_item_size':
            problem_default_values_['output_item_size']
        })

        # Parse parameters.
        # It is stored here, but will we used ONLY ONCE - for initialization of
        # memory in the forward() function.
        self.num_memory_addresses = params['memory']['num_addresses']
        self.num_memory_content_bits = params['memory']['num_content_bits']

        # Initialize recurrent NTM cell.
        self.ntm_cell = NTMCell(params)

        # Set different visualizations depending on the flags.
        try:
            if params['visualization_mode'] == 1:
                self.plot = self.plot_memory_attention_sequence
            elif params['visualization_mode'] == 2:
                self.plot = self.plot_memory_all_model_params_sequence
            # else: default visualization.
        except KeyError:
            # If the 'visualization_mode' key is not present, catch the exception and do nothing
            # I.e. show default vizualization.
            pass

    def forward(self, data_dict):
        """
        Forward function requires that the data_dict will contain at least "sequences"

        :param data_dict: DataDict containing at least:
            - "sequences": a tensor of input data of size [BATCH_SIZE x LENGTH_SIZE x INPUT_SIZE]

        :returns: Predictions (logits) being a tensor of size  [BATCH_SIZE x LENGTH_SIZE x OUTPUT_SIZE].

        """
        # Get dtype.
        dtype = self.app_state.dtype

        # Unpack dict.
        inputs_BxSxI = data_dict['sequences']

        # Get batch size.
        batch_size = inputs_BxSxI.size(0)

        # "Data-driven memory size".
        # Save as TEMPORAL VARIABLE!
        # (do not overwrite self.num_memory_addresses, which will cause problem with next batch!)
        if self.num_memory_addresses == -1:
            # Set equal to input sequence length.
            num_memory_addresses = inputs_BxSxI.size(1)
        else:
            num_memory_addresses = self.num_memory_addresses

        # Initialize memory [BATCH_SIZE x MEMORY_ADDRESSES x CONTENT_BITS]
        init_memory_BxAxC = torch.zeros(
            batch_size, num_memory_addresses,
            self.num_memory_content_bits).type(dtype)

        # Initialize 'zero' state.
        cell_state = self.ntm_cell.init_state(init_memory_BxAxC)

        # List of output logits [BATCH_SIZE x OUTPUT_SIZE] of length SEQ_LENGTH
        output_logits_BxO_S = []

        # Check if we want to collect cell history for the visualization
        # purposes.
        if self.app_state.visualize:
            self.cell_state_history = []
            self.cell_state_initial = cell_state

        # Divide sequence into chunks of size [BATCH_SIZE x INPUT_SIZE] and
        # process them one by one.
        for input_t_Bx1xI in inputs_BxSxI.chunk(inputs_BxSxI.size(1), dim=1):
            # Process one chunk.
            output_BxO, cell_state = self.ntm_cell(input_t_Bx1xI.squeeze(1),
                                                   cell_state)
            # Append to list of logits.
            output_logits_BxO_S += [output_BxO]

            # Collect cell history - for the visualization purposes.
            if self.app_state.visualize:
                self.cell_state_history.append(cell_state)

        # Stack logits along time axis (1).
        output_logits_BxSxO = torch.stack(output_logits_BxO_S, 1)

        return output_logits_BxSxO

    def generate_memory_attention_figure_layout(self):
        """
        Creates a figure template for showing basic NTM attributes (write &
        write attentions), memory and sequence (inputs, predictions and
        targets).

        :returns: Matplot figure object.

        """
        import matplotlib
        from matplotlib.figure import Figure

        # Change fonts globally - for all figures/subsplots at once.
        matplotlib.rc('font', **{'family': 'Times New Roman'})

        # Prepare "generic figure template".
        # Create figure object.
        fig = Figure()

        # Create a specific grid for NTM .
        gs = matplotlib.gridspec.GridSpec(3, 7)

        # Memory
        ax_memory = fig.add_subplot(gs[:, 0])  # all rows, col 0
        ax_write_attention = fig.add_subplot(gs[:, 1:3])  # all rows, col 2-3
        ax_read_attention = fig.add_subplot(gs[:, 3:5])  # all rows, col 4-5

        ax_inputs = fig.add_subplot(gs[0, 5:])  # row 0, span 2 columns
        ax_targets = fig.add_subplot(gs[1, 5:])  # row 0, span 2 columns
        ax_predictions = fig.add_subplot(gs[2, 5:])  # row 0, span 2 columns

        # Set ticks - currently for all axes.
        for ax in fig.axes:
            ax.xaxis.set_major_locator(
                matplotlib.ticker.MaxNLocator(integer=True))
            ax.yaxis.set_major_locator(
                matplotlib.ticker.MaxNLocator(integer=True))
            ax.yaxis.set_major_locator(
                matplotlib.ticker.MaxNLocator(integer=True))

        # Set labels.
        ax_inputs.set_title('Inputs')
        ax_inputs.set_ylabel('Control/Data bits')
        ax_targets.set_title('Targets')
        ax_targets.set_ylabel('Data bits')
        ax_predictions.set_title('Predictions')
        ax_predictions.set_ylabel('Data bits')
        ax_predictions.set_xlabel('Item number/Iteration')

        ax_memory.set_title('Memory')
        ax_memory.set_ylabel('Memory Addresses')
        ax_memory.set_xlabel('Content bits')

        ax_write_attention.set_title('Write Attention')
        ax_write_attention.set_xlabel('Iteration')
        ax_read_attention.set_title('Read Attention')
        ax_read_attention.set_xlabel('Iteration')

        fig.set_tight_layout(True)
        return fig

    def plot_memory_attention_sequence(self,
                                       data_dict,
                                       predictions,
                                       sample_number=0):
        """
        Creates list of figures used in interactive visualization, with a
        slider enabling to move forth and back along the time axis (iteration
        in a given episode). The visualization presents input, output and
        target sequences passed as input parameters. Additionally, it utilizes
        state tuples collected during the experiment for displaying the memory
        state, read and write attentions.

        :param data_dict: DataDict containing at least:
            - "sequences": a tensor of input data of size [BATCH_SIZE x LENGTH_SIZE x INPUT_SIZE]
            - "targets": a tensor of targets of size  [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE]

        :param predictions: Prediction sequence [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE]

        :param sample_number: Number of sample in batch (DEFAULT: 0)

        """
        # Check if we are supposed to visualize at all.
        if not self.app_state.visualize:
            return False

        # Initialize timePlot window - if required.
        if self.plotWindow is None:
            from miprometheus.utils.time_plot import TimePlot
            self.plotWindow = TimePlot()

        # import time
        # start_time = time.time()
        # Create figure template.
        fig = self.generate_memory_attention_figure_layout()
        # Get axes that artists will draw on.
        (ax_memory, ax_write_attention, ax_read_attention, ax_inputs,
         ax_targets, ax_predictions) = fig.axes

        # Unpack data tuple.
        inputs_seq = data_dict["sequences"][sample_number].cpu().detach(
        ).numpy()
        targets_seq = data_dict["targets"][sample_number].cpu().detach().numpy(
        )
        predictions_seq = predictions[sample_number].cpu().detach().numpy()

        # Set intial values of displayed  inputs, targets and predictions -
        # simply zeros.
        inputs_displayed = np.transpose(np.zeros(inputs_seq.shape))
        targets_displayed = np.transpose(np.zeros(targets_seq.shape))
        predictions_displayed = np.transpose(np.zeros(predictions_seq.shape))

        # Set initial values of memory and attentions.
        # Unpack initial state.
        (ctrl_state, interface_state, memory_state,
         read_vectors) = self.cell_state_initial

        # Initialize "empty" matrices.

        read0_attention_displayed = np.zeros(
            (memory_state.shape[1], targets_seq.shape[0]))
        write_attention_displayed = np.zeros(
            (memory_state.shape[1], targets_seq.shape[0]))

        # Log sequence length - so the user can understand what is going on.
        logger = logging.getLogger('ModelBase')
        logger.info(
            "Generating dynamic visualization of {} figures, please wait...".
            format(inputs_seq.shape[0]))

        # Create frames - a list of lists, where each row is a list of artists
        # used to draw a given frame.
        frames = []

        for i, (input_element, target_element, prediction_element,
                cell_state) in enumerate(
                    zip(inputs_seq, targets_seq, predictions_seq,
                        self.cell_state_history)):
            # Display information every 10% of figures.
            if (inputs_seq.shape[0] > 10) and (i % (inputs_seq.shape[0] // 10)
                                               == 0):
                logger.info("Generating figure {}/{}".format(
                    i, inputs_seq.shape[0]))

            # Update displayed values on adequate positions.
            inputs_displayed[:, i] = input_element
            targets_displayed[:, i] = target_element
            predictions_displayed[:, i] = prediction_element

            # Unpack cell state.
            (ctrl_state, interface_state, memory_state,
             read_vectors) = cell_state
            (read_state_tuples, write_state_tuple) = interface_state
            (write_attention, write_similarity, write_gate,
             write_shift) = write_state_tuple
            (read_attentions, read_similarities, read_gates,
             read_shifts) = zip(*read_state_tuples)

            # Set variables.
            memory_displayed = memory_state[0].detach().numpy()
            # Get attention of head 0.
            read0_attention_displayed[:, i] = read_attentions[0][0][:,
                                                                    0].detach(
                                                                    ).numpy()
            write_attention_displayed[:, i] = write_attention[0][:, 0].detach(
            ).numpy()

            # Create "Artists" drawing data on "ImageAxes".
            artists = [None] * len(fig.axes)

            # "Show" data on "axes".
            artists[0] = ax_memory.imshow(memory_displayed,
                                          interpolation='nearest',
                                          aspect='auto')
            artists[1] = ax_read_attention.imshow(read0_attention_displayed,
                                                  interpolation='nearest',
                                                  aspect='auto')
            artists[2] = ax_write_attention.imshow(write_attention_displayed,
                                                   interpolation='nearest',
                                                   aspect='auto')
            artists[3] = ax_inputs.imshow(inputs_displayed,
                                          interpolation='nearest',
                                          aspect='auto')
            artists[4] = ax_targets.imshow(targets_displayed,
                                           interpolation='nearest',
                                           aspect='auto')
            artists[5] = ax_predictions.imshow(predictions_displayed,
                                               interpolation='nearest',
                                               aspect='auto')

            # Add "frame".
            frames.append(artists)

        # print("--- %s seconds ---" % (time.time() - start_time))
        # Plot figure and list of frames.

        self.plotWindow.update(fig, frames)
        return self.plotWindow.is_closed

    def generate_memory_all_model_params_figure_layout(self):
        """
        Creates a figure template for showing all NTM attributes (write & write
        attentions, gates, convolution masks), along with memory and sequence
        (inputs, predictions and targets).

        :returns: Matplot figure object.

        """
        import matplotlib
        from matplotlib.figure import Figure

        # Change fonts globally - for all figures/subsplots at once.
        matplotlib.rc('font', **{'family': 'Times New Roman'})

        # Prepare "generic figure template".
        # Create figure object.
        fig = Figure()
        #axes = fig.subplots(3, 1, sharex=True, sharey=False, gridspec_kw={'width_ratios': [input_seq.shape[0]]})

        # Create a specific grid for NTM .
        gs = matplotlib.gridspec.GridSpec(4, 7)

        # Memory
        ax_memory = fig.add_subplot(gs[1:, 0])  # all rows, col 0
        ax_write_gate = fig.add_subplot(gs[0, 1:2])
        ax_write_shift = fig.add_subplot(gs[0, 2:3])
        ax_write_attention = fig.add_subplot(gs[1:, 1:2])  # all rows, col 2-3
        ax_write_similarity = fig.add_subplot(gs[1:, 2:3])
        ax_read_gate = fig.add_subplot(gs[0, 3:4])
        ax_read_shift = fig.add_subplot(gs[0, 4:5])
        ax_read_attention = fig.add_subplot(gs[1:, 3:4])  # all rows, col 4-5
        ax_read_similarity = fig.add_subplot(gs[1:, 4:5])

        ax_inputs = fig.add_subplot(gs[1, 5:])  # row 0, span 2 columns
        ax_targets = fig.add_subplot(gs[2, 5:])  # row 0, span 2 columns
        ax_predictions = fig.add_subplot(gs[3, 5:])  # row 0, span 2 columns

        # Set ticks - currently for all axes.
        for ax in fig.axes:
            ax.xaxis.set_major_locator(
                matplotlib.ticker.MaxNLocator(integer=True))
            ax.yaxis.set_major_locator(
                matplotlib.ticker.MaxNLocator(integer=True))
        # ... except gates - single bit.
        ax_write_gate.yaxis.set_major_locator(matplotlib.ticker.NullLocator())
        ax_read_gate.yaxis.set_major_locator(matplotlib.ticker.NullLocator())

        # Set labels.
        ax_inputs.set_title('Inputs')
        ax_inputs.set_ylabel('Control/Data bits')
        ax_targets.set_title('Targets')
        ax_targets.set_ylabel('Data bits')
        ax_predictions.set_title('Predictions')
        ax_predictions.set_ylabel('Data bits')
        ax_predictions.set_xlabel('Item number/Iteration')

        ax_memory.set_title('Memory')
        ax_memory.set_ylabel('Memory Addresses')
        ax_memory.set_xlabel('Content bits')

        for ax in [
                ax_write_gate, ax_write_shift, ax_write_attention,
                ax_write_similarity, ax_read_gate, ax_read_shift,
                ax_read_attention, ax_read_similarity
        ]:
            ax.set_xlabel('Iteration')

        # Write head.
        ax_write_gate.set_title('Write Gate')
        ax_write_shift.set_title('Write Shift')
        ax_write_attention.set_title('Write Attention')
        ax_write_similarity.set_title('Write Similarity')

        # Read head.
        ax_read_gate.set_title('Read Gate')
        ax_read_shift.set_title('Read Shift')
        ax_read_attention.set_title('Read Attention')
        ax_read_similarity.set_title('Read Similarity')

        fig.set_tight_layout(True)
        return fig

    def plot_memory_all_model_params_sequence(self,
                                              data_dict,
                                              predictions,
                                              sample_number=0):
        """
        Creates list of figures used in interactive visualization, with a
        slider enabling to move forth and back along the time axis (iteration
        in a given episode). The visualization presents input, output and
        target sequences passed as input parameters. Additionally, it utilizes
        state tuples collected during the experiment for displaying the memory
        state, read and write attentions; and gating params.

        :param data_dict: DataDict containing at least:
            - "sequences": a tensor of input data of size [BATCH_SIZE x LENGTH_SIZE x INPUT_SIZE]
            - "targets": a tensor of targets of size  [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE]

        :param predictions: Prediction sequence [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE]
        
        :param sample_number: Number of sample in batch (DEFAULT: 0)

        """
        # Check if we are supposed to visualize at all.
        if not self.app_state.visualize:
            return False

        # Initialize timePlot window - if required.
        if self.plotWindow is None:
            from miprometheus.utils.time_plot import TimePlot
            self.plotWindow = TimePlot()

        # import time
        # start_time = time.time()
        # Create figure template.
        fig = self.generate_memory_all_model_params_figure_layout()
        # Get axes that artists will draw on.
        (ax_memory, ax_write_gate, ax_write_shift, ax_write_attention,
         ax_write_similarity, ax_read_gate, ax_read_shift, ax_read_attention,
         ax_read_similarity, ax_inputs, ax_targets, ax_predictions) = fig.axes

        # Unpack data tuple.
        inputs_seq = data_dict["sequences"][sample_number].cpu().detach(
        ).numpy()
        targets_seq = data_dict["targets"][sample_number].cpu().detach().numpy(
        )

        predictions_seq = predictions[sample_number].cpu().detach().numpy()

        # Set intial values of displayed  inputs, targets and predictions -
        # simply zeros.
        inputs_displayed = np.transpose(np.zeros(inputs_seq.shape))
        targets_displayed = np.transpose(np.zeros(targets_seq.shape))
        predictions_displayed = np.transpose(np.zeros(predictions_seq.shape))

        # Set initial values of memory and attentions.
        # Unpack initial state.
        (ctrl_state, interface_state, memory_state,
         read_vectors) = self.cell_state_initial
        (read_state_tuples, write_state_tuple) = interface_state
        (write_attention, write_similarity, write_gate,
         write_shift) = write_state_tuple

        # Initialize "empty" matrices.

        read0_attention_displayed = np.zeros(
            (memory_state.shape[1], targets_seq.shape[0]))
        read0_similarity_displayed = np.zeros(
            (memory_state.shape[1], targets_seq.shape[0]))
        read0_gate_displayed = np.zeros(
            (write_gate.shape[1], targets_seq.shape[0]))
        read0_shift_displayed = np.zeros(
            (write_shift.shape[1], targets_seq.shape[0]))

        write_attention_displayed = np.zeros(
            (memory_state.shape[1], targets_seq.shape[0]))
        write_similarity_displayed = np.zeros(
            (memory_state.shape[1], targets_seq.shape[0]))
        # Generally we can use write shapes as are the same.
        write_gate_displayed = np.zeros(
            (write_gate.shape[1], targets_seq.shape[0]))
        write_shift_displayed = np.zeros(
            (write_shift.shape[1], targets_seq.shape[0]))

        # Log sequence length - so the user can understand what is going on.
        logger = logging.getLogger('ModelBase')
        logger.info(
            "Generating dynamic visualization of {} figures, please wait...".
            format(inputs_seq.shape[0]))

        # Create frames - a list of lists, where each row is a list of artists
        # used to draw a given frame.
        frames = []

        for i, (input_element, target_element, prediction_element,
                cell_state) in enumerate(
                    zip(inputs_seq, targets_seq, predictions_seq,
                        self.cell_state_history)):
            # Display information every 10% of figures.
            if (inputs_seq.shape[0] > 10) and (i % (inputs_seq.shape[0] // 10)
                                               == 0):
                logger.info("Generating figure {}/{}".format(
                    i, inputs_seq.shape[0]))

            # Update displayed values on adequate positions.
            inputs_displayed[:, i] = input_element
            targets_displayed[:, i] = target_element
            predictions_displayed[:, i] = prediction_element

            # Unpack cell state.
            (ctrl_state, interface_state, memory_state,
             read_vectors) = cell_state
            (read_state_tuples, write_state_tuple) = interface_state
            (write_attention, write_similarity, write_gate,
             write_shift) = write_state_tuple
            (read_attentions, read_similarities, read_gates,
             read_shifts) = zip(*read_state_tuples)

            # Set variables.
            memory_displayed = memory_state[0].detach().numpy()
            # Get params of read head 0.
            read0_attention_displayed[:, i] = read_attentions[0][0][:,
                                                                    0].detach(
                                                                    ).numpy()
            read0_similarity_displayed[:,
                                       i] = read_similarities[0][0][:,
                                                                    0].detach(
                                                                    ).numpy()
            read0_gate_displayed[:, i] = read_gates[0][0][:,
                                                          0].detach().numpy()
            read0_shift_displayed[:,
                                  i] = read_shifts[0][0][:,
                                                         0].detach().numpy()

            # Get params of write head
            write_attention_displayed[:, i] = write_attention[0][:, 0].detach(
            ).numpy()
            write_similarity_displayed[:, i] = write_similarity[0][:,
                                                                   0].detach(
                                                                   ).numpy()
            write_gate_displayed[:, i] = write_gate[0][:, 0].detach().numpy()
            write_shift_displayed[:, i] = write_shift[0][:, 0].detach().numpy()

            # Create "Artists" drawing data on "ImageAxes".
            artists = [None] * len(fig.axes)

            # "Show" data on "axes".
            artists[0] = ax_memory.imshow(memory_displayed,
                                          interpolation='nearest',
                                          aspect='auto')

            # Read head.
            artists[1] = ax_read_attention.imshow(read0_attention_displayed,
                                                  interpolation='nearest',
                                                  aspect='auto')
            artists[2] = ax_read_similarity.imshow(read0_similarity_displayed,
                                                   interpolation='nearest',
                                                   aspect='auto')
            artists[3] = ax_read_gate.imshow(read0_gate_displayed,
                                             interpolation='nearest',
                                             aspect='auto')
            artists[4] = ax_read_shift.imshow(read0_shift_displayed,
                                              interpolation='nearest',
                                              aspect='auto')

            # Write head.
            artists[5] = ax_write_attention.imshow(write_attention_displayed,
                                                   interpolation='nearest',
                                                   aspect='auto')
            artists[6] = ax_write_similarity.imshow(write_similarity_displayed,
                                                    interpolation='nearest',
                                                    aspect='auto')
            artists[7] = ax_write_gate.imshow(write_gate_displayed,
                                              interpolation='nearest',
                                              aspect='auto')
            artists[8] = ax_write_shift.imshow(write_shift_displayed,
                                               interpolation='nearest',
                                               aspect='auto')

            # "Default data".
            artists[9] = ax_inputs.imshow(inputs_displayed,
                                          interpolation='nearest',
                                          aspect='auto')
            artists[10] = ax_targets.imshow(targets_displayed,
                                            interpolation='nearest',
                                            aspect='auto')
            artists[11] = ax_predictions.imshow(predictions_displayed,
                                                interpolation='nearest',
                                                aspect='auto')

            # Add "frame".
            frames.append(artists)

        # print("--- %s seconds ---" % (time.time() - start_time))
        # Plot figure and list of frames.

        self.plotWindow.update(fig, frames)
        return self.plotWindow.is_closed
Ejemplo n.º 4
0
    def plot(self, data_dict, predictions, sample=0):
        """
        Creates a default interactive visualization, with a slider enabling to
        move forth and back along the time axis (iteration over the sequence elements in a given episode).
        The default visualization contains the input, output and target sequences.

        For a more model/problem - dependent visualization, please overwrite this
        method in the derived model class.

        :param data_dict: DataDict containing

           - input sequences: [BATCH_SIZE x SEQUENCE_LENGTH x INPUT_SIZE],
           - target sequences:  [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_SIZE]


        :param predictions: Predicted sequences [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_SIZE]
        :type predictions: torch.tensor

        :param sample: Number of sample in batch (default: 0)
        :type sample: int

        """
        # Check if we are supposed to visualize at all.
        if not self.app_state.visualize:
            return

        # Initialize timePlot window - if required.
        if self.plotWindow is None:
            from miprometheus.utils.time_plot import TimePlot
            self.plotWindow = TimePlot()

        from matplotlib.figure import Figure
        import matplotlib.ticker as ticker

        # Change fonts globally - for all figures/subsplots at once.
        # from matplotlib import rc
        # rc('font', **{'family': 'Times New Roman'})
        import matplotlib.pylab as pylab
        params = {  # 'legend.fontsize': '28',
            'axes.titlesize': 'large',
            'axes.labelsize': 'large',
            'xtick.labelsize': 'medium',
            'ytick.labelsize': 'medium'
        }
        pylab.rcParams.update(params)

        # Create a single "figure layout" for all displayed frames.
        fig = Figure()
        axes = fig.subplots(
            3,
            1,
            sharex=True,
            sharey=False,
            gridspec_kw={'width_ratios': [predictions.shape[0]]})

        # Set ticks.
        axes[0].xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
        axes[0].yaxis.set_major_locator(ticker.MaxNLocator(integer=True))
        axes[1].yaxis.set_major_locator(ticker.MaxNLocator(integer=True))
        axes[2].yaxis.set_major_locator(ticker.MaxNLocator(integer=True))

        # Set labels.
        axes[0].set_title('Inputs')
        axes[0].set_ylabel('Control/Data bits')
        axes[1].set_title('Targets')
        axes[1].set_ylabel('Data bits')
        axes[2].set_title('Predictions')
        axes[2].set_ylabel('Data bits')
        axes[2].set_xlabel('Item number')
        fig.set_tight_layout(True)

        # Detach a sample from batch and copy it to CPU.
        inputs_seq = data_dict['sequences'][sample].cpu().detach().numpy()
        targets_seq = data_dict['targets'][sample].cpu().detach().numpy()
        predictions_seq = predictions[sample].cpu().detach().numpy()

        # Create empty matrices.
        x = np.transpose(np.zeros(inputs_seq.shape))
        y = np.transpose(np.zeros(predictions_seq.shape))
        z = np.transpose(np.zeros(targets_seq.shape))

        # Log sequence length - so the user can understand what is going on.
        self.logger.info(
            "Generating dynamic visualization of {} figures, please wait...".
            format(inputs_seq.shape[0]))

        # Create frames - a list of lists, where each row is a list of artists
        # used to draw a given frame.
        frames = []

        for i, (input_word, prediction_word, target_word) in enumerate(
                zip(inputs_seq, predictions_seq, targets_seq)):
            # Display information every 10% of figures.
            if (inputs_seq.shape[0] > 10) and (i % (inputs_seq.shape[0] // 10)
                                               == 0):
                self.logger.info("Generating figure {}/{}".format(
                    i, inputs_seq.shape[0]))

            # Add words to adequate positions.
            x[:, i] = input_word
            y[:, i] = target_word
            z[:, i] = prediction_word

            # Create "Artists" drawing data on "ImageAxes".
            artists = [None] * len(fig.axes)

            # Tell artists what to do
            artists[0] = axes[0].imshow(x,
                                        interpolation='nearest',
                                        aspect='auto')
            artists[1] = axes[1].imshow(y,
                                        interpolation='nearest',
                                        aspect='auto')
            artists[2] = axes[2].imshow(z,
                                        interpolation='nearest',
                                        aspect='auto')

            # Add "frame".
            frames.append(artists)

        # Plot figure and list of frames.
        self.plotWindow.update(fig, frames)
Ejemplo n.º 5
0
    def plot(self, data_dict, predictions, sample_number=0):
        """
        Interactive visualization, with a slider enabling to move forth and
        back along the time axis (iteration in a given episode).

        :param data_dict: DataDict containing at least:
            - "sequences": a tensor of input data of size [BATCH_SIZE x LENGTH_SIZE x INPUT_SIZE]
            - "targets": a tensor of targets of size  [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE]

        :param predictions: Prediction sequence [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE]

        :param sample_number: Number of sample in batch (DEFAULT: 0)

        """
        # Check if we are supposed to visualize at all.
        if not self.app_state.visualize:
            return False

        # Initialize timePlot window - if required.
        if self.plotWindow is None:
            from miprometheus.utils.time_plot import TimePlot
            self.plotWindow = TimePlot()

        # import time
        # start_time = time.time()
        inputs_seq = data_dict["sequences"][sample_number].cpu().detach(
        ).numpy()
        targets_seq = data_dict["targets"][sample_number].cpu().detach().numpy(
        )
        predictions_seq = predictions[0].cpu().detach()
        #predictions_seq = torch.sigmoid(predictions_seq).numpy()

        # temporary for data with additional channel
        if len(inputs_seq.shape) == 3:
            inputs_seq = inputs_seq[0, :, :]

        # Create figure template.
        fig = self.generate_figure_layout()
        # Get axes that artists will draw on.
        (ax_memory, ax_attention, ax_bookmark, ax_inputs, ax_targets,
         ax_predictions) = fig.axes

        # Set intial values of displayed  inputs, targets and predictions -
        # simply zeros.
        inputs_displayed = np.transpose(np.zeros(inputs_seq.shape))
        targets_displayed = np.transpose(np.zeros(targets_seq.shape))
        predictions_displayed = np.transpose(np.zeros(predictions_seq.shape))

        head_attention_displayed = np.zeros(
            (self.cell_state_history[0][1].shape[-1], targets_seq.shape[0]))
        bookmark_attention_displayed = np.zeros(
            (self.cell_state_history[0][2].shape[-1], targets_seq.shape[0]))

        # Log sequence length - so the user can understand what is going on.
        logger = logging.getLogger('ModelBase')
        logger.info(
            "Generating dynamic visualization of {} figures, please wait...".
            format(inputs_seq.shape[0]))

        # Create frames - a list of lists, where each row is a list of artists
        # used to draw a given frame.
        frames = []

        for i, (input_element, target_element, prediction_element,
                (memory, wt, wt_d)) in enumerate(
                    zip(inputs_seq, targets_seq, predictions_seq,
                        self.cell_state_history)):
            # Display information every 10% of figures.
            if (inputs_seq.shape[0] > 10) and (i % (inputs_seq.shape[0] // 10)
                                               == 0):
                logger.info("Generating figure {}/{}".format(
                    i, inputs_seq.shape[0]))

            # Update displayed values on adequate positions.
            inputs_displayed[:, i] = input_element
            targets_displayed[:, i] = target_element
            predictions_displayed[:, i] = prediction_element

            memory_displayed = np.clip(memory[0], -3.0, 3.0)
            head_attention_displayed[:, i] = wt[0, 0, :]
            bookmark_attention_displayed[:, i] = wt_d[0, 0, :]

            # Create "Artists" drawing data on "ImageAxes".
            artists = [None] * len(fig.axes)
            params = {
                'edgecolor': 'black',
                'cmap': 'inferno',
                'linewidths': 1.4e-3
            }

            # Tell artists what to do;)
            artists[0] = ax_memory.pcolormesh(np.transpose(memory_displayed),
                                              vmin=-3.0,
                                              vmax=3.0,
                                              **params)
            artists[1] = ax_attention.pcolormesh(
                np.copy(head_attention_displayed),
                vmin=0.0,
                vmax=1.0,
                **params)
            artists[2] = ax_bookmark.pcolormesh(
                np.copy(bookmark_attention_displayed),
                vmin=0.0,
                vmax=1.0,
                **params)
            artists[3] = ax_inputs.pcolormesh(np.copy(inputs_displayed),
                                              vmin=0.0,
                                              vmax=1.0,
                                              **params)
            artists[4] = ax_targets.pcolormesh(np.copy(targets_displayed),
                                               vmin=0.0,
                                               vmax=1.0,
                                               **params)
            artists[5] = ax_predictions.pcolormesh(
                np.copy(predictions_displayed), vmin=0.0, vmax=1.0, **params)

            # Add "frame".
            frames.append(artists)

        # Plot figure and list of frames.

        self.plotWindow.update(fig, frames)
        return self.plotWindow.is_closed
Ejemplo n.º 6
0
    def plot(self, data_dict, logits, sample=0):
        """
        Visualize the attention weights (``ControlUnit`` & ``ReadUnit``) on the \
        question & feature maps. Dynamic visualization throughout the reasoning \
        steps is possible.

        :param data_dict: DataDict({'images','questions', 'questions_length', 'questions_string', 'questions_type', \
        'targets', 'targets_string', 'index','imgfiles', 'prediction_string'})
        :type data_dict: utils.DataDict

        :param logits: Prediction of the model.
        :type logits: torch.tensor

        :param sample: Index of sample in batch (Default: 0)
        :type sample: int

        :return: True when the user closes the window, False if we do not need to visualize.

        """

        # check whether the visualization is required
        if not self.app_state.visualize:
            return False

        # Initialize timePlot window - if required.
        if self.plotWindow is None:
            from miprometheus.utils.time_plot import TimePlot
            self.plotWindow = TimePlot()

        # unpack data_dict
        s_questions = data_dict['questions_string']
        question_type = data_dict['questions_type']
        answer_string = data_dict['targets_string']
        imgfiles = data_dict['imgfiles']
        prediction_string = data_dict['predictions_string']
        clevr_dir = data_dict['clevr_dir']

        # needed for nltk.word.tokenize
        nltk.download('punkt')
        # tokenize question string using same processing as in the problem
        # class
        words = nltk.word_tokenize(s_questions[sample])

        # Create figure template.
        fig = self.generate_figure_layout()
        # Get axes that artists will draw on.
        (ax_image, ax_attention_image, ax_attention_question,
         ax_step) = fig.axes

        # get the image
        set = imgfiles[sample].split('_')[1]
        image = os.path.join(clevr_dir, 'images', set, imgfiles[sample])
        image = Image.open(image).convert('RGB')
        image = self.transform(image)
        image = image.permute(1, 2, 0)  # [300, 300, 3]

        # get most probable answer -> prediction of the network
        proba_answers = torch.nn.functional.softmax(logits, -1)
        proba_answer = proba_answers[sample].detach().cpu()
        proba_answer = proba_answer.max().numpy()

        # image & attention sizes
        width = image.size(0)
        height = image.size(1)

        frames = []
        for step, (attention_mask, attention_question) in zip(
                range(self.max_step), self.mac_unit.cell_state_history):
            # preprocess attention image, reshape
            attention_size = int(np.sqrt(attention_mask.size(-1)))
            # attention mask has size [batch_size x 1 x(H*W)]
            attention_mask = attention_mask.view(-1, 1, attention_size,
                                                 attention_size)

            # upsample attention mask
            m = torch.nn.Upsample(size=[width, height],
                                  mode='bilinear',
                                  align_corners=True)
            up_sample_attention_mask = m(attention_mask)
            attention_mask = up_sample_attention_mask[sample, 0]

            # preprocess question, pick one sample number
            attention_question = attention_question[sample]

            # Create "Artists" drawing data on "ImageAxes".
            num_artists = len(fig.axes) + 1
            artists = [None] * num_artists

            # set title labels
            ax_image.set_title('CLEVR image: {}'.format(imgfiles[sample]))
            ax_attention_question.set_xticklabels(['h'] + words,
                                                  rotation='vertical',
                                                  fontsize=10)
            ax_step.axis('off')

            # set axis attention labels
            ax_attention_image.set_title('Predicted Answer: ' +
                                         prediction_string[sample] +
                                         ' [ proba: ' +
                                         str.format("{0:.3f}", proba_answer) +
                                         ']  ' + 'Ground Truth: ' +
                                         answer_string[sample])

            # Tell artists what to do:
            artists[0] = ax_image.imshow(image,
                                         interpolation='nearest',
                                         aspect='auto')
            artists[1] = ax_attention_image.imshow(image,
                                                   interpolation='nearest',
                                                   aspect='auto')
            artists[2] = ax_attention_image.imshow(attention_mask,
                                                   interpolation='nearest',
                                                   aspect='auto',
                                                   alpha=0.5,
                                                   cmap='Reds')
            artists[3] = ax_attention_question.imshow(
                attention_question.transpose(1, 0),
                interpolation='nearest',
                aspect='auto',
                cmap='Reds')
            artists[4] = ax_step.text(0,
                                      0.5,
                                      'Reasoning step index: ' + str(step) +
                                      ' | Question type: ' +
                                      question_type[sample],
                                      fontsize=15)

            # Add "frame".
            frames.append(artists)

        # Plot figure and list of frames.
        self.plotWindow.update(fig, frames)

        return self.plotWindow.is_closed
Ejemplo n.º 7
0
class MACNetwork(Model):
    """
    Implementation of the entire ``MAC`` network.
    """
    def __init__(self, params, problem_default_values_={}):
        """
        Constructor for the ``MAC`` network.

        :param params: dict of parameters (read from configuration ``.yaml`` file).
        :type params: utils.ParamInterface

        :param problem_default_values_: default values coming from the ``Problem`` class.
        :type problem_default_values_: dict
        """

        # call base constructor
        super(MACNetwork, self).__init__(params, problem_default_values_)

        # parse params dict
        self.dim = params['dim']
        self.embed_hidden = params['embed_hidden']  # embedding dimension
        self.max_step = params['max_step']
        self.self_attention = params['self_attention']
        self.memory_gate = params['memory_gate']
        self.dropout = params['dropout']

        try:
            self.nb_classes = problem_default_values_['nb_classes']
        except KeyError:
            self.logger.warning(
                "Couldn't retrieve one or more value(s) from problem_default_values_."
            )

        self.name = 'MAC'

        # instantiate units
        self.input_unit = InputUnit(dim=self.dim,
                                    embedded_dim=self.embed_hidden)

        self.mac_unit = MACUnit(dim=self.dim,
                                max_step=self.max_step,
                                self_attention=self.self_attention,
                                memory_gate=self.memory_gate,
                                dropout=self.dropout)

        self.output_unit = OutputUnit(dim=self.dim, nb_classes=self.nb_classes)

        self.data_definitions = {
            'images': {
                'size': [-1, 1024, 14, 14],
                'type': [np.ndarray]
            },
            'questions': {
                'size': [-1, -1, -1],
                'type': [torch.Tensor]
            },
            'questions_length': {
                'size': [-1],
                'type': [list, int]
            },
            'targets': {
                'size': [-1, self.nb_classes],
                'type': [torch.Tensor]
            }
        }

        # transform for the image plotting
        self.transform = transforms.Compose(
            [transforms.Resize([224, 224]),
             transforms.ToTensor()])

    def forward(self, data_dict, dropout=0.15):
        """
        Forward pass of the ``MAC`` network. Calls first the ``InputUnit``, then the recurrent \
        MAC cells and finally the ```OutputUnit``.

        :param data_dict: input data batch.
        :type data_dict: utils.DataDict

        :param dropout: dropout rate.
        :type dropout: float

        :return: Predictions of the model.
        """

        # reset cell state history for visualization
        if self.app_state.visualize:
            self.mac_unit.cell_state_history = []

        # unpack data_dict
        images = data_dict['images']
        questions = data_dict['questions']
        questions_length = data_dict['questions_length']

        # input unit
        img, kb_proj, lstm_out, h = self.input_unit(questions,
                                                    questions_length, images)

        # recurrent MAC cells
        memory = self.mac_unit(lstm_out, h, img, kb_proj)

        # output unit
        logits = self.output_unit(memory, h)

        return logits

    @staticmethod
    def generate_figure_layout():
        """
        Generate a figure layout for the attention visualization (done in \
        ``MACNetwork.plot()``)

        :return: figure layout.

        """
        import matplotlib
        from matplotlib.figure import Figure

        params = {
            'axes.titlesize': 'large',
            'axes.labelsize': 'large',
            'xtick.labelsize': 'medium',
            'ytick.labelsize': 'medium'
        }
        matplotlib.pylab.rcParams.update(params)

        # Prepare "generic figure template".
        # Create figure object.
        fig = Figure()

        # Create a specific grid for MAC.
        gs = matplotlib.gridspec.GridSpec(6, 2)

        # subplots: original image, attention on image & question, step index
        ax_image = fig.add_subplot(gs[2:6, 0])
        ax_attention_image = fig.add_subplot(gs[2:6, 1])
        ax_attention_question = fig.add_subplot(gs[0, :])
        ax_step = fig.add_subplot(gs[1, 0])

        # Set axis ticks
        ax_image.xaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_image.yaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_attention_image.xaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_attention_image.yaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))

        # question ticks
        ax_attention_question.xaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(nbins=40))

        ax_step.xaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_step.yaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))

        fig.set_tight_layout(True)

        return fig

    def plot(self, data_dict, logits, sample=0):
        """
        Visualize the attention weights (``ControlUnit`` & ``ReadUnit``) on the \
        question & feature maps. Dynamic visualization throughout the reasoning \
        steps is possible.

        :param data_dict: DataDict({'images','questions', 'questions_length', 'questions_string', 'questions_type', \
        'targets', 'targets_string', 'index','imgfiles', 'prediction_string'})
        :type data_dict: utils.DataDict

        :param logits: Prediction of the model.
        :type logits: torch.tensor

        :param sample: Index of sample in batch (Default: 0)
        :type sample: int

        :return: True when the user closes the window, False if we do not need to visualize.

        """

        # check whether the visualization is required
        if not self.app_state.visualize:
            return False

        # Initialize timePlot window - if required.
        if self.plotWindow is None:
            from miprometheus.utils.time_plot import TimePlot
            self.plotWindow = TimePlot()

        # unpack data_dict
        s_questions = data_dict['questions_string']
        question_type = data_dict['questions_type']
        answer_string = data_dict['targets_string']
        imgfiles = data_dict['imgfiles']
        prediction_string = data_dict['predictions_string']
        clevr_dir = data_dict['clevr_dir']

        # needed for nltk.word.tokenize
        nltk.download('punkt')
        # tokenize question string using same processing as in the problem
        # class
        words = nltk.word_tokenize(s_questions[sample])

        # Create figure template.
        fig = self.generate_figure_layout()
        # Get axes that artists will draw on.
        (ax_image, ax_attention_image, ax_attention_question,
         ax_step) = fig.axes

        # get the image
        set = imgfiles[sample].split('_')[1]
        image = os.path.join(clevr_dir, 'images', set, imgfiles[sample])
        image = Image.open(image).convert('RGB')
        image = self.transform(image)
        image = image.permute(1, 2, 0)  # [300, 300, 3]

        # get most probable answer -> prediction of the network
        proba_answers = torch.nn.functional.softmax(logits, -1)
        proba_answer = proba_answers[sample].detach().cpu()
        proba_answer = proba_answer.max().numpy()

        # image & attention sizes
        width = image.size(0)
        height = image.size(1)

        frames = []
        for step, (attention_mask, attention_question) in zip(
                range(self.max_step), self.mac_unit.cell_state_history):
            # preprocess attention image, reshape
            attention_size = int(np.sqrt(attention_mask.size(-1)))
            # attention mask has size [batch_size x 1 x(H*W)]
            attention_mask = attention_mask.view(-1, 1, attention_size,
                                                 attention_size)

            # upsample attention mask
            m = torch.nn.Upsample(size=[width, height],
                                  mode='bilinear',
                                  align_corners=True)
            up_sample_attention_mask = m(attention_mask)
            attention_mask = up_sample_attention_mask[sample, 0]

            # preprocess question, pick one sample number
            attention_question = attention_question[sample]

            # Create "Artists" drawing data on "ImageAxes".
            num_artists = len(fig.axes) + 1
            artists = [None] * num_artists

            # set title labels
            ax_image.set_title('CLEVR image: {}'.format(imgfiles[sample]))
            ax_attention_question.set_xticklabels(['h'] + words,
                                                  rotation='vertical',
                                                  fontsize=10)
            ax_step.axis('off')

            # set axis attention labels
            ax_attention_image.set_title('Predicted Answer: ' +
                                         prediction_string[sample] +
                                         ' [ proba: ' +
                                         str.format("{0:.3f}", proba_answer) +
                                         ']  ' + 'Ground Truth: ' +
                                         answer_string[sample])

            # Tell artists what to do:
            artists[0] = ax_image.imshow(image,
                                         interpolation='nearest',
                                         aspect='auto')
            artists[1] = ax_attention_image.imshow(image,
                                                   interpolation='nearest',
                                                   aspect='auto')
            artists[2] = ax_attention_image.imshow(attention_mask,
                                                   interpolation='nearest',
                                                   aspect='auto',
                                                   alpha=0.5,
                                                   cmap='Reds')
            artists[3] = ax_attention_question.imshow(
                attention_question.transpose(1, 0),
                interpolation='nearest',
                aspect='auto',
                cmap='Reds')
            artists[4] = ax_step.text(0,
                                      0.5,
                                      'Reasoning step index: ' + str(step) +
                                      ' | Question type: ' +
                                      question_type[sample],
                                      fontsize=15)

            # Add "frame".
            frames.append(artists)

        # Plot figure and list of frames.
        self.plotWindow.update(fig, frames)

        return self.plotWindow.is_closed
Ejemplo n.º 8
0
class ThalNetModel(SequentialModel):
    """
    ``ThalNet`` is a deep learning model inspired by neocortical communication \
    via the thalamus. This model consists of recurrent neural modules that send features \
    through a routing center, endowing the modules with the flexibility to share features \
    over multiple time steps.

    See the reference paper here: https://arxiv.org/pdf/1706.05744.pdf.

    .. warning:

        The reference paper indicates that the ``Thalnet`` model works on the Sequential MNIST problem. \
        This implementation does not for the moment, and has only been tested on the SerialRecall task so far.

        This should be adressed in a future release.

    """
    def __init__(self, params, problem_default_values_={}):
        """
        Constructor of the ``ThalNetModel``. Instantiates the ``ThalNetCell``.

        :param params: dictionary of parameters (read from the ``.yaml`` configuration file.)

        :param problem_default_values_: default values coming from the ``Problem`` class.
        :type problem_default_values_: dict

        """
        # Call base class initialization.
        super(ThalNetModel, self).__init__(params, problem_default_values_)

        # get the parameters values
        self.context_input_size = params['context_input_size']
        self.input_size = params['input_size']
        self.output_size = params['output_size']
        self.center_size = params['num_modules'] * params[
            'center_size_per_module']
        self.center_size_per_module = params['center_size_per_module']
        self.num_modules = params['num_modules']
        self.output_center_size = self.output_size + self.center_size_per_module

        # This is for the time plot
        self.cell_state_history = None

        # Create the DWM components
        self.ThalnetCell = ThalNetCell(self.input_size, self.output_size,
                                       self.context_input_size,
                                       self.center_size_per_module,
                                       self.num_modules)

        # model name
        self.name = 'ThalNetModel'

        # Expected content of the inputs
        self.data_definitions = {
            'sequences': {
                'size': [-1, -1, -1],
                'type': [torch.Tensor]
            },
            'targets': {
                'size': [-1, -1, -1],
                'type': [torch.Tensor]
            }
        }

    def forward(self, data_dict):  # x : batch_size, seq_len, input_size
        """
        Forward run of the ThalNetModel model.

        :param data_dict: DataDict({'sequences', **}) where 'sequences' is of shape \
         [batch_size, sequence_length, input_size]
        :type data_dict: utils.DataDict

        :returns: Predictions [batch_size, sequence_length, output_size]

        """
        inputs = data_dict['sequences']

        if self.app_state.visualize:
            self.cell_state_history = []

        output = None
        batch_size = inputs.size(0)
        seq_length = inputs.size(-2)

        # init state
        cell_state = self.ThalnetCell.init_state(batch_size)
        for j in range(seq_length):
            output_cell, cell_state = self.ThalnetCell(inputs[..., j, :],
                                                       cell_state)

            if output_cell is None:
                continue

            output_cell = output_cell[..., None, :]
            if output is None:
                output = output_cell

            # concatenate output
            else:
                output = torch.cat([output, output_cell], dim=-2)

            # This is for the time plot
            if self.app_state.visualize:
                self.cell_state_history.append([
                    cell_state[i][0].detach().numpy()
                    for i in range(self.num_modules)
                ] + [
                    cell_state[i][1].hidden_state.detach().numpy()
                    for i in range(self.num_modules)
                ])

        return output

    def generate_figure_layout(self):
        """
        Generate a figure layout which will be used in ``self.plot()``.

        :return: figure layout.

        """
        from matplotlib.figure import Figure
        import matplotlib.ticker as ticker
        from matplotlib import rc
        import matplotlib.gridspec as gridspec

        # Change fonts globally - for all figures/subsplots at once.
        rc('font', **{'family': 'Times New Roman'})

        # Prepare "generic figure template".
        # Create figure object.
        fig = Figure()

        # Create a specific grid
        gs = gridspec.GridSpec(4, 3)

        # modules & centers subplots
        ax_center = [
            fig.add_subplot(gs[i, 0]) for i in range(self.num_modules)
        ]
        ax_module = [
            fig.add_subplot(gs[i, 1]) for i in range(self.num_modules)
        ]  #

        # inputs & prediction subplot
        ax_inputs = fig.add_subplot(gs[0, 2])
        ax_pred = fig.add_subplot(gs[2, 2])

        # Set ticks - for bit axes only (for now).
        ax_inputs.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
        ax_inputs.yaxis.set_major_locator(ticker.MaxNLocator(integer=True))

        ax_pred.xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
        ax_pred.yaxis.set_major_locator(ticker.MaxNLocator(integer=True))

        # Set labels.
        ax_inputs.set_title('Inputs')
        ax_inputs.set_ylabel('num_row')
        ax_inputs.set_xlabel('num_columns')

        ax_pred.set_title('Prediction')
        ax_pred.set_xlabel('num_classes')

        # centers
        ax_center[0].set_title('center states')
        ax_center[3].set_xlabel('iteration')
        ax_center[0].set_ylabel('center size')
        ax_center[1].set_ylabel('center size')
        ax_center[2].set_ylabel('center size')
        ax_center[3].set_ylabel('center size')

        # modules
        ax_module[0].set_title('module states')
        ax_module[3].set_xlabel('iteration')
        ax_module[0].set_ylabel('module state size')
        ax_module[1].set_ylabel('module state size')
        ax_module[2].set_ylabel('module state size')
        ax_module[3].set_ylabel('module state size')

        # Return figure.
        return fig

    def plot(self, data_dict, logits, sample=0):
        """
        Plots specific information on the model's behavior.

        :param data_dict: DataDict({'sequences', **})
        :type data_dict: utils.DataDict

        :param logits: Predictions of the model
        :type logits: torch.tensor

        :param sample: Index of the sample to visualize. Default to 0.
        :type sample: int

        :return: ``True`` if the user pressed stop, else ``False``.

        """
        # Check if we are supposed to visualize at all.
        if not self.app_state.visualize:
            return False

        # Initialize timePlot window - if required.
        if self.plotWindow is None:
            from miprometheus.utils.time_plot import TimePlot
            self.plotWindow = TimePlot()

        inputs = data_dict['sequences']
        inputs = inputs.cpu().detach().numpy()
        predictions_seq = logits.cpu().detach().numpy()

        input_seq = inputs[sample,
                           0] if len(inputs.shape) == 4 else inputs[sample]

        # Create figure template.
        fig = self.generate_figure_layout()
        # Get axes that artists will draw on.

        # Set intial values of displayed  inputs, targets and predictions -
        # simply zeros.
        inputs_displayed = np.zeros(input_seq.shape)

        # Define Modules
        module_state_displayed_1 = np.zeros(
            (self.cell_state_history[0][4].shape[-1], input_seq.shape[-2]))
        module_state_displayed_2 = np.zeros(
            (self.cell_state_history[0][4].shape[-1], input_seq.shape[-2]))
        module_state_displayed_3 = np.zeros(
            (self.cell_state_history[0][4].shape[-1], input_seq.shape[-2]))
        module_state_displayed_4 = np.zeros(
            (self.cell_state_history[0][-1].shape[-1], input_seq.shape[-2]))

        # Define centers
        center_state_displayed_1 = np.zeros(
            (self.cell_state_history[0][0].shape[-1], input_seq.shape[-2]))
        center_state_displayed_2 = np.zeros(
            (self.cell_state_history[0][0].shape[-1], input_seq.shape[-2]))
        center_state_displayed_3 = np.zeros(
            (self.cell_state_history[0][0].shape[-1], input_seq.shape[-2]))
        center_state_displayed_4 = np.zeros(
            (self.cell_state_history[0][0].shape[-1], input_seq.shape[-2]))

        # Set initial values of memory and attentions.
        # Unpack initial state.

        # Log sequence length - so the user can understand what is going on.

        self.logger.info(
            "Generating dynamic visualization of {} figures, please wait...".
            format(input_seq.shape[0]))

        # Create frames - a list of lists, where each row is a list of artists
        # used to draw a given frame.
        frames = []

        for i, (input_element, state_tuple) in enumerate(
                zip(input_seq, self.cell_state_history)):
            # Display information every 10% of figures.
            if (input_seq.shape[0] > 10) and (i %
                                              (input_seq.shape[0] // 10) == 0):
                self.logger.info("Generating figure {}/{}".format(
                    i, input_seq.shape[0]))

            # Update displayed values on adequate positions.
            inputs_displayed[i, :] = input_element

            # Create "Artists" drawing data on "ImageAxes".
            artists = [None] * len(fig.axes)

            # centers state
            center_state_displayed_1[:, i] = state_tuple[0][sample, :]
            entity = fig.axes[0]
            artists[0] = entity.imshow(center_state_displayed_1,
                                       interpolation='nearest',
                                       aspect='auto')

            center_state_displayed_2[:, i] = state_tuple[1][sample, :]
            entity = fig.axes[1]
            artists[1] = entity.imshow(center_state_displayed_2,
                                       interpolation='nearest',
                                       aspect='auto')

            center_state_displayed_3[:, i] = state_tuple[2][sample, :]
            entity = fig.axes[2]
            artists[2] = entity.imshow(center_state_displayed_3,
                                       interpolation='nearest',
                                       aspect='auto')

            center_state_displayed_4[:, i] = state_tuple[3][sample, :]
            entity = fig.axes[3]
            artists[3] = entity.imshow(center_state_displayed_4,
                                       interpolation='nearest',
                                       aspect='auto')

            # module state
            module_state_displayed_1[:, i] = state_tuple[4][sample, :]
            entity = fig.axes[4]
            artists[4] = entity.imshow(module_state_displayed_1,
                                       interpolation='nearest',
                                       aspect='auto')

            module_state_displayed_2[:, i] = state_tuple[5][sample, :]
            entity = fig.axes[5]
            artists[5] = entity.imshow(module_state_displayed_2,
                                       interpolation='nearest',
                                       aspect='auto')

            module_state_displayed_3[:, i] = state_tuple[6][sample, :]
            entity = fig.axes[6]
            artists[6] = entity.imshow(module_state_displayed_3,
                                       interpolation='nearest',
                                       aspect='auto')

            module_state_displayed_4[:, i] = state_tuple[7][sample, :]
            entity = fig.axes[7]
            artists[7] = entity.imshow(module_state_displayed_4,
                                       interpolation='nearest',
                                       aspect='auto')

            entity = fig.axes[2 * self.num_modules]
            artists[2 * self.num_modules] = entity.imshow(
                inputs_displayed, interpolation='nearest', aspect='auto')

            entity = fig.axes[2 * self.num_modules + 1]
            artists[2 * self.num_modules + 1] = entity.imshow(
                predictions_seq[0, -1, None],
                interpolation='nearest',
                aspect='auto')

            # Add "frame".
            frames.append(artists)

        # Update time plot fir generated list of figures.
        self.plotWindow.update(fig, frames)

        return self.plotWindow.is_closed
Ejemplo n.º 9
0
    def plot(self, data_dict, logits, sample=0):
        """
        Plots specific information on the model's behavior.

        :param data_dict: DataDict({'sequences', **})
        :type data_dict: utils.DataDict

        :param logits: Predictions of the model
        :type logits: torch.tensor

        :param sample: Index of the sample to visualize. Default to 0.
        :type sample: int

        :return: ``True`` if the user pressed stop, else ``False``.

        """
        # Check if we are supposed to visualize at all.
        if not self.app_state.visualize:
            return False

        # Initialize timePlot window - if required.
        if self.plotWindow is None:
            from miprometheus.utils.time_plot import TimePlot
            self.plotWindow = TimePlot()

        inputs = data_dict['sequences']
        inputs = inputs.cpu().detach().numpy()
        predictions_seq = logits.cpu().detach().numpy()

        input_seq = inputs[sample,
                           0] if len(inputs.shape) == 4 else inputs[sample]

        # Create figure template.
        fig = self.generate_figure_layout()
        # Get axes that artists will draw on.

        # Set intial values of displayed  inputs, targets and predictions -
        # simply zeros.
        inputs_displayed = np.zeros(input_seq.shape)

        # Define Modules
        module_state_displayed_1 = np.zeros(
            (self.cell_state_history[0][4].shape[-1], input_seq.shape[-2]))
        module_state_displayed_2 = np.zeros(
            (self.cell_state_history[0][4].shape[-1], input_seq.shape[-2]))
        module_state_displayed_3 = np.zeros(
            (self.cell_state_history[0][4].shape[-1], input_seq.shape[-2]))
        module_state_displayed_4 = np.zeros(
            (self.cell_state_history[0][-1].shape[-1], input_seq.shape[-2]))

        # Define centers
        center_state_displayed_1 = np.zeros(
            (self.cell_state_history[0][0].shape[-1], input_seq.shape[-2]))
        center_state_displayed_2 = np.zeros(
            (self.cell_state_history[0][0].shape[-1], input_seq.shape[-2]))
        center_state_displayed_3 = np.zeros(
            (self.cell_state_history[0][0].shape[-1], input_seq.shape[-2]))
        center_state_displayed_4 = np.zeros(
            (self.cell_state_history[0][0].shape[-1], input_seq.shape[-2]))

        # Set initial values of memory and attentions.
        # Unpack initial state.

        # Log sequence length - so the user can understand what is going on.

        self.logger.info(
            "Generating dynamic visualization of {} figures, please wait...".
            format(input_seq.shape[0]))

        # Create frames - a list of lists, where each row is a list of artists
        # used to draw a given frame.
        frames = []

        for i, (input_element, state_tuple) in enumerate(
                zip(input_seq, self.cell_state_history)):
            # Display information every 10% of figures.
            if (input_seq.shape[0] > 10) and (i %
                                              (input_seq.shape[0] // 10) == 0):
                self.logger.info("Generating figure {}/{}".format(
                    i, input_seq.shape[0]))

            # Update displayed values on adequate positions.
            inputs_displayed[i, :] = input_element

            # Create "Artists" drawing data on "ImageAxes".
            artists = [None] * len(fig.axes)

            # centers state
            center_state_displayed_1[:, i] = state_tuple[0][sample, :]
            entity = fig.axes[0]
            artists[0] = entity.imshow(center_state_displayed_1,
                                       interpolation='nearest',
                                       aspect='auto')

            center_state_displayed_2[:, i] = state_tuple[1][sample, :]
            entity = fig.axes[1]
            artists[1] = entity.imshow(center_state_displayed_2,
                                       interpolation='nearest',
                                       aspect='auto')

            center_state_displayed_3[:, i] = state_tuple[2][sample, :]
            entity = fig.axes[2]
            artists[2] = entity.imshow(center_state_displayed_3,
                                       interpolation='nearest',
                                       aspect='auto')

            center_state_displayed_4[:, i] = state_tuple[3][sample, :]
            entity = fig.axes[3]
            artists[3] = entity.imshow(center_state_displayed_4,
                                       interpolation='nearest',
                                       aspect='auto')

            # module state
            module_state_displayed_1[:, i] = state_tuple[4][sample, :]
            entity = fig.axes[4]
            artists[4] = entity.imshow(module_state_displayed_1,
                                       interpolation='nearest',
                                       aspect='auto')

            module_state_displayed_2[:, i] = state_tuple[5][sample, :]
            entity = fig.axes[5]
            artists[5] = entity.imshow(module_state_displayed_2,
                                       interpolation='nearest',
                                       aspect='auto')

            module_state_displayed_3[:, i] = state_tuple[6][sample, :]
            entity = fig.axes[6]
            artists[6] = entity.imshow(module_state_displayed_3,
                                       interpolation='nearest',
                                       aspect='auto')

            module_state_displayed_4[:, i] = state_tuple[7][sample, :]
            entity = fig.axes[7]
            artists[7] = entity.imshow(module_state_displayed_4,
                                       interpolation='nearest',
                                       aspect='auto')

            entity = fig.axes[2 * self.num_modules]
            artists[2 * self.num_modules] = entity.imshow(
                inputs_displayed, interpolation='nearest', aspect='auto')

            entity = fig.axes[2 * self.num_modules + 1]
            artists[2 * self.num_modules + 1] = entity.imshow(
                predictions_seq[0, -1, None],
                interpolation='nearest',
                aspect='auto')

            # Add "frame".
            frames.append(artists)

        # Update time plot fir generated list of figures.
        self.plotWindow.update(fig, frames)

        return self.plotWindow.is_closed
Ejemplo n.º 10
0
class SequentialModel(Model):
    """
    Class representing base class for all Sequential Models.

    Inherits from models.model.Model as most features are the same.

    Should be derived by all sequential models.

    """
    def __init__(self, params, problem_default_values_={}):
        """
        Mostly calls the base ``models.model.Model`` constructor.

        Specifies a better structure for ``self.data_definitions``.

        :param params: Parameters read from configuration ``.yaml`` file.

        :param problem_default_values_: dict of parameters values coming from the problem class. One example of such \
        parameter value is the size of the vocabulary set in a translation problem.
        :type problem_default_values_: dict

        """
        super(SequentialModel,
              self).__init__(params,
                             problem_default_values_=problem_default_values_)

        # "Default" model name.
        self.name = 'SequentialModel'

        # We can then define a dict that contains a description of the expected (and mandatory) inputs for this model.
        # This dict should be defined using self.params.
        self.data_definitions = {
            'sequences': {
                'size': [-1, -1, -1],
                'type': [torch.Tensor]
            },
            'targets': {
                'size': [-1, -1, -1],
                'type': [torch.Tensor]
            }
        }

    def plot(self, data_dict, predictions, sample=0):
        """
        Creates a default interactive visualization, with a slider enabling to
        move forth and back along the time axis (iteration over the sequence elements in a given episode).
        The default visualization contains the input, output and target sequences.

        For a more model/problem - dependent visualization, please overwrite this
        method in the derived model class.

        :param data_dict: DataDict containing

           - input sequences: [BATCH_SIZE x SEQUENCE_LENGTH x INPUT_SIZE],
           - target sequences:  [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_SIZE]


        :param predictions: Predicted sequences [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_SIZE]
        :type predictions: torch.tensor

        :param sample: Number of sample in batch (default: 0)
        :type sample: int

        """
        # Check if we are supposed to visualize at all.
        if not self.app_state.visualize:
            return

        # Initialize timePlot window - if required.
        if self.plotWindow is None:
            from miprometheus.utils.time_plot import TimePlot
            self.plotWindow = TimePlot()

        import matplotlib
        from matplotlib.figure import Figure

        # Change fonts globally - for all figures/subsplots at once.
        # from matplotlib import rc
        # rc('font', **{'family': 'Times New Roman'})

        params = {  # 'legend.fontsize': '28',
            'axes.titlesize': 'large',
            'axes.labelsize': 'large',
            'xtick.labelsize': 'medium',
            'ytick.labelsize': 'medium'
        }
        matplotlib.pylab.rcParams.update(params)

        # Create a single "figure layout" for all displayed frames.
        fig = Figure()
        axes = fig.subplots(
            3,
            1,
            sharex=True,
            sharey=False,
            gridspec_kw={'width_ratios': [predictions.shape[0]]})

        # Set ticks.
        axes[0].xaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        axes[0].yaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        axes[1].yaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        axes[2].yaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))

        # Set labels.
        axes[0].set_title('Inputs')
        axes[0].set_ylabel('Control/Data bits')
        axes[1].set_title('Targets')
        axes[1].set_ylabel('Data bits')
        axes[2].set_title('Predictions')
        axes[2].set_ylabel('Data bits')
        axes[2].set_xlabel('Item number')
        fig.set_tight_layout(True)

        # Detach a sample from batch and copy it to CPU.
        inputs_seq = data_dict['sequences'][sample].cpu().detach().numpy()
        targets_seq = data_dict['targets'][sample].cpu().detach().numpy()
        predictions_seq = predictions[sample].cpu().detach().numpy()

        # Create empty matrices.
        x = np.transpose(np.zeros(inputs_seq.shape))
        y = np.transpose(np.zeros(predictions_seq.shape))
        z = np.transpose(np.zeros(targets_seq.shape))

        # Log sequence length - so the user can understand what is going on.
        self.logger.info(
            "Generating dynamic visualization of {} figures, please wait...".
            format(inputs_seq.shape[0]))

        # Create frames - a list of lists, where each row is a list of artists
        # used to draw a given frame.
        frames = []

        for i, (input_word, prediction_word, target_word) in enumerate(
                zip(inputs_seq, predictions_seq, targets_seq)):
            # Display information every 10% of figures.
            if (inputs_seq.shape[0] > 10) and (i % (inputs_seq.shape[0] // 10)
                                               == 0):
                self.logger.info("Generating figure {}/{}".format(
                    i, inputs_seq.shape[0]))

            # Add words to adequate positions.
            x[:, i] = input_word
            y[:, i] = target_word
            z[:, i] = prediction_word

            # Create "Artists" drawing data on "ImageAxes".
            artists = [None] * len(fig.axes)

            # Tell artists what to do
            artists[0] = axes[0].imshow(x,
                                        interpolation='nearest',
                                        aspect='auto')
            artists[1] = axes[1].imshow(y,
                                        interpolation='nearest',
                                        aspect='auto')
            artists[2] = axes[2].imshow(z,
                                        interpolation='nearest',
                                        aspect='auto')

            # Add "frame".
            frames.append(artists)

        # Plot figure and list of frames.
        self.plotWindow.update(fig, frames)
Ejemplo n.º 11
0
class DNC(SequentialModel):
    """
        Implementation of Differentiable Neural Computer (DNC)

        Graves, Alex, et al. "Hybrid computing using a neural network with dynamic external memory."
        Nature 538.7626 (2016): 471. doi:10.1038/nature20101
    """
    def __init__(self, params, problem_default_values_={}):
        """
        Constructor. Initializes parameters on the basis of dictionary passed
        as argument.

        :param params: Local view to the Parameter Regsitry ''model'' section.

        :param problem_default_values_: Dictionary containing key-values received from problem.

        """
        # Call base constructor. Sets up default values etc.
        super(DNC, self).__init__(params, problem_default_values_)
        # Model name.
        self.name = 'DNC'

        # Parse default values received from problem and add them to registry.
        self.params.add_default_params({
            'input_item_size':
            problem_default_values_['input_item_size'],
            'output_item_size':
            problem_default_values_['output_item_size']
        })

        self.output_units = params['output_item_size']

        self.memory_addresses_size = params["memory_addresses_size"]
        self.label = params["name"]
        self.cell_state_history = None

        # Number of read and write heads
        self._num_reads = params["num_reads"]
        self._num_writes = params["num_writes"]

        # Create the DNC components
        self.DNCCell = DNCCell(self.output_units, params)

    def forward(self, data_dict):
        """
        Forward function requires that the data_dict will contain at least "sequences"

        :param data_dict: DataDict containing at least:
            - "sequences": a tensor of input data of size [BATCH_SIZE x LENGTH_SIZE x INPUT_SIZE]

        :returns: Predictions (logits) being a tensor of size  [BATCH_SIZE x LENGTH_SIZE x OUTPUT_SIZE].

        """
        # Get dtype.
        dtype = self.app_state.dtype

        # Unpack dict.
        inputs = data_dict['sequences']

        # Get batch size and seq length.
        batch_size = inputs.size(0)
        seq_length = inputs.size(1)

        output = None

        if self.app_state.visualize:
            self.cell_state_history = []

        memory_addresses_size = self.memory_addresses_size

        # if memory size is not fixed, set it to the total input plus output
        # size
        if memory_addresses_size == -1:
            memory_addresses_size = seq_length

        # init state
        cell_state = self.DNCCell.init_state(memory_addresses_size, batch_size)

        #cell_state = self.init_state(memory_addresses_size)
        for j in range(seq_length):
            output_cell, cell_state = self.DNCCell(inputs[..., j, :],
                                                   cell_state)

            if output_cell is None:
                continue

            output_cell = output_cell[..., None, :]
            if output is None:
                output = output_cell
            else:
                output = torch.cat([output, output_cell], dim=-2)

            # This is for the time plot
            if self.app_state.visualize:
                self.cell_state_history.append(
                    (cell_state.memory_state.detach().cpu().numpy(),
                     cell_state.int_init_state.usage.detach().cpu().numpy(),
                     cell_state.int_init_state.links.precedence_weights.detach(
                     ).cpu().numpy(),
                     cell_state.int_init_state.read_weights.detach().cpu(
                     ).numpy(), cell_state.int_init_state.write_weights.detach(
                     ).cpu().numpy()))

            # if self.plot_active:
            #    self.plot_memory_attention(output, cell_state)

        return output

    def plot_memory_attention(self, data_dict, predictions, sample_number=0):
        """
        Plots memory and attention TODO: fix.

        :param data_dict: DataDict containing at least:
            - "sequences": a tensor of input data of size [BATCH_SIZE x LENGTH_SIZE x INPUT_SIZE]
            - "targets": a tensor of targets of size  [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE]

        :param predictions: Prediction sequence [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE]

        :param sample_number: Number of sample in batch (DEFAULT: 0)

        """
        # plot attention/memory

        self.logger.warning(
            "DNC 'plot_memory_attention' method not implemented!")
        #plot_memory_attention(output, states[2], states[1][0], states[1][1], states[1][2], self.label)

    def generate_figure_layout(self):
        """
        DOCUMENTATION!
        :return:
        """
        import matplotlib
        from matplotlib.figure import Figure

        # Change fonts globally - for all figures/subsplots at once.
        # from matplotlib import rc
        # rc('font', **{'family': 'Times New Roman'})
        params = {
            # 'legend.fontsize': '28',
            'axes.titlesize': 'large',
            'axes.labelsize': 'large',
            'xtick.labelsize': 'medium',
            'ytick.labelsize': 'medium'
        }
        matplotlib.pylab.rcParams.update(params)

        # Prepare "generic figure template".
        # Create figure object.
        fig = Figure()

        # Create a specific grid for DWM .
        gs = matplotlib.gridspec.GridSpec(3, 9)

        # Memory
        ax_memory = fig.add_subplot(gs[:, 0])  # all rows, col 0
        ax_read = fig.add_subplot(gs[:, 1:3])  # all rows, col 2-3
        ax_write = fig.add_subplot(gs[:, 3:5])  # all rows, col 4-5
        ax_usage = fig.add_subplot(gs[:, 5:7])  # all rows, col 4-5

        ax_inputs = fig.add_subplot(gs[0, 7:])  # row 0, span 2 columns
        ax_targets = fig.add_subplot(gs[1, 7:])  # row 0, span 2 columns
        ax_predictions = fig.add_subplot(gs[2, 7:])  # row 0, span 2 columns

        # Set ticks - for bit axes only (for now).
        ax_inputs.xaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_inputs.yaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_targets.xaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_targets.yaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_predictions.xaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_predictions.yaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_memory.xaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_memory.yaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_read.xaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_read.yaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_write.xaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_write.yaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_usage.xaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))
        ax_usage.yaxis.set_major_locator(
            matplotlib.ticker.MaxNLocator(integer=True))

        # Set labels.
        ax_inputs.set_title('Inputs')
        ax_inputs.set_ylabel('Control/Data bits')
        ax_targets.set_title('Targets')
        ax_targets.set_ylabel('Data bits')
        ax_predictions.set_title('Predictions')
        ax_predictions.set_ylabel('Data bits')
        ax_predictions.set_xlabel('Item number/Iteration')

        ax_memory.set_title('Memory')
        ax_memory.set_ylabel('Memory Addresses')
        ax_memory.set_xlabel('Content bits')
        ax_read.set_title('Read attention')
        ax_read.set_xlabel('Iteration')
        ax_write.set_title('Write Attention')
        ax_write.set_xlabel('Iteration')
        ax_usage.set_title('Usage')
        ax_usage.set_xlabel('Iteration')

        fig.set_tight_layout(True)
        # gs.tight_layout(fig)
        # plt.tight_layout()
        #fig.subplots_adjust(left = 0)
        return fig

    def plot(self, data_dict, predictions, sample_number=0):
        """
        Interactive visualization, with a slider enabling to move forth and
        back along the time axis (iteration in a given episode).

        :param data_dict: DataDict containing at least:
            - "sequences": a tensor of input data of size [BATCH_SIZE x LENGTH_SIZE x INPUT_SIZE]
            - "targets": a tensor of targets of size  [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE]

        :param predictions: Prediction sequence [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE]
        :param sample_number: Number of sample in batch (DEFAULT: 0)

        """
        # Check if we are supposed to visualize at all.
        if not self.app_state.visualize:
            return False

        # Initialize timePlot window - if required.
        if self.plotWindow is None:
            from miprometheus.utils.time_plot import TimePlot
            self.plotWindow = TimePlot()

        # import time
        # start_time = time.time()
        inputs_seq = data_dict["sequences"][sample_number].cpu().detach(
        ).numpy()
        targets_seq = data_dict["targets"][sample_number].cpu().detach().numpy(
        )
        predictions_seq = predictions[0].cpu().detach().numpy()

        # temporary for data with additional channel
        if len(inputs_seq.shape) == 3:
            inputs_seq = inputs_seq[0, :, :]

        # Create figure template.
        fig = self.generate_figure_layout()
        # Get axes that artists will draw on.
        (ax_memory, ax_read, ax_write, ax_usage, ax_inputs, ax_targets,
         ax_predictions) = fig.axes

        # Set intial values of displayed  inputs, targets and predictions -
        # simply zeros.
        inputs_displayed = np.transpose(np.zeros(inputs_seq.shape))
        targets_displayed = np.transpose(np.zeros(targets_seq.shape))
        predictions_displayed = np.transpose(np.zeros(predictions_seq.shape))

        head_attention_read = np.zeros(
            (self.cell_state_history[0][3].shape[-1], targets_seq.shape[0]))
        head_attention_write = np.zeros(
            (self.cell_state_history[0][4].shape[-1], targets_seq.shape[0]))
        usage_displayed = np.zeros(
            (self.cell_state_history[0][1].shape[-1], targets_seq.shape[0]))

        # Log sequence length - so the user can understand what is going on.
        logger = logging.getLogger('ModelBase')
        logger.info(
            "Generating dynamic visualization of {} figures, please wait...".
            format(inputs_seq.shape[0]))

        # Create frames - a list of lists, where each row is a list of artists
        # used to draw a given frame.
        frames = []

        for i, (input_element, target_element, prediction_element,
                (memory, usage, links, wt_r, wt_w)) in enumerate(
                    zip(inputs_seq, targets_seq, predictions_seq,
                        self.cell_state_history)):
            # Display information every 10% of figures.
            if (inputs_seq.shape[0] > 10) and (i % (inputs_seq.shape[0] // 10)
                                               == 0):
                logger.info("Generating figure {}/{}".format(
                    i, inputs_seq.shape[0]))

            # Update displayed values on adequate positions.
            inputs_displayed[:, i] = input_element
            targets_displayed[:, i] = target_element
            predictions_displayed[:, i] = prediction_element

            memory_displayed = memory[0]
            # Get attention of head 0.
            head_attention_read[:, i] = wt_r[0, 0, :]
            head_attention_write[:, i] = wt_w[0, 0, :]
            usage_displayed[:, i] = usage[0, :]

            # Create "Artists" drawing data on "ImageAxes".
            artists = [None] * len(fig.axes)

            # Tell artists what to do;)
            artists[0] = ax_memory.imshow(np.transpose(memory_displayed),
                                          interpolation='nearest',
                                          aspect='auto')
            artists[1] = ax_read.imshow(head_attention_read,
                                        interpolation='nearest',
                                        aspect='auto')
            artists[2] = ax_write.imshow(head_attention_write,
                                         interpolation='nearest',
                                         aspect='auto')
            artists[3] = ax_usage.imshow(usage_displayed,
                                         interpolation='nearest',
                                         aspect='auto')
            artists[4] = ax_inputs.imshow(inputs_displayed,
                                          interpolation='nearest',
                                          aspect='auto')
            artists[5] = ax_targets.imshow(targets_displayed,
                                           interpolation='nearest',
                                           aspect='auto')
            artists[6] = ax_predictions.imshow(predictions_displayed,
                                               interpolation='nearest',
                                               aspect='auto')

            # Add "frame".
            frames.append(artists)

        # print("--- %s seconds ---" % (time.time() - start_time))
        # Plot figure and list of frames.

        self.plotWindow.update(fig, frames)
        return self.plotWindow.is_closed
Ejemplo n.º 12
0
    def plot(self, data_dict, predictions, sample_number=0):
        """
        Interactive visualization, with a slider enabling to move forth and
        back along the time axis (iteration in a given episode).

        :param data_dict: DataDict containing at least:
            - "sequences": a tensor of input data of size [BATCH_SIZE x LENGTH_SIZE x INPUT_SIZE]
            - "targets": a tensor of targets of size  [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE]

        :param predictions: Prediction sequence [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE]
        :param sample_number: Number of sample in batch (DEFAULT: 0)

        """
        # Check if we are supposed to visualize at all.
        if not self.app_state.visualize:
            return False

        # Initialize timePlot window - if required.
        if self.plotWindow is None:
            from miprometheus.utils.time_plot import TimePlot
            self.plotWindow = TimePlot()

        # import time
        # start_time = time.time()
        inputs_seq = data_dict["sequences"][sample_number].cpu().detach(
        ).numpy()
        targets_seq = data_dict["targets"][sample_number].cpu().detach().numpy(
        )
        predictions_seq = predictions[0].cpu().detach().numpy()

        # temporary for data with additional channel
        if len(inputs_seq.shape) == 3:
            inputs_seq = inputs_seq[0, :, :]

        # Create figure template.
        fig = self.generate_figure_layout()
        # Get axes that artists will draw on.
        (ax_memory, ax_read, ax_write, ax_usage, ax_inputs, ax_targets,
         ax_predictions) = fig.axes

        # Set intial values of displayed  inputs, targets and predictions -
        # simply zeros.
        inputs_displayed = np.transpose(np.zeros(inputs_seq.shape))
        targets_displayed = np.transpose(np.zeros(targets_seq.shape))
        predictions_displayed = np.transpose(np.zeros(predictions_seq.shape))

        head_attention_read = np.zeros(
            (self.cell_state_history[0][3].shape[-1], targets_seq.shape[0]))
        head_attention_write = np.zeros(
            (self.cell_state_history[0][4].shape[-1], targets_seq.shape[0]))
        usage_displayed = np.zeros(
            (self.cell_state_history[0][1].shape[-1], targets_seq.shape[0]))

        # Log sequence length - so the user can understand what is going on.
        logger = logging.getLogger('ModelBase')
        logger.info(
            "Generating dynamic visualization of {} figures, please wait...".
            format(inputs_seq.shape[0]))

        # Create frames - a list of lists, where each row is a list of artists
        # used to draw a given frame.
        frames = []

        for i, (input_element, target_element, prediction_element,
                (memory, usage, links, wt_r, wt_w)) in enumerate(
                    zip(inputs_seq, targets_seq, predictions_seq,
                        self.cell_state_history)):
            # Display information every 10% of figures.
            if (inputs_seq.shape[0] > 10) and (i % (inputs_seq.shape[0] // 10)
                                               == 0):
                logger.info("Generating figure {}/{}".format(
                    i, inputs_seq.shape[0]))

            # Update displayed values on adequate positions.
            inputs_displayed[:, i] = input_element
            targets_displayed[:, i] = target_element
            predictions_displayed[:, i] = prediction_element

            memory_displayed = memory[0]
            # Get attention of head 0.
            head_attention_read[:, i] = wt_r[0, 0, :]
            head_attention_write[:, i] = wt_w[0, 0, :]
            usage_displayed[:, i] = usage[0, :]

            # Create "Artists" drawing data on "ImageAxes".
            artists = [None] * len(fig.axes)

            # Tell artists what to do;)
            artists[0] = ax_memory.imshow(np.transpose(memory_displayed),
                                          interpolation='nearest',
                                          aspect='auto')
            artists[1] = ax_read.imshow(head_attention_read,
                                        interpolation='nearest',
                                        aspect='auto')
            artists[2] = ax_write.imshow(head_attention_write,
                                         interpolation='nearest',
                                         aspect='auto')
            artists[3] = ax_usage.imshow(usage_displayed,
                                         interpolation='nearest',
                                         aspect='auto')
            artists[4] = ax_inputs.imshow(inputs_displayed,
                                          interpolation='nearest',
                                          aspect='auto')
            artists[5] = ax_targets.imshow(targets_displayed,
                                           interpolation='nearest',
                                           aspect='auto')
            artists[6] = ax_predictions.imshow(predictions_displayed,
                                               interpolation='nearest',
                                               aspect='auto')

            # Add "frame".
            frames.append(artists)

        # print("--- %s seconds ---" % (time.time() - start_time))
        # Plot figure and list of frames.

        self.plotWindow.update(fig, frames)
        return self.plotWindow.is_closed