Beispiel #1
0
class ActorCritic(torch.nn.Module):
    def __init__(self, observation_space, action_space, n_frames, args):
        super(ActorCritic, self).__init__()

        # State preprocessing
        self.senc_nngrid = senc_NNGrid(args)
        self.frame_stack = FrameStack(n_frames)

        # Action postprocessing
        self.adec_nngrid = adec_NNGrid(action_space, args)

        self.observation_space = observation_space
        self.action_space = action_space

        self.input_size = self.senc_nngrid.observation_space.shape
        self.output_size = int(np.prod(self.action_space.shape))

        _s = [32, 64, 128, 128]
        self.convlstm1 = ConvLSTM(self.frame_stack.n_frames *
                                  self.input_size[0],
                                  32,
                                  4,
                                  stride=1,
                                  padding=1)
        self.convlstm2 = ConvLSTM(32, 64, 3, stride=1, padding=1)
        self.convlstm3 = ConvLSTM(64, 128, 3, stride=1, padding=1)
        self.convlstm4 = ConvLSTM(128, 128, 3, stride=1, padding=1)
        self.convlstm = [
            self.convlstm1,
            self.convlstm2,
            self.convlstm3,
            self.convlstm4,
        ]
        _is = (n_frames * self.input_size[0], ) + self.input_size[1:]
        self.memsizes = []
        for i in range(len(self.convlstm)):
            _is = self.convlstm[i]._spatial_size_output_given_input((1, ) +
                                                                    _is)
            _is = (_s[i], ) + _is
            self.memsizes.append(copy.deepcopy(_is))

        self.critic_linear = nn.Conv2d(128, 1, 3, stride=1, padding=1)
        self.actor_linear = nn.Conv2d(128, 1, 3, stride=1, padding=1)
        self.actor_linear2 = nn.Conv2d(128, 1, 3, stride=1, padding=1)

        self.actor_linear.weight.data = norm_col_init(
            self.actor_linear.weight.data, 0.01)
        self.actor_linear.bias.data.fill_(0)
        self.actor_linear2.weight.data = norm_col_init(
            self.actor_linear2.weight.data, 0.01)
        self.actor_linear2.bias.data.fill_(0)
        self.critic_linear.weight.data = norm_col_init(
            self.critic_linear.weight.data, 1.0)
        self.critic_linear.bias.data.fill_(0)

        self.train()

    def _convlstmforward(self, x, convhx, convcx):
        last_convhx = x
        for i in range(len(self.convlstm)):
            convhx[i], convcx[i] = self.convlstm[i](last_convhx,
                                                    (convhx[i], convcx[i]))
            last_convhx = convhx[i]
        return convhx, convcx

    def forward(self, inputs):
        ob, info, (convhx, convcx, frames) = inputs

        # Get the grid state from vectorized input
        x = self.senc_nngrid((ob, info))

        # Stack it
        x, frames = self.frame_stack((x, frames))

        # Resize to correct dims for convnet
        batch_size = x.size(0)
        x = x.view(batch_size, self.frame_stack.n_frames * self.input_size[0],
                   self.input_size[1], self.input_size[2])
        convhx, convcx = self._convlstmforward(x, convhx, convcx)
        x = convhx[-1]

        # Compute action mean, action var and value grid
        critic_out = self.critic_linear(x)
        actor_out = F.softsign(self.actor_linear(x))
        actor_out2 = self.actor_linear2(x)

        # Extract motor-specific values from action grid
        critic_out = self.adec_nngrid((critic_out, info)).mean(-1,
                                                               keepdim=True)
        actor_out = self.adec_nngrid((actor_out, info))
        actor_out2 = self.adec_nngrid((actor_out2, info))
        return critic_out, actor_out, actor_out2, (convhx, convcx, frames)

    def initialize_memory(self):
        if next(self.parameters()).is_cuda:
            return ([
                Variable(torch.zeros((1, ) + self.memsizes[i]).cuda())
                for i in range(len(self.memsizes))
            ], [
                Variable(torch.zeros((1, ) + self.memsizes[i]).cuda())
                for i in range(len(self.memsizes))
            ], self.frame_stack.initialize_memory())
        return ([
            Variable(torch.zeros((1, ) + self.memsizes[i]))
            for i in range(len(self.memsizes))
        ], [
            Variable(torch.zeros((1, ) + self.memsizes[i]))
            for i in range(len(self.memsizes))
        ], self.frame_stack.initialize_memory())

    def reinitialize_memory(self, old_memory):
        old_convhx, old_convcx, old_frames = old_memory
        return ([Variable(chx.data) for chx in old_convhx],
                [Variable(ccx.data) for ccx in old_convcx],
                self.frame_stack.reinitialize_memory(old_frames))
class ActorCritic(torch.nn.Module):
    def __init__(self, observation_space, action_space, n_frames, args):
        super(ActorCritic, self).__init__()

        # State preprocessing
        self.senc_nngrid = senc_NNGrid(args)
        self.frame_stack = FrameStack(n_frames)

        # Action postprocessing
        self.adec_nngrid = adec_NNGrid(action_space, args)

        self.observation_space = observation_space
        self.action_space = action_space

        self.input_size = self.senc_nngrid.observation_space.shape
        self.output_size = int(np.prod(self.action_space.shape))

        self.conv1 = nn.Conv2d(self.frame_stack.n_frames * self.input_size[0],
                               32,
                               4,
                               stride=1,
                               padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(128, 128, 3, stride=1, padding=1)

        self.critic_linear = nn.Conv2d(128, 2, 3, stride=1, padding=1)
        self.actor_linear = nn.Conv2d(128, 2, 3, stride=1, padding=1)
        self.actor_linear2 = nn.Conv2d(128, 2, 3, stride=1, padding=1)

        self.apply(weights_init)
        lrelu_gain = nn.init.calculate_gain('leaky_relu')
        self.conv1.weight.data.mul_(lrelu_gain)
        self.conv2.weight.data.mul_(lrelu_gain)
        self.conv3.weight.data.mul_(lrelu_gain)
        self.conv4.weight.data.mul_(lrelu_gain)

        self.actor_linear.weight.data = norm_col_init(
            self.actor_linear.weight.data, 0.01)
        self.actor_linear.bias.data.fill_(0)
        self.actor_linear2.weight.data = norm_col_init(
            self.actor_linear2.weight.data, 0.01)
        self.actor_linear2.bias.data.fill_(0)
        self.critic_linear.weight.data = norm_col_init(
            self.critic_linear.weight.data, 1.0)
        self.critic_linear.bias.data.fill_(0)

        self.train()

    def _convforward(self, x):
        x = F.leaky_relu(self.conv1(x), 0.1)
        x = F.leaky_relu(self.conv2(x), 0.1)
        x = F.leaky_relu(self.conv3(x), 0.1)
        x = F.leaky_relu(self.conv4(x), 0.1)
        return x

    def forward(self, inputs):
        ob, info, frames = inputs

        # Get the grid state from vectorized input
        x = self.senc_nngrid((ob, info))

        # Stack it
        x, frames = self.frame_stack((x, frames))

        # Resize to correct dims for convnet
        batch_size = x.size(0)
        x = x.view(batch_size, self.frame_stack.n_frames * self.input_size[0],
                   self.input_size[1], self.input_size[2])
        x = self._convforward(x)

        # Compute action mean, action var and value grid
        critic_out = self.critic_linear(x)
        actor_out = F.softsign(self.actor_linear(x))
        actor_out2 = self.actor_linear2(x)

        # Extract motor-specific values from action grid
        critic_out = self.adec_nngrid((critic_out, info)).mean(-1,
                                                               keepdim=True)
        actor_out = self.adec_nngrid((actor_out, info))
        actor_out2 = self.adec_nngrid((actor_out2, info))
        return critic_out, actor_out, actor_out2, frames

    def initialize_memory(self):
        return self.frame_stack.initialize_memory()

    def reinitialize_memory(self, old_memory):
        return self.frame_stack.reinitialize_memory(old_memory)
class ActorCritic(torch.nn.Module):
    def __init__(self, observation_space, action_space, n_frames, args):
        super(ActorCritic, self).__init__()

        # State preprocessing
        self.senc_nngrid = senc_NNGrid(args)
        self.frame_stack = FrameStack(n_frames)

        # Action postprocessing
        self.adec_nngrid = adec_NNGrid(action_space, args)

        self.observation_space = observation_space
        self.action_space = action_space

        self.input_size = self.senc_nngrid.observation_space.shape
        self.output_size = int(np.prod(self.action_space.shape))

        _s = [32, 64, 128, 128]
        self.convlstm1 = ConvLSTM(self.frame_stack.n_frames *
                                  self.input_size[0],
                                  32,
                                  4,
                                  stride=1,
                                  padding=1)
        self.convlstm2 = ConvLSTM(32, 64, 3, stride=1, padding=1)
        self.convlstm3 = ConvLSTM(64, 128, 3, stride=1, padding=1)
        self.convlstm4 = ConvLSTM(128, 128, 3, stride=1, padding=1)
        self.convlstm = [
            self.convlstm1,
            self.convlstm2,
            self.convlstm3,
            self.convlstm4,
        ]
        _is = (n_frames * self.input_size[0], ) + self.input_size[1:]
        self.convh0 = []
        self.convc0 = []
        self.memsizes = []
        for i in range(len(self.convlstm)):
            _is = self.convlstm[i]._spatial_size_output_given_input((1, ) +
                                                                    _is)
            _is = (_s[i], ) + _is
            self.memsizes.append(copy.deepcopy(_is))
            self.convh0.append(
                nn.Parameter(torch.zeros((1, ) + self.memsizes[i])))
            self.convc0.append(
                nn.Parameter(torch.zeros((1, ) + self.memsizes[i])))
        self._convh0_module = nn.ParameterList(self.convh0)
        self._convc0_module = nn.ParameterList(self.convc0)

        self.critic_linear = nn.Conv2d(128, 2, 3, stride=1, padding=1)
        self.actor_linear = nn.Conv2d(128, 2, 3, stride=1, padding=1)
        self.actor_linear2 = nn.Conv2d(128, 2, 3, stride=1, padding=1)

        self.actor_linear.weight.data = norm_col_init(
            self.actor_linear.weight.data, 0.01)
        self.actor_linear.bias.data.fill_(0)
        self.actor_linear2.weight.data = norm_col_init(
            self.actor_linear2.weight.data, 0.01)
        self.actor_linear2.bias.data.fill_(0)
        self.critic_linear.weight.data = norm_col_init(
            self.critic_linear.weight.data, 1.0)
        self.critic_linear.bias.data.fill_(0)

        self.train()

    def _convlstmforward(self, x, convhx, convcx):
        last_convhx = x
        for i in range(len(self.convlstm)):
            convhx[i], convcx[i] = self.convlstm[i](last_convhx,
                                                    (convhx[i], convcx[i]))
            last_convhx = convhx[i]
        return convhx, convcx

    def forward(self, inputs):
        ob, info, (convhx, convcx, frames) = inputs

        # Get the grid state from vectorized input
        x = self.senc_nngrid((ob, info))

        # Stack it
        x, frames = self.frame_stack((x, frames))

        # Resize to correct dims for convnet
        batch_size = x.size(0)
        x = x.view(batch_size, self.frame_stack.n_frames * self.input_size[0],
                   self.input_size[1], self.input_size[2])
        convhx, convcx = self._convlstmforward(x, convhx, convcx)
        x = convhx[-1]

        # Compute action mean, action var and value grid
        critic_out = self.critic_linear(x)
        actor_out = F.softsign(self.actor_linear(x))
        actor_out2 = self.actor_linear2(x)

        # Extract motor-specific values from action grid
        critic_out = self.adec_nngrid((critic_out, info)).mean(-1,
                                                               keepdim=True)
        actor_out = self.adec_nngrid((actor_out, info))
        actor_out2 = self.adec_nngrid((actor_out2, info))
        return critic_out, actor_out, actor_out2, (convhx, convcx, frames)

    def initialize_memory(self):
        #print(np.sum([torch.norm(ch0).item() for ch0 in self.convh0]),
        #      np.sum([torch.norm(cc0).item() for cc0 in self.convc0]))
        use_gpu = next(self.parameters()).is_cuda
        return (
            #self.convh0,
            #self.convc0,
            # <!> DO NOT REMOVE BELOW CODE <!>
            # Below code is needed to fix a strange bug in graph backprop
            # TODO(eparisot): debug this further (low priority, might be pytorch..)
            #[ch0 for ch0 in self.convh0],
            #[cc0 for cc0 in self.convc0],
            [
                Variable(torch.zeros(ch0.size()).cuda())
                if use_gpu else Variable(torch.zeros(ch0.size()))
                for ch0 in self.convh0
            ],
            [
                Variable(torch.zeros(cc0.size()).cuda())
                if use_gpu else Variable(torch.zeros(cc0.size()))
                for cc0 in self.convc0
            ],
            self.frame_stack.initialize_memory())

    def reinitialize_memory(self, old_memory):
        old_convhx, old_convcx, old_frames = old_memory
        return ([Variable(chx.data) for chx in old_convhx],
                [Variable(ccx.data) for ccx in old_convcx],
                self.frame_stack.reinitialize_memory(old_frames))
Beispiel #4
0
class ActorCritic(torch.nn.Module):
    def __init__(self, observation_space, action_space, n_frames, args):
        super(ActorCritic, self).__init__()

        # State preprocessing
        # Note: only works for 1d observation spaces
        args['observation_dim'] = observation_space.shape[0]
        self.senc_nngrid = senc_FlatDepthNNGrid(args)
        self.frame_stack = FrameStack(n_frames)

        self.observation_space = observation_space
        self.action_space = action_space

        self.input_size = self.senc_nngrid.observation_space.shape[0]
        self.output_size = int(np.prod(self.action_space.shape))

        self.fc1 = nn.Linear(self.input_size, 256)
        self.lrelu1 = nn.LeakyReLU(0.1)
        self.fc2 = nn.Linear(256, 256)
        self.lrelu2 = nn.LeakyReLU(0.1)
        self.fc3 = nn.Linear(256, 128)
        self.lrelu3 = nn.LeakyReLU(0.1)
        self.fc4 = nn.Linear(128, 128)
        self.lrelu4 = nn.LeakyReLU(0.1)

        self.m1 = self.frame_stack.n_frames * 128
        self.lstm = nn.LSTMCell(self.m1, 128)
        self.critic_linear = nn.Linear(128, 1)
        self.actor_linear = nn.Linear(128, self.output_size)
        self.actor_linear2 = nn.Linear(128, self.output_size)

        self.apply(weights_init_mlp)
        lrelu = nn.init.calculate_gain('leaky_relu')
        self.fc1.weight.data.mul_(lrelu)
        self.fc2.weight.data.mul_(lrelu)
        self.fc3.weight.data.mul_(lrelu)
        self.fc4.weight.data.mul_(lrelu)

        self.actor_linear.weight.data = norm_col_init(
            self.actor_linear.weight.data, 0.01)
        self.actor_linear.bias.data.fill_(0)
        self.actor_linear2.weight.data = norm_col_init(
            self.actor_linear2.weight.data, 0.01)
        self.actor_linear2.bias.data.fill_(0)
        self.critic_linear.weight.data = norm_col_init(
            self.critic_linear.weight.data, 1.0)
        self.critic_linear.bias.data.fill_(0)

        self.lstm.bias_ih.data.fill_(0)
        self.lstm.bias_hh.data.fill_(0)

        self.train()

    def forward(self, inputs):
        ob, info, (hx, cx, frames) = inputs

        # Get the grid state from vectorized input
        x = self.senc_nngrid((ob, info))

        # Stack it
        x, frames = self.frame_stack((x, frames))

        batch_size = x.size(0)
        x = x.view(batch_size, self.frame_stack.n_frames, self.input_size)

        x = self.lrelu1(self.fc1(x))
        x = self.lrelu2(self.fc2(x))
        x = self.lrelu3(self.fc3(x))
        x = self.lrelu4(self.fc4(x))

        x = x.view(1, self.m1)
        hx, cx = self.lstm(x, (hx, cx))
        x = hx

        return self.critic_linear(x), F.softsign(
            self.actor_linear(x)), self.actor_linear2(x), (hx, cx, frames)

    def initialize_memory(self):
        if next(self.parameters()).is_cuda:
            return (Variable(torch.zeros(1, 128).cuda()),
                    Variable(torch.zeros(1, 128).cuda()),
                    self.frame_stack.initialize_memory())
        return (Variable(torch.zeros(1, 128)), Variable(torch.zeros(1, 128)),
                self.frame_stack.initialize_memory())

    def reinitialize_memory(self, old_memory):
        old_hx, old_cx, old_frames = old_memory
        return (Variable(old_hx.data), Variable(old_cx.data),
                self.frame_stack.reinitialize_memory(old_frames))
class ActorCritic(torch.nn.Module):
    def __init__(self, observation_space, action_space, n_frames, args):
        super(ActorCritic, self).__init__()

        # State preprocessing
        self.senc_nngrid = senc_NNGrid(args)
        self.frame_stack = FrameStack(n_frames)

        self.observation_space = observation_space
        self.action_space      = action_space

        self.input_size  = self.senc_nngrid.observation_space.shape
        self.output_size = int(np.prod(self.action_space.shape))

        _s = [32, 64, 128]
        self.convlstm1 = ConvLSTM(self.frame_stack.n_frames*self.input_size[0], _s[0], 4, stride=2, padding=0)
        self.convlstm2 = ConvLSTM(_s[0], _s[1],  3, stride=2, padding=0)
        self.convlstm3 = ConvLSTM(_s[1], _s[2], 3, stride=1, padding=0)

        self.convlstm = [
            self.convlstm1,
            self.convlstm2,
            self.convlstm3,
        ]

        _is = (n_frames*self.input_size[0],)+self.input_size[1:]
        self.memsizes = []
        for i in range(3):
            _is = self.convlstm[i]._spatial_size_output_given_input((1,)+_is)
            _is = (_s[i],)+_is
            self.memsizes.append(copy.deepcopy(_is))

        self.lstm = nn.LSTMCell(np.prod(self.memsizes[-1]), 128)

        self.critic_linear = nn.Linear(128, 1)
        self.actor_linear  = nn.Linear(128, self.action_space.shape[0])
        self.actor_linear2 = nn.Linear(128, self.action_space.shape[0])

        self.actor_linear.weight.data = norm_col_init(
            self.actor_linear.weight.data, 0.01)
        self.actor_linear.bias.data.fill_(0)
        self.actor_linear2.weight.data = norm_col_init(
            self.actor_linear2.weight.data, 0.01)
        self.actor_linear2.bias.data.fill_(0)
        self.critic_linear.weight.data = norm_col_init(
            self.critic_linear.weight.data, 1.0)
        self.critic_linear.bias.data.fill_(0)

        self.train()

    def _convlstmforward(self, x, convhx, convcx):
        last_convhx = x
        for i in range(len(self.convlstm)):
            convhx[i], convcx[i] = self.convlstm[i](last_convhx, (convhx[i], convcx[i]))
            last_convhx = convhx[i]
        return convhx, convcx

    def forward(self, inputs):
        ob, info, (convhx, convcx, hx, cx, frames) = inputs

        # Get the grid state from vectorized input
        x = self.senc_nngrid((ob, info))

        # Stack it
        x, frames = self.frame_stack((x, frames))

        # Resize to correct dims for convnet
        batch_size = x.size(0)
        x = x.view(batch_size,
                   self.frame_stack.n_frames*self.input_size[0],
                   self.input_size[1], self.input_size[2])
        convhx, convcx = self._convlstmforward(x, convhx, convcx)
        x = convhx[-1]

        x = x.view(1, -1)
        hx, cx = self.lstm(x, (hx, cx))
        x = hx

        critic_out = self.critic_linear(x)
        actor_out = F.softsign(self.actor_linear(x))
        actor_out2 = self.actor_linear2(x)

        return self.critic_linear(x), F.softsign(self.actor_linear(x)), self.actor_linear2(x), (convhx, convcx, hx, cx, frames)

    def initialize_memory(self):
        if next(self.parameters()).is_cuda:
            return (
                [Variable(torch.zeros((1,)+self.memsizes[i]).cuda()) for i in range(3)],
                [Variable(torch.zeros((1,)+self.memsizes[i]).cuda()) for i in range(3)],
                Variable(torch.zeros(1, 128).cuda()),
                Variable(torch.zeros(1, 128).cuda()),
                self.frame_stack.initialize_memory())
        return (
            [Variable(torch.zeros((1,)+self.memsizes[i])) for i in range(3)],
            [Variable(torch.zeros((1,)+self.memsizes[i])) for i in range(3)],
            Variable(torch.zeros(1, 128)),
            Variable(torch.zeros(1, 128)),
            self.frame_stack.initialize_memory())

    def reinitialize_memory(self, old_memory):
        old_convhx, old_convcx, old_hx, old_cx, old_frames = old_memory
        return (
            [Variable(chx.data) for chx in old_convhx],
            [Variable(ccx.data) for ccx in old_convcx],
            Variable(old_hx.data),
            Variable(old_cx.data),
            self.frame_stack.reinitialize_memory(old_frames))
class ActorCritic(torch.nn.Module):
    def __init__(self, observation_space, action_space, n_frames, args):
        super(ActorCritic, self).__init__()

        # Stack preprocessing
        self.frame_stack = FrameStack(n_frames)

        self.observation_space = observation_space
        self.action_space = action_space

        self.input_size = int(np.prod(self.observation_space.shape))
        self.output_size = int(np.prod(self.action_space.shape))

        self.conv1 = nn.Conv1d(self.frame_stack.n_frames,
                               32,
                               3,
                               stride=1,
                               padding=1)
        self.lrelu1 = nn.LeakyReLU(0.1)
        self.conv2 = nn.Conv1d(32, 32, 3, stride=1, padding=1)
        self.lrelu2 = nn.LeakyReLU(0.1)
        self.conv3 = nn.Conv1d(32, 64, 2, stride=1, padding=1)
        self.lrelu3 = nn.LeakyReLU(0.1)
        self.conv4 = nn.Conv1d(64, 64, 1, stride=1)
        self.lrelu4 = nn.LeakyReLU(0.1)

        dummy_input = Variable(
            torch.zeros(1, self.frame_stack.n_frames, self.input_size))
        dummy_conv_output = self._convforward(dummy_input)

        self.lstm = nn.LSTMCell(dummy_conv_output.nelement(), 128)
        self.critic_linear = nn.Linear(128, 1)
        self.actor_linear = nn.Linear(128, self.output_size)
        self.actor_linear2 = nn.Linear(128, self.output_size)

        self.apply(weights_init)
        lrelu_gain = nn.init.calculate_gain('leaky_relu')
        self.conv1.weight.data.mul_(lrelu_gain)
        self.conv2.weight.data.mul_(lrelu_gain)
        self.conv3.weight.data.mul_(lrelu_gain)
        self.conv4.weight.data.mul_(lrelu_gain)

        self.actor_linear.weight.data = norm_col_init(
            self.actor_linear.weight.data, 0.01)
        self.actor_linear.bias.data.fill_(0)
        self.actor_linear2.weight.data = norm_col_init(
            self.actor_linear2.weight.data, 0.01)
        self.actor_linear2.bias.data.fill_(0)
        self.critic_linear.weight.data = norm_col_init(
            self.critic_linear.weight.data, 1.0)
        self.critic_linear.bias.data.fill_(0)

        self.lstm.bias_ih.data.fill_(0)
        self.lstm.bias_hh.data.fill_(0)

        self.train()

    def _convforward(self, x):
        x = self.lrelu1(self.conv1(x))
        x = self.lrelu2(self.conv2(x))
        x = self.lrelu3(self.conv3(x))
        x = self.lrelu4(self.conv4(x))
        return x

    def forward(self, inputs):
        x, _, (hx, cx, frames) = inputs

        # Stack it
        x, frames = self.frame_stack((x, frames))

        batch_size = x.size(0)
        x = x.view(batch_size, self.frame_stack.n_frames, self.input_size)

        x = self._convforward(x)
        x = x.view(x.size(0), -1)

        hx, cx = self.lstm(x, (hx, cx))
        x = hx

        return self.critic_linear(x), F.softsign(
            self.actor_linear(x)), self.actor_linear2(x), (hx, cx, frames)

    def initialize_memory(self):
        if next(self.parameters()).is_cuda:
            return (Variable(torch.zeros(1, 128).cuda()),
                    Variable(torch.zeros(1, 128).cuda()),
                    self.frame_stack.initialize_memory())
        return (Variable(torch.zeros(1, 128)), Variable(torch.zeros(1, 128)),
                self.frame_stack.initialize_memory())

    def reinitialize_memory(self, old_memory):
        old_hx, old_cx, old_frames = old_memory
        return (Variable(old_hx.data), Variable(old_cx.data),
                self.frame_stack.reinitialize_memory(old_frames))
Beispiel #7
0
class ActorCritic(torch.nn.Module):
    def __init__(self, observation_space, action_space, n_frames, args):
        super(ActorCritic, self).__init__()

        # State preprocessing
        self.senc_nngrid = senc_NNGrid(args)
        self.frame_stack = FrameStack(n_frames)

        self.observation_space = observation_space
        self.action_space = action_space

        self.input_size = self.senc_nngrid.observation_space.shape
        self.output_size = int(np.prod(self.action_space.shape))

        self.conv1 = nn.Conv2d(self.frame_stack.n_frames * self.input_size[0],
                               32,
                               4,
                               stride=2,
                               padding=0)
        self.conv2 = nn.Conv2d(32, 64, 3, stride=2, padding=0)
        self.conv3 = nn.Conv2d(64, 128, 3, stride=1, padding=0)

        # Calculate conv->linear size
        dummy_input = Variable(
            torch.zeros((
                1,
                n_frames,
            ) + self.senc_nngrid.observation_space.shape))
        dummy_input = dummy_input.view((
            1,
            n_frames * self.input_size[0],
        ) + self.input_size[1:])
        outconv = self._convforward(dummy_input)

        self.lstm = nn.LSTMCell(outconv.nelement(), 128)

        self.critic_linear = nn.Linear(128, 1)
        self.actor_linear = nn.Linear(128, self.action_space.shape[0])
        self.actor_linear2 = nn.Linear(128, self.action_space.shape[0])

        self.apply(weights_init)
        lrelu_gain = nn.init.calculate_gain('leaky_relu')
        self.conv1.weight.data.mul_(lrelu_gain)
        self.conv2.weight.data.mul_(lrelu_gain)
        self.conv3.weight.data.mul_(lrelu_gain)

        self.actor_linear.weight.data = norm_col_init(
            self.actor_linear.weight.data, 0.01)
        self.actor_linear.bias.data.fill_(0)
        self.actor_linear2.weight.data = norm_col_init(
            self.actor_linear2.weight.data, 0.01)
        self.actor_linear2.bias.data.fill_(0)
        self.critic_linear.weight.data = norm_col_init(
            self.critic_linear.weight.data, 1.0)
        self.critic_linear.bias.data.fill_(0)

        self.train()

    def _convforward(self, x):
        x = F.leaky_relu(self.conv1(x), 0.1)
        x = F.leaky_relu(self.conv2(x), 0.1)
        x = F.leaky_relu(self.conv3(x), 0.1)
        return x

    def forward(self, inputs):
        ob, info, (hx, cx, frames) = inputs

        # Get the grid state from vectorized input
        x = self.senc_nngrid((ob, info))

        # Stack it
        x, frames = self.frame_stack((x, frames))

        # Resize to correct dims for convnet
        batch_size = x.size(0)
        x = x.view(batch_size, self.frame_stack.n_frames * self.input_size[0],
                   self.input_size[1], self.input_size[2])
        x = self._convforward(x)

        x = x.view(1, -1)
        hx, cx = self.lstm(x, (hx, cx))
        x = hx

        critic_out = self.critic_linear(x)
        actor_out = F.softsign(self.actor_linear(x))
        actor_out2 = self.actor_linear2(x)

        return self.critic_linear(x), F.softsign(
            self.actor_linear(x)), self.actor_linear2(x), (hx, cx, frames)

    def initialize_memory(self):
        if next(self.parameters()).is_cuda:
            return (Variable(torch.zeros(1, 128).cuda()),
                    Variable(torch.zeros(1, 128).cuda()),
                    self.frame_stack.initialize_memory())
        return (Variable(torch.zeros(1, 128)), Variable(torch.zeros(1, 128)),
                self.frame_stack.initialize_memory())

    def reinitialize_memory(self, old_memory):
        old_hx, old_cx, old_frames = old_memory
        return (Variable(old_hx.data), Variable(old_cx.data),
                self.frame_stack.reinitialize_memory(old_frames))