示例#1
0
    def _controller_update(self, steps, finished_e_steps, finished_c_steps):
        controller_loss_meter = utils.AverageMeter()
        controller_stat_meters = utils.OrderedStats()
        rollout_stat_meters = utils.OrderedStats()

        self.controller.set_mode("train")
        for i_cont in range(1, steps + 1):
            print("\reva step {}/{} ; controller step {}/{}"\
                  .format(finished_e_steps, self.evaluator_steps,
                          finished_c_steps+i_cont, self.controller_steps),
                  end="")

            rollouts = self.controller.sample(self.controller_samples,
                                              self.rollout_batch_size)
            # if self.rollout_type == "differentiable":
            if self.is_differentiable:
                self.controller.zero_grad()

            step_loss = {"_": 0.}
            rollouts = self.evaluator.evaluate_rollouts(
                rollouts,
                is_training=True,
                callback=partial(self._backward_rollout_to_controller,
                                 step_loss=step_loss))
            self.evaluator.update_rollouts(rollouts)

            # if self.rollout_type == "differentiable":
            if self.is_differentiable:
                # differntiable rollout (controller is optimized using differentiable relaxation)
                # adjust lr and call step_current_gradients
                # (update using the accumulated gradients)
                controller_loss = step_loss["_"] / self.controller_samples
                if self.controller_samples != 1:
                    # adjust the lr to keep the effective learning rate unchanged
                    lr_bak = self.controller_optimizer.param_groups[0]["lr"]
                    self.controller_optimizer.param_groups[0]["lr"] \
                        = lr_bak / self.controller_samples
                self.controller.step_current_gradient(
                    self.controller_optimizer)
                if self.controller_samples != 1:
                    self.controller_optimizer.param_groups[0]["lr"] = lr_bak
            else:  # other rollout types
                controller_loss = self.controller.step(
                    rollouts, self.controller_optimizer, perf_name="reward")

            # update meters
            controller_loss_meter.update(controller_loss)
            controller_stats = self.controller.summary(rollouts, log=False)
            if controller_stats is not None:
                controller_stat_meters.update(controller_stats)

            r_stats = OrderedDict()
            for n in rollouts[0].perf:
                r_stats[n] = np.mean([r.perf[n] for r in rollouts])
            rollout_stat_meters.update(r_stats)

        print("\r", end="")

        return controller_loss, rollout_stat_meters.avgs(
        ), controller_stat_meters.avgs()
示例#2
0
    def train_epoch(self, train_queue, model, criterion, optimizer, device,
                    epoch):
        expect(self._is_setup, "trainer.setup should be called first")
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        losses_obj = utils.OrderedStats()
        model.train()

        for step, (inputs, targets) in enumerate(train_queue):
            inputs = inputs.to(self.device)

            optimizer.zero_grad()
            predictions = model.forward(inputs)
            losses = criterion(inputs, predictions, targets, model)
            loss = sum(losses.values())
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip)
            optimizer.step()

            prec1, prec5 = self._acc_func(inputs, predictions, targets, model)

            n = inputs.size(0)
            losses_obj.update(losses)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

            if step % self.report_every == 0:
                self.logger.info("train %03d %.2f%%; %.2f%%; %s",
                                 step, top1.avg, top5.avg, "; ".join(
                                     ["{}: {:.3f}".format(perf_n, v) \
                                      for perf_n, v in losses_obj.avgs().items()]))
        return top1.avg, sum(losses_obj.avgs().values())
示例#3
0
    def infer_epoch(self, valid_queue, model, criterion, device):
        expect(self._is_setup, "trainer.setup should be called first")
        objs = utils.AverageMeter()
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        objective_perfs = utils.OrderedStats()
        model.eval()

        context = torch.no_grad if self.eval_no_grad else nullcontext
        with context():
            for step, (inputs, target) in enumerate(valid_queue):
                inputs = inputs.to(device)
                target = target.to(device)

                logits = model(inputs)
                loss = criterion(logits, target)
                perfs = self._perf_func(inputs, logits, target, model)
                prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
                n = inputs.size(0)
                objective_perfs.update(dict(zip(self._perf_names, perfs)), n=n)
                objs.update(loss.item(), n)
                top1.update(prec1.item(), n)
                top5.update(prec5.item(), n)

                if step % self.report_every == 0:
                    self.logger.info("valid %03d %e %f %f %s", step, objs.avg, top1.avg, top5.avg,
                                     "; ".join(["{}: {:.3f}".format(perf_n, v) \
                                                for perf_n, v in objective_perfs.avgs().items()]))

        return top1.avg, objs.avg, objective_perfs.avgs()
示例#4
0
文件: simple.py 项目: zeta1999/aw_nas
    def _evaluator_update(self, steps, finished_e_steps, finished_c_steps):
        self.controller.set_mode("eval")
        eva_stat_meters = utils.OrderedStats()

        for i_eva in range(1, steps + 1):  # mepa stands for meta param
            e_stats = self.evaluator.update_evaluator(self.controller)
            eva_stat_meters.update(e_stats)
            print("\reva step {}/{} ; controller step {}/{}; {}" \
                  .format(finished_e_steps+i_eva, self.evaluator_steps,
                          finished_c_steps, self.controller_steps, ";".join([" %.3f" % v for k, v in eva_stat_meters.avgs().items()])),
                  end="")
        return eva_stat_meters.avgs()
示例#5
0
    def infer_epoch(self, valid_queue, model, criterion, device):
        expect(self._is_setup, "trainer.setup should be called first")
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        objective_perfs = utils.OrderedStats()
        losses_obj = utils.OrderedStats()
        all_perfs = []
        model.eval()

        context = torch.no_grad if self.eval_no_grad else nullcontext
        with context():
            for step, (inputs, targets) in enumerate(valid_queue):
                inputs = inputs.to(device)
                # targets = targets.to(device)

                predictions = model.forward(inputs)
                losses = criterion(inputs, predictions, targets, model)
                prec1, prec5 = self._acc_func(inputs, predictions, targets,
                                              model)
                perfs = self._perf_func(inputs, predictions, targets, model)
                all_perfs.append(perfs)
                objective_perfs.update(dict(zip(self._perf_names, perfs)))
                losses_obj.update(losses)
                n = inputs.size(0)
                top1.update(prec1.item(), n)
                top5.update(prec5.item(), n)

                if step % self.report_every == 0:
                    self.logger.info(
                        "valid %03d %.2f%%; %.2f%%; %s", step, top1.avg, top5.avg,
                        "; ".join([
                            "{}: {:.3f}".format(perf_n, v) for perf_n, v in \
                            list(objective_perfs.avgs().items()) + \
                            list(losses_obj.avgs().items())]))
        all_perfs = list(zip(*all_perfs))
        obj_perfs = {
            k: self.objective.aggregate_fn(k, False)(v)
            for k, v in zip(self._perf_names, all_perfs)
        }
        return top1.avg, sum(losses_obj.avgs().values()), obj_perfs
示例#6
0
    def infer_epoch(self, valid_queue, model, criterion, device):
        expect(self._is_setup, "trainer.setup should be called first")
        objs = utils.AverageMeter()
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        objective_perfs = utils.OrderedStats()
        all_perfs = []
        model.eval()

        context = torch.no_grad if self.eval_no_grad else nullcontext
        with context():
            for step, (inputs, target) in enumerate(valid_queue):
                inputs = inputs.to(device)
                target = target.to(device)

                logits = model(inputs)
                loss = criterion(logits, target)
                perfs = self._perf_func(inputs, logits, target, model)
                all_perfs.append(perfs)
                prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
                n = inputs.size(0)
                # objective_perfs.update(dict(zip(self._perf_names, perfs)), n=n)
                objs.update(loss.item(), n)
                top1.update(prec1.item(), n)
                top5.update(prec5.item(), n)
                del loss
                if step % self.report_every == 0:
                    all_perfs_by_name = list(zip(*all_perfs))
                    # support use objective aggregate fn, for stat method other than mean
                    # e.g., adversarial distance median; detection mAP (see det_trainer.py)
                    obj_perfs = {
                        k: self.objective.aggregate_fn(k, False)(v)
                        for k, v in zip(self._perf_names, all_perfs_by_name)
                    }
                    self.logger.info("valid %03d %e %f %f %s", step, objs.avg, top1.avg, top5.avg,
                                     "; ".join(["{}: {:.3f}".format(perf_n, v) \
                                                # for perf_n, v in objective_perfs.avgs().items()]))

                                                for perf_n, v in obj_perfs.items()]))
        all_perfs_by_name = list(zip(*all_perfs))
        obj_perfs = {
            k: self.objective.aggregate_fn(k, False)(v)
            for k, v in zip(self._perf_names, all_perfs_by_name)
        }
        return top1.avg, objs.avg, obj_perfs
示例#7
0
    def infer_epoch(self, valid_queue, model, criterion, device):
        expect(self._is_setup, "trainer.setup should be called first")
        cls_objs = utils.AverageMeter()
        loc_objs = utils.AverageMeter()
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        objective_perfs = utils.OrderedStats()
        model.eval()

        context = torch.no_grad if self.eval_no_grad else nullcontext
        with context():
            for step, (inputs, targets) in enumerate(valid_queue):
                inputs = inputs.to(device)
                # targets = targets.to(device)

                predictions = model.forward(inputs)
                classification_loss, regression_loss = criterion(
                    inputs, predictions, targets, model)

                prec1, prec5 = self._acc_func(inputs, predictions, targets, model)

                perfs = self._perf_func(inputs, predictions, targets, model)
                objective_perfs.update(dict(zip(self._perf_names, perfs)))
                n = inputs.size(0)
                cls_objs.update(classification_loss.item(), n)
                loc_objs.update(regression_loss.item(), n)
                top1.update(prec1.item(), n)
                top5.update(prec5.item(), n)

                if step % self.report_every == 0:
                    self.logger.info("valid %03d %e %e;  %.2f%%; %.2f%%; %s",
                                     step, cls_objs.avg, loc_objs.avg, top1.avg, top5.avg,
                                     "; ".join(["{}: {:.3f}".format(perf_n, v) \
                                                for perf_n, v in objective_perfs.avgs().items()]))
        stats = self.dataset.evaluate_detections(self.objective.all_boxes, self.eval_dir)
        self.logger.info("mAP: {}".format(stats[0]))
        return top1.avg, cls_objs.avg + loc_objs.avg, objective_perfs.avgs()
示例#8
0
文件: simple.py 项目: zeta1999/aw_nas
    def train(self):  #pylint: disable=too-many-branches
        assert self.is_setup, "Must call `trainer.setup` method before calling `trainer.train`."

        if self.interleave_controller_every is not None:
            inter_steps = self.controller_steps
            evaluator_steps = self.interleave_controller_every
            controller_steps = 1
        else:
            inter_steps = 1
            evaluator_steps = self.evaluator_steps
            controller_steps = self.controller_steps

        for epoch in range(self.last_epoch + 1, self.epochs + 1):
            c_loss_meter = utils.AverageMeter()
            rollout_stat_meters = utils.OrderedStats(
            )  # rollout performance stats from evaluator
            c_stat_meters = utils.OrderedStats()  # other stats from controller
            eva_stat_meters = utils.OrderedStats(
            )  # other stats from `evaluator.update_evaluator`

            self.epoch = epoch  # this is redundant as Component.on_epoch_start also set this

            # call `on_epoch_start` of sub-components
            # also schedule values and optimizer learning rates
            self.on_epoch_start(epoch)

            finished_e_steps = 0
            finished_c_steps = 0
            for i_inter in range(1, inter_steps +
                                 1):  # interleave mepa/controller training
                # meta parameter training
                if evaluator_steps > 0:
                    e_stats = self._evaluator_update(evaluator_steps,
                                                     finished_e_steps,
                                                     finished_c_steps)
                    eva_stat_meters.update(e_stats)
                    finished_e_steps += evaluator_steps

                if epoch >= self.controller_train_begin and \
                   epoch % self.controller_train_every == 0 and controller_steps > 0:
                    # controller training
                    c_loss, rollout_stats, c_stats \
                        = self._controller_update(controller_steps,
                                                  finished_e_steps, finished_c_steps)
                    # update meters
                    if c_loss is not None:
                        c_loss_meter.update(c_loss)
                    if rollout_stats is not None:
                        rollout_stat_meters.update(rollout_stats)
                    if c_stats is not None:
                        c_stat_meters.update(c_stats)

                    finished_c_steps += controller_steps

                if self.interleave_report_every and i_inter % self.interleave_report_every == 0:
                    # log for every `interleave_report_every` interleaving steps
                    self.logger.info("(inter step %3d): "
                                     "evaluator (%3d/%3d) %s ; "
                                     "controller (%3d/%3d) %s",
                                     i_inter, finished_e_steps, self.evaluator_steps,
                                     "; ".join(
                                         ["{}: {:.3f}".format(n, v) \
                                          for n, v in eva_stat_meters.avgs().items()]),
                                     finished_c_steps, self.controller_steps,
                                     "" if not rollout_stat_meters else "; ".join(
                                         ["{}: {:.3f}".format(n, v) \
                                          for n, v in rollout_stat_meters.avgs().items()]))

            # log infomations of this epoch
            if eva_stat_meters:
                self.logger.info("Epoch %3d: [evaluator update] %s", epoch,
                                 "; ".join(["{}: {:.3f}".format(n, v) \
                                            for n, v in eva_stat_meters.avgs().items()]))
            if rollout_stat_meters:
                self.logger.info("Epoch %3d: [controller update] controller loss: %.3f ; "
                                 "rollout performance: %s", epoch, c_loss_meter.avg,
                                 "; ".join(["{}: {:.3f}".format(n, v) \
                                            for n, v in rollout_stat_meters.avgs().items()]))
            if c_stat_meters:
                self.logger.info("[controller stats] %s", \
                                 "; ".join(["{}: {:.3f}".format(n, v) \
                                            for n, v in c_stat_meters.avgs().items()]))

            # maybe write tensorboard info
            if not self.writer.is_none():
                if eva_stat_meters:
                    for n, meter in eva_stat_meters.items():
                        self.writer.add_scalar(
                            "evaluator_update/{}".format(n.replace(" ", "-")),
                            meter.avg, epoch)
                if rollout_stat_meters:
                    for n, meter in rollout_stat_meters.items():
                        self.writer.add_scalar(
                            "controller_update/{}".format(n.replace(" ", "-")),
                            meter.avg, epoch)
                if c_stat_meters:
                    for n, meter in c_stat_meters.items():
                        self.writer.add_scalar(
                            "controller_stats/{}".format(n.replace(" ", "-")),
                            meter.avg, epoch)
                if not c_loss_meter.is_empty():
                    self.writer.add_scalar("controller_loss", c_loss_meter.avg,
                                           epoch)

            # maybe save checkpoints
            self.maybe_save()

            # maybe derive archs and test
            if self.test_every and self.epoch % self.test_every == 0:
                self.test()

            self.on_epoch_end(epoch)  # call `on_epoch_end` of sub-components

        # `final_save` pickle dump the weights_manager and controller directly,
        # instead of the state dict
        self.final_save()