def train_epoch(self, train_queue, model, criterion, optimizer, device, epoch): expect(self._is_setup, "trainer.setup should be called first") cls_objs = utils.AverageMeter() loc_objs = utils.AverageMeter() top1 = utils.AverageMeter() top5 = utils.AverageMeter() model.train() for step, (inputs, targets) in enumerate(train_queue): inputs = inputs.to(self.device) # targets = targets.to(self.device) optimizer.zero_grad() predictions = model.forward(inputs) classification_loss, regression_loss = criterion(inputs, predictions, targets, model) loss = classification_loss + regression_loss loss.backward() nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip) optimizer.step() prec1, prec5 = self._acc_func(inputs, predictions, targets, model) n = inputs.size(0) cls_objs.update(classification_loss.item(), n) loc_objs.update(regression_loss.item(), n) top1.update(prec1.item(), n) top5.update(prec5.item(), n) if step % self.report_every == 0: self.logger.info("train %03d %.3f %.3f; %.2f%%; %.2f%%", step, cls_objs.avg, loc_objs.avg, top1.avg, top5.avg) return top1.avg, cls_objs.avg + loc_objs.avg
def train_epoch(self, train_queue, model, criterion, optimizer, device, epoch): expect(self._is_setup, "trainer.setup should be called first") top1 = utils.AverageMeter() top5 = utils.AverageMeter() losses_obj = utils.OrderedStats() model.train() for step, (inputs, targets) in enumerate(train_queue): inputs = inputs.to(self.device) optimizer.zero_grad() predictions = model.forward(inputs) losses = criterion(inputs, predictions, targets, model) loss = sum(losses.values()) loss.backward() nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip) optimizer.step() prec1, prec5 = self._acc_func(inputs, predictions, targets, model) n = inputs.size(0) losses_obj.update(losses) top1.update(prec1.item(), n) top5.update(prec5.item(), n) if step % self.report_every == 0: self.logger.info("train %03d %.2f%%; %.2f%%; %s", step, top1.avg, top5.avg, "; ".join( ["{}: {:.3f}".format(perf_n, v) \ for perf_n, v in losses_obj.avgs().items()])) return top1.avg, sum(losses_obj.avgs().values())
def infer_epoch(self, valid_queue, model, criterion, device): expect(self._is_setup, "trainer.setup should be called first") objs = utils.AverageMeter() top1 = utils.AverageMeter() top5 = utils.AverageMeter() objective_perfs = utils.OrderedStats() model.eval() context = torch.no_grad if self.eval_no_grad else nullcontext with context(): for step, (inputs, target) in enumerate(valid_queue): inputs = inputs.to(device) target = target.to(device) logits = model(inputs) loss = criterion(logits, target) perfs = self._perf_func(inputs, logits, target, model) prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) n = inputs.size(0) objective_perfs.update(dict(zip(self._perf_names, perfs)), n=n) objs.update(loss.item(), n) top1.update(prec1.item(), n) top5.update(prec5.item(), n) if step % self.report_every == 0: self.logger.info("valid %03d %e %f %f %s", step, objs.avg, top1.avg, top5.avg, "; ".join(["{}: {:.3f}".format(perf_n, v) \ for perf_n, v in objective_perfs.avgs().items()])) return top1.avg, objs.avg, objective_perfs.avgs()
def train_epoch(self, train_queue, model, criterion, optimizer, device, epoch): expect(self._is_setup, "trainer.setup should be called first") objs = utils.AverageMeter() top1 = utils.AverageMeter() top5 = utils.AverageMeter() model.train() for step, (inputs, target) in enumerate(train_queue): inputs = inputs.to(device) target = target.to(device) optimizer.zero_grad() if self.auxiliary_head: # assume model return two logits in train mode logits, logits_aux = model(inputs) loss = self._obj_loss( inputs, logits, target, model, add_evaluator_regularization=self.add_regularization) loss_aux = criterion(logits_aux, target) loss += self.auxiliary_weight * loss_aux else: logits = model(inputs) loss = self._obj_loss( inputs, logits, target, model, add_evaluator_regularization=self.add_regularization) #torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM) loss.backward() nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip) optimizer.step() prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) n = inputs.size(0) objs.update(loss.item(), n) top1.update(prec1.item(), n) top5.update(prec5.item(), n) del loss if step % self.report_every == 0: self.logger.info("train %03d %.3f; %.2f%%; %.2f%%", step, objs.avg, top1.avg, top5.avg) return top1.avg, objs.avg
def _controller_update(self, steps, finished_e_steps, finished_c_steps): controller_loss_meter = utils.AverageMeter() controller_stat_meters = utils.OrderedStats() rollout_stat_meters = utils.OrderedStats() self.controller.set_mode("train") for i_cont in range(1, steps + 1): print("\reva step {}/{} ; controller step {}/{}"\ .format(finished_e_steps, self.evaluator_steps, finished_c_steps+i_cont, self.controller_steps), end="") rollouts = self.controller.sample(self.controller_samples, self.rollout_batch_size) # if self.rollout_type == "differentiable": if self.is_differentiable: self.controller.zero_grad() step_loss = {"_": 0.} rollouts = self.evaluator.evaluate_rollouts( rollouts, is_training=True, callback=partial(self._backward_rollout_to_controller, step_loss=step_loss)) self.evaluator.update_rollouts(rollouts) # if self.rollout_type == "differentiable": if self.is_differentiable: # differntiable rollout (controller is optimized using differentiable relaxation) # adjust lr and call step_current_gradients # (update using the accumulated gradients) controller_loss = step_loss["_"] / self.controller_samples if self.controller_samples != 1: # adjust the lr to keep the effective learning rate unchanged lr_bak = self.controller_optimizer.param_groups[0]["lr"] self.controller_optimizer.param_groups[0]["lr"] \ = lr_bak / self.controller_samples self.controller.step_current_gradient( self.controller_optimizer) if self.controller_samples != 1: self.controller_optimizer.param_groups[0]["lr"] = lr_bak else: # other rollout types controller_loss = self.controller.step( rollouts, self.controller_optimizer, perf_name="reward") # update meters controller_loss_meter.update(controller_loss) controller_stats = self.controller.summary(rollouts, log=False) if controller_stats is not None: controller_stat_meters.update(controller_stats) r_stats = OrderedDict() for n in rollouts[0].perf: r_stats[n] = np.mean([r.perf[n] for r in rollouts]) rollout_stat_meters.update(r_stats) print("\r", end="") return controller_loss, rollout_stat_meters.avgs( ), controller_stat_meters.avgs()
def infer_epoch(self, valid_queue, model, criterion, device): expect(self._is_setup, "trainer.setup should be called first") objs = utils.AverageMeter() top1 = utils.AverageMeter() top5 = utils.AverageMeter() objective_perfs = utils.OrderedStats() all_perfs = [] model.eval() context = torch.no_grad if self.eval_no_grad else nullcontext with context(): for step, (inputs, target) in enumerate(valid_queue): inputs = inputs.to(device) target = target.to(device) logits = model(inputs) loss = criterion(logits, target) perfs = self._perf_func(inputs, logits, target, model) all_perfs.append(perfs) prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) n = inputs.size(0) # objective_perfs.update(dict(zip(self._perf_names, perfs)), n=n) objs.update(loss.item(), n) top1.update(prec1.item(), n) top5.update(prec5.item(), n) del loss if step % self.report_every == 0: all_perfs_by_name = list(zip(*all_perfs)) # support use objective aggregate fn, for stat method other than mean # e.g., adversarial distance median; detection mAP (see det_trainer.py) obj_perfs = { k: self.objective.aggregate_fn(k, False)(v) for k, v in zip(self._perf_names, all_perfs_by_name) } self.logger.info("valid %03d %e %f %f %s", step, objs.avg, top1.avg, top5.avg, "; ".join(["{}: {:.3f}".format(perf_n, v) \ # for perf_n, v in objective_perfs.avgs().items()])) for perf_n, v in obj_perfs.items()])) all_perfs_by_name = list(zip(*all_perfs)) obj_perfs = { k: self.objective.aggregate_fn(k, False)(v) for k, v in zip(self._perf_names, all_perfs_by_name) } return top1.avg, objs.avg, obj_perfs
def train_listwise(train_data, model, epoch, args, arch_network_type): objs = utils.AverageMeter() model.train() num_data = len(train_data) idx_list = np.arange(num_data) num_batches = getattr( args, "num_batch_per_epoch", int(num_data / (args.batch_size * args.list_length) * args.max_compare_ratio)) logging.info("Number of batches: {:d}".format(num_batches)) update_batch_n = getattr(args, "update_batch_n", 1) listwise_compare = getattr(args, "listwise_compare", False) if listwise_compare: assert args.list_length == 2 and update_batch_n == 1 model.optimizer.zero_grad() for step in range(1, num_batches + 1): if getattr(args, "bs_replace", False): idxes = np.array([ np.random.choice(idx_list, size=(args.list_length, ), replace=False) for _ in range(args.batch_size) ]) else: idxes = np.random.choice(idx_list, size=(args.batch_size, args.list_length), replace=False) flat_idxes = idxes.reshape(-1) archs, accs, _ = zip(*[train_data[idx] for idx in flat_idxes]) archs = np.array(archs).reshape( (args.batch_size, args.list_length, -1)) accs = np.array(accs).reshape((args.batch_size, args.list_length)) # accs[np.arange(0, args.batch_size)[:, None], np.argsort(accs, axis=1)[:, ::-1]] if update_batch_n == 1: if listwise_compare: loss = model.update_compare(archs[:, 0, :], archs[:, 1, :], accs[:, 1] > accs[:, 0]) else: loss = model.update_argsort(archs, np.argsort(accs, axis=1)[:, ::-1], first_n=getattr( args, "score_list_length", None)) else: loss = model.update_argsort(archs, np.argsort(accs, axis=1)[:, ::-1], first_n=getattr( args, "score_list_length", None), accumulate_only=True) if step % update_batch_n == 0: model.optimizer.step() model.optimizer.zero_grad() if arch_network_type != "random_forest": objs.update(loss, args.batch_size) if step % args.report_freq == 0: logging.info("train {:03d} [{:03d}/{:03d}] {:.4f}".format( epoch, step, num_batches, objs.avg)) return objs.avg
def train_epoch(logger, train_loader, model, epoch, cfg): objs = utils.AverageMeter() n_diff_pairs_meter = utils.AverageMeter() model.train() for step, (archs, accs) in enumerate(train_loader): archs = np.array(archs) accs = np.array(accs) n = len(archs) if cfg["compare"]: n_max_pairs = int(cfg["max_compare_ratio"] * n) acc_diff = np.array(accs)[:, None] - np.array(accs) acc_abs_diff_matrix = np.triu(np.abs(acc_diff), 1) ex_thresh_inds = np.where( acc_abs_diff_matrix > cfg["compare_threshold"]) ex_thresh_num = len(ex_thresh_inds[0]) if ex_thresh_num > n_max_pairs: keep_inds = np.random.choice(np.arange(ex_thresh_num), n_max_pairs, replace=False) ex_thresh_inds = (ex_thresh_inds[0][keep_inds], ex_thresh_inds[1][keep_inds]) archs_1, archs_2, better_lst = archs[ex_thresh_inds[1]], archs[ex_thresh_inds[0]], \ (acc_diff > 0)[ex_thresh_inds] n_diff_pairs = len(better_lst) n_diff_pairs_meter.update(float(n_diff_pairs)) loss = model.update_compare(archs_1, archs_2, better_lst) objs.update(loss, n_diff_pairs) else: loss = model.update_predict(archs, accs) objs.update(loss, n) if step % cfg["report_freq"] == 0: n_pair_per_batch = (cfg["batch_size"] * (cfg["batch_size"] - 1)) // 2 logger.info("train {:03d} [{:03d}/{:03d}] {:.4f}; {}".format( epoch, step, len(train_loader), objs.avg, "different pair ratio: {:.3f} ({:.1f}/{:3d})".format( n_diff_pairs_meter.avg / n_pair_per_batch, n_diff_pairs_meter.avg, n_pair_per_batch) if cfg["compare"] else "")) return objs.avg
def infer_epoch(self, valid_queue, model, criterion, device): expect(self._is_setup, "trainer.setup should be called first") top1 = utils.AverageMeter() top5 = utils.AverageMeter() objective_perfs = utils.OrderedStats() losses_obj = utils.OrderedStats() all_perfs = [] model.eval() context = torch.no_grad if self.eval_no_grad else nullcontext with context(): for step, (inputs, targets) in enumerate(valid_queue): inputs = inputs.to(device) # targets = targets.to(device) predictions = model.forward(inputs) losses = criterion(inputs, predictions, targets, model) prec1, prec5 = self._acc_func(inputs, predictions, targets, model) perfs = self._perf_func(inputs, predictions, targets, model) all_perfs.append(perfs) objective_perfs.update(dict(zip(self._perf_names, perfs))) losses_obj.update(losses) n = inputs.size(0) top1.update(prec1.item(), n) top5.update(prec5.item(), n) if step % self.report_every == 0: self.logger.info( "valid %03d %.2f%%; %.2f%%; %s", step, top1.avg, top5.avg, "; ".join([ "{}: {:.3f}".format(perf_n, v) for perf_n, v in \ list(objective_perfs.avgs().items()) + \ list(losses_obj.avgs().items())])) all_perfs = list(zip(*all_perfs)) obj_perfs = { k: self.objective.aggregate_fn(k, False)(v) for k, v in zip(self._perf_names, all_perfs) } return top1.avg, sum(losses_obj.avgs().values()), obj_perfs
def evaluate_epoch(self, data, targets, bptt_steps): expect(self._is_setup, "trainer.setup should be called first") batch_size = data.shape[1] self.model.eval() objs = utils.AverageMeter() hiddens = self.model.init_hidden(batch_size) for i in range(0, data.size(0), bptt_steps): seq_len = min(bptt_steps, len(data)-i) inp, targ = data[i:i+seq_len], targets[i:i+seq_len] logits, _, _, hiddens = self.parallel_model(inp, hiddens) objs.update(self._criterion(logits.view(-1, logits.size(-1)), targ.view(-1)).item(), seq_len) return objs.avg
def infer_epoch(self, valid_queue, model, criterion, device): expect(self._is_setup, "trainer.setup should be called first") cls_objs = utils.AverageMeter() loc_objs = utils.AverageMeter() top1 = utils.AverageMeter() top5 = utils.AverageMeter() objective_perfs = utils.OrderedStats() model.eval() context = torch.no_grad if self.eval_no_grad else nullcontext with context(): for step, (inputs, targets) in enumerate(valid_queue): inputs = inputs.to(device) # targets = targets.to(device) predictions = model.forward(inputs) classification_loss, regression_loss = criterion( inputs, predictions, targets, model) prec1, prec5 = self._acc_func(inputs, predictions, targets, model) perfs = self._perf_func(inputs, predictions, targets, model) objective_perfs.update(dict(zip(self._perf_names, perfs))) n = inputs.size(0) cls_objs.update(classification_loss.item(), n) loc_objs.update(regression_loss.item(), n) top1.update(prec1.item(), n) top5.update(prec5.item(), n) if step % self.report_every == 0: self.logger.info("valid %03d %e %e; %.2f%%; %.2f%%; %s", step, cls_objs.avg, loc_objs.avg, top1.avg, top5.avg, "; ".join(["{}: {:.3f}".format(perf_n, v) \ for perf_n, v in objective_perfs.avgs().items()])) stats = self.dataset.evaluate_detections(self.objective.all_boxes, self.eval_dir) self.logger.info("mAP: {}".format(stats[0])) return top1.avg, cls_objs.avg + loc_objs.avg, objective_perfs.avgs()
def train_multi_stage_pair_pool(all_stages, pairs_list, model, i_epoch, args): objs = utils.AverageMeter() model.train() # try get through all the pairs pairs_pool = list( zip(*[np.concatenate(items) for items in zip(*pairs_list)])) num_pairs = len(pairs_pool) logging.info("Number of pairs: {}".format(num_pairs)) np.random.shuffle(pairs_pool) num_batch = num_pairs // args.batch_size for i_batch in range(num_batch): archs_1_inds, archs_2_inds, better_lst = list( zip(*pairs_pool[i_batch * args.batch_size:(i_batch + 1) * args.batch_size])) loss = model.update_compare( np.array([all_stages[idx][0] for idx in archs_1_inds]), np.array([all_stages[idx][0] for idx in archs_2_inds]), better_lst) objs.update(loss, args.batch_size) if i_batch % args.report_freq == 0: logging.info("train {:03d} [{:03d}/{:03d}] {:.4f}".format( i_epoch, i_batch, num_batch, objs.avg)) return objs.avg
def train(train_loader, model, epoch, args, arch_network_type): objs = utils.AverageMeter() n_diff_pairs_meter = utils.AverageMeter() model.train() for step, (archs, f_accs, h_accs) in enumerate(train_loader): archs = np.array(archs) h_accs = np.array(h_accs) f_accs = np.array(f_accs) n = len(archs) if getattr(args, "use_half", False): accs = h_accs else: accs = f_accs if args.compare: if None in f_accs: # some archs only have half-time acc n_max_pairs = int(args.max_compare_ratio * n) n_max_inter_pairs = int(args.inter_pair_ratio * n_max_pairs) half_inds = np.array( [ind for ind, acc in enumerate(accs) if acc is None]) mask = np.zeros(n) mask[half_inds] = 1 final_inds = np.where(1 - mask)[0] half_eche = h_accs[half_inds] final_eche = h_accs[final_inds] half_acc_diff = final_eche[:, None] - half_eche # (num_final, num_half) assert (half_acc_diff >= 0).all() # should be >0 half_ex_thresh_inds = np.where( np.abs(half_acc_diff) > getattr(args, "half_compare_threshold", 2 * args.compare_threshold)) half_ex_thresh_num = len(half_ex_thresh_inds[0]) if half_ex_thresh_num > n_max_inter_pairs: # random choose keep_inds = np.random.choice(np.arange(half_ex_thresh_num), n_max_inter_pairs, replace=False) half_ex_thresh_inds = (half_ex_thresh_inds[0][keep_inds], half_ex_thresh_inds[1][keep_inds]) inter_archs_1, inter_archs_2, inter_better_lst \ = archs[half_inds[half_ex_thresh_inds[1]]], archs[final_inds[half_ex_thresh_inds[0]]], \ (half_acc_diff > 0)[half_ex_thresh_inds] n_inter_pairs = len(inter_better_lst) # only use intra pairs in the final echelon n_intra_pairs = n_max_pairs - n_inter_pairs accs = np.array(accs)[final_inds] archs = archs[final_inds] acc_diff = np.array(accs)[:, None] - np.array(accs) acc_abs_diff_matrix = np.triu(np.abs(acc_diff), 1) ex_thresh_inds = np.where( acc_abs_diff_matrix > args.compare_threshold) ex_thresh_num = len(ex_thresh_inds[0]) if ex_thresh_num > n_intra_pairs: if args.choose_pair_criterion == "diff": keep_inds = np.argpartition( acc_abs_diff_matrix[ex_thresh_inds], -n_intra_pairs)[-n_intra_pairs:] elif args.choose_pair_criterion == "random": keep_inds = np.random.choice(np.arange(ex_thresh_num), n_intra_pairs, replace=False) ex_thresh_inds = (ex_thresh_inds[0][keep_inds], ex_thresh_inds[1][keep_inds]) archs_1, archs_2, better_lst = archs[ex_thresh_inds[1]], archs[ ex_thresh_inds[0]], (acc_diff > 0)[ex_thresh_inds] archs_1, archs_2, better_lst = np.concatenate((inter_archs_1, archs_1)),\ np.concatenate((inter_archs_2, archs_2)),\ np.concatenate((inter_better_lst, better_lst)) else: if getattr(args, "compare_split", False): n_pairs = len(archs) // 2 accs = np.array(accs) acc_diff_lst = accs[n_pairs:2 * n_pairs] - accs[:n_pairs] keep_inds = np.where( np.abs(acc_diff_lst) > args.compare_threshold)[0] better_lst = (np.array(accs[n_pairs:2 * n_pairs] - accs[:n_pairs]) > 0)[keep_inds] archs_1 = np.array(archs[:n_pairs])[keep_inds] archs_2 = np.array(archs[n_pairs:2 * n_pairs])[keep_inds] else: n_max_pairs = int(args.max_compare_ratio * n) acc_diff = np.array(accs)[:, None] - np.array(accs) acc_abs_diff_matrix = np.triu(np.abs(acc_diff), 1) ex_thresh_inds = np.where( acc_abs_diff_matrix > args.compare_threshold) ex_thresh_num = len(ex_thresh_inds[0]) if ex_thresh_num > n_max_pairs: if args.choose_pair_criterion == "diff": keep_inds = np.argpartition( acc_abs_diff_matrix[ex_thresh_inds], -n_max_pairs)[-n_max_pairs:] elif args.choose_pair_criterion == "random": keep_inds = np.random.choice( np.arange(ex_thresh_num), n_max_pairs, replace=False) ex_thresh_inds = (ex_thresh_inds[0][keep_inds], ex_thresh_inds[1][keep_inds]) archs_1, archs_2, better_lst = archs[ ex_thresh_inds[1]], archs[ex_thresh_inds[0]], ( acc_diff > 0)[ex_thresh_inds] n_diff_pairs = len(better_lst) n_diff_pairs_meter.update(float(n_diff_pairs)) loss = model.update_compare(archs_1, archs_2, better_lst) objs.update(loss, n_diff_pairs) else: loss = model.update_predict(archs, accs) if arch_network_type != "random_forest": objs.update(loss, n) if step % args.report_freq == 0: n_pair_per_batch = (args.batch_size * (args.batch_size - 1)) // 2 logging.info("train {:03d} [{:03d}/{:03d}] {:.4f}; {}".format( epoch, step, len(train_loader), objs.avg, "different pair ratio: {:.3f} ({:.1f}/{:3d})".format( n_diff_pairs_meter.avg / n_pair_per_batch, n_diff_pairs_meter.avg, n_pair_per_batch) if args.compare else "")) return objs.avg
def sample_batchify(search_space, model, ratio, K, args, conflict_archs=None): model.eval() inner_sample_n = args.sample_batchify_inner_sample_n ss = search_space assert K % inner_sample_n == 0 num_iter = K // inner_sample_n want_samples_per_iter = int(ratio * inner_sample_n) logging.info( "Sample {}. REPEAT {}: Sample {} archs based on the predicted score across {} archs" .format(K, num_iter, inner_sample_n, want_samples_per_iter)) sampled_rollouts = [] sampled_scores = [] # the number, mean and max predicted scores of current sampled archs cur_sampled_mean_max = (0, 0, 0) i_iter = 1 # num_steps = (ratio * K + args.batch_size - 1) // args.batch_size _r_cls = ss.random_sample().__class__ conflict_rollouts = [ _r_cls(arch, info={}, search_space=search_space) for arch in conflict_archs or [] ] inner_report_freq = 10 judget_conflict = False while i_iter <= num_iter: # # random init # if self.inner_iter_random_init \ # and hasattr(self.inner_controller, "reinit"): # self.inner_controller.reinit() new_per_step_meter = utils.AverageMeter() # a list with length self.inner_sample_n best_rollouts = [] best_scores = [] num_to_sample = inner_sample_n iter_r_set = [] iter_s_set = [] sampled_r_set = sampled_rollouts # for i_inner in range(1, num_steps+1): i_inner = 0 while new_per_step_meter.sum < want_samples_per_iter: i_inner += 1 rollouts = [ search_space.random_sample() for _ in range(args.batch_size) ] batch_archs = [r.arch for r in rollouts] step_scores = list(model.predict(batch_archs).cpu().data.numpy()) if judget_conflict: new_inds, new_rollouts = zip( *[(i, r) for i, r in enumerate(rollouts) if r not in conflict_rollouts and r not in sampled_r_set and r not in iter_r_set]) new_step_scores = [step_scores[i] for i in new_inds] iter_r_set += new_rollouts iter_s_set += new_step_scores else: new_rollouts = rollouts new_step_scores = step_scores new_per_step_meter.update(len(new_rollouts)) best_rollouts += new_rollouts best_scores += new_step_scores # iter_r_set += rollouts # iter_s_set += step_scores if len(best_scores) > num_to_sample: keep_inds = np.argpartition(best_scores, -num_to_sample)[-num_to_sample:] best_rollouts = [best_rollouts[ind] for ind in keep_inds] best_scores = [best_scores[ind] for ind in keep_inds] if i_inner % inner_report_freq == 0: logging.info( ( "Seen %d/%d Iter %d (to sample %d) (already sampled %d mean %.5f, best %.5f); " "Step %d: sample %d step mean %.5f best %.5f: {} " # "(iter mean %.5f, best %.5f). "AVG new/step: %.3f").format(", ".join( ["{:.5f}".format(s) for s in best_scores])), new_per_step_meter.sum, want_samples_per_iter, i_iter, num_to_sample, cur_sampled_mean_max[0], cur_sampled_mean_max[1], cur_sampled_mean_max[2], i_inner, len(rollouts), np.mean(step_scores), np.max(step_scores), #np.mean(iter_s_set), np.max(iter_s_set), new_per_step_meter.avg) # if new_per_step_meter.sum < num_to_sample * 10: # # rerun this iter, also reinit! # self.logger.info("Cannot find %d (num_to_sample x min_inner_sample_ratio)" # " (%d x %d) new rollouts in one run of the inner controller" # "Re-init the controller and re-run this iteration.", # num_to_sample * self.min_inner_sample_ratio, # num_to_sample, self.min_inner_sample_ratio) # continue i_iter += 1 assert len(best_scores) == num_to_sample sampled_rollouts += best_rollouts sampled_scores += best_scores cur_sampled_mean_max = (len(sampled_scores), np.mean(sampled_scores), np.max(sampled_scores)) return [r.genotype for r in sampled_rollouts]
def train(train_loader, model, epoch, args): objs = utils.AverageMeter() n_diff_pairs_meter = utils.AverageMeter() n_eq_pairs_meter = utils.AverageMeter() model.train() margin_diff_coeff = getattr(args, "margin_diff_coeff", None) eq_threshold = getattr(args, "eq_threshold", None) eq_pair_ratio = getattr(args, "eq_pair_ratio", 0) if eq_threshold is not None: assert eq_pair_ratio > 0 assert eq_threshold <= args.compare_threshold for step, (archs, all_accs) in enumerate(train_loader): archs = np.array(archs) n = len(archs) use_checkpoint = getattr(args, "use_checkpoint", 3) accs = all_accs[:, use_checkpoint] if args.compare: if getattr(args, "compare_split", False): n_pairs = len(archs) // 2 accs = np.array(accs) acc_diff_lst = accs[n_pairs:2 * n_pairs] - accs[:n_pairs] keep_inds = np.where( np.abs(acc_diff_lst) > args.compare_threshold)[0] better_lst = (np.array(accs[n_pairs:2 * n_pairs] - accs[:n_pairs]) > 0)[keep_inds] archs_1 = np.array(archs[:n_pairs])[keep_inds] archs_2 = np.array(archs[n_pairs:2 * n_pairs])[keep_inds] else: n_max_pairs = int(args.max_compare_ratio * n * (1 - eq_pair_ratio)) acc_diff = np.array(accs)[:, None] - np.array(accs) acc_abs_diff_matrix = np.triu(np.abs(acc_diff), 1) ex_thresh_inds = np.where( acc_abs_diff_matrix > args.compare_threshold) ex_thresh_num = len(ex_thresh_inds[0]) if ex_thresh_num > n_max_pairs: if args.choose_pair_criterion == "diff": keep_inds = np.argpartition( acc_abs_diff_matrix[ex_thresh_inds], -n_max_pairs)[-n_max_pairs:] elif args.choose_pair_criterion == "random": keep_inds = np.random.choice(np.arange(ex_thresh_num), n_max_pairs, replace=False) ex_thresh_inds = (ex_thresh_inds[0][keep_inds], ex_thresh_inds[1][keep_inds]) archs_1, archs_2, better_lst, acc_diff_lst = archs[ ex_thresh_inds[1]], archs[ex_thresh_inds[0]], ( acc_diff > 0)[ex_thresh_inds], acc_diff[ex_thresh_inds] n_diff_pairs = len(better_lst) n_diff_pairs_meter.update(float(n_diff_pairs)) if eq_threshold is None: if margin_diff_coeff is not None: margin = np.abs(acc_diff_lst) * margin_diff_coeff loss = model.update_compare(archs_1, archs_2, better_lst, margin=margin) else: loss = model.update_compare(archs_1, archs_2, better_lst) else: # drag close the score of arch pairs whose true acc diffs are below args.eq_threshold n_eq_pairs = int(args.max_compare_ratio * n * eq_pair_ratio) below_eq_thresh_inds = np.where( acc_abs_diff_matrix < eq_threshold) below_eq_thresh_num = len(below_eq_thresh_inds[0]) if below_eq_thresh_num > n_eq_pairs: keep_inds = np.random.choice( np.arange(below_eq_thresh_num), n_eq_pairs, replace=False) below_eq_thresh_inds = (below_eq_thresh_inds[0][keep_inds], below_eq_thresh_inds[1][keep_inds]) eq_archs_1, eq_archs_2, below_acc_diff_lst = \ archs[below_eq_thresh_inds[1]], archs[below_eq_thresh_inds[0]], acc_abs_diff_matrix[below_eq_thresh_inds] if margin_diff_coeff is not None: margin = np.concatenate( (np.abs(acc_diff_lst), np.abs(below_acc_diff_lst))) * margin_diff_coeff else: margin = None better_pm_lst = np.concatenate( (2 * better_lst - 1, np.zeros(len(eq_archs_1)))) n_eq_pairs_meter.update(float(len(eq_archs_1))) loss = model.update_compare_eq(np.concatenate( (archs_1, eq_archs_1)), np.concatenate( (archs_2, eq_archs_2)), better_pm_lst, margin=margin) objs.update(loss, n_diff_pairs) else: loss = model.update_predict(archs, accs) objs.update(loss, n) if step % args.report_freq == 0: n_pair_per_batch = (args.batch_size * (args.batch_size - 1)) // 2 logging.info("train {:03d} [{:03d}/{:03d}] {:.4f}; {}".format( epoch, step, len(train_loader), objs.avg, "different pair ratio: {:.3f} ({:.1f}/{:3d}){}".format( n_diff_pairs_meter.avg / n_pair_per_batch, n_diff_pairs_meter.avg, n_pair_per_batch, "; eq pairs: {.3d}".format(n_eq_pairs_meter.avg) if eq_threshold is not None else "") if args.compare else "")) return objs.avg
def train_multi_stage_listwise(train_stages, model, epoch, args, avg_stage_scores, stage_epochs, score_train_stages=None): # TODO: multi stage objs = utils.AverageMeter() n_listlength_meter = utils.AverageMeter() model.train() num_stages = len(train_stages) stage_lens = [len(stage_data) for stage_data in train_stages] stage_sep_inds = [np.arange(stage_len) for stage_len in stage_lens] sample_acc_temp = getattr(args, "sample_acc_temp", None) if sample_acc_temp is not None: stage_sep_probs = [] for i_stage, stage_data in enumerate(train_stages): perfs = np.array([ item[1][stage_epochs[i_stage]] for item in train_stages[i_stage] ]) perfs = perfs / sample_acc_temp exp_perfs = np.exp(perfs - np.max(perfs)) stage_sep_probs.append(exp_perfs / exp_perfs.sum()) else: stage_sep_probs = None stage_single_probs = getattr(args, "stage_single_probs", None) assert stage_single_probs is not None if stage_single_probs is not None: stage_probs = np.array([ single_prob * len_ for single_prob, len_ in zip(stage_single_probs, stage_lens) ]) stage_probs = stage_probs / stage_probs.sum() logging.info("Epoch {:d}: Stage probs {}".format(epoch, stage_probs)) num_stage_samples_avg = np.zeros(num_stages) train_stages = np.array(train_stages) listwise_compare = getattr(args, "listwise_compare", False) if listwise_compare: assert args.list_length == 2 for step in range(args.num_batch_per_epoch): num_stage_samples = np.random.multinomial(args.list_length, stage_probs) num_stage_samples = np.minimum(num_stage_samples, stage_lens) true_ll = np.sum(num_stage_samples) n_listlength_meter.update(true_ll, args.batch_size) num_stage_samples_avg += num_stage_samples stage_inds = [ np.array([ np.random.choice(stage_sep_inds[i_stage], size=(sz), replace=False, p=None if stage_sep_probs is None else stage_sep_probs[i_stage]) for _ in range(args.batch_size) ]) if sz > 0 else np.zeros((args.batch_size, 0), dtype=np.int) for i_stage, sz in enumerate(num_stage_samples) ] sorted_stage_inds = [ s_stage_inds[ np.arange(args.batch_size)[:, None], np.argsort(np.array( np.array(train_stages[i_stage])[s_stage_inds][:, :, 1]. tolist())[:, :, stage_epochs[i_stage]], axis=1)] if s_stage_inds.shape[1] > 1 else s_stage_inds for i_stage, s_stage_inds in enumerate(stage_inds) ] archs = np.concatenate([ np.array(train_stages[i_stage])[s_stage_inds][:, :, 0] for i_stage, s_stage_inds in enumerate(sorted_stage_inds) if s_stage_inds.size > 0 ], axis=1) archs = archs[:, ::-1] # order: from best to worst assert archs.ndim == 2 archs = np.array(archs.tolist( )) # (batch_size, list_length, num_cell_groups, node_or_op, decisions) if listwise_compare: loss = model.update_compare(archs[:, 0], archs[:, 1], np.zeros(archs.shape[0])) else: loss = model.update_argsort(archs, idxes=None, first_n=getattr( args, "score_list_length", None), is_sorted=True) objs.update(loss, args.batch_size) if step % args.report_freq == 0: logging.info( "train {:03d} [{:03d}/{:03d}] {:.4f} (mean ll: {:.1f}; {})". format(epoch, step, args.num_batch_per_epoch, objs.avg, n_listlength_meter.avg, (num_stage_samples_avg / (step + 1)).tolist())) return objs.avg
def train_multi_stage(train_stages, model, epoch, args, avg_stage_scores, stage_epochs): # TODO: multi stage objs = utils.AverageMeter() n_diff_pairs_meter = utils.AverageMeter() model.train() num_stages = len(train_stages) # must specificy `stage_probs` or `stage_prob_power` stage_probs = getattr(args, "stage_probs", None) if stage_probs is None: stage_probs = _cal_stage_probs(avg_stage_scores, args.stage_prob_power) stage_accept_pair_probs = getattr(args, "stage_accept_pair_probs", [1.0] * num_stages) stage_lens = [len(stage_data) for stage_data in train_stages] for i, len_ in enumerate(stage_lens): if len_ == 0: n_j = num_stages - i - 1 for j in range(i + 1, num_stages): stage_probs[j] += stage_probs[i] / float(n_j) stage_probs[i] = 0 # diff_threshold = getattr(args, "diff_threshold", [0.08, 0.04, 0.02, 0.0]) stage_single_probs = getattr(args, "stage_single_probs", None) if stage_single_probs is not None: stage_probs = np.array([ single_prob * len_ for single_prob, len_ in zip(stage_single_probs, stage_lens) ]) stage_probs = stage_probs / stage_probs.sum() logging.info("Epoch {:d}: Stage probs {}".format(epoch, stage_probs)) diff_threshold = args.diff_threshold for step in range(args.num_batch_per_epoch): pair_batch = [] i_pair = 0 while 1: stage_1, stage_2 = np.random.choice(np.arange(num_stages), size=2, p=stage_probs) d_1 = train_stages[stage_1][np.random.randint( 0, stage_lens[stage_1])] d_2 = train_stages[stage_2][np.random.randint( 0, stage_lens[stage_2])] min_stage = min(stage_2, stage_1) if np.random.rand() > stage_accept_pair_probs[min_stage]: continue # max_stage = stage_2 + stage_1 - min_stage # if max_stage - min_stage >= 2: # better = stage_2 > stage_1 # else: min_epoch = stage_epochs[min_stage] diff_21 = d_2[1][min_epoch] - d_1[1][min_epoch] # print(stage_1, stage_2, diff_21, diff_threshold) if np.abs(diff_21) > diff_threshold[min_stage]: # if the difference is larger than the threshold of the min stage, this pair count better = diff_21 > 0 else: continue pair_batch.append((d_1[0], d_2[0], better)) i_pair += 1 if i_pair == args.batch_size: break archs_1, archs_2, better_lst = zip(*pair_batch) n_diff_pairs = len(better_lst) n_diff_pairs_meter.update(float(n_diff_pairs)) loss = model.update_compare(archs_1, archs_2, better_lst) objs.update(loss, n_diff_pairs) if step % args.report_freq == 0: logging.info("train {:03d} [{:03d}/{:03d}] {:.4f}".format( epoch, step, args.num_batch_per_epoch, objs.avg)) return objs.avg
def train(self): #pylint: disable=too-many-branches assert self.is_setup, "Must call `trainer.setup` method before calling `trainer.train`." if self.interleave_controller_every is not None: inter_steps = self.controller_steps evaluator_steps = self.interleave_controller_every controller_steps = 1 else: inter_steps = 1 evaluator_steps = self.evaluator_steps controller_steps = self.controller_steps for epoch in range(self.last_epoch + 1, self.epochs + 1): c_loss_meter = utils.AverageMeter() rollout_stat_meters = utils.OrderedStats( ) # rollout performance stats from evaluator c_stat_meters = utils.OrderedStats() # other stats from controller eva_stat_meters = utils.OrderedStats( ) # other stats from `evaluator.update_evaluator` self.epoch = epoch # this is redundant as Component.on_epoch_start also set this # call `on_epoch_start` of sub-components # also schedule values and optimizer learning rates self.on_epoch_start(epoch) finished_e_steps = 0 finished_c_steps = 0 for i_inter in range(1, inter_steps + 1): # interleave mepa/controller training # meta parameter training if evaluator_steps > 0: e_stats = self._evaluator_update(evaluator_steps, finished_e_steps, finished_c_steps) eva_stat_meters.update(e_stats) finished_e_steps += evaluator_steps if epoch >= self.controller_train_begin and \ epoch % self.controller_train_every == 0 and controller_steps > 0: # controller training c_loss, rollout_stats, c_stats \ = self._controller_update(controller_steps, finished_e_steps, finished_c_steps) # update meters if c_loss is not None: c_loss_meter.update(c_loss) if rollout_stats is not None: rollout_stat_meters.update(rollout_stats) if c_stats is not None: c_stat_meters.update(c_stats) finished_c_steps += controller_steps if self.interleave_report_every and i_inter % self.interleave_report_every == 0: # log for every `interleave_report_every` interleaving steps self.logger.info("(inter step %3d): " "evaluator (%3d/%3d) %s ; " "controller (%3d/%3d) %s", i_inter, finished_e_steps, self.evaluator_steps, "; ".join( ["{}: {:.3f}".format(n, v) \ for n, v in eva_stat_meters.avgs().items()]), finished_c_steps, self.controller_steps, "" if not rollout_stat_meters else "; ".join( ["{}: {:.3f}".format(n, v) \ for n, v in rollout_stat_meters.avgs().items()])) # log infomations of this epoch if eva_stat_meters: self.logger.info("Epoch %3d: [evaluator update] %s", epoch, "; ".join(["{}: {:.3f}".format(n, v) \ for n, v in eva_stat_meters.avgs().items()])) if rollout_stat_meters: self.logger.info("Epoch %3d: [controller update] controller loss: %.3f ; " "rollout performance: %s", epoch, c_loss_meter.avg, "; ".join(["{}: {:.3f}".format(n, v) \ for n, v in rollout_stat_meters.avgs().items()])) if c_stat_meters: self.logger.info("[controller stats] %s", \ "; ".join(["{}: {:.3f}".format(n, v) \ for n, v in c_stat_meters.avgs().items()])) # maybe write tensorboard info if not self.writer.is_none(): if eva_stat_meters: for n, meter in eva_stat_meters.items(): self.writer.add_scalar( "evaluator_update/{}".format(n.replace(" ", "-")), meter.avg, epoch) if rollout_stat_meters: for n, meter in rollout_stat_meters.items(): self.writer.add_scalar( "controller_update/{}".format(n.replace(" ", "-")), meter.avg, epoch) if c_stat_meters: for n, meter in c_stat_meters.items(): self.writer.add_scalar( "controller_stats/{}".format(n.replace(" ", "-")), meter.avg, epoch) if not c_loss_meter.is_empty(): self.writer.add_scalar("controller_loss", c_loss_meter.avg, epoch) # maybe save checkpoints self.maybe_save() # maybe derive archs and test if self.test_every and self.epoch % self.test_every == 0: self.test() self.on_epoch_end(epoch) # call `on_epoch_end` of sub-components # `final_save` pickle dump the weights_manager and controller directly, # instead of the state dict self.final_save()
def train_epoch(self, data, targets, bptt_steps): expect(self._is_setup, "trainer.setup should be called first") batch_size = data.shape[1] num_total_steps = data.shape[0] self.model.train() objs = utils.AverageMeter() losses = utils.AverageMeter() hiddens = self.model.init_hidden(batch_size) if self.random_bptt: # random sequece lengths seq_lens = [] i = 0 while i < data.size(0): mean_ = bptt_steps if np.random.random() < 0.95 else bptt_steps / 2 seq_len = min(max(5, int(np.random.normal(mean_, 5))), bptt_steps + 20) seq_lens.append(seq_len) i += seq_len seq_lens[-1] -= i - data.size(0) num_total_batches = len(seq_lens) else: # fixed sequence length == bptt_steps num_total_batches = int(np.ceil(data.size(0) / bptt_steps)) seq_lens = [bptt_steps] * num_total_batches seq_lens[-1] = num_total_steps - bptt_steps * (num_total_batches-1) lr_bak = self.optimizer.param_groups[0]["lr"] i = 0 for batch in range(1, num_total_batches+1): seq_len = seq_lens[batch-1] inp, targ = data[i:i+seq_len], targets[i:i+seq_len] # linear adjusting learning rate self.optimizer.param_groups[0]["lr"] = lr_bak * seq_len / bptt_steps self.optimizer.zero_grad() logits, raw_outs, outs, hiddens = self.parallel_model(inp, hiddens) raw_loss = self._criterion(logits.view(-1, logits.size(-1)), targ.view(-1)) loss = raw_loss # Activiation Regularization if self.rnn_act_reg > 0: loss = loss + self.rnn_act_reg * outs.pow(2).mean() # Temporal Activation Regularization (slowness) if self.rnn_slowness_reg > 0: loss = loss + self.rnn_slowness_reg * (raw_outs[1:] - raw_outs[:-1]).pow(2).mean() loss.backward() # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs. self.model.step_current_gradients(self.optimizer) objs.update(raw_loss.item(), seq_len) losses.update(loss.item(), seq_len) # del logits, raw_outs, outs, raw_loss, loss i += seq_len if batch % self.report_every == 0: self.logger.info("train %3d/%3d: perp %.3f; loss %.3f; loss(with reg) %.3f", batch, num_total_batches, np.exp(objs.avg), objs.avg, losses.avg) self.optimizer.param_groups[0]["lr"] = lr_bak return objs.avg, losses.avg
def sample(self, n=1, batch_size=1): """Sample architectures based on the current predictor""" if self.mode == "eval": # return the best n rollouts that are evaluted by ground-truth evaluator self.logger.info( "Return the best {} rollouts in the population".format(n)) all_gt_arch_scores = sum(self.gt_arch_scores, []) all_rollouts = sum(self.gt_rollouts, []) best_inds = np.argpartition( [item[1] for item in all_gt_arch_scores], -n)[-n:] # all_rollouts, all_scores = zip( # *[(r, r.get_perf("reward")) for rs in self.gt_rollouts for r in rs]) # best_inds = np.argpartition(all_scores, -n)[-n:] return [all_rollouts[ind] for ind in best_inds] if not self.is_predictor_trained: # if predictor is not trained, random sample from search space return [self.search_space.random_sample() for _ in range(n)] if n % self.inner_sample_n != 0: self.logger.warn( "samle number %d cannot be divided by inner_sample_n %d", n, self.inner_sample_n) # the arch rollouts that have already evaled, avoid sampling them already_evaled_r_set = sum(self.gt_rollouts, []) # nb101, nb201 420k, 15k, small. forward 1~2min max if self.inner_enumerate_search_space: if self.inner_enumerate_sample_ratio is not None: assert n % self.inner_sample_n == 0 max_num = None if self.inner_enumerate_sample_ratio is None \ else n * self.inner_enumerate_sample_ratio iter_ = self.search_space.batch_rollouts( batch_size=self.predict_batch_size, shuffle=True, max_num=max_num) scores = [] all_rollouts = [] num_ignore = 0 for rollouts in iter_: # remove the rollouts that is already evaled ori_len_ = len(rollouts) rollouts = [ rollout for rollout in rollouts if rollout not in already_evaled_r_set ] num_ignore += ori_len_ - len(rollouts) all_rollouts = all_rollouts + self._predict_rollouts(rollouts) scores = scores + [i.perf["predicted_score"] for i in rollouts] if self.inner_sample_n is not None: num_iters = n // self.inner_sample_n rs_per_s = len(scores) // num_iters scores = np.array(scores)[:rs_per_s * num_iters] inds = np.argpartition(scores.reshape([num_iters, rs_per_s]), -self.inner_sample_n, axis=1)[:, -self.inner_sample_n:] # inds: (num_iters, self.inner_sample_n) best_inds = ( inds + rs_per_s * np.arange(num_iters)[:, None]).reshape(-1) self.logger.info( "Random sample %d archs (max num %d), ignore %d already evaled archs, " "and choose %d archs per %d archs with highest predict scores", len(scores), max_num, num_ignore, self.inner_sample_n, rs_per_s) else: # finally: ranking, and get the first n archs. train_cellss_pkl.py `sample` function best_inds = np.argpartition(scores, -n)[-n:] self.logger.info( "Random sample %d archs (max num %d), ignore %d already evaled archs, " "and choose %d archs with highest predict scores", len(scores), max_num, num_ignore, n) return [all_rollouts[i] for i in best_inds] # if self.inner_controller_reinit: self.inner_controller = BaseController.get_class_( self.inner_controller_type)(self.search_space, self.device, mode=self.mode, rollout_type=self.rollout_type, **self.inner_controller_cfg) if hasattr(self.inner_controller, "set_init_population"): self.logger.info( "re-evaluating %d rollouts using the current predictor", self.num_gt_rollouts) # set the init population of the inner controller # re-evaluate rollouts using the current predictor for rollouts in self.gt_rollouts: rollouts = self._predict_rollouts(rollouts) if not self.inner_random_init: self.inner_controller.set_init_population( sum(self.gt_rollouts, []), perf_name="predicted_score") # inner_sample_n: how many archs to sample every iter num_iter = (n + self.inner_sample_n - 1) // self.inner_sample_n sampled_rollouts = [] sampled_scores = [] # the number, mean and max predicted scores of current sampled archs cur_sampled_mean_max = (0, 0, 0) i_iter = 1 while i_iter <= num_iter: # for i_iter in range(1, num_iter+1): # random init if self.inner_iter_random_init \ and hasattr(self.inner_controller, "reinit"): if i_iter > 1: # might use gt rollouts as the init population if `inner_random_init=true` # so, do not call reinit when i_iter == 1 if (not isinstance(self.inner_iter_random_init, int)) or \ self.inner_iter_random_init == 1 or \ i_iter % self.inner_iter_random_init == 1: # if `inner_iter_random_init` is a integer # only reinit every `inner_iter_random_init` iterations. # `inner_iter_random_init==True` is the same as `inner_iter_random_init==1`, # and means that every iter (besides iter 1) would call `reinit` self.inner_controller.reinit() new_per_step_meter = utils.AverageMeter() # a list with length self.inner_sample_n best_rollouts = [] best_scores = [] num_to_sample = min(n - (i_iter - 1) * self.inner_sample_n, self.inner_sample_n) iter_r_set = [] iter_s_set = [] sampled_r_set = sampled_rollouts for i_inner in range(1, self.inner_steps + 1): # self.inner_controller.on_epoch_begin(i_inner) # while 1: # rollouts = self.inner_controller.sample(self.inner_samples) # # remove the duplicate rollouts # # *fixme* FIXME: local minimum problem exists! # # random sample is one way, or do not use the best as the init? # # Add a test to test the whole dataset... # # grond-truth evaled, decided rollouts # # rollouts = [r for r in rollouts # # if r not in already_evaled_r_set \ # # and r not in sampled_r_set] # # and r not in iter_r_set # if not rollouts: # print("all conflict, resample") # continue # else: # # print("sampled {}".format(i_inner)) # break rollouts = self.inner_controller.sample(self.inner_samples) rollouts = self._predict_rollouts(rollouts) self.inner_controller.step(rollouts, self.inner_cont_optimizer, perf_name="predicted_score") # keep the `num_to_sample` archs with highest scores step_scores = [ r.get_perf(name="predicted_score") for r in rollouts ] new_rollouts = [r for r in rollouts if r not in already_evaled_r_set \ and r not in sampled_r_set and r not in iter_r_set] new_step_scores = [ r.get_perf(name="predicted_score") for r in new_rollouts ] new_per_step_meter.update(len(new_rollouts)) best_rollouts += new_rollouts best_scores += new_step_scores iter_r_set += rollouts iter_s_set += step_scores if len(best_scores) > num_to_sample: keep_inds = np.argpartition( best_scores, -num_to_sample)[-num_to_sample:] best_rollouts = [best_rollouts[ind] for ind in keep_inds] best_scores = [best_scores[ind] for ind in keep_inds] if i_inner % self.inner_report_freq == 0: self.logger.info(( "Iter %d (to sample %d) (already sampled %d mean %.5f, best %.5f); " "Step %d: sample %d step mean %.5f best %.5f: {} " "(iter mean %.5f, best %.5f). AVG new/step: %.3f" ).format(", ".join([ "{:.5f}".format(s) for s in best_scores ])), i_iter, num_to_sample, cur_sampled_mean_max[0], cur_sampled_mean_max[1], cur_sampled_mean_max[2], i_inner, len(rollouts), np.mean(step_scores), np.max(step_scores), np.mean(iter_s_set), np.max(iter_s_set), new_per_step_meter.avg) if new_per_step_meter.sum < num_to_sample * self.min_inner_sample_ratio: # rerun this iter, also reinit! self.logger.info( "Cannot find %d (num_to_sample x min_inner_sample_ratio)" " (%d x %d) new rollouts in one run of the inner controller" "Re-init the controller and re-run this iteration.", num_to_sample * self.min_inner_sample_ratio, num_to_sample, self.min_inner_sample_ratio) continue i_iter += 1 assert len(best_scores) == num_to_sample sampled_rollouts += best_rollouts sampled_scores += best_scores cur_sampled_mean_max = (len(sampled_scores), np.mean(sampled_scores), np.max(sampled_scores)) return sampled_rollouts