Пример #1
0
    def plot_memory_all_model_params_sequence(self,
                                              data_dict,
                                              predictions,
                                              sample_number=0):
        """
        Creates list of figures used in interactive visualization, with a
        slider enabling to move forth and back along the time axis (iteration
        in a given episode). The visualization presents input, output and
        target sequences passed as input parameters. Additionally, it utilizes
        state tuples collected during the experiment for displaying the memory
        state, read and write attentions; and gating params.

        :param data_dict: DataDict containing at least:
            - "sequences": a tensor of input data of size [BATCH_SIZE x LENGTH_SIZE x INPUT_SIZE]
            - "targets": a tensor of targets of size  [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE]

        :param predictions: Prediction sequence [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE]
        
        :param sample_number: Number of sample in batch (DEFAULT: 0)

        """
        # Check if we are supposed to visualize at all.
        if not self.app_state.visualize:
            return False

        # Initialize timePlot window - if required.
        if self.plotWindow is None:
            from miprometheus.utils.time_plot import TimePlot
            self.plotWindow = TimePlot()

        # import time
        # start_time = time.time()
        # Create figure template.
        fig = self.generate_memory_all_model_params_figure_layout()
        # Get axes that artists will draw on.
        (ax_memory, ax_write_gate, ax_write_shift, ax_write_attention,
         ax_write_similarity, ax_read_gate, ax_read_shift, ax_read_attention,
         ax_read_similarity, ax_inputs, ax_targets, ax_predictions) = fig.axes

        # Unpack data tuple.
        inputs_seq = data_dict["sequences"][sample_number].cpu().detach(
        ).numpy()
        targets_seq = data_dict["targets"][sample_number].cpu().detach().numpy(
        )

        predictions_seq = predictions[sample_number].cpu().detach().numpy()

        # Set intial values of displayed  inputs, targets and predictions -
        # simply zeros.
        inputs_displayed = np.transpose(np.zeros(inputs_seq.shape))
        targets_displayed = np.transpose(np.zeros(targets_seq.shape))
        predictions_displayed = np.transpose(np.zeros(predictions_seq.shape))

        # Set initial values of memory and attentions.
        # Unpack initial state.
        (ctrl_state, interface_state, memory_state,
         read_vectors) = self.cell_state_initial
        (read_state_tuples, write_state_tuple) = interface_state
        (write_attention, write_similarity, write_gate,
         write_shift) = write_state_tuple

        # Initialize "empty" matrices.

        read0_attention_displayed = np.zeros(
            (memory_state.shape[1], targets_seq.shape[0]))
        read0_similarity_displayed = np.zeros(
            (memory_state.shape[1], targets_seq.shape[0]))
        read0_gate_displayed = np.zeros(
            (write_gate.shape[1], targets_seq.shape[0]))
        read0_shift_displayed = np.zeros(
            (write_shift.shape[1], targets_seq.shape[0]))

        write_attention_displayed = np.zeros(
            (memory_state.shape[1], targets_seq.shape[0]))
        write_similarity_displayed = np.zeros(
            (memory_state.shape[1], targets_seq.shape[0]))
        # Generally we can use write shapes as are the same.
        write_gate_displayed = np.zeros(
            (write_gate.shape[1], targets_seq.shape[0]))
        write_shift_displayed = np.zeros(
            (write_shift.shape[1], targets_seq.shape[0]))

        # Log sequence length - so the user can understand what is going on.
        logger = logging.getLogger('ModelBase')
        logger.info(
            "Generating dynamic visualization of {} figures, please wait...".
            format(inputs_seq.shape[0]))

        # Create frames - a list of lists, where each row is a list of artists
        # used to draw a given frame.
        frames = []

        for i, (input_element, target_element, prediction_element,
                cell_state) in enumerate(
                    zip(inputs_seq, targets_seq, predictions_seq,
                        self.cell_state_history)):
            # Display information every 10% of figures.
            if (inputs_seq.shape[0] > 10) and (i % (inputs_seq.shape[0] // 10)
                                               == 0):
                logger.info("Generating figure {}/{}".format(
                    i, inputs_seq.shape[0]))

            # Update displayed values on adequate positions.
            inputs_displayed[:, i] = input_element
            targets_displayed[:, i] = target_element
            predictions_displayed[:, i] = prediction_element

            # Unpack cell state.
            (ctrl_state, interface_state, memory_state,
             read_vectors) = cell_state
            (read_state_tuples, write_state_tuple) = interface_state
            (write_attention, write_similarity, write_gate,
             write_shift) = write_state_tuple
            (read_attentions, read_similarities, read_gates,
             read_shifts) = zip(*read_state_tuples)

            # Set variables.
            memory_displayed = memory_state[0].detach().numpy()
            # Get params of read head 0.
            read0_attention_displayed[:, i] = read_attentions[0][0][:,
                                                                    0].detach(
                                                                    ).numpy()
            read0_similarity_displayed[:,
                                       i] = read_similarities[0][0][:,
                                                                    0].detach(
                                                                    ).numpy()
            read0_gate_displayed[:, i] = read_gates[0][0][:,
                                                          0].detach().numpy()
            read0_shift_displayed[:,
                                  i] = read_shifts[0][0][:,
                                                         0].detach().numpy()

            # Get params of write head
            write_attention_displayed[:, i] = write_attention[0][:, 0].detach(
            ).numpy()
            write_similarity_displayed[:, i] = write_similarity[0][:,
                                                                   0].detach(
                                                                   ).numpy()
            write_gate_displayed[:, i] = write_gate[0][:, 0].detach().numpy()
            write_shift_displayed[:, i] = write_shift[0][:, 0].detach().numpy()

            # Create "Artists" drawing data on "ImageAxes".
            artists = [None] * len(fig.axes)

            # "Show" data on "axes".
            artists[0] = ax_memory.imshow(memory_displayed,
                                          interpolation='nearest',
                                          aspect='auto')

            # Read head.
            artists[1] = ax_read_attention.imshow(read0_attention_displayed,
                                                  interpolation='nearest',
                                                  aspect='auto')
            artists[2] = ax_read_similarity.imshow(read0_similarity_displayed,
                                                   interpolation='nearest',
                                                   aspect='auto')
            artists[3] = ax_read_gate.imshow(read0_gate_displayed,
                                             interpolation='nearest',
                                             aspect='auto')
            artists[4] = ax_read_shift.imshow(read0_shift_displayed,
                                              interpolation='nearest',
                                              aspect='auto')

            # Write head.
            artists[5] = ax_write_attention.imshow(write_attention_displayed,
                                                   interpolation='nearest',
                                                   aspect='auto')
            artists[6] = ax_write_similarity.imshow(write_similarity_displayed,
                                                    interpolation='nearest',
                                                    aspect='auto')
            artists[7] = ax_write_gate.imshow(write_gate_displayed,
                                              interpolation='nearest',
                                              aspect='auto')
            artists[8] = ax_write_shift.imshow(write_shift_displayed,
                                               interpolation='nearest',
                                               aspect='auto')

            # "Default data".
            artists[9] = ax_inputs.imshow(inputs_displayed,
                                          interpolation='nearest',
                                          aspect='auto')
            artists[10] = ax_targets.imshow(targets_displayed,
                                            interpolation='nearest',
                                            aspect='auto')
            artists[11] = ax_predictions.imshow(predictions_displayed,
                                                interpolation='nearest',
                                                aspect='auto')

            # Add "frame".
            frames.append(artists)

        # print("--- %s seconds ---" % (time.time() - start_time))
        # Plot figure and list of frames.

        self.plotWindow.update(fig, frames)
        return self.plotWindow.is_closed
Пример #2
0
    def plot(self, data_dict, predictions, sample_number=0):
        """
        Interactive visualization, with a slider enabling to move forth and
        back along the time axis (iteration in a given episode).

        :param data_dict: DataDict containing at least:
            - "sequences": a tensor of input data of size [BATCH_SIZE x LENGTH_SIZE x INPUT_SIZE]
            - "targets": a tensor of targets of size  [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE]

        :param predictions: Prediction sequence [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE]

        :param sample_number: Number of sample in batch (DEFAULT: 0)

        """
        # Check if we are supposed to visualize at all.
        if not self.app_state.visualize:
            return False

        # Initialize timePlot window - if required.
        if self.plotWindow is None:
            from miprometheus.utils.time_plot import TimePlot
            self.plotWindow = TimePlot()

        # import time
        # start_time = time.time()
        inputs_seq = data_dict["sequences"][sample_number].cpu().detach(
        ).numpy()
        targets_seq = data_dict["targets"][sample_number].cpu().detach().numpy(
        )
        predictions_seq = predictions[0].cpu().detach()
        #predictions_seq = torch.sigmoid(predictions_seq).numpy()

        # temporary for data with additional channel
        if len(inputs_seq.shape) == 3:
            inputs_seq = inputs_seq[0, :, :]

        # Create figure template.
        fig = self.generate_figure_layout()
        # Get axes that artists will draw on.
        (ax_memory, ax_attention, ax_bookmark, ax_inputs, ax_targets,
         ax_predictions) = fig.axes

        # Set intial values of displayed  inputs, targets and predictions -
        # simply zeros.
        inputs_displayed = np.transpose(np.zeros(inputs_seq.shape))
        targets_displayed = np.transpose(np.zeros(targets_seq.shape))
        predictions_displayed = np.transpose(np.zeros(predictions_seq.shape))

        head_attention_displayed = np.zeros(
            (self.cell_state_history[0][1].shape[-1], targets_seq.shape[0]))
        bookmark_attention_displayed = np.zeros(
            (self.cell_state_history[0][2].shape[-1], targets_seq.shape[0]))

        # Log sequence length - so the user can understand what is going on.
        logger = logging.getLogger('ModelBase')
        logger.info(
            "Generating dynamic visualization of {} figures, please wait...".
            format(inputs_seq.shape[0]))

        # Create frames - a list of lists, where each row is a list of artists
        # used to draw a given frame.
        frames = []

        for i, (input_element, target_element, prediction_element,
                (memory, wt, wt_d)) in enumerate(
                    zip(inputs_seq, targets_seq, predictions_seq,
                        self.cell_state_history)):
            # Display information every 10% of figures.
            if (inputs_seq.shape[0] > 10) and (i % (inputs_seq.shape[0] // 10)
                                               == 0):
                logger.info("Generating figure {}/{}".format(
                    i, inputs_seq.shape[0]))

            # Update displayed values on adequate positions.
            inputs_displayed[:, i] = input_element
            targets_displayed[:, i] = target_element
            predictions_displayed[:, i] = prediction_element

            memory_displayed = np.clip(memory[0], -3.0, 3.0)
            head_attention_displayed[:, i] = wt[0, 0, :]
            bookmark_attention_displayed[:, i] = wt_d[0, 0, :]

            # Create "Artists" drawing data on "ImageAxes".
            artists = [None] * len(fig.axes)
            params = {
                'edgecolor': 'black',
                'cmap': 'inferno',
                'linewidths': 1.4e-3
            }

            # Tell artists what to do;)
            artists[0] = ax_memory.pcolormesh(np.transpose(memory_displayed),
                                              vmin=-3.0,
                                              vmax=3.0,
                                              **params)
            artists[1] = ax_attention.pcolormesh(
                np.copy(head_attention_displayed),
                vmin=0.0,
                vmax=1.0,
                **params)
            artists[2] = ax_bookmark.pcolormesh(
                np.copy(bookmark_attention_displayed),
                vmin=0.0,
                vmax=1.0,
                **params)
            artists[3] = ax_inputs.pcolormesh(np.copy(inputs_displayed),
                                              vmin=0.0,
                                              vmax=1.0,
                                              **params)
            artists[4] = ax_targets.pcolormesh(np.copy(targets_displayed),
                                               vmin=0.0,
                                               vmax=1.0,
                                               **params)
            artists[5] = ax_predictions.pcolormesh(
                np.copy(predictions_displayed), vmin=0.0, vmax=1.0, **params)

            # Add "frame".
            frames.append(artists)

        # Plot figure and list of frames.

        self.plotWindow.update(fig, frames)
        return self.plotWindow.is_closed
Пример #3
0
    def plot(self, data_dict, predictions, sample=0):
        """
        Creates a default interactive visualization, with a slider enabling to
        move forth and back along the time axis (iteration over the sequence elements in a given episode).
        The default visualization contains the input, output and target sequences.

        For a more model/problem - dependent visualization, please overwrite this
        method in the derived model class.

        :param data_dict: DataDict containing

           - input sequences: [BATCH_SIZE x SEQUENCE_LENGTH x INPUT_SIZE],
           - target sequences:  [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_SIZE]


        :param predictions: Predicted sequences [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_SIZE]
        :type predictions: torch.tensor

        :param sample: Number of sample in batch (default: 0)
        :type sample: int

        """
        # Check if we are supposed to visualize at all.
        if not self.app_state.visualize:
            return

        # Initialize timePlot window - if required.
        if self.plotWindow is None:
            from miprometheus.utils.time_plot import TimePlot
            self.plotWindow = TimePlot()

        from matplotlib.figure import Figure
        import matplotlib.ticker as ticker

        # Change fonts globally - for all figures/subsplots at once.
        # from matplotlib import rc
        # rc('font', **{'family': 'Times New Roman'})
        import matplotlib.pylab as pylab
        params = {  # 'legend.fontsize': '28',
            'axes.titlesize': 'large',
            'axes.labelsize': 'large',
            'xtick.labelsize': 'medium',
            'ytick.labelsize': 'medium'
        }
        pylab.rcParams.update(params)

        # Create a single "figure layout" for all displayed frames.
        fig = Figure()
        axes = fig.subplots(
            3,
            1,
            sharex=True,
            sharey=False,
            gridspec_kw={'width_ratios': [predictions.shape[0]]})

        # Set ticks.
        axes[0].xaxis.set_major_locator(ticker.MaxNLocator(integer=True))
        axes[0].yaxis.set_major_locator(ticker.MaxNLocator(integer=True))
        axes[1].yaxis.set_major_locator(ticker.MaxNLocator(integer=True))
        axes[2].yaxis.set_major_locator(ticker.MaxNLocator(integer=True))

        # Set labels.
        axes[0].set_title('Inputs')
        axes[0].set_ylabel('Control/Data bits')
        axes[1].set_title('Targets')
        axes[1].set_ylabel('Data bits')
        axes[2].set_title('Predictions')
        axes[2].set_ylabel('Data bits')
        axes[2].set_xlabel('Item number')
        fig.set_tight_layout(True)

        # Detach a sample from batch and copy it to CPU.
        inputs_seq = data_dict['sequences'][sample].cpu().detach().numpy()
        targets_seq = data_dict['targets'][sample].cpu().detach().numpy()
        predictions_seq = predictions[sample].cpu().detach().numpy()

        # Create empty matrices.
        x = np.transpose(np.zeros(inputs_seq.shape))
        y = np.transpose(np.zeros(predictions_seq.shape))
        z = np.transpose(np.zeros(targets_seq.shape))

        # Log sequence length - so the user can understand what is going on.
        self.logger.info(
            "Generating dynamic visualization of {} figures, please wait...".
            format(inputs_seq.shape[0]))

        # Create frames - a list of lists, where each row is a list of artists
        # used to draw a given frame.
        frames = []

        for i, (input_word, prediction_word, target_word) in enumerate(
                zip(inputs_seq, predictions_seq, targets_seq)):
            # Display information every 10% of figures.
            if (inputs_seq.shape[0] > 10) and (i % (inputs_seq.shape[0] // 10)
                                               == 0):
                self.logger.info("Generating figure {}/{}".format(
                    i, inputs_seq.shape[0]))

            # Add words to adequate positions.
            x[:, i] = input_word
            y[:, i] = target_word
            z[:, i] = prediction_word

            # Create "Artists" drawing data on "ImageAxes".
            artists = [None] * len(fig.axes)

            # Tell artists what to do
            artists[0] = axes[0].imshow(x,
                                        interpolation='nearest',
                                        aspect='auto')
            artists[1] = axes[1].imshow(y,
                                        interpolation='nearest',
                                        aspect='auto')
            artists[2] = axes[2].imshow(z,
                                        interpolation='nearest',
                                        aspect='auto')

            # Add "frame".
            frames.append(artists)

        # Plot figure and list of frames.
        self.plotWindow.update(fig, frames)
Пример #4
0
    def plot(self, data_dict, logits, sample=0):
        """
        Visualize the attention weights (``ControlUnit`` & ``ReadUnit``) on the \
        question & feature maps. Dynamic visualization throughout the reasoning \
        steps is possible.

        :param data_dict: DataDict({'images','questions', 'questions_length', 'questions_string', 'questions_type', \
        'targets', 'targets_string', 'index','imgfiles', 'prediction_string'})
        :type data_dict: utils.DataDict

        :param logits: Prediction of the model.
        :type logits: torch.tensor

        :param sample: Index of sample in batch (Default: 0)
        :type sample: int

        :return: True when the user closes the window, False if we do not need to visualize.

        """

        # check whether the visualization is required
        if not self.app_state.visualize:
            return False

        # Initialize timePlot window - if required.
        if self.plotWindow is None:
            from miprometheus.utils.time_plot import TimePlot
            self.plotWindow = TimePlot()

        # unpack data_dict
        s_questions = data_dict['questions_string']
        question_type = data_dict['questions_type']
        answer_string = data_dict['targets_string']
        imgfiles = data_dict['imgfiles']
        prediction_string = data_dict['predictions_string']
        clevr_dir = data_dict['clevr_dir']

        # needed for nltk.word.tokenize
        nltk.download('punkt')
        # tokenize question string using same processing as in the problem
        # class
        words = nltk.word_tokenize(s_questions[sample])

        # Create figure template.
        fig = self.generate_figure_layout()
        # Get axes that artists will draw on.
        (ax_image, ax_attention_image, ax_attention_question,
         ax_step) = fig.axes

        # get the image
        set = imgfiles[sample].split('_')[1]
        image = os.path.join(clevr_dir, 'images', set, imgfiles[sample])
        image = Image.open(image).convert('RGB')
        image = self.transform(image)
        image = image.permute(1, 2, 0)  # [300, 300, 3]

        # get most probable answer -> prediction of the network
        proba_answers = torch.nn.functional.softmax(logits, -1)
        proba_answer = proba_answers[sample].detach().cpu()
        proba_answer = proba_answer.max().numpy()

        # image & attention sizes
        width = image.size(0)
        height = image.size(1)

        frames = []
        for step, (attention_mask, attention_question) in zip(
                range(self.max_step), self.mac_unit.cell_state_history):
            # preprocess attention image, reshape
            attention_size = int(np.sqrt(attention_mask.size(-1)))
            # attention mask has size [batch_size x 1 x(H*W)]
            attention_mask = attention_mask.view(-1, 1, attention_size,
                                                 attention_size)

            # upsample attention mask
            m = torch.nn.Upsample(size=[width, height],
                                  mode='bilinear',
                                  align_corners=True)
            up_sample_attention_mask = m(attention_mask)
            attention_mask = up_sample_attention_mask[sample, 0]

            # preprocess question, pick one sample number
            attention_question = attention_question[sample]

            # Create "Artists" drawing data on "ImageAxes".
            num_artists = len(fig.axes) + 1
            artists = [None] * num_artists

            # set title labels
            ax_image.set_title('CLEVR image: {}'.format(imgfiles[sample]))
            ax_attention_question.set_xticklabels(['h'] + words,
                                                  rotation='vertical',
                                                  fontsize=10)
            ax_step.axis('off')

            # set axis attention labels
            ax_attention_image.set_title('Predicted Answer: ' +
                                         prediction_string[sample] +
                                         ' [ proba: ' +
                                         str.format("{0:.3f}", proba_answer) +
                                         ']  ' + 'Ground Truth: ' +
                                         answer_string[sample])

            # Tell artists what to do:
            artists[0] = ax_image.imshow(image,
                                         interpolation='nearest',
                                         aspect='auto')
            artists[1] = ax_attention_image.imshow(image,
                                                   interpolation='nearest',
                                                   aspect='auto')
            artists[2] = ax_attention_image.imshow(attention_mask,
                                                   interpolation='nearest',
                                                   aspect='auto',
                                                   alpha=0.5,
                                                   cmap='Reds')
            artists[3] = ax_attention_question.imshow(
                attention_question.transpose(1, 0),
                interpolation='nearest',
                aspect='auto',
                cmap='Reds')
            artists[4] = ax_step.text(0,
                                      0.5,
                                      'Reasoning step index: ' + str(step) +
                                      ' | Question type: ' +
                                      question_type[sample],
                                      fontsize=15)

            # Add "frame".
            frames.append(artists)

        # Plot figure and list of frames.
        self.plotWindow.update(fig, frames)

        return self.plotWindow.is_closed
Пример #5
0
    def plot(self, data_dict, logits, sample=0):
        """
        Plots specific information on the model's behavior.

        :param data_dict: DataDict({'sequences', **})
        :type data_dict: utils.DataDict

        :param logits: Predictions of the model
        :type logits: torch.tensor

        :param sample: Index of the sample to visualize. Default to 0.
        :type sample: int

        :return: ``True`` if the user pressed stop, else ``False``.

        """
        # Check if we are supposed to visualize at all.
        if not self.app_state.visualize:
            return False

        # Initialize timePlot window - if required.
        if self.plotWindow is None:
            from miprometheus.utils.time_plot import TimePlot
            self.plotWindow = TimePlot()

        inputs = data_dict['sequences']
        inputs = inputs.cpu().detach().numpy()
        predictions_seq = logits.cpu().detach().numpy()

        input_seq = inputs[sample,
                           0] if len(inputs.shape) == 4 else inputs[sample]

        # Create figure template.
        fig = self.generate_figure_layout()
        # Get axes that artists will draw on.

        # Set intial values of displayed  inputs, targets and predictions -
        # simply zeros.
        inputs_displayed = np.zeros(input_seq.shape)

        # Define Modules
        module_state_displayed_1 = np.zeros(
            (self.cell_state_history[0][4].shape[-1], input_seq.shape[-2]))
        module_state_displayed_2 = np.zeros(
            (self.cell_state_history[0][4].shape[-1], input_seq.shape[-2]))
        module_state_displayed_3 = np.zeros(
            (self.cell_state_history[0][4].shape[-1], input_seq.shape[-2]))
        module_state_displayed_4 = np.zeros(
            (self.cell_state_history[0][-1].shape[-1], input_seq.shape[-2]))

        # Define centers
        center_state_displayed_1 = np.zeros(
            (self.cell_state_history[0][0].shape[-1], input_seq.shape[-2]))
        center_state_displayed_2 = np.zeros(
            (self.cell_state_history[0][0].shape[-1], input_seq.shape[-2]))
        center_state_displayed_3 = np.zeros(
            (self.cell_state_history[0][0].shape[-1], input_seq.shape[-2]))
        center_state_displayed_4 = np.zeros(
            (self.cell_state_history[0][0].shape[-1], input_seq.shape[-2]))

        # Set initial values of memory and attentions.
        # Unpack initial state.

        # Log sequence length - so the user can understand what is going on.

        self.logger.info(
            "Generating dynamic visualization of {} figures, please wait...".
            format(input_seq.shape[0]))

        # Create frames - a list of lists, where each row is a list of artists
        # used to draw a given frame.
        frames = []

        for i, (input_element, state_tuple) in enumerate(
                zip(input_seq, self.cell_state_history)):
            # Display information every 10% of figures.
            if (input_seq.shape[0] > 10) and (i %
                                              (input_seq.shape[0] // 10) == 0):
                self.logger.info("Generating figure {}/{}".format(
                    i, input_seq.shape[0]))

            # Update displayed values on adequate positions.
            inputs_displayed[i, :] = input_element

            # Create "Artists" drawing data on "ImageAxes".
            artists = [None] * len(fig.axes)

            # centers state
            center_state_displayed_1[:, i] = state_tuple[0][sample, :]
            entity = fig.axes[0]
            artists[0] = entity.imshow(center_state_displayed_1,
                                       interpolation='nearest',
                                       aspect='auto')

            center_state_displayed_2[:, i] = state_tuple[1][sample, :]
            entity = fig.axes[1]
            artists[1] = entity.imshow(center_state_displayed_2,
                                       interpolation='nearest',
                                       aspect='auto')

            center_state_displayed_3[:, i] = state_tuple[2][sample, :]
            entity = fig.axes[2]
            artists[2] = entity.imshow(center_state_displayed_3,
                                       interpolation='nearest',
                                       aspect='auto')

            center_state_displayed_4[:, i] = state_tuple[3][sample, :]
            entity = fig.axes[3]
            artists[3] = entity.imshow(center_state_displayed_4,
                                       interpolation='nearest',
                                       aspect='auto')

            # module state
            module_state_displayed_1[:, i] = state_tuple[4][sample, :]
            entity = fig.axes[4]
            artists[4] = entity.imshow(module_state_displayed_1,
                                       interpolation='nearest',
                                       aspect='auto')

            module_state_displayed_2[:, i] = state_tuple[5][sample, :]
            entity = fig.axes[5]
            artists[5] = entity.imshow(module_state_displayed_2,
                                       interpolation='nearest',
                                       aspect='auto')

            module_state_displayed_3[:, i] = state_tuple[6][sample, :]
            entity = fig.axes[6]
            artists[6] = entity.imshow(module_state_displayed_3,
                                       interpolation='nearest',
                                       aspect='auto')

            module_state_displayed_4[:, i] = state_tuple[7][sample, :]
            entity = fig.axes[7]
            artists[7] = entity.imshow(module_state_displayed_4,
                                       interpolation='nearest',
                                       aspect='auto')

            entity = fig.axes[2 * self.num_modules]
            artists[2 * self.num_modules] = entity.imshow(
                inputs_displayed, interpolation='nearest', aspect='auto')

            entity = fig.axes[2 * self.num_modules + 1]
            artists[2 * self.num_modules + 1] = entity.imshow(
                predictions_seq[0, -1, None],
                interpolation='nearest',
                aspect='auto')

            # Add "frame".
            frames.append(artists)

        # Update time plot fir generated list of figures.
        self.plotWindow.update(fig, frames)

        return self.plotWindow.is_closed
Пример #6
0
    def plot(self, data_dict, predictions, sample_number=0):
        """
        Interactive visualization, with a slider enabling to move forth and
        back along the time axis (iteration in a given episode).

        :param data_dict: DataDict containing at least:
            - "sequences": a tensor of input data of size [BATCH_SIZE x LENGTH_SIZE x INPUT_SIZE]
            - "targets": a tensor of targets of size  [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE]

        :param predictions: Prediction sequence [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE]
        :param sample_number: Number of sample in batch (DEFAULT: 0)

        """
        # Check if we are supposed to visualize at all.
        if not self.app_state.visualize:
            return False

        # Initialize timePlot window - if required.
        if self.plotWindow is None:
            from miprometheus.utils.time_plot import TimePlot
            self.plotWindow = TimePlot()

        # import time
        # start_time = time.time()
        inputs_seq = data_dict["sequences"][sample_number].cpu().detach(
        ).numpy()
        targets_seq = data_dict["targets"][sample_number].cpu().detach().numpy(
        )
        predictions_seq = predictions[0].cpu().detach().numpy()

        # temporary for data with additional channel
        if len(inputs_seq.shape) == 3:
            inputs_seq = inputs_seq[0, :, :]

        # Create figure template.
        fig = self.generate_figure_layout()
        # Get axes that artists will draw on.
        (ax_memory, ax_read, ax_write, ax_usage, ax_inputs, ax_targets,
         ax_predictions) = fig.axes

        # Set intial values of displayed  inputs, targets and predictions -
        # simply zeros.
        inputs_displayed = np.transpose(np.zeros(inputs_seq.shape))
        targets_displayed = np.transpose(np.zeros(targets_seq.shape))
        predictions_displayed = np.transpose(np.zeros(predictions_seq.shape))

        head_attention_read = np.zeros(
            (self.cell_state_history[0][3].shape[-1], targets_seq.shape[0]))
        head_attention_write = np.zeros(
            (self.cell_state_history[0][4].shape[-1], targets_seq.shape[0]))
        usage_displayed = np.zeros(
            (self.cell_state_history[0][1].shape[-1], targets_seq.shape[0]))

        # Log sequence length - so the user can understand what is going on.
        logger = logging.getLogger('ModelBase')
        logger.info(
            "Generating dynamic visualization of {} figures, please wait...".
            format(inputs_seq.shape[0]))

        # Create frames - a list of lists, where each row is a list of artists
        # used to draw a given frame.
        frames = []

        for i, (input_element, target_element, prediction_element,
                (memory, usage, links, wt_r, wt_w)) in enumerate(
                    zip(inputs_seq, targets_seq, predictions_seq,
                        self.cell_state_history)):
            # Display information every 10% of figures.
            if (inputs_seq.shape[0] > 10) and (i % (inputs_seq.shape[0] // 10)
                                               == 0):
                logger.info("Generating figure {}/{}".format(
                    i, inputs_seq.shape[0]))

            # Update displayed values on adequate positions.
            inputs_displayed[:, i] = input_element
            targets_displayed[:, i] = target_element
            predictions_displayed[:, i] = prediction_element

            memory_displayed = memory[0]
            # Get attention of head 0.
            head_attention_read[:, i] = wt_r[0, 0, :]
            head_attention_write[:, i] = wt_w[0, 0, :]
            usage_displayed[:, i] = usage[0, :]

            # Create "Artists" drawing data on "ImageAxes".
            artists = [None] * len(fig.axes)

            # Tell artists what to do;)
            artists[0] = ax_memory.imshow(np.transpose(memory_displayed),
                                          interpolation='nearest',
                                          aspect='auto')
            artists[1] = ax_read.imshow(head_attention_read,
                                        interpolation='nearest',
                                        aspect='auto')
            artists[2] = ax_write.imshow(head_attention_write,
                                         interpolation='nearest',
                                         aspect='auto')
            artists[3] = ax_usage.imshow(usage_displayed,
                                         interpolation='nearest',
                                         aspect='auto')
            artists[4] = ax_inputs.imshow(inputs_displayed,
                                          interpolation='nearest',
                                          aspect='auto')
            artists[5] = ax_targets.imshow(targets_displayed,
                                           interpolation='nearest',
                                           aspect='auto')
            artists[6] = ax_predictions.imshow(predictions_displayed,
                                               interpolation='nearest',
                                               aspect='auto')

            # Add "frame".
            frames.append(artists)

        # print("--- %s seconds ---" % (time.time() - start_time))
        # Plot figure and list of frames.

        self.plotWindow.update(fig, frames)
        return self.plotWindow.is_closed