def _generator_training_step(self, generator, model, loader, epoch, tau, info_for_logger=""): start_time = time.time() generator.train() model.eval() for step, (X, y) in enumerate(loader): arch_param, hardware_constraint = self.set_arch_param(generator, model, tau=tau) macs = self.lookup_table.get_model_macs(arch_param) logging.info("Generate model macs : {}".format(macs)) hc_loss = cal_hc_loss(macs.cuda(), hardware_constraint.item(), self.alpha, self.loss_penalty) X, y = X.cuda(non_blocking=True), y.cuda(non_blocking=True) N = X.shape[0] self.g_optimizer.zero_grad() outs = model(X, True) ce_loss = self.criterion(outs, y) loss = ce_loss + hc_loss logging.info("HC loss : {}".format(hc_loss)) loss.backward() self.g_optimizer.step() self.g_optimizer.zero_grad() self._intermediate_stats_logging(outs, y, loss, step, epoch, N, len_loader=len(loader), val_or_train="Train", hc_losses=hc_loss) self._epoch_stats_logging(start_time=start_time, epoch=epoch, val_or_train="train") for avg in [self.top1, self.top5, self.losses, self.hc_losses]: avg.reset()
def evaluate_generator(generator, prior_pool, lookup_table, CONFIG, device, val=True): """ Evaluate kendetall and hardware constraint loss of generator """ total_loss = 0 evaluate_metric = {"gen_flops": [], "true_flops": []} for mac in range(CONFIG.low_flops, CONFIG.high_flops, 10): hardware_constraint = torch.tensor(mac, dtype=torch.float32) hardware_constraint = hardware_constraint.view(-1, 1) hardware_constraint = hardware_constraint.to(device) prior = prior_pool.get_prior(hardware_constraint.item()) prior = prior.to(device) normalize_hardware_constraint = min_max_normalize( CONFIG.high_flops, CONFIG.low_flops, hardware_constraint) arch_param = generator(prior, normalize_hardware_constraint) arch_param = lookup_table.get_validation_arch_param(arch_param) layers_config = lookup_table.decode_arch_param(arch_param) gen_mac = lookup_table.get_model_flops(arch_param) hc_loss = cal_hc_loss(gen_mac.cuda(), hardware_constraint.item(), CONFIG.alpha, CONFIG.loss_penalty) evaluate_metric["gen_flops"].append(gen_mac.item()) evaluate_metric["true_flops"].append(mac) total_loss += hc_loss.item() tau, _ = stats.kendalltau(evaluate_metric["gen_flops"], evaluate_metric["true_flops"]) return evaluate_metric, total_loss, tau
def generator_validate(self, generator, model, loader, epoch, tau, hardware_constraint=360, arch_param=None, sample=False, info_for_logger=""): if generator is not None: generator.eval() model.eval() start_time = time.time() if sample: hardware_constraint, arch_param = self._get_arch_param(generator, hardware_constraint, valid=True) arch_param = self.lookup_table.get_validation_arch_param(arch_param) arch_param, hardware_constraint = self.set_arch_param(generator, model, hardware_constraint=hardware_constraint, arch_param=arch_param) else: hardware_constraint = torch.tensor(hardware_constraint) hardware_constraint.cuda() macs = self.lookup_table.get_model_macs(arch_param) logging.info("Generate model macs : {}".format(macs)) hc_loss = cal_hc_loss(macs.cuda(), hardware_constraint.item(), self.alpha, self.loss_penalty) with torch.no_grad(): for step, (X, y) in enumerate(loader): X, y = X.cuda(non_blocking=True), y.cuda(non_blocking=True) N = X.shape[0] outs = model(X, True) loss = self.criterion(outs, y) self._intermediate_stats_logging(outs, y, loss, step, epoch, N, len_loader=len(loader), val_or_train="Valid", hc_losses=hc_loss) top1_avg = self.top1.get_avg() self._epoch_stats_logging(start_time=start_time, epoch=epoch, val_or_train="val") self.writer.add_scalar("train_vs_val/"+"val"+"_hc_", macs, epoch) for avg in [self.top1, self.top5, self.losses, self.hc_losses]: avg.reset() return top1_avg, macs.item()
def search_train_loop(self, generator): self.epochs = self.warmup_epochs + self.search_epochs # Training generator best_loss = 10000.0 best_top1 = 0 tau = 5 for epoch in range(self.warmup_epochs, self.search_epochs): logging.info("Start to train for search epoch {}".format(epoch)) logging.info("Tau: {}".format(tau)) self._generator_training_step(generator, val_loader, epoch, tau, info_for_logger="_gen_train_step") # ================ Train ============================================ for i in range(): # Training generator arch_param, hardware_constraint = self.set_arch_param( generator, tau=tau) # ============== evaluation flops =============================== gen_flops = self.flops_table.predict_arch_param_efficiency( arch_param) hc_loss = cal_hc_loss(gen_flops.cuda(), hardware_constraint.item(), self.CONFIG.alpha, self.CONFIG.loss_penalty) # =============================================================== self.g_optimizer.zero_grad() # ============== predict top1 accuracy ========================== top1_avg = self.accuracy_predictor(arch_param) ce_loss = -1 * top1_avg # =============================================================== loss = ce_loss + hc_loss logging.info("HC loss : {}".format(hc_loss)) loss.backward() self.g_optimizer.step() self.g_optimizer.zero_grad() # ==================================================================== # ============== Valid =============================================== hardware_constraint, arch_param = self._get_arch_param( generator, hardware_constraint, valid=True) arch_param = self.calculate_one_hot(arch_param) arch_param, hardware_constraint = self.set_arch_param( generator, model, hardware_constraint=hardware_constraint, arch_param=arch_param) # ============== evaluation flops =============================== gen_flops = self.flops_table.predict_arch_param_efficiency( arch_param) hc_loss = cal_hc_loss(gen_flops.cuda(), hardware_constraint.item(), self.CONFIG.alpha, self.CONFIG.loss_penalty) # =============================================================== # ============== predict top1 accuracy ========================== top1_avg = self.accuracy_predictor(arch_param) logger.info("Valid : Top-1 avg : {}".format(top1_avg)) # =============================================================== # ==================================================================== # ============== Evaluate ============================================ total_loss = 0 evaluate_metric = {"gen_flops": [], "true_flops": []} for flops in range(self.CONFIG.low_macs, self.CONFIG.high_macs, 10): hardware_constraint = torch.tensor(flops, dtpye=torch.float32) hardware_constraint = hardware_constraint.view(-1, 1) hardware_constraint = hardware_constraint.to(self.device) normalize_hardware_constraint = min_max_normalize( self.CONFIG.high_macs, self.CONFIG.low_macs, hardware_constraint) noise = torch.randn(*self.backbone.shape) noise = noise.to(device) noise *= 0 arch_param = generator(self.backbone, normalize_hardware_constraint, noise) # ============== evaluation flops =============================== gen_flops = self.flops_table.predict_arch_param_efficiency( arch_param) hc_loss = cal_hc_loss(gen_flops.cuda(), hardware_constraint.item(), self.CONFIG.alpha, self.CONFIG.loss_penalty) # =============================================================== evaluate_metric["gen_flops"].append(gen_flops) evaluate_metric["true_flops"].append(flops) total_loss += hc_loss.item() kendall_tau, _ = stats.kendalltau(evaluate_metric["gen_flops"], evaluate_metric["true_flops"]) # ==================================================================== logging.info("Total loss : {}".format(total_loss)) if best_loss > total_loss: logging.info("Best loss by now: {} Tau : {}.Save model".format( total_loss, kendall_tau)) best_loss = total_loss save_generator_evaluate_metric( evaluate_metric, self.CONFIG.path_to_generator_eval) save(generator, self.g_optimizer, self.CONFIG.path_to_save_generator) if top1_avg > best_top1 and total_loss < 0.4: logging.info( "Best top1-avg by now: {}.Save model".format(top1_avg)) best_top1 = top1_avg save(generator, self.g_optimizer, self.CONFIG.path_to_best_avg_generator) save(generator, self.g_optimizer, "./logs/generator/{}.pth".format(total_loss)) tau *= self.CONFIG.tau_decay self.noise_weight = self.noise_weight * self.CONFIG.noise_decay if self.noise_weight > 0.0001 else 0 logging.info("Noise weight : {}".format(self.noise_weight)) logging.info("Best loss: {}".format(best_loss)) save(generator, self.g_optimizer, self.CONFIG.path_to_fianl_generator)
def generator_validate(self, generator, model, loader, epoch, target_hardware_constraint=None, arch_param=None, info_for_logger=""): if generator is not None: generator.eval() model.eval() start_time = time.time() if arch_param is None: if target_hardware_constraint is None: target_hardware_constraint = self._get_target_hardware_constraint( ) arch_param = self._get_arch_param(generator, target_hardware_constraint) arch_param = self.set_arch_param( model, arch_param) # Validate architecture parameter else: target_hardware_constraint = self._get_target_hardware_constraint( target_hardware_constraint) arch_param = self._get_arch_param(generator, target_hardware_constraint) arch_param = self.set_arch_param( model, arch_param) # Validate architecture parameter else: arch_param = self.set_arch_param(model, arch_param) flops = self.lookup_table.get_model_flops(arch_param) logging.info("Generate model flops : {}".format(flops)) hc_loss = cal_hc_loss(flops.cuda(), target_hardware_constraint.item(), self.CONFIG.alpha, self.CONFIG.loss_penalty) with torch.no_grad(): for step, (X, y) in enumerate(loader): X, y = X.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) N = X.shape[0] outs = model(X, True) loss = self.criterion(outs, y) self._intermediate_stats_logging(outs, y, loss, step, epoch, N, len_loader=len(loader), val_or_train="Valid", hc_losses=hc_loss) top1_avg = self.top1.get_avg() self._epoch_stats_logging(start_time=start_time, epoch=epoch, val_or_train="val") self.writer.add_scalar("train_vs_val/" + "val" + "_hc_", flops, epoch) for avg in [self.top1, self.top5, self.losses, self.hc_losses]: avg.reset() return top1_avg, flops.item()