def question_1i_sanity_check(): """ Sanity check for word Convolution layer """ input = torch.randn((10, 50, 20)) #(B, word_num, char_embed, char_num) kernel_size = 5 cnn = CNN(20, 50, 300, kernel_size) conved = cnn.conv(input) assert conved.shape == (10, 300, 16) pooled = torch.squeeze(cnn.pooling(conved)) assert pooled.shape == (10, 300) print("-" * 80) print("Sanity Check Passed for Question 1i: CNN")