def reset_block_lr(self, bidx, lr=0.0): ''' Sets the learning rate of the connections from hidden to hidden t+1 ''' block = self.blocks[bidx] block[-1] = lr r1, r2, c1, c2, lr = self.blocks[bidx] units = util.get_diagonal_elements_range(block) # All the units have already been added assert self.active_units[r1:r2].nonzero().numel() == r2 - r1 assert self.active_units[c1:c2].nonzero().numel() == c2 - c1 if units is not None: # We have some diagonal units unit_from, unit_to = units lr_hhb = self.lr_bias_hh lr_hhb, = util.channel_view(self.nchan, lr_hhb) lr_hhb[:, unit_from:unit_to].fill_(lr) lr_hh = self.lr_weight_hh lr_hh, = util.channel_view(self.nchan, lr_hh) lr_hh[:, r1:r2, c1:c2].fill_(lr)
def blockwise_broadcast(self, vals, t=None): ''' blockwise broadcast of values into an empty matrix ''' assert len(vals) == len(self.blocks) result = t if t is not None else \ torch.zeros(self.rnn.state_dict()['weight_hh'].size()) c_result, = util.channel_view(self.nchan, result) for i, b in enumerate(self.blocks): c_result[:, b[0]:b[1], b[2]:b[3]] = vals[i] return result
def init_weights(self, parent): self.encoder.weight.data.uniform_(-initrange, initrange) self.decoder.bias.data.fill_(0) self.decoder.weight.data.uniform_(-initrange, initrange) # Weight Inheritance if parent is not None: new_state = self.state_dict() old_state = parent.state_dict() for k in new_state: src, dst = old_state[k], new_state[k] # dst.zero_() if 'rnn' in k and new_state[k].size(0) > self.nhid: src, dst = util.channel_view(self.nchan, src, dst) util.subset(src, dst).copy_(src)
def add_block(self, b, init=True): ''' add a block to the model, optionally initialize the weights uniformly or with a closure ''' self.logger.info('Adding block {}'.format(b)) util.check_block(b, self.nhid) if self.strict: for b2 in self.blocks: assert not util.overlap(b, b2) self.blocks.append(b) r1, r2, c1, c2, lr = b if init: new_units = util.get_diagonal_elements_range(b) if new_units is None: # If we're not adding any new units, check that the connections # we add are to units that have already been initialized. assert self.active_units[b[0]] else: self.logger.info('Adding new units {}'.format(new_units)) unit_from, unit_to = new_units self.active_units[unit_from:unit_to] = 1 # Initialize input connections to new units ih = self.rnn.weight_ih.data ih, = util.channel_view(self.nchan, ih) ih[:, unit_from:unit_to, :].uniform_(-initrange, initrange) lr_ih = self.lr_weight_ih lr_ih, = util.channel_view(self.nchan, lr_ih) lr_ih[:, unit_from:unit_to, :].fill_(lr) ihb = self.rnn.bias_ih.data ihb, = util.channel_view(self.nchan, ihb) ihb[:, unit_from:unit_to].uniform_(-initrange, initrange) lr_ihb = self.lr_bias_ih lr_ihb, = util.channel_view(self.nchan, lr_ihb) lr_ihb[:, unit_from:unit_to, :].fill_(lr) # Initialize the bias of recurrent connections (weights are # initialized below) hhb = self.rnn.bias_hh.data hhb, = util.channel_view(self.nchan, hhb) hhb[:, unit_from:unit_to].uniform_(-initrange, initrange) # Initialize decoder connections to new units: dw = self.decoder.weight.data dw[:, unit_from:unit_to].uniform_(-initrange, initrange) #print('decoder matrix') #print(dw) self.lr_decoder_weight[:, unit_from:unit_to].fill_(lr) # Initialize hidden to hidden connections. # This update is the same irrespective of whether the block is # diagonal (adding new units) or non-diagonal (new connections # between previously initialized units). hh = self.rnn.weight_hh.data hh, = util.channel_view(self.nchan, hh) hh[:, b[0]:b[1], b[2]:b[3]].uniform_(-initrange, initrange) self.set_block_lr(self.get_block(b), lr) if len(self.blocks) > 1 and self.frozen_h: self.reset_block_lr(self.get_block(self.blocks[-2])) bidx = len(self.blocks) return bidx