Пример #1
0
    def __init__(self,
                 perception_unit,
                 use_gru=False,
                 internal_state_size=512,
                 perception_unit_class=None,
                 perception_unit_kwargs={}):
        super(NaivelyRecurrentACModule, self).__init__()
        self._internal_state_size = internal_state_size

        if use_gru:
            self.gru = nn.GRUCell(input_size=internal_state_size,
                                  hidden_size=internal_state_size)
            # nn.init.orthogonal_(self.gru.weight_ih.data)
            # nn.init.orthogonal_(self.gru.weight_hh.data)
            # self.gru.bias_ih.data.fill_(0)
            # self.gru.bias_hh.data.fill_(0)

        if perception_unit is None:
            self.perception_unit = eval(perception_unit_class)(
                **perception_unit_kwargs)
        else:
            self.perception_unit = perception_unit

        # Make the critic
        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0))
        self.critic_linear = init_(nn.Linear(internal_state_size, 1))
Пример #2
0
def atari_small_conv(num_inputs):
    init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(
        x, 0), nn.init.calculate_gain('relu'))

    return nn.Sequential(init_(nn.Conv2d(num_inputs, 32, 8, stride=4)),
                         nn.ReLU(), init_(nn.Conv2d(32, 64, 4, stride=2)),
                         nn.ReLU())
Пример #3
0
def atari_nature_vae(num_inputs, num_outputs=512):
    init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(
        x, 0), nn.init.calculate_gain('relu'))

    nn.Sequential(init_(nn.Conv2d(num_inputs, 32, 8, stride=4)), nn.ReLU(),
                  init_(nn.Conv2d(32, 64, 4, stride=2)), nn.ReLU(),
                  init_(nn.Conv2d(64, 32, 3, stride=1)), nn.ReLU(), Flatten(),
                  init_(nn.Linear(32 * 7 * 7, num_outputs)), nn.ReLU())
Пример #4
0
def atari_match_conv(num_frames, num_inputs_per_frame):
    # Expected Input size: (3*N, 84, 84)
    # Expected Output size: (8*N, 16, 16)
    num_inputs = num_frames * num_inputs_per_frame
    init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(
        x, 0), nn.init.calculate_gain('relu'))
    return nn.Sequential(init_(nn.Conv2d(num_inputs, 64, 8, stride=4)),
                         nn.ReLU(),
                         init_(nn.Conv2d(64, 8 * num_frames, 5, stride=1)),
                         nn.ReLU())
Пример #5
0
def atari_big_conv(num_frames, num_inputs_per_frame):
    # Expected Input size: (3*N, 256, 256)
    # Expected Output size: (8*N, 16, 16)
    # TODO get the dimenions here working
    num_inputs = num_frames * num_inputs_per_frame
    init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(
        x, 0), nn.init.calculate_gain('relu'))
    return nn.Sequential(init_(nn.Conv2d(num_inputs, 64, 8, stride=4)),
                         nn.ReLU(), init_(nn.Conv2d(64, 64, 5, stride=1)),
                         nn.ReLU(),
                         init_(nn.Conv2d(64, 8 * num_frames, 5, stride=1)),
                         nn.ReLU())
Пример #6
0
    def __init__(self, img_channels, latent_size):
        super().__init__()
        self.latent_size = latent_size
        #self.img_size = img_size
        self.img_channels = img_channels

        init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.
                               constant_(x, 0), nn.init.calculate_gain('relu'))

        self.conv1 = init_(nn.Conv2d(img_channels, 32, 8, stride=4))
        self.conv2 = init_(nn.Conv2d(32, 64, 4, stride=2))
        self.conv3 = init_(nn.Conv2d(64, 32, 3, stride=1))

        self.flatten = Flatten()
        self.fc1 = init_(nn.Linear(32 * 7 * 7, latent_size))
import multiprocessing
import numpy as np
import os
import torch
import torch
import torch.nn as nn
from torch.nn import Parameter, ModuleList
import torch.nn.functional as F

from evkit.rl.utils import init, init_normc_
from evkit.utils.misc import is_cuda
from evkit.preprocess import transforms

import pickle as pkl

init_ = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(
    x, 0), nn.init.calculate_gain('relu'))


################################
# Inverse Models
#   Predict  s_{t+1} | s_t, a_t
################################
class ForwardModel(nn.Module):
    def __init__(self, state_shape, action_shape, hidden_size):
        super().__init__()
        self.fc1 = init_(nn.Linear(state_shape + action_shape[1], hidden_size))
        self.fc2 = init_(nn.Linear(hidden_size, state_shape))

    def forward(self, state, action):
        x = torch.cat([state, action], 1)
        x = F.relu(self.fc1(x))