Esempio n. 1
0
    def __init__(self, num_channels, num_actions, N=32, num_cosines=32,
                 embedding_dim=64, dueling_net=False, noisy_net=False,
                 target=False):
        super(FQF, self).__init__()

        # Feature extractor of DQN.
        self.dqn_net = DQNBase(num_channels=num_channels)
        # Cosine embedding network.
        self.cosine_net = CosineEmbeddingNetwork(
            num_cosines=num_cosines, embedding_dim=embedding_dim,
            noisy_net=noisy_net)
        # Quantile network.
        self.quantile_net = QuantileNetwork(
            num_actions=num_actions, dueling_net=dueling_net,
            noisy_net=noisy_net)

        # Fraction proposal network.
        if not target:
            self.fraction_net = FractionProposalNetwork(
                N=N, embedding_dim=embedding_dim)

        self.N = N
        self.num_channels = num_channels
        self.num_actions = num_actions
        self.num_cosines = num_cosines
        self.embedding_dim = embedding_dim
        self.dueling_net = dueling_net
        self.noisy_net = noisy_net
        self.target = target
Esempio n. 2
0
    def __init__(self,
                 num_channels,
                 num_actions,
                 K=32,
                 num_cosines=32,
                 embedding_dim=7 * 7 * 64,
                 dueling_net=False,
                 noisy_net=False):
        super(IQN, self).__init__()

        # Feature extractor of DQN.
        self.dqn_net = DQNBase(num_channels=num_channels)
        # Cosine embedding network.
        self.cosine_net = CosineEmbeddingNetwork(num_cosines=num_cosines,
                                                 embedding_dim=embedding_dim,
                                                 noisy_net=noisy_net)
        # Quantile network.
        self.quantile_net = QuantileNetwork(num_actions=num_actions,
                                            dueling_net=dueling_net,
                                            noisy_net=noisy_net)

        self.K = K
        self.num_channels = num_channels
        self.num_actions = num_actions
        self.num_cosines = num_cosines
        self.embedding_dim = embedding_dim
        self.dueling_net = dueling_net
        self.noisy_net = noisy_net
Esempio n. 3
0
    def __init__(self, num_channels, num_actions, N=200, embedding_dim=7*7*64,
                 dueling_net=False, noisy_net=False):
        super(QRDQN, self).__init__()
        linear = NoisyLinear if noisy_net else nn.Linear

        # Feature extractor of DQN.
        self.dqn_net = DQNBase(num_channels=num_channels)
        # Quantile network.
        if not dueling_net:
            self.q_net = nn.Sequential(
                linear(embedding_dim, 512),
                nn.ReLU(),
                linear(512, num_actions * N),
            )
        else:
            self.advantage_net = nn.Sequential(
                linear(embedding_dim, 512),
                nn.ReLU(),
                linear(512, num_actions * N),
            )
            self.baseline_net = nn.Sequential(
                linear(embedding_dim, 512),
                nn.ReLU(),
                linear(512, N),
            )

        self.N = N
        self.num_channels = num_channels
        self.num_actions = num_actions
        self.embedding_dim = embedding_dim
        self.dueling_net = dueling_net
        self.noisy_net = noisy_net
Esempio n. 4
0
 def __init__(self,
              num_channels,
              num_actions,
              num_gaussians=5,
              embedding_dim=7 * 7 * 64,
              dueling_net=False,
              noisy_net=False):
     super(DMoGQ, self).__init__()
     # Feature extractor of DQN: Mapping the state, i.e., image, to the embedding
     self.dqn_net = DQNBase(num_channels=num_channels)
     # Then, mapping the embedding to the Q value distribution
     self.mog_net = MoGNet(num_actions=num_actions,
                           num_gaussians=num_gaussians)
Esempio n. 5
0
# import torch
# print(cdf_gauss(torch.tensor(0.0), torch.tensor(0.0), torch.tensor(1.0)))

import os
import yaml
import argparse
from datetime import datetime

import torch

from fqf_iqn_qrdqn.env import make_pytorch_env
from fqf_iqn_qrdqn.network import DQNBase
from DMoGDiscrete.network import MoGNet

env = make_pytorch_env('PongNoFrameskip-v4')
print(env.action_space.n)
dqn_net = DQNBase(num_channels=env.observation_space.shape[0])
dmog_net = MoGNet(env.action_space.n)

state = env.reset()
state = torch.ByteTensor(state).unsqueeze(0).to('cpu').float() / 255.
dqn_out = dqn_net(state)
# print(dqn_out.shape)
out_pi, out_mu, out_sigma = dmog_net(dqn_out)
print(out_pi)
print(out_mu)
print(out_sigma)
print(out_pi[0, :, 3])
print(torch.sum(out_pi, dim=2, keepdim=True))
print(torch.sum(out_pi, dim=2, keepdim=True).mean())