コード例 #1
0
    def _generator_training_step(self, generator, model, loader, epoch, tau, info_for_logger=""):
        start_time = time.time()
        generator.train()
        model.eval()

        for step, (X, y) in enumerate(loader):
            arch_param, hardware_constraint = self.set_arch_param(generator, model, tau=tau)

            macs = self.lookup_table.get_model_macs(arch_param)
            logging.info("Generate model macs : {}".format(macs))

            hc_loss = cal_hc_loss(macs.cuda(), hardware_constraint.item(), self.alpha, self.loss_penalty)

            X, y = X.cuda(non_blocking=True), y.cuda(non_blocking=True)
            N = X.shape[0]

            self.g_optimizer.zero_grad()
            outs = model(X, True)

            ce_loss = self.criterion(outs, y)
            loss = ce_loss + hc_loss
            logging.info("HC loss : {}".format(hc_loss))
            loss.backward()

            self.g_optimizer.step()
            self.g_optimizer.zero_grad()

            self._intermediate_stats_logging(outs, y, loss, step, epoch, N, len_loader=len(loader), val_or_train="Train", hc_losses=hc_loss)
        self._epoch_stats_logging(start_time=start_time, epoch=epoch, val_or_train="train")
        for avg in [self.top1, self.top5, self.losses, self.hc_losses]:
            avg.reset()
コード例 #2
0
ファイル: evaluate.py プロジェクト: eric8607242/SGNAS
def evaluate_generator(generator,
                       prior_pool,
                       lookup_table,
                       CONFIG,
                       device,
                       val=True):
    """
    Evaluate kendetall and hardware constraint loss of generator
    """
    total_loss = 0

    evaluate_metric = {"gen_flops": [], "true_flops": []}
    for mac in range(CONFIG.low_flops, CONFIG.high_flops, 10):
        hardware_constraint = torch.tensor(mac, dtype=torch.float32)
        hardware_constraint = hardware_constraint.view(-1, 1)
        hardware_constraint = hardware_constraint.to(device)

        prior = prior_pool.get_prior(hardware_constraint.item())
        prior = prior.to(device)

        normalize_hardware_constraint = min_max_normalize(
            CONFIG.high_flops, CONFIG.low_flops, hardware_constraint)

        arch_param = generator(prior, normalize_hardware_constraint)
        arch_param = lookup_table.get_validation_arch_param(arch_param)

        layers_config = lookup_table.decode_arch_param(arch_param)

        gen_mac = lookup_table.get_model_flops(arch_param)
        hc_loss = cal_hc_loss(gen_mac.cuda(), hardware_constraint.item(),
                              CONFIG.alpha, CONFIG.loss_penalty)

        evaluate_metric["gen_flops"].append(gen_mac.item())
        evaluate_metric["true_flops"].append(mac)

        total_loss += hc_loss.item()
    tau, _ = stats.kendalltau(evaluate_metric["gen_flops"],
                              evaluate_metric["true_flops"])

    return evaluate_metric, total_loss, tau
コード例 #3
0
    def generator_validate(self, generator, model, loader, epoch, tau, hardware_constraint=360, arch_param=None, sample=False, info_for_logger=""):
        if generator is not None:
            generator.eval()
        model.eval()
        start_time = time.time()

        if sample:
            hardware_constraint, arch_param = self._get_arch_param(generator, hardware_constraint, valid=True)
            arch_param = self.lookup_table.get_validation_arch_param(arch_param)
            arch_param, hardware_constraint = self.set_arch_param(generator, model, hardware_constraint=hardware_constraint, arch_param=arch_param)
        else:
            hardware_constraint = torch.tensor(hardware_constraint)
            hardware_constraint.cuda()
            
        macs = self.lookup_table.get_model_macs(arch_param)
        logging.info("Generate model macs : {}".format(macs))

        hc_loss = cal_hc_loss(macs.cuda(), hardware_constraint.item(), self.alpha, self.loss_penalty)

        with torch.no_grad():
            for step, (X, y) in enumerate(loader):
                X, y = X.cuda(non_blocking=True), y.cuda(non_blocking=True)
                N = X.shape[0]

                outs = model(X, True)
                loss = self.criterion(outs, y)

                self._intermediate_stats_logging(outs, y, loss, step, epoch, N, len_loader=len(loader), val_or_train="Valid", hc_losses=hc_loss)

        top1_avg = self.top1.get_avg()
        self._epoch_stats_logging(start_time=start_time, epoch=epoch, val_or_train="val")
        self.writer.add_scalar("train_vs_val/"+"val"+"_hc_", macs, epoch)
        for avg in [self.top1, self.top5, self.losses, self.hc_losses]:
            avg.reset()

        return top1_avg, macs.item()
コード例 #4
0
    def search_train_loop(self, generator):
        self.epochs = self.warmup_epochs + self.search_epochs
        # Training generator
        best_loss = 10000.0
        best_top1 = 0
        tau = 5
        for epoch in range(self.warmup_epochs, self.search_epochs):
            logging.info("Start to train for search epoch {}".format(epoch))
            logging.info("Tau: {}".format(tau))
            self._generator_training_step(generator,
                                          val_loader,
                                          epoch,
                                          tau,
                                          info_for_logger="_gen_train_step")

            # ================ Train ============================================
            for i in range():
                # Training generator
                arch_param, hardware_constraint = self.set_arch_param(
                    generator, tau=tau)

                # ============== evaluation flops ===============================
                gen_flops = self.flops_table.predict_arch_param_efficiency(
                    arch_param)
                hc_loss = cal_hc_loss(gen_flops.cuda(),
                                      hardware_constraint.item(),
                                      self.CONFIG.alpha,
                                      self.CONFIG.loss_penalty)
                # ===============================================================
                self.g_optimizer.zero_grad()

                # ============== predict top1 accuracy ==========================
                top1_avg = self.accuracy_predictor(arch_param)
                ce_loss = -1 * top1_avg
                # ===============================================================
                loss = ce_loss + hc_loss
                logging.info("HC loss : {}".format(hc_loss))
                loss.backward()

                self.g_optimizer.step()
                self.g_optimizer.zero_grad()
            # ====================================================================

            # ============== Valid ===============================================
            hardware_constraint, arch_param = self._get_arch_param(
                generator, hardware_constraint, valid=True)
            arch_param = self.calculate_one_hot(arch_param)
            arch_param, hardware_constraint = self.set_arch_param(
                generator,
                model,
                hardware_constraint=hardware_constraint,
                arch_param=arch_param)
            # ============== evaluation flops ===============================
            gen_flops = self.flops_table.predict_arch_param_efficiency(
                arch_param)

            hc_loss = cal_hc_loss(gen_flops.cuda(), hardware_constraint.item(),
                                  self.CONFIG.alpha, self.CONFIG.loss_penalty)
            # ===============================================================

            # ============== predict top1 accuracy ==========================
            top1_avg = self.accuracy_predictor(arch_param)
            logger.info("Valid : Top-1 avg : {}".format(top1_avg))
            # ===============================================================

            # ====================================================================

            # ============== Evaluate ============================================
            total_loss = 0
            evaluate_metric = {"gen_flops": [], "true_flops": []}
            for flops in range(self.CONFIG.low_macs, self.CONFIG.high_macs,
                               10):
                hardware_constraint = torch.tensor(flops, dtpye=torch.float32)
                hardware_constraint = hardware_constraint.view(-1, 1)
                hardware_constraint = hardware_constraint.to(self.device)

                normalize_hardware_constraint = min_max_normalize(
                    self.CONFIG.high_macs, self.CONFIG.low_macs,
                    hardware_constraint)

                noise = torch.randn(*self.backbone.shape)
                noise = noise.to(device)
                noise *= 0

                arch_param = generator(self.backbone,
                                       normalize_hardware_constraint, noise)
                # ============== evaluation flops ===============================
                gen_flops = self.flops_table.predict_arch_param_efficiency(
                    arch_param)
                hc_loss = cal_hc_loss(gen_flops.cuda(),
                                      hardware_constraint.item(),
                                      self.CONFIG.alpha,
                                      self.CONFIG.loss_penalty)
                # ===============================================================

                evaluate_metric["gen_flops"].append(gen_flops)
                evaluate_metric["true_flops"].append(flops)

                total_loss += hc_loss.item()
            kendall_tau, _ = stats.kendalltau(evaluate_metric["gen_flops"],
                                              evaluate_metric["true_flops"])
            # ====================================================================

            logging.info("Total loss : {}".format(total_loss))
            if best_loss > total_loss:
                logging.info("Best loss by now: {} Tau : {}.Save model".format(
                    total_loss, kendall_tau))
                best_loss = total_loss
                save_generator_evaluate_metric(
                    evaluate_metric, self.CONFIG.path_to_generator_eval)
                save(generator, self.g_optimizer,
                     self.CONFIG.path_to_save_generator)
            if top1_avg > best_top1 and total_loss < 0.4:
                logging.info(
                    "Best top1-avg by now: {}.Save model".format(top1_avg))
                best_top1 = top1_avg
                save(generator, self.g_optimizer,
                     self.CONFIG.path_to_best_avg_generator)
            save(generator, self.g_optimizer,
                 "./logs/generator/{}.pth".format(total_loss))

            tau *= self.CONFIG.tau_decay
            self.noise_weight = self.noise_weight * self.CONFIG.noise_decay if self.noise_weight > 0.0001 else 0
            logging.info("Noise weight : {}".format(self.noise_weight))
        logging.info("Best loss: {}".format(best_loss))
        save(generator, self.g_optimizer, self.CONFIG.path_to_fianl_generator)
コード例 #5
0
    def generator_validate(self,
                           generator,
                           model,
                           loader,
                           epoch,
                           target_hardware_constraint=None,
                           arch_param=None,
                           info_for_logger=""):
        if generator is not None:
            generator.eval()
        model.eval()
        start_time = time.time()

        if arch_param is None:
            if target_hardware_constraint is None:
                target_hardware_constraint = self._get_target_hardware_constraint(
                )
                arch_param = self._get_arch_param(generator,
                                                  target_hardware_constraint)
                arch_param = self.set_arch_param(
                    model, arch_param)  # Validate architecture parameter

            else:
                target_hardware_constraint = self._get_target_hardware_constraint(
                    target_hardware_constraint)
                arch_param = self._get_arch_param(generator,
                                                  target_hardware_constraint)
                arch_param = self.set_arch_param(
                    model, arch_param)  # Validate architecture parameter
        else:
            arch_param = self.set_arch_param(model, arch_param)

        flops = self.lookup_table.get_model_flops(arch_param)
        logging.info("Generate model flops : {}".format(flops))

        hc_loss = cal_hc_loss(flops.cuda(), target_hardware_constraint.item(),
                              self.CONFIG.alpha, self.CONFIG.loss_penalty)

        with torch.no_grad():
            for step, (X, y) in enumerate(loader):
                X, y = X.to(self.device,
                            non_blocking=True), y.to(self.device,
                                                     non_blocking=True)
                N = X.shape[0]

                outs = model(X, True)
                loss = self.criterion(outs, y)

                self._intermediate_stats_logging(outs,
                                                 y,
                                                 loss,
                                                 step,
                                                 epoch,
                                                 N,
                                                 len_loader=len(loader),
                                                 val_or_train="Valid",
                                                 hc_losses=hc_loss)

        top1_avg = self.top1.get_avg()
        self._epoch_stats_logging(start_time=start_time,
                                  epoch=epoch,
                                  val_or_train="val")
        self.writer.add_scalar("train_vs_val/" + "val" + "_hc_", flops, epoch)
        for avg in [self.top1, self.top5, self.losses, self.hc_losses]:
            avg.reset()

        return top1_avg, flops.item()