def infer(valid_queue, model, epoch, Latency,criterion, writer): batch_time = utils.AverageMeters('Time', ':6.3f') losses = utils.AverageMeters('Loss', ':.4e') top1 = utils.AverageMeters('Acc@1', ':6.2f') top5 = utils.AverageMeters('Acc@5', ':6.2f') # set chosen op active model.module.set_chosen_op_active() model.module.unused_modules_off() model.eval() progress = utils.ProgressMeter(len(valid_queue), batch_time, losses, top1, top5, prefix='Test: ') cur_step = epoch*len(valid_queue) end = time.time() with torch.no_grad(): for step, (input, target) in enumerate(valid_queue): # input = input.cuda() # target = target.cuda(non_blocking=True) input = Variable(input, volatile=True).cuda() # target = Variable(target, volatile=True).cuda(async=True) target = Variable(target, volatile=True).cuda() logits = model(input) loss = criterion(logits, target) acc1, acc5 = utils.accuracy(logits, target, topk=(1, 5)) n = input.size(0) reduced_loss = reduce_tensor( loss.data, world_size=config.world_size) acc1 = reduce_tensor(acc1, world_size=config.world_size) acc5 = reduce_tensor(acc5, world_size=config.world_size) losses.update(to_python_float(reduced_loss), n) top1.update(to_python_float(acc1), n) top5.update(to_python_float(acc5), n) # measure elapsed time batch_time.update(time.time() - end) end = time.time() shape = [1, 3, 224, 224] input_var = torch.zeros(shape, device=device) flops = model.module.get_flops(input_var) if config.target_hardware in [None, 'flops']: latency = 0 else: latency = Latency.predict_latency(model) model.module.unused_modules_back() if step % config.print_freq == 0: progress.print(step) logger.info('valid %03d\t loss: %e\t top1: %f\t top5: %f\t flops: %f\t latency: %f', step, losses.avg, top1.avg, top5.avg, flops/1e6, latency) writer.add_scalar('val/loss', losses.avg, cur_step) writer.add_scalar('val/top1', top1.avg, cur_step) writer.add_scalar('val/top5', top5.avg, cur_step) return top1.avg, losses.avg
def validate_warmup(valid_queue, model, epoch, criterion, writer): batch_time = utils.AverageMeters('Time', ':6.3f') losses = utils.AverageMeters('Loss', ':.4e') top1 = utils.AverageMeters('Acc@1', ':6.2f') top5 = utils.AverageMeters('Acc@5', ':6.2f') model.train() progress = utils.ProgressMeter(len(valid_queue), batch_time, losses, top1, top5, prefix='Warmup-Test: ') cur_step = epoch * len(valid_queue) end = time.time() with torch.no_grad(): for step, (input, target) in enumerate(valid_queue): # input = input.cuda() # target = target.cuda(non_blocking=True) input = Variable(input, volatile=True).cuda() # target = Variable(target, volatile=True).cuda(async=True) target = Variable(target, volatile=True).cuda() logits = model(input) loss = criterion(logits, target) acc1, acc5 = utils.accuracy(logits, target, topk=(1, 5)) n = input.size(0) losses.update(loss, n) top1.update(acc1, n) top5.update(acc5, n) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if step % config.print_freq == 0: progress.print(step) logger.info('warmup-valid %03d %e %f %f', step, losses.avg, top1.avg, top5.avg) writer.add_scalar('warmup-val/loss', losses.avg, cur_step) writer.add_scalar('warmup-val/top1', top1.avg, cur_step) writer.add_scalar('warmup-val/top5', top5.avg, cur_step) return top1.avg, top5.avg, losses.avg
def train(self, loader, st_step=0, max_step=100000): self.gen.train() if self.disc is not None: self.disc.train() # loss stats losses = utils.AverageMeters("g_total", "pixel", "disc", "gen", "fm", "indp_exp", "indp_fact", "ac_s", "ac_c", "cross_ac_s", "cross_ac_c", "ac_gen_s", "ac_gen_c", "cross_ac_gen_s", "cross_ac_gen_c") # discriminator stats discs = utils.AverageMeters("real_font", "real_uni", "fake_font", "fake_uni", "real_font_acc", "real_uni_acc", "fake_font_acc", "fake_uni_acc") # etc stats stats = utils.AverageMeters("B", "ac_acc_s", "ac_acc_c", "ac_gen_acc_s", "ac_gen_acc_c") self.step = st_step self.clear_losses() self.logger.info("Start training ...") for batch in cyclize(loader): epoch = self.step // len(loader) if self.cfg.use_ddp and (self.step % len(loader)) == 0: loader.sampler.set_epoch(epoch) style_imgs = batch["style_imgs"].cuda() style_fids = batch["style_fids"].cuda() style_decs = batch["style_decs"] char_imgs = batch["char_imgs"].cuda() char_fids = batch["char_fids"].cuda() char_decs = batch["char_decs"] trg_imgs = batch["trg_imgs"].cuda() trg_fids = batch["trg_fids"].cuda() trg_cids = batch["trg_cids"].cuda() trg_decs = batch["trg_decs"] ############################################################## # infer ############################################################## B = len(trg_imgs) n_s = style_imgs.shape[1] n_c = char_imgs.shape[1] style_feats = self.gen.encode(style_imgs.flatten( 0, 1)) # (B*n_s, n_exp, *feat_shape) char_feats = self.gen.encode(char_imgs.flatten(0, 1)) self.add_indp_exp_loss( torch.cat([style_feats["last"], char_feats["last"]])) style_facts_s = self.gen.factorize( style_feats, 0) # (B*n_s, n_exp, *feat_shape) style_facts_c = self.gen.factorize(style_feats, 1) char_facts_s = self.gen.factorize(char_feats, 0) char_facts_c = self.gen.factorize(char_feats, 1) self.add_indp_fact_loss( [style_facts_s["last"], style_facts_c["last"]], [style_facts_s["skip"], style_facts_c["skip"]], [char_facts_s["last"], char_facts_c["last"]], [char_facts_s["skip"], char_facts_c["skip"]], ) mean_style_facts = { k: utils.add_dim_and_reshape(v, 0, (-1, n_s)).mean(1) for k, v in style_facts_s.items() } mean_char_facts = { k: utils.add_dim_and_reshape(v, 0, (-1, n_c)).mean(1) for k, v in char_facts_c.items() } gen_feats = self.gen.defactorize( [mean_style_facts, mean_char_facts]) gen_imgs = self.gen.decode(gen_feats) stats.updates({ "B": B, }) real_font, real_uni, *real_feats = self.disc( trg_imgs, trg_fids, trg_cids, out_feats=self.cfg['fm_layers']) fake_font, fake_uni = self.disc(gen_imgs.detach(), trg_fids, trg_cids) self.add_gan_d_loss([real_font, real_uni], [fake_font, fake_uni]) self.d_optim.zero_grad() self.d_backward() self.d_optim.step() fake_font, fake_uni, *fake_feats = self.disc( gen_imgs, trg_fids, trg_cids, out_feats=self.cfg['fm_layers']) self.add_gan_g_loss(fake_font, fake_uni) self.add_fm_loss(real_feats, fake_feats) def racc(x): return (x > 0.).float().mean().item() def facc(x): return (x < 0.).float().mean().item() discs.updates( { "real_font": real_font.mean().item(), "real_uni": real_uni.mean().item(), "fake_font": fake_font.mean().item(), "fake_uni": fake_uni.mean().item(), 'real_font_acc': racc(real_font), 'real_uni_acc': racc(real_uni), 'fake_font_acc': facc(fake_font), 'fake_uni_acc': facc(fake_uni) }, B) self.add_pixel_loss(gen_imgs, trg_imgs) self.g_optim.zero_grad() self.add_ac_losses_and_update_stats( torch.cat([style_facts_s["last"], char_facts_s["last"]]), torch.cat([style_fids.flatten(), char_fids.flatten()]), torch.cat([style_facts_c["last"], char_facts_c["last"]]), style_decs + char_decs, gen_imgs, trg_fids, trg_decs, stats) self.ac_optim.zero_grad() self.ac_backward() self.ac_optim.step() self.g_backward() self.g_optim.step() loss_dic = self.clear_losses() losses.updates(loss_dic, B) # accum loss stats # EMA g self.accum_g() if self.is_bn_gen: self.sync_g_ema(style_imgs, char_imgs) torch.cuda.synchronize() if self.cfg.gpu <= 0: if self.step % self.cfg.tb_freq == 0: self.plot(losses, discs, stats) if self.step % self.cfg.print_freq == 0: self.log(losses, discs, stats) self.logger.debug( "GPU Memory usage: max mem_alloc = %.1fM / %.1fM", torch.cuda.max_memory_allocated() / 1000 / 1000, torch.cuda.max_memory_cached() / 1000 / 1000) losses.resets() discs.resets() stats.resets() nrow = len(trg_imgs) grid = utils.make_comparable_grid(trg_imgs.detach().cpu(), gen_imgs.detach().cpu(), nrow=nrow) self.writer.add_image("last", grid) if self.step > 0 and self.step % self.cfg.val_freq == 0: epoch = self.step / len(loader) self.logger.info( "Validation at Epoch = {:.3f}".format(epoch)) if not self.is_bn_gen: self.sync_g_ema(style_imgs, char_imgs) self.evaluator.comparable_val_saveimg( self.gen_ema, self.test_loader, self.step, n_row=self.test_n_row) self.save(loss_dic['g_total'], self.cfg.save, self.cfg.get('save_freq', self.cfg.val_freq)) else: pass if self.step >= max_step: break self.step += 1 self.logger.info("Iteration finished.")
def train(self, loader, st_step=1, val=None): val = val or {} self.gen.train() self.disc.train() # loss stats losses = utils.AverageMeters("g_total", "pixel", "disc", "gen", "fm", "ac", "ac_gen") # discriminator stats discs = utils.AverageMeters("real", "fake", "real_font", "real_char", "fake_font", "fake_char", "real_acc", "fake_acc", "real_font_acc", "real_char_acc", "fake_font_acc", "fake_char_acc") # etc stats stats = utils.AverageMeters("B_style", "B_target", "ac_acc", "ac_gen_acc") self.step = st_step self.clear_losses() self.logger.info("Start training ...") for (style_ids, style_char_ids, style_comp_ids, style_imgs, trg_ids, trg_char_ids, trg_comp_ids, trg_imgs, *content_imgs) in cyclize(loader): B = trg_imgs.size(0) stats.updates({"B_style": style_imgs.size(0), "B_target": B}) style_ids = style_ids.cuda() # style_char_ids = style_char_ids.cuda() style_comp_ids = style_comp_ids.cuda() style_imgs = style_imgs.cuda() trg_ids = trg_ids.cuda() trg_char_ids = trg_char_ids.cuda() trg_comp_ids = trg_comp_ids.cuda() trg_imgs = trg_imgs.cuda() # infer comp_feats = self.gen.encode_write(style_ids, style_comp_ids, style_imgs) out = self.gen.read_decode(trg_ids, trg_comp_ids) # D loss real, real_font, real_char, real_feats = self.disc(trg_imgs, trg_ids, trg_char_ids, out_feats=True) fake, fake_font, fake_char = self.disc(out.detach(), trg_ids, trg_char_ids) self.add_gan_d_loss(real, real_font, real_char, fake, fake_font, fake_char) self.d_optim.zero_grad() self.d_backward() self.d_optim.step() # G loss fake, fake_font, fake_char, fake_feats = self.disc(out, trg_ids, trg_char_ids, out_feats=True) self.add_gan_g_loss(real, real_font, real_char, fake, fake_font, fake_char) # feature matching loss self.add_fm_loss(real_feats, fake_feats) # disc stats racc = lambda x: (x > 0.).float().mean().item() facc = lambda x: (x < 0.).float().mean().item() discs.updates( { "real": real.mean().item(), "fake": fake.mean().item(), "real_font": real_font.mean().item(), "real_char": real_char.mean().item(), "fake_font": fake_font.mean().item(), "fake_char": fake_char.mean().item(), 'real_acc': racc(real), 'fake_acc': facc(fake), 'real_font_acc': racc(real_font), 'real_char_acc': racc(real_char), 'fake_font_acc': facc(fake_font), 'fake_char_acc': facc(fake_char) }, B) # pixel loss self.add_pixel_loss(out, trg_imgs) self.g_optim.zero_grad() # NOTE ac loss generates & leaves grads to G. # so g_optim.zero_grad() should place in front of ac loss and # g_backward() should follow ac loss. if self.aux_clf is not None: self.add_ac_losses_and_update_stats(comp_feats, style_comp_ids, out, trg_comp_ids, stats) self.ac_optim.zero_grad() self.ac_backward(retain_graph=True) self.ac_optim.step() self.g_backward() self.g_optim.step() loss_dic = self.clear_losses() losses.updates(loss_dic, B) # generator EMA self.accum_g() if self.is_bn_gen: self.sync_g_ema(style_ids, style_comp_ids, style_imgs, trg_ids, trg_comp_ids) # after step if self.step % self.cfg['tb_freq'] == 0: self.plot(losses, discs, stats) if self.step % self.cfg['print_freq'] == 0: self.log(losses, discs, stats) losses.resets() discs.resets() stats.resets() if self.step % self.cfg['val_freq'] == 0: epoch = self.step / len(loader) self.logger.info("Validation at Epoch = {:.3f}".format(epoch)) self.evaluator.merge_and_log_image('d1', out, trg_imgs, self.step) self.evaluator.validation(self.gen, self.step) # if non-BN generator, sync max singular value of spectral norm. if not self.is_bn_gen: self.sync_g_ema(style_ids, style_comp_ids, style_imgs, trg_ids, trg_comp_ids) self.evaluator.validation(self.gen_ema, self.step, extra_tag='_EMA') # save freq == val freq self.save(loss_dic['g_total'], self.cfg['save'], self.cfg.get('save_freq', self.cfg['val_freq'])) if self.step >= self.cfg['max_iter']: self.logger.info("Iteration finished.") break self.step += 1
def cross_validation(self, gen, step, loader, tag, n_batches, n_log=64, save_dir=None): """Validation using splitted cross-validation set Args: n_log: # of images to log save_dir: if given, images are saved to save_dir """ if save_dir: save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) outs = [] trgs = [] n_accum = 0 losses = utils.AverageMeters("l1", "ssim", "msssim") for i, (style_ids, style_comp_ids, style_imgs, trg_ids, trg_comp_ids, content_imgs, trg_imgs) in enumerate(loader): if i == n_batches: break style_ids = style_ids.cuda() style_comp_ids = style_comp_ids.cuda() style_imgs = style_imgs.cuda() trg_ids = trg_ids.cuda() trg_comp_ids = trg_comp_ids.cuda() trg_imgs = trg_imgs.cuda() gen.encode_write(style_ids, style_comp_ids, style_imgs) out = gen.read_decode(trg_ids, trg_comp_ids) B = len(out) # log images if n_accum < n_log: trgs.append(trg_imgs) outs.append(out) n_accum += B if n_accum >= n_log: # log results outs = torch.cat(outs)[:n_log] trgs = torch.cat(trgs)[:n_log] self.merge_and_log_image(tag, outs, trgs, step) l1, ssim, msssim = self.get_pixel_losses(out, trg_imgs, self.unify_resize_method) losses.updates( { "l1": l1.item(), "ssim": ssim.item(), "msssim": msssim.item() }, B) # save images if save_dir: font_ids = trg_ids.detach().cpu().numpy() images = out.detach().cpu() # [B, 1, 128, 128] char_comp_ids = trg_comp_ids.detach().cpu().numpy( ) # [B, n_comp_types] for font_id, image, comp_ids in zip(font_ids, images, char_comp_ids): font_name = loader.dataset.fonts[font_id] # name.ttf font_name = Path(font_name).stem # remove ext (save_dir / font_name).mkdir(parents=True, exist_ok=True) if self.language == 'kor': char = kor.compose(*comp_ids) elif self.language == 'thai': char = thai.compose_ids(*comp_ids) uni = "".join([f'{ord(each):04X}' for each in char]) path = save_dir / font_name / "{}_{}.png".format( font_name, uni) utils.save_tensor_to_image(image, path) self.logger.info( " [Valid] {tag:30s} | Step {step:7d} L1 {L.l1.avg:7.4f} SSIM {L.ssim.avg:7.4f}" " MSSSIM {L.msssim.avg:7.4f}".format(tag=tag, step=step, L=losses)) return losses.l1.avg, losses.ssim.avg, losses.msssim.avg
def train(train_queue, valid_queue, model, criterion, LatencyLoss, optimizer, alpha_optimizer, lr, epoch, writer, update_schedule): arch_param_num = np.sum( np.prod(params.size()) for params in model.module.arch_parameters()) binary_gates_num = len(list(model.module.binary_gates())) weight_param_num = len(list(model.module.weight_parameters())) print('#arch_params: %d\t#binary_gates: %d\t#weight_params: %d' % (arch_param_num, binary_gates_num, weight_param_num)) batch_time = utils.AverageMeters('Time', ':6.3f') data_time = utils.AverageMeters('Data', ':6.3f') losses = utils.AverageMeters('Loss', ':.4e') top1 = utils.AverageMeters('Acc@1', ':6.2f') top5 = utils.AverageMeters('Acc@5', ':6.2f') entropy = utils.AverageMeters('Entropy', ':.4e') progress = utils.ProgressMeter(len(train_queue), batch_time, data_time, losses, top1, top5, prefix="Epoch: [{}]".format(epoch)) cur_step = epoch * len(train_queue) writer.add_scalar('train/lr', lr, cur_step) model.train() end = time.time() for step, (input, target) in enumerate(train_queue): # measure data loading time data_time.update(time.time() - end) net_entropy = model.module.entropy() entropy.update(net_entropy.data.item() / arch_param_num, 1) # sample random path model.module.reset_binary_gates() # close unused module model.module.unused_modules_off() n = input.size(0) input = Variable(input, requires_grad=False).cuda() # target = Variable(target, requires_grad=False).cuda(async=True) target = Variable(target, requires_grad=False).cuda() logits = model(input) if config.label_smooth > 0.0: loss = utils.cross_entropy_with_label_smoothing( logits, target, config.label_smooth) else: loss = criterion(logits, target) acc1, acc5 = utils.accuracy(logits, target, topk=(1, 5)) losses.update(loss, n) top1.update(acc1, n) top5.update(acc5, n) model.zero_grad() loss.backward() nn.utils.clip_grad_norm(model.parameters(), config.grad_clip) optimizer.step() # unused module back model.module.unused_modules_back() # Training weights firstly, after few epoch, train arch parameters if epoch > 0: #### office warm up lr #### # T_cur = epoch * len(train_queue) + step # lr_max = 0.05 # T_totol = config.warmup_eforhs * len(train_queue) # lr = 0.5 * lr_max * (1 + math.cos(math.pi * T_cur / T_total)) #### office warm up lr #### for j in range(update_schedule.get(step, 0)): model.train() latency_loss = 0 expected_loss = 0 valid_iter = iter(valid_queue) input_valid, target_valid = next(valid_iter) # alpha_optimizer.zero_grad() input_valid = Variable(input_valid, requires_grad=False).cuda() # target = Variable(target, requires_grad=False).cuda(async=True) target_valid = Variable(target_valid, requires_grad=False).cuda() model.module.reset_binary_gates() model.module.unused_modules_off() output_valid = model(input_valid).float() loss_ce = criterion(output_valid, target_valid) expected_loss = LatencyLoss.expected_latency(model) expected_loss_tensor = torch.cuda.FloatTensor([expected_loss]) latency_loss = LatencyLoss(loss_ce, expected_loss_tensor, config) # compute gradient and do SGD step # zero grads of weight_param, arch_param & binary_param model.zero_grad() latency_loss.backward() # set architecture parameter gradients model.module.set_arch_param_grad() alpha_optimizer.step() model.module.rescale_updated_arch_param() model.module.unused_modules_back() log_str = 'Architecture [%d-%d]\t Loss %.4f\t %s LatencyLoss: %s' % ( epoch, step, latency_loss, config.target_hardware, expected_loss) utils.write_log(arch_logger_path, log_str) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if step % config.print_freq == 0 or step == len(train_queue) - 1: logger.info('train step:%03d %03d loss:%e top1:%05f top5:%05f', step, len(train_queue), losses.avg, top1.avg, top5.avg) progress.print(step) writer.add_scalar('train/loss', losses.avg, cur_step) writer.add_scalar('train/top1', top1.avg, cur_step) writer.add_scalar('train/top5', top5.avg, cur_step) return top1.avg, losses.avg
def warm_up(train_queue, valid_queue, model, criterion, Latency, optimizer, epoch, writer): batch_time = utils.AverageMeters('Time', ':6.3f') data_time = utils.AverageMeters('Data', ':6.3f') losses = utils.AverageMeters('Loss', ':.4e') top1 = utils.AverageMeters('Acc@1', ':6.2f') top5 = utils.AverageMeters('Acc@5', ':6.2f') progress = utils.ProgressMeter(len(train_queue), batch_time, data_time, losses, top1, top5, prefix="Epoch: [{}]".format(epoch)) cur_step = epoch * len(train_queue) model.train() print('\n', '-' * 30, 'Warmup epoch: %d' % (epoch), '-' * 30, '\n') end = time.time() lr = 0 for step, (input, target) in enumerate(train_queue): # measure data loading time data_time.update(time.time() - end) # office warm up lr #l'r T_cur = epoch * len(train_queue) + step lr_max = 0.05 T_total = config.warmup_epochs * len(train_queue) lr = 0.5 * lr_max * (1 + math.cos(math.pi * T_cur / T_total)) for param_group in optimizer.param_groups: param_group['lr'] = lr writer.add_scalar('warm-up/lr', lr, cur_step + step) #### office warm up lr #### n = input.size(0) input = Variable(input, requires_grad=False).cuda() # target = Variable(target, requires_grad=False).cuda(async=True) target = Variable(target, requires_grad=False).cuda() model.module.reset_binary_gates() model.module.unused_modules_off() logits = model(input) if config.label_smooth > 0 and epoch > config.warmup_epochs: loss = utils.cross_entropy_with_label_smoothing( logits, target, config.label_smooth) else: loss = criterion(logits, target) model.zero_grad() loss.backward() optimizer.step() acc1, acc5 = utils.accuracy(logits, target, topk=(1, 5)) losses.update(loss, n) top1.update(acc1, n) top5.update(acc5, n) # unused modules back model.module.unused_modules_back() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if step % config.print_freq == 0 or step == len(train_queue) - 1: logger.info( 'warmup train step:%03d %03d loss:%e top1:%05f top5:%05f', step, len(train_queue), losses.avg, top1.avg, top5.avg) progress.print(step) writer.add_scalar('warmup-train/loss', losses.avg, cur_step) writer.add_scalar('warmup-train/top1', top1.avg, cur_step) writer.add_scalar('warmup-train/top5', top5.avg, cur_step) logger.info('warmup epoch %d lr %e', epoch, lr) # set chosen op active model.module.set_chosen_op_active() # remove unused modules model.module.unused_modules_off() valid_top1, valid_top5, valid_loss = validate_warmup( valid_queue, model, epoch, criterion, writer) shape = [1, 3, 224, 224] input_var = torch.zeros(shape, device=device) flops = model.module.get_flops(input_var) latency = 0 if config.target_hardware in [None, 'flops']: latency = 0 else: latency = Latency.predict_latency(model) # unused modules back logger.info( 'Warmup Valid [{0}/{1}]\tloss {2:.3f}\ttop-1 acc {3:.3f}\ttop-5 acc ' '{4:.3f}\tflops: {5:.1f}M {6:.3f}ms'.format(epoch, config.warmup_epochs, valid_loss, valid_top1, valid_top5, flops / 1e6, latency)) model.module.unused_modules_back() config.warmup = epoch + 1 < config.warmup_epochs state_dict = model.state_dict() # rm architect params and binary getes for key in list(state_dict.keys()): if 'alpha' in key or 'path' in key: state_dict.pop(key) checkpoint = { 'state_dict': state_dict, 'warmup': config.warmup, } if config.warmup: checkpoint['warmup_epoch'] = epoch checkpoint['epoch'] = epoch checkpoint['w_optimizer'] = optimizer.state_dict() save_model(model, checkpoint, model_name='warmup.pth.tar') return top1.avg, losses.avg
def train(self, loader, st_step=1, max_step=100000): self.gen.train() self.disc.train() losses = utils.AverageMeters("g_total", "pixel", "disc", "gen", "fm", "ac", "ac_gen", "dec_const") discs = utils.AverageMeters("real_font", "real_uni", "fake_font", "fake_uni", "real_font_acc", "real_uni_acc", "fake_font_acc", "fake_uni_acc") # etc stats stats = utils.AverageMeters("B_style", "B_target", "ac_acc", "ac_gen_acc") self.step = st_step self.clear_losses() self.logger.info("Start training ...") for (in_style_ids, in_comp_ids, in_imgs, trg_style_ids, trg_uni_ids, trg_comp_ids, trg_imgs, content_imgs) in cyclize(loader): epoch = self.step // len(loader) if self.cfg.use_ddp and (self.step % len(loader)) == 0: loader.sampler.set_epoch(epoch) B = trg_imgs.size(0) stats.updates({"B_style": in_imgs.size(0), "B_target": B}) in_style_ids = in_style_ids.cuda() in_comp_ids = in_comp_ids.cuda() in_imgs = in_imgs.cuda() trg_style_ids = trg_style_ids.cuda() trg_imgs = trg_imgs.cuda() content_imgs = content_imgs.cuda() if self.cfg.use_half: in_imgs = in_imgs.half() content_imgs = content_imgs.half() feat_styles, feat_comps = self.gen.encode_write_fact( in_style_ids, in_comp_ids, in_imgs, write_comb=True) feats_rc = (feat_styles * feat_comps).sum(1) ac_feats = feats_rc self.add_dec_const_loss() out = self.gen.read_decode(trg_style_ids, trg_comp_ids, content_imgs=content_imgs, phase="fact", try_comb=True) trg_uni_disc_ids = trg_uni_ids.cuda() real_font, real_uni, *real_feats = self.disc( trg_imgs, trg_style_ids, trg_uni_disc_ids, out_feats=self.cfg['fm_layers']) fake_font, fake_uni = self.disc(out.detach(), trg_style_ids, trg_uni_disc_ids) self.add_gan_d_loss(real_font, real_uni, fake_font, fake_uni) self.d_optim.zero_grad() self.d_backward() self.d_optim.step() fake_font, fake_uni, *fake_feats = self.disc( out, trg_style_ids, trg_uni_disc_ids, out_feats=self.cfg['fm_layers']) self.add_gan_g_loss(real_font, real_uni, fake_font, fake_uni) self.add_fm_loss(real_feats, fake_feats) def racc(x): return (x > 0.).float().mean().item() def facc(x): return (x < 0.).float().mean().item() discs.updates( { "real_font": real_font.mean().item(), "real_uni": real_uni.mean().item(), "fake_font": fake_font.mean().item(), "fake_uni": fake_uni.mean().item(), 'real_font_acc': racc(real_font), 'real_uni_acc': racc(real_uni), 'fake_font_acc': facc(fake_font), 'fake_uni_acc': facc(fake_uni) }, B) self.add_pixel_loss(out, trg_imgs) self.g_optim.zero_grad() if self.aux_clf is not None: self.add_ac_losses_and_update_stats(ac_feats, in_comp_ids, out, trg_comp_ids, stats) self.ac_optim.zero_grad() self.ac_backward() self.ac_optim.step() self.g_backward() self.g_optim.step() loss_dic = self.clear_losses() losses.updates(loss_dic, B) # accum loss stats self.accum_g() if self.is_bn_gen: self.sync_g_ema(in_style_ids, in_comp_ids, in_imgs, trg_style_ids, trg_comp_ids, content_imgs=content_imgs) torch.cuda.synchronize() if self.cfg.gpu <= 0: if self.step % self.cfg['tb_freq'] == 0: self.baseplot(losses, discs, stats) self.plot(losses) if self.step % self.cfg['print_freq'] == 0: self.log(losses, discs, stats) self.logger.debug( "GPU Memory usage: max mem_alloc = %.1fM / %.1fM", torch.cuda.max_memory_allocated() / 1000 / 1000, torch.cuda.max_memory_cached() / 1000 / 1000) losses.resets() discs.resets() stats.resets() if self.step % self.cfg['val_freq'] == 0: epoch = self.step / len(loader) self.logger.info( "Validation at Epoch = {:.3f}".format(epoch)) if not self.is_bn_gen: self.sync_g_ema(in_style_ids, in_comp_ids, in_imgs, trg_style_ids, trg_comp_ids, content_imgs=content_imgs) self.evaluator.cp_validation(self.gen_ema, self.cv_loaders, self.step, phase="fact", ext_tag="factorize") self.save(loss_dic['g_total'], self.cfg['save'], self.cfg.get('save_freq', self.cfg['val_freq'])) else: pass if self.step >= max_step: break self.step += 1 self.logger.info("Iteration finished.")
def validate(val_loader, model, epoch, criterion, config, early_stopping, writer, start): batch_time = utils.AverageMeters('Time', ':6.3f') losses = utils.AverageMeters('Loss', ':.4e') top1 = utils.AverageMeters('Acc@1', ':6.2f') top5 = utils.AverageMeters('Acc@5', ':6.2f') if 'DALIClassificationIterator' in val_loader.__class__.__name__: progress = utils.ProgressMeter(math.ceil(val_loader._size / config.batch_size), batch_time, losses, top1, top5, prefix='Test: ') else: progress = utils.ProgressMeter(len(val_loader), batch_time, losses, top1, top5, prefix='Test: ') # switch to evaluate mode model.eval() with torch.no_grad(): end = time.time() if 'DALIClassificationIterator' in val_loader.__class__.__name__: for i, data in enumerate(val_loader): images = Variable(data[0]['data']) target = Variable( data[0]['label'].squeeze().cuda().long().cuda( non_blocking=True)) # compute output output = model(images) loss = criterion(output, target) # measure accuracy and record loss acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) if config.distributed: reduced_loss = reduce_tensor(loss.data, world_size=config.world_size) acc1 = reduce_tensor(acc1, world_size=config.world_size) acc5 = reduce_tensor(acc5, world_size=config.world_size) else: reduced_loss = loss.data losses.update(to_python_float(reduced_loss), images.size(0)) top1.update(to_python_float(acc1), images.size(0)) top5.update(to_python_float(acc5), images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % config.print_freq == 0: progress.print(i) else: for i, (images, target) in enumerate(val_loader): images = images.cuda(device, non_blocking=True) target = target.cuda(device, non_blocking=True) # compute output output = model(images) loss = criterion(output, target) # measure accuracy and record loss acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), images.size(0)) top1.update(acc1[0], images.size(0)) top5.update(acc5[0], images.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % config.print_freq == 0: progress.print(i) # TODO: this should also be done with the ProgressMeter print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5)) early_stopping(losses.avg, model, ckpt_dir=config.path) if early_stopping.early_stop: print("Early stopping") utils.time(time.time() - start) os._exit(0) writer.add_scalar('val/loss', losses.avg, epoch) writer.add_scalar('val/top1', top1.val, epoch) writer.add_scalar('val/top5', top5.val, epoch) return top1.avg
def train(train_loader, model, criterion, optimizer, epoch, config, writer): utils.adjust_learning_rate(optimizer, epoch, config) batch_time = utils.AverageMeters('Time', ':6.3f') data_time = utils.AverageMeters('Data', ':6.3f') losses = utils.AverageMeters('Loss', ':.4e') top1 = utils.AverageMeters('Acc@1', ':6.2f') top5 = utils.AverageMeters('Acc@5', ':6.2f') if 'DALIClassificationIterator' in train_loader.__class__.__name__: # TODO: IF need * config.world_size progress = utils.ProgressMeter(math.ceil(train_loader._size / config.batch_size), batch_time, data_time, losses, top1, top5, prefix="Epoch: [{}]".format(epoch)) cur_step = epoch * math.ceil(train_loader._size / config.batch_size) else: progress = utils.ProgressMeter(len(train_loader), batch_time, data_time, losses, top1, top5, prefix="Epoch: [{}]".format(epoch)) cur_step = epoch * len(train_loader) writer.add_scalar('train/lr', config.lr, cur_step) model.train() end = time.time() if 'DALIClassificationIterator' in train_loader.__class__.__name__: for i, data in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) images = Variable(data[0]['data']) target = Variable(data[0]['label'].squeeze().cuda().long()) # compute output output = model(images) loss = criterion(output, target) # measure accuracy and record loss acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) if config.distributed: reduced_loss = reduce_tensor(loss.data, world_size=config.world_size) acc1 = reduce_tensor(acc1, world_size=config.world_size) acc5 = reduce_tensor(acc5, world_size=config.world_size) else: reduced_loss = loss.data losses.update(to_python_float(reduced_loss), images.size(0)) top1.update(to_python_float(acc1), images.size(0)) top5.update(to_python_float(acc5), images.size(0)) # compute gradient and do SGD step optimizer.zero_grad() if config.fp16_allreduce: optimizer.backward(loss) else: loss.backward() optimizer.step() torch.cuda.synchronize() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % config.print_freq == 0: progress.print(i) writer.add_scalar('train/loss', loss.item(), cur_step) writer.add_scalar('train/top1', top1.avg, cur_step) writer.add_scalar('train/top5', top5.avg, cur_step) else: for i, (images, target) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) images = images.cuda(device, non_blocking=True) target = target.cuda(device, non_blocking=True) # compute output output = model(images) loss = criterion(output, target) # measure accuracy and record loss acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) losses.update(loss.item(), images.size(0)) top1.update(acc1.item(), images.size(0)) top5.update(acc5.item(), images.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % config.print_freq == 0: progress.print(i) writer.add_scalar('train/loss', loss.item(), cur_step) writer.add_scalar('train/top1', top1.avg, cur_step) writer.add_scalar('train/top5', top5.avg, cur_step)