def plot(self, data_tuple, predictions, sample_number=0): """ Plot function to visualize the attention weights on the input sequence as the model is generating the output sequence. :param data_tuple: data_tuple: Data tuple containing input [BATCH_SIZE x SEQUENCE_LENGTH] and target sequences [BATCH_SIZE x SEQUENCE_LENGTH] :param predictions: logits as dict {'inputs_text', 'logits_text'} :param sample_number: """ # Check if we are supposed to visualize at all. if not self.app_state.visualize: return False # Initialize timePlot window - if required. if self.plotWindow is None: from utils.time_plot import TimePlot self.plotWindow = TimePlot() # select 1 random sample in the batch and retrieve corresponding # input_text, logit_text, attention_weight batch_size = data_tuple.targets.shape[0] sample = random.choice(range(batch_size)) # pred should be a dict {'inputs_text', 'logits_text'} created by # Translation.plot_processing() input_text = predictions['inputs_text'][sample].split() print('input sentence: ', predictions['inputs_text'][sample]) target_text = predictions['logits_text'][sample] print('predicted translation:', target_text) attn_weights = self.decoder_attentions[sample].cpu().detach().numpy() import matplotlib.pyplot as plt import matplotlib.ticker as ticker fig = plt.figure() ax = fig.add_subplot(111) cax = ax.matshow(attn_weights) # set up axes ax.set_xticklabels([''] + input_text, rotation=90) ax.set_yticklabels([''] + target_text) # show label at every tick ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) # Plot figure and list of frames. self.plotWindow.update(fig, frames=[[cax]]) # Return True if user closed the window. return self.plotWindow.is_closed
class DWM(SequentialModel): """ Differentiable Working Memory (DWM), is a memory augmented neural network which emulates the human working memory. The DWM shows the same functional characteristics of working memory and robustly learns psychology-inspired tasks and converges faster than comparable state-of-the-art models """ def __init__(self, params): """ " Constructor. Initializes parameters on the basis of dictionary of parameters passed as argument. :param params: Dictionary of parameters. """ # Call base class initialization. super(DWM, self).__init__(params) self.in_dim = params["control_bits"] + params["data_bits"] try: self.output_units = params['output_bits'] except KeyError: self.output_units = params['data_bits'] self.state_units = params["hidden_state_dim"] self.num_heads = params["num_heads"] self.is_cam = params["use_content_addressing"] self.num_shift = params["shift_size"] self.M = params["memory_content_size"] self.memory_addresses_size = params["memory_addresses_size"] self.name = "Differentiable Working Memory (DWM)" # params["name"] # This is for the time plot self.cell_state_history = None # Create the DWM components self.DWMCell = DWMCell(self.in_dim, self.output_units, self.state_units, self.num_heads, self.is_cam, self.num_shift, self.M) def forward(self, data_tuple): """ Forward function of the DWM model. :param data_tuple: contains (inputs, targets) :param data_tuple.inputs: tensor containing the data sequences of the batch [batch, sequence_length, input_size] :param data_tuple.targets: tensor containing the target sequences of the batch [batch, sequence_length, output_size] :returns: output: logits which represent the prediction of DWM [batch, sequence_length, output_size] Example: >>> dwm = DWM(params) >>> inputs = torch.randn(5, 3, 10) >>> targets = torch.randn(5, 3, 20) >>> data_tuple = (inputs, targets) >>> output = dwm(data_tuple) """ # Unpack tuple. (inputs, targets) = data_tuple if self.app_state.visualize: self.cell_state_history = [] output = None # TODO if len(inputs.size()) == 4: inputs = inputs[:, 0, :, :] batch_size = inputs.size(0) seq_length = inputs.size(-2) # The length of the memory is set to be equal to the input length in # case ```self.memory_addresses_size == -1``` if self.memory_addresses_size == -1: if seq_length < self.num_shift: # memory size can't be smaller than num_shift (see # circular_convolution implementation) memory_addresses_size = self.num_shift else: memory_addresses_size = seq_length # a hack for now else: memory_addresses_size = self.memory_addresses_size # Init state cell_state = self.DWMCell.init_state(memory_addresses_size, batch_size) # loop over the different sequences for j in range(seq_length): output_cell, cell_state = self.DWMCell(inputs[..., j, :], cell_state) if output_cell is None: continue output_cell = output_cell[..., None, :] if output is None: output = output_cell # Concatenate output else: output = torch.cat([output, output_cell], dim=-2) # This is for the time plot if self.app_state.visualize: self.cell_state_history.append( (cell_state.memory_state.detach().numpy(), cell_state.interface_state.head_weight.detach().numpy(), cell_state.interface_state.snapshot_weight.detach().numpy( ))) return output # Method to change memory size def set_memory_size(self, mem_size): self.memory_addresses_size = mem_size @staticmethod def generate_figure_layout(): import matplotlib.pyplot as plt import matplotlib.ticker as ticker import matplotlib.gridspec as gridspec # Prepare "generic figure template". # Create figure object. fig = plt.figure(figsize=(16, 9)) # fig.tight_layout() fig.subplots_adjust(left=0.07, right=0.96, top=0.88, bottom=0.15) gs0 = gridspec.GridSpec(1, 2, figure=fig, width_ratios=[5.0, 3.0]) # Create a specific grid for DWM . gs00 = gridspec.GridSpecFromSubplotSpec(1, 3, subplot_spec=gs0[0], width_ratios=[3.0, 4.0, 4.0]) ax_memory = fig.add_subplot(gs00[:, 0]) # all rows, col 0 ax_attention = fig.add_subplot(gs00[:, 1]) # all rows, col 2-3 ax_bookmark = fig.add_subplot(gs00[:, 2]) # all rows, col 4-5 gs01 = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs0[1], hspace=0.5, height_ratios=[1.0, 0.8, 0.8]) ax_inputs = fig.add_subplot(gs01[0, :]) # row 0, span 2 columns ax_targets = fig.add_subplot(gs01[1, :]) # row 0, span 2 columns ax_predictions = fig.add_subplot(gs01[2, :]) # row 0, span 2 columns # Set ticks - for bit axes only (for now). ax_inputs.xaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_inputs.yaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_targets.xaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_targets.yaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_predictions.xaxis.set_major_locator( ticker.MaxNLocator(integer=True)) ax_predictions.yaxis.set_major_locator( ticker.MaxNLocator(integer=True)) ax_memory.xaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_memory.yaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_bookmark.xaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_bookmark.yaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_attention.xaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_attention.yaxis.set_major_locator(ticker.MaxNLocator(integer=True)) # Set labels. ax_inputs.set_title('Inputs') ax_inputs.set_ylabel('Control/Data bits') ax_targets.set_title('Targets') ax_targets.set_ylabel('Data bits') ax_predictions.set_title('Predictions') ax_predictions.set_ylabel('Data bits') ax_predictions.set_xlabel('Item number/Iteration') ax_memory.set_title('Memory') ax_memory.set_ylabel('Memory Addresses') ax_memory.set_xlabel('Content bits') ax_attention.set_title('Head Attention') ax_attention.set_xlabel('Iteration') ax_bookmark.set_title('Bookmark Attention') ax_bookmark.set_xlabel('Iteration') return fig def plot(self, data_tuple, predictions, sample_number=0): """ Interactive visualization, with a slider enabling to move forth and back along the time axis (iteration in a given episode). :param data_tuple: Data tuple containing - input [BATCH_SIZE x SEQUENCE_LENGTH x INPUT_DATA_SIZE] and - target sequences [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE] :param predictions: Prediction sequence [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE] :param sample_number: Number of sample in batch (DEFAULT: 0) """ # Check if we are supposed to visualize at all. if not self.app_state.visualize: return False # Initialize timePlot window - if required. if self.plotWindow is None: from utils.time_plot import TimePlot self.plotWindow = TimePlot() # import time # start_time = time.time() inputs_seq = data_tuple.inputs[0].cpu().detach().numpy() targets_seq = data_tuple.targets[0].cpu().detach().numpy() predictions_seq = predictions[0].cpu().detach() predictions_seq = torch.sigmoid(predictions_seq).numpy() # temporary for data with additional channel if len(inputs_seq.shape) == 3: inputs_seq = inputs_seq[0, :, :] # Create figure template. fig = self.generate_figure_layout() # Get axes that artists will draw on. (ax_memory, ax_attention, ax_bookmark, ax_inputs, ax_targets, ax_predictions) = fig.axes # Set intial values of displayed inputs, targets and predictions - # simply zeros. inputs_displayed = np.transpose(np.zeros(inputs_seq.shape)) targets_displayed = np.transpose(np.zeros(targets_seq.shape)) predictions_displayed = np.transpose(np.zeros(predictions_seq.shape)) head_attention_displayed = np.zeros( (self.cell_state_history[0][1].shape[-1], targets_seq.shape[0])) bookmark_attention_displayed = np.zeros( (self.cell_state_history[0][2].shape[-1], targets_seq.shape[0])) # Log sequence length - so the user can understand what is going on. logger = logging.getLogger('ModelBase') logger.info( "Generating dynamic visualization of {} figures, please wait...". format(inputs_seq.shape[0])) # Create frames - a list of lists, where each row is a list of artists # used to draw a given frame. frames = [] for i, (input_element, target_element, prediction_element, (memory, wt, wt_d)) in enumerate( zip(inputs_seq, targets_seq, predictions_seq, self.cell_state_history)): # Display information every 10% of figures. if (inputs_seq.shape[0] > 10) and (i % (inputs_seq.shape[0] // 10) == 0): logger.info("Generating figure {}/{}".format( i, inputs_seq.shape[0])) # Update displayed values on adequate positions. inputs_displayed[:, i] = input_element targets_displayed[:, i] = target_element predictions_displayed[:, i] = prediction_element memory_displayed = np.clip(memory[0], -3.0, 3.0) head_attention_displayed[:, i] = wt[0, 0, :] bookmark_attention_displayed[:, i] = wt_d[0, 0, :] # Create "Artists" drawing data on "ImageAxes". artists = [None] * len(fig.axes) params = { 'edgecolor': 'black', 'cmap': 'inferno', 'linewidths': 1.4e-3 } # Tell artists what to do;) artists[0] = ax_memory.pcolormesh(np.transpose(memory_displayed), vmin=-3.0, vmax=3.0, **params) artists[1] = ax_attention.pcolormesh( np.copy(head_attention_displayed), vmin=0.0, vmax=1.0, **params) artists[2] = ax_bookmark.pcolormesh( np.copy(bookmark_attention_displayed), vmin=0.0, vmax=1.0, **params) artists[3] = ax_inputs.pcolormesh(np.copy(inputs_displayed), vmin=0.0, vmax=1.0, **params) artists[4] = ax_targets.pcolormesh(np.copy(targets_displayed), vmin=0.0, vmax=1.0, **params) artists[5] = ax_predictions.pcolormesh( np.copy(predictions_displayed), vmin=0.0, vmax=1.0, **params) # Add "frame". frames.append(artists) # Plot figure and list of frames. self.plotWindow.update(fig, frames) return self.plotWindow.is_closed
def plot(self, data_tuple, predictions, sample_number=0): """ Interactive visualization, with a slider enabling to move forth and back along the time axis (iteration in a given episode). :param data_tuple: Data tuple containing - input [BATCH_SIZE x SEQUENCE_LENGTH x INPUT_DATA_SIZE] and - target sequences [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE] :param predictions: Prediction sequence [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE] :param sample_number: Number of sample in batch (DEFAULT: 0) """ # Check if we are supposed to visualize at all. if not self.app_state.visualize: return False # Initialize timePlot window - if required. if self.plotWindow is None: from utils.time_plot import TimePlot self.plotWindow = TimePlot() # import time # start_time = time.time() inputs_seq = data_tuple.inputs[0].cpu().detach().numpy() targets_seq = data_tuple.targets[0].cpu().detach().numpy() predictions_seq = predictions[0].cpu().detach() predictions_seq = torch.sigmoid(predictions_seq).numpy() # temporary for data with additional channel if len(inputs_seq.shape) == 3: inputs_seq = inputs_seq[0, :, :] # Create figure template. fig = self.generate_figure_layout() # Get axes that artists will draw on. (ax_memory, ax_attention, ax_bookmark, ax_inputs, ax_targets, ax_predictions) = fig.axes # Set intial values of displayed inputs, targets and predictions - # simply zeros. inputs_displayed = np.transpose(np.zeros(inputs_seq.shape)) targets_displayed = np.transpose(np.zeros(targets_seq.shape)) predictions_displayed = np.transpose(np.zeros(predictions_seq.shape)) head_attention_displayed = np.zeros( (self.cell_state_history[0][1].shape[-1], targets_seq.shape[0])) bookmark_attention_displayed = np.zeros( (self.cell_state_history[0][2].shape[-1], targets_seq.shape[0])) # Log sequence length - so the user can understand what is going on. logger = logging.getLogger('ModelBase') logger.info( "Generating dynamic visualization of {} figures, please wait...". format(inputs_seq.shape[0])) # Create frames - a list of lists, where each row is a list of artists # used to draw a given frame. frames = [] for i, (input_element, target_element, prediction_element, (memory, wt, wt_d)) in enumerate( zip(inputs_seq, targets_seq, predictions_seq, self.cell_state_history)): # Display information every 10% of figures. if (inputs_seq.shape[0] > 10) and (i % (inputs_seq.shape[0] // 10) == 0): logger.info("Generating figure {}/{}".format( i, inputs_seq.shape[0])) # Update displayed values on adequate positions. inputs_displayed[:, i] = input_element targets_displayed[:, i] = target_element predictions_displayed[:, i] = prediction_element memory_displayed = np.clip(memory[0], -3.0, 3.0) head_attention_displayed[:, i] = wt[0, 0, :] bookmark_attention_displayed[:, i] = wt_d[0, 0, :] # Create "Artists" drawing data on "ImageAxes". artists = [None] * len(fig.axes) params = { 'edgecolor': 'black', 'cmap': 'inferno', 'linewidths': 1.4e-3 } # Tell artists what to do;) artists[0] = ax_memory.pcolormesh(np.transpose(memory_displayed), vmin=-3.0, vmax=3.0, **params) artists[1] = ax_attention.pcolormesh( np.copy(head_attention_displayed), vmin=0.0, vmax=1.0, **params) artists[2] = ax_bookmark.pcolormesh( np.copy(bookmark_attention_displayed), vmin=0.0, vmax=1.0, **params) artists[3] = ax_inputs.pcolormesh(np.copy(inputs_displayed), vmin=0.0, vmax=1.0, **params) artists[4] = ax_targets.pcolormesh(np.copy(targets_displayed), vmin=0.0, vmax=1.0, **params) artists[5] = ax_predictions.pcolormesh( np.copy(predictions_displayed), vmin=0.0, vmax=1.0, **params) # Add "frame". frames.append(artists) # Plot figure and list of frames. self.plotWindow.update(fig, frames) return self.plotWindow.is_closed
class ThalNetModel(SequentialModel): """ ThalNet model consists of recurrent neural modules that send features through a routing center, it was proposed in the following paper https://arxiv.org/pdf/1706.05744.pdf. """ def __init__(self, params): """ Constructor of the ThalNetModel. :param params: Parameters read from configuration file. """ # Call base class initialization. super(ThalNetModel, self).__init__(params) self.context_input_size = params['context_input_size'] self.input_size = params['input_size'] self.output_size = params['output_size'] self.center_size = params['num_modules'] * \ params['center_size_per_module'] self.center_size_per_module = params['center_size_per_module'] self.num_modules = params['num_modules'] self.output_center_size = self.output_size + self.center_size_per_module # This is for the time plot self.cell_state_history = None # Create the DWM components self.ThalnetCell = ThalNetCell(self.input_size, self.output_size, self.context_input_size, self.center_size_per_module, self.num_modules) def forward(self, data_tuple): # x : batch_size, seq_len, input_size """ Forward run of the ThalNetModel model. :param data_tuple: (inputs [batch_size, sequence_length, input_size], targets[batch_size, sequence_length, output_size]) :returns: output: prediction [batch_size, sequence_length, output_size] """ (inputs, _) = data_tuple if self.app_state.visualize: self.cell_state_history = [] output = None batch_size = inputs.size(0) seq_length = inputs.size(-2) # init state cell_state = self.ThalnetCell.init_state(batch_size) for j in range(seq_length): output_cell, cell_state = self.ThalnetCell(inputs[..., j, :], cell_state) if output_cell is None: continue output_cell = output_cell[..., None, :] if output is None: output = output_cell # concatenate output else: output = torch.cat([output, output_cell], dim=-2) # This is for the time plot if self.app_state.visualize: self.cell_state_history.append([ cell_state[i][0].detach().numpy() for i in range(self.num_modules) ] + [ cell_state[i][1].hidden_state.detach().numpy() for i in range(self.num_modules) ]) return output def generate_figure_layout(self): from matplotlib.figure import Figure import matplotlib.ticker as ticker from matplotlib import rc import matplotlib.gridspec as gridspec # Change fonts globally - for all figures/subsplots at once. rc('font', **{'family': 'Times New Roman'}) # Prepare "generic figure template". # Create figure object. fig = Figure() # Create a specific grid for NTM . gs = gridspec.GridSpec(4, 3) # modules & centers subplots ax_center = [ fig.add_subplot(gs[i, 0]) for i in range(self.num_modules) ] ax_module = [ fig.add_subplot(gs[i, 1]) for i in range(self.num_modules) ] # # inputs & prediction subplot ax_inputs = fig.add_subplot(gs[0, 2]) ax_pred = fig.add_subplot(gs[2, 2]) # Set ticks - for bit axes only (for now). ax_inputs.xaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_inputs.yaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_pred.xaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_pred.yaxis.set_major_locator(ticker.MaxNLocator(integer=True)) # Set labels. ax_inputs.set_title('Inputs') ax_inputs.set_ylabel('num_row') ax_inputs.set_xlabel('num_columns') ax_pred.set_title('Prediction') ax_pred.set_xlabel('num_classes') # centers ax_center[0].set_title('center states') ax_center[3].set_xlabel('iteration') ax_center[0].set_ylabel('center size') ax_center[1].set_ylabel('center size') ax_center[2].set_ylabel('center size') ax_center[3].set_ylabel('center size') # modules ax_module[0].set_title('module states') ax_module[3].set_xlabel('iteration') ax_module[0].set_ylabel('module state size') ax_module[1].set_ylabel('module state size') ax_module[2].set_ylabel('module state size') ax_module[3].set_ylabel('module state size') # Return figure. return fig def plot(self, data_tuple, logits, sample_number=0): # Check if we are supposed to visualize at all. if not self.app_state.visualize: return False # Initialize timePlot window - if required. if self.plotWindow is None: from utils.time_plot import TimePlot self.plotWindow = TimePlot() (inputs, _) = data_tuple inputs = inputs.cpu().detach().numpy() predictions_seq = logits.cpu().detach().numpy() input_seq = inputs[sample_number, 0] if len( inputs.shape) == 4 else inputs[sample_number] # Create figure template. fig = self.generate_figure_layout() # Get axes that artists will draw on. # Set intial values of displayed inputs, targets and predictions - # simply zeros. inputs_displayed = np.zeros(input_seq.shape) # Define Modules module_state_displayed_1 = np.zeros( (self.cell_state_history[0][4].shape[-1], input_seq.shape[-2])) module_state_displayed_2 = np.zeros( (self.cell_state_history[0][4].shape[-1], input_seq.shape[-2])) module_state_displayed_3 = np.zeros( (self.cell_state_history[0][4].shape[-1], input_seq.shape[-2])) module_state_displayed_4 = np.zeros( (self.cell_state_history[0][-1].shape[-1], input_seq.shape[-2])) # Define centers center_state_displayed_1 = np.zeros( (self.cell_state_history[0][0].shape[-1], input_seq.shape[-2])) center_state_displayed_2 = np.zeros( (self.cell_state_history[0][0].shape[-1], input_seq.shape[-2])) center_state_displayed_3 = np.zeros( (self.cell_state_history[0][0].shape[-1], input_seq.shape[-2])) center_state_displayed_4 = np.zeros( (self.cell_state_history[0][0].shape[-1], input_seq.shape[-2])) #modules_plot = [module_state_displayed for _ in range(self.num_modules)] #center_plot = [center_state_displayed for _ in range(self.num_modules)] # Set initial values of memory and attentions. # Unpack initial state. # Log sequence length - so the user can understand what is going on. logger = logging.getLogger('ModelBase') logger.info( "Generating dynamic visualization of {} figures, please wait...". format(input_seq.shape[0])) # Create frames - a list of lists, where each row is a list of artists # used to draw a given frame. frames = [] for i, (input_element, state_tuple) in enumerate( zip(input_seq, self.cell_state_history)): # Display information every 10% of figures. if (input_seq.shape[0] > 10) and (i % (input_seq.shape[0] // 10) == 0): logger.info("Generating figure {}/{}".format( i, input_seq.shape[0])) # Update displayed values on adequate positions. inputs_displayed[i, :] = input_element # Create "Artists" drawing data on "ImageAxes". artists = [None] * len(fig.axes) # centers state center_state_displayed_1[:, i] = state_tuple[0][sample_number, :] entity = fig.axes[0] artists[0] = entity.imshow(center_state_displayed_1, interpolation='nearest', aspect='auto') center_state_displayed_2[:, i] = state_tuple[1][sample_number, :] entity = fig.axes[1] artists[1] = entity.imshow(center_state_displayed_2, interpolation='nearest', aspect='auto') center_state_displayed_3[:, i] = state_tuple[2][sample_number, :] entity = fig.axes[2] artists[2] = entity.imshow(center_state_displayed_3, interpolation='nearest', aspect='auto') center_state_displayed_4[:, i] = state_tuple[3][sample_number, :] entity = fig.axes[3] artists[3] = entity.imshow(center_state_displayed_4, interpolation='nearest', aspect='auto') # module state module_state_displayed_1[:, i] = state_tuple[4][sample_number, :] entity = fig.axes[4] artists[4] = entity.imshow(module_state_displayed_1, interpolation='nearest', aspect='auto') module_state_displayed_2[:, i] = state_tuple[5][sample_number, :] entity = fig.axes[5] artists[5] = entity.imshow(module_state_displayed_2, interpolation='nearest', aspect='auto') module_state_displayed_3[:, i] = state_tuple[6][sample_number, :] entity = fig.axes[6] artists[6] = entity.imshow(module_state_displayed_3, interpolation='nearest', aspect='auto') module_state_displayed_4[:, i] = state_tuple[7][sample_number, :] entity = fig.axes[7] artists[7] = entity.imshow(module_state_displayed_4, interpolation='nearest', aspect='auto') # h = 0 # for j, state in enumerate(state_tuple): # # Get attention of head 0. # # # "Show" data on "axes". # entity = fig.axes[j] # if self.num_modules <= h < 2 * self.num_modules : # modules_plot[j - self.num_modules][:, i] = state[sample_number, :] # artists[j] = entity.imshow(modules_plot[j - self.num_modules], interpolation='nearest', aspect='auto') # # else: # center_plot[j][:, i] = state[sample_number, :] # artists[j] = entity.imshow(center_plot[j], interpolation='nearest', aspect='auto') # # h += 1 entity = fig.axes[2 * self.num_modules] artists[2 * self.num_modules] = entity.imshow( inputs_displayed, interpolation='nearest', aspect='auto') entity = fig.axes[2 * self.num_modules + 1] artists[2 * self.num_modules + 1] = entity.imshow( predictions_seq[0, -1, None], interpolation='nearest', aspect='auto') # Add "frame". frames.append(artists) # print("--- %s seconds ---" % (time.time() - start_time)) # Update time plot fir generated list of figures. self.plotWindow.update(fig, frames) return self.plotWindow.is_closed
def plot(self, data_tuple, logits, sample_number=0): # Check if we are supposed to visualize at all. if not self.app_state.visualize: return False # Initialize timePlot window - if required. if self.plotWindow is None: from utils.time_plot import TimePlot self.plotWindow = TimePlot() (inputs, _) = data_tuple inputs = inputs.cpu().detach().numpy() predictions_seq = logits.cpu().detach().numpy() input_seq = inputs[sample_number, 0] if len( inputs.shape) == 4 else inputs[sample_number] # Create figure template. fig = self.generate_figure_layout() # Get axes that artists will draw on. # Set intial values of displayed inputs, targets and predictions - # simply zeros. inputs_displayed = np.zeros(input_seq.shape) # Define Modules module_state_displayed_1 = np.zeros( (self.cell_state_history[0][4].shape[-1], input_seq.shape[-2])) module_state_displayed_2 = np.zeros( (self.cell_state_history[0][4].shape[-1], input_seq.shape[-2])) module_state_displayed_3 = np.zeros( (self.cell_state_history[0][4].shape[-1], input_seq.shape[-2])) module_state_displayed_4 = np.zeros( (self.cell_state_history[0][-1].shape[-1], input_seq.shape[-2])) # Define centers center_state_displayed_1 = np.zeros( (self.cell_state_history[0][0].shape[-1], input_seq.shape[-2])) center_state_displayed_2 = np.zeros( (self.cell_state_history[0][0].shape[-1], input_seq.shape[-2])) center_state_displayed_3 = np.zeros( (self.cell_state_history[0][0].shape[-1], input_seq.shape[-2])) center_state_displayed_4 = np.zeros( (self.cell_state_history[0][0].shape[-1], input_seq.shape[-2])) #modules_plot = [module_state_displayed for _ in range(self.num_modules)] #center_plot = [center_state_displayed for _ in range(self.num_modules)] # Set initial values of memory and attentions. # Unpack initial state. # Log sequence length - so the user can understand what is going on. logger = logging.getLogger('ModelBase') logger.info( "Generating dynamic visualization of {} figures, please wait...". format(input_seq.shape[0])) # Create frames - a list of lists, where each row is a list of artists # used to draw a given frame. frames = [] for i, (input_element, state_tuple) in enumerate( zip(input_seq, self.cell_state_history)): # Display information every 10% of figures. if (input_seq.shape[0] > 10) and (i % (input_seq.shape[0] // 10) == 0): logger.info("Generating figure {}/{}".format( i, input_seq.shape[0])) # Update displayed values on adequate positions. inputs_displayed[i, :] = input_element # Create "Artists" drawing data on "ImageAxes". artists = [None] * len(fig.axes) # centers state center_state_displayed_1[:, i] = state_tuple[0][sample_number, :] entity = fig.axes[0] artists[0] = entity.imshow(center_state_displayed_1, interpolation='nearest', aspect='auto') center_state_displayed_2[:, i] = state_tuple[1][sample_number, :] entity = fig.axes[1] artists[1] = entity.imshow(center_state_displayed_2, interpolation='nearest', aspect='auto') center_state_displayed_3[:, i] = state_tuple[2][sample_number, :] entity = fig.axes[2] artists[2] = entity.imshow(center_state_displayed_3, interpolation='nearest', aspect='auto') center_state_displayed_4[:, i] = state_tuple[3][sample_number, :] entity = fig.axes[3] artists[3] = entity.imshow(center_state_displayed_4, interpolation='nearest', aspect='auto') # module state module_state_displayed_1[:, i] = state_tuple[4][sample_number, :] entity = fig.axes[4] artists[4] = entity.imshow(module_state_displayed_1, interpolation='nearest', aspect='auto') module_state_displayed_2[:, i] = state_tuple[5][sample_number, :] entity = fig.axes[5] artists[5] = entity.imshow(module_state_displayed_2, interpolation='nearest', aspect='auto') module_state_displayed_3[:, i] = state_tuple[6][sample_number, :] entity = fig.axes[6] artists[6] = entity.imshow(module_state_displayed_3, interpolation='nearest', aspect='auto') module_state_displayed_4[:, i] = state_tuple[7][sample_number, :] entity = fig.axes[7] artists[7] = entity.imshow(module_state_displayed_4, interpolation='nearest', aspect='auto') # h = 0 # for j, state in enumerate(state_tuple): # # Get attention of head 0. # # # "Show" data on "axes". # entity = fig.axes[j] # if self.num_modules <= h < 2 * self.num_modules : # modules_plot[j - self.num_modules][:, i] = state[sample_number, :] # artists[j] = entity.imshow(modules_plot[j - self.num_modules], interpolation='nearest', aspect='auto') # # else: # center_plot[j][:, i] = state[sample_number, :] # artists[j] = entity.imshow(center_plot[j], interpolation='nearest', aspect='auto') # # h += 1 entity = fig.axes[2 * self.num_modules] artists[2 * self.num_modules] = entity.imshow( inputs_displayed, interpolation='nearest', aspect='auto') entity = fig.axes[2 * self.num_modules + 1] artists[2 * self.num_modules + 1] = entity.imshow( predictions_seq[0, -1, None], interpolation='nearest', aspect='auto') # Add "frame". frames.append(artists) # print("--- %s seconds ---" % (time.time() - start_time)) # Update time plot fir generated list of figures. self.plotWindow.update(fig, frames) return self.plotWindow.is_closed
class MACNetwork(Model): """ Implementation of the entire MAC network. """ def __init__(self, params): """ Constructor for the MAC network. :param params: dict of parameters. """ # call base constructor super(MACNetwork, self).__init__(params) # parse params dict self.dim = params['dim'] self.embed_hidden = params['embed_hidden'] self.max_step = params['max_step'] self.self_attention = params['self_attention'] self.memory_gate = params['memory_gate'] self.nb_classes = params['nb_classes'] self.dropout = params['dropout'] self.image = [] # instantiate units self.input_unit = InputUnit(dim=self.dim, embedded_dim=self.embed_hidden) self.mac_unit = MACUnit(dim=self.dim, max_step=self.max_step, self_attention=self.self_attention, memory_gate=self.memory_gate, dropout=self.dropout) self.output_unit = OutputUnit(dim=self.dim, nb_classes=self.nb_classes) # transform for the image plotting self.transform = transforms.Compose( [transforms.Resize([224, 224]), transforms.ToTensor()]) def forward(self, data_tuple, dropout=0.15): # reset cell state history for visualization if self.app_state.visualize: self.mac_unit.cell_state_history = [] # unpack data_tuple inner_tuple, _ = data_tuple image_questions_tuple, questions_len = inner_tuple images, questions = image_questions_tuple # input unit img, kb_proj, lstm_out, h = self.input_unit(questions, questions_len, images) self.image = kb_proj # recurrent MAC cells memory = self.mac_unit(lstm_out, h, img, kb_proj) # output unit logits = self.output_unit(memory, h) return logits def generate_figure_layout(self): """ Generate a figure layout for the attention visualization (done in MACNetwork.plot()) :return: figure layout. """ from matplotlib.figure import Figure import matplotlib.ticker as ticker import matplotlib.gridspec as gridspec import matplotlib.pylab as pylab params = { 'axes.titlesize': 'large', 'axes.labelsize': 'large', 'xtick.labelsize': 'medium', 'ytick.labelsize': 'medium' } pylab.rcParams.update(params) # Prepare "generic figure template". # Create figure object. fig = Figure() # Create a specific grid for MAC. gs = gridspec.GridSpec(6, 2) # subplots: original image, attention on image & question, step index ax_image = fig.add_subplot(gs[2:6, 0]) ax_attention_image = fig.add_subplot(gs[2:6, 1]) ax_attention_question = fig.add_subplot(gs[0, :]) ax_step = fig.add_subplot(gs[1, 0]) # Set axis ticks ax_image.xaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_image.yaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_attention_image.xaxis.set_major_locator( ticker.MaxNLocator(integer=True)) ax_attention_image.yaxis.set_major_locator( ticker.MaxNLocator(integer=True)) # question ticks ax_attention_question.xaxis.set_major_locator( ticker.MaxNLocator(nbins=40)) ax_step.xaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_step.yaxis.set_major_locator(ticker.MaxNLocator(integer=True)) fig.set_tight_layout(True) return fig def plot(self, aux_tuple, logits, sample_number=0): """ Visualize the attention weights (Control Unit & Read Unit) on the question & feature maps. Dynamic visualization trhoughout the reasoning steps possible. :param aux_tuple: aux_tuple (), transformed by CLEVR.plot_preprocessing() -> (s_questions, answer_string, imgfiles, set, prediction_string, clevr_dir) :param logits: prediction of the network :param sample_number: Number of sample in batch (DEFAULT: 0) :return: True when the user closes the window, False if we do not need to visualize. """ # check whether the visualization is required if not self.app_state.visualize: return False # Initialize timePlot window - if required. if self.plotWindow is None: from utils.time_plot import TimePlot self.plotWindow = TimePlot() # attention mask [batch_size x 1 x(H*W)] # unpack aux_tuple (s_questions, answer_string, imgfiles, set, prediction_string, clevr_dir) = aux_tuple # needed for nltk.word.tokenize nltk.download('punkt') # tokenize question string using same processing as in the problem # class words = nltk.word_tokenize(s_questions[sample_number]) # Create figure template. fig = self.generate_figure_layout() # Get axes that artists will draw on. (ax_image, ax_attention_image, ax_attention_question, ax_step) = fig.axes # get the image image = os.path.join(clevr_dir, 'images', set, imgfiles[sample_number]) image = Image.open(image).convert('RGB') image = self.transform(image) image = image.permute(1, 2, 0) # [300, 300, 3] # get most probable answer -> prediction of the network proba_answers = F.softmax(logits, -1) proba_answer = proba_answers[sample_number].detach().cpu() proba_answer = proba_answer.max().numpy() # image & attention sizes width = image.size(0) height = image.size(1) frames = [] for step, (attention_mask, attention_question) in zip( range(self.max_step), self.mac_unit.cell_state_history): # preprocess attention image, reshape attention_size = int(np.sqrt(attention_mask.size(-1))) attention_mask = attention_mask.view(-1, 1, attention_size, attention_size) # upsample attention mask m = torch.nn.Upsample(size=[width, height], mode='bilinear', align_corners=True) up_sample_attention_mask = m(attention_mask) attention_mask = up_sample_attention_mask[sample_number, 0] # preprocess question, pick one sample number attention_question = attention_question[sample_number] # Create "Artists" drawing data on "ImageAxes". num_artists = len(fig.axes) + 1 artists = [None] * num_artists # set title labels ax_image.set_title('CLEVR image: {}'.format( imgfiles[sample_number])) ax_attention_question.set_xticklabels(['h'] + words, rotation='vertical', fontsize=10) ax_step.axis('off') # set axis attention labels ax_attention_image.set_title('Predicted Answer: ' + prediction_string[sample_number] + ' [ proba: ' + str.format("{0:.3f}", proba_answer) + '] ' + 'Ground Truth: ' + answer_string[sample_number]) # Tell artists what to do: artists[0] = ax_image.imshow(image, interpolation='nearest', aspect='auto') artists[1] = ax_attention_image.imshow(image, interpolation='nearest', aspect='auto') artists[2] = ax_attention_image.imshow(attention_mask, interpolation='nearest', aspect='auto', alpha=0.5, cmap='Reds') artists[3] = ax_attention_question.imshow( attention_question.transpose(1, 0), interpolation='nearest', aspect='auto', cmap='Reds') artists[4] = ax_step.text(0, 0.5, 'Reasoning step index: ' + str(step), fontsize=15) # Add "frame". frames.append(artists) # Plot figure and list of frames. self.plotWindow.update(fig, frames) return self.plotWindow.is_closed
def plot(self, aux_tuple, logits, sample_number=0): """ Visualize the attention weights (Control Unit & Read Unit) on the question & feature maps. Dynamic visualization trhoughout the reasoning steps possible. :param aux_tuple: aux_tuple (), transformed by CLEVR.plot_preprocessing() -> (s_questions, answer_string, imgfiles, set, prediction_string, clevr_dir) :param logits: prediction of the network :param sample_number: Number of sample in batch (DEFAULT: 0) :return: True when the user closes the window, False if we do not need to visualize. """ # check whether the visualization is required if not self.app_state.visualize: return False # Initialize timePlot window - if required. if self.plotWindow is None: from utils.time_plot import TimePlot self.plotWindow = TimePlot() # attention mask [batch_size x 1 x(H*W)] # unpack aux_tuple (s_questions, answer_string, imgfiles, set, prediction_string, clevr_dir) = aux_tuple # needed for nltk.word.tokenize nltk.download('punkt') # tokenize question string using same processing as in the problem # class words = nltk.word_tokenize(s_questions[sample_number]) # Create figure template. fig = self.generate_figure_layout() # Get axes that artists will draw on. (ax_image, ax_attention_image, ax_attention_question, ax_step) = fig.axes # get the image image = os.path.join(clevr_dir, 'images', set, imgfiles[sample_number]) image = Image.open(image).convert('RGB') image = self.transform(image) image = image.permute(1, 2, 0) # [300, 300, 3] # get most probable answer -> prediction of the network proba_answers = F.softmax(logits, -1) proba_answer = proba_answers[sample_number].detach().cpu() proba_answer = proba_answer.max().numpy() # image & attention sizes width = image.size(0) height = image.size(1) frames = [] for step, (attention_mask, attention_question) in zip( range(self.max_step), self.mac_unit.cell_state_history): # preprocess attention image, reshape attention_size = int(np.sqrt(attention_mask.size(-1))) attention_mask = attention_mask.view(-1, 1, attention_size, attention_size) # upsample attention mask m = torch.nn.Upsample(size=[width, height], mode='bilinear', align_corners=True) up_sample_attention_mask = m(attention_mask) attention_mask = up_sample_attention_mask[sample_number, 0] # preprocess question, pick one sample number attention_question = attention_question[sample_number] # Create "Artists" drawing data on "ImageAxes". num_artists = len(fig.axes) + 1 artists = [None] * num_artists # set title labels ax_image.set_title('CLEVR image: {}'.format( imgfiles[sample_number])) ax_attention_question.set_xticklabels(['h'] + words, rotation='vertical', fontsize=10) ax_step.axis('off') # set axis attention labels ax_attention_image.set_title('Predicted Answer: ' + prediction_string[sample_number] + ' [ proba: ' + str.format("{0:.3f}", proba_answer) + '] ' + 'Ground Truth: ' + answer_string[sample_number]) # Tell artists what to do: artists[0] = ax_image.imshow(image, interpolation='nearest', aspect='auto') artists[1] = ax_attention_image.imshow(image, interpolation='nearest', aspect='auto') artists[2] = ax_attention_image.imshow(attention_mask, interpolation='nearest', aspect='auto', alpha=0.5, cmap='Reds') artists[3] = ax_attention_question.imshow( attention_question.transpose(1, 0), interpolation='nearest', aspect='auto', cmap='Reds') artists[4] = ax_step.text(0, 0.5, 'Reasoning step index: ' + str(step), fontsize=15) # Add "frame". frames.append(artists) # Plot figure and list of frames. self.plotWindow.update(fig, frames) return self.plotWindow.is_closed
class SimpleEncoderDecoder(SequentialModel): """ Sequence to Sequence model based on EncoderRNN & DecoderRNN. """ def __init__(self, params): """ Initializes the Encoder-Decoder network. :param params: dict containing the main parameters set: - max_length: maximal length of the input / output sequence of words: i.e, max length of the sentences to translate -> upper limit of seq_length - input_voc_size: should correspond to the length of the vocabulary set of the input language - hidden size: size of the embedding & hidden states vectors. - output_voc_size: should correspond to the length of the vocabulary set of the output language """ # call base constructor super(SimpleEncoderDecoder, self).__init__(params) self.max_length = params['max_length'] # parse params to create encoder self.input_voc_size = params['input_voc_size'] self.hidden_size = params['hidden_size'] self.encoder_bidirectional = params['encoder_bidirectional'] # create encoder self.encoder = EncoderRNN(input_voc_size=self.input_voc_size, hidden_size=self.hidden_size, bidirectional=self.encoder_bidirectional, n_layers=1) # parse param to create decoder self.output_voc_size = params['output_voc_size'] # create base decoder #self.decoder = DecoderRNN(hidden_size=self.hidden_size, output_voc_size=self.output_voc_size) # create attention decoder self.decoder = AttnDecoderRNN( self.hidden_size, self.output_voc_size, dropout_p=0.1, max_length=self.max_length, encoder_bidirectional=self.encoder_bidirectional) print('EncoderDecoderRNN (with Bahdanau attention) created.\n') def plot(self, data_tuple, predictions, sample_number=0): """ Plot function to visualize the attention weights on the input sequence as the model is generating the output sequence. :param data_tuple: data_tuple: Data tuple containing input [BATCH_SIZE x SEQUENCE_LENGTH] and target sequences [BATCH_SIZE x SEQUENCE_LENGTH] :param predictions: logits as dict {'inputs_text', 'logits_text'} :param sample_number: """ # Check if we are supposed to visualize at all. if not self.app_state.visualize: return False # Initialize timePlot window - if required. if self.plotWindow is None: from utils.time_plot import TimePlot self.plotWindow = TimePlot() # select 1 random sample in the batch and retrieve corresponding # input_text, logit_text, attention_weight batch_size = data_tuple.targets.shape[0] sample = random.choice(range(batch_size)) # pred should be a dict {'inputs_text', 'logits_text'} created by # Translation.plot_processing() input_text = predictions['inputs_text'][sample].split() print('input sentence: ', predictions['inputs_text'][sample]) target_text = predictions['logits_text'][sample] print('predicted translation:', target_text) attn_weights = self.decoder_attentions[sample].cpu().detach().numpy() import matplotlib.pyplot as plt import matplotlib.ticker as ticker fig = plt.figure() ax = fig.add_subplot(111) cax = ax.matshow(attn_weights) # set up axes ax.set_xticklabels([''] + input_text, rotation=90) ax.set_yticklabels([''] + target_text) # show label at every tick ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) # Plot figure and list of frames. self.plotWindow.update(fig, frames=[[cax]]) # Return True if user closed the window. return self.plotWindow.is_closed # global forward pass def forward(self, data_tuple): """ Runs the network. :param data_tuple: (input_tensor, target_tensor) tuple :return: decoder outputs: of shape [target_length x output_voc_size] containing the probability distributions over the vocabulary set for each word in the target sequence. """ # unpack data_tuple (inputs, targets) = data_tuple # get batch_size (dim 0) batch_size = inputs.size(0) # reshape tensors: from [batch_size x max_seq_length] to # [max_seq_length x batch_size] input_tensor = inputs.transpose(0, 1) target_tensor = targets.transpose(0, 1) # init encoder hidden states encoder_hidden = self.encoder.init_hidden(batch_size) # create placeholder for the encoder outputs -> will be passed to # attention decoder if self.encoder.bidirectional: encoder_outputs = torch.zeros(self.max_length, batch_size, (self.hidden_size * 2)).type( self.app_state.dtype) else: encoder_outputs = torch.zeros(self.max_length, batch_size, (self.hidden_size * 1)).type( self.app_state.dtype) # create placeholder for the attention weights -> for visualization self.decoder_attentions = torch.zeros(batch_size, self.max_length, self.max_length).type( self.app_state.dtype) # encoder manual loop for ei in range(self.max_length): encoder_output, encoder_hidden = self.encoder( input_tensor[ei].unsqueeze(-1), encoder_hidden) encoder_outputs[ei] = encoder_output.squeeze() # reshape encoder_outputs to be batch_size first: [max_length, # batch_size, *] -> [batch_size, max_length, *] encoder_outputs = encoder_outputs.transpose(0, 1) # decoder input : [batch_size x 1] initialized to the value of Start Of # String token decoder_input = torch.ones(batch_size, 1).type( self.app_state.LongTensor) * SOS_token # pass along the hidden states: shape [[(encoder.n_layers * # encoder.n_directions) x batch_size x hidden_size]] decoder_hidden = encoder_hidden # create placeholder for the decoder outputs -> will be the logits decoder_outputs = torch.zeros(self.max_length, batch_size, self.output_voc_size).type( self.app_state.dtype) if self.training: # Teacher forcing: Feed the target as the next input for di in range(self.max_length): # base decoder #decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden) # attention decoder decoder_output, decoder_hidden, decoder_attention = self.decoder( decoder_input, decoder_hidden, encoder_outputs) decoder_outputs[di] = decoder_output.squeeze() # Teacher forcing decoder_input = target_tensor[di].unsqueeze(-1) else: # Without teacher forcing: use its own predictions as the next # input for di in range(self.max_length): # base decoder #decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden) # attention decoder decoder_output, decoder_hidden, decoder_attention = self.decoder( decoder_input, decoder_hidden, encoder_outputs) decoder_outputs[di] = decoder_output.squeeze() # save attention weights self.decoder_attentions[:, di, :] = decoder_attention.squeeze() # get most probable word as input of decoder for next iteration topv, topi = decoder_output.topk(k=1, dim=-1) # detach from history as input decoder_input = topi.view(batch_size, 1).detach() # TODO: The line below would stop inference when the next predicted word is the EOS token. This if # statement works for batch_size = 1, but how to generalize it to any size? # if decoder_input.item() == EOS_token: # break return decoder_outputs.transpose(0, 1)
def plot_memory_all_model_params_sequence(self, data_tuple, predictions, sample_number=0): """ Creates list of figures used in interactive visualization, with a slider enabling to move forth and back along the time axis (iteration in a given episode). The visualization presents input, output and target sequences passed as input parameters. Additionally, it utilizes state tuples collected during the experiment for displaying the memory state, read and write attentions; and gating params. :param data_tuple: Data tuple containing - input [BATCH_SIZE x SEQUENCE_LENGTH x INPUT_DATA_SIZE] and - target sequences [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE] :param predictions: Prediction sequence [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE] :param sample_number: Number of sample in batch (DEFAULT: 0) """ # Check if we are supposed to visualize at all. if not self.app_state.visualize: return False # Initialize timePlot window - if required. if self.plotWindow is None: from utils.time_plot import TimePlot self.plotWindow = TimePlot() # import time # start_time = time.time() # Create figure template. fig = self.generate_memory_all_model_params_figure_layout() # Get axes that artists will draw on. (ax_memory, ax_write_gate, ax_write_shift, ax_write_attention, ax_write_similarity, ax_read_gate, ax_read_shift, ax_read_attention, ax_read_similarity, ax_inputs, ax_targets, ax_predictions) = fig.axes # Unpack data tuple. inputs_seq = data_tuple.inputs[sample_number].cpu().detach().numpy() targets_seq = data_tuple.targets[sample_number].cpu().detach().numpy() predictions_seq = predictions[sample_number].cpu().detach().numpy() # Set intial values of displayed inputs, targets and predictions - # simply zeros. inputs_displayed = np.transpose(np.zeros(inputs_seq.shape)) targets_displayed = np.transpose(np.zeros(targets_seq.shape)) predictions_displayed = np.transpose(np.zeros(predictions_seq.shape)) # Set initial values of memory and attentions. # Unpack initial state. (ctrl_state, interface_state, memory_state, read_vectors) = self.cell_state_initial (read_state_tuples, write_state_tuple) = interface_state (write_attention, write_similarity, write_gate, write_shift) = write_state_tuple # Initialize "empty" matrices. memory_displayed = memory_state[0] read0_attention_displayed = np.zeros( (memory_state.shape[1], targets_seq.shape[0])) read0_similarity_displayed = np.zeros( (memory_state.shape[1], targets_seq.shape[0])) read0_gate_displayed = np.zeros( (write_gate.shape[1], targets_seq.shape[0])) read0_shift_displayed = np.zeros( (write_shift.shape[1], targets_seq.shape[0])) write_attention_displayed = np.zeros( (memory_state.shape[1], targets_seq.shape[0])) write_similarity_displayed = np.zeros( (memory_state.shape[1], targets_seq.shape[0])) # Generally we can use write shapes as are the same. write_gate_displayed = np.zeros( (write_gate.shape[1], targets_seq.shape[0])) write_shift_displayed = np.zeros( (write_shift.shape[1], targets_seq.shape[0])) # Log sequence length - so the user can understand what is going on. logger = logging.getLogger('ModelBase') logger.info( "Generating dynamic visualization of {} figures, please wait...". format(inputs_seq.shape[0])) # Create frames - a list of lists, where each row is a list of artists # used to draw a given frame. frames = [] for i, (input_element, target_element, prediction_element, cell_state) in enumerate( zip(inputs_seq, targets_seq, predictions_seq, self.cell_state_history)): # Display information every 10% of figures. if (inputs_seq.shape[0] > 10) and (i % (inputs_seq.shape[0] // 10) == 0): logger.info("Generating figure {}/{}".format( i, inputs_seq.shape[0])) # Update displayed values on adequate positions. inputs_displayed[:, i] = input_element targets_displayed[:, i] = target_element predictions_displayed[:, i] = prediction_element # Unpack cell state. (ctrl_state, interface_state, memory_state, read_vectors) = cell_state (read_state_tuples, write_state_tuple) = interface_state (write_attention, write_similarity, write_gate, write_shift) = write_state_tuple (read_attentions, read_similarities, read_gates, read_shifts) = zip(*read_state_tuples) # Set variables. memory_displayed = memory_state[0].detach().numpy() # Get params of read head 0. read0_attention_displayed[:, i] = read_attentions[0][0][:, 0].detach( ).numpy() read0_similarity_displayed[:, i] = read_similarities[0][0][:, 0].detach( ).numpy() read0_gate_displayed[:, i] = read_gates[0][0][:, 0].detach().numpy() read0_shift_displayed[:, i] = read_shifts[0][0][:, 0].detach().numpy() # Get params of write head write_attention_displayed[:, i] = write_attention[0][:, 0].detach( ).numpy() write_similarity_displayed[:, i] = write_similarity[0][:, 0].detach( ).numpy() write_gate_displayed[:, i] = write_gate[0][:, 0].detach().numpy() write_shift_displayed[:, i] = write_shift[0][:, 0].detach().numpy() # Create "Artists" drawing data on "ImageAxes". artists = [None] * len(fig.axes) # "Show" data on "axes". artists[0] = ax_memory.imshow(memory_displayed, interpolation='nearest', aspect='auto') # Read head. artists[1] = ax_read_attention.imshow(read0_attention_displayed, interpolation='nearest', aspect='auto') artists[2] = ax_read_similarity.imshow(read0_similarity_displayed, interpolation='nearest', aspect='auto') artists[3] = ax_read_gate.imshow(read0_gate_displayed, interpolation='nearest', aspect='auto') artists[4] = ax_read_shift.imshow(read0_shift_displayed, interpolation='nearest', aspect='auto') # Write head. artists[5] = ax_write_attention.imshow(write_attention_displayed, interpolation='nearest', aspect='auto') artists[6] = ax_write_similarity.imshow(write_similarity_displayed, interpolation='nearest', aspect='auto') artists[7] = ax_write_gate.imshow(write_gate_displayed, interpolation='nearest', aspect='auto') artists[8] = ax_write_shift.imshow(write_shift_displayed, interpolation='nearest', aspect='auto') # "Default data". artists[9] = ax_inputs.imshow(inputs_displayed, interpolation='nearest', aspect='auto') artists[10] = ax_targets.imshow(targets_displayed, interpolation='nearest', aspect='auto') artists[11] = ax_predictions.imshow(predictions_displayed, interpolation='nearest', aspect='auto') # Add "frame". frames.append(artists) # print("--- %s seconds ---" % (time.time() - start_time)) # Plot figure and list of frames. self.plotWindow.update(fig, frames) return self.plotWindow.is_closed
class NTM(SequentialModel): """ Class representing the Neural Turing Machine module. """ def __init__(self, params): """ Constructor. Initializes parameters on the basis of dictionary of parameters passed as argument. :param params: Dictionary of parameters. """ # Call constructor of base class. super(NTM, self).__init__(params) # Parse parameters. # It is stored here, but will we used ONLY ONCE - for initialization of # memory in the forward() function. self.num_memory_addresses = params['memory']['num_addresses'] self.num_memory_content_bits = params['memory']['num_content_bits'] # Initialize recurrent NTM cell. self.ntm_cell = NTMCell(params) # Set different visualizations depending on the flags. try: if params['visualization_mode'] == 1: self.plot = self.plot_memory_attention_sequence elif params['visualization_mode'] == 2: self.plot = self.plot_memory_all_model_params_sequence # else: default visualization. except KeyError: # If the 'visualization_mode' key is not present, catch the exception and do nothing # I.e. show default vizualization. pass def forward(self, data_tuple): """ Forward function accepts a tuple consisting of : - a tensor of input data of size [BATCH_SIZE x LENGTH_SIZE x INPUT_SIZE] and - a tensor of targets :return: Predictions being a tensor of size [BATCH_SIZE x LENGTH_SIZE x OUTPUT_SIZE] . """ dtype = self.app_state.dtype # Unpack data tuple. (inputs_BxSxI, targets) = data_tuple batch_size = inputs_BxSxI.size(0) # "Data-driven memory size". # Save as TEMPORAL VARIABLE! # (do not overwrite self.num_memory_addresses, which will cause problem with next batch!) if self.num_memory_addresses == -1: # Set equal to input sequence length. num_memory_addresses = inputs_BxSxI.size(1) else: num_memory_addresses = self.num_memory_addresses # Initialize memory [BATCH_SIZE x MEMORY_ADDRESSES x CONTENT_BITS] init_memory_BxAxC = torch.zeros( batch_size, num_memory_addresses, self.num_memory_content_bits).type(dtype) # Initialize 'zero' state. cell_state = self.ntm_cell.init_state(init_memory_BxAxC) # List of output logits [BATCH_SIZE x OUTPUT_SIZE] of length SEQ_LENGTH output_logits_BxO_S = [] # Check if we want to collect cell history for the visualization # purposes. if self.app_state.visualize: self.cell_state_history = [] self.cell_state_initial = cell_state # Divide sequence into chunks of size [BATCH_SIZE x INPUT_SIZE] and # process them one by one. for input_t_Bx1xI in inputs_BxSxI.chunk(inputs_BxSxI.size(1), dim=1): # Process one chunk. output_BxO, cell_state = self.ntm_cell(input_t_Bx1xI.squeeze(1), cell_state) # Append to list of logits. output_logits_BxO_S += [output_BxO] # Collect cell history - for the visualization purposes. if self.app_state.visualize: self.cell_state_history.append(cell_state) # Stack logits along time axis (1). output_logits_BxSxO = torch.stack(output_logits_BxO_S, 1) return output_logits_BxSxO def generate_memory_attention_figure_layout(self): """ Creates a figure template for showing basic NTM attributes (write & write attentions), memory and sequence (inputs, predictions and targets). :returns: Matplot figure object. """ from matplotlib.figure import Figure import matplotlib.ticker as ticker from matplotlib import rc import matplotlib.gridspec as gridspec # Change fonts globally - for all figures/subsplots at once. rc('font', **{'family': 'Times New Roman'}) # Prepare "generic figure template". # Create figure object. fig = Figure() # Create a specific grid for NTM . gs = gridspec.GridSpec(3, 7) # Memory ax_memory = fig.add_subplot(gs[:, 0]) # all rows, col 0 ax_write_attention = fig.add_subplot(gs[:, 1:3]) # all rows, col 2-3 ax_read_attention = fig.add_subplot(gs[:, 3:5]) # all rows, col 4-5 ax_inputs = fig.add_subplot(gs[0, 5:]) # row 0, span 2 columns ax_targets = fig.add_subplot(gs[1, 5:]) # row 0, span 2 columns ax_predictions = fig.add_subplot(gs[2, 5:]) # row 0, span 2 columns # Set ticks - currently for all axes. for ax in fig.axes: ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax.yaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax.yaxis.set_major_locator(ticker.MaxNLocator(integer=True)) # Set labels. ax_inputs.set_title('Inputs') ax_inputs.set_ylabel('Control/Data bits') ax_targets.set_title('Targets') ax_targets.set_ylabel('Data bits') ax_predictions.set_title('Predictions') ax_predictions.set_ylabel('Data bits') ax_predictions.set_xlabel('Item number/Iteration') ax_memory.set_title('Memory') ax_memory.set_ylabel('Memory Addresses') ax_memory.set_xlabel('Content bits') ax_write_attention.set_title('Write Attention') ax_write_attention.set_xlabel('Iteration') ax_read_attention.set_title('Read Attention') ax_read_attention.set_xlabel('Iteration') fig.set_tight_layout(True) return fig def plot_memory_attention_sequence(self, data_tuple, predictions, sample_number=0): """ Creates list of figures used in interactive visualization, with a slider enabling to move forth and back along the time axis (iteration in a given episode). The visualization presents input, output and target sequences passed as input parameters. Additionally, it utilizes state tuples collected during the experiment for displaying the memory state, read and write attentions. :param data_tuple: Data tuple containing - input [BATCH_SIZE x SEQUENCE_LENGTH x INPUT_DATA_SIZE] and - target sequences [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE] :param predictions: Prediction sequence [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE] :param sample_number: Number of sample in batch (DEFAULT: 0) """ # Check if we are supposed to visualize at all. if not self.app_state.visualize: return False # Initialize timePlot window - if required. if self.plotWindow is None: from utils.time_plot import TimePlot self.plotWindow = TimePlot() # import time # start_time = time.time() # Create figure template. fig = self.generate_memory_attention_figure_layout() # Get axes that artists will draw on. (ax_memory, ax_write_attention, ax_read_attention, ax_inputs, ax_targets, ax_predictions) = fig.axes # Unpack data tuple. inputs_seq = data_tuple.inputs[sample_number].cpu().detach().numpy() targets_seq = data_tuple.targets[sample_number].cpu().detach().numpy() predictions_seq = predictions[sample_number].cpu().detach().numpy() # Set intial values of displayed inputs, targets and predictions - # simply zeros. inputs_displayed = np.transpose(np.zeros(inputs_seq.shape)) targets_displayed = np.transpose(np.zeros(targets_seq.shape)) predictions_displayed = np.transpose(np.zeros(predictions_seq.shape)) # Set initial values of memory and attentions. # Unpack initial state. (ctrl_state, interface_state, memory_state, read_vectors) = self.cell_state_initial # Initialize "empty" matrices. memory_displayed = memory_state[0] read0_attention_displayed = np.zeros( (memory_state.shape[1], targets_seq.shape[0])) write_attention_displayed = np.zeros( (memory_state.shape[1], targets_seq.shape[0])) # Log sequence length - so the user can understand what is going on. logger = logging.getLogger('ModelBase') logger.info( "Generating dynamic visualization of {} figures, please wait...". format(inputs_seq.shape[0])) # Create frames - a list of lists, where each row is a list of artists # used to draw a given frame. frames = [] for i, (input_element, target_element, prediction_element, cell_state) in enumerate( zip(inputs_seq, targets_seq, predictions_seq, self.cell_state_history)): # Display information every 10% of figures. if (inputs_seq.shape[0] > 10) and (i % (inputs_seq.shape[0] // 10) == 0): logger.info("Generating figure {}/{}".format( i, inputs_seq.shape[0])) # Update displayed values on adequate positions. inputs_displayed[:, i] = input_element targets_displayed[:, i] = target_element predictions_displayed[:, i] = prediction_element # Unpack cell state. (ctrl_state, interface_state, memory_state, read_vectors) = cell_state (read_state_tuples, write_state_tuple) = interface_state (write_attention, write_similarity, write_gate, write_shift) = write_state_tuple (read_attentions, read_similarities, read_gates, read_shifts) = zip(*read_state_tuples) # Set variables. memory_displayed = memory_state[0].detach().numpy() # Get attention of head 0. read0_attention_displayed[:, i] = read_attentions[0][0][:, 0].detach( ).numpy() write_attention_displayed[:, i] = write_attention[0][:, 0].detach( ).numpy() # Create "Artists" drawing data on "ImageAxes". artists = [None] * len(fig.axes) # "Show" data on "axes". artists[0] = ax_memory.imshow(memory_displayed, interpolation='nearest', aspect='auto') artists[1] = ax_read_attention.imshow(read0_attention_displayed, interpolation='nearest', aspect='auto') artists[2] = ax_write_attention.imshow(write_attention_displayed, interpolation='nearest', aspect='auto') artists[3] = ax_inputs.imshow(inputs_displayed, interpolation='nearest', aspect='auto') artists[4] = ax_targets.imshow(targets_displayed, interpolation='nearest', aspect='auto') artists[5] = ax_predictions.imshow(predictions_displayed, interpolation='nearest', aspect='auto') # Add "frame". frames.append(artists) # print("--- %s seconds ---" % (time.time() - start_time)) # Plot figure and list of frames. self.plotWindow.update(fig, frames) return self.plotWindow.is_closed def generate_memory_all_model_params_figure_layout(self): """ Creates a figure template for showing all NTM attributes (write & write attentions, gates, convolution masks), along with memory and sequence (inputs, predictions and targets). :returns: Matplot figure object. """ from matplotlib.figure import Figure import matplotlib.ticker as ticker from matplotlib import rc import matplotlib.gridspec as gridspec # Change fonts globally - for all figures/subsplots at once. rc('font', **{'family': 'Times New Roman'}) # Prepare "generic figure template". # Create figure object. fig = Figure() #axes = fig.subplots(3, 1, sharex=True, sharey=False, gridspec_kw={'width_ratios': [input_seq.shape[0]]}) # Create a specific grid for NTM . gs = gridspec.GridSpec(4, 7) # Memory ax_memory = fig.add_subplot(gs[1:, 0]) # all rows, col 0 ax_write_gate = fig.add_subplot(gs[0, 1:2]) ax_write_shift = fig.add_subplot(gs[0, 2:3]) ax_write_attention = fig.add_subplot(gs[1:, 1:2]) # all rows, col 2-3 ax_write_similarity = fig.add_subplot(gs[1:, 2:3]) ax_read_gate = fig.add_subplot(gs[0, 3:4]) ax_read_shift = fig.add_subplot(gs[0, 4:5]) ax_read_attention = fig.add_subplot(gs[1:, 3:4]) # all rows, col 4-5 ax_read_similarity = fig.add_subplot(gs[1:, 4:5]) ax_inputs = fig.add_subplot(gs[1, 5:]) # row 0, span 2 columns ax_targets = fig.add_subplot(gs[2, 5:]) # row 0, span 2 columns ax_predictions = fig.add_subplot(gs[3, 5:]) # row 0, span 2 columns # Set ticks - currently for all axes. for ax in fig.axes: ax.xaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax.yaxis.set_major_locator(ticker.MaxNLocator(integer=True)) # ... except gates - single bit. ax_write_gate.yaxis.set_major_locator(ticker.NullLocator()) ax_read_gate.yaxis.set_major_locator(ticker.NullLocator()) # Set labels. ax_inputs.set_title('Inputs') ax_inputs.set_ylabel('Control/Data bits') ax_targets.set_title('Targets') ax_targets.set_ylabel('Data bits') ax_predictions.set_title('Predictions') ax_predictions.set_ylabel('Data bits') ax_predictions.set_xlabel('Item number/Iteration') ax_memory.set_title('Memory') ax_memory.set_ylabel('Memory Addresses') ax_memory.set_xlabel('Content bits') for ax in [ ax_write_gate, ax_write_shift, ax_write_attention, ax_write_similarity, ax_read_gate, ax_read_shift, ax_read_attention, ax_read_similarity ]: ax.set_xlabel('Iteration') # Write head. ax_write_gate.set_title('Write Gate') ax_write_shift.set_title('Write Shift') ax_write_attention.set_title('Write Attention') ax_write_similarity.set_title('Write Similarity') # Read head. ax_read_gate.set_title('Read Gate') ax_read_shift.set_title('Read Shift') ax_read_attention.set_title('Read Attention') ax_read_similarity.set_title('Read Similarity') fig.set_tight_layout(True) return fig def plot_memory_all_model_params_sequence(self, data_tuple, predictions, sample_number=0): """ Creates list of figures used in interactive visualization, with a slider enabling to move forth and back along the time axis (iteration in a given episode). The visualization presents input, output and target sequences passed as input parameters. Additionally, it utilizes state tuples collected during the experiment for displaying the memory state, read and write attentions; and gating params. :param data_tuple: Data tuple containing - input [BATCH_SIZE x SEQUENCE_LENGTH x INPUT_DATA_SIZE] and - target sequences [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE] :param predictions: Prediction sequence [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE] :param sample_number: Number of sample in batch (DEFAULT: 0) """ # Check if we are supposed to visualize at all. if not self.app_state.visualize: return False # Initialize timePlot window - if required. if self.plotWindow is None: from utils.time_plot import TimePlot self.plotWindow = TimePlot() # import time # start_time = time.time() # Create figure template. fig = self.generate_memory_all_model_params_figure_layout() # Get axes that artists will draw on. (ax_memory, ax_write_gate, ax_write_shift, ax_write_attention, ax_write_similarity, ax_read_gate, ax_read_shift, ax_read_attention, ax_read_similarity, ax_inputs, ax_targets, ax_predictions) = fig.axes # Unpack data tuple. inputs_seq = data_tuple.inputs[sample_number].cpu().detach().numpy() targets_seq = data_tuple.targets[sample_number].cpu().detach().numpy() predictions_seq = predictions[sample_number].cpu().detach().numpy() # Set intial values of displayed inputs, targets and predictions - # simply zeros. inputs_displayed = np.transpose(np.zeros(inputs_seq.shape)) targets_displayed = np.transpose(np.zeros(targets_seq.shape)) predictions_displayed = np.transpose(np.zeros(predictions_seq.shape)) # Set initial values of memory and attentions. # Unpack initial state. (ctrl_state, interface_state, memory_state, read_vectors) = self.cell_state_initial (read_state_tuples, write_state_tuple) = interface_state (write_attention, write_similarity, write_gate, write_shift) = write_state_tuple # Initialize "empty" matrices. memory_displayed = memory_state[0] read0_attention_displayed = np.zeros( (memory_state.shape[1], targets_seq.shape[0])) read0_similarity_displayed = np.zeros( (memory_state.shape[1], targets_seq.shape[0])) read0_gate_displayed = np.zeros( (write_gate.shape[1], targets_seq.shape[0])) read0_shift_displayed = np.zeros( (write_shift.shape[1], targets_seq.shape[0])) write_attention_displayed = np.zeros( (memory_state.shape[1], targets_seq.shape[0])) write_similarity_displayed = np.zeros( (memory_state.shape[1], targets_seq.shape[0])) # Generally we can use write shapes as are the same. write_gate_displayed = np.zeros( (write_gate.shape[1], targets_seq.shape[0])) write_shift_displayed = np.zeros( (write_shift.shape[1], targets_seq.shape[0])) # Log sequence length - so the user can understand what is going on. logger = logging.getLogger('ModelBase') logger.info( "Generating dynamic visualization of {} figures, please wait...". format(inputs_seq.shape[0])) # Create frames - a list of lists, where each row is a list of artists # used to draw a given frame. frames = [] for i, (input_element, target_element, prediction_element, cell_state) in enumerate( zip(inputs_seq, targets_seq, predictions_seq, self.cell_state_history)): # Display information every 10% of figures. if (inputs_seq.shape[0] > 10) and (i % (inputs_seq.shape[0] // 10) == 0): logger.info("Generating figure {}/{}".format( i, inputs_seq.shape[0])) # Update displayed values on adequate positions. inputs_displayed[:, i] = input_element targets_displayed[:, i] = target_element predictions_displayed[:, i] = prediction_element # Unpack cell state. (ctrl_state, interface_state, memory_state, read_vectors) = cell_state (read_state_tuples, write_state_tuple) = interface_state (write_attention, write_similarity, write_gate, write_shift) = write_state_tuple (read_attentions, read_similarities, read_gates, read_shifts) = zip(*read_state_tuples) # Set variables. memory_displayed = memory_state[0].detach().numpy() # Get params of read head 0. read0_attention_displayed[:, i] = read_attentions[0][0][:, 0].detach( ).numpy() read0_similarity_displayed[:, i] = read_similarities[0][0][:, 0].detach( ).numpy() read0_gate_displayed[:, i] = read_gates[0][0][:, 0].detach().numpy() read0_shift_displayed[:, i] = read_shifts[0][0][:, 0].detach().numpy() # Get params of write head write_attention_displayed[:, i] = write_attention[0][:, 0].detach( ).numpy() write_similarity_displayed[:, i] = write_similarity[0][:, 0].detach( ).numpy() write_gate_displayed[:, i] = write_gate[0][:, 0].detach().numpy() write_shift_displayed[:, i] = write_shift[0][:, 0].detach().numpy() # Create "Artists" drawing data on "ImageAxes". artists = [None] * len(fig.axes) # "Show" data on "axes". artists[0] = ax_memory.imshow(memory_displayed, interpolation='nearest', aspect='auto') # Read head. artists[1] = ax_read_attention.imshow(read0_attention_displayed, interpolation='nearest', aspect='auto') artists[2] = ax_read_similarity.imshow(read0_similarity_displayed, interpolation='nearest', aspect='auto') artists[3] = ax_read_gate.imshow(read0_gate_displayed, interpolation='nearest', aspect='auto') artists[4] = ax_read_shift.imshow(read0_shift_displayed, interpolation='nearest', aspect='auto') # Write head. artists[5] = ax_write_attention.imshow(write_attention_displayed, interpolation='nearest', aspect='auto') artists[6] = ax_write_similarity.imshow(write_similarity_displayed, interpolation='nearest', aspect='auto') artists[7] = ax_write_gate.imshow(write_gate_displayed, interpolation='nearest', aspect='auto') artists[8] = ax_write_shift.imshow(write_shift_displayed, interpolation='nearest', aspect='auto') # "Default data". artists[9] = ax_inputs.imshow(inputs_displayed, interpolation='nearest', aspect='auto') artists[10] = ax_targets.imshow(targets_displayed, interpolation='nearest', aspect='auto') artists[11] = ax_predictions.imshow(predictions_displayed, interpolation='nearest', aspect='auto') # Add "frame". frames.append(artists) # print("--- %s seconds ---" % (time.time() - start_time)) # Plot figure and list of frames. self.plotWindow.update(fig, frames) return self.plotWindow.is_closed
class DNC(SequentialModel): """ @Ryan CLASS DESCRIPTION HERE """ def __init__(self, params): """ Initialize an DNC Layer. :param params: dictionary of inputs. """ # Call base class initialization. super(DNC, self).__init__(params) try: self.output_units = params['output_bits'] except KeyError: self.output_units = params['data_bits'] self.memory_addresses_size = params["memory_addresses_size"] self.label = params["name"] self.cell_state_history = None # Number of read and write heads self._num_reads = params["num_reads"] self._num_writes = params["num_writes"] # Create the DNC components self.DNCCell = DNCCell(self.output_units, params) def forward(self, data_tuple): # inputs : batch_size, seq_len, input_size """ Runs the DNC cell and plots if necessary. :param data_tuple: Tuple containing inputs and targets :returns: output [batch_size, seq_len, output_size] """ (inputs, targets) = data_tuple dtype = self.app_state.dtype output = None if self.app_state.visualize: self.cell_state_history = [] batch_size = inputs.size(0) seq_length = inputs.size(1) memory_addresses_size = self.memory_addresses_size # if memory size is not fixed, set it to the total input plus output # size if memory_addresses_size == -1: memory_addresses_size = seq_length # init state cell_state = self.DNCCell.init_state(memory_addresses_size, batch_size) #cell_state = self.init_state(memory_addresses_size) for j in range(seq_length): output_cell, cell_state = self.DNCCell(inputs[..., j, :], cell_state) if output_cell is None: continue output_cell = output_cell[..., None, :] if output is None: output = output_cell else: output = torch.cat([output, output_cell], dim=-2) # This is for the time plot if self.app_state.visualize: self.cell_state_history.append( (cell_state.memory_state.detach().cpu().numpy(), cell_state.int_init_state.usage.detach().cpu().numpy(), cell_state.int_init_state.links.precedence_weights.detach( ).cpu().numpy(), cell_state.int_init_state.read_weights.detach().cpu( ).numpy(), cell_state.int_init_state.write_weights.detach( ).cpu().numpy())) # if self.plot_active: # self.plot_memory_attention(output, cell_state) return output def plot_memory_attention(self, data_tuple, predictions, sample_number=0): """ Plots memory and attention TODO: fix. :param data_tuple: Data tuple containing input [BATCH_SIZE x SEQUENCE_LENGTH x INPUT_DATA_SIZE] and target sequences [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE] :param predictions: Prediction sequence [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE] :param sample_number: Number of sample in batch (DEFAULT: 0) """ # plot attention/memory from models.dnc.plot_data import plot_memory_attention #plot_memory_attention(output, states[2], states[1][0], states[1][1], states[1][2], self.label) def generate_figure_layout(self): from matplotlib.figure import Figure import matplotlib.ticker as ticker import matplotlib.gridspec as gridspec # Change fonts globally - for all figures/subsplots at once. #from matplotlib import rc #rc('font', **{'family': 'Times New Roman'}) import matplotlib.pylab as pylab params = { # 'legend.fontsize': '28', 'axes.titlesize': 'large', 'axes.labelsize': 'large', 'xtick.labelsize': 'medium', 'ytick.labelsize': 'medium' } pylab.rcParams.update(params) # Prepare "generic figure template". # Create figure object. fig = Figure() # Create a specific grid for DWM . gs = gridspec.GridSpec(3, 9) # Memory ax_memory = fig.add_subplot(gs[:, 0]) # all rows, col 0 ax_read = fig.add_subplot(gs[:, 1:3]) # all rows, col 2-3 ax_write = fig.add_subplot(gs[:, 3:5]) # all rows, col 4-5 ax_usage = fig.add_subplot(gs[:, 5:7]) # all rows, col 4-5 ax_inputs = fig.add_subplot(gs[0, 7:]) # row 0, span 2 columns ax_targets = fig.add_subplot(gs[1, 7:]) # row 0, span 2 columns ax_predictions = fig.add_subplot(gs[2, 7:]) # row 0, span 2 columns # Set ticks - for bit axes only (for now). ax_inputs.xaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_inputs.yaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_targets.xaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_targets.yaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_predictions.xaxis.set_major_locator( ticker.MaxNLocator(integer=True)) ax_predictions.yaxis.set_major_locator( ticker.MaxNLocator(integer=True)) ax_memory.xaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_memory.yaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_read.xaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_read.yaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_write.xaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_write.yaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_usage.xaxis.set_major_locator(ticker.MaxNLocator(integer=True)) ax_usage.yaxis.set_major_locator(ticker.MaxNLocator(integer=True)) # Set labels. ax_inputs.set_title('Inputs') ax_inputs.set_ylabel('Control/Data bits') ax_targets.set_title('Targets') ax_targets.set_ylabel('Data bits') ax_predictions.set_title('Predictions') ax_predictions.set_ylabel('Data bits') ax_predictions.set_xlabel('Item number/Iteration') ax_memory.set_title('Memory') ax_memory.set_ylabel('Memory Addresses') ax_memory.set_xlabel('Content bits') ax_read.set_title('Read attention') ax_read.set_xlabel('Iteration') ax_write.set_title('Write Attention') ax_write.set_xlabel('Iteration') ax_usage.set_title('Usage') ax_usage.set_xlabel('Iteration') fig.set_tight_layout(True) # gs.tight_layout(fig) # plt.tight_layout() #fig.subplots_adjust(left = 0) return fig def plot(self, data_tuple, predictions, sample_number=0): """ Interactive visualization, with a slider enabling to move forth and back along the time axis (iteration in a given episode). :param data_tuple: Data tuple containing - input [BATCH_SIZE x SEQUENCE_LENGTH x INPUT_DATA_SIZE] and - target sequences [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE] :param predictions: Prediction sequence [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE] :param sample_number: Number of sample in batch (DEFAULT: 0) """ # Check if we are supposed to visualize at all. if not self.app_state.visualize: return False # Initialize timePlot window - if required. if self.plotWindow is None: from utils.time_plot import TimePlot self.plotWindow = TimePlot() # import time # start_time = time.time() inputs_seq = data_tuple.inputs[0].cpu().detach().numpy() targets_seq = data_tuple.targets[0].cpu().detach().numpy() predictions_seq = predictions[0].cpu().detach().numpy() # temporary for data with additional channel if len(inputs_seq.shape) == 3: inputs_seq = inputs_seq[0, :, :] # Create figure template. fig = self.generate_figure_layout() # Get axes that artists will draw on. (ax_memory, ax_read, ax_write, ax_usage, ax_inputs, ax_targets, ax_predictions) = fig.axes # Set intial values of displayed inputs, targets and predictions - # simply zeros. inputs_displayed = np.transpose(np.zeros(inputs_seq.shape)) targets_displayed = np.transpose(np.zeros(targets_seq.shape)) predictions_displayed = np.transpose(np.zeros(predictions_seq.shape)) head_attention_read = np.zeros( (self.cell_state_history[0][3].shape[-1], targets_seq.shape[0])) head_attention_write = np.zeros( (self.cell_state_history[0][4].shape[-1], targets_seq.shape[0])) usage_displayed = np.zeros( (self.cell_state_history[0][1].shape[-1], targets_seq.shape[0])) # Log sequence length - so the user can understand what is going on. logger = logging.getLogger('ModelBase') logger.info( "Generating dynamic visualization of {} figures, please wait...". format(inputs_seq.shape[0])) # Create frames - a list of lists, where each row is a list of artists # used to draw a given frame. frames = [] for i, (input_element, target_element, prediction_element, (memory, usage, links, wt_r, wt_w)) in enumerate( zip(inputs_seq, targets_seq, predictions_seq, self.cell_state_history)): # Display information every 10% of figures. if (inputs_seq.shape[0] > 10) and (i % (inputs_seq.shape[0] // 10) == 0): logger.info("Generating figure {}/{}".format( i, inputs_seq.shape[0])) # Update displayed values on adequate positions. inputs_displayed[:, i] = input_element targets_displayed[:, i] = target_element predictions_displayed[:, i] = prediction_element memory_displayed = memory[0] # Get attention of head 0. head_attention_read[:, i] = wt_r[0, 0, :] head_attention_write[:, i] = wt_w[0, 0, :] usage_displayed[:, i] = usage[0, :] # Create "Artists" drawing data on "ImageAxes". artists = [None] * len(fig.axes) # Tell artists what to do;) artists[0] = ax_memory.imshow(np.transpose(memory_displayed), interpolation='nearest', aspect='auto') artists[1] = ax_read.imshow(head_attention_read, interpolation='nearest', aspect='auto') artists[2] = ax_write.imshow(head_attention_write, interpolation='nearest', aspect='auto') artists[3] = ax_usage.imshow(usage_displayed, interpolation='nearest', aspect='auto') artists[4] = ax_inputs.imshow(inputs_displayed, interpolation='nearest', aspect='auto') artists[5] = ax_targets.imshow(targets_displayed, interpolation='nearest', aspect='auto') artists[6] = ax_predictions.imshow(predictions_displayed, interpolation='nearest', aspect='auto') # Add "frame". frames.append(artists) # print("--- %s seconds ---" % (time.time() - start_time)) # Plot figure and list of frames. self.plotWindow.update(fig, frames) return self.plotWindow.is_closed
def plot(self, data_tuple, predictions, sample_number=0): """ Interactive visualization, with a slider enabling to move forth and back along the time axis (iteration in a given episode). :param data_tuple: Data tuple containing - input [BATCH_SIZE x SEQUENCE_LENGTH x INPUT_DATA_SIZE] and - target sequences [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE] :param predictions: Prediction sequence [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE] :param sample_number: Number of sample in batch (DEFAULT: 0) """ # Check if we are supposed to visualize at all. if not self.app_state.visualize: return False # Initialize timePlot window - if required. if self.plotWindow is None: from utils.time_plot import TimePlot self.plotWindow = TimePlot() # import time # start_time = time.time() inputs_seq = data_tuple.inputs[0].cpu().detach().numpy() targets_seq = data_tuple.targets[0].cpu().detach().numpy() predictions_seq = predictions[0].cpu().detach().numpy() # temporary for data with additional channel if len(inputs_seq.shape) == 3: inputs_seq = inputs_seq[0, :, :] # Create figure template. fig = self.generate_figure_layout() # Get axes that artists will draw on. (ax_memory, ax_read, ax_write, ax_usage, ax_inputs, ax_targets, ax_predictions) = fig.axes # Set intial values of displayed inputs, targets and predictions - # simply zeros. inputs_displayed = np.transpose(np.zeros(inputs_seq.shape)) targets_displayed = np.transpose(np.zeros(targets_seq.shape)) predictions_displayed = np.transpose(np.zeros(predictions_seq.shape)) head_attention_read = np.zeros( (self.cell_state_history[0][3].shape[-1], targets_seq.shape[0])) head_attention_write = np.zeros( (self.cell_state_history[0][4].shape[-1], targets_seq.shape[0])) usage_displayed = np.zeros( (self.cell_state_history[0][1].shape[-1], targets_seq.shape[0])) # Log sequence length - so the user can understand what is going on. logger = logging.getLogger('ModelBase') logger.info( "Generating dynamic visualization of {} figures, please wait...". format(inputs_seq.shape[0])) # Create frames - a list of lists, where each row is a list of artists # used to draw a given frame. frames = [] for i, (input_element, target_element, prediction_element, (memory, usage, links, wt_r, wt_w)) in enumerate( zip(inputs_seq, targets_seq, predictions_seq, self.cell_state_history)): # Display information every 10% of figures. if (inputs_seq.shape[0] > 10) and (i % (inputs_seq.shape[0] // 10) == 0): logger.info("Generating figure {}/{}".format( i, inputs_seq.shape[0])) # Update displayed values on adequate positions. inputs_displayed[:, i] = input_element targets_displayed[:, i] = target_element predictions_displayed[:, i] = prediction_element memory_displayed = memory[0] # Get attention of head 0. head_attention_read[:, i] = wt_r[0, 0, :] head_attention_write[:, i] = wt_w[0, 0, :] usage_displayed[:, i] = usage[0, :] # Create "Artists" drawing data on "ImageAxes". artists = [None] * len(fig.axes) # Tell artists what to do;) artists[0] = ax_memory.imshow(np.transpose(memory_displayed), interpolation='nearest', aspect='auto') artists[1] = ax_read.imshow(head_attention_read, interpolation='nearest', aspect='auto') artists[2] = ax_write.imshow(head_attention_write, interpolation='nearest', aspect='auto') artists[3] = ax_usage.imshow(usage_displayed, interpolation='nearest', aspect='auto') artists[4] = ax_inputs.imshow(inputs_displayed, interpolation='nearest', aspect='auto') artists[5] = ax_targets.imshow(targets_displayed, interpolation='nearest', aspect='auto') artists[6] = ax_predictions.imshow(predictions_displayed, interpolation='nearest', aspect='auto') # Add "frame". frames.append(artists) # print("--- %s seconds ---" % (time.time() - start_time)) # Plot figure and list of frames. self.plotWindow.update(fig, frames) return self.plotWindow.is_closed
def plot(self, data_tuple, predictions, sample_number=0): """ Creates a default interactive visualization, with a slider enabling to move forth and back along the time axis (iteration in a given episode). The default visualizatoin contains input, output and target sequences. For more model/problem dependent visualization please overwrite this method in the derived model class. :param data_tuple: Data tuple containing - input [BATCH_SIZE x SEQUENCE_LENGTH x INPUT_DATA_SIZE] and - target sequences [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE] :param predictions: Prediction sequence [BATCH_SIZE x SEQUENCE_LENGTH x OUTPUT_DATA_SIZE] :param sample_number: Number of sample in batch (DEFAULT: 0) """ # Check if we are supposed to visualize at all. if not self.app_state.visualize: return False # Initialize timePlot window - if required. if self.plotWindow is None: from utils.time_plot import TimePlot self.plotWindow = TimePlot() from matplotlib.figure import Figure import matplotlib.ticker as ticker # Change fonts globally - for all figures/subsplots at once. #from matplotlib import rc #rc('font', **{'family': 'Times New Roman'}) import matplotlib.pylab as pylab params = { # 'legend.fontsize': '28', 'axes.titlesize': 'large', 'axes.labelsize': 'large', 'xtick.labelsize': 'medium', 'ytick.labelsize': 'medium' } pylab.rcParams.update(params) # Create a single "figure layout" for all displayed frames. fig = Figure() axes = fig.subplots( 3, 1, sharex=True, sharey=False, gridspec_kw={'width_ratios': [predictions.shape[0]]}) # Set ticks. axes[0].xaxis.set_major_locator(ticker.MaxNLocator(integer=True)) axes[0].yaxis.set_major_locator(ticker.MaxNLocator(integer=True)) axes[1].yaxis.set_major_locator(ticker.MaxNLocator(integer=True)) axes[2].yaxis.set_major_locator(ticker.MaxNLocator(integer=True)) # Set labels. axes[0].set_title('Inputs') axes[0].set_ylabel('Control/Data bits') axes[1].set_title('Targets') axes[1].set_ylabel('Data bits') axes[2].set_title('Predictions') axes[2].set_ylabel('Data bits') axes[2].set_xlabel('Item number') fig.set_tight_layout(True) # Detach a sample from batch and copy it to CPU. inputs_seq = data_tuple.inputs[sample_number].cpu().detach().numpy() targets_seq = data_tuple.targets[sample_number].cpu().detach().numpy() predictions_seq = predictions[sample_number].cpu().detach().numpy() # Create empty matrices. x = np.transpose(np.zeros(inputs_seq.shape)) y = np.transpose(np.zeros(predictions_seq.shape)) z = np.transpose(np.zeros(targets_seq.shape)) # Log sequence length - so the user can understand what is going on. logger = logging.getLogger('ModelBase') logger.info( "Generating dynamic visualization of {} figures, please wait...". format(inputs_seq.shape[0])) # Create frames - a list of lists, where each row is a list of artists # used to draw a given frame. frames = [] for i, (input_word, prediction_word, target_word) in enumerate( zip(inputs_seq, predictions_seq, targets_seq)): # Display information every 10% of figures. if (inputs_seq.shape[0] > 10) and (i % (inputs_seq.shape[0] // 10) == 0): logger.info("Generating figure {}/{}".format( i, inputs_seq.shape[0])) # Add words to adequate positions. x[:, i] = input_word y[:, i] = target_word z[:, i] = prediction_word # Create "Artists" drawing data on "ImageAxes". artists = [None] * len(fig.axes) # Tell artists what to do;) artists[0] = axes[0].imshow(x, interpolation='nearest', aspect='auto') artists[1] = axes[1].imshow(y, interpolation='nearest', aspect='auto') artists[2] = axes[2].imshow(z, interpolation='nearest', aspect='auto') # Add "frame". frames.append(artists) # Plot figure and list of frames. self.plotWindow.update(fig, frames) # Return True if user closed the window. return self.plotWindow.is_closed