Пример #1
0
parser.add_argument("--loadpth",
                    type=str,
                    default='./results/vqvae_data_bo.pth')
parser.add_argument("--data_dir",
                    type=str,
                    default='/home/karam/Downloads/bco/')
parser.add_argument("--data", type=str, default='bco')
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#Load model
model = VQVAE(args.n_hiddens, args.n_residual_hiddens, args.n_residual_layers,
              args.n_embeddings, args.embedding_dim, args.beta).to(device)
assert args.loadpth is not ''
model.load_state_dict(torch.load(args.loadpth)['model'])
model.eval()
print("Loaded model")

#Load data
save_dir = os.getcwd() + '/data'
data_dir = args.data_dir
if args.data == 'bco':
    data1 = np.load(data_dir + "/bcov5_0.npy")
    data2 = np.load(data_dir + "/bcov5_1.npy")
    data3 = np.load(data_dir + "/bcov5_2.npy")
    data4 = np.load(data_dir + "/bcov5_3.npy")
    data = np.concatenate((data1, data2, data3, data4), axis=0)
elif args.data == 'bo':
    data = np.load(data_dir + "/bo-150-50-20.npy", allow_pickle=True)

n_trajs, length = data.shape[:2]
Пример #2
0
n_trajs, length = data.shape[:2]
img_dim=args.img_dim

model = GatedPixelCNN(n_embeddings=args.n_embeddings, imgximg=args.img_dim**2, 
    n_layers=args.n_layers, conditional=args.conditional,
    x_one_hot=args.x_one_hot,c_one_hot=args.c_one_hot, n_cond_res_block=args.n_cres_layers).to(device)
model.train()
criterion = nn.CrossEntropyLoss().cuda()
opt = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

if args.loadpth_vq is not '':
    vae = VQVAE(args.n_hiddens, args.n_residual_hiddens,
              args.n_residual_layers, args.n_embeddings, args.embedding_dim, args.beta).cuda()
    vae.load_state_dict(torch.load(args.loadpth_vq)['model'])
    print("VQ Loaded")
    vae.eval()
    if args.data=='bco':
        sample_c=vae(sample_c_imgs,latent_only=True).detach().cpu().numpy().reshape(-1,length).squeeze()

# 
if args.loadpth_pcnn is not '':
    model.load_state_dict(torch.load(args.loadpth_pcnn))
    print("PCNN Loaded")

n_trajs = len(data)
dt = n_trajs // context.shape[0]
n_batch = int(n_trajs / args.batch_size)
n_trajs_t = len(val)
dv = n_trajs_t // valcon.shape[0]
n_batch_t = int(n_trajs_t / args.batch_size)