def _test_conv_backward(input_shape, out_channels, kernel_size, stride): np.random.seed(0) torch.manual_seed(0) in_channels = input_shape[1] #print('test ksize',kernel_size) #print('strid',stride) padding = (kernel_size - 1) // 2 #print('test pad',padding) input = np.random.random(input_shape).astype(np.float32) * 20 layer = ConvLayer(in_channels, out_channels, kernel_size, stride) torch_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=True) utils.assign_conv_layer_weights(layer, torch_layer) output = layer.forward(input) out_grad = layer.backward(2 * np.ones_like(output) / output.size) torch_input = utils.from_numpy(input).requires_grad_(True) torch_out = torch_layer(torch_input) (2 * torch_out.mean()).backward() utils.assert_close(out_grad, torch_input.grad, atol=TOLERANCE) utils.check_conv_grad_match(layer, torch_layer)
def _test_conv_forward(input_shape, out_channels, kernel_size, stride): np.random.seed(0) torch.manual_seed(0) in_channels = input_shape[1] padding = (kernel_size - 1) // 2 input = np.random.random(input_shape).astype(np.float32) * 20 original_input = input.copy() layer = ConvLayer(in_channels, out_channels, kernel_size, stride) torch_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=True) utils.assign_conv_layer_weights(layer, torch_layer) output = layer.forward(input) torch_data = utils.from_numpy(input) torch_out = torch_layer(torch_data) assert np.all(input == original_input) assert output.shape == torch_out.shape utils.assert_close(output, torch_out, atol=TOLERANCE)
def test_networks(): np.random.seed(0) torch.manual_seed(0) data = np.random.random((100, 1, 28, 28)).astype(np.float32) * 10 - 5 labels = np.random.randint(0, 10, 100).astype(np.int64) net = MNISTResNetwork() torch_net = TorchMNISTResNetwork() utils.assign_conv_layer_weights(net.layers[0], torch_net.layers[0]) utils.assign_conv_layer_weights(net.layers[3], torch_net.layers[3]) utils.assign_conv_layer_weights(net.layers[4].conv_layers[0], torch_net.layers[4].conv1) utils.assign_conv_layer_weights(net.layers[4].conv_layers[2], torch_net.layers[4].conv2) utils.assign_conv_layer_weights(net.layers[5].conv_layers[0], torch_net.layers[5].conv1) utils.assign_conv_layer_weights(net.layers[5].conv_layers[2], torch_net.layers[5].conv2) utils.assign_linear_layer_weights(net.layers[9], torch_net.layers[9]) utils.assign_linear_layer_weights(net.layers[11], torch_net.layers[11]) utils.assign_linear_layer_weights(net.layers[13], torch_net.layers[13]) forward = net(data) data_torch = utils.from_numpy(data).requires_grad_(True) forward_torch = torch_net(data_torch) utils.assert_close(forward, forward_torch) loss = net.loss(forward, labels) torch_loss = torch_net.loss(forward_torch, utils.from_numpy(labels)) utils.assert_close(loss, torch_loss) out_grad = net.backward() torch_loss.backward() utils.assert_close(out_grad, data_torch.grad, atol=0.01) tolerance = 1e-4 utils.check_linear_grad_match(net.layers[13], torch_net.layers[13], tolerance=tolerance) utils.check_linear_grad_match(net.layers[11], torch_net.layers[11], tolerance=tolerance) utils.check_linear_grad_match(net.layers[9], torch_net.layers[9], tolerance=tolerance) utils.check_conv_grad_match(net.layers[5].conv_layers[2], torch_net.layers[5].conv2, tolerance=tolerance) utils.check_conv_grad_match(net.layers[5].conv_layers[0], torch_net.layers[5].conv1, tolerance=tolerance) utils.check_conv_grad_match(net.layers[4].conv_layers[2], torch_net.layers[4].conv2, tolerance=tolerance) utils.check_conv_grad_match(net.layers[4].conv_layers[0], torch_net.layers[4].conv1, tolerance=tolerance) utils.check_conv_grad_match(net.layers[3], torch_net.layers[3], tolerance=tolerance) utils.check_conv_grad_match(net.layers[0], torch_net.layers[0], tolerance=tolerance)