def main(args): device = "cuda" args.distributed = dist.get_world_size() > 1 normMean = [0.5] normStd = [0.5] normTransform = transforms.Normalize(normMean, normStd) transform = transforms.Compose([ transforms.Resize(args.size), transforms.ToTensor(), normTransform, ]) txt_path = 'datd/train.txt' images_path = '/data' labels_path = '/data' dataset = txtDataset(txt_path, images_path, labels_path, transform=transform) sampler = dist.data_sampler(dataset, shuffle=True, distributed=args.distributed) loader = DataLoader(dataset, batch_size=batch_size // args.n_gpu, sampler=sampler, num_workers=16) model = VQVAE().to(device) if args.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = None if args.sched == "cycle": scheduler = CycleScheduler( optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None, warmup_proportion=0.05, ) for i in range(args.epoch): train(i, loader, model, optimizer, scheduler, device) if dist.is_primary(): torch.save(model.state_dict(), f"checkpoint/vqvae_{str(i + 1).zfill(3)}.pt")
def load_model_from_file(path): with open(os.path.join(path, 'args.json'), 'rb') as f: args = dotdict(json.load(f)) from vqvae import VQVAE # create model model = VQVAE(args) # load weights model.load_state_dict(torch.load(os.path.join(path, 'best_model.pth'))) return model
def main(args): device = "cuda" args.distributed = dist.get_world_size() > 1 transform = transforms.Compose([ transforms.Resize(args.size), transforms.CenterCrop(args.size), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) # dataset = datasets.ImageFolder(args.path, transform=transform) dataset = CUBDataset(args.path, transform=transform, mode=args.mode) sampler = dist.data_sampler(dataset, shuffle=True, distributed=args.distributed) loader = DataLoader(dataset, batch_size=128 // args.n_gpu, sampler=sampler, num_workers=2) model = VQVAE().to(device) if args.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = None if args.sched == "cycle": scheduler = CycleScheduler( optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None, warmup_proportion=0.05, ) print(args) for i in range(args.epoch): train(i, loader, model, optimizer, scheduler, device) if dist.is_primary(): torch.save(model.state_dict(), f"checkpoint/vqvae_{str(i + 1).zfill(3)}.pt")
def load_model(model, checkpoint, device): ckpt = torch.load(os.path.join('checkpoint', checkpoint)) if 'args' in ckpt: args = ckpt['args'] if model == 'vqvae': model = VQVAE() elif model == 'pixelsnail_bottom': model = PixelSNAIL( [64, 64], 512, args.channel, 5, 4, args.n_res_block, args.n_res_channel, attention=False, dropout=args.dropout, n_cond_res_block=args.n_cond_res_block, cond_res_channel=args.n_res_channel, ) if 'model' in ckpt: ckpt = ckpt['model'] model.load_state_dict(ckpt) model = model.to(device) model.eval() return model
def main(args): device = "cpu" args.distributed = dist.get_world_size() > 1 transform = transforms.Compose( [ transforms.Resize(args.size), transforms.CenterCrop(args.size), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ] ) dataset = datasets.ImageFolder(args.path, transform=transform) sampler = dist.data_sampler(dataset, shuffle=True, distributed=args.distributed) loader = DataLoader( dataset, batch_size=128 // args.n_gpu, sampler=sampler, num_workers=2 ) model = VQVAE().to(device) if args.load_path: load_state_dict = torch.load(args.load_path, map_location=device) model.load_state_dict(load_state_dict) print('successfully loaded model') if args.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = None if args.sched == "cycle": scheduler = CycleScheduler( optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None, warmup_proportion=0.05, ) interpolate(loader, model, device)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--size', type=int, default=256) parser.add_argument('--model_path', type=str) parser.add_argument('--name', type=str) parser.add_argument('path', type=str) args = parser.parse_args() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') transform = transforms.Compose([ transforms.Resize(args.size), transforms.CenterCrop(args.size), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) dataset = ImageFileDataset(args.path, transform=transform) loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=4) model = VQVAE() model.load_state_dict(torch.load(args.model_path)) model = model.to(device) model.eval() map_size = 100 * 1024 * 1024 * 1024 env = lmdb.open(args.name, map_size=map_size) extract(env, loader, model, device)
def construct_model(): x_input = tf.keras.layers.Input((28, 28, 1)) enc_x = EncoderLayer()(x_input) quant_x = VQVAE()(enc_x) x_dec = tf.keras.layers.Lambda( lambda quant_x: enc_x + tf.stop_gradient(quant_x - enc_x))(quant_x) dec_x = DecoderLayer()(x_dec) model = tf.keras.models.Model(x_input, dec_x) model.compile(optimizer=tf.keras.optimizers.Adam(), loss=vqvae_loss(0.25, enc_x, quant_x), experimental_run_tf_function=False) return model
def encode_proc(model_path, model_config_path, img_root_path, img_key_path_list, img_size, device, output_path): model_config_json = open(model_config_path).read() print("ModelConfig:", model_config_json, file=sys.stderr, flush=True) model_config = VqvaeConfig.from_json(model_config_json) model = VQVAE(model_config).to(device) if device.type == "cuda": torch.cuda.set_device(device) model.load_state_dict(torch.load(model_path, map_location=device)) model.eval() transforms = build_transform(img_size) output_fp = open(output_path, "w") linecnt = 0 for f in img_key_path_list: for line in open(f): linecnt += 1 if linecnt % 100000 == 0: print("{} {} done".format(f, linecnt), file=sys.stderr, flush=True) img_key = line.strip() img_path = get_key_path(img_root_path, line.strip()) try: img = default_loader(img_path) except: continue img = transforms(img)[None].to(device) id_t = model(img)[2].detach().cpu().flatten(1) print("{}\t{}".format(img_key, ",".join( (str(x) for x in id_t[0].tolist()))), file=output_fp, flush=True) output_fp.close()
def train_vqvae(hparams_path): hparams = load_hparams(hparams_path) os.makedirs(hparams.folder, exist_ok=True) model = VQVAE(hparams) logger = pl.loggers.TensorBoardLogger(save_dir=hparams.folder, name="logs") trainer = pl.Trainer( default_root=hparams.folder, max_epochs=hparams.epochs, show_progress_bar=False, gpus=hparams.gpus, logger=logger, ) trainer.fit(model)
def __init__(self, in_channels, hidden_channels, res_channels, nb_res_layers, nb_levels, embed_dim, nb_entries, scaling_rates, lr, beta, batch_size, mini_batch_size, no_amp, random_resets, device): self.device = device self.model = VQVAE(in_channel=in_channels, channel=hidden_channels, n_res_channel=res_channels, n_res_block=nb_res_layers, nb_levels=nb_levels, embed_dim=embed_dim, n_embed=nb_entries, scaling_rates=scaling_rates, random_resets=random_resets).to(self.device) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) self.optimizer.zero_grad() self.beta = beta self.scaler = torch.cuda.amp.GradScaler(enabled=not no_amp) self.update_frequency = math.ceil(batch_size / mini_batch_size) self.steps = 0
def main(args): root = args.root results_dir = args.results_dir save_path = os.path.join(root, results_dir) print('root is', root) print('save_path is:', save_path) os.makedirs(save_path, exist_ok=True) json_file_name = os.path.join(save_path, 'args.json') with open(json_file_name, 'w') as fp: json.dump(dict(args._get_kwargs()), fp, sort_keys=True, indent=4) checkpoints_path = os.path.join(save_path, 'checkpoints') os.makedirs(checkpoints_path, exist_ok=True) sample_output_path = os.path.join(save_path, 'output') os.makedirs(sample_output_path, exist_ok=True) log_file = os.path.join(save_path, 'log.txt') config_logging(log_file) logging.info('====> args{} '.format(args)) num_workers = args.num_workers device = "cuda" batch_size = args.batch_size dataset_path = args.dataset_path transform_train = transform_train_cifar train_ds = CIFAR100(root=dataset_path, train=True, download=True, transform=transform_train) obtain_indices = get_indices classes = [i for i in range(args.pretrain_classes)] print('pretrain vqvae using ', classes) training_idx = obtain_indices(train_ds, classes, is_training=True) loader = DataLoader(train_ds, batch_size=batch_size, sampler=SubsetRandomSampler(training_idx), num_workers=num_workers, drop_last=False) model = VQVAE(embed_dim=args.dim_emb, n_embed=args.n_emb).to(device) if args.checkpoint is not None: model_pt = torch.load(args.checkpoint) model.load_state_dict(model_pt) opt = optim.Adam(model.parameters(), lr=args.lr) best_mse = 999999 for i in range(args.epoch): tmp_mse = train_AE(i, loader, model, opt, device, save_path) if best_mse > tmp_mse: best_mse = tmp_mse logging.info('====> Epoch{}: best_mse {} '.format(i, best_mse)) pt_path = os.path.join(save_path, f"checkpoints/VQVAE2_cifar_best.pt") torch.save(model.state_dict(), pt_path)
def __init__(self, vqvae, in_channel=6, channel=128, n_res_block=2, n_res_channel=32, embed_dim=64, n_embed=512, decay=0.99): super(OffsetNetwork, self).__init__(in_channel=in_channel, channel=channel, n_res_block=n_res_block, n_res_channel=n_res_channel, embed_dim=embed_dim, n_embed=n_embed, decay=decay) # Fix pre-trained VQVAE self.vqvae = vqvae if vqvae is not None else VQVAE() for params in self.vqvae.parameters(): params.requires_grad = False
def encode_fn(): model = VQVAE().to(device) def encode(x): model.eval() with torch.no_grad(): x = cv2.resize(x, (160, 160), interpolation=cv2.INTER_AREA) x = transform(x) x = x.unsqueeze(0) x = x.to(device) _, _, _, id_t, id_b = model.encode(x) id_t = id_t.cpu().numpy() # id_b = id_b.cpu().numpy() model.train() return id_t return model, encode
parser = argparse.ArgumentParser() parser.add_argument('--size', type=int, default=256) parser.add_argument('--ckpt', type=str) parser.add_argument('--name', type=str) parser.add_argument('path', type=str) args = parser.parse_args() device = 'cuda' transform = transforms.Compose([ transforms.Resize(args.size), transforms.CenterCrop(args.size), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) dataset = ImageFileDataset(args.path, transform=transform) loader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=4) model = VQVAE() model.load_state_dict(torch.load(args.ckpt)) model = model.to(device) model.eval() map_size = 100 * 1024 * 1024 * 1024 env = lmdb.open(args.name, map_size=map_size) extract(env, loader, model, device)
def main(args): device = "cuda" args.distributed = dist.get_world_size() > 1 transform = transforms.Compose([ transforms.Resize(args.size), transforms.CenterCrop(args.size), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) dataset = OffsetDataset(args.path, transform=transform, offset=args.offset) sampler = dist.data_sampler(dataset, shuffle=True, distributed=args.distributed) loader = DataLoader(dataset, batch_size=args.bsize // args.n_gpu, sampler=sampler, num_workers=2) # Load pre-trained VQVAE vqvae = VQVAE().to(device) try: vqvae.load_state_dict(torch.load(args.ckpt)) except: print( "Seems the checkpoint was trained with data parallel, try loading it that way" ) weights = torch.load(args.ckpt) renamed_weights = {} for key, value in weights.items(): renamed_weights[key.replace('module.', '')] = value weights = renamed_weights vqvae.load_state_dict(weights) # Init offset encoder model = OffsetNetwork(vqvae).to(device) if args.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), find_unused_parameters=True) optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = None if args.sched == "cycle": scheduler = CycleScheduler( optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None, warmup_proportion=0.05, ) for i in range(args.epoch): train(i, loader, model, optimizer, scheduler, device) if dist.is_primary(): torch.save(model.state_dict(), f"checkpoint/offset_enc_{str(i + 1).zfill(3)}.pt")
def encode(args): model_config_json = open(args.config_path).read() print("ModelConfig:", model_config_json, file=sys.stderr, flush=True) model_config = VqvaeConfig.from_json(model_config_json) device = torch.device(args.device) n_gpu = torch.cuda.device_count() if args.device == "cuda" else 0 model = VQVAE(model_config).to(device) model.load_state_dict(torch.load(args.model_path, map_location=device)) if n_gpu > 1: model = torch.nn.DataParallel(model) model.eval() trans = build_transform(args.img_size) dataset = ImageLmdbDataset(args.img_root_path, args.img_key_path, trans, args.batch_size, with_key=True) dataloader = IterDataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=collate_fn, pin_memory=True) lmdb_env = lmdb.open(args.output_path, map_size=int(1e12)) lmdb_txn = lmdb_env.begin(write=True) cache = [] batch_cnt = 0 input_cost = 0.0 to_cost = 0.0 eval_cost = 0.0 trans_cost = 0.0 write_cost = 0.0 write_batch_cnt = 0 t_point = time.time() start_point = t_point for key_list, img_batch in dataloader: t_point_1 = time.time() input_cost += t_point_1 - t_point t_point = t_point_1 img_batch.to(device) t_point_1 = time.time() to_cost += t_point_1 - t_point t_point = t_point_1 id_batch = model(img_batch)[3].detach().cpu().flatten(1) t_point_1 = time.time() eval_cost += t_point_1 - t_point t_point = t_point_1 for key, id_t in zip(key_list, id_batch): lmdb_txn.put(key.encode("utf-8"), id_t.to(torch.int16).numpy().tobytes()) t_point_1 = time.time() trans_cost += t_point_1 - t_point t_point = t_point_1 ''' if len(cache) > 1000: for k, v in cache: lmdb_txn.put(k, v) lmdb_txn.commit() lmdb_txn = lmdb_env.begin(write=True) del cache[:] t_point_1 = time.time() write_cost += t_point_1 - t_point t_point = t_point_1 write_batch_cnt = batch_cnt ''' batch_cnt += 1 if batch_cnt % 100 == 0: print( "{} batch done, input_c={:.4f}, to_c={:.4f}, eval_c={:.4f}, trans_c={:.4f}, total_c={:.4f}" .format(batch_cnt, input_cost / batch_cnt, to_cost / batch_cnt, eval_cost / batch_cnt, trans_cost / batch_cnt, (time.time() - start_point) / batch_cnt), file=sys.stderr, flush=True) t_point = time.time() lmdb_txn.commit() lmdb_env.close()
args = parser.parse_args() print(args) device = 'cuda' transform = transforms.Compose([ transforms.Resize(args.size), transforms.CenterCrop(args.size), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) dataset = datasets.ImageFolder(args.path, transform=transform) loader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4) model = VQVAE().to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = None if args.sched == 'cycle': scheduler = CycleScheduler(optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None) for i in range(args.epoch): train(i, loader, model, optimizer, scheduler, device) torch.save(model.state_dict(), f'checkpoint/vqvae_{str(i + 1).zfill(3)}.pt')
device = device_list[device_i] input_list = img_key_path_list[device_i * proc_pic: (device_i + 1) * proc_pic] output_path = os.path.join(args.output_path, "part-{}".format(device_i)) proc_list.append(Process( target=encode_proc, args=(args.model_path, args.config_path, args.img_root_path, input_list, args.img_size, device, output_path) )) for proc in proc_list: proc.start() for proc in proc_list: proc.join() ''' encode(args) if __name__ == "__main__": main() exit(0) config_path = os.path.join(sys.argv[1], "config.json") model_path = os.path.join(sys.argv[1], "pytorch_model.bin") model = VQVAE(VqvaeConfig.from_json(open(config_path).read())).to("cpu") model.load_state_dict(torch.load(model_path, map_location="cpu")) model.eval() img_path = get_key_path("/mnt2/makai/imgs", sys.argv[2]) trans = build_transform(224) img = default_loader(img_path) img = trans(img)[None] id_t, id_b = model(img)[2:4] print(",".join((str(x.item()) for x in id_b.flatten(1)[0])))
print(args) device = 'cuda' # transform = transforms.Compose( # [ # transforms.Resize(args.size), # transforms.CenterCrop(args.size), # transforms.ToTensor(), # transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), # ] # ) # dataset = datasets.ImageFolder(args.path, transform=transform) # loader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4) loader = get_dataset(args.path, batch_size=args.batchsize) model = nn.DataParallel(VQVAE(embed_dim=32)).to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = None if args.sched == 'cycle': scheduler = CycleScheduler(optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None) for i in range(args.epoch): train(i, loader, model, optimizer, scheduler, device) torch.save( model.module.state_dict(), f'allCheckpoint/checkpoint32/vqvae_{str(i + 1).zfill(3)}.pt') writer.close()
def main(): parser = argparse.ArgumentParser() parser.add_argument('--size', help='Image size', type=int, default=256) parser.add_argument('--epoch', type=int, default=100) parser.add_argument('--lr', type=float, default=3e-4) parser.add_argument('--bs', type=int, default=64) parser.add_argument('--sched', type=str, default='cycle') parser.add_argument('--vishost', type=str, default='localhost') parser.add_argument('--visport', type=int, default=8097) parser.add_argument('path', help="root path with train and test folder in it", type=str) args = parser.parse_args() print(args) device = torch.device('cuda' if torch.cuda.is_available() else "cpu") transform = transforms.Compose([ transforms.Resize(args.size), transforms.CenterCrop(args.size), transforms.ToTensor(), transforms.Normalize([0.5] * 3, [0.5] * 3) ]) train_path = os.path.join(args.path, "train") test_path = os.path.join(args.path, "test") train_dataset = datasets.ImageFolder(train_path, transform=transform) train_loader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True, num_workers=4) test_dataset = datasets.ImageFolder(test_path, transform=transform) test_loader = DataLoader(test_dataset, batch_size=args.bs, shuffle=False, num_workers=4) model = VQVAE().to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr) if args.sched == 'cycle': scheduler = CycleScheduler(optimizer, args.lr, n_iter=len(train_loader) * args.epoch, momentum=None) else: scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [50, 70], 0.1) train_losses = [] test_losses = [] vis = visdom.Visdom(server=args.vishost, port=args.visport) win = None best_model_loss = np.inf for i in range(args.epoch): # Training stage print(f"Training epoch {i + 1}") train_loss = train(i, train_loader, model, optimizer, scheduler, device) print(f"Train Loss: {train_loss:.5f}") # Testing stage print(f"Testing epoch {i + 1}") test_loss, test_recon_error, test_commitment_loss = test( i, test_loader, model, device) print(f"Test Loss: {test_loss:.5f}") torch.save(model.state_dict(), f'checkpoints/vqvae_chkpt_{str(i + 1).zfill(3)}.pt') if test_loss < best_model_loss: print("Saving model") torch.save(model.state_dict(), f'weights/vqvae.pt') best_model_loss = test_loss train_losses.append(train_loss) test_losses.append(test_loss) win = plot(train_losses, test_losses, vis, win) # Sampling stage recon_sample(i, model, test_loader, device)
valid_data_loader = AtariDataset( valid_data_file, number_condition=4, steps_ahead=1, batch_size=largs.batch_size, norm_by=255.0,) args.size_training_set = valid_data_loader.num_examples hsize = valid_data_loader.data_h wsize = valid_data_loader.data_w if args.reward_int: int_reward = info['num_rewards'] vqvae_model = VQVAE(num_clusters=largs.num_k, encoder_output_size=largs.num_z, num_output_mixtures=info['num_output_mixtures'], in_channels_size=largs.number_condition, n_actions=info['num_actions'], int_reward=info['num_rewards']).to(DEVICE) elif 'num_rewards' in info.keys(): print("CREATING model with est future reward") vqvae_model = VQVAE(num_clusters=largs.num_k, encoder_output_size=largs.num_z, num_output_mixtures=info['num_output_mixtures'], in_channels_size=largs.number_condition, n_actions=info['num_actions'], int_reward=False, reward_value=True).to(DEVICE) else: vqvae_model = VQVAE(num_clusters=largs.num_k, encoder_output_size=largs.num_z, num_output_mixtures=info['num_output_mixtures'],
train_data_loader = AtariDataset( train_data_file, number_condition=4, steps_ahead=1, batch_size=args.batch_size, norm_by=255.,) valid_data_loader = AtariDataset( valid_data_file, number_condition=4, steps_ahead=1, batch_size=largs.batch_size, norm_by=255.0,) num_actions = valid_data_loader.n_actions args.size_training_set = valid_data_loader.num_examples hsize = valid_data_loader.data_h wsize = valid_data_loader.data_w vqvae_model = VQVAE(num_clusters=largs.num_k, encoder_output_size=largs.num_z, in_channels_size=largs.number_condition).to(DEVICE) vqvae_model.load_state_dict(model_dict['vqvae_state_dict']) #valid_data, valid_label, test_batch_index = data_loader.validation_ordered_batch() #valid_episode_batch, episode_index, episode_reward = valid_data_loader.get_entire_episode() #sample_batch(valid_episode_batch, episode_index, episode_reward, 'valid') train_episode_batch, episode_index, episode_reward = train_data_loader.get_entire_episode() sample_batch(train_episode_batch, episode_index, episode_reward, 'train')
class Trainer: def __init__(self, in_channels, hidden_channels, res_channels, nb_res_layers, nb_levels, embed_dim, nb_entries, scaling_rates, lr, beta, batch_size, mini_batch_size, no_amp, random_resets, device): self.device = device self.model = VQVAE(in_channel=in_channels, channel=hidden_channels, n_res_channel=res_channels, n_res_block=nb_res_layers, nb_levels=nb_levels, embed_dim=embed_dim, n_embed=nb_entries, scaling_rates=scaling_rates, random_resets=random_resets).to(self.device) self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr) self.optimizer.zero_grad() self.beta = beta self.scaler = torch.cuda.amp.GradScaler(enabled=not no_amp) self.update_frequency = math.ceil(batch_size / mini_batch_size) self.steps = 0 def _calculate_loss(self, x): y, d, _, _, _ = self.model(x) r_loss, l_loss = y.sub(x).pow(2).mean(), sum(d) loss = r_loss + self.beta * l_loss return loss, r_loss, l_loss, y def train(self, x): self.model.train() with torch.cuda.amp.autocast(enabled=self.scaler.is_enabled()): loss, r_loss, l_loss, _ = self._calculate_loss(x) self.scaler.scale(loss / self.update_frequency).backward() self.steps += 1 if self.steps % self.update_frequency == 0: self._update_parameters() return loss.item(), r_loss.item(), l_loss.item() def _update_parameters(self): self.scaler.step(self.optimizer) self.optimizer.zero_grad() self.scaler.update() @torch.no_grad() def eval(self, x): self.model.eval() self.optimizer.zero_grad() loss, r_loss, l_loss, y = self._calculate_loss(x) return loss.item(), r_loss.item(), l_loss.item(), y def save_checkpoint(self, path): torch.save(self.model.state_dict(), path) def load_checkpoint(self, path): self.model.load_state_dict(torch.load(path)) def save_reconstructions(self, batch, path, sample_size=16): batch = batch[:sample_size] _, _, _, out = self.eval(batch) utils.save_image(torch.cat([batch, out]), path, nrow=batch.shape[0], normalize=True, value_range=(-1, 1))
args = parser.parse_args() print(args) device = 'cuda' transform = transforms.Compose( [ transforms.Resize(args.size), transforms.CenterCrop(args.size), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ] ) dataset = datasets.ImageFolder(args.path, transform=transform) loader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4) model = VQVAE().to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = None if args.sched == 'cycle': scheduler = CycleScheduler( optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None ) for i in range(args.epoch): train(i, loader, model, optimizer, scheduler, device) torch.save(model.state_dict(), f'checkpoint/vqvae_{str(i + 1).zfill(3)}.pt')
number_condition=args.number_condition, steps_ahead=1, batch_size=args.batch_size, norm_by=info['norm_by']) num_actions = train_data_loader.n_actions args.size_training_set = train_data_loader.num_examples hsize = train_data_loader.data_h wsize = train_data_loader.data_w # output mixtures should be 2*nr_logistic_mix + nr_logistic mix for each # decorelated channel info['num_channels'] = 2 info['num_output_mixtures'] = (2 * args.nr_logistic_mix + args.nr_logistic_mix) * info['num_channels'] nmix = int(info['num_output_mixtures'] / 2) vqvae_model = VQVAE(num_clusters=args.num_k, encoder_output_size=args.num_z, num_output_mixtures=info['num_output_mixtures'], in_channels_size=args.number_condition).to(DEVICE) parameters = list(vqvae_model.parameters()) opt = optim.Adam(parameters, lr=args.learning_rate) if args.model_loadpath != '': vqvae_model.load_state_dict(model_dict['vqvae_state_dict']) opt.load_state_dict(model_dict['optimizer']) vqvae_model.embedding = model_dict['embedding'] #args.pred_output_size = 1*80*80 ## 10 is result of structure of network #args.z_input_size = 10*10*args.num_z train_cnt = train_vqvae(train_cnt)
def main(args): ############################### # TRAIN PREP ############################### print("Loading data") train_loader, valid_loader, data_var, input_size = \ data.get_data(args.data_folder,args.batch_size) args.input_size = input_size args.downsample = args.input_size[-1] // args.enc_height args.data_variance = data_var print(f"Training set size {len(train_loader.dataset)}") print(f"Validation set size {len(valid_loader.dataset)}") print("Loading model") if args.model == 'diffvqvae': model = DiffVQVAE(args).to(device) elif args.model == 'vqvae': model = VQVAE(args).to(device) print( f'The model has {utils.count_parameters(model):,} trainable parameters' ) optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, amsgrad=False) print(f"Start training for {args.num_epochs} epochs") num_batches = math.ceil( len(train_loader.dataset) / train_loader.batch_size) pbar = Progress(num_batches, bar_length=10, custom_increment=True) # Needed for bpd args.KL = args.enc_height * args.enc_height * args.num_codebooks * \ np.log(args.num_embeddings) args.num_pixels = np.prod(args.input_size) ############################### # MAIN TRAIN LOOP ############################### best_valid_loss = float('inf') train_bpd = [] train_recon_error = [] train_perplexity = [] args.global_it = 0 for epoch in range(args.num_epochs): pbar.epoch_start() train_epoch(args, vq_vae_loss, pbar, train_loader, model, optimizer, train_bpd, train_recon_error, train_perplexity) # loss, _ = test(valid_loader, model, args) # pbar.print_eval(loss) valid_loss = evaluate(args, vq_vae_loss, pbar, valid_loader, model) if valid_loss < best_valid_loss: best_valid_loss = valid_loss best_valid_epoch = epoch torch.save(model.state_dict(), args.save_path) pbar.print_end_epoch() print("Plotting training results") utils.plot_results(train_recon_error, train_perplexity, "results/train.png") print("Evaluate and plot validation set") generate_samples(model, valid_loader)
NB_EMBED = 512 TRY_CUDA = True NB_SAMPLES = 4 LATENT_TOP = (32, 32) LATENT_BOTTOM = (64, 64) TEMPERATURE = 1.0 device = torch.device('cuda:0' if TRY_CUDA and torch.cuda.is_available() else 'cpu') print(f"> Device: {device} ({'CUDA is enabled' if TRY_CUDA and torch.cuda.is_available() else 'CUDA not available'}) \n") vqvae_path = sys.argv[1] pixelsnail_top_path = sys.argv[2] pixelsnail_bottom_path = sys.argv[3] vqvae = VQVAE( i_dim=3, h_dim=128, r_dim=64, nb_r_layers=2, nb_emd=NB_EMBED, emd_dim=64 ).to(device).eval() vqvae.load_state_dict(torch.load(vqvae_path)) pixelsnail_top = PixelSnail( [32, 32], nb_class=NB_EMBED, channel=256, kernel_size=5, nb_pixel_block=2, nb_res_block=4, res_channel=128, dropout=0.0, nb_out_res_block=1, ).to(device).eval() pixelsnail_top.load_state_dict(torch.load(pixelsnail_top_path))
info = { 'train_cnts': [], 'train_losses': [], 'test_cnts': [], 'test_losses': [], 'save_times': [], 'args': [args], 'last_save': 0, 'last_plot': -args.plot_every, } if args.model_loadname is None: vmodel = VQVAE(nr_logistic_mix=args.nr_logistic_mix, num_clusters=args.num_k, encoder_output_size=args.num_z, in_channels_size=args.number_condition, out_channels_size=1).to(DEVICE) opt = torch.optim.Adam(vmodel.parameters(), lr=args.learning_rate) else: model_loadpath = os.path.abspath( os.path.join(default_base_savedir, args.model_loadname)) if os.path.exists(model_loadpath): model_dict = torch.load(model_loadpath) info = model_dict['info'] largs = info['args'][-1] args.number_condition = largs.number_condition args.steps_ahead = largs.number_condition args.num_z = args.num_z args.nr_logistic_mix args.num_k = largs.num_k
args = parser.parse_args() print(args) device = 'cuda' transform = transforms.Compose([ transforms.Resize(args.size), transforms.CenterCrop(args.size), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) dataset = datasets.ImageFolder(args.path, transform=transform) loader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4) model = nn.DataParallel(VQVAE()).to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = None if args.sched == 'cycle': scheduler = CycleScheduler(optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None) for i in range(args.epoch): train(i, loader, model, optimizer, scheduler, device) torch.save(model.module.state_dict(), f'checkpoint/vqvae_{str(i + 1).zfill(3)}.pt')
data_root='data', image_size=img_sz, num_digits=num_dig, channels=channels, to_sort_label=to_sort_label, dig_to_use=dig_to_use, nxt_dig_prob=nxt_dig_prob, rand_dig_combine=rand_dig_combine, split_dig_set=split_dig_set, ) loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4) model = nn.DataParallel(VQVAE(in_channel=img_chn)).to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = None if args.sched == 'cycle': scheduler = CycleScheduler(optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None) for i in range(args.epoch): train(i, loader, model, optimizer, scheduler, device) torch.save( model.module.state_dict(), f'experiments/{cur_time}/checkpoint/vqvae_{str(i + 1).zfill(3)}.pt' )