コード例 #1
0
def load_vqvae(model_path: Union[str, Path], device: torch.device = None):
    model_path = Path(model_path)
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if torch.cuda.is_available():
        data = torch.load(model_path)
    else:
        data = torch.load(model_path,
                          map_location=lambda storage, loc: storage)

    params = data["hyperparameters"]

    if 'channels' in params:
        channels = params['channels']
    else:
        channels = 1 if params['dataset'] == 'MNIST' else 3

    model = VQVAE(channels, params['n_hiddens'], params['n_residual_hiddens'],
                  params['n_residual_layers'], params['n_embeddings'],
                  params['embedding_dim'], params['beta']).to(device)

    model.load_state_dict(data['model'])

    return model, data
コード例 #2
0
def main(unused_args):
    if args.logdir is None:
        raise ValueError('Please specify the dir to the checkpoint')

    arch = tf.gfile.Glob(join(args.logdir, 'arch*.json'))[0]
    arch = json2dict(arch)

    net = VQVAE(arch)

    data = ByteWavWholeReader(speaker_list=txt2list(args.speaker_list),
                              filenames=tf.gfile.Glob(args.file_pattern))

    ZH = net.encode(data.x, args.mode)

    ema = tf.train.ExponentialMovingAverage(decay=0.995)
    trg_vars = {ema.average_name(v): v for v in tf.trainable_variables()}
    saver = tf.train.Saver(trg_vars)

    sess_config = tf.ConfigProto(allow_soft_placement=True,
                                 gpu_options=tf.GPUOptions(allow_growth=True))
    with tf.Session(config=sess_config) as sess:
        sess.run(tf.tables_initializer())
        sess.run(data.iterator.initializer)
        sess.run(tf.global_variables_initializer())
        load(saver, sess, args.logdir, ckpt=args.ckpt)

        hist = np.zeros([
            arch['num_exemplar'],
        ], dtype=np.int64)
        counter = 1
        while True:
            try:
                z_ids = sess.run(ZH)
                print('\rNum of processed files: {:d}'.format(counter), end='')
                counter += 1
                for i in z_ids[0]:  # bz = 1
                    hist[i] += 1
            except tf.errors.OutOfRangeError:
                print()
                break

        with open('histogram.npf', 'wb') as fp:
            hist.tofile(fp)

        plt.figure(figsize=[10, 2])
        plt.plot(np.log10(hist + 1), '.')
        plt.xlim([0, arch['num_exemplar'] - 1])
        plt.ylabel('log-frequency')
        plt.xlabel('exemplar index')
        plt.savefig('histogram.png')
        plt.close()
コード例 #3
0
ファイル: model_utils.py プロジェクト: rlit/sparse-vqvae
def get_model(architecture,
              num_embeddings,
              device,
              neighborhood,
              selection_fn,
              embed_dim,
              parallel=True,
              **kwargs):
    """
    Creates a VQVAE object.

    :param architecture: Has to be 'vqvae'.
    :param num_embeddings: Int. Number of dictioanry atoms
    :param device: String. 'cpu', 'cuda' or 'cuda:device_number'
    :param neighborhood: Int. Not used.
    :param selection_fn: String. 'vanilla' or 'fista'
    :param embed_dim: Int. Size of latent space.
    :param parallel: Bool. Use DataParallel or not.

    :return: VQVAE model or DataParallel(VQVAE model)
    """
    if architecture == 'vqvae':
        model = VQVAE(n_embed=num_embeddings,
                      neighborhood=neighborhood,
                      selection_fn=selection_fn,
                      embed_dim=embed_dim,
                      **kwargs).to(device)
    else:
        raise ValueError(
            'Valid architectures are vqvae. Got: {}'.format(architecture))

    if parallel and device != 'cpu':
        model = nn.DataParallel(model)

    return model
コード例 #4
0
ファイル: brainqa.py プロジェクト: yifr/brainqa
    def __init__(self, args, config):
        super(BrainQA, self).__init__(config)
        self.num_labels = config.num_labels

        # Set up BERT encoder
        self.config_enc = config.to_dict()
        self.config_enc['output_hidden_states'] = True
        self.config_enc = BertConfig.from_dict(self.config_enc)
        self.bert_enc = BertModel.from_pretrained(args.model_name_or_path, config=self.config_enc)

        # Set up BERT decoder
        self.config_dec = config.to_dict()
        self.config_dec['is_decoder'] = True
        self.config_dec = BertConfig.from_dict(self.config_dec)
        self.bert_dec = BertModel.from_pretrained(args.model_name_or_path, config=self.config_dec)

        # # VQVAE for external memory
        self.vqvae_model= VQVAE(h_dim=256,
                        res_h_dim=256,
                        n_res_layers=4,
                        n_embeddings=args.n_vqvae_embeddings,
                        embedding_dim=256,
                        restart=args.vqvae_random_restart,
                        beta=2)

        # Question answer layer to output spans of question answers
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()
コード例 #5
0
def main(_):
  speaker_list = txt2list(args.speaker_list)
  dirs = validate_log_dirs(args)
  arch = json2dict(args.arch)
  arch.update(dirs)
  arch.update({'ckpt': args.ckpt})
  copy_arch_file(args.arch, arch['logdir'])
  net = VQVAE(arch)
  P = net.n_padding()
  print('Receptive field: {} samples ({:.2f} sec)\n'.format(P, P / arch['fs']))
  data = ByteWavReader(
    speaker_list,
    args.file_pattern,
    T=arch['T'],
    batch_size=arch['training']['batch_size'],
    buffer_size=5000
  )
  net.train(data)
コード例 #6
0
def decode(model: VQVAE, code, plot_path: str = None):
    emb = model.vector_quantization.embedding(code.squeeze(1)).permute(
        0, 3, 1, 2)
    hx = model.decoder(emb)

    display_image_grid(hx)
    if plot_path:
        plt.savefig(plot_path)
    else:
        plt.show()
コード例 #7
0
def uniform_sample(model: VQVAE,
                   num_samples: int,
                   device,
                   plot_path: str = None):
    code_shape = model.encode(
        torch.zeros((num_samples, 3, 32, 32), device=device)).shape
    print('Latent code shape:', code_shape)
    if not plot_path:
        plt.title('Uniform sample')
    code = torch.randint(0,
                         model.vector_quantization.embedding.num_embeddings,
                         code_shape,
                         device=device)
    decode(model, code, plot_path)
コード例 #8
0
def main(_):
    """Train the model based on the command-line arguments."""
    # Parse command-line arguments
    speaker_list = txt2list(args.speaker_list)
    dirs = validate_log_dirs(args)
    arch = json2dict(args.arch)
    arch.update(dirs)
    arch.update({'ckpt': args.ckpt})
    copy_arch_file(args.arch, arch['logdir'])

    # Initialize the model
    net = VQVAE(arch)
    P = net.n_padding()
    print('Receptive field: {} samples ({:.2f} sec)'.format(P, P / arch['fs']))

    # Read the input data as specified by the command line arguments
    data = ByteWavReader(speaker_list,
                         args.file_pattern,
                         T=arch['T'],
                         batch_size=arch['training']['batch_size'],
                         buffer_size=5000)

    # Train the model on the input data
    net.train(data)
コード例 #9
0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if args.save:
    print('Results will be saved in ./results/vqvae_' + args.filename + '.pth')
"""
Load data and define batch data loaders
"""

training_data, validation_data, training_loader, validation_loader, x_train_var = utils.load_data_and_data_loaders(
    args.dataset, args.batch_size)
"""
Set up VQ-VAE model with components defined in ./models/ folder
"""

model = VQVAE(args.n_hiddens, args.n_residual_hiddens, args.n_residual_layers,
              args.n_embeddings, args.embedding_dim, args.beta).to(device)
"""
Set up optimizer and training loop
"""
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, amsgrad=True)

model.train()

results = {
    'n_updates': 0,
    'recon_errors': [],
    'loss_vals': [],
    'perplexities': [],
}

コード例 #10
0
ファイル: generate.py プロジェクト: GPUPhobia/pitch-vqvae
def main(unused_args):
    if args.logdir is None:
        raise ValueError('Please specify the dir to the checkpoint')

    speaker_list = txt2list(args.speaker_list)

    arch = tf.gfile.Glob(os.path.join(args.logdir, 'arch*.json'))[0]
    arch = json2dict(arch)

    net = VQVAE(arch)

    # they start roughly at the same position but end very differently (3 is longest)
    filenames = [
        'dataset/VCTK/tfr/p227/p227_363.tfr',
        # 'dataset/VCTK/tfr/p240/p240_341.tfr',
        # 'dataset/VCTK/tfr/p243/p243_359.tfr',
        'dataset/VCTK/tfr/p225/p225_001.tfr'
    ]
    data = ByteWavWholeReader(speaker_list, filenames)

    X = tf.placeholder(dtype=tf.int64, shape=[None, None])
    Y = tf.placeholder(dtype=tf.int64, shape=[
        None,
    ])
    ZH = net.encode(X, args.mode)
    XH = net.generate(X, ZH, Y)
    # XWAV = mu_law_decode(X)
    # XBIN = tf.contrib.ffmpeg.encode_audio(XWAV, 'wav', arch['fs'])

    ema = tf.train.ExponentialMovingAverage(decay=0.995)
    trg_vars = {ema.average_name(v): v for v in tf.trainable_variables()}
    saver = tf.train.Saver(trg_vars)

    logdir = get_default_logdir(args.logdir)
    tf.gfile.MkDir(logdir)

    sess_config = tf.ConfigProto(allow_soft_placement=True,
                                 gpu_options=tf.GPUOptions(allow_growth=True))
    with tf.Session(config=sess_config) as sess:
        sess.run(tf.tables_initializer())
        sess.run(data.iterator.initializer)

        results = []
        for _ in filenames:
            result = sess.run({'x': data.x, 'y': data.y})
            results.append(result)
        # results1 = sess.run({'x': data.x, 'y': data.y})
        # results2 = sess.run({'x': data.x, 'y': data.y})

        length_input = net.n_padding() + 1  # same as padding + 1

        ini = 15149 - length_input
        end = 42285
        # x_source1 = results1['x'][:, ini: end]
        # x_source2 = results2['x'][:, ini: end]
        for i in range(len(results)):
            x = results[i]['x']
            if x.shape[-1] < end:
                x = np.concatenate(
                    [x, x[0, 0] + np.zeros([1, end - x.shape[-1]])], -1)
            results[i]['x'] = x[:, ini:end]

        # from pdb import set_trace
        # set_trace()
        x_source = np.concatenate([
            results[0]['x'], results[0]['x'], results[1]['x'], results[1]['x']
        ], 0)

        B = x_source.shape[0]

        y_input = np.concatenate([
            results[0]['y'], results[1]['y'], results[1]['y'], results[0]['y']
        ], 0)

        length_target = x_source.shape[1] - length_input

        while True:
            sess.run(tf.global_variables_initializer())
            load(saver, sess, args.logdir, ckpt=args.ckpt)

            z_blend = sess.run(ZH, feed_dict={X: x_source})
            x_input = x_source[:, :length_input]

            z_input = z_blend[:, :length_input, :]

            # Generate
            try:
                x_gen = np.zeros([B, length_target],
                                 dtype=np.int64)  # + results['x'][0, 0]
                for i in range(length_target):
                    xh = sess.run(XH,
                                  feed_dict={
                                      X: x_input,
                                      ZH: z_input,
                                      Y: y_input
                                  })
                    z_input = z_blend[:, i + 1:i + 1 + length_input, :]
                    x_input[:, :-1] = x_input[:, 1:]
                    x_input[:, -1] = xh[:, -1]
                    x_gen[:, i] = xh[:, -1]
                    print('\rGenerating {:5d}/{:5d}... x={:3d}'.format(
                        i + 1, length_target, xh[0, -1]),
                          end='',
                          flush=True)
            except KeyboardInterrupt:
                print("Interrupted by the user.")
            finally:
                print()
                x_wav = mu_law_decode(x_gen)
                for i in range(x_wav.shape[0]):
                    x_1ch = np.expand_dims(x_wav[i], -1)
                    # x_bin = sess.run(XBIN, feed_dict={X: x_1ch})

                    librosa.output.write_wav('testwav-{}.wav'.format(i), x_1ch,
                                             arch['fs'])
                    # with open(os.path.join(logdir, 'testwav-{}.wav'.format(i)), 'wb') as fp:
                    #  fp.write(x_bin)

            # For periodic gen.
            if args.period > 0:
                try:
                    print('Sleep for a while')
                    sleep(args.period * 60)
                    logdir = get_default_logdir(args.logdir)
                    tf.gfile.MkDir(logdir)
                except KeyboardInterrupt:
                    print('Stop periodic gen.')
                    break
                finally:
                    print('all finished')
            else:
                break
コード例 #11
0
ファイル: run_brainqa.py プロジェクト: yifr/brainqa
def main():
    args = model_args.get_args()

    if (
        os.path.exists(args.output_dir)
        and os.listdir(args.output_dir)
        and args.do_train
        and not args.overwrite_output_dir
    ):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                args.output_dir
            )
        )

    # Setup CUDA, GPU
    device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
    args.n_gpu = torch.cuda.device_count()
    args.device = device

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.warning(
        "Device: %s, n_gpu: %s",
        device,
        args.n_gpu
    )

    # Set seed
    set_seed(args)

    # Set up model with huggingface pre-trained config
    args.model_type = args.model_type.lower()
    config = AutoConfig.from_pretrained(
        args.config_name if args.config_name else args.model_name_or_path,
        cache_dir=args.cache_dir if args.cache_dir else None, early_stopping=True
    )
    tokenizer = AutoTokenizer.from_pretrained(
        args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
        do_lower_case=args.do_lower_case,
        cache_dir=args.cache_dir if args.cache_dir else None,
    )

    model = BrainQA(args=args, config=config)
    if args.train_vqvae_instead:
        model = VQVAE(h_dim=256, 
                        res_h_dim=256, 
                        n_res_layers=4, 
                        n_embeddings=4096, 
                        embedding_dim=256, 
                        beta=2)

    model.to(args.device)
    logger.info("Training/evaluation parameters %s", args)

    # Training
    if args.do_train:
        train_dataset = load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=False)
        if args.train_vqvae_instead:
            global_step, tr_loss = train_vqvae(args, train_dataset, model, model, tokenizer)
        else:
            global_step, tr_loss = train(args, train_dataset, model, tokenizer)
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

        # Create output directory if needed
        if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)

        logger.info("Saving model checkpoint to %s", args.output_dir)
        if args.train_vqvae_instead:
            torch.save(model.state_dict(), os.path.join(args.output_dir, 'vqvae_model.bin'))
        else:
            model_to_save = model.module if hasattr(model, "module") else model
            model_to_save.save_pretrained(args.output_dir)

        tokenizer.save_pretrained(args.output_dir)
        # Save training args as well
        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))

    # Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
    results = {}
    if args.do_eval:
        logger.info("Loading checkpoints saved during training for evaluation")
    
        checkpoints = model_args.get_checkpoints(args)
        logger.info("Evaluate the following checkpoints: %s", checkpoints)

        for checkpoint in checkpoints:
            # Reload the model
            logger.info('Evaluating checkpoint: {}'.format(checkpoint))
            global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
            
            state_dict = torch.load(checkpoint + '/pytorch_model.bin')
            model.load_state_dict(state_dict)
            model.to(args.device)

            # Evaluate
            result = evaluate(args, model, tokenizer, prefix=global_step)

            result = dict((k + ("_{}".format(global_step) if global_step else ""), v) for k, v in result.items())
            results.update(result)
        logger.info("Results: {}".format(results))

    if args.run_visualizations:
        model = BrainQA(args=args, config=config)
        checkpoints = model_args.get_checkpoints(args)
        checkpoint = checkpoint[-1]
        path_to_dict = os.path.join(args.output_dir, checkpoint, '/pytorch_model.bin')
        if not os.path.exists(path_to_dict):
            raise FileNotFoundError(path_to_dict + ' not found. Please make sure you have passed the correct output directory, \
                and that it contains a fully trained model checkpoint. Visualizations are not currently in place for VQ-VAE alone.')
            
        state_dict = torch.load(path_to_dict)
        model.load_state_dict(state_dict)
        
        model.to(args.device)

        eval_dataset, examples, features = load_and_cache_examples(args, tokenizer, evaluate=True, output_examples=True)
        visualizations.visualize(model, eval_dataset, tokenizer, args, latent_vis=True)
        
        logger.info('Visualizations constructed. Please check /images directory.')

    return results
コード例 #12
0
ファイル: data_to_latent.py プロジェクト: richardrl/cvqvae
parser.add_argument("--n_residual_layers", type=int, default=2)
parser.add_argument("--embedding_dim", type=int, default=64)
parser.add_argument("--n_embeddings", type=int, default=512)
parser.add_argument("--beta", type=float, default=.25)
parser.add_argument("--loadpth",
                    type=str,
                    default='./results/vqvae_data_bo.pth')
parser.add_argument("--data_dir",
                    type=str,
                    default='/home/karam/Downloads/bco/')
parser.add_argument("--data", type=str, default='bco')
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#Load model
model = VQVAE(args.n_hiddens, args.n_residual_hiddens, args.n_residual_layers,
              args.n_embeddings, args.embedding_dim, args.beta).to(device)
assert args.loadpth is not ''
model.load_state_dict(torch.load(args.loadpth)['model'])
model.eval()
print("Loaded model")

#Load data
save_dir = os.getcwd() + '/data'
data_dir = args.data_dir
if args.data == 'bco':
    data1 = np.load(data_dir + "/bcov5_0.npy")
    data2 = np.load(data_dir + "/bcov5_1.npy")
    data3 = np.load(data_dir + "/bcov5_2.npy")
    data4 = np.load(data_dir + "/bcov5_3.npy")
    data = np.concatenate((data1, data2, data3, data4), axis=0)
elif args.data == 'bo':
コード例 #13
0
def main(args):
    writer = SummaryWriter(args.experiment_log_path)
    writer.add_hparams(vars(args), {})

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    transform = transforms.Compose(
        [transforms.Resize((32, 32), 3),
         transforms.ToTensor()])

    if args.dataset == 'cifar10':
        train_dataset = datasets.CIFAR10('data',
                                         train=True,
                                         download=True,
                                         transform=transform)
        test_dataset = datasets.CIFAR10('data',
                                        train=False,
                                        download=True,
                                        transform=transform)
        args.in_channels = 3
    elif args.dataset == 'mnist':
        train_dataset = datasets.MNIST('data',
                                       train=True,
                                       download=True,
                                       transform=transform)
        test_dataset = datasets.MNIST('data',
                                      train=False,
                                      download=True,
                                      transform=transform)
        args.in_channels = 1
    else:
        raise ValueError(f"Invalid dataset: {args.dataset}")

    train_dataloader = DataLoader(train_dataset,
                                  args.batch_size,
                                  shuffle=True,
                                  pin_memory=True,
                                  num_workers=4)
    test_dataloader = DataLoader(test_dataset,
                                 args.batch_size // 4,
                                 pin_memory=True,
                                 num_workers=4)

    vqvae = VQVAE(args.in_channels, args.hidden_channels_vqvae,
                  args.num_embeddings, args.embedding_dim)
    vqvae.load_state_dict(
        torch.load(args.vqvae_state_dict, map_location=torch.device('cpu')))
    vqvae = vqvae.to(device)

    prior = PixelCNN(args.num_embeddings, args.hidden_channels_prior, args.num_layers, args.num_classes) \
        .to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(prior.parameters(), args.lr)

    # Initialize Loggers
    train_metric_logger = MeterLogger(("nll", ), writer)
    val_metric_logger = MeterLogger(("nll", ), writer)

    print(vqvae)

    for epoch in tqdm(range(args.num_epoch)):

        train_metric_logger.reset()
        prior.train()
        for train_batch in tqdm(train_dataloader):
            images, labels = train_batch
            images = images.to(device)
            labels = labels.to(device)

            with torch.no_grad():
                # TODO repack into one call
                latents = vqvae.encoder(images)
                latents = vqvae.prenet(latents)
                latents = vqvae.vector_quantizer.proposal_distribution(latents)
                latents = latents.unsqueeze(1)

            logits = prior(latents, labels)
            loss = criterion(logits, latents.squeeze())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_metric_logger.update('nll', loss.item(),
                                       train_dataloader.batch_size)

        # Save train metrics
        train_metric_logger.write(epoch, 'train')

        val_metric_logger.reset()
        prior.eval()
        for test_batch in tqdm(test_dataloader):
            images, labels = test_batch
            images = images.to(device)
            labels = labels.to(device)

            with torch.no_grad():
                latents = vqvae.encoder(images)
                latents = vqvae.prenet(latents)
                latents = vqvae.vector_quantizer.proposal_distribution(latents)
                latents = latents.unsqueeze(1)

                logits = prior(latents, labels)
                loss = criterion(logits, latents.squeeze())

            val_metric_logger.update('nll', loss.item(),
                                     test_dataloader.batch_size)

        # Save val metrics
        val_metric_logger.write(epoch, 'val')

        # Generate
        resolution = 8 if args.dataset == 'cifar10' else 7
        condition = torch.arange(8).repeat(8)
        generated_prior = prior.generate(condition.to(device), resolution) \
            .squeeze()

        quantized_prior = vqvae.vector_quantizer.embeddings(generated_prior) \
            .permute(0, 3, 1, 2)
        generated = vqvae.decoder(vqvae.postnet(quantized_prior))

        writer.add_images('generated', generated, epoch)

        # Save checkpoint
        checkpoint_path = pathlib.Path(experiment_model_path) / f"{epoch}.pth"
        torch.save(prior.state_dict(), checkpoint_path)
コード例 #14
0
ファイル: gated_pixelcnn.py プロジェクト: richardrl/cvqvae
data,val=data.reshape(-1,256),val.reshape(-1,256)
context = np.load("./data/%s_clatents.npy"%args.data).squeeze()
context,valcon=context[split:],context[:split]
context,valcon=context.reshape(-1,256),context.reshape(-1,256)
n_trajs, length = data.shape[:2]
img_dim=args.img_dim

model = GatedPixelCNN(n_embeddings=args.n_embeddings, imgximg=args.img_dim**2, 
    n_layers=args.n_layers, conditional=args.conditional,
    x_one_hot=args.x_one_hot,c_one_hot=args.c_one_hot, n_cond_res_block=args.n_cres_layers).to(device)
model.train()
criterion = nn.CrossEntropyLoss().cuda()
opt = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

if args.loadpth_vq is not '':
    vae = VQVAE(args.n_hiddens, args.n_residual_hiddens,
              args.n_residual_layers, args.n_embeddings, args.embedding_dim, args.beta).cuda()
    vae.load_state_dict(torch.load(args.loadpth_vq)['model'])
    print("VQ Loaded")
    vae.eval()
    if args.data=='bco':
        sample_c=vae(sample_c_imgs,latent_only=True).detach().cpu().numpy().reshape(-1,length).squeeze()

# 
if args.loadpth_pcnn is not '':
    model.load_state_dict(torch.load(args.loadpth_pcnn))
    print("PCNN Loaded")

n_trajs = len(data)
dt = n_trajs // context.shape[0]
n_batch = int(n_trajs / args.batch_size)
n_trajs_t = len(val)
コード例 #15
0
import operator
import util.torchaudio_transforms as transforms
from experiment_builders.vqvae_builder import VQVAEWORLDExperimentBuilder, VQVAERawExperimentBuilder
from models.vqvae import VQVAE
from models.common_networks import QuantisedInputModuleWrapper
from datasets.vcc_world_dataset import VCCWORLDDataset
from datasets.vcc_raw_dataset import VCCRawDataset
from datasets.vctk_dataset import VCTKDataset
from util.samplers import ChunkEfficientRandomSampler

torch.manual_seed(seed=args.seed)

vqvae_model = VQVAE(input_shape=(1, 1, args.input_len),
                    encoder_arch=args.encoder,
                    vq_arch=args.vq,
                    generator_arch=args.generator,
                    num_speakers=args.num_speakers,
                    speaker_dim=args.speaker_dim,
                    use_gated_convolutions=args.use_gated_convolutions)

if args.dataset == 'VCCWORLD2016':
    print('VCC2016 dataset WORLD features.')

    dataset_path = args.dataset_root_path
    train_dataset = VCCWORLDDataset(root=dataset_path, scale=True)
    val_dataset = VCCWORLDDataset(root=dataset_path, scale=True, eval=True)

    # Create data loaders
    train_data = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=args.batch_size,
コード例 #16
0
ファイル: main_encoder.py プロジェクト: lior1990/vqvae
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
])

dataset = ImageDataset(args.dataset, transform)
training_loader = DataLoader(dataset,
                             batch_size=args.batch_size,
                             drop_last=False)
"""
Set up VQ-VAE model with components defined in ./models/ folder
"""

model = VQVAE(args.n_hiddens, args.n_residual_hiddens, args.n_residual_layers,
              args.n_embeddings, args.embedding_dim, args.beta,
              args.n_dimension_changes).to(device)
checkpoint = torch.load(os.path.join(utils.SAVE_MODEL_PATH, args.model_path),
                        map_location=device)
model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()

encoder = E2EEncoder(args.n_hiddens, args.n_residual_hiddens,
                     args.n_residual_layers, args.embedding_dim,
                     args.n_dimension_changes)
encoder.to(device)
encoder.train()
"""
Set up optimizer and training loop
"""