Пример #1
0
    model_cond.eval()
    print('Complete !')
    optimizer_cond = optim.Adam(model_cond.parameters(), lr=args.lr)

    model_D = MultiscaleDiscriminator(input_nc=3).to(device)
    model_D = nn.DataParallel(model_D).cuda()
    print('Loading Model_D...', end='')
    model_D.load_state_dict(torch.load('/p300/mem/mem_src/SPADE/checkpoint/as_101/vqvae_D_072.pt'))
    model_D.eval()
    print('Complete !')
    optimizer_D = optim.Adam(model_D.parameters(), lr=args.lr)
    # if args.sched == 'cycle':
    #     scheduler = CycleScheduler(
    #         optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None
    #     )

    dic_model = {'model_img': model, 'model_cond': model_cond,
                 # 'model_transfer': model_transfer,
                 'model_D': model_D,
                 'optimizer_img': optimizer, 'optimizer_cond': optimizer_cond,
                 # 'optimizer_transfer': optimizer_transfer
                 'optimizer_D': optimizer_D
                 }

    for i in range(args.epoch):
        viz.text(f'{DESCRIPTION} ##### Epoch: {i} #####', win='board')
        train(i, loader_train, dic_model, scheduler, device)
        val(i, loader_val, dic_model, scheduler, device)
        torch.save(model.state_dict(), f'checkpoint/{EXPERIMENT_CODE}/vqvae_{str(i + 1).zfill(3)}.pt')
        torch.save(model_D.state_dict(), f'checkpoint/{EXPERIMENT_CODE}/vqvae_D_{str(i + 1).zfill(3)}.pt')
Пример #2
0
        torch.load('/p300/mem/mem_src/checkpoint/pose_04/vqvae_462.pt'))
    model_cond.eval()
    print('Complete !')
    optimizer_cond = optim.Adam(model_cond.parameters(), lr=args.lr)

    model_D = MultiscaleDiscriminator(input_nc=3).to(device)
    model_D = nn.DataParallel(model_D).cuda()
    optimizer_D = optim.Adam(model_D.parameters(), lr=args.lr)
    # if args.sched == 'cycle':
    #     scheduler = CycleScheduler(
    #         optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None
    #     )

    dic_model = {
        'model_img': model,
        'model_cond': model_cond,
        # 'model_transfer': model_transfer,
        'model_D': model_D,
        'optimizer_img': optimizer,
        'optimizer_cond': optimizer_cond,
        # 'optimizer_transfer': optimizer_transfer
        'optimizer_D': optimizer_D
    }

    for i in range(args.epoch):
        viz.text(f'{DESCRIPTION} ##### Epoch: {i} #####', win='board')
        train(i, loader, dic_model, scheduler, device)
        torch.save(
            model.state_dict(),
            f'checkpoint/{EXPERIMENT_CODE}/vqvae_{str(i + 1).zfill(3)}.pt')