def train_bottleneck_model(nb_epochs, batch_size): """Train bottleneck model""" # Load the training and validation bottleneck features train_data = load_np_array(cfg.bf_train_path) val_data = load_np_array(cfg.bf_val_path) # Get training and validation labels for bottleneck features # (we know the images are in sorted order) train_labels = [] val_labels = [] k = 0 for class_name in cfg.classes: train_labels += [k] * len( os.listdir(osp.join(cfg.train_data_dir, class_name))) val_labels += [k] * len( os.listdir(osp.join(cfg.val_data_dir, class_name))) k += 1 # Create custom model model = BottleneckModel.build(input_shape=train_data.shape[1:], nb_classes=cfg.nb_classes) # If multiclass, encode the labels to 1-to-K binary format if cfg.nb_classes != 2: train_labels = np_utils.to_categorical(train_labels, cfg.nb_classes) val_labels = np_utils.to_categorical(val_labels, cfg.nb_classes) # Compile model model.compile(loss='{}_crossentropy'.format(cfg.classmode), optimizer=Adam(lr=5e-5), metrics=['accuracy']) # Print model summary model.summary() # Save weights with best val loss model_checkpoint = ModelCheckpoint(cfg.model_weights_path, save_best_only=True, save_weights_only=True, monitor='val_loss') # Decay learning rate by half every 20 epochs decay = decay_lr(20, 0.5) # Start training history = model.fit(train_data, train_labels, nb_epoch=nb_epochs, batch_size=batch_size, validation_data=(val_data, val_labels), callbacks=[model_checkpoint, decay]) # Load best weights to get val data predictions model.load_weights(cfg.model_weights_path) # Get val data predictions val_pred_proba = model.predict(val_data) return model, history, val_pred_proba
def run(net, loader, optimizer, scheduler, tracker, train=False, has_answers=True, prefix='', epoch=0): """ Run an epoch over the given loader """ assert not (train and not has_answers) if train: net.train() tracker_class, tracker_params = tracker.MovingMeanMonitor, { 'momentum': 0.99 } else: net.eval() tracker_class, tracker_params = tracker.MeanMonitor, {} answ = [] idxs = [] accs = [] # set learning rate decay policy if epoch < len(config.gradual_warmup_steps ) and config.schedule_method == 'warm_up': utils.set_lr(optimizer, config.gradual_warmup_steps[epoch]) utils.print_lr(optimizer, prefix, epoch) elif (epoch in config.lr_decay_epochs ) and train and config.schedule_method == 'warm_up': utils.decay_lr(optimizer, config.lr_decay_rate) utils.print_lr(optimizer, prefix, epoch) else: utils.print_lr(optimizer, prefix, epoch) loader = tqdm(loader, desc='{} E{:03d}'.format(prefix, epoch), ncols=0) loss_tracker = tracker.track('{}_loss'.format(prefix), tracker_class(**tracker_params)) acc_tracker = tracker.track('{}_acc'.format(prefix), tracker_class(**tracker_params)) for v, q, a, b, idx, v_mask, q_mask, q_len in loader: var_params = { 'requires_grad': False, } v = Variable(v.cuda(), **var_params) q = Variable(q.cuda(), **var_params) a = Variable(a.cuda(), **var_params) b = Variable(b.cuda(), **var_params) q_len = Variable(q_len.cuda(), **var_params) v_mask = Variable(v_mask.cuda(), **var_params) q_mask = Variable(q_mask.cuda(), **var_params) out = net(v, b, q, v_mask, q_mask, q_len) if has_answers: answer = utils.process_answer(a) loss = utils.calculate_loss(answer, out, method=config.loss_method) acc = utils.batch_accuracy(out, answer).data.cpu() if train: optimizer.zero_grad() loss.backward() # print gradient if config.print_gradient: utils.print_grad([(n, p) for n, p in net.named_parameters() if p.grad is not None]) # clip gradient clip_grad_norm_(net.parameters(), config.clip_value) optimizer.step() if (config.schedule_method == 'batch_decay'): scheduler.step() else: # store information about evaluation of this minibatch _, answer = out.data.cpu().max(dim=1) answ.append(answer.view(-1)) if has_answers: accs.append(acc.view(-1)) idxs.append(idx.view(-1).clone()) if has_answers: loss_tracker.append(loss.item()) acc_tracker.append(acc.mean()) fmt = '{:.4f}'.format loader.set_postfix(loss=fmt(loss_tracker.mean.value), acc=fmt(acc_tracker.mean.value)) if not train: answ = list(torch.cat(answ, dim=0)) if has_answers: accs = list(torch.cat(accs, dim=0)) else: accs = [] idxs = list(torch.cat(idxs, dim=0)) #print('{} E{:03d}:'.format(prefix, epoch), ' Total num: ', len(accs)) #print('{} E{:03d}:'.format(prefix, epoch), ' Average Score: ', float(sum(accs) / len(accs))) return answ, accs, idxs
def train(): parser = argparse.ArgumentParser() parser.add_argument("-conf", type=str) parser.add_argument("--debug", action="store_true") args = parser.parse_args() debug = args.debug config = configparser.ConfigParser() config.read(args.conf) log_path = config["log"]["log_path"] log_step = int(config["log"]["log_step"]) log_dir = os.path.dirname(log_path) os.makedirs(log_dir, exist_ok=True) save_prefix = config["save"]["save_prefix"] save_format = save_prefix + ".network.epoch{}" optimizer_save_format = save_prefix + ".optimizer.epoch{}" save_step = int(config["save"]["save_step"]) save_dir = os.path.dirname(save_prefix) os.makedirs(save_dir, exist_ok=True) num_epochs = int(config["train"]["num_epochs"]) batch_size = int(config["train"]["batch_size"]) decay_start_epoch = int(config["train"]["decay_start_epoch"]) decay_rate = float(config["train"]["decay_rate"]) vocab_size = int(config["vocab"]["vocab_size"]) ls_prob = float(config["train"]["ls_prob"]) distill_weight = float(config["distill"]["distill_weight"]) if debug: logging.basicConfig(format="%(asctime)s %(message)s", level=logging.INFO) # to stdout else: logging.basicConfig(filename=log_path, format="%(asctime)s %(message)s", level=logging.DEBUG) model = AttnModel(args.conf) model.apply(init_weight) model.to(device) optimizer = optim.Adam(model.parameters(), weight_decay=1e-5) dataset = SpeechDataset(args.conf) dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=2, pin_memory=True) num_steps = len(dataloader) for epoch in range(num_epochs): loss_sum = 0 for step, data in enumerate(dataloader): loss_step = train_step(model, optimizer, data, vocab_size, ls_prob, distill_weight) loss_sum += loss_step if (step + 1) % log_step == 0: logging.info( "epoch = {:>2} step = {:>6} / {:>6} loss = {:.3f}".format( epoch + 1, step + 1, num_steps, loss_sum / log_step)) loss_sum = 0 if epoch == 0 or (epoch + 1) % save_step == 0: save_path = save_format.format(epoch + 1) torch.save(model.state_dict(), save_path) optimizer_save_path = optimizer_save_format.format(epoch + 1) torch.save(optimizer.state_dict(), optimizer_save_path) logging.info("model saved to: {}".format(save_path)) logging.info("optimizer saved to: {}".format(optimizer_save_path)) update_epoch(model, epoch + 1) decay_lr(optimizer, epoch + 1, decay_start_epoch, decay_rate)
def run(net, loader, optimizer, scheduler, tracker, train=False, prefix='', epoch=0): """ Run an epoch over the given loader """ if train: net.train() # tracker_class, tracker_params = tracker.MovingMeanMonitor, {'momentum': 0.99} else: net.eval() tracker_class, tracker_params = tracker.MeanMonitor, {} # set learning rate decay policy if epoch < len(config.gradual_warmup_steps ) and config.schedule_method == 'warm_up': utils.set_lr(optimizer, config.gradual_warmup_steps[epoch]) elif (epoch in config.lr_decay_epochs ) and train and config.schedule_method == 'warm_up': utils.decay_lr(optimizer, config.lr_decay_rate) utils.print_lr(optimizer, prefix, epoch) loader = tqdm(loader, desc='{} E{:03d}'.format(prefix, epoch), ncols=0) loss_tracker = tracker.track('{}_loss'.format(prefix), tracker_class(**tracker_params)) acc_tracker = tracker.track('{}_acc'.format(prefix), tracker_class(**tracker_params)) for v, q, a, b, idx, v_mask, q_mask, q_len in loader: var_params = { 'requires_grad': False, } v = Variable(v.cuda(), **var_params) q = Variable(q.cuda(), **var_params) a = Variable(a.cuda(), **var_params) b = Variable(b.cuda(), **var_params) q_len = Variable(q_len.cuda(), **var_params) v_mask = Variable(v_mask.cuda(), **var_params) q_mask = Variable(q_mask.cuda(), **var_params) out = net(v, b, q, v_mask, q_mask, q_len) answer = utils.process_answer(a) loss = utils.calculate_loss(answer, out, method=config.loss_method) acc = utils.batch_accuracy(out, answer).data.cpu() if train: optimizer.zero_grad() loss.backward() # clip gradient clip_grad_norm_(net.parameters(), config.clip_value) optimizer.step() if config.schedule_method == 'batch_decay': scheduler.step() loss_tracker.append(loss.item()) acc_tracker.append(acc.mean()) fmt = '{:.4f}'.format loader.set_postfix(loss=fmt(loss_tracker.mean.value), acc=fmt(acc_tracker.mean.value)) return acc_tracker.mean.value, loss_tracker.mean.value
def train(self): loss = {} nrow = min(int(np.sqrt(self.batch_size)), 8) n_samples = nrow * nrow iter_per_epoch = len(self.train_loader.dataset) // self.batch_size max_iteration = self.num_epoch * iter_per_epoch lambda_l1 = 0.2 print('Start training...') for epoch in tqdm(range(self.resume_epoch, self.num_epoch)): for i, (x_real, noise, label) in enumerate(tqdm(self.train_loader)): # lr decay if epoch * iter_per_epoch + i >= self.lr_decay_start: utils.decay_lr(self.g_optimizer, max_iteration, self.lr_decay_start, self.g_lr) utils.decay_lr(self.d_optimizer, max_iteration, self.lr_decay_start, self.d_lr) if i % 1000 == 0: print('d_lr / g_lr is updated to {:.8f} / {:.8f} !'. format(self.d_optimizer.param_groups[0]['lr'], self.g_optimizer.param_groups[0]['lr'])) x_real = x_real.to(self.device) noise = noise.to(self.device) label = label.to(self.device) #''' # =================================================================================== # # 1. Train the discriminator # # =================================================================================== # for param in self.D.parameters(): param.requires_grad = True dis_real, real_list = self.D(x_real, label) real_list = [h.detach() for h in real_list] x_fake = self.G(noise, label).detach() dis_fake, _ = self.D(x_fake, label) d_loss_real, d_loss_fake = self.dis_hinge(dis_real, dis_fake) # sample try: x_real2, label2 = next(real_iter) except: real_iter = iter(self.real_loader) x_real2, label2 = next(real_iter) x_real2 = x_real2.to(self.device) label2 = label2.to(self.device) noise2 = torch.FloatTensor(utils.truncated_normal(self.batch_size*self.z_dim)) \ .view(self.batch_size, self.z_dim).to(self.device) # noise2 = torch.randn(self.batch_size, self.z_dim).to(self.device) dis_real2, _ = self.D(x_real2, label2) x_fake2 = self.G(noise2, label2).detach() dis_fake2, _ = self.D(x_fake2, label2) d_loss_real2, d_loss_fake2 = self.dis_hinge( dis_real2, dis_fake2) # Backward and optimize. d_loss = d_loss_real + d_loss_fake + 0.2 * (d_loss_real2 + d_loss_fake2) self.d_optimizer.zero_grad() d_loss.backward() self.d_optimizer.step() # Logging. loss['D/loss_real'] = d_loss_real.item() loss['D/loss_fake'] = d_loss_fake.item() loss['D/loss_real2'] = d_loss_real2.item() loss['D/loss_fake2'] = d_loss_fake2.item() # =================================================================================== # # 2. Train the generator # # =================================================================================== # #''' x_fake = self.G(noise, label) for param in self.D.parameters(): param.requires_grad = False dis_fake, fake_list = self.D(x_fake, label) g_loss_feat = self.KDLoss(real_list, fake_list) g_loss_pix = F.l1_loss(x_fake, x_real) g_loss = g_loss_feat + lambda_l1 * g_loss_pix loss['G/loss_ft'] = g_loss_feat.item() loss['G/loss_l1'] = g_loss_pix.item() if (i + 1) % self.n_critic == 0: dis_fake, _ = self.D(x_fake, label) g_loss_fake = self.gen_hinge(dis_fake) g_loss += self.lambda_gan * g_loss_fake # sample noise2 = torch.FloatTensor(utils.truncated_normal(self.batch_size*self.z_dim)) \ .view(self.batch_size, self.z_dim).to(self.device) # noise2 = torch.randn(self.batch_size, self.z_dim).to(self.device) x_fake2 = self.G(noise2, label2) dis_fake2, _ = self.D(x_fake2, label2) g_loss_fake2 = self.gen_hinge(dis_fake2) g_loss += 0.2 * self.lambda_gan * g_loss_fake2 loss['G/loss_fake'] = g_loss_fake.item() loss['G/loss_fake2'] = g_loss_fake2.item() self.g_optimizer.zero_grad() g_loss.backward() self.g_optimizer.step() # =================================================================================== # # 3. Miscellaneous # # =================================================================================== # # Print out training information. if (i + 1) % self.log_step == 0: log = "[{}/{}]".format(epoch, i) for tag, value in loss.items(): log += ", {}: {:.4f}".format(tag, value) print(log) if self.use_tensorboard: for tag, value in loss.items(): self.logger.scalar_summary(tag, value, i + 1) if epoch == 0 or (epoch + 1) % self.sample_step == 0: with torch.no_grad(): """ # randomly sampled noise noise = torch.FloatTensor(utils.truncated_normal(n_samples*self.z_dim)) \ .view(n_samples, self.z_dim).to(self.device) label = label[:nrow].repeat(nrow) #label = np.random.choice(1000, nrow, replace=False) #label = torch.tensor(label).repeat(10).to(self.device) x_sample = self.G(noise, label) sample_path = os.path.join(self.sample_dir, '{}-sample.png'.format(epoch+1)) save_image(utils.denorm(x_sample.cpu()), sample_path, nrow=nrow, padding=0) """ # recons n = min(x_real.size(0), 8) comparison = torch.cat([x_real[:n], x_fake[:n]]) sample_path = os.path.join( self.sample_dir, '{}-train.png'.format(epoch + 1)) save_image(utils.denorm(comparison.cpu()), sample_path) print('Save fake images into {}...'.format(sample_path)) # noise2 comparison = torch.cat([x_real2[:n], x_fake2[:n]]) sample_path = os.path.join( self.sample_dir, '{}-random.png'.format(epoch + 1)) save_image(utils.denorm(comparison.cpu()), sample_path) print('Save fake images into {}...'.format(sample_path)) # noise sampled from BigGAN's test set try: x_real, noise, label = next(test_iter) except: test_iter = iter(self.test_loader) x_real, noise, label = next(test_iter) noise = noise.to(self.device) label = label.to(self.device) x_fake = self.G(noise, label).detach().cpu() n = min(x_real.size(0), 8) comparison = torch.cat([x_real[:n], x_fake[:n]]) sample_path = os.path.join(self.sample_dir, '{}-test.png'.format(epoch + 1)) save_image(utils.denorm(comparison.cpu()), sample_path) print('Save fake images into {}...'.format(sample_path)) lambda_l1 = max(0.00, lambda_l1 - 0.01) # Save model checkpoints. if (epoch + 1) % self.model_save_step == 0: utils.save_model(self.model_save_dir, epoch + 1, self.G, self.D, self.g_optimizer, self.d_optimizer)
def train(): dir_name, save_path = getSavePath() if args.norm == 'sn': netG, netD = networks.getGD_SN(args.structure, args.dataset, args.Gnum_features, args.Dnum_features) elif args.norm == 'bn': netG, netD = networks.getGD_batchnorm(args.structure, args.dataset, args.Gnum_features, args.Dnum_features, dim_z=args.input_dim) if args.ema_trick: ema_netG_9999 = copy.deepcopy(netG) netG.cuda() netD.cuda() g_optimizer = torch.optim.Adam(netG.parameters(), lr=args.g_lr, betas=(args.beta1, args.beta2)) d_optimizer = torch.optim.Adam(netD.parameters(), lr=args.d_lr, betas=(args.beta1, args.beta2)) g_losses, d_losses = [], [] grad_normD, grad_normG = [], [] loader = datasets.getDataLoader(args.dataset, args.image_size, batch_size=args.batch_size, shuffle=not args.fixz) data_iter = iter(loader) data_num = len(loader.dataset) zs = TensorDataset(torch.randn(data_num, args.input_dim)) zloader = DataLoader(zs, batch_size=args.batch_size, shuffle=not args.fixz, num_workers=0, drop_last=True) z_iter = iter(zloader) for i in range(1, args.num_iters + 1): if i >= args.lr_decay_start: utils.decay_lr(g_optimizer, args.num_iters, args.lr_decay_start, args.g_lr) utils.decay_lr(d_optimizer, args.num_iters, args.lr_decay_start, args.d_lr) if i == 1: torch.save(netG.state_dict(), save_path + 'G_epoch0.pth') torch.save(netD.state_dict(), save_path + 'D_epoch0.pth') # G-step for _ in range(args.g_freq): if args.fixz: try: z = next(z_iter)[0].cuda() except: z_iter = iter(zloader) z = next(z_iter)[0].cuda() else: z = torch.randn(args.batch_size, args.input_dim, device=device) g_optimizer.zero_grad() x_hat = netG(z) y_hat = netD(x_hat) g_loss = get_gloss(args.losstype, y_hat) g_losses.append(g_loss.item()) g_loss.backward() g_optimizer.step() grad_normG.append(utils.getGradNorm(netG)) if args.ema_trick: moving_average.soft_copy_param(ema_netG_9999, netG, 0.9999) for _ in range(args.d_freq): try: x = next(data_iter)[0].cuda().float() except StopIteration: data_iter = iter(loader) x = next(data_iter)[0].cuda().float() if args.fixz: try: z = next(z_iter)[0].cuda() except: z_iter = iter(zloader) z = next(z_iter)[0].cuda() else: z = torch.randn(args.batch_size, args.input_dim, device=device) d_optimizer.zero_grad() x_hat = netG(z).detach() y_hat = netD(x_hat) y = netD(x) d_loss = get_dloss(args.losstype, y_hat, y) d_losses.append(d_loss.item()) d_loss.backward() d_optimizer.step() grad_normD.append(utils.getGradNorm(netD)) if i % args.print_freq == 0: print('Iteration: {}; G-Loss: {}; D-Loss: {};'.format( i, g_loss, d_loss)) if i == 1: save_image((x / 2. + 0.5)[:36], os.path.join(dir_name, 'real.png')) if i == 1 or i % args.plot_freq == 0: plot_x = netG(torch.randn(36, args.input_dim, device=device)).data plot_x = plot_x / 2. + 0.5 save_image( plot_x, os.path.join(dir_name, 'fake_images-{}.png'.format(i + 1))) utils.plot_losses(g_losses, d_losses, grad_normG, grad_normD, dir_name) utils.saveproj(y.cpu(), y_hat.cpu(), i, save_path) if i % args.save_freq == 0: torch.save(netG.state_dict(), save_path + 'G_epoch{}.pth'.format(i)) torch.save(netD.state_dict(), save_path + 'D_epoch{}.pth'.format(i)) if args.ema_trick: torch.save(ema_netG_9999.state_dict(), save_path + 'emaG0.9999_epoch{}.pth'.format(i))
def SWAG(model, dataloader, optimizer, criterion, epochs=3000, print_freq=1000, swag_start=2000, M=1e5, lr_ratio=1, verbose=False): '''Implementation of Stochastic Weight Averaging''' model.train() # prep model layers for training # initialize first moment as vectors w/ length = num model parameters num_params = sum(param.numel() for param in model.parameters()) first_moment = torch.zeros(num_params) # initialize deviation matrix 'A' A = torch.empty(0, num_params, dtype=torch.float32) lr_init = optimizer.defaults['lr'] n_iterates = 0 for epoch in tqdm(range(epochs)): epoch_loss = 0 # Implementation of learning rate decay from paper epoch_ratio = (epoch + 1) / swag_start lr = decay_lr(optimizer, epoch_ratio, lr_init=lr_init, lr_ratio=lr_ratio) for inputs, labels in dataloader: optimizer.zero_grad() # clear gradients preds = model(inputs) # perform a forward pass loss = criterion(preds, labels) # compute the loss loss.backward() # backpropagate optimizer.step() # update the weights epoch_loss += loss.data.item() * inputs.shape[0] # Print output if (epoch % print_freq == 0 or epoch == epochs - 1) and verbose: print('Epoch %d | LR: %g | Loss: %.4f' % (epoch, lr, epoch_loss)) # Average gradient weights if epoch > swag_start: # obtain a flattened vector of weights weights_list = [param.detach() for param in model.parameters()] w = torch.cat([w.contiguous().view(-1, 1) for w in weights_list]).view(-1) # update the first moment first_moment = (n_iterates * first_moment + w) / (n_iterates + 1) # update 'a' matrix (following their code implementation) a = w - first_moment A = torch.cat((A, a.view(1, -1)), dim=0) # only store the last 'M' deviation vectors if memory limited if A.shape[1] > M: A = A[1:, :] n_iterates += 1 return first_moment.double(), A.numpy()
def train(): dir_name, save_path = getSavePath() netG, netD = networks.getGD_SN(args.structure, args.dataset, args.image_size, args.num_features, dim_z=args.input_dim, bottleneck=args.bottleneck) if args.ema_trick: ema_netG_9999 = copy.deepcopy(netG) if args.reload > 0: netG.load_state_dict(torch.load(save_path + 'G_epoch{}.pth'.format(args.reload))) netD.load_state_dict(torch.load(save_path + 'D_epoch{}.pth'.format(args.reload))) if args.ema_trick: ema_netG_9999.load_state_dict( torch.load(save_path + 'emaG0.9999_epoch{}.pth'.format(args.reload), map_location=torch.device('cpu'))) netG.cuda() netD.cuda() g_optimizer = torch.optim.Adam(netG.parameters(), lr=args.g_lr, betas=(args.beta1, args.beta2)) d_optimizer = torch.optim.Adam(netD.parameters(), lr=args.d_lr, betas=(args.beta1, args.beta2)) g_losses, d_losses = [], [] grad_normD, grad_normG = [], [] loader = datasets.getDataLoader(args.dataset, args.image_size, batch_size=args.batch_size) data_iter = iter(loader) for i in range(1, args.num_iters+1): if i >= args.lr_decay_start: utils.decay_lr(g_optimizer, args.num_iters, args.lr_decay_start, args.g_lr) utils.decay_lr(d_optimizer, args.num_iters, args.lr_decay_start, args.d_lr) if i <= args.reload: continue if i == 1: torch.save(netG.state_dict(), save_path + 'G_epoch0.pth') torch.save(netD.state_dict(), save_path + 'D_epoch0.pth') # G-step for _ in range(args.g_freq): try: x = next(data_iter)[0].cuda().float() except StopIteration: data_iter = iter(loader) x = next(data_iter)[0].cuda().float() z = torch.randn(args.batch_size, args.input_dim, device=device) g_optimizer.zero_grad() x_hat = netG(z) y_hat = netD(x_hat) y = netD(x) g_loss = get_gloss(args.losstype, y_hat, y) g_losses.append(g_loss.item()) g_loss.backward() g_optimizer.step() grad_normG.append(utils.getGradNorm(netG)) if args.ema_trick: utils.soft_copy_param(ema_netG_9999, netG, 0.9999) for _ in range(args.d_freq): try: x = next(data_iter)[0].cuda().float() except StopIteration: data_iter = iter(loader) x = next(data_iter)[0].cuda().float() z = torch.randn(args.batch_size, args.input_dim, device=device) d_optimizer.zero_grad() x_hat = netG(z).detach() y_hat = netD(x_hat) y = netD(x) d_loss = get_dloss(args.losstype, y_hat, y) d_losses.append(d_loss.item()) d_loss.backward() d_optimizer.step() grad_normD.append(utils.getGradNorm(netD)) netD.proj.weight.data = F.normalize(netD.proj.weight.data, dim=1) if i % args.print_freq == 0: print('Iteration: {}; G-Loss: {}; D-Loss: {};'.format(i, g_loss, d_loss)) if i == 1: save_image((x / 2. + 0.5), os.path.join(dir_name, 'real.pdf')) if i == 1 or i % args.plot_freq == 0: plot_x = netG(torch.randn(args.batch_size, args.input_dim, device=device)).data plot_x = plot_x / 2. + 0.5 save_image(plot_x, os.path.join(dir_name, 'fake_images-{}.pdf'.format(i + 1))) utils.plot_losses(g_losses, d_losses, grad_normG, grad_normD, dir_name) if i % args.save_freq == 0: torch.save(netG.state_dict(), save_path + 'G_epoch{}.pth'.format(i)) torch.save(netD.state_dict(), save_path + 'D_epoch{}.pth'.format(i)) if args.ema_trick: torch.save(ema_netG_9999.state_dict(), save_path + 'emaG0.9999_epoch{}.pth'.format(i))
def train(): with tf.Graph().as_default(), tf.device('/cpu:0'): assert FLAGS.batch_size % FLAGS.num_gpus == 0, ( 'Batch size must be divisible by number of GPUs') bs_l = FLAGS.batch_size num_iter_per_epoch = int(FLAGS.num_train_l / bs_l) max_steps = int(FLAGS.num_epochs * num_iter_per_epoch) num_classes = FLAGS.num_classes global_step = slim.create_global_step() lr = tf.placeholder(tf.float32, shape=[], name="learning_rate") opt = tf.train.MomentumOptimizer(learning_rate=lr, momentum=0.9, use_nesterov=True) images, labels = utils.prepare_traindata(FLAGS.dataset_dir_l, int(bs_l)) images_splits = tf.split(images, FLAGS.num_gpus, 0) labels_splits = tf.split(labels, FLAGS.num_gpus, 0) tower_grads = [] top_1_op = [] reuse_variables = None for i in range(FLAGS.num_gpus): with tf.device('/gpu:%d' % i): with tf.name_scope('%s_%d' % (network.TOWER_NAME, i)) as scope: with slim.arg_scope(slim.get_model_variables(scope=scope), device='/cpu:0'): loss, logits = \ _build_training_graph(images_splits[i], labels_splits[i], num_classes, reuse_variables) top_1_op.append( tf.nn.in_top_k(logits, labels_splits[i], 1)) reuse_variables = True summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope) batchnorm_updates = tf.get_collection( tf.GraphKeys.UPDATE_OPS, scope) grads = opt.compute_gradients(loss) tower_grads.append(grads) grads = network.average_gradients(tower_grads) gradient_op = opt.apply_gradients(grads, global_step=global_step) var_averages = tf.train.ExponentialMovingAverage( FLAGS.ema_decay, global_step) var_op = var_averages.apply(tf.trainable_variables() + tf.moving_average_variables()) batchnorm_op = tf.group(*batchnorm_updates) train_op = tf.group(gradient_op, var_op, batchnorm_op) saver = tf.train.Saver(tf.global_variables(), max_to_keep=None) summary_op = tf.summary.merge(summaries) init_op = tf.global_variables_initializer() config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False) if FLAGS.gpu_memory: config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory sess = tf.Session(config=config) boundaries, values = utils.config_lr(max_steps) sess.run([init_op], feed_dict={lr: values[0]}) tf.train.start_queue_runners(sess=sess) summary_writer = tf.summary.FileWriter(FLAGS.train_dir, graph=sess.graph) iter_count = epoch = sum_loss = sum_top_1 = 0 start = time.time() for step in range(max_steps): decayed_lr = utils.decay_lr(step, boundaries, values, max_steps) _, loss_value, top_1_value = \ sess.run([train_op, loss, top_1_op], feed_dict={lr: decayed_lr}) sum_loss += loss_value top_1_value = np.sum(top_1_value) / bs_l sum_top_1 += top_1_value iter_count += 1 assert not np.isnan(loss_value), 'Model diverged with loss = NaN' if step % num_iter_per_epoch == 0: end = time.time() sum_loss = sum_loss / num_iter_per_epoch sum_top_1 = min(sum_top_1 / num_iter_per_epoch, 1.0) time_per_iter = float(end - start) / iter_count format_str = ( 'epoch %d, L = %.2f, top_1 = %.2f, lr = %.4f (time_per_iter: %.4f s)' ) print(format_str % (epoch, sum_loss, sum_top_1 * 100, decayed_lr, time_per_iter)) epoch += 1 sum_loss = sum_top_1 = 0 if step % 100 == 0: summary_str = sess.run(summary_op, feed_dict={lr: decayed_lr}) summary_writer.add_summary(summary_str, step) if (step + 1) == max_steps: checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=epoch)