Example #1
0
def question_1i_sanity_check():
    """ Sanity check for cnn.
    """
    print ("-"*80)
    print("Running Sanity Check for Question 1i: Cnn")
    print ("-"*80)
    
    # TODO: write test cases
    emb_size = 2
    cnn = CNN(2, emb_size, kernel_size=2)
    # out_channels, in_channels, kernel_size
    cnn.cnn.weight.data = torch.Tensor([[[1, -1], [-1, 1]],
                                       [[0.1, 0.2], [-0.3, 0.05]]])
    cnn.cnn.bias.data = torch.zeros(emb_size)

    # validate input & output shape
    inpt = torch.Tensor([[[0.1, 0.2, 0, 0], [0, 1.3, 0, 0]]])
    batch_size = inpt.size()[0]
    output_expected_size = [batch_size, emb_size]
    output = cnn(inpt)
    assert(list(output.size()) == output_expected_size), "output shape is incorrect: it should be:\n {} but is:\n{}".format(output_expected_size, list(output.size()))
    
    after_relu = F.relu(cnn.cnn(inpt))
    expected_after_relu = torch.Tensor([[1.2, 0, 0], [0.115, 0, 0]])
    assert(after_relu.allclose(expected_after_relu)), "after_relu is incorrect: it should be:\n {} but is:\n{}".format(expected_after_relu, after_relu)

    expected_output = torch.Tensor([[1.2, 0.115]])
    assert(output.allclose(expected_output)), "output is incorrect: it should be:\n {} but is:\n{}".format(expected_output, output)

    print("Sanity Check Passed for Question 1i: Cnn!")
    print("-"*80)