コード例 #1
0
def main(args):
    device = "cuda"

    args.distributed = dist.get_world_size() > 1

    normMean = [0.5]
    normStd = [0.5]

    normTransform = transforms.Normalize(normMean, normStd)
    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.ToTensor(),
        normTransform,
    ])

    txt_path = 'datd/train.txt'
    images_path = '/data'
    labels_path = '/data'

    dataset = txtDataset(txt_path,
                         images_path,
                         labels_path,
                         transform=transform)

    sampler = dist.data_sampler(dataset,
                                shuffle=True,
                                distributed=args.distributed)
    loader = DataLoader(dataset,
                        batch_size=batch_size // args.n_gpu,
                        sampler=sampler,
                        num_workers=16)

    model = VQVAE().to(device)

    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == "cycle":
        scheduler = CycleScheduler(
            optimizer,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device)

        if dist.is_primary():
            torch.save(model.state_dict(),
                       f"checkpoint/vqvae_{str(i + 1).zfill(3)}.pt")
コード例 #2
0
def load_model_from_file(path):
    with open(os.path.join(path, 'args.json'), 'rb') as f:
        args = dotdict(json.load(f))

    from vqvae import VQVAE
    # create model
    model = VQVAE(args)

    # load weights
    model.load_state_dict(torch.load(os.path.join(path, 'best_model.pth')))

    return model
コード例 #3
0
def main(args):
    device = "cuda"

    args.distributed = dist.get_world_size() > 1

    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.CenterCrop(args.size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    # dataset = datasets.ImageFolder(args.path, transform=transform)
    dataset = CUBDataset(args.path, transform=transform, mode=args.mode)
    sampler = dist.data_sampler(dataset,
                                shuffle=True,
                                distributed=args.distributed)
    loader = DataLoader(dataset,
                        batch_size=128 // args.n_gpu,
                        sampler=sampler,
                        num_workers=2)

    model = VQVAE().to(device)

    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == "cycle":
        scheduler = CycleScheduler(
            optimizer,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

    print(args)
    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device)

        if dist.is_primary():
            torch.save(model.state_dict(),
                       f"checkpoint/vqvae_{str(i + 1).zfill(3)}.pt")
コード例 #4
0
ファイル: sample.py プロジェクト: kassellevi1/DeepProject
def load_model(model, checkpoint, device):
    ckpt = torch.load(os.path.join('checkpoint', checkpoint))

    if 'args' in ckpt:
        args = ckpt['args']

    if model == 'vqvae':
        model = VQVAE()

    elif model == 'pixelsnail_bottom':
        model = PixelSNAIL(
            [64, 64],
            512,
            args.channel,
            5,
            4,
            args.n_res_block,
            args.n_res_channel,
            attention=False,
            dropout=args.dropout,
            n_cond_res_block=args.n_cond_res_block,
            cond_res_channel=args.n_res_channel,
        )

    if 'model' in ckpt:
        ckpt = ckpt['model']

    model.load_state_dict(ckpt)
    model = model.to(device)
    model.eval()

    return model
コード例 #5
0
def main(args):
    device = "cpu"

    args.distributed = dist.get_world_size() > 1

    transform = transforms.Compose(
        [
            transforms.Resize(args.size),
            transforms.CenterCrop(args.size),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        ]
    )

    dataset = datasets.ImageFolder(args.path, transform=transform)
    sampler = dist.data_sampler(dataset, shuffle=True, distributed=args.distributed)
    loader = DataLoader(
        dataset, batch_size=128 // args.n_gpu, sampler=sampler, num_workers=2
    )

    model = VQVAE().to(device)

    if args.load_path:
        load_state_dict = torch.load(args.load_path, map_location=device)
        model.load_state_dict(load_state_dict)
        print('successfully loaded model')

    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
        )

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == "cycle":
        scheduler = CycleScheduler(
            optimizer,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

    interpolate(loader, model, device)
コード例 #6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--size', type=int, default=256)
    parser.add_argument('--model_path', type=str)
    parser.add_argument('--name', type=str)
    parser.add_argument('path', type=str)

    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.CenterCrop(args.size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    dataset = ImageFileDataset(args.path, transform=transform)
    loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4)

    model = VQVAE()
    model.load_state_dict(torch.load(args.model_path))
    model = model.to(device)
    model.eval()

    map_size = 100 * 1024 * 1024 * 1024
    env = lmdb.open(args.name, map_size=map_size)

    extract(env, loader, model, device)
コード例 #7
0
def construct_model():
    x_input = tf.keras.layers.Input((28, 28, 1))
    enc_x = EncoderLayer()(x_input)
    quant_x = VQVAE()(enc_x)
    x_dec = tf.keras.layers.Lambda(
        lambda quant_x: enc_x + tf.stop_gradient(quant_x - enc_x))(quant_x)
    dec_x = DecoderLayer()(x_dec)
    model = tf.keras.models.Model(x_input, dec_x)
    model.compile(optimizer=tf.keras.optimizers.Adam(),
                  loss=vqvae_loss(0.25, enc_x, quant_x),
                  experimental_run_tf_function=False)
    return model
コード例 #8
0
def encode_proc(model_path, model_config_path, img_root_path,
                img_key_path_list, img_size, device, output_path):
    model_config_json = open(model_config_path).read()
    print("ModelConfig:", model_config_json, file=sys.stderr, flush=True)
    model_config = VqvaeConfig.from_json(model_config_json)
    model = VQVAE(model_config).to(device)
    if device.type == "cuda":
        torch.cuda.set_device(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    transforms = build_transform(img_size)

    output_fp = open(output_path, "w")
    linecnt = 0
    for f in img_key_path_list:
        for line in open(f):
            linecnt += 1
            if linecnt % 100000 == 0:
                print("{} {} done".format(f, linecnt),
                      file=sys.stderr,
                      flush=True)
            img_key = line.strip()
            img_path = get_key_path(img_root_path, line.strip())
            try:
                img = default_loader(img_path)
            except:
                continue
            img = transforms(img)[None].to(device)
            id_t = model(img)[2].detach().cpu().flatten(1)
            print("{}\t{}".format(img_key, ",".join(
                (str(x) for x in id_t[0].tolist()))),
                  file=output_fp,
                  flush=True)
    output_fp.close()
コード例 #9
0
def train_vqvae(hparams_path):
    hparams = load_hparams(hparams_path)
    os.makedirs(hparams.folder, exist_ok=True)
    model = VQVAE(hparams)
    logger = pl.loggers.TensorBoardLogger(save_dir=hparams.folder, name="logs")
    trainer = pl.Trainer(
        default_root=hparams.folder,
        max_epochs=hparams.epochs,
        show_progress_bar=False,
        gpus=hparams.gpus,
        logger=logger,
    )
    trainer.fit(model)
コード例 #10
0
ファイル: train.py プロジェクト: Ryan-Rudes/Experiment
    def __init__(self, in_channels, hidden_channels, res_channels,
                 nb_res_layers, nb_levels, embed_dim, nb_entries,
                 scaling_rates, lr, beta, batch_size, mini_batch_size, no_amp,
                 random_resets, device):
        self.device = device
        self.model = VQVAE(in_channel=in_channels,
                           channel=hidden_channels,
                           n_res_channel=res_channels,
                           n_res_block=nb_res_layers,
                           nb_levels=nb_levels,
                           embed_dim=embed_dim,
                           n_embed=nb_entries,
                           scaling_rates=scaling_rates,
                           random_resets=random_resets).to(self.device)

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        self.optimizer.zero_grad()

        self.beta = beta
        self.scaler = torch.cuda.amp.GradScaler(enabled=not no_amp)

        self.update_frequency = math.ceil(batch_size / mini_batch_size)
        self.steps = 0
コード例 #11
0
def main(args):
	root = args.root
	results_dir = args.results_dir
	save_path = os.path.join(root, results_dir)
	print('root is', root)
	print('save_path is:', save_path)
	os.makedirs(save_path, exist_ok=True)
	json_file_name = os.path.join(save_path, 'args.json')
	with open(json_file_name, 'w') as fp:
		json.dump(dict(args._get_kwargs()), fp, sort_keys=True, indent=4)
	checkpoints_path = os.path.join(save_path, 'checkpoints')
	os.makedirs(checkpoints_path, exist_ok=True)
	sample_output_path = os.path.join(save_path, 'output')
	os.makedirs(sample_output_path, exist_ok=True)
	log_file = os.path.join(save_path, 'log.txt')
	config_logging(log_file)

	logging.info('====>  args{} '.format(args))
	num_workers = args.num_workers

	device = "cuda"
	batch_size = args.batch_size
	dataset_path = args.dataset_path

	transform_train = transform_train_cifar

	train_ds = CIFAR100(root=dataset_path, train=True, download=True, transform=transform_train)
	obtain_indices = get_indices


	classes = [i for i in range(args.pretrain_classes)]

	print('pretrain vqvae using ', classes)
	training_idx = obtain_indices(train_ds, classes, is_training=True)

	loader = DataLoader(train_ds, batch_size=batch_size, sampler=SubsetRandomSampler(training_idx),
						num_workers=num_workers, drop_last=False)

	model = VQVAE(embed_dim=args.dim_emb, n_embed=args.n_emb).to(device)
	if args.checkpoint is not None:
		model_pt = torch.load(args.checkpoint)
		model.load_state_dict(model_pt)
	opt = optim.Adam(model.parameters(), lr=args.lr)

	best_mse = 999999
	for i in range(args.epoch):

		tmp_mse = train_AE(i, loader, model, opt, device, save_path)
		if best_mse > tmp_mse:
			best_mse = tmp_mse
			logging.info('====>  Epoch{}: best_mse {} '.format(i, best_mse))
			pt_path = os.path.join(save_path, f"checkpoints/VQVAE2_cifar_best.pt")
			torch.save(model.state_dict(), pt_path)
コード例 #12
0
    def __init__(self,
                 vqvae,
                 in_channel=6,
                 channel=128,
                 n_res_block=2,
                 n_res_channel=32,
                 embed_dim=64,
                 n_embed=512,
                 decay=0.99):

        super(OffsetNetwork, self).__init__(in_channel=in_channel,
                                            channel=channel,
                                            n_res_block=n_res_block,
                                            n_res_channel=n_res_channel,
                                            embed_dim=embed_dim,
                                            n_embed=n_embed,
                                            decay=decay)

        # Fix pre-trained VQVAE
        self.vqvae = vqvae if vqvae is not None else VQVAE()
        for params in self.vqvae.parameters():
            params.requires_grad = False
コード例 #13
0
def encode_fn():
    model = VQVAE().to(device)

    def encode(x):
        model.eval()

        with torch.no_grad():
            x = cv2.resize(x, (160, 160), interpolation=cv2.INTER_AREA)
            x = transform(x)
            x = x.unsqueeze(0)
            x = x.to(device)

            _, _, _, id_t, id_b = model.encode(x)

        id_t = id_t.cpu().numpy()
        # id_b = id_b.cpu().numpy()

        model.train()

        return id_t

    return model, encode
コード例 #14
0
    parser = argparse.ArgumentParser()
    parser.add_argument('--size', type=int, default=256)
    parser.add_argument('--ckpt', type=str)
    parser.add_argument('--name', type=str)
    parser.add_argument('path', type=str)

    args = parser.parse_args()

    device = 'cuda'

    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.CenterCrop(args.size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    dataset = ImageFileDataset(args.path, transform=transform)
    loader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=4)

    model = VQVAE()
    model.load_state_dict(torch.load(args.ckpt))
    model = model.to(device)
    model.eval()

    map_size = 100 * 1024 * 1024 * 1024

    env = lmdb.open(args.name, map_size=map_size)

    extract(env, loader, model, device)
コード例 #15
0
def main(args):
    device = "cuda"

    args.distributed = dist.get_world_size() > 1

    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.CenterCrop(args.size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    dataset = OffsetDataset(args.path, transform=transform, offset=args.offset)
    sampler = dist.data_sampler(dataset,
                                shuffle=True,
                                distributed=args.distributed)
    loader = DataLoader(dataset,
                        batch_size=args.bsize // args.n_gpu,
                        sampler=sampler,
                        num_workers=2)

    # Load pre-trained VQVAE
    vqvae = VQVAE().to(device)
    try:
        vqvae.load_state_dict(torch.load(args.ckpt))
    except:
        print(
            "Seems the checkpoint was trained with data parallel, try loading it that way"
        )
        weights = torch.load(args.ckpt)
        renamed_weights = {}
        for key, value in weights.items():
            renamed_weights[key.replace('module.', '')] = value
        weights = renamed_weights
        vqvae.load_state_dict(weights)

    # Init offset encoder
    model = OffsetNetwork(vqvae).to(device)

    if args.distributed:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[dist.get_local_rank()],
            output_device=dist.get_local_rank(),
            find_unused_parameters=True)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == "cycle":
        scheduler = CycleScheduler(
            optimizer,
            args.lr,
            n_iter=len(loader) * args.epoch,
            momentum=None,
            warmup_proportion=0.05,
        )

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device)

        if dist.is_primary():
            torch.save(model.state_dict(),
                       f"checkpoint/offset_enc_{str(i + 1).zfill(3)}.pt")
コード例 #16
0
def encode(args):
    model_config_json = open(args.config_path).read()
    print("ModelConfig:", model_config_json, file=sys.stderr, flush=True)
    model_config = VqvaeConfig.from_json(model_config_json)
    device = torch.device(args.device)
    n_gpu = torch.cuda.device_count() if args.device == "cuda" else 0
    model = VQVAE(model_config).to(device)
    model.load_state_dict(torch.load(args.model_path, map_location=device))
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)
    model.eval()
    trans = build_transform(args.img_size)
    dataset = ImageLmdbDataset(args.img_root_path,
                               args.img_key_path,
                               trans,
                               args.batch_size,
                               with_key=True)
    dataloader = IterDataLoader(dataset,
                                batch_size=args.batch_size,
                                num_workers=args.num_workers,
                                collate_fn=collate_fn,
                                pin_memory=True)
    lmdb_env = lmdb.open(args.output_path, map_size=int(1e12))
    lmdb_txn = lmdb_env.begin(write=True)
    cache = []
    batch_cnt = 0
    input_cost = 0.0
    to_cost = 0.0
    eval_cost = 0.0
    trans_cost = 0.0
    write_cost = 0.0
    write_batch_cnt = 0
    t_point = time.time()
    start_point = t_point
    for key_list, img_batch in dataloader:
        t_point_1 = time.time()
        input_cost += t_point_1 - t_point
        t_point = t_point_1
        img_batch.to(device)
        t_point_1 = time.time()
        to_cost += t_point_1 - t_point
        t_point = t_point_1
        id_batch = model(img_batch)[3].detach().cpu().flatten(1)
        t_point_1 = time.time()
        eval_cost += t_point_1 - t_point
        t_point = t_point_1
        for key, id_t in zip(key_list, id_batch):
            lmdb_txn.put(key.encode("utf-8"),
                         id_t.to(torch.int16).numpy().tobytes())
        t_point_1 = time.time()
        trans_cost += t_point_1 - t_point
        t_point = t_point_1
        '''
        if len(cache) > 1000:
            for k, v in cache:
                lmdb_txn.put(k, v)
            lmdb_txn.commit()
            lmdb_txn = lmdb_env.begin(write=True)
            del cache[:]
            t_point_1 = time.time()
            write_cost += t_point_1 - t_point
            t_point = t_point_1
            write_batch_cnt = batch_cnt
        '''
        batch_cnt += 1
        if batch_cnt % 100 == 0:
            print(
                "{} batch done, input_c={:.4f}, to_c={:.4f}, eval_c={:.4f}, trans_c={:.4f}, total_c={:.4f}"
                .format(batch_cnt, input_cost / batch_cnt, to_cost / batch_cnt,
                        eval_cost / batch_cnt, trans_cost / batch_cnt,
                        (time.time() - start_point) / batch_cnt),
                file=sys.stderr,
                flush=True)
        t_point = time.time()
    lmdb_txn.commit()
    lmdb_env.close()
コード例 #17
0
    args = parser.parse_args()

    print(args)

    device = 'cuda'

    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.CenterCrop(args.size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    dataset = datasets.ImageFolder(args.path, transform=transform)
    loader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)

    model = VQVAE().to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == 'cycle':
        scheduler = CycleScheduler(optimizer,
                                   args.lr,
                                   n_iter=len(loader) * args.epoch,
                                   momentum=None)

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device)
        torch.save(model.state_dict(),
                   f'checkpoint/vqvae_{str(i + 1).zfill(3)}.pt')
コード例 #18
0
        device = device_list[device_i]
        input_list = img_key_path_list[device_i * proc_pic: (device_i + 1) * proc_pic]
        output_path = os.path.join(args.output_path, "part-{}".format(device_i))
        proc_list.append(Process(
            target=encode_proc,
            args=(args.model_path, args.config_path, args.img_root_path,
                  input_list, args.img_size, device, output_path)
        ))

    for proc in proc_list:
        proc.start()
    for proc in proc_list:
        proc.join()
    '''
    encode(args)


if __name__ == "__main__":
    main()
    exit(0)
    config_path = os.path.join(sys.argv[1], "config.json")
    model_path = os.path.join(sys.argv[1], "pytorch_model.bin")
    model = VQVAE(VqvaeConfig.from_json(open(config_path).read())).to("cpu")
    model.load_state_dict(torch.load(model_path, map_location="cpu"))
    model.eval()
    img_path = get_key_path("/mnt2/makai/imgs", sys.argv[2])
    trans = build_transform(224)
    img = default_loader(img_path)
    img = trans(img)[None]
    id_t, id_b = model(img)[2:4]
    print(",".join((str(x.item()) for x in id_b.flatten(1)[0])))
コード例 #19
0
ファイル: train_vqvae.py プロジェクト: August-us/vq-vae-2
    print(args)

    device = 'cuda'

    # transform = transforms.Compose(
    #     [
    #         transforms.Resize(args.size),
    #         transforms.CenterCrop(args.size),
    #         transforms.ToTensor(),
    #         transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    #     ]
    # )
    # dataset = datasets.ImageFolder(args.path, transform=transform)
    # loader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)
    loader = get_dataset(args.path, batch_size=args.batchsize)
    model = nn.DataParallel(VQVAE(embed_dim=32)).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == 'cycle':
        scheduler = CycleScheduler(optimizer,
                                   args.lr,
                                   n_iter=len(loader) * args.epoch,
                                   momentum=None)

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device)
        torch.save(
            model.module.state_dict(),
            f'allCheckpoint/checkpoint32/vqvae_{str(i + 1).zfill(3)}.pt')
    writer.close()
コード例 #20
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--size', help='Image size', type=int, default=256)
    parser.add_argument('--epoch', type=int, default=100)
    parser.add_argument('--lr', type=float, default=3e-4)
    parser.add_argument('--bs', type=int, default=64)
    parser.add_argument('--sched', type=str, default='cycle')
    parser.add_argument('--vishost', type=str, default='localhost')
    parser.add_argument('--visport', type=int, default=8097)
    parser.add_argument('path',
                        help="root path with train and test folder in it",
                        type=str)

    args = parser.parse_args()

    print(args)

    device = torch.device('cuda' if torch.cuda.is_available() else "cpu")

    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.CenterCrop(args.size),
        transforms.ToTensor(),
        transforms.Normalize([0.5] * 3, [0.5] * 3)
    ])

    train_path = os.path.join(args.path, "train")
    test_path = os.path.join(args.path, "test")

    train_dataset = datasets.ImageFolder(train_path, transform=transform)
    train_loader = DataLoader(train_dataset,
                              batch_size=args.bs,
                              shuffle=True,
                              num_workers=4)

    test_dataset = datasets.ImageFolder(test_path, transform=transform)
    test_loader = DataLoader(test_dataset,
                             batch_size=args.bs,
                             shuffle=False,
                             num_workers=4)

    model = VQVAE().to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    if args.sched == 'cycle':
        scheduler = CycleScheduler(optimizer,
                                   args.lr,
                                   n_iter=len(train_loader) * args.epoch,
                                   momentum=None)
    else:
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [50, 70], 0.1)

    train_losses = []
    test_losses = []
    vis = visdom.Visdom(server=args.vishost, port=args.visport)
    win = None
    best_model_loss = np.inf
    for i in range(args.epoch):
        # Training stage
        print(f"Training epoch {i + 1}")
        train_loss = train(i, train_loader, model, optimizer, scheduler,
                           device)
        print(f"Train Loss: {train_loss:.5f}")

        # Testing stage
        print(f"Testing epoch {i + 1}")
        test_loss, test_recon_error, test_commitment_loss = test(
            i, test_loader, model, device)
        print(f"Test Loss: {test_loss:.5f}")
        torch.save(model.state_dict(),
                   f'checkpoints/vqvae_chkpt_{str(i + 1).zfill(3)}.pt')

        if test_loss < best_model_loss:
            print("Saving model")
            torch.save(model.state_dict(), f'weights/vqvae.pt')
            best_model_loss = test_loss

        train_losses.append(train_loss)
        test_losses.append(test_loss)
        win = plot(train_losses, test_losses, vis, win)

        # Sampling stage
        recon_sample(i, model, test_loader, device)
コード例 #21
0
    valid_data_loader = AtariDataset(
                                   valid_data_file,
                                   number_condition=4,
                                   steps_ahead=1,
                                   batch_size=largs.batch_size,
                                   norm_by=255.0,)

    args.size_training_set = valid_data_loader.num_examples
    hsize = valid_data_loader.data_h
    wsize = valid_data_loader.data_w

    if args.reward_int:
        int_reward = info['num_rewards']
        vqvae_model = VQVAE(num_clusters=largs.num_k,
                            encoder_output_size=largs.num_z,
                            num_output_mixtures=info['num_output_mixtures'],
                            in_channels_size=largs.number_condition,
                            n_actions=info['num_actions'],
                            int_reward=info['num_rewards']).to(DEVICE)
    elif 'num_rewards' in info.keys():
        print("CREATING model with est future reward")
        vqvae_model = VQVAE(num_clusters=largs.num_k,
                            encoder_output_size=largs.num_z,
                            num_output_mixtures=info['num_output_mixtures'],
                            in_channels_size=largs.number_condition,
                            n_actions=info['num_actions'],
                            int_reward=False,
                            reward_value=True).to(DEVICE)
    else:
        vqvae_model = VQVAE(num_clusters=largs.num_k,
                            encoder_output_size=largs.num_z,
                            num_output_mixtures=info['num_output_mixtures'],
コード例 #22
0
    train_data_loader = AtariDataset(
                                   train_data_file,
                                   number_condition=4,
                                   steps_ahead=1,
                                   batch_size=args.batch_size,
                                   norm_by=255.,)
    valid_data_loader = AtariDataset(
                                   valid_data_file,
                                   number_condition=4,
                                   steps_ahead=1,
                                   batch_size=largs.batch_size,
                                   norm_by=255.0,)

    num_actions = valid_data_loader.n_actions
    args.size_training_set = valid_data_loader.num_examples
    hsize = valid_data_loader.data_h
    wsize = valid_data_loader.data_w

    vqvae_model = VQVAE(num_clusters=largs.num_k,
                        encoder_output_size=largs.num_z,
                        in_channels_size=largs.number_condition).to(DEVICE)

    vqvae_model.load_state_dict(model_dict['vqvae_state_dict'])
    #valid_data, valid_label, test_batch_index = data_loader.validation_ordered_batch()
    #valid_episode_batch, episode_index, episode_reward = valid_data_loader.get_entire_episode()
    #sample_batch(valid_episode_batch, episode_index, episode_reward, 'valid')
    train_episode_batch, episode_index, episode_reward = train_data_loader.get_entire_episode()
    sample_batch(train_episode_batch, episode_index, episode_reward, 'train')

コード例 #23
0
ファイル: train.py プロジェクト: Ryan-Rudes/Experiment
class Trainer:
    def __init__(self, in_channels, hidden_channels, res_channels,
                 nb_res_layers, nb_levels, embed_dim, nb_entries,
                 scaling_rates, lr, beta, batch_size, mini_batch_size, no_amp,
                 random_resets, device):
        self.device = device
        self.model = VQVAE(in_channel=in_channels,
                           channel=hidden_channels,
                           n_res_channel=res_channels,
                           n_res_block=nb_res_layers,
                           nb_levels=nb_levels,
                           embed_dim=embed_dim,
                           n_embed=nb_entries,
                           scaling_rates=scaling_rates,
                           random_resets=random_resets).to(self.device)

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        self.optimizer.zero_grad()

        self.beta = beta
        self.scaler = torch.cuda.amp.GradScaler(enabled=not no_amp)

        self.update_frequency = math.ceil(batch_size / mini_batch_size)
        self.steps = 0

    def _calculate_loss(self, x):
        y, d, _, _, _ = self.model(x)
        r_loss, l_loss = y.sub(x).pow(2).mean(), sum(d)
        loss = r_loss + self.beta * l_loss
        return loss, r_loss, l_loss, y

    def train(self, x):
        self.model.train()
        with torch.cuda.amp.autocast(enabled=self.scaler.is_enabled()):
            loss, r_loss, l_loss, _ = self._calculate_loss(x)
        self.scaler.scale(loss / self.update_frequency).backward()

        self.steps += 1
        if self.steps % self.update_frequency == 0:
            self._update_parameters()

        return loss.item(), r_loss.item(), l_loss.item()

    def _update_parameters(self):
        self.scaler.step(self.optimizer)
        self.optimizer.zero_grad()
        self.scaler.update()

    @torch.no_grad()
    def eval(self, x):
        self.model.eval()
        self.optimizer.zero_grad()
        loss, r_loss, l_loss, y = self._calculate_loss(x)
        return loss.item(), r_loss.item(), l_loss.item(), y

    def save_checkpoint(self, path):
        torch.save(self.model.state_dict(), path)

    def load_checkpoint(self, path):
        self.model.load_state_dict(torch.load(path))

    def save_reconstructions(self, batch, path, sample_size=16):
        batch = batch[:sample_size]
        _, _, _, out = self.eval(batch)

        utils.save_image(torch.cat([batch, out]),
                         path,
                         nrow=batch.shape[0],
                         normalize=True,
                         value_range=(-1, 1))
コード例 #24
0
    args = parser.parse_args()

    print(args)

    device = 'cuda'

    transform = transforms.Compose(
        [
            transforms.Resize(args.size),
            transforms.CenterCrop(args.size),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        ]
    )

    dataset = datasets.ImageFolder(args.path, transform=transform)
    loader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)

    model = VQVAE().to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == 'cycle':
        scheduler = CycleScheduler(
            optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None
        )

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device)
        torch.save(model.state_dict(), f'checkpoint/vqvae_{str(i + 1).zfill(3)}.pt')
コード例 #25
0
                                     number_condition=args.number_condition,
                                     steps_ahead=1,
                                     batch_size=args.batch_size,
                                     norm_by=info['norm_by'])
    num_actions = train_data_loader.n_actions
    args.size_training_set = train_data_loader.num_examples
    hsize = train_data_loader.data_h
    wsize = train_data_loader.data_w
    # output mixtures should be 2*nr_logistic_mix + nr_logistic mix for each
    # decorelated channel
    info['num_channels'] = 2
    info['num_output_mixtures'] = (2 * args.nr_logistic_mix +
                                   args.nr_logistic_mix) * info['num_channels']
    nmix = int(info['num_output_mixtures'] / 2)
    vqvae_model = VQVAE(num_clusters=args.num_k,
                        encoder_output_size=args.num_z,
                        num_output_mixtures=info['num_output_mixtures'],
                        in_channels_size=args.number_condition).to(DEVICE)

    parameters = list(vqvae_model.parameters())
    opt = optim.Adam(parameters, lr=args.learning_rate)
    if args.model_loadpath != '':
        vqvae_model.load_state_dict(model_dict['vqvae_state_dict'])
        opt.load_state_dict(model_dict['optimizer'])
        vqvae_model.embedding = model_dict['embedding']

    #args.pred_output_size = 1*80*80
    ## 10 is result of structure of network
    #args.z_input_size = 10*10*args.num_z
    train_cnt = train_vqvae(train_cnt)
コード例 #26
0
def main(args):
    ###############################
    # TRAIN PREP
    ###############################
    print("Loading data")
    train_loader, valid_loader, data_var, input_size = \
                                data.get_data(args.data_folder,args.batch_size)

    args.input_size = input_size
    args.downsample = args.input_size[-1] // args.enc_height
    args.data_variance = data_var
    print(f"Training set size {len(train_loader.dataset)}")
    print(f"Validation set size {len(valid_loader.dataset)}")

    print("Loading model")
    if args.model == 'diffvqvae':
        model = DiffVQVAE(args).to(device)
    elif args.model == 'vqvae':
        model = VQVAE(args).to(device)
    print(
        f'The model has {utils.count_parameters(model):,} trainable parameters'
    )

    optimizer = optim.Adam(model.parameters(),
                           lr=args.learning_rate,
                           amsgrad=False)

    print(f"Start training for {args.num_epochs} epochs")
    num_batches = math.ceil(
        len(train_loader.dataset) / train_loader.batch_size)
    pbar = Progress(num_batches, bar_length=10, custom_increment=True)

    # Needed for bpd
    args.KL = args.enc_height * args.enc_height * args.num_codebooks * \
                                                    np.log(args.num_embeddings)
    args.num_pixels = np.prod(args.input_size)

    ###############################
    # MAIN TRAIN LOOP
    ###############################
    best_valid_loss = float('inf')
    train_bpd = []
    train_recon_error = []
    train_perplexity = []
    args.global_it = 0
    for epoch in range(args.num_epochs):
        pbar.epoch_start()
        train_epoch(args, vq_vae_loss, pbar, train_loader, model, optimizer,
                    train_bpd, train_recon_error, train_perplexity)
        # loss, _ = test(valid_loader, model, args)
        # pbar.print_eval(loss)
        valid_loss = evaluate(args, vq_vae_loss, pbar, valid_loader, model)
        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            best_valid_epoch = epoch
            torch.save(model.state_dict(), args.save_path)
        pbar.print_end_epoch()

    print("Plotting training results")
    utils.plot_results(train_recon_error, train_perplexity,
                       "results/train.png")

    print("Evaluate and plot validation set")
    generate_samples(model, valid_loader)
コード例 #27
0
NB_EMBED = 512
TRY_CUDA = True
NB_SAMPLES = 4
LATENT_TOP = (32, 32)
LATENT_BOTTOM = (64, 64)
TEMPERATURE = 1.0

device = torch.device('cuda:0' if TRY_CUDA and torch.cuda.is_available() else 'cpu')
print(f"> Device: {device} ({'CUDA is enabled' if TRY_CUDA and torch.cuda.is_available() else 'CUDA not available'}) \n")

vqvae_path = sys.argv[1]
pixelsnail_top_path = sys.argv[2]
pixelsnail_bottom_path = sys.argv[3]

vqvae = VQVAE(
    i_dim=3, h_dim=128, r_dim=64, nb_r_layers=2,
    nb_emd=NB_EMBED, emd_dim=64
).to(device).eval()
vqvae.load_state_dict(torch.load(vqvae_path))

pixelsnail_top = PixelSnail(
    [32, 32],
    nb_class=NB_EMBED,
    channel=256,
    kernel_size=5,
    nb_pixel_block=2,
    nb_res_block=4,
    res_channel=128,
    dropout=0.0,
    nb_out_res_block=1,
).to(device).eval()
pixelsnail_top.load_state_dict(torch.load(pixelsnail_top_path))
コード例 #28
0
    info = {
        'train_cnts': [],
        'train_losses': [],
        'test_cnts': [],
        'test_losses': [],
        'save_times': [],
        'args': [args],
        'last_save': 0,
        'last_plot': -args.plot_every,
    }

    if args.model_loadname is None:
        vmodel = VQVAE(nr_logistic_mix=args.nr_logistic_mix,
                       num_clusters=args.num_k,
                       encoder_output_size=args.num_z,
                       in_channels_size=args.number_condition,
                       out_channels_size=1).to(DEVICE)
        opt = torch.optim.Adam(vmodel.parameters(), lr=args.learning_rate)
    else:
        model_loadpath = os.path.abspath(
            os.path.join(default_base_savedir, args.model_loadname))
        if os.path.exists(model_loadpath):
            model_dict = torch.load(model_loadpath)
            info = model_dict['info']
            largs = info['args'][-1]
            args.number_condition = largs.number_condition
            args.steps_ahead = largs.number_condition
            args.num_z = args.num_z
            args.nr_logistic_mix
            args.num_k = largs.num_k
コード例 #29
0
    args = parser.parse_args()

    print(args)

    device = 'cuda'

    transform = transforms.Compose([
        transforms.Resize(args.size),
        transforms.CenterCrop(args.size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])

    dataset = datasets.ImageFolder(args.path, transform=transform)
    loader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)

    model = nn.DataParallel(VQVAE()).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == 'cycle':
        scheduler = CycleScheduler(optimizer,
                                   args.lr,
                                   n_iter=len(loader) * args.epoch,
                                   momentum=None)

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device)
        torch.save(model.module.state_dict(),
                   f'checkpoint/vqvae_{str(i + 1).zfill(3)}.pt')
コード例 #30
0
        data_root='data',
        image_size=img_sz,
        num_digits=num_dig,
        channels=channels,
        to_sort_label=to_sort_label,
        dig_to_use=dig_to_use,
        nxt_dig_prob=nxt_dig_prob,
        rand_dig_combine=rand_dig_combine,
        split_dig_set=split_dig_set,
    )
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=4)

    model = nn.DataParallel(VQVAE(in_channel=img_chn)).to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = None
    if args.sched == 'cycle':
        scheduler = CycleScheduler(optimizer,
                                   args.lr,
                                   n_iter=len(loader) * args.epoch,
                                   momentum=None)

    for i in range(args.epoch):
        train(i, loader, model, optimizer, scheduler, device)
        torch.save(
            model.module.state_dict(),
            f'experiments/{cur_time}/checkpoint/vqvae_{str(i + 1).zfill(3)}.pt'
        )