def init_hidden(self, batch_size): """ Initializes the hidden states for the encoder. :param batch_size: batch size :return: initial hidden states. """ if self.bidirectional: return torch.zeros(self.n_layers * 2, batch_size, self.hidden_size).type(AppState().dtype) else: return torch.zeros(self.n_layers, batch_size, self.hidden_size).type(AppState().dtype)
def __init__(self, params): """ Initializes application state and sets plot if visualization flag is turned on. :param params: Parameters read from configuration file. """ # Call base class inits here. super(Model, self).__init__() # Initialize app state. self.app_state = AppState() # Store pointer to params. self.params = params # Window in which the data will be ploted. self.plotWindow = None # Initialization of best loss - as INF. self.best_loss = np.inf # Flag indicating whether intermediate checkpoints should be saved or # not (DEFAULT: False). if "save_intermediate" not in params: params.add_default_params({"save_intermediate": False}) self.save_intermediate = params["save_intermediate"] # "Default" model name. self.name = 'Model'
def init_state(self, memory_address_size, batch_size): """ Returns 'zero' (initial) state tuple. :param memory_address_size: The number of memory addresses :param batch_size: Size of the batch in given iteraction/epoch. :returns: Initial state tuple - object of InterfaceStateTuple class. """ dtype = AppState().dtype # Read attention weights [BATCH_SIZE x MEMORY_SIZE] read_attention = torch.ones((batch_size, self._num_reads, memory_address_size)).type(dtype) * 1e-6 # Write attention weights [BATCH_SIZE x MEMORY_SIZE] write_attention = torch.ones((batch_size, self._num_writes, memory_address_size)).type(dtype) * 1e-6 # Usage of memory cells [BATCH_SIZE x MEMORY_SIZE] usage = self.mem_usage.init_state(memory_address_size, batch_size) # temporal links tuple link_tuple = self.temporal_linkage.init_state(memory_address_size, batch_size) return InterfaceStateTuple(read_attention, write_attention, usage, link_tuple)
def forward(self, logits, targets, mask): """ Calculates loss accounting for different numbers of output per sample. :param logits: Logits being output by the model. [batch, classes, sequence] :param targets: LongTensor targets [batch, sequence] :param mask: ByteTensor mask [batch, sequence] """ # Calculate the loss per element in the sequence loss_per_element = self.loss_function(logits, targets) # Have to convert the mask to floats to multiply by the loss mask_float = mask.type(AppState().dtype) # if the loss has one extra dimenison then you need an extra unit dimension # to multiply element by element if len(mask.shape) < len(loss_per_element.shape): mask_float = mask_float.unsqueeze(-1) # Set the loss per element to zero for unneeded output masked_loss_per = mask_float * loss_per_element # obtain the number of non-zero elements in the mask. # nonzero() returns the indices so you have to divide by the number of # dimensions size = mask.nonzero().numel() / len(mask.shape) # add up the loss scaling by only the needed outputs loss = torch.sum(masked_loss_per) / size return loss
def init_state(self, memory_address_size, batch_size): """ Returns 'zero' (initial) state: * memory is reset to random values. * read & write weights (and read vector) are set to 1e-6. :param batch_size: Size of the batch in given iteraction/epoch. :param num_memory_adresses: Number of memory addresses. """ dtype = AppState().dtype # Initialize controller state. ctrl_init_state = self.control_params.init_state(batch_size) # Initialize interface state. interface_init_state = self.interface.init_state( memory_address_size, batch_size) # Memory [BATCH_SIZE x MEMORY_BITS x MEMORY_SIZE] init_memory_BxMxA = torch.zeros(batch_size, self.num_memory_bits, memory_address_size).type(dtype) # Read vector [BATCH_SIZE x MEMORY_SIZE] read_vector_BxM = self.interface.read(interface_init_state, init_memory_BxMxA) # Pack and return a tuple. return NTMCellStateTuple(ctrl_init_state, interface_init_state, init_memory_BxMxA, read_vector_BxM)
def normalize(x): """ Normalizes the input torch tensor along the last dimension using the max of the one norm The normalization is "fuzzy" to prevent divergences. :param x: input of shape [batch_size, A, A1 ..An] if the input is the weight vector x'sahpe (batch_size, num_heads, memory_size) :return: normalized x of shape [batch_size, A, A1 ..An] """ dtype = AppState().dtype return x / torch.max(torch.sum(x, dim=-1, keepdim=True), torch.Tensor([1e-12]).type(dtype))
def init_state(self, memory_address_size, batch_size): """ Returns 'zero' (initial) state tuple. :param batch_size: Size of the batch in given iteraction/epoch. :returns: Initial state tuple - object of InterfaceStateTuple class. """ dtype = AppState().dtype self._memory_size = memory_address_size usage = torch.zeros((batch_size, memory_address_size)).type(dtype) return usage
def init_state(self, batch_size): """ Returns 'zero' (initial) state tuple. :param batch_size: Size of the batch in given iteraction/epoch. :returns: Initial state tuple - object of RNNStateTuple class. """ # Initialize LSTM hidden state [BATCH_SIZE x CTRL_HIDDEN_SIZE]. dtype = AppState().dtype hidden_state = torch.zeros((batch_size, self.ctrl_hidden_state_size), requires_grad=False).type(dtype) return RNNStateTuple(hidden_state)
def exclusive_cumprod_temp(self, sorted_usage, dim=1): """ Applies the exclusive cumultative product (at the moment it assumes the shape of the input) :param sorted_usage: tensor of shape `[batch_size, memory_size]` indicating current memory usage sorted in ascending order. :returns: Tensor of shape `[batch_size, memory_size]` that is exclusive pruduct of the sorted usage i.e. = [1, u1, u1*u2, u1*u2*u3, ....] """ # TODO: expand this so it works for any dim dtype = AppState().dtype a = torch.ones((sorted_usage.shape[0], 1)).type(dtype) b = torch.cat((a, sorted_usage), dim=dim).type(dtype) prod_sorted_usage = torch.cumprod(b, dim=dim)[:, :-1] return prod_sorted_usage
def init_state(self, memory_addresses_size, batch_size): dtype = AppState().dtype # Initialize controller state. tuple_ctrl_init_state = self.controller.init_state(batch_size) # Initialize interface state. tuple_interface_init_state = self.interface.init_state( memory_addresses_size, batch_size) # Initialize memory mem_init = (torch.ones( (batch_size, self.M, memory_addresses_size)) * 0.01).type(dtype) return DWMCellStateTuple(tuple_ctrl_init_state, tuple_interface_init_state, mem_init)
def __init__(self, params): """ Initializes problem object. :param params: Dictionary of parameters (read from configuration file). """ # Set default loss function. self.loss_function = None # Store pointer to params. self.params = params # Get access to AppState. self.app_state = AppState() # "Default" problem name. self.name = 'Problem'
def init_state(self, memory_address_size, batch_size): """ Returns 'zero' (initial) state tuple. :param batch_size: Size of the batch in given iteraction/epoch. :returns: Initial state tuple - object of InterfaceStateTuple class. """ dtype = AppState().dtype self._memory_size = memory_address_size link = torch.ones((batch_size, self._num_writes, memory_address_size, memory_address_size)).type(dtype) * 1e-6 precendence_weights = torch.ones( (batch_size, self._num_writes, memory_address_size)).type(dtype) * 1e-6 return TemporalLinkageState(link, precendence_weights)
def init_state(self, batch_size, num_memory_addresses): """ Returns 'zero' (initial) state tuple. :param batch_size: Size of the batch in given iteraction/epoch. :param num_memory_addresses: Number of memory addresses. :returns: Initial state tuple - object of InterfaceStateTuple class. """ dtype = AppState().dtype # Add read head states - one for each read head. read_state_tuples = [] # Initial attention weights [BATCH_SIZE x MEMORY_ADDRESSES x 1] # Initialize attention: to address 0. zh_attention = torch.zeros(batch_size, num_memory_addresses, 1).type(dtype) zh_attention[:, 0, 0] = 1 # Initialize gating: to previous attention (i.e. zero-hard). init_gating = torch.ones(batch_size, 1, 1).type(dtype) # Initialize shift - to zero. init_shift = torch.zeros(batch_size, self.interface_shift_size, 1).type(dtype) init_shift[:, 1, 0] = 1 for i in range(self.interface_num_read_heads): read_ht = HeadStateTuple(zh_attention, zh_attention, init_gating, init_shift) # Single read head tuple. read_state_tuples.append(read_ht) # Single write head tuple. write_state_tuple = HeadStateTuple(zh_attention, zh_attention, init_gating, init_shift) # Return tuple. interface_state = InterfaceStateTuple(read_state_tuples, write_state_tuple) return interface_state
def init_state(self, memory_addresses_size, batch_size): """ Returns 'zero' (initial) state of Interface tuple. :param batch_size: Size of the batch in given iteraction/epoch. :param memory_addresses_size: size of the memory :returns: Initial state tuple - object of InterfaceStateTuple class: (head_weight_init, snapshot_weight_init) """ dtype = AppState().dtype # initial attention vector head_weight_init = torch.zeros( (batch_size, self.num_heads, memory_addresses_size)).type(dtype) head_weight_init[:, 0:self.num_heads, 0] = 1.0 # bookmark snapshot_weight_init = head_weight_init return InterfaceStateTuple(head_weight_init, snapshot_weight_init)
def init_state(self, batch_size, num_memory_addresses, final_encoder_attention_BxAx1): """ Returns 'zero' (initial) state tuple. :param batch_size: Size of the batch in given iteraction/epoch. :param num_memory_addresses: Number of memory addresses. :param final_encoder_attention_BxAx1: final attention of the encoder [BATCH_SIZE x MEMORY_ADDRESSES x 1] :returns: Initial state tuple - object of InterfaceStateTuple class. """ # Get dtype. dtype = AppState().dtype # Initial attention weights [BATCH_SIZE x MEMORY_ADDRESSES x 1] # Zero-hard attention. zh_attention = torch.zeros(batch_size, num_memory_addresses, 1).type(dtype) # Initialize as "hard attention on 0 address" zh_attention[:, 0, 0] = 1 # Gating [BATCH x 3 x 1] init_gating = torch.zeros(batch_size, 3, 1).type(dtype) init_gating[:, 0, 0] = 1 # Initialize as "prev attention" # Shift [BATCH x SHIFT_SIZE x 1] init_shift = torch.zeros(batch_size, self.interface_shift_size, 1).type(dtype) init_shift[:, 1, 0] = 1 # Initialize as "0 shift". # Remember zero-hard attention. self.zero_hard_attention_BxAx1 = zh_attention # Remember final attention of encoder. self.final_encoder_attention_BxAx1 = final_encoder_attention_BxAx1 # Return tuple. return MASInterfaceStateTuple(zh_attention, self.final_encoder_attention_BxAx1, init_gating, init_shift)
def forward(self, encoded_image, encoded_question): """ Apply stacked attention. :param encoded_image: output of the image encoding (CNN + FC layer), [batch_size, new_width * new_height, num_channels_encoded_image] :param encoded_question: last hidden layer of the LSTM, [batch_size, question_encoding_size] :returns: u: attention [batch_size, num_channels_encoded_image] """ for att_layer in self.san: u, attention_prob = att_layer(encoded_image, encoded_question) if AppState().visualize: if self.visualize_attention is None: self.visualize_attention = attention_prob # Concatenate output else: self.visualize_attention = torch.cat( [self.visualize_attention, attention_prob], dim=-1) return u
def validation(model, problem, episode, stat_col, data_valid, aux_valid, FLAGS, logger, validation_file, validation_writer): """ Function performs validation of the model, using the provided data and criterion. Additionally it logs (to files, tensorboard) and visualizes. :param stat_col: Statistic collector object. :return: True if training loop is supposed to end. """ # Turn on evaluation mode. model.eval() # Calculate loss of the validation data. with torch.no_grad(): logits_valid, loss_valid = forward_step(model, problem, episode, stat_col, data_valid, aux_valid) # Log to logger. logger.info(stat_col.export_statistics_to_string('[Validation]')) # Export to csv. stat_col.export_statistics_to_csv(validation_file) if (FLAGS.tensorboard is not None): # Save loss + accuracy to tensorboard. stat_col.export_statistics_to_tensorboard(validation_writer) # Visualization of validation. if AppState().visualize: # True means that we should terminate # Allow for preprocessing data_valid, aux_valid, logits_valid = problem.plot_preprocessing( data_valid, aux_valid, logits_valid) return loss_valid, model.plot(data_valid, logits_valid) # Else simply return false, i.e. continue training. return loss_valid, False
def init_state(self, batch_size, num_memory_addresses): """ Returns 'zero' (initial) state tuple. :param batch_size: Size of the batch in given iteraction/epoch. :param num_memory_addresses: Number of memory addresses. :returns: Initial state tuple - object of InterfaceStateTuple class. """ # Get dtype. dtype = AppState().dtype # Zero-hard attention. zh_attention = torch.zeros(batch_size, num_memory_addresses, 1).type(dtype) zh_attention[:, 0, 0] = 1 # Init gating. init_shift = torch.zeros(batch_size, self.interface_shift_size, 1).type(dtype) init_shift[:, 1, 0] = 1 # Return tuple. return MAEInterfaceStateTuple(zh_attention, init_shift)
:param x: a combination of the attention and question :return: Prediction of the answer [batch_size, num_classes] """ x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = F.dropout(x) x = self.fc3(x) return F.log_softmax(x, dim=-1) if __name__ == '__main__': # Set visualization. from utils.app_state import AppState AppState().visualize = True # Test base model. from utils.param_interface import ParamInterface params = ParamInterface() params.add_custom_params({}) # model model = MultiHopsAttention(params) while True: # Generate new sequence. # "Image" - batch x channels x width x height input_np = np.random.binomial(1, 0.5, (1, 3, 128, 128)) image = torch.from_numpy(input_np).type(torch.FloatTensor)
def __init__(self, set, clevr_dir, clevr_humans, embedding_type='random', random_embedding_dim=300): """ Instantiate a ClevrDataset object: - Mainly check if the files containing the extracted features & tokenized questions already exist. If not, it generates them for the specified sub-set. - self.img contains then the extracted feature maps - self.data contains the tokenized questions, the associated image filenames, the answers & the question string The questions are then embedded based on the specified embedding. This embedding is random by default, but pretrained ones are possible. :param set: String to specify which dataset to use: 'train', 'val' or 'test'. :param clevr_dir: Directory path to the CLEVR_v1.0 dataset. Will also be used to store the generated files (.hdf5, .pkl) :param clevr_humans: Boolean to indicate whether to use the questions from CLEVR-Humans. :param embedding_type: string to indicate the pretrained embedding to use: either 'random' to use nn.Embedding or one of the following: "charngram.100d", "fasttext.en.300d", "fasttext.simple.300d", "glove.42B.300d", "glove.840B.300d", "glove.twitter.27B.25d", "glove.twitter.27B.50d", "glove.twitter.27B.100d", "glove.twitter.27B.200d":, "glove.6B.50d", "glove.6B.100d", "glove.6B.200d", "glove.6B.300d" :param random_embedding_dim: In the case of random embedding, this is the embedding dimension to use. """ # call base constructor super(CLEVRDataset).__init__() # parse params self.set = set self.clevr_dir = clevr_dir self.clevr_humans = clevr_humans self.embedding_type = embedding_type self.random_embedding_dim = random_embedding_dim # Get access to app state. self.app_state = AppState() if self.set == 'test': logger.error('Test set generation not supported for now. Exiting.') exit(0) logger.info('Loading the {} samples from {}'.format( set, 'CLEVR-Humans' if self.clevr_humans else 'CLEVR')) # check if the folder /generated_files in self.clevr already exists, if # not creates it: if not os.path.isdir(self.clevr_dir + '/generated_files'): logger.warning('Folder {} not found, creating it.'.format( self.clevr_dir + '/generated_files')) os.mkdir(self.clevr_dir + '/generated_files') # checking if the file containing the images feature maps (processed by ResNet101) exists or not # For the same self.set, this file is the same for CLEVR & CLEVR-Humans feature_maps_filename = self.clevr_dir + \ '/generated_files/{}_CLEVR_features.hdf5'.format(self.set) if os.path.isfile(feature_maps_filename): logger.info('The file {} already exists, loading it.'.format( feature_maps_filename)) else: logger.warning('File {} not found on disk, generating it:'.format( feature_maps_filename)) self.generate_feature_maps_file(feature_maps_filename) # actually load the file self.h = h5py.File(feature_maps_filename, 'r') self.img = self.h['data'] # checking if the file containing the tokenized questions (& answers, # image filename) exists or not questions_filename = self.clevr_dir + '/generated_files/{}_{}_questions.pkl'.format( self.set, 'CLEVR_Humans' if self.clevr_humans else 'CLEVR') if os.path.isfile(questions_filename): logger.info('The file {} already exists, loading it.'.format( questions_filename)) # load questions with open(questions_filename, 'rb') as questions: self.data = pickle.load(questions) # load word_dics & answer_dics with open(self.clevr_dir + '/generated_files/dics.pkl', 'rb') as f: dic = pickle.load(f) self.answer_dic = dic['answer_dic'] self.word_dic = dic['word_dic'] else: logger.warning( 'File {} not found on disk, generating it.'.format( questions_filename)) # WARNING: We need to ensure that we use the same words & answers dics for both train & val, otherwise we # do not have the same reference! if self.set == 'val' or self.set == 'valA' or self.set == 'valB': # first generate the words dic using the training samples logger.warning( 'We need to ensure that we use the same words-to-index & answers-to-index dictionaries ' 'for both the train & val samples.') logger.warning( 'First, generating the words-to-index & answers-to-index dictionaries from ' 'the training samples :') _, self.word_dic, self.answer_dic = self.generate_questions_dics( 'train' if self.set == 'val' else 'trainA', word_dic=None, answer_dic=None) # then tokenize the questions using the created dictionaries # from the training samples logger.warning( 'Then we can tokenize the validation questions using the dictionaries ' 'created from the training samples') self.data, self.word_dic, self.answer_dic = self.generate_questions_dics( self.set, word_dic=self.word_dic, answer_dic=self.answer_dic) # self.set=='train', we can directly tokenize the questions elif self.set == 'train' or self.set == 'trainA': self.data, self.word_dic, self.answer_dic = self.generate_questions_dics( self.set, word_dic=None, answer_dic=None) # At this point, the objects self.img & self.data contains the feature # maps & questions # creates the objects for the specified embeddings if self.embedding_type == 'random': logger.info( 'Constructing random embeddings using a uniform distribution') # instantiate nn.Embeddings look-up-table with specified # embedding_dim self.n_vocab = len(self.word_dic) + 1 self.embed_layer = torch.nn.Embedding( num_embeddings=self.n_vocab, embedding_dim=self.random_embedding_dim) # we have to make sure that the weights are the same during train # or val! if os.path.isfile(self.clevr_dir + '/generated_files/random_embedding_weights.pkl'): logger.info( 'Found random embedding weights on file, using them.') with open(self.clevr_dir + '/generated_files/random_embedding_weights.pkl', 'rb') as f: self.embed_layer.weight.data = pickle.load(f) else: logger.warning( 'No weights found on file for random embeddings. Initializing them from a Uniform ' 'distribution and saving to file in {}'.format( self.clevr_dir + '/generated_files/random_embedding_weights.pkl')) self.embed_layer.weight.data.uniform_(0, 1) with open(self.clevr_dir + '/generated_files/random_embedding_weights.pkl', 'wb') as f: pickle.dump(self.embed_layer.weight.data, f) else: logger.info('Constructing embeddings using {}'.format( self.embedding_type)) # instantiate Language class self.language = Language('lang') self.questions = [q['string_question'] for q in self.data] # use the questions set to construct the embeddings vectors self.language.build_pretrained_vocab( self.questions, vectors=self.embedding_type)
x = self.f_fc2(x) x = F.relu(x) x = F.dropout(x, p=0.5) x = self.f_fc3(x) return x if __name__ == '__main__': """ Unit Tests for g_theta & f_phi. """ input_size = (24 + 2) * 2 + 13 batch_size = 64 inputs = np.random.binomial(1, 0.5, (batch_size, 3, input_size)) inputs = torch.from_numpy(inputs).type(AppState().dtype) params_g = {'input_size': input_size} g_theta = PairwiseRelationNetwork(params_g) g_outputs = g_theta(inputs) print('g_outputs:', g_outputs.shape) output_size = 29 params_f = {'output_size': output_size} f_phi = SumOfPairsAnalysisNetwork(params_f) f_outputs = f_phi(g_outputs) print('f_outputs:', f_outputs.shape)
if __name__ == "__main__": input_size = 28 params_dict = { 'context_input_size': 32, 'input_size': input_size, 'output_size': 10, 'center_size': 1, 'center_size_per_module': 32, 'num_modules': 4 } # Initialize the application state singleton. from utils.app_state import AppState app_state = AppState() app_state.visualize = True from utils.param_interface import ParamInterface params = ParamInterface() params.add_custom_params(params_dict) model = ThalNetModel(params) seq_length = 10 batch_size = 2 # Check for different seq_lengts and batch_sizes. for i in range(1): # Create random Tensors to hold inputs and outputs x = torch.randn(batch_size, 1, input_size, input_size) logits = torch.randn(batch_size, 1, params_dict['output_size'])
param_interface["training"].add_custom_params({"seed_torch": seed}) logger.info("Setting torch random seed to: {}".format( param_interface["training"]["seed_torch"])) torch.manual_seed(param_interface["training"]["seed_torch"]) torch.cuda.manual_seed_all(param_interface["training"]["seed_torch"]) if "seed_numpy" not in param_interface["training"] or param_interface[ "training"]["seed_numpy"] == -1: seed = randrange(0, 2**32) param_interface["training"].add_custom_params({"seed_numpy": seed}) logger.info("Setting numpy random seed to: {}".format( param_interface["training"]["seed_numpy"])) np.random.seed(param_interface["training"]["seed_numpy"]) # Initialize the application state singleton. app_state = AppState() # check if CUDA is available turn it on check_and_set_cuda(param_interface['training'], logger) # Build the model. model = ModelFactory.build_model(param_interface['model']) model.cuda() if app_state.use_CUDA else None # Log the model summary. logger.info(model.summarize()) # Build problem for the training problem = ProblemFactory.build_problem( param_interface['training']['problem'])
x = self.batchNorm2(x) x = F.relu(x) x = self.conv3(x) x = self.batchNorm3(x) x = F.relu(x) x = self.conv4(x) x = self.batchNorm4(x) x = F.relu(x) return x if __name__ == '__main__': """ Unit Test for the ConvInputModel. """ # "Image" - batch x channels x width x height batch_size = 64 img_size = 128 input_np = np.random.binomial(1, 0.5, (batch_size, 3, img_size, img_size)) image = torch.from_numpy(input_np).type(AppState().dtype) cnn = ConvInputModel() feature_maps = cnn(image) print('feature_maps:', feature_maps.shape)
return x_out if __name__ == '__main__': question_size = 13 input_size = (24 + 2) * 2 + question_size output_size = 29 params = { 'g_theta': { 'input_size': input_size }, 'f_phi': { 'output_size': output_size } } batch_size = 128 img_size = 128 images = np.random.binomial(1, 0.5, (batch_size, 3, img_size, img_size)) images = torch.from_numpy(images).type(AppState().dtype) questions = np.random.binomial(1, 0.5, (batch_size, question_size)) questions = torch.from_numpy(questions).type(AppState().dtype) targets = None net = RelationalNetwork(params) net(((images, questions), targets))
self.plotWindow.update(fig, frames) return self.plotWindow.is_closed if __name__ == '__main__': dim = 512 embed_hidden = 300 max_step = 12 self_attention = True memory_gate = True nb_classes = 28 dropout = 0.15 from utils.app_state import AppState app_state = AppState() from utils.param_interface import ParamInterface params = ParamInterface() params.add_custom_params({ 'dim': dim, 'embed_hidden': embed_hidden, 'max_step': 12, 'self_attention': self_attention, 'memory_gate': memory_gate, 'nb_classes': nb_classes, 'dropout': dropout }) net = MACNetwork(params)
def circular_convolution(self, attention_BxAx1, shift_BxSx1): """ Performs circular convolution, i.e. shitfts the attention accodring to given shift vector (convolution mask). :param attention_BxAx1: Current attention [BATCH_SIZE x ADDRESS_SIZE x 1] :param shift_BxSx1: soft shift maks (convolutional kernel) [BATCH_SIZE x SHIFT_SIZE x 1] :returns: attention vector of size [BATCH_SIZE x ADDRESS_SIZE x 1] """ def circular_index(idx, num_addr): """ Calculates the index, taking into consideration the number of addresses in memory. :param idx: index (single element) :param num_addr: number of addresses in memory """ if idx < 0: return num_addr + idx elif idx >= num_addr: return idx - num_addr else: return idx # Check whether inputs are already on GPU or not. #dtype = torch.cuda.LongTensor if attention_BxAx1.is_cuda else torch.LongTensor dtype = AppState().LongTensor # Get number of memory addresses and batch size. batch_size = attention_BxAx1.size(0) num_addr = attention_BxAx1.size(1) shift_size = self.interface_shift_size #logger.debug("shift_BxSx1 {}: {}".format(shift_BxSx1, shift_BxSx1.size())) # Create an extended list of indices indicating what elements of the # sequence will be where. ext_indices_tensor = torch.Tensor([ circular_index(shift, num_addr) for shift in range(-shift_size // 2 + 1, num_addr + shift_size // 2) ]).type(dtype) #logger.debug("ext_indices {}:\n {}".format(ext_indices_tensor.size(), ext_indices_tensor)) # Use indices for creation of an extended attention vector. ext_attention_BxEAx1 = torch.index_select(attention_BxAx1, dim=1, index=ext_indices_tensor) #logger.debug("ext_attention_BxEAx1 {}:\n {}".format(ext_attention_BxEAx1.size(), ext_attention_BxEAx1)) # Transpose inputs to convolution. ext_att_trans_Bx1xEA = torch.transpose(ext_attention_BxEAx1, 1, 2) shift_trans_Bx1xS = torch.transpose(shift_BxSx1, 1, 2) # Perform convolution for every batch-filter pair. tmp_attention_list = [] for b in range(batch_size): tmp_attention_list.append( F.conv1d(ext_att_trans_Bx1xEA.narrow(0, b, 1), shift_trans_Bx1xS.narrow(0, b, 1))) # Concatenate list into a single tensor. shifted_attention_BxAx1 = torch.transpose( torch.cat(tmp_attention_list, dim=0), 1, 2) #logger.debug("shifted_attention_BxAx1 {}:\n {}".format(shifted_attention_BxAx1.size(), shifted_attention_BxAx1)) return shifted_attention_BxAx1