def _controller_update(self, steps, finished_e_steps, finished_c_steps): controller_loss_meter = utils.AverageMeter() controller_stat_meters = utils.OrderedStats() rollout_stat_meters = utils.OrderedStats() self.controller.set_mode("train") for i_cont in range(1, steps + 1): print("\reva step {}/{} ; controller step {}/{}"\ .format(finished_e_steps, self.evaluator_steps, finished_c_steps+i_cont, self.controller_steps), end="") rollouts = self.controller.sample(self.controller_samples, self.rollout_batch_size) # if self.rollout_type == "differentiable": if self.is_differentiable: self.controller.zero_grad() step_loss = {"_": 0.} rollouts = self.evaluator.evaluate_rollouts( rollouts, is_training=True, callback=partial(self._backward_rollout_to_controller, step_loss=step_loss)) self.evaluator.update_rollouts(rollouts) # if self.rollout_type == "differentiable": if self.is_differentiable: # differntiable rollout (controller is optimized using differentiable relaxation) # adjust lr and call step_current_gradients # (update using the accumulated gradients) controller_loss = step_loss["_"] / self.controller_samples if self.controller_samples != 1: # adjust the lr to keep the effective learning rate unchanged lr_bak = self.controller_optimizer.param_groups[0]["lr"] self.controller_optimizer.param_groups[0]["lr"] \ = lr_bak / self.controller_samples self.controller.step_current_gradient( self.controller_optimizer) if self.controller_samples != 1: self.controller_optimizer.param_groups[0]["lr"] = lr_bak else: # other rollout types controller_loss = self.controller.step( rollouts, self.controller_optimizer, perf_name="reward") # update meters controller_loss_meter.update(controller_loss) controller_stats = self.controller.summary(rollouts, log=False) if controller_stats is not None: controller_stat_meters.update(controller_stats) r_stats = OrderedDict() for n in rollouts[0].perf: r_stats[n] = np.mean([r.perf[n] for r in rollouts]) rollout_stat_meters.update(r_stats) print("\r", end="") return controller_loss, rollout_stat_meters.avgs( ), controller_stat_meters.avgs()
def train_epoch(self, train_queue, model, criterion, optimizer, device, epoch): expect(self._is_setup, "trainer.setup should be called first") top1 = utils.AverageMeter() top5 = utils.AverageMeter() losses_obj = utils.OrderedStats() model.train() for step, (inputs, targets) in enumerate(train_queue): inputs = inputs.to(self.device) optimizer.zero_grad() predictions = model.forward(inputs) losses = criterion(inputs, predictions, targets, model) loss = sum(losses.values()) loss.backward() nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip) optimizer.step() prec1, prec5 = self._acc_func(inputs, predictions, targets, model) n = inputs.size(0) losses_obj.update(losses) top1.update(prec1.item(), n) top5.update(prec5.item(), n) if step % self.report_every == 0: self.logger.info("train %03d %.2f%%; %.2f%%; %s", step, top1.avg, top5.avg, "; ".join( ["{}: {:.3f}".format(perf_n, v) \ for perf_n, v in losses_obj.avgs().items()])) return top1.avg, sum(losses_obj.avgs().values())
def infer_epoch(self, valid_queue, model, criterion, device): expect(self._is_setup, "trainer.setup should be called first") objs = utils.AverageMeter() top1 = utils.AverageMeter() top5 = utils.AverageMeter() objective_perfs = utils.OrderedStats() model.eval() context = torch.no_grad if self.eval_no_grad else nullcontext with context(): for step, (inputs, target) in enumerate(valid_queue): inputs = inputs.to(device) target = target.to(device) logits = model(inputs) loss = criterion(logits, target) perfs = self._perf_func(inputs, logits, target, model) prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) n = inputs.size(0) objective_perfs.update(dict(zip(self._perf_names, perfs)), n=n) objs.update(loss.item(), n) top1.update(prec1.item(), n) top5.update(prec5.item(), n) if step % self.report_every == 0: self.logger.info("valid %03d %e %f %f %s", step, objs.avg, top1.avg, top5.avg, "; ".join(["{}: {:.3f}".format(perf_n, v) \ for perf_n, v in objective_perfs.avgs().items()])) return top1.avg, objs.avg, objective_perfs.avgs()
def _evaluator_update(self, steps, finished_e_steps, finished_c_steps): self.controller.set_mode("eval") eva_stat_meters = utils.OrderedStats() for i_eva in range(1, steps + 1): # mepa stands for meta param e_stats = self.evaluator.update_evaluator(self.controller) eva_stat_meters.update(e_stats) print("\reva step {}/{} ; controller step {}/{}; {}" \ .format(finished_e_steps+i_eva, self.evaluator_steps, finished_c_steps, self.controller_steps, ";".join([" %.3f" % v for k, v in eva_stat_meters.avgs().items()])), end="") return eva_stat_meters.avgs()
def infer_epoch(self, valid_queue, model, criterion, device): expect(self._is_setup, "trainer.setup should be called first") top1 = utils.AverageMeter() top5 = utils.AverageMeter() objective_perfs = utils.OrderedStats() losses_obj = utils.OrderedStats() all_perfs = [] model.eval() context = torch.no_grad if self.eval_no_grad else nullcontext with context(): for step, (inputs, targets) in enumerate(valid_queue): inputs = inputs.to(device) # targets = targets.to(device) predictions = model.forward(inputs) losses = criterion(inputs, predictions, targets, model) prec1, prec5 = self._acc_func(inputs, predictions, targets, model) perfs = self._perf_func(inputs, predictions, targets, model) all_perfs.append(perfs) objective_perfs.update(dict(zip(self._perf_names, perfs))) losses_obj.update(losses) n = inputs.size(0) top1.update(prec1.item(), n) top5.update(prec5.item(), n) if step % self.report_every == 0: self.logger.info( "valid %03d %.2f%%; %.2f%%; %s", step, top1.avg, top5.avg, "; ".join([ "{}: {:.3f}".format(perf_n, v) for perf_n, v in \ list(objective_perfs.avgs().items()) + \ list(losses_obj.avgs().items())])) all_perfs = list(zip(*all_perfs)) obj_perfs = { k: self.objective.aggregate_fn(k, False)(v) for k, v in zip(self._perf_names, all_perfs) } return top1.avg, sum(losses_obj.avgs().values()), obj_perfs
def infer_epoch(self, valid_queue, model, criterion, device): expect(self._is_setup, "trainer.setup should be called first") objs = utils.AverageMeter() top1 = utils.AverageMeter() top5 = utils.AverageMeter() objective_perfs = utils.OrderedStats() all_perfs = [] model.eval() context = torch.no_grad if self.eval_no_grad else nullcontext with context(): for step, (inputs, target) in enumerate(valid_queue): inputs = inputs.to(device) target = target.to(device) logits = model(inputs) loss = criterion(logits, target) perfs = self._perf_func(inputs, logits, target, model) all_perfs.append(perfs) prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) n = inputs.size(0) # objective_perfs.update(dict(zip(self._perf_names, perfs)), n=n) objs.update(loss.item(), n) top1.update(prec1.item(), n) top5.update(prec5.item(), n) del loss if step % self.report_every == 0: all_perfs_by_name = list(zip(*all_perfs)) # support use objective aggregate fn, for stat method other than mean # e.g., adversarial distance median; detection mAP (see det_trainer.py) obj_perfs = { k: self.objective.aggregate_fn(k, False)(v) for k, v in zip(self._perf_names, all_perfs_by_name) } self.logger.info("valid %03d %e %f %f %s", step, objs.avg, top1.avg, top5.avg, "; ".join(["{}: {:.3f}".format(perf_n, v) \ # for perf_n, v in objective_perfs.avgs().items()])) for perf_n, v in obj_perfs.items()])) all_perfs_by_name = list(zip(*all_perfs)) obj_perfs = { k: self.objective.aggregate_fn(k, False)(v) for k, v in zip(self._perf_names, all_perfs_by_name) } return top1.avg, objs.avg, obj_perfs
def infer_epoch(self, valid_queue, model, criterion, device): expect(self._is_setup, "trainer.setup should be called first") cls_objs = utils.AverageMeter() loc_objs = utils.AverageMeter() top1 = utils.AverageMeter() top5 = utils.AverageMeter() objective_perfs = utils.OrderedStats() model.eval() context = torch.no_grad if self.eval_no_grad else nullcontext with context(): for step, (inputs, targets) in enumerate(valid_queue): inputs = inputs.to(device) # targets = targets.to(device) predictions = model.forward(inputs) classification_loss, regression_loss = criterion( inputs, predictions, targets, model) prec1, prec5 = self._acc_func(inputs, predictions, targets, model) perfs = self._perf_func(inputs, predictions, targets, model) objective_perfs.update(dict(zip(self._perf_names, perfs))) n = inputs.size(0) cls_objs.update(classification_loss.item(), n) loc_objs.update(regression_loss.item(), n) top1.update(prec1.item(), n) top5.update(prec5.item(), n) if step % self.report_every == 0: self.logger.info("valid %03d %e %e; %.2f%%; %.2f%%; %s", step, cls_objs.avg, loc_objs.avg, top1.avg, top5.avg, "; ".join(["{}: {:.3f}".format(perf_n, v) \ for perf_n, v in objective_perfs.avgs().items()])) stats = self.dataset.evaluate_detections(self.objective.all_boxes, self.eval_dir) self.logger.info("mAP: {}".format(stats[0])) return top1.avg, cls_objs.avg + loc_objs.avg, objective_perfs.avgs()
def train(self): #pylint: disable=too-many-branches assert self.is_setup, "Must call `trainer.setup` method before calling `trainer.train`." if self.interleave_controller_every is not None: inter_steps = self.controller_steps evaluator_steps = self.interleave_controller_every controller_steps = 1 else: inter_steps = 1 evaluator_steps = self.evaluator_steps controller_steps = self.controller_steps for epoch in range(self.last_epoch + 1, self.epochs + 1): c_loss_meter = utils.AverageMeter() rollout_stat_meters = utils.OrderedStats( ) # rollout performance stats from evaluator c_stat_meters = utils.OrderedStats() # other stats from controller eva_stat_meters = utils.OrderedStats( ) # other stats from `evaluator.update_evaluator` self.epoch = epoch # this is redundant as Component.on_epoch_start also set this # call `on_epoch_start` of sub-components # also schedule values and optimizer learning rates self.on_epoch_start(epoch) finished_e_steps = 0 finished_c_steps = 0 for i_inter in range(1, inter_steps + 1): # interleave mepa/controller training # meta parameter training if evaluator_steps > 0: e_stats = self._evaluator_update(evaluator_steps, finished_e_steps, finished_c_steps) eva_stat_meters.update(e_stats) finished_e_steps += evaluator_steps if epoch >= self.controller_train_begin and \ epoch % self.controller_train_every == 0 and controller_steps > 0: # controller training c_loss, rollout_stats, c_stats \ = self._controller_update(controller_steps, finished_e_steps, finished_c_steps) # update meters if c_loss is not None: c_loss_meter.update(c_loss) if rollout_stats is not None: rollout_stat_meters.update(rollout_stats) if c_stats is not None: c_stat_meters.update(c_stats) finished_c_steps += controller_steps if self.interleave_report_every and i_inter % self.interleave_report_every == 0: # log for every `interleave_report_every` interleaving steps self.logger.info("(inter step %3d): " "evaluator (%3d/%3d) %s ; " "controller (%3d/%3d) %s", i_inter, finished_e_steps, self.evaluator_steps, "; ".join( ["{}: {:.3f}".format(n, v) \ for n, v in eva_stat_meters.avgs().items()]), finished_c_steps, self.controller_steps, "" if not rollout_stat_meters else "; ".join( ["{}: {:.3f}".format(n, v) \ for n, v in rollout_stat_meters.avgs().items()])) # log infomations of this epoch if eva_stat_meters: self.logger.info("Epoch %3d: [evaluator update] %s", epoch, "; ".join(["{}: {:.3f}".format(n, v) \ for n, v in eva_stat_meters.avgs().items()])) if rollout_stat_meters: self.logger.info("Epoch %3d: [controller update] controller loss: %.3f ; " "rollout performance: %s", epoch, c_loss_meter.avg, "; ".join(["{}: {:.3f}".format(n, v) \ for n, v in rollout_stat_meters.avgs().items()])) if c_stat_meters: self.logger.info("[controller stats] %s", \ "; ".join(["{}: {:.3f}".format(n, v) \ for n, v in c_stat_meters.avgs().items()])) # maybe write tensorboard info if not self.writer.is_none(): if eva_stat_meters: for n, meter in eva_stat_meters.items(): self.writer.add_scalar( "evaluator_update/{}".format(n.replace(" ", "-")), meter.avg, epoch) if rollout_stat_meters: for n, meter in rollout_stat_meters.items(): self.writer.add_scalar( "controller_update/{}".format(n.replace(" ", "-")), meter.avg, epoch) if c_stat_meters: for n, meter in c_stat_meters.items(): self.writer.add_scalar( "controller_stats/{}".format(n.replace(" ", "-")), meter.avg, epoch) if not c_loss_meter.is_empty(): self.writer.add_scalar("controller_loss", c_loss_meter.avg, epoch) # maybe save checkpoints self.maybe_save() # maybe derive archs and test if self.test_every and self.epoch % self.test_every == 0: self.test() self.on_epoch_end(epoch) # call `on_epoch_end` of sub-components # `final_save` pickle dump the weights_manager and controller directly, # instead of the state dict self.final_save()