예제 #1
0
파일: layers.py 프로젝트: sk409/Cartpole
 def forward(self, x, train=True):
     n, c, h, w = x.shape
     self.batch_size = n
     out_h, out_w = filter_out_size(h, w, self.fh, self.fw, self.stride_h, self.stride_w, self.padding)
     col = im2col(x, self.fh, self.fw, self.stride_h, self.stride_w, self.padding)
     col = col.reshape(-1, self.fh*self.fw)
     out = np.mean(col, axis=1)
     out = out.reshape(n, out_h, out_w, c).transpose(0, 3, 1, 2)
     return out
예제 #2
0
파일: layers.py 프로젝트: sk409/Cartpole
 def compile(self, input_shape):
     if input_shape is None:
         assert self.input_shape is not None
     else:
         self.input_shape = input_shape
     assert len(self.input_shape) == 3
     c, h, w = input_shape
     out_h, out_w = filter_out_size(h, w, self.fh, self.fw, self.stride_h, self.stride_w, self.padding)
     return (c, out_h, out_w)
예제 #3
0
파일: layers.py 프로젝트: sk409/Cartpole
 def forward(self, x, train=True):
     oc, _, fh, fw = self.w.shape
     n, _, h, w = x.shape
     out_h, out_w = filter_out_size(h, w, fh, fw, self.stride_h, self.stride_w, self.padding)
     col = im2col(x, fh, fw, self.stride_h, self.stride_w, self.padding)
     col_w = self.w.reshape(oc, -1).T
     out = np.dot(col, col_w) + self.b
     out = out.reshape(n, out_h, out_w, -1).transpose(0, 3, 1, 2)
     self.x = x
     self.col = col
     self.col_w = col_w
     return out
예제 #4
0
파일: layers.py 프로젝트: sk409/Cartpole
 def compile(self, input_shape):
     if input_shape is None:
         assert self.input_shape is not None
     else:
         self.input_shape = input_shape
     assert len(self.input_shape) == 3
     ic, h, w = self.input_shape
     oc, fw, fh = self.out_channel, self.fw, self.fh
     out_h, out_w = filter_out_size(h, w, fh, fw, self.stride_h, self.stride_w, self.padding)
     self.w = self.weight_initializer(self.input_shape, (oc, ic, fh, fw))
     self.b = np.zeros(oc, skml_config.config.f_type)
     self.params = [self.w, self.b]
     self.grads = [np.zeros_like(self.w, skml_config.config.f_type), np.zeros_like(self.b, skml_config.config.f_type)]
     return (oc, out_h, out_w)