예제 #1
0
파일: main.py 프로젝트: augustdemi/jcvae
BIAS_TRAIN = (train_data_size - 1) / (args.batch_size - 1)
BIAS_TEST = (test_data.dataset.__len__() - 1) / (args.batch_size - 1)


def cuda_tensors(obj):
    for attr in dir(obj):
        value = getattr(obj, attr)
        if isinstance(value, torch.Tensor):
            setattr(obj, attr, value.cuda())


encA = EncoderA(args.wseed, zPrivate_dim=args.n_private)
decA = DecoderA(args.wseed, zPrivate_dim=args.n_private)
encB = EncoderB(args.wseed)
decB = DecoderB(args.wseed)
if CUDA:
    encA.cuda()
    decA.cuda()
    encB.cuda()
    decB.cuda()
    cuda_tensors(encA)
    cuda_tensors(decA)
    cuda_tensors(encB)
    cuda_tensors(decB)

optimizer = torch.optim.Adam(list(encB.parameters()) +
                             list(decB.parameters()) +
                             list(encA.parameters()) + list(decA.parameters()),
                             lr=args.lr)