def __init__( self, image_shape, output_size, n_atoms=51, fc_sizes=512, dueling=False, use_maxpool=False, channels=None, # None uses default. kernel_sizes=None, strides=None, paddings=None, ): """Instantiates the neural network according to arguments; network defaults stored within this method.""" super().__init__() self.dueling = dueling c, h, w = image_shape self.conv = Conv2dModel( in_channels=c, channels=channels or [32, 64, 64], kernel_sizes=kernel_sizes or [8, 4, 3], strides=strides or [4, 2, 1], paddings=paddings or [0, 1, 1], use_maxpool=use_maxpool, ) conv_out_size = self.conv.conv_out_size(h, w) if dueling: self.head = DistributionalDuelingHeadModel(conv_out_size, fc_sizes, output_size=output_size, n_atoms=n_atoms) else: self.head = DistributionalHeadModel(conv_out_size, fc_sizes, output_size=output_size, n_atoms=n_atoms)
def __init__( self, observation_shape, action_size, linear_value_output=True, sequence_length=50, seperate_value_network=True, size='small', channels=None, # None uses default. kernel_sizes=None, strides=None, paddings=None, ): super().__init__() self.action_size = action_size c, h, w = observation_shape self.conv = Conv2dModel( in_channels=c, channels=channels or [64, 32, 32], kernel_sizes=kernel_sizes or [3, 2, 2], strides=strides or [2, 1, 1], paddings=paddings or [0, 0, 0], use_maxpool=False, ) self.conv_out_size = self.conv.conv_out_size(h, w) self.sequence_length = sequence_length self.transformer_dim = 32 # SIZES[size]['dim'] self.depth = SIZES[size]['depth'] self.cmem_ratio = SIZES[size]['cmem_ratio'] self.cmem_length = self.sequence_length // self.cmem_ratio memory_layers = range(1, self.depth + 1) self.transformer = CompressiveTransformerPyTorch( num_tokens=20000, emb_dim=self.conv_out_size, dim=self.transformer_dim, heads=SIZES[size]['num_heads'], depth=self.depth, seq_len=self.sequence_length, mem_len=self.sequence_length, # memory length reconstruction_loss_weight= 1, # weight to place on compressed memory reconstruction loss gru_gated_residual=True, # whether to gate the residual intersection, from 'Stabilizing Transformer for RL' paper memory_layers=memory_layers, ) self.transformer.token_emb = torch.nn.Identity( ) # don't use token embedding in compressive transforrmer self.transformer.to_logits = torch.nn.Identity() # self.input_layer_norm = torch.nn.LayerNorm(self.state_size) self.input_layer_norm = torch.nn.Identity() self.output_layer_norm = torch.nn.LayerNorm(self.transformer_dim) # self.output_layer_norm = torch.nn.Identity() self.softplus = torch.nn.Softplus() self.pi_head = MlpModel(input_size=self.transformer_dim, hidden_sizes=[ 256, ], output_size=action_size) self.value_head = MlpModel( input_size=self.transformer_dim, hidden_sizes=[ 256, ], output_size=1 if linear_value_output else None) self.mask = torch.ones((self.sequence_length, self.sequence_length), dtype=torch.int8).triu()