Path('/'.join(args.check_point.split('/')[:-1])).mkdir(parents=True, exist_ok=True) Path(args.logs_root).mkdir(parents=True, exist_ok=True) optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, args.epochs) start_epoch = 0 if args.continue_train: start_epoch = load_weights(state_dict_path=args.check_point, models=model, model_names='model', optimizers=optimizer, optimizer_names='optimizer', return_val='start_epoch') pbar = tqdm(range(start_epoch, args.epochs)) for epoch in pbar: model.train() train_losses = [] for x1, x2 in train_loader: x1, x2 = x1.to(device), x2.to(device) z1, z2, p1, p2 = model(x1, x2) if args.symmetric: loss = (model.module.cosine_loss(p1, z2) + model.module.cosine_loss(p2, z1)) / 2 else: loss = model.module.cosine_loss(p1, z2)
lr_scheduler_D = torch.optim.lr_scheduler.LambdaLR( optimizer_D, lr_lambda=LambdaLR(args.n_epochs, args.starting_epoch, args.decay_epoch).step ) criterion_GAN = torch.nn.MSELoss() criterion_cycle = torch.nn.L1Loss() criterion_identity = torch.nn.L1Loss() pool_A = ReplayBuffer() pool_B = ReplayBuffer() if args.continue_train: args.start_epoch = load_weights(state_dict_path=f"{args.checkpoint_dir}/{args.data_root.split('/')[-1]}.pth", models=[D_A, D_B, G_AB, G_BA], model_names=['D_A', 'D_B', 'G_AB', 'G_BA'], optimizers=[optimizer_G, optimizer_D], optimizer_names=['optimizer_G', 'optimizer_D'], return_val='start_epoch') pbar = tqdm( range(args.starting_epoch, args.n_epochs), total=(args.n_epochs - args.starting_epoch) ) dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True) sampled_idx = get_random_ids(len(dataloader), args.sample_batches) for epoch in pbar: G_AB.train() G_BA.train() D_A.train() D_B.train() disc_A_losses, gen_A_losses, disc_B_losses, gen_B_losses = [], [], [], [] (gen_ad_A_losses,
_ = _display.start() if args.img_input: env.reset() env = PixelObservationWrapper(env) agent = Agent(env, args.alpha, args.beta, args.hidden_dims, args.tau, args.batch_size, args.gamma, args.d, 0, args.max_size, args.c * max_action, args.sigma * max_action, args.one_device, args.log_dir, args.checkpoint_dir, args.img_input, args.in_channels, args.order, args.depth, args.multiplier, args.action_embed_dim, args.hidden_dim, args.crop_dim, args.img_feature_dim) best_score = env.reward_range[0] load_weights(args.checkpoint_dir, [agent.actor], ['actor']) episodes = tqdm(range(args.n_episodes)) for e in episodes: # resetting state = env.reset() if args.img_input: state_queue = deque([ preprocess_img(state['pixels'], args.crop_dim) for _ in range(args.order) ], maxlen=args.order) state = torch.cat(list(state_queue), 1).cpu().numpy() done, score = False, 0 while not done:
def train(): # creating dirs if needed Path(opt.checkpoint_dir).mkdir(parents=True, exist_ok=True) Path(opt.log_dir).mkdir(parents=True, exist_ok=True) if opt.save_local_samples: Path(opt.sample_dir).mkdir(parents=True, exist_ok=True) writer = SummaryWriter(opt.log_dir + f'/{int(datetime.now().timestamp()*1e6)}') device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if opt.conditional: loader, num_classes = get_cifar_loader(opt.batch_size, opt.crop_size, opt.img_size) else: num_classes = 0 loader = get_celeba_loaders(opt.data_path, opt.img_ext, opt.crop_size, opt.img_size, opt.batch_size, opt.download) G = torch.nn.DataParallel(Generator(opt.h_dim, opt.z_dim, opt.img_channels, opt.img_size, num_classes=num_classes), device_ids=opt.devices).to(device) D = torch.nn.DataParallel(Discriminator(opt.img_channels, opt.h_dim, opt.img_size, num_classes=num_classes), device_ids=opt.devices).to(device) if opt.criterion == 'wasserstein-gp': criterion = Wasserstein_GP_Loss(opt.lambda_gp) elif opt.criterion == 'hinge': criterion = Hinge_loss() else: raise NotImplementedError('Please choose criterion [hinge, wasserstein-gp]') optimizer_G = torch.optim.Adam(G.parameters(), lr=opt.lr_G, betas=opt.betas) optimizer_D = torch.optim.Adam(D.parameters(), lr=opt.lr_D, betas=opt.betas) # sample fixed z to see progress through training fixed_z = torch.randn(opt.sample_size, opt.z_dim).to(device) if opt.conditional: fixed_fake_labels = get_random_labels(num_classes, opt.sample_size, device) # if continue training, load weights, otherwise starting epoch=0 if opt.continue_train: start_epoch = load_weights(state_dict_path=opt.checkpoint_dir, models=[D, G], model_names=['D', 'G'], optimizers=[optimizer_D, optimizer_G], optimizer_names=['optimizer_D', 'optimizer_G'], return_val='start_epoch') else: start_epoch = 0 pbar = tqdm(range(start_epoch, opt.n_epochs)) ckpt_iter = 0 for epoch in pbar: d_losses, g_losses = [], [] D.train() G.train() for batch_idx, data in enumerate(loader): # data prep if opt.conditional: reals, labels = [d.to(device) for d in data] fake_labels = get_random_labels(num_classes, reals.size(0), device) else: reals = data.to(device) fake_labels, labels = None, None z = torch.randn(reals.size(0), opt.z_dim).to(device) # forward generator optimizer_G.zero_grad() fakes = G(z, fake_labels) g_loss = criterion(fake_logits=D(fakes, fake_labels), mode='generator') # update gen g_loss.backward() optimizer_G.step() # forward discriminator optimizer_D.zero_grad() logits_fake = D(fakes.detach(), fake_labels) logits_real = D(reals, labels) # compute loss & update disc d_loss = criterion(fake_logits=logits_fake, real_logits=logits_real, mode='discriminator') # if wgangp, calculate gradient penalty and add to current d_loss if opt.criterion == 'wasserstein-gp': interpolates = criterion.get_interpolates(reals, fakes) interpolated_logits = D(interpolates, labels) grad_penalty = criterion.grad_penalty_loss(interpolates, interpolated_logits) d_loss = d_loss + grad_penalty d_loss.backward() optimizer_D.step() # logging d_losses.append(d_loss.item()) g_losses.append(g_loss.item()) pbar.set_postfix({ 'G Loss': g_loss.item(), 'D Loss': d_loss.item(), 'Batch ID': batch_idx}) # tensorboard logging samples, not logging first iteration if batch_idx % opt.cpt_interval == 0: ckpt_iter += 1 G.eval() # generate image from fixed noise vector with torch.no_grad(): if opt.conditional: samples = G(fixed_z, fixed_fake_labels) else: samples = G(fixed_z) samples = (samples + 1) / 2 # save locally if opt.save_local_samples: torchvision.utils.save_image(samples, f'{opt.sample_dir}/Interval_{ckpt_iter}.{opt.img_ext}') # save sample and loss to tensorboard writer.add_image('Generated Images', torchvision.utils.make_grid(samples), global_step=ckpt_iter) writer.add_scalars("Train Losses", { "Discriminator Loss": sum(d_losses) / len(d_losses), "Generator Loss": sum(g_losses) / len(g_losses) }, global_step=ckpt_iter) # resetting G.train() # printing loss and save weights tqdm.write( f'Epoch {epoch + 1}/{opt.n_epochs}, \ Discriminator loss: {sum(d_losses) / len(d_losses):.3f}, \ Generator Loss: {sum(g_losses) / len(g_losses):.3f}' ) torch.save({ 'D': D.state_dict(), 'G': G.state_dict(), 'optimizer_D': optimizer_D.state_dict(), 'optimizer_G': optimizer_G.state_dict(), 'start_epoch': epoch + 1 }, f"{opt.checkpoint_dir}/SA_GAN.pth")
criterion = MoCoLoss(args.temperature) optimizer = torch.optim.SGD(f_q.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd) scheduler = torch.optim.lr_scheduler.MultiplicativeLR( optimizer, lambda epoch: 0.1 if epoch in (120, 160) else 1) memo_bank = MemoryBank(f_k, device, momentum_loader, args.K) writer = SummaryWriter(args.logs_root + f'/{int(datetime.now().timestamp()*1e6)}') start_epoch = 0 if args.continue_train: start_epoch = load_weights(state_dict_path=args.check_point, models=[f_q, f_k], model_names=['f_q', 'f_k'], optimizers=[optimizer], optimizer_names=['optimizer'], return_val='start_epoch') pbar = tqdm(range(start_epoch, args.epochs)) for epoch in pbar: train_losses = [] f_q.train() f_k.train() for x1, x2 in train_loader: q1, q2 = f_q(x1), f_q(x2) with torch.no_grad(): momentum_update(f_k, f_q, args.m) k1, k2 = f_k(x1), f_k(x2) loss = criterion(q1, k2, memo_bank) + criterion(q2, k1, memo_bank) optimizer.zero_grad()