def WNetTest(): encoded=EncoderTest(verbose=False) decoder=WNet.UDec(4) reproduced=decoder(encoded) var=torch.var(reproduced) mean=torch.mean(reproduced) print('Passed Decoder Test with var=%s and mean=%s' % (var, mean))
def DecoderTest(): shape=(2, 4, 224, 224) out_shape=(2, 3, 224, 224) decoder=WNet.UDec(shape[1]) data=torch.rand(tuple(shape)) decoded=decoder(data) assert tuple(decoded.shape)==out_shape var=torch.var(decoded) mean=torch.mean(decoded) print('Passed Decoder Test with var=%s and mean=%s' % (var, mean))