def train(epoch, loader, model, optimizer, scheduler, device): if dist.is_primary(): loader = tqdm(loader) criterion = nn.MSELoss() latent_loss_weight = 0.25 sample_size = 25 mse_sum = 0 mse_n = 0 lr = optimizer.param_groups[0]["lr"] for i, (img, _, _) in enumerate(loader): model.zero_grad() img = img.to(device) out, latent_loss = model(img) recon_loss = criterion(out, img) latent_loss = latent_loss.mean() loss = recon_loss + latent_loss_weight * latent_loss loss.backward() if scheduler is not None: scheduler.step() optimizer.step() loader.set_postfix_str( f'Step: {i + 1}: MSE: {recon_loss.item():.5f}; Latent: {latent_loss.item():.3f}; Total: {loss.item():.5f}' )
def main(args): device = "cuda" args.distributed = dist.get_world_size() > 1 normMean = [0.5] normStd = [0.5] normTransform = transforms.Normalize(normMean, normStd) transform = transforms.Compose([ transforms.Resize(args.size), transforms.ToTensor(), normTransform, ]) txt_path = 'datd/train.txt' images_path = '/data' labels_path = '/data' dataset = txtDataset(txt_path, images_path, labels_path, transform=transform) sampler = dist.data_sampler(dataset, shuffle=True, distributed=args.distributed) loader = DataLoader(dataset, batch_size=batch_size // args.n_gpu, sampler=sampler, num_workers=16) model = VQVAE().to(device) if args.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = None if args.sched == "cycle": scheduler = CycleScheduler( optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None, warmup_proportion=0.05, ) for i in range(args.epoch): train(i, loader, model, optimizer, scheduler, device) if dist.is_primary(): torch.save(model.state_dict(), f"checkpoint/vqvae_{str(i + 1).zfill(3)}.pt")
def main(args): device = "cuda" args.distributed = dist.get_world_size() > 1 transforms = video_transforms.Compose([ RandomSelectFrames(16), video_transforms.Resize(args.size), video_transforms.CenterCrop(args.size), volume_transforms.ClipToTensor(), tensor_transforms.Normalize(0.5, 0.5) ]) f = open( '/home/shirakawa/movie/code/iVideoGAN/over16frame_list_training.txt', 'rb') train_file_list = pickle.load(f) print(len(train_file_list)) dataset = MITDataset(train_file_list, transform=transforms) sampler = dist.data_sampler(dataset, shuffle=True, distributed=args.distributed) #loader = DataLoader( # dataset, batch_size=128 // args.n_gpu, sampler=sampler, num_workers=2 #) loader = DataLoader(dataset, batch_size=32 // args.n_gpu, sampler=sampler, num_workers=2) model = VQVAE().to(device) if args.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = None if args.sched == "cycle": scheduler = CycleScheduler( optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None, warmup_proportion=0.05, ) for i in range(args.epoch): train(i, loader, model, optimizer, scheduler, device) if dist.is_primary(): torch.save(model.state_dict(), f"checkpoint_vid_v2/vqvae_{str(i + 1).zfill(3)}.pt")
def main(args): device = "cuda" args.distributed = dist.get_world_size() > 1 transform = transforms.Compose([ transforms.Resize(args.size), transforms.CenterCrop(args.size), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) dataset = datasets.ImageFolder(args.path, transform=transform) sampler = dist.data_sampler(dataset, shuffle=True, distributed=args.distributed) loader = DataLoader(dataset, batch_size=128 // args.n_gpu, sampler=sampler, num_workers=2) model = VQVAE().to(device) if args.load_path: load_state_dict = torch.load(args.load_path, map_location=device) model.load_state_dict(load_state_dict) print('successfully loaded model') if args.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = None if args.sched == "cycle": scheduler = CycleScheduler( optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None, warmup_proportion=0.05, ) for i in range(args.epoch): train(i, loader, model, optimizer, scheduler, device) if dist.is_primary(): torch.save(model.state_dict(), f"checkpoint/vqvae_{str(i + 1).zfill(3)}.pt")
def evaluate(loader, model, out_path, sample_size): if dist.is_primary(): loader = tqdm(loader) model.eval() i, (img, label) = next(enumerate(loader)) sample = img[:sample_size] with torch.no_grad(): out, _ = model(sample) utils.save_image( torch.cat([sample, out], 0), out_path, nrow=sample_size, normalize=True, range=(-1, 1), )
def train(epoch, loader, model, optimizer, scheduler, scaler, device): if dist.is_primary(): loader = tqdm(loader) criterion = nn.MSELoss() latent_loss_weight = 0.25 sample_size = 25 mse_sum = 0 mse_n = 0 for i, (img, label) in enumerate(loader): model.zero_grad() img = img.to(device) with torch.cuda.amp.autocast(scaler.is_enabled()): out, latent_loss = model(img) recon_loss = criterion(out, img) latent_loss = latent_loss.mean() loss = recon_loss + latent_loss_weight * latent_loss scaler.scale(loss).backward() if scheduler is not None: scheduler.step() scaler.step(optimizer) scaler.update() part_mse_sum = recon_loss.item() * img.shape[0] part_mse_n = img.shape[0] comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n} comm = dist.all_gather(comm) for part in comm: mse_sum += part["mse_sum"] mse_n += part["mse_n"] if dist.is_primary(): lr = optimizer.param_groups[0]["lr"] loader.set_description(( f"epoch: {epoch + 1}; mse: {recon_loss.item():.5f}; " f"latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; " f"lr: {lr:.5f}")) if i % 100 == 0: model.eval() sample = img[:sample_size] with torch.no_grad(), torch.cuda.amp.autocast( scaler.is_enabled()): out, _ = model(sample) utils.save_image( torch.cat([sample, out], 0), f"sample/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png", nrow=sample_size, normalize=True, range=(-1, 1), ) model.train()
def train(epoch, loader, model, optimizer, scheduler, device): if dist.is_primary(): loader = tqdm(loader) criterion = nn.MSELoss() latent_loss_weight = 0.25 sample_size = 25 mse_sum = 0 mse_n = 0 for i, (img, label) in enumerate(loader): model.zero_grad() img = img.to(device) out, latent_loss = model(img) recon_loss = criterion(out, img) latent_loss = latent_loss.mean() loss = recon_loss + latent_loss_weight * latent_loss loss.backward() wandb.log({'train loss': loss.item()}) if scheduler is not None: scheduler.step() optimizer.step() part_mse_sum = recon_loss.item() * img.shape[0] part_mse_n = img.shape[0] comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n} comm = dist.all_gather(comm) for part in comm: mse_sum += part["mse_sum"] mse_n += part["mse_n"] if dist.is_primary(): lr = optimizer.param_groups[0]["lr"] loader.set_description(( f"epoch: {epoch + 1}; mse: {recon_loss.item():.5f}; " f"latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; " f"lr: {lr:.5f}")) if i % 100 == 0: model.eval() sample = img[:sample_size] with torch.no_grad(): out, _ = model(sample) # utils.save_image( # torch.cat([sample, out], 0), # f"sample/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png", # nrow=sample_size, # normalize=True, # range=(-1, 1), # ) example_images = [ wandb.Image(image, caption=f"{epoch}_{i}") for image in out ] wandb.log({"Examples": example_images}) model.train()
def train(epoch, loader, discriminator, generator, scheduler_D, scheduler_G, optimizer_D, optimizer_G, device): loader_d = tqdm(loader) if (epoch + 1) % n_critic == 0: loader_g = tqdm(loader) adversarial_loss = nn.BCEWithLogitsLoss() # sigmoid pixelwise_loss = nn.L1Loss() gdloss = GDLoss() recon_loss_weight = 0.4 latent_loss_weight = 0.2 gradient_loss_weight = 0.4 sample_size = batch_size mse_sum = 0 mse_n = 0 g_sum = 0 g_n = 0 requires_grad(generator, False) requires_grad(discriminator, True) # --------------------- # Train Discriminator # --------------------- for i, (img, label, label_path, class_name) in enumerate(loader_d): discriminator.zero_grad() valid = Variable(torch.Tensor(img.shape[0], 1).fill_(1.0), requires_grad=False) fake = Variable(torch.Tensor(img.shape[0], 1).fill_(0.0), requires_grad=False) img = img.to(device) valid = valid.to(device) fake = fake.to(device) label = label.to(device) gdloss.conv_x = gdloss.conv_x.to(device) gdloss.conv_y = gdloss.conv_y.to(device) vqvae2_out, latent_loss = generator(img) real_loss = adversarial_loss(discriminator(label), valid) fake_loss = adversarial_loss(discriminator(vqvae2_out), fake) d_loss = 0.5 * (real_loss + fake_loss) d_loss.backward() if scheduler_D is not None: scheduler_D.step() optimizer_D.step() if dist.is_primary(): lr = optimizer_D.param_groups[0]["lr"] loader_d.set_description(( f"Discriminator epoch: {epoch + 1}; class loss: {d_loss.item():.5f};" f"lr: {lr:.5f}")) # --------------------- # Train Generator # --------------------- if (epoch + 1) % n_critic == 0: requires_grad(generator, True) requires_grad(discriminator, False) for i, (img, label, label_path, class_name) in enumerate(loader_g): generator.zero_grad() valid = Variable(torch.Tensor(img.shape[0], 1).fill_(1.0), requires_grad=False) img = img.to(device) valid = valid.to(device) label = label.to(device) gdloss.conv_x = gdloss.conv_x.to(device) gdloss.conv_y = gdloss.conv_y.to(device) vqvae2_out, latent_loss = generator(img) recon_loss = pixelwise_loss(vqvae2_out, label) gradient_loss = gdloss(vqvae2_out, label) gradient_loss = gradient_loss.mean() latent_loss = latent_loss.mean() g_loss = 0.1 * adversarial_loss(discriminator(vqvae2_out), valid) + \ 0.9 * (recon_loss_weight * recon_loss + latent_loss_weight * latent_loss + gradient_loss_weight * gradient_loss) g_loss.backward() if scheduler_G is not None: scheduler_G.step() optimizer_G.step() part_mse_sum = recon_loss.item() * img.shape[0] part_mse_n = img.shape[0] comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n} comm = dist.all_gather(comm) for part in comm: mse_sum += part["mse_sum"] mse_n += part["mse_n"] part_g_sum = gradient_loss.item() * img.shape[0] part_g_n = img.shape[0] g_comm = {"g_sum": part_g_sum, "g_n": part_g_n} g_comm = dist.all_gather(g_comm) for part in g_comm: g_sum += part["g_sum"] g_n += part["g_n"] if dist.is_primary(): lr = optimizer_G.param_groups[0]["lr"] loader_g.set_description(( f"Denerator epoch: {(epoch + 1) // n_critic + 1}; mse: {recon_loss.item():.5f}; " f"latent: {latent_loss.item():.3f}; gradient: {g_sum / g_n:.5f}; avg mse: {mse_sum / mse_n:.5f}; " f"lr: {lr:.5f}")) if i % 100 == 0: generator.eval() sample = img[:sample_size] label_sample = label[:sample_size] sample0 = sample[:, 0, :, :].unsqueeze(dim=1) sample1 = sample[:, 1, :, :].unsqueeze(dim=1) a = (sample1.data.cpu()).numpy() with torch.no_grad(): out, _ = generator(sample) utils.save_image( torch.cat([sample0, sample1, label_sample, out], 0), f"sample/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png", nrow=sample_size, normalize=True, range=(-1, 1), ) generator.train()
def main(args): device = "cuda" args.distributed = dist.get_world_size() > 1 normMean = [0.5] normStd = [0.5] normTransform = transforms.Normalize(normMean, normStd) transform = transforms.Compose([ transforms.Resize(args.size), transforms.ToTensor(), normTransform, ]) txt_path = './data/train.txt' images_path = './data' labels_path = './data' dataset = txtDataset(txt_path, images_path, labels_path, transform=transform) sampler = dist.data_sampler(dataset, shuffle=True, distributed=args.distributed) loader = DataLoader(dataset, batch_size=batch_size // args.n_gpu, sampler=sampler, num_workers=16) # Initialize generator and discriminator DpretrainedPath = './checkpoint/vqvae2GAN_040.pt' GpretrainedPath = './checkpoint/vqvae_040.pt' discriminator = Discriminator() generator = Generator() if os.path.exists(DpretrainedPath): print('Loading model weights...') discriminator.load_state_dict( torch.load(DpretrainedPath)['discriminator']) print('done') if os.path.exists(GpretrainedPath): print('Loading model weights...') generator.load_state_dict(torch.load(GpretrainedPath)) print('done') discriminator = discriminator.to(device) generator = generator.to(device) if args.distributed: discriminator = nn.parallel.DistributedDataParallel( discriminator, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) if args.distributed: generator = nn.parallel.DistributedDataParallel( generator, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), ) optimizer_D = optim.Adam(discriminator.parameters(), lr=args.lr) optimizer_G = optim.Adam(generator.parameters(), lr=args.lr) scheduler_D = None scheduler_G = None if args.sched == "cycle": scheduler_D = CycleScheduler( optimizer_D, args.lr, n_iter=len(loader) * args.epoch, momentum=None, warmup_proportion=0.05, ) scheduler_G = CycleScheduler( optimizer_G, args.lr, n_iter=len(loader) * args.epoch, momentum=None, warmup_proportion=0.05, ) for i in range(41, args.epoch): train(i, loader, discriminator, generator, scheduler_D, scheduler_G, optimizer_D, optimizer_G, device) if dist.is_primary(): torch.save( { 'generator': generator.state_dict(), 'discriminator': discriminator.state_dict(), 'g_optimizer': optimizer_G.state_dict(), 'd_optimizer': optimizer_D.state_dict(), }, f'checkpoint/vqvae2GAN_{str(i + 1).zfill(3)}.pt', ) if (i + 1) % n_critic == 0: torch.save(generator.state_dict(), f"checkpoint/vqvae_{str(i + 1).zfill(3)}.pt")
def main(args): device = "cuda" args.distributed = dist.get_world_size() > 1 transform = transforms.Compose([ transforms.Resize(args.size), transforms.CenterCrop(args.size), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) dataset = OffsetDataset(args.path, transform=transform, offset=args.offset) sampler = dist.data_sampler(dataset, shuffle=True, distributed=args.distributed) loader = DataLoader(dataset, batch_size=args.bsize // args.n_gpu, sampler=sampler, num_workers=2) # Load pre-trained VQVAE vqvae = VQVAE().to(device) try: vqvae.load_state_dict(torch.load(args.ckpt)) except: print( "Seems the checkpoint was trained with data parallel, try loading it that way" ) weights = torch.load(args.ckpt) renamed_weights = {} for key, value in weights.items(): renamed_weights[key.replace('module.', '')] = value weights = renamed_weights vqvae.load_state_dict(weights) # Init offset encoder model = OffsetNetwork(vqvae).to(device) if args.distributed: model = nn.parallel.DistributedDataParallel( model, device_ids=[dist.get_local_rank()], output_device=dist.get_local_rank(), find_unused_parameters=True) optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = None if args.sched == "cycle": scheduler = CycleScheduler( optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None, warmup_proportion=0.05, ) for i in range(args.epoch): train(i, loader, model, optimizer, scheduler, device) if dist.is_primary(): torch.save(model.state_dict(), f"checkpoint/offset_enc_{str(i + 1).zfill(3)}.pt")
def train(epoch, loader, model, optimizer, scheduler, device): if dist.is_primary(): loader = tqdm(loader) criterion = nn.MSELoss() latent_loss_weight = 0.25 sample_size = 25 mse_sum = 0 mse_n = 0 for i, (frames, next_frames) in enumerate(loader): model.zero_grad() frames = frames.to(device) next_frames = next_frames.to(device) out, latent_loss = model(frames, next_frames) recon_loss = criterion(out, next_frames) latent_loss = latent_loss.mean() loss = recon_loss + latent_loss_weight * latent_loss loss.backward() if scheduler is not None: scheduler.step() optimizer.step() part_mse_sum = recon_loss.item() * frames.shape[0] part_mse_n = frames.shape[0] comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n} comm = dist.all_gather(comm) for part in comm: mse_sum += part["mse_sum"] mse_n += part["mse_n"] if dist.is_primary(): lr = optimizer.param_groups[0]["lr"] loader.set_description(( f"epoch: {epoch + 1}; mse: {recon_loss.item():.5f}; " f"latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; " f"lr: {lr:.5f}")) if i % 100 == 0: model.eval() sample_frames = frames[:sample_size] sample_next = next_frames[:sample_size] with torch.no_grad(): out, _ = model(sample_frames, sample_next) utils.save_image( torch.cat([sample_frames, out], 0), f"offset_sample/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png", nrow=sample_size, normalize=True, range=(-1, 1), ) model.train()
def train(epoch, loader, model, optimizer, scheduler, device): # if dist.is_primary(): # loader = tqdm(loader) loader = tqdm(loader) # criterion = nn.MSELoss() criterion = nn.L1Loss() gdloss = GDLoss() recon_loss_weight = 0.4 latent_loss_weight = 0.2 gradient_loss_weight = 0.4 sample_size = batch_size mse_sum = 0 mse_n = 0 g_sum = 0 g_n = 0 for i, (img, label, label_path, class_name) in enumerate(loader): model.zero_grad() img = img.to(device) label = label.to(device) gdloss.conv_x = gdloss.conv_x.to(device) gdloss.conv_y = gdloss.conv_y.to(device) out, latent_loss = model(img) recon_loss = criterion(out, label) gradient_loss = gdloss(out, label) gradient_loss = gradient_loss.mean() latent_loss = latent_loss.mean() loss = recon_loss_weight * recon_loss + latent_loss_weight * latent_loss + gradient_loss_weight * gradient_loss loss.backward() if scheduler is not None: scheduler.step() optimizer.step() part_mse_sum = recon_loss.item() * img.shape[0] part_mse_n = img.shape[0] comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n} comm = dist.all_gather(comm) for part in comm: mse_sum += part["mse_sum"] mse_n += part["mse_n"] part_g_sum = gradient_loss.item() * img.shape[0] part_g_n = img.shape[0] g_comm = {"g_sum": part_g_sum, "g_n": part_g_n} g_comm = dist.all_gather(g_comm) for part in g_comm: g_sum += part["g_sum"] g_n += part["g_n"] if dist.is_primary(): lr = optimizer.param_groups[0]["lr"] loader.set_description(( f"epoch: {epoch + 1}; mse: {recon_loss.item():.5f}; " f"latent: {latent_loss.item():.3f}; gradient: {g_sum / g_n:.5f}; avg mse: {mse_sum / mse_n:.5f}; " f"lr: {lr:.5f}")) if i % 100 == 0: model.eval() sample = img[:sample_size] label_sample = label[:sample_size] sample0 = sample[:, 0, :, :].unsqueeze(dim=1) sample1 = sample[:, 1, :, :].unsqueeze(dim=1) a = (sample1.data.cpu()).numpy() with torch.no_grad(): out, _ = model(sample) utils.save_image( torch.cat([sample0, sample1, label_sample, out], 0), f"sample/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png", nrow=sample_size, normalize=True, range=(-1, 1), ) model.train()
def train_epoch(self, epoch): if dist.is_primary(): loader = tqdm(self.dataloader) else: loader = self.dataloader criterion = nn.MSELoss() latent_loss_weight = 0.25 sample_size = 25 mse_sum = 0 mse_n = 0 for i, img in enumerate(loader): self.model.zero_grad() img = img.to(self.device) outputs = self.model(img) out, latent_loss = outputs[:2] recon_loss = criterion(out, img) latent_loss = latent_loss.mean() loss = recon_loss + latent_loss_weight * latent_loss if self.args.fp16: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backword() else: loss.backward() self.optimizer.step() if self.scheduler is not None: self.scheduler.step() part_mse_sum = recon_loss.item() * img.shape[0] part_mse_n = img.shape[0] comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n} comm = dist.all_gather(comm) for part in comm: mse_sum += part["mse_sum"] mse_n += part["mse_n"] self.global_step += 1 if dist.is_primary( ) and self.global_step % self.args.logging_steps == 0: print("global_step", self.global_step, "mse", "{:.4g}".format(recon_loss.item()), "latent", "{:.4g}".format(latent_loss.item()), "avg_mse", "{:.4g}".format(mse_sum / mse_n), "lr", "{:.4g}".format(self.optimizer.param_groups[0]["lr"]), file=sys.stderr, flush=True) if dist.is_primary( ) and self.global_step % self.args.save_steps == 0: self.save_checkpoint() if dist.is_primary( ) and self.global_step % self.args.eval_steps == 0: self.model.eval() sample = img[:sample_size] with torch.no_grad(): out = self.model(sample)[0] utils.save_image( torch.cat([sample, out], 0), f"{self.args.eval_path}/{str(epoch + 1).zfill(5)}_{str(i).zfill(5)}.png", nrow=sample_size, normalize=True, range=(-1, 1), ) self.model.train()
def run(self, args): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') transform = [transforms.ToTensor()] if args.normalize: transform.append( transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])) transform = transforms.Compose(transform) dataset = datasets.ImageFolder(args.path, transform=transform) sampler = dist_fn.data_sampler(dataset, shuffle=True, distributed=args.distributed) loader = DataLoader(dataset, batch_size=args.batch_size // args.n_gpu, sampler=sampler, num_workers=args.num_workers) self = self.to(device) if args.distributed: self = nn.parallel.DistributedDataParallel( self, device_ids=[dist_fn.get_local_rank()], output_device=dist_fn.get_local_rank()) optimizer = args.optimizer(self.parameters(), lr=args.lr) scheduler = None if args.sched == 'cycle': scheduler = CycleScheduler( optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None, warmup_proportion=0.05, ) start = str(time()) run_path = os.path.join('runs', start) sample_path = os.path.join(run_path, 'sample') checkpoint_path = os.path.join(run_path, 'checkpoint') os.mkdir(run_path) os.mkdir(sample_path) os.mkdir(checkpoint_path) with Progress() as progress: train = progress.add_task(f'epoch 1/{args.epoch}', total=args.epoch, columns='epochs') steps = progress.add_task('', total=len(dataset) // args.batch_size) for epoch in range(args.epoch): progress.update(steps, completed=0, refresh=True) for recon_loss, latent_loss, avg_mse, lr in self.train_epoch( epoch, loader, optimizer, scheduler, device, sample_path): progress.update( steps, description= f'mse: {recon_loss:.5f}; latent: {latent_loss:.5f}; avg mse: {avg_mse:.5f}; lr: {lr:.5f}' ) progress.advance(steps) if dist_fn.is_primary(): torch.save( self.state_dict(), os.path.join(checkpoint_path, f'vqvae_{str(epoch + 1).zfill(3)}.pt')) progress.update(train, description=f'epoch {epoch + 1}/{args.epoch}') progress.advance(train)