Ejemplo n.º 1
0
def cli(ctx, profile, snapshot):
    # load hyper-parameters
    hps = util.load_profile(profile)
    util.manual_seed(hps.ablation.seed)
    if snapshot is not None:
        hps.general.warm_start = True
        hps.general.pre_trained = snapshot

    # build graph
    builder = Builder(hps)
    state = builder.build(training=False)

    # load dataset
    dataset = CelebA(root=hps.dataset.root,
                     transform=transforms.Compose(
                         (transforms.CenterCrop(160), transforms.Resize(128),
                          transforms.ToTensor())))

    # start inference
    inferer = Inferer(hps=hps,
                      graph=state['graph'],
                      devices=state['devices'],
                      data_device=state['data_device'])
    ctx.obj['hps'] = hps
    ctx.obj['dataset'] = dataset
    ctx.obj['inferer'] = inferer
Ejemplo n.º 2
0
 def test_glow_model(self):
     # build model
     hps = util.load_profile('profile/test.json')
     glow_model = Glow(hps).cuda()
     image_shape = hps.model.image_shape
     batch_size = glow_model.h_top.shape[0]
     # read image
     img = Image.open('misc/test.png').convert('RGB')
     x = util.pil_to_tensor(img, shape=image_shape)
     x = util.make_batch(x, batch_size).cuda()
     y_onehot = torch.zeros((batch_size, hps.dataset.num_classes)).cuda()
     # forward and reverse flow
     z, logdet, y_logits = glow_model(x=x, y_onehot=y_onehot, reverse=False)
     x_ = glow_model(z=z, y_onehot=y_onehot, reverse=True)
Ejemplo n.º 3
0
#     return parser.parse_args()


if __name__ == '__main__':
    # this enables a Ctrl-C without triggering errors
    signal.signal(signal.SIGINT, lambda x, y: sys.exit(0))
    print("hello")
    # parse arguments
    #args = parse_args()
    args = "profile/celeba.json"
    # initialize logging
    util.init_output_logging()

    # load hyper-parameters
    # hps = util.load_profile(args.profile)
    hps = util.load_profile(args)
    util.manual_seed(hps.ablation.seed)

    # build graph
    builder = Builder(hps)
    state = builder.build()

    # load dataset
    dataset = CelebA(root=hps.dataset.root,
                     transform=transforms.Compose((
                         transforms.CenterCrop(160),
                         transforms.Resize(64),
                         transforms.ToTensor()
                     )))

    # start training
Ejemplo n.º 4
0
 def test_load_profile(self):
     hps = util.load_profile('profile/celebahq_256x256_5bit.json')
     self.assertIsInstance(hps, dict)