Exemple #1
0
def main():
    save_path = 'models/model_23.pt'
    no_images = 64
    images_size = 32
    images_channels = 3
    

    #Define and load model
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    net = PixelCNN().to(device)
    net.load_state_dict(torch.load(save_path))
    net.eval()

    sample = torch.zeros(no_images, images_channels, images_size, images_size).to(device)
    print('-------------------------------------SAMPLING!!!!!---------------------------------')

    for i in tqdm(range(images_size)):
        for j in range(images_size):
            for c in range(images_channels):
                out = net(sample)
                probs = torch.softmax(out[:, :, c, i, j], dim=1)
                # print(probs)
                sampled_levels = torch.multinomial(probs, 1).squeeze().float() / (63.0)
                sample[:,c,i,j] = sampled_levels


    torchvision.utils.save_image(sample, 'sample.png', nrow=12, padding=0)
    h5f = h5py.File(hdf5_path, 'r')
else:
    with h5py.File(hdf5_path, 'w') as h5f:
        for name, shape in init_vars:

            val = tf.train.load_variable(tf_path, name)

            print(val.dtype)
            print("Loading TF weight {} with shape {}, {}".format(
                name, shape, val.shape))
            torch.from_numpy(np.array(val))
            if 'model' in name:
                new_name = name.replace('/', '.')
                print(new_name)
                h5f.create_dataset(str(new_name), data=val)

    h5f = h5py.File(hdf5_path, 'r')

model = PixelCNN(nr_resnet=5,
                 nr_filters=160,
                 input_channels=3,
                 nr_logistic_mix=10)

#print(model.state_dict().keys())
converter = TF2Pytorch(h5f)
converter.load_pixelcnn()

model.load_state_dict(converter.state_dict)
torch.save(model.state_dict(), ckpt_path)
h5f.close()
Exemple #3
0
NB_CLASSES = 10

def get_device():
    if TRY_CUDA == False:
        return torch.device('cpu')
    if torch.cuda.is_available():
        return torch.device('cuda')
    return torch.device('cpu')

device = torch.device('cuda' if TRY_CUDA and torch.cuda.is_available() else 'cpu')
print(f"> Using device {device}")

try:
    print(f"> Loading PixelCNN from file {sys.argv[1]}")
    model = PixelCNN(IMAGE_DIM, 16, 5, 256, 10).to(device)
    model.load_state_dict(torch.load(sys.argv[1]))
    model.eval()
    print("> Loaded PixelCNN succesfully!")
except:
    print("! Failed to load state dict!")
    print("! Make sure model is of correct size and path is correct!")
    exit()

with torch.no_grad():
    sample = torch.zeros(NB_SAMPLES*NB_CLASSES, *IMAGE_DIM).to(device)
    cond = torch.tensor([d for d in range(NB_CLASSES) for _ in range(NB_SAMPLES)]).to(device)

    pb = tqdm(total=IMAGE_DIM[0]*IMAGE_DIM[1]*IMAGE_DIM[2])

    for c in range(IMAGE_DIM[0]):
        for i in range(IMAGE_DIM[1]):