def test_hybridnet(input_sizes): model = HybridNet( input_sizes['n_channels'], input_sizes['n_classes'], input_sizes['n_in_times'], ) check_forward_pass(model, input_sizes, only_check_until_dim=2)
def test_hybridnet(): rng = np.random.RandomState(42) n_channels = 18 n_in_times = 600 n_classes = 2 n_samples = 7 X = rng.randn(n_samples, n_channels, n_in_times, 1) X = torch.Tensor(X.astype(np.float32)) model = HybridNet(n_channels, n_classes, n_in_times) y_pred = model(X) assert y_pred.shape[:2] == (n_samples, n_classes)