예제 #1
0
    def forward(self, input, hx=None, cx=None, doDropMC=False):
        if self.dr > 0 and (doDropMC is True or self.training is True):
            doDrop = True
        else:
            doDrop = False

        batchSize = input.size(1)

        if hx is None:
            hx = input.new_zeros(
                1, batchSize, self.hiddenSize, requires_grad=False)
        if cx is None:
            cx = input.new_zeros(
                1, batchSize, self.hiddenSize, requires_grad=False)

        # cuDNN backend - disabled flat weight
        handle = torch.backends.cudnn.get_handle()
        if doDrop is True:
            self.reset_mask()
            weight = [dropMask.apply(self.w_ih, self.maskW_ih, True),
                      dropMask.apply(self.w_hh, self.maskW_hh, True),
                      self.b_ih, self.b_hh]
        else:
            weight = [self.w_ih, self.w_hh, self.b_ih, self.b_hh]

        output, hy, cy, reserve, new_weight_buf = torch._cudnn_rnn(
            input, weight, 4, None, hx, cx,
            torch.backends.cudnn.CUDNN_LSTM, self.hiddenSize,
            1, False, 0, self.training, False, (), None)
        return output, (hy, cy)
예제 #2
0
파일: rnn.py 프로젝트: xy1802/PYNQ-Torch
    def forward(input, weight, hx, batch_sizes):
        if mode == cudnn.CUDNN_LSTM:
            hx, cx = hx
        else:
            cx = None

        handle = cudnn.get_handle()
        with torch.cuda.device(input.get_device()):
            dropout_ts = cudnn.rnn.init_dropout_state(dropout, train, dropout_seed, dropout_state)

        weight_arr = list(itertools.chain.from_iterable(weight))
        weight_stride0 = len(weight[0])

        output, hy, cy, reserve, new_weight_buf = torch._cudnn_rnn(
            input, weight_arr, weight_stride0,
            flat_weight,
            hx, cx,
            mode, hidden_size, num_layers,
            batch_first, dropout, train, bool(bidirectional),
            list(batch_sizes.data) if variable_length else (),
            dropout_ts)

        if cx is not None:
            return (output, (hy, cy))
        else:
            return (output, hy)
예제 #3
0
    def forward(self, input, hx=None, cx=None, do_drop_mc=False, dropout_false=False):
        # dropout_false: it will ensure do_drop is false, unless do_drop_mc is true
        if dropout_false and (not do_drop_mc):
            do_drop = False
        elif self.dr > 0 and (do_drop_mc is True or self.training is True):
            do_drop = True
        else:
            do_drop = False

        batch_size = input.size(1)

        if hx is None:
            hx = input.new_zeros(1, batch_size, self.hidden_size, requires_grad=False)
        if cx is None:
            cx = input.new_zeros(1, batch_size, self.hidden_size, requires_grad=False)

        # cuDNN backend - disabled flat weight
        freeze_mask = False
        handle = torch.backends.cudnn.get_handle()
        if do_drop is True:
            if not freeze_mask:
                self.reset_mask()
            weight = [
                DropMask.apply(self.w_ih, self.mask_w_ih, True),
                DropMask.apply(self.w_hh, self.mask_w_hh, True), self.b_ih,
                self.b_hh
            ]
        else:
            weight = [self.w_ih, self.w_hh, self.b_ih, self.b_hh]

        output, hy, cy, reserve, new_weight_buf = torch._cudnn_rnn(
            input, weight, 4, None, hx, cx, torch.backends.cudnn.CUDNN_LSTM,
            self.hidden_size, 1, False, 0, self.training, False, (), None)
        return output, (hy, cy)
예제 #4
0
    def forward(self,
                input,
                hx=None,
                cx=None,
                doDropMC=False,
                dropoutFalse=False):
        # dropoutFalse: it will ensure doDrop is false, unless doDropMC is true
        if dropoutFalse and (not doDropMC):
            doDrop = False
        elif self.dr > 0 and (doDropMC is True or self.training is True):
            doDrop = True
        else:
            doDrop = False

        batchSize = input.size(1)

        if hx is None:
            hx = input.new_zeros(1,
                                 batchSize,
                                 self.hiddenSize,
                                 requires_grad=False)
        if cx is None:
            cx = input.new_zeros(1,
                                 batchSize,
                                 self.hiddenSize,
                                 requires_grad=False)

        # cuDNN backend - disabled flat weight
        # handle = torch.backends.cudnn.get_handle()
        if doDrop is True:
            self.reset_mask()
            weight = [
                DropMask.apply(self.w_ih, self.maskW_ih, True),
                DropMask.apply(self.w_hh, self.maskW_hh, True),
                self.b_ih,
                self.b_hh,
            ]
        else:
            weight = [self.w_ih, self.w_hh, self.b_ih, self.b_hh]

        # output, hy, cy, reserve, new_weight_buf = torch._cudnn_rnn(
        # input, weight, 4, None, hx, cx, torch.backends.cudnn.CUDNN_LSTM,
        # self.hiddenSize, 1, False, 0, self.training, False, (), None)
        if torch.__version__ < "1.8":
            output, hy, cy, reserve, new_weight_buf = torch._cudnn_rnn(
                input,
                weight,
                4,
                None,
                hx,
                cx,
                2,  # 2 means LSTM
                self.hiddenSize,
                1,
                False,
                0,
                self.training,
                False,
                (),
                None,
            )
        else:
            output, hy, cy, reserve, new_weight_buf = torch._cudnn_rnn(
                input,
                weight,
                4,
                None,
                hx,
                cx,
                2,  # 2 means LSTM
                self.hiddenSize,
                0,
                1,
                False,
                0,
                self.training,
                False,
                (),
                None,
            )
        return output, (hy, cy)