class ActorCritic(torch.nn.Module): def __init__(self, observation_space, action_space, n_frames, args): super(ActorCritic, self).__init__() # State preprocessing self.senc_nngrid = senc_NNGrid(args) self.frame_stack = FrameStack(n_frames) # Action postprocessing self.adec_nngrid = adec_NNGrid(action_space, args) self.observation_space = observation_space self.action_space = action_space self.input_size = self.senc_nngrid.observation_space.shape self.output_size = int(np.prod(self.action_space.shape)) _s = [32, 64, 128, 128] self.convlstm1 = ConvLSTM(self.frame_stack.n_frames * self.input_size[0], 32, 4, stride=1, padding=1) self.convlstm2 = ConvLSTM(32, 64, 3, stride=1, padding=1) self.convlstm3 = ConvLSTM(64, 128, 3, stride=1, padding=1) self.convlstm4 = ConvLSTM(128, 128, 3, stride=1, padding=1) self.convlstm = [ self.convlstm1, self.convlstm2, self.convlstm3, self.convlstm4, ] _is = (n_frames * self.input_size[0], ) + self.input_size[1:] self.memsizes = [] for i in range(len(self.convlstm)): _is = self.convlstm[i]._spatial_size_output_given_input((1, ) + _is) _is = (_s[i], ) + _is self.memsizes.append(copy.deepcopy(_is)) self.critic_linear = nn.Conv2d(128, 1, 3, stride=1, padding=1) self.actor_linear = nn.Conv2d(128, 1, 3, stride=1, padding=1) self.actor_linear2 = nn.Conv2d(128, 1, 3, stride=1, padding=1) self.actor_linear.weight.data = norm_col_init( self.actor_linear.weight.data, 0.01) self.actor_linear.bias.data.fill_(0) self.actor_linear2.weight.data = norm_col_init( self.actor_linear2.weight.data, 0.01) self.actor_linear2.bias.data.fill_(0) self.critic_linear.weight.data = norm_col_init( self.critic_linear.weight.data, 1.0) self.critic_linear.bias.data.fill_(0) self.train() def _convlstmforward(self, x, convhx, convcx): last_convhx = x for i in range(len(self.convlstm)): convhx[i], convcx[i] = self.convlstm[i](last_convhx, (convhx[i], convcx[i])) last_convhx = convhx[i] return convhx, convcx def forward(self, inputs): ob, info, (convhx, convcx, frames) = inputs # Get the grid state from vectorized input x = self.senc_nngrid((ob, info)) # Stack it x, frames = self.frame_stack((x, frames)) # Resize to correct dims for convnet batch_size = x.size(0) x = x.view(batch_size, self.frame_stack.n_frames * self.input_size[0], self.input_size[1], self.input_size[2]) convhx, convcx = self._convlstmforward(x, convhx, convcx) x = convhx[-1] # Compute action mean, action var and value grid critic_out = self.critic_linear(x) actor_out = F.softsign(self.actor_linear(x)) actor_out2 = self.actor_linear2(x) # Extract motor-specific values from action grid critic_out = self.adec_nngrid((critic_out, info)).mean(-1, keepdim=True) actor_out = self.adec_nngrid((actor_out, info)) actor_out2 = self.adec_nngrid((actor_out2, info)) return critic_out, actor_out, actor_out2, (convhx, convcx, frames) def initialize_memory(self): if next(self.parameters()).is_cuda: return ([ Variable(torch.zeros((1, ) + self.memsizes[i]).cuda()) for i in range(len(self.memsizes)) ], [ Variable(torch.zeros((1, ) + self.memsizes[i]).cuda()) for i in range(len(self.memsizes)) ], self.frame_stack.initialize_memory()) return ([ Variable(torch.zeros((1, ) + self.memsizes[i])) for i in range(len(self.memsizes)) ], [ Variable(torch.zeros((1, ) + self.memsizes[i])) for i in range(len(self.memsizes)) ], self.frame_stack.initialize_memory()) def reinitialize_memory(self, old_memory): old_convhx, old_convcx, old_frames = old_memory return ([Variable(chx.data) for chx in old_convhx], [Variable(ccx.data) for ccx in old_convcx], self.frame_stack.reinitialize_memory(old_frames))
class ActorCritic(torch.nn.Module): def __init__(self, observation_space, action_space, n_frames, args): super(ActorCritic, self).__init__() # State preprocessing self.senc_nngrid = senc_NNGrid(args) self.frame_stack = FrameStack(n_frames) # Action postprocessing self.adec_nngrid = adec_NNGrid(action_space, args) self.observation_space = observation_space self.action_space = action_space self.input_size = self.senc_nngrid.observation_space.shape self.output_size = int(np.prod(self.action_space.shape)) self.conv1 = nn.Conv2d(self.frame_stack.n_frames * self.input_size[0], 32, 4, stride=1, padding=1) self.conv2 = nn.Conv2d(32, 64, 3, stride=1, padding=1) self.conv3 = nn.Conv2d(64, 128, 3, stride=1, padding=1) self.conv4 = nn.Conv2d(128, 128, 3, stride=1, padding=1) self.critic_linear = nn.Conv2d(128, 2, 3, stride=1, padding=1) self.actor_linear = nn.Conv2d(128, 2, 3, stride=1, padding=1) self.actor_linear2 = nn.Conv2d(128, 2, 3, stride=1, padding=1) self.apply(weights_init) lrelu_gain = nn.init.calculate_gain('leaky_relu') self.conv1.weight.data.mul_(lrelu_gain) self.conv2.weight.data.mul_(lrelu_gain) self.conv3.weight.data.mul_(lrelu_gain) self.conv4.weight.data.mul_(lrelu_gain) self.actor_linear.weight.data = norm_col_init( self.actor_linear.weight.data, 0.01) self.actor_linear.bias.data.fill_(0) self.actor_linear2.weight.data = norm_col_init( self.actor_linear2.weight.data, 0.01) self.actor_linear2.bias.data.fill_(0) self.critic_linear.weight.data = norm_col_init( self.critic_linear.weight.data, 1.0) self.critic_linear.bias.data.fill_(0) self.train() def _convforward(self, x): x = F.leaky_relu(self.conv1(x), 0.1) x = F.leaky_relu(self.conv2(x), 0.1) x = F.leaky_relu(self.conv3(x), 0.1) x = F.leaky_relu(self.conv4(x), 0.1) return x def forward(self, inputs): ob, info, frames = inputs # Get the grid state from vectorized input x = self.senc_nngrid((ob, info)) # Stack it x, frames = self.frame_stack((x, frames)) # Resize to correct dims for convnet batch_size = x.size(0) x = x.view(batch_size, self.frame_stack.n_frames * self.input_size[0], self.input_size[1], self.input_size[2]) x = self._convforward(x) # Compute action mean, action var and value grid critic_out = self.critic_linear(x) actor_out = F.softsign(self.actor_linear(x)) actor_out2 = self.actor_linear2(x) # Extract motor-specific values from action grid critic_out = self.adec_nngrid((critic_out, info)).mean(-1, keepdim=True) actor_out = self.adec_nngrid((actor_out, info)) actor_out2 = self.adec_nngrid((actor_out2, info)) return critic_out, actor_out, actor_out2, frames def initialize_memory(self): return self.frame_stack.initialize_memory() def reinitialize_memory(self, old_memory): return self.frame_stack.reinitialize_memory(old_memory)
class ActorCritic(torch.nn.Module): def __init__(self, observation_space, action_space, n_frames, args): super(ActorCritic, self).__init__() # State preprocessing self.senc_nngrid = senc_NNGrid(args) self.frame_stack = FrameStack(n_frames) # Action postprocessing self.adec_nngrid = adec_NNGrid(action_space, args) self.observation_space = observation_space self.action_space = action_space self.input_size = self.senc_nngrid.observation_space.shape self.output_size = int(np.prod(self.action_space.shape)) _s = [32, 64, 128, 128] self.convlstm1 = ConvLSTM(self.frame_stack.n_frames * self.input_size[0], 32, 4, stride=1, padding=1) self.convlstm2 = ConvLSTM(32, 64, 3, stride=1, padding=1) self.convlstm3 = ConvLSTM(64, 128, 3, stride=1, padding=1) self.convlstm4 = ConvLSTM(128, 128, 3, stride=1, padding=1) self.convlstm = [ self.convlstm1, self.convlstm2, self.convlstm3, self.convlstm4, ] _is = (n_frames * self.input_size[0], ) + self.input_size[1:] self.convh0 = [] self.convc0 = [] self.memsizes = [] for i in range(len(self.convlstm)): _is = self.convlstm[i]._spatial_size_output_given_input((1, ) + _is) _is = (_s[i], ) + _is self.memsizes.append(copy.deepcopy(_is)) self.convh0.append( nn.Parameter(torch.zeros((1, ) + self.memsizes[i]))) self.convc0.append( nn.Parameter(torch.zeros((1, ) + self.memsizes[i]))) self._convh0_module = nn.ParameterList(self.convh0) self._convc0_module = nn.ParameterList(self.convc0) self.critic_linear = nn.Conv2d(128, 2, 3, stride=1, padding=1) self.actor_linear = nn.Conv2d(128, 2, 3, stride=1, padding=1) self.actor_linear2 = nn.Conv2d(128, 2, 3, stride=1, padding=1) self.actor_linear.weight.data = norm_col_init( self.actor_linear.weight.data, 0.01) self.actor_linear.bias.data.fill_(0) self.actor_linear2.weight.data = norm_col_init( self.actor_linear2.weight.data, 0.01) self.actor_linear2.bias.data.fill_(0) self.critic_linear.weight.data = norm_col_init( self.critic_linear.weight.data, 1.0) self.critic_linear.bias.data.fill_(0) self.train() def _convlstmforward(self, x, convhx, convcx): last_convhx = x for i in range(len(self.convlstm)): convhx[i], convcx[i] = self.convlstm[i](last_convhx, (convhx[i], convcx[i])) last_convhx = convhx[i] return convhx, convcx def forward(self, inputs): ob, info, (convhx, convcx, frames) = inputs # Get the grid state from vectorized input x = self.senc_nngrid((ob, info)) # Stack it x, frames = self.frame_stack((x, frames)) # Resize to correct dims for convnet batch_size = x.size(0) x = x.view(batch_size, self.frame_stack.n_frames * self.input_size[0], self.input_size[1], self.input_size[2]) convhx, convcx = self._convlstmforward(x, convhx, convcx) x = convhx[-1] # Compute action mean, action var and value grid critic_out = self.critic_linear(x) actor_out = F.softsign(self.actor_linear(x)) actor_out2 = self.actor_linear2(x) # Extract motor-specific values from action grid critic_out = self.adec_nngrid((critic_out, info)).mean(-1, keepdim=True) actor_out = self.adec_nngrid((actor_out, info)) actor_out2 = self.adec_nngrid((actor_out2, info)) return critic_out, actor_out, actor_out2, (convhx, convcx, frames) def initialize_memory(self): #print(np.sum([torch.norm(ch0).item() for ch0 in self.convh0]), # np.sum([torch.norm(cc0).item() for cc0 in self.convc0])) use_gpu = next(self.parameters()).is_cuda return ( #self.convh0, #self.convc0, # <!> DO NOT REMOVE BELOW CODE <!> # Below code is needed to fix a strange bug in graph backprop # TODO(eparisot): debug this further (low priority, might be pytorch..) #[ch0 for ch0 in self.convh0], #[cc0 for cc0 in self.convc0], [ Variable(torch.zeros(ch0.size()).cuda()) if use_gpu else Variable(torch.zeros(ch0.size())) for ch0 in self.convh0 ], [ Variable(torch.zeros(cc0.size()).cuda()) if use_gpu else Variable(torch.zeros(cc0.size())) for cc0 in self.convc0 ], self.frame_stack.initialize_memory()) def reinitialize_memory(self, old_memory): old_convhx, old_convcx, old_frames = old_memory return ([Variable(chx.data) for chx in old_convhx], [Variable(ccx.data) for ccx in old_convcx], self.frame_stack.reinitialize_memory(old_frames))
class ActorCritic(torch.nn.Module): def __init__(self, observation_space, action_space, n_frames, args): super(ActorCritic, self).__init__() # State preprocessing # Note: only works for 1d observation spaces args['observation_dim'] = observation_space.shape[0] self.senc_nngrid = senc_FlatDepthNNGrid(args) self.frame_stack = FrameStack(n_frames) self.observation_space = observation_space self.action_space = action_space self.input_size = self.senc_nngrid.observation_space.shape[0] self.output_size = int(np.prod(self.action_space.shape)) self.fc1 = nn.Linear(self.input_size, 256) self.lrelu1 = nn.LeakyReLU(0.1) self.fc2 = nn.Linear(256, 256) self.lrelu2 = nn.LeakyReLU(0.1) self.fc3 = nn.Linear(256, 128) self.lrelu3 = nn.LeakyReLU(0.1) self.fc4 = nn.Linear(128, 128) self.lrelu4 = nn.LeakyReLU(0.1) self.m1 = self.frame_stack.n_frames * 128 self.lstm = nn.LSTMCell(self.m1, 128) self.critic_linear = nn.Linear(128, 1) self.actor_linear = nn.Linear(128, self.output_size) self.actor_linear2 = nn.Linear(128, self.output_size) self.apply(weights_init_mlp) lrelu = nn.init.calculate_gain('leaky_relu') self.fc1.weight.data.mul_(lrelu) self.fc2.weight.data.mul_(lrelu) self.fc3.weight.data.mul_(lrelu) self.fc4.weight.data.mul_(lrelu) self.actor_linear.weight.data = norm_col_init( self.actor_linear.weight.data, 0.01) self.actor_linear.bias.data.fill_(0) self.actor_linear2.weight.data = norm_col_init( self.actor_linear2.weight.data, 0.01) self.actor_linear2.bias.data.fill_(0) self.critic_linear.weight.data = norm_col_init( self.critic_linear.weight.data, 1.0) self.critic_linear.bias.data.fill_(0) self.lstm.bias_ih.data.fill_(0) self.lstm.bias_hh.data.fill_(0) self.train() def forward(self, inputs): ob, info, (hx, cx, frames) = inputs # Get the grid state from vectorized input x = self.senc_nngrid((ob, info)) # Stack it x, frames = self.frame_stack((x, frames)) batch_size = x.size(0) x = x.view(batch_size, self.frame_stack.n_frames, self.input_size) x = self.lrelu1(self.fc1(x)) x = self.lrelu2(self.fc2(x)) x = self.lrelu3(self.fc3(x)) x = self.lrelu4(self.fc4(x)) x = x.view(1, self.m1) hx, cx = self.lstm(x, (hx, cx)) x = hx return self.critic_linear(x), F.softsign( self.actor_linear(x)), self.actor_linear2(x), (hx, cx, frames) def initialize_memory(self): if next(self.parameters()).is_cuda: return (Variable(torch.zeros(1, 128).cuda()), Variable(torch.zeros(1, 128).cuda()), self.frame_stack.initialize_memory()) return (Variable(torch.zeros(1, 128)), Variable(torch.zeros(1, 128)), self.frame_stack.initialize_memory()) def reinitialize_memory(self, old_memory): old_hx, old_cx, old_frames = old_memory return (Variable(old_hx.data), Variable(old_cx.data), self.frame_stack.reinitialize_memory(old_frames))
class ActorCritic(torch.nn.Module): def __init__(self, observation_space, action_space, n_frames, args): super(ActorCritic, self).__init__() # State preprocessing self.senc_nngrid = senc_NNGrid(args) self.frame_stack = FrameStack(n_frames) self.observation_space = observation_space self.action_space = action_space self.input_size = self.senc_nngrid.observation_space.shape self.output_size = int(np.prod(self.action_space.shape)) _s = [32, 64, 128] self.convlstm1 = ConvLSTM(self.frame_stack.n_frames*self.input_size[0], _s[0], 4, stride=2, padding=0) self.convlstm2 = ConvLSTM(_s[0], _s[1], 3, stride=2, padding=0) self.convlstm3 = ConvLSTM(_s[1], _s[2], 3, stride=1, padding=0) self.convlstm = [ self.convlstm1, self.convlstm2, self.convlstm3, ] _is = (n_frames*self.input_size[0],)+self.input_size[1:] self.memsizes = [] for i in range(3): _is = self.convlstm[i]._spatial_size_output_given_input((1,)+_is) _is = (_s[i],)+_is self.memsizes.append(copy.deepcopy(_is)) self.lstm = nn.LSTMCell(np.prod(self.memsizes[-1]), 128) self.critic_linear = nn.Linear(128, 1) self.actor_linear = nn.Linear(128, self.action_space.shape[0]) self.actor_linear2 = nn.Linear(128, self.action_space.shape[0]) self.actor_linear.weight.data = norm_col_init( self.actor_linear.weight.data, 0.01) self.actor_linear.bias.data.fill_(0) self.actor_linear2.weight.data = norm_col_init( self.actor_linear2.weight.data, 0.01) self.actor_linear2.bias.data.fill_(0) self.critic_linear.weight.data = norm_col_init( self.critic_linear.weight.data, 1.0) self.critic_linear.bias.data.fill_(0) self.train() def _convlstmforward(self, x, convhx, convcx): last_convhx = x for i in range(len(self.convlstm)): convhx[i], convcx[i] = self.convlstm[i](last_convhx, (convhx[i], convcx[i])) last_convhx = convhx[i] return convhx, convcx def forward(self, inputs): ob, info, (convhx, convcx, hx, cx, frames) = inputs # Get the grid state from vectorized input x = self.senc_nngrid((ob, info)) # Stack it x, frames = self.frame_stack((x, frames)) # Resize to correct dims for convnet batch_size = x.size(0) x = x.view(batch_size, self.frame_stack.n_frames*self.input_size[0], self.input_size[1], self.input_size[2]) convhx, convcx = self._convlstmforward(x, convhx, convcx) x = convhx[-1] x = x.view(1, -1) hx, cx = self.lstm(x, (hx, cx)) x = hx critic_out = self.critic_linear(x) actor_out = F.softsign(self.actor_linear(x)) actor_out2 = self.actor_linear2(x) return self.critic_linear(x), F.softsign(self.actor_linear(x)), self.actor_linear2(x), (convhx, convcx, hx, cx, frames) def initialize_memory(self): if next(self.parameters()).is_cuda: return ( [Variable(torch.zeros((1,)+self.memsizes[i]).cuda()) for i in range(3)], [Variable(torch.zeros((1,)+self.memsizes[i]).cuda()) for i in range(3)], Variable(torch.zeros(1, 128).cuda()), Variable(torch.zeros(1, 128).cuda()), self.frame_stack.initialize_memory()) return ( [Variable(torch.zeros((1,)+self.memsizes[i])) for i in range(3)], [Variable(torch.zeros((1,)+self.memsizes[i])) for i in range(3)], Variable(torch.zeros(1, 128)), Variable(torch.zeros(1, 128)), self.frame_stack.initialize_memory()) def reinitialize_memory(self, old_memory): old_convhx, old_convcx, old_hx, old_cx, old_frames = old_memory return ( [Variable(chx.data) for chx in old_convhx], [Variable(ccx.data) for ccx in old_convcx], Variable(old_hx.data), Variable(old_cx.data), self.frame_stack.reinitialize_memory(old_frames))
class ActorCritic(torch.nn.Module): def __init__(self, observation_space, action_space, n_frames, args): super(ActorCritic, self).__init__() # Stack preprocessing self.frame_stack = FrameStack(n_frames) self.observation_space = observation_space self.action_space = action_space self.input_size = int(np.prod(self.observation_space.shape)) self.output_size = int(np.prod(self.action_space.shape)) self.conv1 = nn.Conv1d(self.frame_stack.n_frames, 32, 3, stride=1, padding=1) self.lrelu1 = nn.LeakyReLU(0.1) self.conv2 = nn.Conv1d(32, 32, 3, stride=1, padding=1) self.lrelu2 = nn.LeakyReLU(0.1) self.conv3 = nn.Conv1d(32, 64, 2, stride=1, padding=1) self.lrelu3 = nn.LeakyReLU(0.1) self.conv4 = nn.Conv1d(64, 64, 1, stride=1) self.lrelu4 = nn.LeakyReLU(0.1) dummy_input = Variable( torch.zeros(1, self.frame_stack.n_frames, self.input_size)) dummy_conv_output = self._convforward(dummy_input) self.lstm = nn.LSTMCell(dummy_conv_output.nelement(), 128) self.critic_linear = nn.Linear(128, 1) self.actor_linear = nn.Linear(128, self.output_size) self.actor_linear2 = nn.Linear(128, self.output_size) self.apply(weights_init) lrelu_gain = nn.init.calculate_gain('leaky_relu') self.conv1.weight.data.mul_(lrelu_gain) self.conv2.weight.data.mul_(lrelu_gain) self.conv3.weight.data.mul_(lrelu_gain) self.conv4.weight.data.mul_(lrelu_gain) self.actor_linear.weight.data = norm_col_init( self.actor_linear.weight.data, 0.01) self.actor_linear.bias.data.fill_(0) self.actor_linear2.weight.data = norm_col_init( self.actor_linear2.weight.data, 0.01) self.actor_linear2.bias.data.fill_(0) self.critic_linear.weight.data = norm_col_init( self.critic_linear.weight.data, 1.0) self.critic_linear.bias.data.fill_(0) self.lstm.bias_ih.data.fill_(0) self.lstm.bias_hh.data.fill_(0) self.train() def _convforward(self, x): x = self.lrelu1(self.conv1(x)) x = self.lrelu2(self.conv2(x)) x = self.lrelu3(self.conv3(x)) x = self.lrelu4(self.conv4(x)) return x def forward(self, inputs): x, _, (hx, cx, frames) = inputs # Stack it x, frames = self.frame_stack((x, frames)) batch_size = x.size(0) x = x.view(batch_size, self.frame_stack.n_frames, self.input_size) x = self._convforward(x) x = x.view(x.size(0), -1) hx, cx = self.lstm(x, (hx, cx)) x = hx return self.critic_linear(x), F.softsign( self.actor_linear(x)), self.actor_linear2(x), (hx, cx, frames) def initialize_memory(self): if next(self.parameters()).is_cuda: return (Variable(torch.zeros(1, 128).cuda()), Variable(torch.zeros(1, 128).cuda()), self.frame_stack.initialize_memory()) return (Variable(torch.zeros(1, 128)), Variable(torch.zeros(1, 128)), self.frame_stack.initialize_memory()) def reinitialize_memory(self, old_memory): old_hx, old_cx, old_frames = old_memory return (Variable(old_hx.data), Variable(old_cx.data), self.frame_stack.reinitialize_memory(old_frames))
class ActorCritic(torch.nn.Module): def __init__(self, observation_space, action_space, n_frames, args): super(ActorCritic, self).__init__() # State preprocessing self.senc_nngrid = senc_NNGrid(args) self.frame_stack = FrameStack(n_frames) self.observation_space = observation_space self.action_space = action_space self.input_size = self.senc_nngrid.observation_space.shape self.output_size = int(np.prod(self.action_space.shape)) self.conv1 = nn.Conv2d(self.frame_stack.n_frames * self.input_size[0], 32, 4, stride=2, padding=0) self.conv2 = nn.Conv2d(32, 64, 3, stride=2, padding=0) self.conv3 = nn.Conv2d(64, 128, 3, stride=1, padding=0) # Calculate conv->linear size dummy_input = Variable( torch.zeros(( 1, n_frames, ) + self.senc_nngrid.observation_space.shape)) dummy_input = dummy_input.view(( 1, n_frames * self.input_size[0], ) + self.input_size[1:]) outconv = self._convforward(dummy_input) self.lstm = nn.LSTMCell(outconv.nelement(), 128) self.critic_linear = nn.Linear(128, 1) self.actor_linear = nn.Linear(128, self.action_space.shape[0]) self.actor_linear2 = nn.Linear(128, self.action_space.shape[0]) self.apply(weights_init) lrelu_gain = nn.init.calculate_gain('leaky_relu') self.conv1.weight.data.mul_(lrelu_gain) self.conv2.weight.data.mul_(lrelu_gain) self.conv3.weight.data.mul_(lrelu_gain) self.actor_linear.weight.data = norm_col_init( self.actor_linear.weight.data, 0.01) self.actor_linear.bias.data.fill_(0) self.actor_linear2.weight.data = norm_col_init( self.actor_linear2.weight.data, 0.01) self.actor_linear2.bias.data.fill_(0) self.critic_linear.weight.data = norm_col_init( self.critic_linear.weight.data, 1.0) self.critic_linear.bias.data.fill_(0) self.train() def _convforward(self, x): x = F.leaky_relu(self.conv1(x), 0.1) x = F.leaky_relu(self.conv2(x), 0.1) x = F.leaky_relu(self.conv3(x), 0.1) return x def forward(self, inputs): ob, info, (hx, cx, frames) = inputs # Get the grid state from vectorized input x = self.senc_nngrid((ob, info)) # Stack it x, frames = self.frame_stack((x, frames)) # Resize to correct dims for convnet batch_size = x.size(0) x = x.view(batch_size, self.frame_stack.n_frames * self.input_size[0], self.input_size[1], self.input_size[2]) x = self._convforward(x) x = x.view(1, -1) hx, cx = self.lstm(x, (hx, cx)) x = hx critic_out = self.critic_linear(x) actor_out = F.softsign(self.actor_linear(x)) actor_out2 = self.actor_linear2(x) return self.critic_linear(x), F.softsign( self.actor_linear(x)), self.actor_linear2(x), (hx, cx, frames) def initialize_memory(self): if next(self.parameters()).is_cuda: return (Variable(torch.zeros(1, 128).cuda()), Variable(torch.zeros(1, 128).cuda()), self.frame_stack.initialize_memory()) return (Variable(torch.zeros(1, 128)), Variable(torch.zeros(1, 128)), self.frame_stack.initialize_memory()) def reinitialize_memory(self, old_memory): old_hx, old_cx, old_frames = old_memory return (Variable(old_hx.data), Variable(old_cx.data), self.frame_stack.reinitialize_memory(old_frames))