def __init__( self, output_size: int, device: str, hidden_sizes: List[int], input_size: int = None, conv: Conv = None, ): super(ActorCritic, self).__init__() self.output_size = output_size self.device = device self.conv = conv if conv is not None: input_size = conv.output_size assert input_size is not None def with_relu(m): return nn.Sequential(init_ortho(m, 'relu'), nn.ReLU()) self.fc = nn.Sequential(*[ with_relu(nn.Linear(s_in, s_out)) for s_in, s_out in zip( [input_size] + hidden_sizes[:-1], hidden_sizes)]) input_size = hidden_sizes[-1] self.pi = init_ortho(nn.Linear(input_size, output_size), .01) self.val = init_ortho(nn.Linear(input_size, 1))
def __init__(self, emb_size, num_layer=64): super().__init__() # 84 x 84 -> 20 x 20 -> 9 x 9 -> 7 x 7 self.block1 = nn.Sequential( init_ortho(nn.Conv2d(1, num_layer // 2, 8, 4), 'relu'), nn.ReLU(), init_ortho(nn.Conv2d(num_layer // 2, num_layer, 4, 2), 'relu')) self.block2 = nn.Sequential( nn.ReLU(), init_ortho(nn.Conv2d(num_layer, num_layer, 3, 1), 'relu'), nn.ReLU(), Flatten(), init_ortho(nn.Linear(num_layer * 7 * 7, emb_size)))
def __init__( self, output_size: int, device: str, emb_size: int = 32, history_size: int = 64, emb_hidden_size: int = None, input_size: int = 4, hidden_size: int = 512, ): super(ActorCritic, self).__init__() self.output_size = output_size self.device = device def with_relu(m): return nn.Sequential(init_ortho(m, 'relu'), nn.ReLU()) self.conv = nn.Sequential( with_relu(nn.Conv2d(input_size, 32, 8, 4)), with_relu(nn.Conv2d(32, 64, 4, 2)), with_relu(nn.Conv2d(64, 64, 3, 1)), Flatten()) conv_output = 64 * 7 * 7 if emb_hidden_size is not None: self.emb_fc = nn.Sequential( Flatten(), with_relu(nn.Linear(emb_size * history_size, emb_hidden_size)) ) self.emb_output = emb_hidden_size else: self.emb_fc = Flatten() self.emb_output = emb_size * history_size self.fc = with_relu( nn.Linear(conv_output + self.emb_output, hidden_size)) self.pi = init_ortho(nn.Linear(hidden_size, output_size), .01) self.val = init_ortho(nn.Linear(hidden_size, 1))
def __init__(self, emb_size, num_heads): super().__init__() self.emb_size = emb_size # 84 x 84 -> 20 x 20 -> 9 x 9 -> 7 x 7 self.base = nn.Sequential( init_ortho(nn.Conv2d(1, 32, 8, 4), 'relu'), nn.ReLU(), init_ortho(nn.Conv2d(32, 64, 4, 2), 'relu'), nn.ReLU(), init_ortho(nn.Conv2d(64, 64, 3, 1), 'relu'), nn.ReLU(), Flatten()) self.heads = [nn.Sequential( nn.Linear(64 * 7 * 7, emb_size * (5 if i == 0 else 1)), nn.Softmax(-1)) for i in range(num_heads)] for i, h in enumerate(self.heads): setattr(self, f'heads_{i}', h) self.head_main = None
def with_relu(m): return nn.Sequential(init_ortho(m, 'relu'), nn.ReLU())