def question_1i_sanity_check(): """ Sanity check for cnn. """ print ("-"*80) print("Running Sanity Check for Question 1i: Cnn") print ("-"*80) # TODO: write test cases emb_size = 2 cnn = CNN(2, emb_size, kernel_size=2) # out_channels, in_channels, kernel_size cnn.cnn.weight.data = torch.Tensor([[[1, -1], [-1, 1]], [[0.1, 0.2], [-0.3, 0.05]]]) cnn.cnn.bias.data = torch.zeros(emb_size) # validate input & output shape inpt = torch.Tensor([[[0.1, 0.2, 0, 0], [0, 1.3, 0, 0]]]) batch_size = inpt.size()[0] output_expected_size = [batch_size, emb_size] output = cnn(inpt) assert(list(output.size()) == output_expected_size), "output shape is incorrect: it should be:\n {} but is:\n{}".format(output_expected_size, list(output.size())) after_relu = F.relu(cnn.cnn(inpt)) expected_after_relu = torch.Tensor([[1.2, 0, 0], [0.115, 0, 0]]) assert(after_relu.allclose(expected_after_relu)), "after_relu is incorrect: it should be:\n {} but is:\n{}".format(expected_after_relu, after_relu) expected_output = torch.Tensor([[1.2, 0.115]]) assert(output.allclose(expected_output)), "output is incorrect: it should be:\n {} but is:\n{}".format(expected_output, output) print("Sanity Check Passed for Question 1i: Cnn!") print("-"*80)