コード例 #1
0
ファイル: test_dims.py プロジェクト: moxime/joint-vae
    print('TYPE:', ntype)
    n = Net(D, C,
            type_of_net=ntype,
            y_is_coded=y_coded and ntype not in ('vib', 'vae'),
            batch_norm='encoder',
            features='vgg16',
            encoder_layer_sizes=[],
            decoder_layer_sizes=[],
            classifier_layer_sizes=cls_cvae if ntype == 'cvae' else [20, 10],
            sigma=0,
            gamma=gamma,
            force_cross_y=0,
            latent_sampling=L,
            latent_dim=K)
    n.to(d)
    nets[ntype] = n
    # n.compute_max_batch_size(batch_size=1024)
    # print(n.max_batch_sizes)
    
    if n.y_is_coded:
        pass

    if ntype != 'vae':
        print('y in input')
        out_y[ntype] = n.evaluate(x, y)
    print('y is none')
    out[ntype] = n.evaluate(x)


for o, _y in zip((out, out_y), ('*', 'y')):
コード例 #2
0
ファイル: train.py プロジェクト: moxime/joint-vae
    jvae.saved_dir = save_dir

    if args.resume:
        with open(os.path.join(resumed_from, 'RESUMED'), 'w') as f:
            f.write(str(job_number) + '\n')

    with open(os.path.join(job_dir, f'number-{hostname}'), 'w') as f:
        f.write(str(job_number + 1) + '\n')

    log.info('Network built, will be saved in')
    log.info(save_dir)

    log.debug('%s: %s', 'Network architecture',
              jvae.print_architecture(True, True))

    jvae.to(device)

    # print('*** .device', jvae.device)

    x, y = torchdl.get_batch(trainset, device=device, batch_size=8)

    if debug:
        log.debug('Trying a first pass')
        log.debug('x in [%.2f, %.2f] with mean (std) %.2f (%.2f)',
                  x.min().item(),
                  x.max().item(),
                  x.mean().item(),
                  x.std().item())
        outs = jvae(x, y if jvae.y_is_coded else None)
        log.debug(' -- '.join(map(str, ([tuple(u.shape) for u in outs]))))