def load_vqvae(model_path: Union[str, Path], device: torch.device = None): model_path = Path(model_path) if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): data = torch.load(model_path) else: data = torch.load(model_path, map_location=lambda storage, loc: storage) params = data["hyperparameters"] if 'channels' in params: channels = params['channels'] else: channels = 1 if params['dataset'] == 'MNIST' else 3 model = VQVAE(channels, params['n_hiddens'], params['n_residual_hiddens'], params['n_residual_layers'], params['n_embeddings'], params['embedding_dim'], params['beta']).to(device) model.load_state_dict(data['model']) return model, data
def main(unused_args): if args.logdir is None: raise ValueError('Please specify the dir to the checkpoint') arch = tf.gfile.Glob(join(args.logdir, 'arch*.json'))[0] arch = json2dict(arch) net = VQVAE(arch) data = ByteWavWholeReader(speaker_list=txt2list(args.speaker_list), filenames=tf.gfile.Glob(args.file_pattern)) ZH = net.encode(data.x, args.mode) ema = tf.train.ExponentialMovingAverage(decay=0.995) trg_vars = {ema.average_name(v): v for v in tf.trainable_variables()} saver = tf.train.Saver(trg_vars) sess_config = tf.ConfigProto(allow_soft_placement=True, gpu_options=tf.GPUOptions(allow_growth=True)) with tf.Session(config=sess_config) as sess: sess.run(tf.tables_initializer()) sess.run(data.iterator.initializer) sess.run(tf.global_variables_initializer()) load(saver, sess, args.logdir, ckpt=args.ckpt) hist = np.zeros([ arch['num_exemplar'], ], dtype=np.int64) counter = 1 while True: try: z_ids = sess.run(ZH) print('\rNum of processed files: {:d}'.format(counter), end='') counter += 1 for i in z_ids[0]: # bz = 1 hist[i] += 1 except tf.errors.OutOfRangeError: print() break with open('histogram.npf', 'wb') as fp: hist.tofile(fp) plt.figure(figsize=[10, 2]) plt.plot(np.log10(hist + 1), '.') plt.xlim([0, arch['num_exemplar'] - 1]) plt.ylabel('log-frequency') plt.xlabel('exemplar index') plt.savefig('histogram.png') plt.close()
def get_model(architecture, num_embeddings, device, neighborhood, selection_fn, embed_dim, parallel=True, **kwargs): """ Creates a VQVAE object. :param architecture: Has to be 'vqvae'. :param num_embeddings: Int. Number of dictioanry atoms :param device: String. 'cpu', 'cuda' or 'cuda:device_number' :param neighborhood: Int. Not used. :param selection_fn: String. 'vanilla' or 'fista' :param embed_dim: Int. Size of latent space. :param parallel: Bool. Use DataParallel or not. :return: VQVAE model or DataParallel(VQVAE model) """ if architecture == 'vqvae': model = VQVAE(n_embed=num_embeddings, neighborhood=neighborhood, selection_fn=selection_fn, embed_dim=embed_dim, **kwargs).to(device) else: raise ValueError( 'Valid architectures are vqvae. Got: {}'.format(architecture)) if parallel and device != 'cpu': model = nn.DataParallel(model) return model
def __init__(self, args, config): super(BrainQA, self).__init__(config) self.num_labels = config.num_labels # Set up BERT encoder self.config_enc = config.to_dict() self.config_enc['output_hidden_states'] = True self.config_enc = BertConfig.from_dict(self.config_enc) self.bert_enc = BertModel.from_pretrained(args.model_name_or_path, config=self.config_enc) # Set up BERT decoder self.config_dec = config.to_dict() self.config_dec['is_decoder'] = True self.config_dec = BertConfig.from_dict(self.config_dec) self.bert_dec = BertModel.from_pretrained(args.model_name_or_path, config=self.config_dec) # # VQVAE for external memory self.vqvae_model= VQVAE(h_dim=256, res_h_dim=256, n_res_layers=4, n_embeddings=args.n_vqvae_embeddings, embedding_dim=256, restart=args.vqvae_random_restart, beta=2) # Question answer layer to output spans of question answers self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) self.init_weights()
def main(_): speaker_list = txt2list(args.speaker_list) dirs = validate_log_dirs(args) arch = json2dict(args.arch) arch.update(dirs) arch.update({'ckpt': args.ckpt}) copy_arch_file(args.arch, arch['logdir']) net = VQVAE(arch) P = net.n_padding() print('Receptive field: {} samples ({:.2f} sec)\n'.format(P, P / arch['fs'])) data = ByteWavReader( speaker_list, args.file_pattern, T=arch['T'], batch_size=arch['training']['batch_size'], buffer_size=5000 ) net.train(data)
def decode(model: VQVAE, code, plot_path: str = None): emb = model.vector_quantization.embedding(code.squeeze(1)).permute( 0, 3, 1, 2) hx = model.decoder(emb) display_image_grid(hx) if plot_path: plt.savefig(plot_path) else: plt.show()
def uniform_sample(model: VQVAE, num_samples: int, device, plot_path: str = None): code_shape = model.encode( torch.zeros((num_samples, 3, 32, 32), device=device)).shape print('Latent code shape:', code_shape) if not plot_path: plt.title('Uniform sample') code = torch.randint(0, model.vector_quantization.embedding.num_embeddings, code_shape, device=device) decode(model, code, plot_path)
def main(_): """Train the model based on the command-line arguments.""" # Parse command-line arguments speaker_list = txt2list(args.speaker_list) dirs = validate_log_dirs(args) arch = json2dict(args.arch) arch.update(dirs) arch.update({'ckpt': args.ckpt}) copy_arch_file(args.arch, arch['logdir']) # Initialize the model net = VQVAE(arch) P = net.n_padding() print('Receptive field: {} samples ({:.2f} sec)'.format(P, P / arch['fs'])) # Read the input data as specified by the command line arguments data = ByteWavReader(speaker_list, args.file_pattern, T=arch['T'], batch_size=arch['training']['batch_size'], buffer_size=5000) # Train the model on the input data net.train(data)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if args.save: print('Results will be saved in ./results/vqvae_' + args.filename + '.pth') """ Load data and define batch data loaders """ training_data, validation_data, training_loader, validation_loader, x_train_var = utils.load_data_and_data_loaders( args.dataset, args.batch_size) """ Set up VQ-VAE model with components defined in ./models/ folder """ model = VQVAE(args.n_hiddens, args.n_residual_hiddens, args.n_residual_layers, args.n_embeddings, args.embedding_dim, args.beta).to(device) """ Set up optimizer and training loop """ optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, amsgrad=True) model.train() results = { 'n_updates': 0, 'recon_errors': [], 'loss_vals': [], 'perplexities': [], }
def main(unused_args): if args.logdir is None: raise ValueError('Please specify the dir to the checkpoint') speaker_list = txt2list(args.speaker_list) arch = tf.gfile.Glob(os.path.join(args.logdir, 'arch*.json'))[0] arch = json2dict(arch) net = VQVAE(arch) # they start roughly at the same position but end very differently (3 is longest) filenames = [ 'dataset/VCTK/tfr/p227/p227_363.tfr', # 'dataset/VCTK/tfr/p240/p240_341.tfr', # 'dataset/VCTK/tfr/p243/p243_359.tfr', 'dataset/VCTK/tfr/p225/p225_001.tfr' ] data = ByteWavWholeReader(speaker_list, filenames) X = tf.placeholder(dtype=tf.int64, shape=[None, None]) Y = tf.placeholder(dtype=tf.int64, shape=[ None, ]) ZH = net.encode(X, args.mode) XH = net.generate(X, ZH, Y) # XWAV = mu_law_decode(X) # XBIN = tf.contrib.ffmpeg.encode_audio(XWAV, 'wav', arch['fs']) ema = tf.train.ExponentialMovingAverage(decay=0.995) trg_vars = {ema.average_name(v): v for v in tf.trainable_variables()} saver = tf.train.Saver(trg_vars) logdir = get_default_logdir(args.logdir) tf.gfile.MkDir(logdir) sess_config = tf.ConfigProto(allow_soft_placement=True, gpu_options=tf.GPUOptions(allow_growth=True)) with tf.Session(config=sess_config) as sess: sess.run(tf.tables_initializer()) sess.run(data.iterator.initializer) results = [] for _ in filenames: result = sess.run({'x': data.x, 'y': data.y}) results.append(result) # results1 = sess.run({'x': data.x, 'y': data.y}) # results2 = sess.run({'x': data.x, 'y': data.y}) length_input = net.n_padding() + 1 # same as padding + 1 ini = 15149 - length_input end = 42285 # x_source1 = results1['x'][:, ini: end] # x_source2 = results2['x'][:, ini: end] for i in range(len(results)): x = results[i]['x'] if x.shape[-1] < end: x = np.concatenate( [x, x[0, 0] + np.zeros([1, end - x.shape[-1]])], -1) results[i]['x'] = x[:, ini:end] # from pdb import set_trace # set_trace() x_source = np.concatenate([ results[0]['x'], results[0]['x'], results[1]['x'], results[1]['x'] ], 0) B = x_source.shape[0] y_input = np.concatenate([ results[0]['y'], results[1]['y'], results[1]['y'], results[0]['y'] ], 0) length_target = x_source.shape[1] - length_input while True: sess.run(tf.global_variables_initializer()) load(saver, sess, args.logdir, ckpt=args.ckpt) z_blend = sess.run(ZH, feed_dict={X: x_source}) x_input = x_source[:, :length_input] z_input = z_blend[:, :length_input, :] # Generate try: x_gen = np.zeros([B, length_target], dtype=np.int64) # + results['x'][0, 0] for i in range(length_target): xh = sess.run(XH, feed_dict={ X: x_input, ZH: z_input, Y: y_input }) z_input = z_blend[:, i + 1:i + 1 + length_input, :] x_input[:, :-1] = x_input[:, 1:] x_input[:, -1] = xh[:, -1] x_gen[:, i] = xh[:, -1] print('\rGenerating {:5d}/{:5d}... x={:3d}'.format( i + 1, length_target, xh[0, -1]), end='', flush=True) except KeyboardInterrupt: print("Interrupted by the user.") finally: print() x_wav = mu_law_decode(x_gen) for i in range(x_wav.shape[0]): x_1ch = np.expand_dims(x_wav[i], -1) # x_bin = sess.run(XBIN, feed_dict={X: x_1ch}) librosa.output.write_wav('testwav-{}.wav'.format(i), x_1ch, arch['fs']) # with open(os.path.join(logdir, 'testwav-{}.wav'.format(i)), 'wb') as fp: # fp.write(x_bin) # For periodic gen. if args.period > 0: try: print('Sleep for a while') sleep(args.period * 60) logdir = get_default_logdir(args.logdir) tf.gfile.MkDir(logdir) except KeyboardInterrupt: print('Stop periodic gen.') break finally: print('all finished') else: break
def main(): args = model_args.get_args() if ( os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir ): raise ValueError( "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( args.output_dir ) ) # Setup CUDA, GPU device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") args.n_gpu = torch.cuda.device_count() args.device = device # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger.warning( "Device: %s, n_gpu: %s", device, args.n_gpu ) # Set seed set_seed(args) # Set up model with huggingface pre-trained config args.model_type = args.model_type.lower() config = AutoConfig.from_pretrained( args.config_name if args.config_name else args.model_name_or_path, cache_dir=args.cache_dir if args.cache_dir else None, early_stopping=True ) tokenizer = AutoTokenizer.from_pretrained( args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case, cache_dir=args.cache_dir if args.cache_dir else None, ) model = BrainQA(args=args, config=config) if args.train_vqvae_instead: model = VQVAE(h_dim=256, res_h_dim=256, n_res_layers=4, n_embeddings=4096, embedding_dim=256, beta=2) model.to(args.device) logger.info("Training/evaluation parameters %s", args) # Training if args.do_train: train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False) if args.train_vqvae_instead: global_step, tr_loss = train_vqvae(args, train_dataset, model, model, tokenizer) else: global_step, tr_loss = train(args, train_dataset, model, tokenizer) logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) # Create output directory if needed if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) logger.info("Saving model checkpoint to %s", args.output_dir) if args.train_vqvae_instead: torch.save(model.state_dict(), os.path.join(args.output_dir, 'vqvae_model.bin')) else: model_to_save = model.module if hasattr(model, "module") else model model_to_save.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir) # Save training args as well torch.save(args, os.path.join(args.output_dir, "training_args.bin")) # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory results = {} if args.do_eval: logger.info("Loading checkpoints saved during training for evaluation") checkpoints = model_args.get_checkpoints(args) logger.info("Evaluate the following checkpoints: %s", checkpoints) for checkpoint in checkpoints: # Reload the model logger.info('Evaluating checkpoint: {}'.format(checkpoint)) global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else "" state_dict = torch.load(checkpoint + '/pytorch_model.bin') model.load_state_dict(state_dict) model.to(args.device) # Evaluate result = evaluate(args, model, tokenizer, prefix=global_step) result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items()) results.update(result) logger.info("Results: {}".format(results)) if args.run_visualizations: model = BrainQA(args=args, config=config) checkpoints = model_args.get_checkpoints(args) checkpoint = checkpoint[-1] path_to_dict = os.path.join(args.output_dir, checkpoint, '/pytorch_model.bin') if not os.path.exists(path_to_dict): raise FileNotFoundError(path_to_dict + ' not found. Please make sure you have passed the correct output directory, \ and that it contains a fully trained model checkpoint. Visualizations are not currently in place for VQ-VAE alone.') state_dict = torch.load(path_to_dict) model.load_state_dict(state_dict) model.to(args.device) eval_dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True) visualizations.visualize(model, eval_dataset, tokenizer, args, latent_vis=True) logger.info('Visualizations constructed. Please check /images directory.') return results
parser.add_argument("--n_residual_layers", type=int, default=2) parser.add_argument("--embedding_dim", type=int, default=64) parser.add_argument("--n_embeddings", type=int, default=512) parser.add_argument("--beta", type=float, default=.25) parser.add_argument("--loadpth", type=str, default='./results/vqvae_data_bo.pth') parser.add_argument("--data_dir", type=str, default='/home/karam/Downloads/bco/') parser.add_argument("--data", type=str, default='bco') args = parser.parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #Load model model = VQVAE(args.n_hiddens, args.n_residual_hiddens, args.n_residual_layers, args.n_embeddings, args.embedding_dim, args.beta).to(device) assert args.loadpth is not '' model.load_state_dict(torch.load(args.loadpth)['model']) model.eval() print("Loaded model") #Load data save_dir = os.getcwd() + '/data' data_dir = args.data_dir if args.data == 'bco': data1 = np.load(data_dir + "/bcov5_0.npy") data2 = np.load(data_dir + "/bcov5_1.npy") data3 = np.load(data_dir + "/bcov5_2.npy") data4 = np.load(data_dir + "/bcov5_3.npy") data = np.concatenate((data1, data2, data3, data4), axis=0) elif args.data == 'bo':
def main(args): writer = SummaryWriter(args.experiment_log_path) writer.add_hparams(vars(args), {}) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') transform = transforms.Compose( [transforms.Resize((32, 32), 3), transforms.ToTensor()]) if args.dataset == 'cifar10': train_dataset = datasets.CIFAR10('data', train=True, download=True, transform=transform) test_dataset = datasets.CIFAR10('data', train=False, download=True, transform=transform) args.in_channels = 3 elif args.dataset == 'mnist': train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform) test_dataset = datasets.MNIST('data', train=False, download=True, transform=transform) args.in_channels = 1 else: raise ValueError(f"Invalid dataset: {args.dataset}") train_dataloader = DataLoader(train_dataset, args.batch_size, shuffle=True, pin_memory=True, num_workers=4) test_dataloader = DataLoader(test_dataset, args.batch_size // 4, pin_memory=True, num_workers=4) vqvae = VQVAE(args.in_channels, args.hidden_channels_vqvae, args.num_embeddings, args.embedding_dim) vqvae.load_state_dict( torch.load(args.vqvae_state_dict, map_location=torch.device('cpu'))) vqvae = vqvae.to(device) prior = PixelCNN(args.num_embeddings, args.hidden_channels_prior, args.num_layers, args.num_classes) \ .to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(prior.parameters(), args.lr) # Initialize Loggers train_metric_logger = MeterLogger(("nll", ), writer) val_metric_logger = MeterLogger(("nll", ), writer) print(vqvae) for epoch in tqdm(range(args.num_epoch)): train_metric_logger.reset() prior.train() for train_batch in tqdm(train_dataloader): images, labels = train_batch images = images.to(device) labels = labels.to(device) with torch.no_grad(): # TODO repack into one call latents = vqvae.encoder(images) latents = vqvae.prenet(latents) latents = vqvae.vector_quantizer.proposal_distribution(latents) latents = latents.unsqueeze(1) logits = prior(latents, labels) loss = criterion(logits, latents.squeeze()) optimizer.zero_grad() loss.backward() optimizer.step() train_metric_logger.update('nll', loss.item(), train_dataloader.batch_size) # Save train metrics train_metric_logger.write(epoch, 'train') val_metric_logger.reset() prior.eval() for test_batch in tqdm(test_dataloader): images, labels = test_batch images = images.to(device) labels = labels.to(device) with torch.no_grad(): latents = vqvae.encoder(images) latents = vqvae.prenet(latents) latents = vqvae.vector_quantizer.proposal_distribution(latents) latents = latents.unsqueeze(1) logits = prior(latents, labels) loss = criterion(logits, latents.squeeze()) val_metric_logger.update('nll', loss.item(), test_dataloader.batch_size) # Save val metrics val_metric_logger.write(epoch, 'val') # Generate resolution = 8 if args.dataset == 'cifar10' else 7 condition = torch.arange(8).repeat(8) generated_prior = prior.generate(condition.to(device), resolution) \ .squeeze() quantized_prior = vqvae.vector_quantizer.embeddings(generated_prior) \ .permute(0, 3, 1, 2) generated = vqvae.decoder(vqvae.postnet(quantized_prior)) writer.add_images('generated', generated, epoch) # Save checkpoint checkpoint_path = pathlib.Path(experiment_model_path) / f"{epoch}.pth" torch.save(prior.state_dict(), checkpoint_path)
data,val=data.reshape(-1,256),val.reshape(-1,256) context = np.load("./data/%s_clatents.npy"%args.data).squeeze() context,valcon=context[split:],context[:split] context,valcon=context.reshape(-1,256),context.reshape(-1,256) n_trajs, length = data.shape[:2] img_dim=args.img_dim model = GatedPixelCNN(n_embeddings=args.n_embeddings, imgximg=args.img_dim**2, n_layers=args.n_layers, conditional=args.conditional, x_one_hot=args.x_one_hot,c_one_hot=args.c_one_hot, n_cond_res_block=args.n_cres_layers).to(device) model.train() criterion = nn.CrossEntropyLoss().cuda() opt = torch.optim.Adam(model.parameters(), lr=args.learning_rate) if args.loadpth_vq is not '': vae = VQVAE(args.n_hiddens, args.n_residual_hiddens, args.n_residual_layers, args.n_embeddings, args.embedding_dim, args.beta).cuda() vae.load_state_dict(torch.load(args.loadpth_vq)['model']) print("VQ Loaded") vae.eval() if args.data=='bco': sample_c=vae(sample_c_imgs,latent_only=True).detach().cpu().numpy().reshape(-1,length).squeeze() # if args.loadpth_pcnn is not '': model.load_state_dict(torch.load(args.loadpth_pcnn)) print("PCNN Loaded") n_trajs = len(data) dt = n_trajs // context.shape[0] n_batch = int(n_trajs / args.batch_size) n_trajs_t = len(val)
import operator import util.torchaudio_transforms as transforms from experiment_builders.vqvae_builder import VQVAEWORLDExperimentBuilder, VQVAERawExperimentBuilder from models.vqvae import VQVAE from models.common_networks import QuantisedInputModuleWrapper from datasets.vcc_world_dataset import VCCWORLDDataset from datasets.vcc_raw_dataset import VCCRawDataset from datasets.vctk_dataset import VCTKDataset from util.samplers import ChunkEfficientRandomSampler torch.manual_seed(seed=args.seed) vqvae_model = VQVAE(input_shape=(1, 1, args.input_len), encoder_arch=args.encoder, vq_arch=args.vq, generator_arch=args.generator, num_speakers=args.num_speakers, speaker_dim=args.speaker_dim, use_gated_convolutions=args.use_gated_convolutions) if args.dataset == 'VCCWORLD2016': print('VCC2016 dataset WORLD features.') dataset_path = args.dataset_root_path train_dataset = VCCWORLDDataset(root=dataset_path, scale=True) val_dataset = VCCWORLDDataset(root=dataset_path, scale=True, eval=True) # Create data loaders train_data = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size,
transforms.ToPILImage(), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ]) dataset = ImageDataset(args.dataset, transform) training_loader = DataLoader(dataset, batch_size=args.batch_size, drop_last=False) """ Set up VQ-VAE model with components defined in ./models/ folder """ model = VQVAE(args.n_hiddens, args.n_residual_hiddens, args.n_residual_layers, args.n_embeddings, args.embedding_dim, args.beta, args.n_dimension_changes).to(device) checkpoint = torch.load(os.path.join(utils.SAVE_MODEL_PATH, args.model_path), map_location=device) model.load_state_dict(checkpoint["model"]) model.to(device) model.eval() encoder = E2EEncoder(args.n_hiddens, args.n_residual_hiddens, args.n_residual_layers, args.embedding_dim, args.n_dimension_changes) encoder.to(device) encoder.train() """ Set up optimizer and training loop """