    def reset_block_lr(self, bidx, lr=0.0):
        Sets the learning rate of the connections from hidden to hidden t+1

        block = self.blocks[bidx]
        block[-1] = lr
        r1, r2, c1, c2, lr = self.blocks[bidx]

        units = util.get_diagonal_elements_range(block)

        # All the units have already been added
        assert self.active_units[r1:r2].nonzero().numel() == r2 - r1
        assert self.active_units[c1:c2].nonzero().numel() == c2 - c1

        if units is not None:
            # We have some diagonal units
            unit_from, unit_to = units

            lr_hhb = self.lr_bias_hh
            lr_hhb, = util.channel_view(self.nchan, lr_hhb)
            lr_hhb[:, unit_from:unit_to].fill_(lr)

        lr_hh = self.lr_weight_hh
        lr_hh, = util.channel_view(self.nchan, lr_hh)
        lr_hh[:, r1:r2, c1:c2].fill_(lr)
    def blockwise_broadcast(self, vals, t=None):
        ''' blockwise broadcast of values into an empty matrix '''

        assert len(vals) == len(self.blocks)

        result = t if t is not None else \
        c_result, = util.channel_view(self.nchan, result)

        for i, b in enumerate(self.blocks):
            c_result[:, b[0]:b[1], b[2]:b[3]] = vals[i]

        return result
    def init_weights(self, parent):
        self.encoder.weight.data.uniform_(-initrange, initrange)

        self.decoder.weight.data.uniform_(-initrange, initrange)

        # Weight Inheritance
        if parent is not None:
            new_state = self.state_dict()
            old_state = parent.state_dict()

            for k in new_state:
                src, dst = old_state[k], new_state[k]

                # dst.zero_()

                if 'rnn' in k and new_state[k].size(0) > self.nhid:
                    src, dst = util.channel_view(self.nchan, src, dst)

                util.subset(src, dst).copy_(src)
    def add_block(self, b, init=True):
        add a block to the model, optionally initialize the weights
        uniformly or with a closure

        self.logger.info('Adding block {}'.format(b))

        util.check_block(b, self.nhid)
        if self.strict:
            for b2 in self.blocks:
                assert not util.overlap(b, b2)

        r1, r2, c1, c2, lr = b

        if init:
            new_units = util.get_diagonal_elements_range(b)
            if new_units is None:
                # If we're not adding any new units, check that the connections
                # we add are to units that have already been initialized.
                assert self.active_units[b[0]]
                self.logger.info('Adding new units {}'.format(new_units))
                unit_from, unit_to = new_units

                self.active_units[unit_from:unit_to] = 1

                # Initialize input connections to new units
                ih = self.rnn.weight_ih.data
                ih, = util.channel_view(self.nchan, ih)
                ih[:, unit_from:unit_to, :].uniform_(-initrange, initrange)

                lr_ih = self.lr_weight_ih
                lr_ih, = util.channel_view(self.nchan, lr_ih)
                lr_ih[:, unit_from:unit_to, :].fill_(lr)

                ihb = self.rnn.bias_ih.data
                ihb, = util.channel_view(self.nchan, ihb)
                ihb[:, unit_from:unit_to].uniform_(-initrange, initrange)

                lr_ihb = self.lr_bias_ih
                lr_ihb, = util.channel_view(self.nchan, lr_ihb)
                lr_ihb[:, unit_from:unit_to, :].fill_(lr)

                # Initialize the bias of recurrent connections (weights are
                # initialized below)
                hhb = self.rnn.bias_hh.data
                hhb, = util.channel_view(self.nchan, hhb)
                hhb[:, unit_from:unit_to].uniform_(-initrange, initrange)

                # Initialize decoder connections to new units:
                dw = self.decoder.weight.data
                dw[:, unit_from:unit_to].uniform_(-initrange, initrange)
                #print('decoder matrix')
                self.lr_decoder_weight[:, unit_from:unit_to].fill_(lr)

            # Initialize hidden to hidden connections.
            # This update is the same irrespective of whether the block is
            # diagonal (adding new units) or non-diagonal (new connections
            # between previously initialized units).
            hh = self.rnn.weight_hh.data
            hh, = util.channel_view(self.nchan, hh)
            hh[:, b[0]:b[1], b[2]:b[3]].uniform_(-initrange, initrange)

            self.set_block_lr(self.get_block(b), lr)
            if len(self.blocks) > 1 and self.frozen_h:
        bidx = len(self.blocks)
        return bidx