def reconstruct_images(checkpoint_path, data_args, model_args): normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) training_data = ImagesDataset(root_dir=data_args['root_dir'], transform=transforms.Compose( [transforms.ToTensor(), normalize])) training_loader = torch.utils.data.DataLoader( training_data, batch_size=data_args['batch_size'], shuffle=True, num_workers=data_args['num_workers']) checkpoint = load_checkpoint(checkpoint_path, device_id=0) model = VqVae(**model_args).to('cuda') model.load_state_dict(checkpoint['state_dict']) model.eval() data = next(iter(training_loader)) data = data.to('cuda') _, data_recon, _ = model(data) recon_error = F.mse_loss(data_recon, data) print('reconstruct error: %6.2f' % recon_error) unnormalize = NormalizeInverse(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) data_recon_unnormalized = unnormalize(data_recon) data_orig_unnormalized = unnormalize(data) save_images2(make_grid(data_recon_unnormalized.cpu().data), 'recon') save_images2(make_grid(data_orig_unnormalized.cpu().data), 'orig')
def reconstruct(checkpoint_path, batch_size, model_args, video_file, max_seq_length, resolution=256): checkpoint = load_checkpoint(checkpoint_path, device_id=0) model = VqVae(**model_args).to('cuda') model.load_state_dict(checkpoint['state_dict']) model.eval() normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) unnormalize = NormalizeInverse(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) video_in = cv2.VideoCapture(video_file) fps = video_in.get(cv2.CAP_PROP_FPS) video_out = cv2.VideoWriter('reconst.mp4', cv2.VideoWriter_fourcc(*'FMP4'), fps, (resolution, resolution)) raw_images = [] raw_seqs = [] i = 1 while video_in.isOpened(): ret, frame = video_in.read() if ret: raw_seqs.append(frame) else: break if len(raw_seqs) == max_seq_length: raw_images.append(np.stack(raw_seqs, 0)) raw_seqs = [] if len(raw_images) == batch_size: print('batch %d' % i) i += 1 raw_images = np.stack(raw_images, 0) images_recon = save_video(model, raw_images, normalize, unnormalize) for seqs in images_recon: for frame in seqs: video_out.write(frame) raw_images = [] if len(raw_images) > 0: print('batch %d' % i) i += 1 raw_images = np.stack(raw_images, 0) images_recon = save_video(model, raw_images, normalize, unnormalize) for seqs in images_recon: for frame in seqs: video_out.write(frame) video_out.release() video_in.release() cv2.destroyAllWindows()
def train_images(): from train.images.image_utils import params data_args = params['data_args'] train_args = params['train_args'] model_args = params['model_args'] if params['use_wandb']: wandb.login(key=os.environ['wanda_api_key']) run_wandb = wandb.init(project='dalle_train_vae', job_type='train_model', config=params, resume=train_args['checkpoint_path'] is not None) else: run_wandb = RunDisabled() model = VqVae(**model_args).to('cuda') print('num of trainable parameters: %d' % get_model_size(model)) print(model) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) unnormalize = NormalizeInverse(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) training_data = ImageFolder( data_args['root_dir'], transforms.Compose([ transforms.RandomResizedCrop(256), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ])) training_loader = torch.utils.data.DataLoader( training_data, batch_size=data_args['batch_size'], shuffle=True, num_workers=data_args['num_workers']) train_object = TrainVqVae(model=model, training_loader=training_loader, run_wandb=run_wandb, unnormalize=unnormalize, **train_args) try: train_object.train() finally: run_wandb.finish()
def train_codes(): from train.codes.code_utils import params data_args = params['data_args'] train_args = params['train_args'] model_args = params['model_args'] if params['use_wandb']: wandb.login(key=os.environ['wanda_api_key']) run_wandb = wandb.init(project='dalle_train_vae', job_type='train_model', config=params, resume=train_args['checkpoint_path'] is not None) else: run_wandb = RunDisabled() model = VqVae(**model_args).to('cuda') print('num of trainable parameters: %d' % get_model_size(model)) print(model) # mean and std of codes are computed using imagenet dataset normalize = transforms.Normalize(mean=[0.1635], std=[0.1713]) unnormalize = NormalizeInverse(mean=[0.1635], std=[0.1713]) training_data = NumpyDataset( data_args['root_dir'], data_args['max_seq_length'], data_args['padding_file'], transforms.Compose([Rescale(code_size=8192.), normalize])) training_loader = torch.utils.data.DataLoader( training_data, batch_size=data_args['batch_size'], shuffle=True, num_workers=data_args['num_workers']) train_object = TrainVqVae(model=model, training_loader=training_loader, run_wandb=run_wandb, unnormalize=unnormalize, **train_args) try: train_object.train() finally: run_wandb.finish()
def train_videos(): from video_utils import params data_args = params['data_args'] train_args = params['train_args'] model_args = params['model_args'] if params['use_wandb']: wandb.login(key=os.environ['wanda_api_key']) run_wandb = wandb.init(project='dalle_train_vae', job_type='train_model', config=params, resume=train_args['checkpoint_path'] is not None) else: run_wandb = RunDisabled() model = VqVae(**model_args).to('cuda') print('num of trainable parameters: %d' % get_model_size(model)) print(model) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) unnormalize = NormalizeInverse(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) training_loader = video_pipe(batch_size=data_args['batch_size'], num_threads=data_args['num_threads'], device_id=data_args['device_id'], filenames=data_args['training_data_files'], seed=data_args['seed']) training_loader.build() train_object = TrainVqVae(model=model, training_loader=training_loader, run_wandb=run_wandb, normalize=normalize, unnormalize=unnormalize, **train_args) try: train_object.train() finally: run_wandb.finish()
def reconstruct(checkpoint_path, data_args, model_args): training_loader = video_pipe(batch_size=data_args['batch_size'], num_threads=data_args['num_threads'], device_id=data_args['device_id'], filenames=data_args['training_data_files'], seed=data_args['seed']) training_loader.build() training_loader = DALIGenericIterator(training_loader, ['data']) checkpoint = load_checkpoint(checkpoint_path, device_id=0) model = VqVae(**model_args).to('cuda') model.load_state_dict(checkpoint['state_dict']) model.eval() normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) unnormalize = NormalizeInverse(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) images = next(training_loader)[0]['data'] b, d, _, _, c = images.size() images = rearrange(images, 'b d h w c -> (b d) c h w') images = normalize(images.float() / 255.) images = rearrange(images, '(b d) c h w -> b (d c) h w', b=b, d=d, c=c) vq_loss, images_recon, _ = model(images) print('reconstruct error: %6.2f' % vq_loss) images, images_recon = map( lambda t: rearrange(t, 'b (d c) h w -> (b d) c h w', b=b, d=d, c=c), [images, images_recon]) images_orig, images_recs = train_visualize(unnormalize=unnormalize, images=images, n_images=b * d, image_recs=images_recon) save_images(file_name='images_orig.png', image=images_orig) save_images(file_name='images_recon.png', image=images_recs)
def reconstruct_codes(checkpoint_path, data_args, model_args, np_folder): normalize = transforms.Normalize(mean=[0.1635], std=[0.1713]) unnormalize = NormalizeInverse(mean=[0.1635], std=[0.1713]) training_data = NumpyDataset( data_args['root_dir'], data_args['max_seq_length'], data_args['padding_file'], transforms.Compose([Rescale(code_size=8192.), normalize])) training_loader = torch.utils.data.DataLoader( training_data, batch_size=data_args['batch_size'], shuffle=True, num_workers=data_args['num_workers']) checkpoint = load_checkpoint(checkpoint_path, device_id=0) model = VqVae(**model_args).to('cuda') model.load_state_dict(checkpoint['state_dict']) model.eval() data = next(iter(training_loader)) data = data.to('cuda') _, data_recon, _ = model(data) data_recon = torch.clip(unnormalize(data_recon) * 8192, 0, 8192).int() data_recon = data_recon.to('cpu').numpy() data = torch.clip(unnormalize(data) * 8192, 0, 8192).int() data = data.to('cpu').numpy() for i in range(data_args['batch_size']): p_recon = os.path.join(np_folder, 'recon', f'{i}.pny.gz') with gzip.GzipFile(p_recon, 'w') as f: np.save(file=f, arr=data_recon[i]) p_orig = os.path.join(np_folder, 'orig', f'{i}.pny.gz') with gzip.GzipFile(p_orig, 'w') as f: np.save(file=f, arr=data[i])