def test_rel():
    model = RelCNN(16,
                   32,
                   num_layers=2,
                   batch_norm=True,
                   cat=True,
                   lin=True,
                   dropout=0.5)
    assert model.__repr__() == ('RelCNN(16, 32, num_layers=2, batch_norm=True'
                                ', cat=True, lin=True, dropout=0.5)')
    assert model.convs[0].__repr__() == 'RelConv(16, 32)'

    x = torch.randn(100, 16)
    edge_index = torch.randint(100, (2, 400), dtype=torch.long)
    for cat, lin in product([False, True], [False, True]):
        model = RelCNN(16, 32, 2, True, cat, lin, 0.5)
        out = model(x, edge_index)
        assert out.size() == (100, 16 + 2 * 32 if not lin and cat else 32)
        assert out.size() == (100, model.out_channels)
parser.add_argument('--num_steps', type=int, default=10)
parser.add_argument('--k', type=int, default=10)
args = parser.parse_args()


class SumEmbedding(object):
    def __call__(self, data):
        data.x1, data.x2 = data.x1.sum(dim=1), data.x2.sum(dim=1)
        return data


device = 'cuda' if torch.cuda.is_available() else 'cpu'
path = osp.join('..', 'data', 'DBP15K')
data = DBP15K(path, args.category, transform=SumEmbedding())[0].to(device)

psi_1 = RelCNN(data.x1.size(-1), args.dim, args.num_layers, batch_norm=False,
               cat=True, lin=True, dropout=0.5)
psi_2 = RelCNN(args.rnd_dim, args.rnd_dim, args.num_layers, batch_norm=False,
               cat=True, lin=True, dropout=0.0)
model = DGMC_modified(psi_1, psi_2, num_steps=None, k=args.k).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


def train():
    model.train()
    optimizer.zero_grad()

    S_L = model(data.x1, data.edge_index1, None, None, data.x2,
                   data.edge_index2, None, None, data.train_y)

    loss = model.loss(S_L, data.train_y)
    loss.backward()