def check_forward(self, x, W, cp, pp): x_in = convolution2D.Forward(array(x), array(W), None, cp) y_act, self.indexes_act = pooling2D.Forward(x_in, pp) x_in_npy = numpy.array(x_in, dtype=self.dtype) y_act_npy = numpy.array(y_act, dtype=self.dtype) indexes_act_npy = numpy.array(self.indexes_act, dtype=self.dtype) col = im2col_cpu(x_in_npy, 3, 3, 2, 2, 0, 0, pval=-float('inf'), cover_all=True) n, c, kh, kw, out_h, out_w = col.shape col = col.reshape(n, c, kh * kw, out_h, out_w) self.indexes_ref = col.argmax(axis=2) y_ref = col.max(axis=2) numpy.testing.assert_allclose(y_act_npy, y_ref, **self.check_forward_options) numpy.testing.assert_allclose(indexes_act_npy, self.indexes_ref, **self.check_forward_options)
def send_array(self, array): if isinstance(array, ideep.mdarray): return array if not isinstance(array, numpy.ndarray): array = _cpu._to_cpu(array) # to numpy.ndarray if (isinstance(array, numpy.ndarray) and array.ndim in (1, 2, 4) and 0 not in array.shape): # TODO(kmaehashi): Remove ndim validation once iDeep has fixed. # Currently iDeep only supports (1, 2, 4)-dim arrays. # Note that array returned from `ideep.array` may not be an # iDeep mdarray, e.g., when the dtype is not float32. array = ideep.array(array, itype=ideep.wgt_array) return array
def check_backward(self, gy, W, pp): gy_in = linear.BackwardData(array(W), array(gy)) gy_in = gy_in.reshape(, self.oc, 6, 6) gx_act = pooling2D.Backward(gy_in, self.indexes_act, pp) gy_in_npy = numpy.array(gy_in, dtype=self.dtype) gx_act_npy = numpy.array(gx_act, dtype=self.dtype) n, c, out_h, out_w = gy_in_npy.shape h = 13 w = 13 kh = 3 kw = 3 gcol = numpy.zeros((n * c * out_h * out_w * 3 * 3), dtype=self.dtype) indexes = self.indexes_ref.flatten() indexes += numpy.arange(0, indexes.size * kh * kw, kh * kw) gcol[indexes] = gy_in.ravel() gcol = gcol.reshape(n, c, out_h, out_w, kh, kw) gcol = numpy.swapaxes(gcol, 2, 4) gcol = numpy.swapaxes(gcol, 3, 5) gx_ref = col2im_cpu(gcol, 2, 2, 0, 0, h, w) numpy.testing.assert_allclose(gx_act_npy, gx_ref, **self.check_backward_options)