コード例 #1
0
    def predict(self, board: np.ndarray):
        start = time()

        board = torch.FloatTensor(board.astype(np.float64))  # TODO: Make sure this doesn't modify game.board in place
        if args["cuda"]:
            board = board.contiguous().cuda()
        board = board.view(1, self.board_x, self.board_y)  # Reshape? TODO: Figure out why this is 1 deep instead of two splitting up the players like in the paper
        self.net.eval()
        with torch.no_grad():  # No need to compute gradients
            policy, value = self.net(board)

        inference_time = time() - start
        return torch.exp(policy).data.cpu().numpy()[0], value.data.cpu().numpy()[0]
コード例 #2
0
    def encode_rnn(self, feats: numpy.ndarray) -> torch.Tensor:
        """Encode acoustic features.

        Args:
            feats: Feature sequence. (F, D_feats)

        Returns:
            enc_out: Encoded feature sequence. (T, D_enc)

        """
        p = next(self.parameters())

        feats_len = [feats.shape[0]]

        feats = feats[::self.subsample[0], :]
        feats = torch.as_tensor(feats, device=p.device, dtype=p.dtype)
        feats = feats.contiguous().unsqueeze(0)

        enc_out, _, _ = self.enc(feats, feats_len)

        return enc_out.squeeze(0)
コード例 #3
0
ファイル: NNet.py プロジェクト: MarkTakken/alpha-zero-general
    def predict(self, board: np.ndarray):
        """
        board: np array with board
        """
        # timing
        #start = time.time()

        # preparing input
        board = torch.as_tensor(board, dtype=torch.float)
        if args.cuda:
            board = board.contiguous().cuda()
            self.nnet.cuda()

        board = board.view(1, 2 * self.history + 1, self.board_x, self.board_y)
        self.nnet.eval()

        with torch.no_grad():
            p, v = self.nnet(board)

        #print(p.shape, v.shape)
        #print(v.data.cpu().numpy()[0])
        #print('PREDICTION TIME TAKEN : {0:03f}'.format(time.time()-start))
        return p.data.cpu().numpy()[0], v.data.cpu().numpy()[0]