Ejemplo n.º 1
0
    def train_epoch(self, train_queue, model, criterion, optimizer, device, epoch):
        expect(self._is_setup, "trainer.setup should be called first")
        cls_objs = utils.AverageMeter()
        loc_objs = utils.AverageMeter()
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        model.train()

        for step, (inputs, targets) in enumerate(train_queue):
            inputs = inputs.to(self.device)
            # targets = targets.to(self.device)

            optimizer.zero_grad()
            predictions = model.forward(inputs)
            classification_loss, regression_loss = criterion(inputs, predictions, targets, model)
            loss = classification_loss + regression_loss
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip)
            optimizer.step()

            prec1, prec5 = self._acc_func(inputs, predictions, targets, model)

            n = inputs.size(0)
            cls_objs.update(classification_loss.item(), n)
            loc_objs.update(regression_loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

            if step % self.report_every == 0:
                self.logger.info("train %03d %.3f %.3f; %.2f%%; %.2f%%",
                                 step, cls_objs.avg, loc_objs.avg, top1.avg, top5.avg)
        return top1.avg, cls_objs.avg + loc_objs.avg
Ejemplo n.º 2
0
    def train_epoch(self, train_queue, model, criterion, optimizer, device,
                    epoch):
        expect(self._is_setup, "trainer.setup should be called first")
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        losses_obj = utils.OrderedStats()
        model.train()

        for step, (inputs, targets) in enumerate(train_queue):
            inputs = inputs.to(self.device)

            optimizer.zero_grad()
            predictions = model.forward(inputs)
            losses = criterion(inputs, predictions, targets, model)
            loss = sum(losses.values())
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip)
            optimizer.step()

            prec1, prec5 = self._acc_func(inputs, predictions, targets, model)

            n = inputs.size(0)
            losses_obj.update(losses)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)

            if step % self.report_every == 0:
                self.logger.info("train %03d %.2f%%; %.2f%%; %s",
                                 step, top1.avg, top5.avg, "; ".join(
                                     ["{}: {:.3f}".format(perf_n, v) \
                                      for perf_n, v in losses_obj.avgs().items()]))
        return top1.avg, sum(losses_obj.avgs().values())
Ejemplo n.º 3
0
    def infer_epoch(self, valid_queue, model, criterion, device):
        expect(self._is_setup, "trainer.setup should be called first")
        objs = utils.AverageMeter()
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        objective_perfs = utils.OrderedStats()
        model.eval()

        context = torch.no_grad if self.eval_no_grad else nullcontext
        with context():
            for step, (inputs, target) in enumerate(valid_queue):
                inputs = inputs.to(device)
                target = target.to(device)

                logits = model(inputs)
                loss = criterion(logits, target)
                perfs = self._perf_func(inputs, logits, target, model)
                prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
                n = inputs.size(0)
                objective_perfs.update(dict(zip(self._perf_names, perfs)), n=n)
                objs.update(loss.item(), n)
                top1.update(prec1.item(), n)
                top5.update(prec5.item(), n)

                if step % self.report_every == 0:
                    self.logger.info("valid %03d %e %f %f %s", step, objs.avg, top1.avg, top5.avg,
                                     "; ".join(["{}: {:.3f}".format(perf_n, v) \
                                                for perf_n, v in objective_perfs.avgs().items()]))

        return top1.avg, objs.avg, objective_perfs.avgs()
Ejemplo n.º 4
0
    def train_epoch(self, train_queue, model, criterion, optimizer, device,
                    epoch):
        expect(self._is_setup, "trainer.setup should be called first")
        objs = utils.AverageMeter()
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        model.train()

        for step, (inputs, target) in enumerate(train_queue):
            inputs = inputs.to(device)
            target = target.to(device)

            optimizer.zero_grad()
            if self.auxiliary_head:  # assume model return two logits in train mode
                logits, logits_aux = model(inputs)
                loss = self._obj_loss(
                    inputs,
                    logits,
                    target,
                    model,
                    add_evaluator_regularization=self.add_regularization)
                loss_aux = criterion(logits_aux, target)
                loss += self.auxiliary_weight * loss_aux
            else:
                logits = model(inputs)
                loss = self._obj_loss(
                    inputs,
                    logits,
                    target,
                    model,
                    add_evaluator_regularization=self.add_regularization)
            #torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip)
            optimizer.step()

            prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
            n = inputs.size(0)
            objs.update(loss.item(), n)
            top1.update(prec1.item(), n)
            top5.update(prec5.item(), n)
            del loss

            if step % self.report_every == 0:
                self.logger.info("train %03d %.3f; %.2f%%; %.2f%%", step,
                                 objs.avg, top1.avg, top5.avg)

        return top1.avg, objs.avg
Ejemplo n.º 5
0
    def _controller_update(self, steps, finished_e_steps, finished_c_steps):
        controller_loss_meter = utils.AverageMeter()
        controller_stat_meters = utils.OrderedStats()
        rollout_stat_meters = utils.OrderedStats()

        self.controller.set_mode("train")
        for i_cont in range(1, steps + 1):
            print("\reva step {}/{} ; controller step {}/{}"\
                  .format(finished_e_steps, self.evaluator_steps,
                          finished_c_steps+i_cont, self.controller_steps),
                  end="")

            rollouts = self.controller.sample(self.controller_samples,
                                              self.rollout_batch_size)
            # if self.rollout_type == "differentiable":
            if self.is_differentiable:
                self.controller.zero_grad()

            step_loss = {"_": 0.}
            rollouts = self.evaluator.evaluate_rollouts(
                rollouts,
                is_training=True,
                callback=partial(self._backward_rollout_to_controller,
                                 step_loss=step_loss))
            self.evaluator.update_rollouts(rollouts)

            # if self.rollout_type == "differentiable":
            if self.is_differentiable:
                # differntiable rollout (controller is optimized using differentiable relaxation)
                # adjust lr and call step_current_gradients
                # (update using the accumulated gradients)
                controller_loss = step_loss["_"] / self.controller_samples
                if self.controller_samples != 1:
                    # adjust the lr to keep the effective learning rate unchanged
                    lr_bak = self.controller_optimizer.param_groups[0]["lr"]
                    self.controller_optimizer.param_groups[0]["lr"] \
                        = lr_bak / self.controller_samples
                self.controller.step_current_gradient(
                    self.controller_optimizer)
                if self.controller_samples != 1:
                    self.controller_optimizer.param_groups[0]["lr"] = lr_bak
            else:  # other rollout types
                controller_loss = self.controller.step(
                    rollouts, self.controller_optimizer, perf_name="reward")

            # update meters
            controller_loss_meter.update(controller_loss)
            controller_stats = self.controller.summary(rollouts, log=False)
            if controller_stats is not None:
                controller_stat_meters.update(controller_stats)

            r_stats = OrderedDict()
            for n in rollouts[0].perf:
                r_stats[n] = np.mean([r.perf[n] for r in rollouts])
            rollout_stat_meters.update(r_stats)

        print("\r", end="")

        return controller_loss, rollout_stat_meters.avgs(
        ), controller_stat_meters.avgs()
Ejemplo n.º 6
0
    def infer_epoch(self, valid_queue, model, criterion, device):
        expect(self._is_setup, "trainer.setup should be called first")
        objs = utils.AverageMeter()
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        objective_perfs = utils.OrderedStats()
        all_perfs = []
        model.eval()

        context = torch.no_grad if self.eval_no_grad else nullcontext
        with context():
            for step, (inputs, target) in enumerate(valid_queue):
                inputs = inputs.to(device)
                target = target.to(device)

                logits = model(inputs)
                loss = criterion(logits, target)
                perfs = self._perf_func(inputs, logits, target, model)
                all_perfs.append(perfs)
                prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
                n = inputs.size(0)
                # objective_perfs.update(dict(zip(self._perf_names, perfs)), n=n)
                objs.update(loss.item(), n)
                top1.update(prec1.item(), n)
                top5.update(prec5.item(), n)
                del loss
                if step % self.report_every == 0:
                    all_perfs_by_name = list(zip(*all_perfs))
                    # support use objective aggregate fn, for stat method other than mean
                    # e.g., adversarial distance median; detection mAP (see det_trainer.py)
                    obj_perfs = {
                        k: self.objective.aggregate_fn(k, False)(v)
                        for k, v in zip(self._perf_names, all_perfs_by_name)
                    }
                    self.logger.info("valid %03d %e %f %f %s", step, objs.avg, top1.avg, top5.avg,
                                     "; ".join(["{}: {:.3f}".format(perf_n, v) \
                                                # for perf_n, v in objective_perfs.avgs().items()]))

                                                for perf_n, v in obj_perfs.items()]))
        all_perfs_by_name = list(zip(*all_perfs))
        obj_perfs = {
            k: self.objective.aggregate_fn(k, False)(v)
            for k, v in zip(self._perf_names, all_perfs_by_name)
        }
        return top1.avg, objs.avg, obj_perfs
Ejemplo n.º 7
0
def train_listwise(train_data, model, epoch, args, arch_network_type):
    objs = utils.AverageMeter()
    model.train()
    num_data = len(train_data)
    idx_list = np.arange(num_data)
    num_batches = getattr(
        args, "num_batch_per_epoch",
        int(num_data / (args.batch_size * args.list_length) *
            args.max_compare_ratio))
    logging.info("Number of batches: {:d}".format(num_batches))
    update_batch_n = getattr(args, "update_batch_n", 1)
    listwise_compare = getattr(args, "listwise_compare", False)
    if listwise_compare:
        assert args.list_length == 2 and update_batch_n == 1
    model.optimizer.zero_grad()
    for step in range(1, num_batches + 1):
        if getattr(args, "bs_replace", False):
            idxes = np.array([
                np.random.choice(idx_list,
                                 size=(args.list_length, ),
                                 replace=False) for _ in range(args.batch_size)
            ])
        else:
            idxes = np.random.choice(idx_list,
                                     size=(args.batch_size, args.list_length),
                                     replace=False)
        flat_idxes = idxes.reshape(-1)
        archs, accs, _ = zip(*[train_data[idx] for idx in flat_idxes])
        archs = np.array(archs).reshape(
            (args.batch_size, args.list_length, -1))
        accs = np.array(accs).reshape((args.batch_size, args.list_length))
        # accs[np.arange(0, args.batch_size)[:, None], np.argsort(accs, axis=1)[:, ::-1]]
        if update_batch_n == 1:
            if listwise_compare:
                loss = model.update_compare(archs[:, 0, :], archs[:, 1, :],
                                            accs[:, 1] > accs[:, 0])
            else:
                loss = model.update_argsort(archs,
                                            np.argsort(accs, axis=1)[:, ::-1],
                                            first_n=getattr(
                                                args, "score_list_length",
                                                None))
        else:
            loss = model.update_argsort(archs,
                                        np.argsort(accs, axis=1)[:, ::-1],
                                        first_n=getattr(
                                            args, "score_list_length", None),
                                        accumulate_only=True)
            if step % update_batch_n == 0:
                model.optimizer.step()
                model.optimizer.zero_grad()
        if arch_network_type != "random_forest":
            objs.update(loss, args.batch_size)
        if step % args.report_freq == 0:
            logging.info("train {:03d} [{:03d}/{:03d}] {:.4f}".format(
                epoch, step, num_batches, objs.avg))
    return objs.avg
Ejemplo n.º 8
0
def train_epoch(logger, train_loader, model, epoch, cfg):
    objs = utils.AverageMeter()
    n_diff_pairs_meter = utils.AverageMeter()
    model.train()
    for step, (archs, accs) in enumerate(train_loader):
        archs = np.array(archs)
        accs = np.array(accs)
        n = len(archs)
        if cfg["compare"]:
            n_max_pairs = int(cfg["max_compare_ratio"] * n)
            acc_diff = np.array(accs)[:, None] - np.array(accs)
            acc_abs_diff_matrix = np.triu(np.abs(acc_diff), 1)
            ex_thresh_inds = np.where(
                acc_abs_diff_matrix > cfg["compare_threshold"])
            ex_thresh_num = len(ex_thresh_inds[0])
            if ex_thresh_num > n_max_pairs:
                keep_inds = np.random.choice(np.arange(ex_thresh_num),
                                             n_max_pairs,
                                             replace=False)
                ex_thresh_inds = (ex_thresh_inds[0][keep_inds],
                                  ex_thresh_inds[1][keep_inds])
            archs_1, archs_2, better_lst = archs[ex_thresh_inds[1]], archs[ex_thresh_inds[0]], \
                                           (acc_diff > 0)[ex_thresh_inds]
            n_diff_pairs = len(better_lst)
            n_diff_pairs_meter.update(float(n_diff_pairs))
            loss = model.update_compare(archs_1, archs_2, better_lst)
            objs.update(loss, n_diff_pairs)
        else:
            loss = model.update_predict(archs, accs)
            objs.update(loss, n)
        if step % cfg["report_freq"] == 0:
            n_pair_per_batch = (cfg["batch_size"] *
                                (cfg["batch_size"] - 1)) // 2
            logger.info("train {:03d} [{:03d}/{:03d}] {:.4f}; {}".format(
                epoch, step, len(train_loader), objs.avg,
                "different pair ratio: {:.3f} ({:.1f}/{:3d})".format(
                    n_diff_pairs_meter.avg /
                    n_pair_per_batch, n_diff_pairs_meter.avg, n_pair_per_batch)
                if cfg["compare"] else ""))
    return objs.avg
Ejemplo n.º 9
0
    def infer_epoch(self, valid_queue, model, criterion, device):
        expect(self._is_setup, "trainer.setup should be called first")
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        objective_perfs = utils.OrderedStats()
        losses_obj = utils.OrderedStats()
        all_perfs = []
        model.eval()

        context = torch.no_grad if self.eval_no_grad else nullcontext
        with context():
            for step, (inputs, targets) in enumerate(valid_queue):
                inputs = inputs.to(device)
                # targets = targets.to(device)

                predictions = model.forward(inputs)
                losses = criterion(inputs, predictions, targets, model)
                prec1, prec5 = self._acc_func(inputs, predictions, targets,
                                              model)
                perfs = self._perf_func(inputs, predictions, targets, model)
                all_perfs.append(perfs)
                objective_perfs.update(dict(zip(self._perf_names, perfs)))
                losses_obj.update(losses)
                n = inputs.size(0)
                top1.update(prec1.item(), n)
                top5.update(prec5.item(), n)

                if step % self.report_every == 0:
                    self.logger.info(
                        "valid %03d %.2f%%; %.2f%%; %s", step, top1.avg, top5.avg,
                        "; ".join([
                            "{}: {:.3f}".format(perf_n, v) for perf_n, v in \
                            list(objective_perfs.avgs().items()) + \
                            list(losses_obj.avgs().items())]))
        all_perfs = list(zip(*all_perfs))
        obj_perfs = {
            k: self.objective.aggregate_fn(k, False)(v)
            for k, v in zip(self._perf_names, all_perfs)
        }
        return top1.avg, sum(losses_obj.avgs().values()), obj_perfs
Ejemplo n.º 10
0
 def evaluate_epoch(self, data, targets, bptt_steps):
     expect(self._is_setup, "trainer.setup should be called first")
     batch_size = data.shape[1]
     self.model.eval()
     objs = utils.AverageMeter()
     hiddens = self.model.init_hidden(batch_size)
     for i in range(0, data.size(0), bptt_steps):
         seq_len = min(bptt_steps, len(data)-i)
         inp, targ = data[i:i+seq_len], targets[i:i+seq_len]
         logits, _, _, hiddens = self.parallel_model(inp, hiddens)
         objs.update(self._criterion(logits.view(-1, logits.size(-1)),
                                     targ.view(-1)).item(),
                     seq_len)
     return objs.avg
Ejemplo n.º 11
0
    def infer_epoch(self, valid_queue, model, criterion, device):
        expect(self._is_setup, "trainer.setup should be called first")
        cls_objs = utils.AverageMeter()
        loc_objs = utils.AverageMeter()
        top1 = utils.AverageMeter()
        top5 = utils.AverageMeter()
        objective_perfs = utils.OrderedStats()
        model.eval()

        context = torch.no_grad if self.eval_no_grad else nullcontext
        with context():
            for step, (inputs, targets) in enumerate(valid_queue):
                inputs = inputs.to(device)
                # targets = targets.to(device)

                predictions = model.forward(inputs)
                classification_loss, regression_loss = criterion(
                    inputs, predictions, targets, model)

                prec1, prec5 = self._acc_func(inputs, predictions, targets, model)

                perfs = self._perf_func(inputs, predictions, targets, model)
                objective_perfs.update(dict(zip(self._perf_names, perfs)))
                n = inputs.size(0)
                cls_objs.update(classification_loss.item(), n)
                loc_objs.update(regression_loss.item(), n)
                top1.update(prec1.item(), n)
                top5.update(prec5.item(), n)

                if step % self.report_every == 0:
                    self.logger.info("valid %03d %e %e;  %.2f%%; %.2f%%; %s",
                                     step, cls_objs.avg, loc_objs.avg, top1.avg, top5.avg,
                                     "; ".join(["{}: {:.3f}".format(perf_n, v) \
                                                for perf_n, v in objective_perfs.avgs().items()]))
        stats = self.dataset.evaluate_detections(self.objective.all_boxes, self.eval_dir)
        self.logger.info("mAP: {}".format(stats[0]))
        return top1.avg, cls_objs.avg + loc_objs.avg, objective_perfs.avgs()
Ejemplo n.º 12
0
def train_multi_stage_pair_pool(all_stages, pairs_list, model, i_epoch, args):
    objs = utils.AverageMeter()
    model.train()

    # try get through all the pairs
    pairs_pool = list(
        zip(*[np.concatenate(items) for items in zip(*pairs_list)]))
    num_pairs = len(pairs_pool)
    logging.info("Number of pairs: {}".format(num_pairs))
    np.random.shuffle(pairs_pool)
    num_batch = num_pairs // args.batch_size

    for i_batch in range(num_batch):
        archs_1_inds, archs_2_inds, better_lst = list(
            zip(*pairs_pool[i_batch * args.batch_size:(i_batch + 1) *
                            args.batch_size]))
        loss = model.update_compare(
            np.array([all_stages[idx][0] for idx in archs_1_inds]),
            np.array([all_stages[idx][0] for idx in archs_2_inds]), better_lst)
        objs.update(loss, args.batch_size)
        if i_batch % args.report_freq == 0:
            logging.info("train {:03d} [{:03d}/{:03d}] {:.4f}".format(
                i_epoch, i_batch, num_batch, objs.avg))
    return objs.avg
Ejemplo n.º 13
0
def train(train_loader, model, epoch, args, arch_network_type):
    objs = utils.AverageMeter()
    n_diff_pairs_meter = utils.AverageMeter()
    model.train()
    for step, (archs, f_accs, h_accs) in enumerate(train_loader):
        archs = np.array(archs)
        h_accs = np.array(h_accs)
        f_accs = np.array(f_accs)
        n = len(archs)
        if getattr(args, "use_half", False):
            accs = h_accs
        else:
            accs = f_accs
        if args.compare:
            if None in f_accs:
                # some archs only have half-time acc
                n_max_pairs = int(args.max_compare_ratio * n)
                n_max_inter_pairs = int(args.inter_pair_ratio * n_max_pairs)
                half_inds = np.array(
                    [ind for ind, acc in enumerate(accs) if acc is None])
                mask = np.zeros(n)
                mask[half_inds] = 1
                final_inds = np.where(1 - mask)[0]

                half_eche = h_accs[half_inds]
                final_eche = h_accs[final_inds]
                half_acc_diff = final_eche[:,
                                           None] - half_eche  # (num_final, num_half)
                assert (half_acc_diff >= 0).all()  # should be >0
                half_ex_thresh_inds = np.where(
                    np.abs(half_acc_diff) >
                    getattr(args, "half_compare_threshold", 2 *
                            args.compare_threshold))
                half_ex_thresh_num = len(half_ex_thresh_inds[0])
                if half_ex_thresh_num > n_max_inter_pairs:
                    # random choose
                    keep_inds = np.random.choice(np.arange(half_ex_thresh_num),
                                                 n_max_inter_pairs,
                                                 replace=False)
                    half_ex_thresh_inds = (half_ex_thresh_inds[0][keep_inds],
                                           half_ex_thresh_inds[1][keep_inds])
                inter_archs_1, inter_archs_2, inter_better_lst \
                    = archs[half_inds[half_ex_thresh_inds[1]]], archs[final_inds[half_ex_thresh_inds[0]]], \
                    (half_acc_diff > 0)[half_ex_thresh_inds]
                n_inter_pairs = len(inter_better_lst)

                # only use intra pairs in the final echelon
                n_intra_pairs = n_max_pairs - n_inter_pairs
                accs = np.array(accs)[final_inds]
                archs = archs[final_inds]
                acc_diff = np.array(accs)[:, None] - np.array(accs)
                acc_abs_diff_matrix = np.triu(np.abs(acc_diff), 1)
                ex_thresh_inds = np.where(
                    acc_abs_diff_matrix > args.compare_threshold)
                ex_thresh_num = len(ex_thresh_inds[0])
                if ex_thresh_num > n_intra_pairs:
                    if args.choose_pair_criterion == "diff":
                        keep_inds = np.argpartition(
                            acc_abs_diff_matrix[ex_thresh_inds],
                            -n_intra_pairs)[-n_intra_pairs:]
                    elif args.choose_pair_criterion == "random":
                        keep_inds = np.random.choice(np.arange(ex_thresh_num),
                                                     n_intra_pairs,
                                                     replace=False)
                    ex_thresh_inds = (ex_thresh_inds[0][keep_inds],
                                      ex_thresh_inds[1][keep_inds])
                archs_1, archs_2, better_lst = archs[ex_thresh_inds[1]], archs[
                    ex_thresh_inds[0]], (acc_diff > 0)[ex_thresh_inds]
                archs_1, archs_2, better_lst = np.concatenate((inter_archs_1, archs_1)),\
                                               np.concatenate((inter_archs_2, archs_2)),\
                                               np.concatenate((inter_better_lst, better_lst))
            else:
                if getattr(args, "compare_split", False):
                    n_pairs = len(archs) // 2
                    accs = np.array(accs)
                    acc_diff_lst = accs[n_pairs:2 * n_pairs] - accs[:n_pairs]
                    keep_inds = np.where(
                        np.abs(acc_diff_lst) > args.compare_threshold)[0]
                    better_lst = (np.array(accs[n_pairs:2 * n_pairs] -
                                           accs[:n_pairs]) > 0)[keep_inds]
                    archs_1 = np.array(archs[:n_pairs])[keep_inds]
                    archs_2 = np.array(archs[n_pairs:2 * n_pairs])[keep_inds]
                else:
                    n_max_pairs = int(args.max_compare_ratio * n)
                    acc_diff = np.array(accs)[:, None] - np.array(accs)
                    acc_abs_diff_matrix = np.triu(np.abs(acc_diff), 1)
                    ex_thresh_inds = np.where(
                        acc_abs_diff_matrix > args.compare_threshold)
                    ex_thresh_num = len(ex_thresh_inds[0])
                    if ex_thresh_num > n_max_pairs:
                        if args.choose_pair_criterion == "diff":
                            keep_inds = np.argpartition(
                                acc_abs_diff_matrix[ex_thresh_inds],
                                -n_max_pairs)[-n_max_pairs:]
                        elif args.choose_pair_criterion == "random":
                            keep_inds = np.random.choice(
                                np.arange(ex_thresh_num),
                                n_max_pairs,
                                replace=False)
                        ex_thresh_inds = (ex_thresh_inds[0][keep_inds],
                                          ex_thresh_inds[1][keep_inds])
                    archs_1, archs_2, better_lst = archs[
                        ex_thresh_inds[1]], archs[ex_thresh_inds[0]], (
                            acc_diff > 0)[ex_thresh_inds]
            n_diff_pairs = len(better_lst)
            n_diff_pairs_meter.update(float(n_diff_pairs))
            loss = model.update_compare(archs_1, archs_2, better_lst)
            objs.update(loss, n_diff_pairs)
        else:
            loss = model.update_predict(archs, accs)
            if arch_network_type != "random_forest":
                objs.update(loss, n)
        if step % args.report_freq == 0:
            n_pair_per_batch = (args.batch_size * (args.batch_size - 1)) // 2
            logging.info("train {:03d} [{:03d}/{:03d}] {:.4f}; {}".format(
                epoch, step, len(train_loader), objs.avg,
                "different pair ratio: {:.3f} ({:.1f}/{:3d})".format(
                    n_diff_pairs_meter.avg /
                    n_pair_per_batch, n_diff_pairs_meter.avg, n_pair_per_batch)
                if args.compare else ""))
    return objs.avg
Ejemplo n.º 14
0
def sample_batchify(search_space, model, ratio, K, args, conflict_archs=None):
    model.eval()
    inner_sample_n = args.sample_batchify_inner_sample_n
    ss = search_space
    assert K % inner_sample_n == 0
    num_iter = K // inner_sample_n
    want_samples_per_iter = int(ratio * inner_sample_n)
    logging.info(
        "Sample {}. REPEAT {}: Sample {} archs based on the predicted score across {} archs"
        .format(K, num_iter, inner_sample_n, want_samples_per_iter))
    sampled_rollouts = []
    sampled_scores = []
    # the number, mean and max predicted scores of current sampled archs
    cur_sampled_mean_max = (0, 0, 0)
    i_iter = 1
    # num_steps = (ratio * K + args.batch_size - 1) // args.batch_size
    _r_cls = ss.random_sample().__class__
    conflict_rollouts = [
        _r_cls(arch, info={}, search_space=search_space)
        for arch in conflict_archs or []
    ]
    inner_report_freq = 10
    judget_conflict = False
    while i_iter <= num_iter:
        # # random init
        # if self.inner_iter_random_init \
        #    and hasattr(self.inner_controller, "reinit"):
        #     self.inner_controller.reinit()

        new_per_step_meter = utils.AverageMeter()

        # a list with length self.inner_sample_n
        best_rollouts = []
        best_scores = []
        num_to_sample = inner_sample_n
        iter_r_set = []
        iter_s_set = []
        sampled_r_set = sampled_rollouts
        # for i_inner in range(1, num_steps+1):
        i_inner = 0
        while new_per_step_meter.sum < want_samples_per_iter:
            i_inner += 1
            rollouts = [
                search_space.random_sample() for _ in range(args.batch_size)
            ]
            batch_archs = [r.arch for r in rollouts]
            step_scores = list(model.predict(batch_archs).cpu().data.numpy())
            if judget_conflict:
                new_inds, new_rollouts = zip(
                    *[(i, r) for i, r in enumerate(rollouts)
                      if r not in conflict_rollouts and r not in sampled_r_set
                      and r not in iter_r_set])
                new_step_scores = [step_scores[i] for i in new_inds]
                iter_r_set += new_rollouts
                iter_s_set += new_step_scores
            else:
                new_rollouts = rollouts
                new_step_scores = step_scores
            new_per_step_meter.update(len(new_rollouts))
            best_rollouts += new_rollouts
            best_scores += new_step_scores
            # iter_r_set += rollouts
            # iter_s_set += step_scores

            if len(best_scores) > num_to_sample:
                keep_inds = np.argpartition(best_scores,
                                            -num_to_sample)[-num_to_sample:]
                best_rollouts = [best_rollouts[ind] for ind in keep_inds]
                best_scores = [best_scores[ind] for ind in keep_inds]
            if i_inner % inner_report_freq == 0:
                logging.info(
                    (
                        "Seen %d/%d Iter %d (to sample %d) (already sampled %d mean %.5f, best %.5f); "
                        "Step %d: sample %d step mean %.5f best %.5f: {} "
                        # "(iter mean %.5f, best %.5f).
                        "AVG new/step: %.3f").format(", ".join(
                            ["{:.5f}".format(s) for s in best_scores])),
                    new_per_step_meter.sum,
                    want_samples_per_iter,
                    i_iter,
                    num_to_sample,
                    cur_sampled_mean_max[0],
                    cur_sampled_mean_max[1],
                    cur_sampled_mean_max[2],
                    i_inner,
                    len(rollouts),
                    np.mean(step_scores),
                    np.max(step_scores),
                    #np.mean(iter_s_set), np.max(iter_s_set),
                    new_per_step_meter.avg)
        # if new_per_step_meter.sum < num_to_sample * 10:
        #         # rerun this iter, also reinit!
        #         self.logger.info("Cannot find %d (num_to_sample x min_inner_sample_ratio)"
        #                          " (%d x %d) new rollouts in one run of the inner controller"
        #                          "Re-init the controller and re-run this iteration.",
        #                          num_to_sample * self.min_inner_sample_ratio,
        #                          num_to_sample, self.min_inner_sample_ratio)
        #         continue

        i_iter += 1
        assert len(best_scores) == num_to_sample
        sampled_rollouts += best_rollouts
        sampled_scores += best_scores
        cur_sampled_mean_max = (len(sampled_scores), np.mean(sampled_scores),
                                np.max(sampled_scores))

    return [r.genotype for r in sampled_rollouts]
Ejemplo n.º 15
0
def train(train_loader, model, epoch, args):
    objs = utils.AverageMeter()
    n_diff_pairs_meter = utils.AverageMeter()
    n_eq_pairs_meter = utils.AverageMeter()
    model.train()

    margin_diff_coeff = getattr(args, "margin_diff_coeff", None)
    eq_threshold = getattr(args, "eq_threshold", None)
    eq_pair_ratio = getattr(args, "eq_pair_ratio", 0)
    if eq_threshold is not None:
        assert eq_pair_ratio > 0
        assert eq_threshold <= args.compare_threshold
    for step, (archs, all_accs) in enumerate(train_loader):
        archs = np.array(archs)
        n = len(archs)
        use_checkpoint = getattr(args, "use_checkpoint", 3)
        accs = all_accs[:, use_checkpoint]
        if args.compare:
            if getattr(args, "compare_split", False):
                n_pairs = len(archs) // 2
                accs = np.array(accs)
                acc_diff_lst = accs[n_pairs:2 * n_pairs] - accs[:n_pairs]
                keep_inds = np.where(
                    np.abs(acc_diff_lst) > args.compare_threshold)[0]
                better_lst = (np.array(accs[n_pairs:2 * n_pairs] -
                                       accs[:n_pairs]) > 0)[keep_inds]
                archs_1 = np.array(archs[:n_pairs])[keep_inds]
                archs_2 = np.array(archs[n_pairs:2 * n_pairs])[keep_inds]
            else:
                n_max_pairs = int(args.max_compare_ratio * n *
                                  (1 - eq_pair_ratio))
                acc_diff = np.array(accs)[:, None] - np.array(accs)
                acc_abs_diff_matrix = np.triu(np.abs(acc_diff), 1)
                ex_thresh_inds = np.where(
                    acc_abs_diff_matrix > args.compare_threshold)
                ex_thresh_num = len(ex_thresh_inds[0])
                if ex_thresh_num > n_max_pairs:
                    if args.choose_pair_criterion == "diff":
                        keep_inds = np.argpartition(
                            acc_abs_diff_matrix[ex_thresh_inds],
                            -n_max_pairs)[-n_max_pairs:]
                    elif args.choose_pair_criterion == "random":
                        keep_inds = np.random.choice(np.arange(ex_thresh_num),
                                                     n_max_pairs,
                                                     replace=False)
                    ex_thresh_inds = (ex_thresh_inds[0][keep_inds],
                                      ex_thresh_inds[1][keep_inds])
                archs_1, archs_2, better_lst, acc_diff_lst = archs[
                    ex_thresh_inds[1]], archs[ex_thresh_inds[0]], (
                        acc_diff > 0)[ex_thresh_inds], acc_diff[ex_thresh_inds]
            n_diff_pairs = len(better_lst)
            n_diff_pairs_meter.update(float(n_diff_pairs))
            if eq_threshold is None:
                if margin_diff_coeff is not None:
                    margin = np.abs(acc_diff_lst) * margin_diff_coeff
                    loss = model.update_compare(archs_1,
                                                archs_2,
                                                better_lst,
                                                margin=margin)
                else:
                    loss = model.update_compare(archs_1, archs_2, better_lst)
            else:
                # drag close the score of arch pairs whose true acc diffs are below args.eq_threshold
                n_eq_pairs = int(args.max_compare_ratio * n * eq_pair_ratio)
                below_eq_thresh_inds = np.where(
                    acc_abs_diff_matrix < eq_threshold)
                below_eq_thresh_num = len(below_eq_thresh_inds[0])
                if below_eq_thresh_num > n_eq_pairs:
                    keep_inds = np.random.choice(
                        np.arange(below_eq_thresh_num),
                        n_eq_pairs,
                        replace=False)
                    below_eq_thresh_inds = (below_eq_thresh_inds[0][keep_inds],
                                            below_eq_thresh_inds[1][keep_inds])
                eq_archs_1, eq_archs_2, below_acc_diff_lst = \
                    archs[below_eq_thresh_inds[1]], archs[below_eq_thresh_inds[0]], acc_abs_diff_matrix[below_eq_thresh_inds]
                if margin_diff_coeff is not None:
                    margin = np.concatenate(
                        (np.abs(acc_diff_lst),
                         np.abs(below_acc_diff_lst))) * margin_diff_coeff
                else:
                    margin = None
                better_pm_lst = np.concatenate(
                    (2 * better_lst - 1, np.zeros(len(eq_archs_1))))
                n_eq_pairs_meter.update(float(len(eq_archs_1)))
                loss = model.update_compare_eq(np.concatenate(
                    (archs_1, eq_archs_1)),
                                               np.concatenate(
                                                   (archs_2, eq_archs_2)),
                                               better_pm_lst,
                                               margin=margin)
            objs.update(loss, n_diff_pairs)
        else:
            loss = model.update_predict(archs, accs)
            objs.update(loss, n)
        if step % args.report_freq == 0:
            n_pair_per_batch = (args.batch_size * (args.batch_size - 1)) // 2
            logging.info("train {:03d} [{:03d}/{:03d}] {:.4f}; {}".format(
                epoch, step, len(train_loader), objs.avg,
                "different pair ratio: {:.3f} ({:.1f}/{:3d}){}".format(
                    n_diff_pairs_meter.avg /
                    n_pair_per_batch, n_diff_pairs_meter.avg, n_pair_per_batch,
                    "; eq pairs: {.3d}".format(n_eq_pairs_meter.avg) if
                    eq_threshold is not None else "") if args.compare else ""))
    return objs.avg
Ejemplo n.º 16
0
def train_multi_stage_listwise(train_stages,
                               model,
                               epoch,
                               args,
                               avg_stage_scores,
                               stage_epochs,
                               score_train_stages=None):
    # TODO: multi stage
    objs = utils.AverageMeter()
    n_listlength_meter = utils.AverageMeter()
    model.train()

    num_stages = len(train_stages)

    stage_lens = [len(stage_data) for stage_data in train_stages]
    stage_sep_inds = [np.arange(stage_len) for stage_len in stage_lens]
    sample_acc_temp = getattr(args, "sample_acc_temp", None)
    if sample_acc_temp is not None:
        stage_sep_probs = []
        for i_stage, stage_data in enumerate(train_stages):
            perfs = np.array([
                item[1][stage_epochs[i_stage]]
                for item in train_stages[i_stage]
            ])
            perfs = perfs / sample_acc_temp
            exp_perfs = np.exp(perfs - np.max(perfs))
            stage_sep_probs.append(exp_perfs / exp_perfs.sum())
    else:
        stage_sep_probs = None
    stage_single_probs = getattr(args, "stage_single_probs", None)
    assert stage_single_probs is not None
    if stage_single_probs is not None:
        stage_probs = np.array([
            single_prob * len_
            for single_prob, len_ in zip(stage_single_probs, stage_lens)
        ])
        stage_probs = stage_probs / stage_probs.sum()
    logging.info("Epoch {:d}: Stage probs {}".format(epoch, stage_probs))

    num_stage_samples_avg = np.zeros(num_stages)
    train_stages = np.array(train_stages)

    listwise_compare = getattr(args, "listwise_compare", False)
    if listwise_compare:
        assert args.list_length == 2

    for step in range(args.num_batch_per_epoch):
        num_stage_samples = np.random.multinomial(args.list_length,
                                                  stage_probs)
        num_stage_samples = np.minimum(num_stage_samples, stage_lens)
        true_ll = np.sum(num_stage_samples)
        n_listlength_meter.update(true_ll, args.batch_size)
        num_stage_samples_avg += num_stage_samples
        stage_inds = [
            np.array([
                np.random.choice(stage_sep_inds[i_stage],
                                 size=(sz),
                                 replace=False,
                                 p=None if stage_sep_probs is None else
                                 stage_sep_probs[i_stage])
                for _ in range(args.batch_size)
            ]) if sz > 0 else np.zeros((args.batch_size, 0), dtype=np.int)
            for i_stage, sz in enumerate(num_stage_samples)
        ]
        sorted_stage_inds = [
            s_stage_inds[
                np.arange(args.batch_size)[:, None],
                np.argsort(np.array(
                    np.array(train_stages[i_stage])[s_stage_inds][:, :, 1].
                    tolist())[:, :, stage_epochs[i_stage]],
                           axis=1)]
            if s_stage_inds.shape[1] > 1 else s_stage_inds
            for i_stage, s_stage_inds in enumerate(stage_inds)
        ]
        archs = np.concatenate([
            np.array(train_stages[i_stage])[s_stage_inds][:, :, 0]
            for i_stage, s_stage_inds in enumerate(sorted_stage_inds)
            if s_stage_inds.size > 0
        ],
                               axis=1)
        archs = archs[:, ::-1]  # order: from best to worst
        assert archs.ndim == 2
        archs = np.array(archs.tolist(
        ))  # (batch_size, list_length, num_cell_groups, node_or_op, decisions)
        if listwise_compare:
            loss = model.update_compare(archs[:, 0], archs[:, 1],
                                        np.zeros(archs.shape[0]))
        else:
            loss = model.update_argsort(archs,
                                        idxes=None,
                                        first_n=getattr(
                                            args, "score_list_length", None),
                                        is_sorted=True)
        objs.update(loss, args.batch_size)
        if step % args.report_freq == 0:
            logging.info(
                "train {:03d} [{:03d}/{:03d}] {:.4f} (mean ll: {:.1f}; {})".
                format(epoch, step, args.num_batch_per_epoch, objs.avg,
                       n_listlength_meter.avg,
                       (num_stage_samples_avg / (step + 1)).tolist()))
    return objs.avg
Ejemplo n.º 17
0
def train_multi_stage(train_stages, model, epoch, args, avg_stage_scores,
                      stage_epochs):
    # TODO: multi stage
    objs = utils.AverageMeter()
    n_diff_pairs_meter = utils.AverageMeter()
    model.train()

    num_stages = len(train_stages)
    # must specificy `stage_probs` or `stage_prob_power`
    stage_probs = getattr(args, "stage_probs", None)
    if stage_probs is None:
        stage_probs = _cal_stage_probs(avg_stage_scores, args.stage_prob_power)
    stage_accept_pair_probs = getattr(args, "stage_accept_pair_probs",
                                      [1.0] * num_stages)

    stage_lens = [len(stage_data) for stage_data in train_stages]
    for i, len_ in enumerate(stage_lens):
        if len_ == 0:
            n_j = num_stages - i - 1
            for j in range(i + 1, num_stages):
                stage_probs[j] += stage_probs[i] / float(n_j)
            stage_probs[i] = 0
    # diff_threshold = getattr(args, "diff_threshold", [0.08, 0.04, 0.02, 0.0])
    stage_single_probs = getattr(args, "stage_single_probs", None)
    if stage_single_probs is not None:
        stage_probs = np.array([
            single_prob * len_
            for single_prob, len_ in zip(stage_single_probs, stage_lens)
        ])
        stage_probs = stage_probs / stage_probs.sum()
    logging.info("Epoch {:d}: Stage probs {}".format(epoch, stage_probs))

    diff_threshold = args.diff_threshold
    for step in range(args.num_batch_per_epoch):
        pair_batch = []
        i_pair = 0
        while 1:
            stage_1, stage_2 = np.random.choice(np.arange(num_stages),
                                                size=2,
                                                p=stage_probs)
            d_1 = train_stages[stage_1][np.random.randint(
                0, stage_lens[stage_1])]
            d_2 = train_stages[stage_2][np.random.randint(
                0, stage_lens[stage_2])]
            min_stage = min(stage_2, stage_1)
            if np.random.rand() > stage_accept_pair_probs[min_stage]:
                continue
            # max_stage = stage_2 + stage_1 - min_stage
            # if max_stage - min_stage >= 2:
            #     better = stage_2 > stage_1
            # else:
            min_epoch = stage_epochs[min_stage]
            diff_21 = d_2[1][min_epoch] - d_1[1][min_epoch]
            # print(stage_1, stage_2, diff_21, diff_threshold)
            if np.abs(diff_21) > diff_threshold[min_stage]:
                # if the difference is larger than the threshold of the min stage, this pair count
                better = diff_21 > 0
            else:
                continue
            pair_batch.append((d_1[0], d_2[0], better))
            i_pair += 1
            if i_pair == args.batch_size:
                break
        archs_1, archs_2, better_lst = zip(*pair_batch)
        n_diff_pairs = len(better_lst)
        n_diff_pairs_meter.update(float(n_diff_pairs))
        loss = model.update_compare(archs_1, archs_2, better_lst)
        objs.update(loss, n_diff_pairs)
        if step % args.report_freq == 0:
            logging.info("train {:03d} [{:03d}/{:03d}] {:.4f}".format(
                epoch, step, args.num_batch_per_epoch, objs.avg))
    return objs.avg
Ejemplo n.º 18
0
    def train(self):  #pylint: disable=too-many-branches
        assert self.is_setup, "Must call `trainer.setup` method before calling `trainer.train`."

        if self.interleave_controller_every is not None:
            inter_steps = self.controller_steps
            evaluator_steps = self.interleave_controller_every
            controller_steps = 1
        else:
            inter_steps = 1
            evaluator_steps = self.evaluator_steps
            controller_steps = self.controller_steps

        for epoch in range(self.last_epoch + 1, self.epochs + 1):
            c_loss_meter = utils.AverageMeter()
            rollout_stat_meters = utils.OrderedStats(
            )  # rollout performance stats from evaluator
            c_stat_meters = utils.OrderedStats()  # other stats from controller
            eva_stat_meters = utils.OrderedStats(
            )  # other stats from `evaluator.update_evaluator`

            self.epoch = epoch  # this is redundant as Component.on_epoch_start also set this

            # call `on_epoch_start` of sub-components
            # also schedule values and optimizer learning rates
            self.on_epoch_start(epoch)

            finished_e_steps = 0
            finished_c_steps = 0
            for i_inter in range(1, inter_steps +
                                 1):  # interleave mepa/controller training
                # meta parameter training
                if evaluator_steps > 0:
                    e_stats = self._evaluator_update(evaluator_steps,
                                                     finished_e_steps,
                                                     finished_c_steps)
                    eva_stat_meters.update(e_stats)
                    finished_e_steps += evaluator_steps

                if epoch >= self.controller_train_begin and \
                   epoch % self.controller_train_every == 0 and controller_steps > 0:
                    # controller training
                    c_loss, rollout_stats, c_stats \
                        = self._controller_update(controller_steps,
                                                  finished_e_steps, finished_c_steps)
                    # update meters
                    if c_loss is not None:
                        c_loss_meter.update(c_loss)
                    if rollout_stats is not None:
                        rollout_stat_meters.update(rollout_stats)
                    if c_stats is not None:
                        c_stat_meters.update(c_stats)

                    finished_c_steps += controller_steps

                if self.interleave_report_every and i_inter % self.interleave_report_every == 0:
                    # log for every `interleave_report_every` interleaving steps
                    self.logger.info("(inter step %3d): "
                                     "evaluator (%3d/%3d) %s ; "
                                     "controller (%3d/%3d) %s",
                                     i_inter, finished_e_steps, self.evaluator_steps,
                                     "; ".join(
                                         ["{}: {:.3f}".format(n, v) \
                                          for n, v in eva_stat_meters.avgs().items()]),
                                     finished_c_steps, self.controller_steps,
                                     "" if not rollout_stat_meters else "; ".join(
                                         ["{}: {:.3f}".format(n, v) \
                                          for n, v in rollout_stat_meters.avgs().items()]))

            # log infomations of this epoch
            if eva_stat_meters:
                self.logger.info("Epoch %3d: [evaluator update] %s", epoch,
                                 "; ".join(["{}: {:.3f}".format(n, v) \
                                            for n, v in eva_stat_meters.avgs().items()]))
            if rollout_stat_meters:
                self.logger.info("Epoch %3d: [controller update] controller loss: %.3f ; "
                                 "rollout performance: %s", epoch, c_loss_meter.avg,
                                 "; ".join(["{}: {:.3f}".format(n, v) \
                                            for n, v in rollout_stat_meters.avgs().items()]))
            if c_stat_meters:
                self.logger.info("[controller stats] %s", \
                                 "; ".join(["{}: {:.3f}".format(n, v) \
                                            for n, v in c_stat_meters.avgs().items()]))

            # maybe write tensorboard info
            if not self.writer.is_none():
                if eva_stat_meters:
                    for n, meter in eva_stat_meters.items():
                        self.writer.add_scalar(
                            "evaluator_update/{}".format(n.replace(" ", "-")),
                            meter.avg, epoch)
                if rollout_stat_meters:
                    for n, meter in rollout_stat_meters.items():
                        self.writer.add_scalar(
                            "controller_update/{}".format(n.replace(" ", "-")),
                            meter.avg, epoch)
                if c_stat_meters:
                    for n, meter in c_stat_meters.items():
                        self.writer.add_scalar(
                            "controller_stats/{}".format(n.replace(" ", "-")),
                            meter.avg, epoch)
                if not c_loss_meter.is_empty():
                    self.writer.add_scalar("controller_loss", c_loss_meter.avg,
                                           epoch)

            # maybe save checkpoints
            self.maybe_save()

            # maybe derive archs and test
            if self.test_every and self.epoch % self.test_every == 0:
                self.test()

            self.on_epoch_end(epoch)  # call `on_epoch_end` of sub-components

        # `final_save` pickle dump the weights_manager and controller directly,
        # instead of the state dict
        self.final_save()
Ejemplo n.º 19
0
    def train_epoch(self, data, targets, bptt_steps):
        expect(self._is_setup, "trainer.setup should be called first")
        batch_size = data.shape[1]
        num_total_steps = data.shape[0]
        self.model.train()
        objs = utils.AverageMeter()
        losses = utils.AverageMeter()

        hiddens = self.model.init_hidden(batch_size)

        if self.random_bptt:
            # random sequece lengths
            seq_lens = []
            i = 0
            while i < data.size(0):
                mean_ = bptt_steps if np.random.random() < 0.95 else bptt_steps / 2
                seq_len = min(max(5, int(np.random.normal(mean_, 5))), bptt_steps + 20)
                seq_lens.append(seq_len)
                i += seq_len
            seq_lens[-1] -= i - data.size(0)
            num_total_batches = len(seq_lens)
        else:
            # fixed sequence length == bptt_steps
            num_total_batches = int(np.ceil(data.size(0) / bptt_steps))
            seq_lens = [bptt_steps] * num_total_batches
            seq_lens[-1] = num_total_steps - bptt_steps * (num_total_batches-1)

        lr_bak = self.optimizer.param_groups[0]["lr"]
        i = 0
        for batch in range(1, num_total_batches+1):
            seq_len = seq_lens[batch-1]
            inp, targ = data[i:i+seq_len], targets[i:i+seq_len]

            # linear adjusting learning rate
            self.optimizer.param_groups[0]["lr"] = lr_bak * seq_len / bptt_steps
            self.optimizer.zero_grad()

            logits, raw_outs, outs, hiddens = self.parallel_model(inp, hiddens)

            raw_loss = self._criterion(logits.view(-1, logits.size(-1)), targ.view(-1))

            loss = raw_loss
            # Activiation Regularization
            if self.rnn_act_reg > 0:
                loss = loss + self.rnn_act_reg * outs.pow(2).mean()
            # Temporal Activation Regularization (slowness)
            if self.rnn_slowness_reg > 0:
                loss = loss + self.rnn_slowness_reg * (raw_outs[1:] - raw_outs[:-1]).pow(2).mean()

            loss.backward()
            # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
            self.model.step_current_gradients(self.optimizer)

            objs.update(raw_loss.item(), seq_len)
            losses.update(loss.item(), seq_len)

            # del logits, raw_outs, outs, raw_loss, loss

            i += seq_len
            if batch % self.report_every == 0:
                self.logger.info("train %3d/%3d: perp %.3f; loss %.3f; loss(with reg) %.3f",
                                 batch, num_total_batches, np.exp(objs.avg), objs.avg, losses.avg)

        self.optimizer.param_groups[0]["lr"] = lr_bak
        return objs.avg, losses.avg
Ejemplo n.º 20
0
    def sample(self, n=1, batch_size=1):
        """Sample architectures based on the current predictor"""

        if self.mode == "eval":
            # return the best n rollouts that are evaluted by ground-truth evaluator
            self.logger.info(
                "Return the best {} rollouts in the population".format(n))
            all_gt_arch_scores = sum(self.gt_arch_scores, [])
            all_rollouts = sum(self.gt_rollouts, [])
            best_inds = np.argpartition(
                [item[1] for item in all_gt_arch_scores], -n)[-n:]
            # all_rollouts, all_scores = zip(
            #     *[(r, r.get_perf("reward")) for rs in self.gt_rollouts for r in rs])
            # best_inds = np.argpartition(all_scores, -n)[-n:]
            return [all_rollouts[ind] for ind in best_inds]

        if not self.is_predictor_trained:
            # if predictor is not trained, random sample from search space
            return [self.search_space.random_sample() for _ in range(n)]

        if n % self.inner_sample_n != 0:
            self.logger.warn(
                "samle number %d cannot be divided by inner_sample_n %d", n,
                self.inner_sample_n)

        # the arch rollouts that have already evaled, avoid sampling them
        already_evaled_r_set = sum(self.gt_rollouts, [])
        # nb101, nb201 420k, 15k, small. forward 1~2min max
        if self.inner_enumerate_search_space:
            if self.inner_enumerate_sample_ratio is not None:
                assert n % self.inner_sample_n == 0

            max_num = None if self.inner_enumerate_sample_ratio is None \
                      else n * self.inner_enumerate_sample_ratio
            iter_ = self.search_space.batch_rollouts(
                batch_size=self.predict_batch_size,
                shuffle=True,
                max_num=max_num)
            scores = []
            all_rollouts = []
            num_ignore = 0
            for rollouts in iter_:
                # remove the rollouts that is already evaled
                ori_len_ = len(rollouts)
                rollouts = [
                    rollout for rollout in rollouts
                    if rollout not in already_evaled_r_set
                ]
                num_ignore += ori_len_ - len(rollouts)
                all_rollouts = all_rollouts + self._predict_rollouts(rollouts)
                scores = scores + [i.perf["predicted_score"] for i in rollouts]

            if self.inner_sample_n is not None:
                num_iters = n // self.inner_sample_n
                rs_per_s = len(scores) // num_iters
                scores = np.array(scores)[:rs_per_s * num_iters]
                inds = np.argpartition(scores.reshape([num_iters, rs_per_s]),
                                       -self.inner_sample_n,
                                       axis=1)[:, -self.inner_sample_n:]
                # inds: (num_iters, self.inner_sample_n)
                best_inds = (
                    inds +
                    rs_per_s * np.arange(num_iters)[:, None]).reshape(-1)
                self.logger.info(
                    "Random sample %d archs (max num %d), ignore %d already evaled archs, "
                    "and choose %d archs per %d archs with highest predict scores",
                    len(scores), max_num, num_ignore, self.inner_sample_n,
                    rs_per_s)
            else:
                # finally: ranking, and get the first n archs. train_cellss_pkl.py `sample` function
                best_inds = np.argpartition(scores, -n)[-n:]
                self.logger.info(
                    "Random sample %d archs (max num %d), ignore %d already evaled archs, "
                    "and choose %d archs with highest predict scores",
                    len(scores), max_num, num_ignore, n)
            return [all_rollouts[i] for i in best_inds]

        # if self.inner_controller_reinit:
        self.inner_controller = BaseController.get_class_(
            self.inner_controller_type)(self.search_space,
                                        self.device,
                                        mode=self.mode,
                                        rollout_type=self.rollout_type,
                                        **self.inner_controller_cfg)
        if hasattr(self.inner_controller, "set_init_population"):
            self.logger.info(
                "re-evaluating %d rollouts using the current predictor",
                self.num_gt_rollouts)
            # set the init population of the inner controller
            # re-evaluate rollouts using the current predictor
            for rollouts in self.gt_rollouts:
                rollouts = self._predict_rollouts(rollouts)

            if not self.inner_random_init:
                self.inner_controller.set_init_population(
                    sum(self.gt_rollouts, []), perf_name="predicted_score")

        # inner_sample_n: how many archs to sample every iter
        num_iter = (n + self.inner_sample_n - 1) // self.inner_sample_n
        sampled_rollouts = []
        sampled_scores = []
        # the number, mean and max predicted scores of current sampled archs
        cur_sampled_mean_max = (0, 0, 0)
        i_iter = 1
        while i_iter <= num_iter:
            # for i_iter in range(1, num_iter+1):
            # random init
            if self.inner_iter_random_init \
               and hasattr(self.inner_controller, "reinit"):
                if i_iter > 1:
                    # might use gt rollouts as the init population if `inner_random_init=true`
                    # so, do not call reinit when i_iter == 1
                    if (not isinstance(self.inner_iter_random_init, int)) or \
                       self.inner_iter_random_init == 1 or \
                       i_iter % self.inner_iter_random_init == 1:
                        # if `inner_iter_random_init` is a integer
                        # only reinit every `inner_iter_random_init` iterations.
                        # `inner_iter_random_init==True` is the same as `inner_iter_random_init==1`,
                        # and means that every iter (besides iter 1) would call `reinit`
                        self.inner_controller.reinit()

            new_per_step_meter = utils.AverageMeter()

            # a list with length self.inner_sample_n
            best_rollouts = []
            best_scores = []
            num_to_sample = min(n - (i_iter - 1) * self.inner_sample_n,
                                self.inner_sample_n)
            iter_r_set = []
            iter_s_set = []
            sampled_r_set = sampled_rollouts
            for i_inner in range(1, self.inner_steps + 1):
                # self.inner_controller.on_epoch_begin(i_inner)
                # while 1:
                #     rollouts = self.inner_controller.sample(self.inner_samples)
                #     # remove the duplicate rollouts
                #     # *fixme* FIXME: local minimum problem exists!
                #     # random sample is one way, or do not use the best as the init?
                #     # Add a test to test the whole dataset...
                #     # grond-truth evaled, decided rollouts
                #     # rollouts = [r for r in rollouts
                #     #             if r not in already_evaled_r_set \
                #     #             and r not in sampled_r_set]
                #     # and r not in iter_r_set

                #     if not rollouts:
                #         print("all conflict, resample")
                #         continue
                #     else:
                #         # print("sampled {}".format(i_inner))
                #         break
                rollouts = self.inner_controller.sample(self.inner_samples)
                rollouts = self._predict_rollouts(rollouts)
                self.inner_controller.step(rollouts,
                                           self.inner_cont_optimizer,
                                           perf_name="predicted_score")

                # keep the `num_to_sample` archs with highest scores
                step_scores = [
                    r.get_perf(name="predicted_score") for r in rollouts
                ]
                new_rollouts = [r for r in rollouts
                                if r not in already_evaled_r_set \
                                and r not in sampled_r_set
                                and r not in iter_r_set]
                new_step_scores = [
                    r.get_perf(name="predicted_score") for r in new_rollouts
                ]
                new_per_step_meter.update(len(new_rollouts))
                best_rollouts += new_rollouts
                best_scores += new_step_scores
                iter_r_set += rollouts
                iter_s_set += step_scores

                if len(best_scores) > num_to_sample:
                    keep_inds = np.argpartition(
                        best_scores, -num_to_sample)[-num_to_sample:]
                    best_rollouts = [best_rollouts[ind] for ind in keep_inds]
                    best_scores = [best_scores[ind] for ind in keep_inds]
                if i_inner % self.inner_report_freq == 0:
                    self.logger.info((
                        "Iter %d (to sample %d) (already sampled %d mean %.5f, best %.5f); "
                        "Step %d: sample %d step mean %.5f best %.5f: {} "
                        "(iter mean %.5f, best %.5f). AVG new/step: %.3f"
                    ).format(", ".join([
                        "{:.5f}".format(s) for s in best_scores
                    ])), i_iter, num_to_sample, cur_sampled_mean_max[0],
                                     cur_sampled_mean_max[1],
                                     cur_sampled_mean_max[2], i_inner,
                                     len(rollouts), np.mean(step_scores),
                                     np.max(step_scores), np.mean(iter_s_set),
                                     np.max(iter_s_set),
                                     new_per_step_meter.avg)
            if new_per_step_meter.sum < num_to_sample * self.min_inner_sample_ratio:
                # rerun this iter, also reinit!
                self.logger.info(
                    "Cannot find %d (num_to_sample x min_inner_sample_ratio)"
                    " (%d x %d) new rollouts in one run of the inner controller"
                    "Re-init the controller and re-run this iteration.",
                    num_to_sample * self.min_inner_sample_ratio, num_to_sample,
                    self.min_inner_sample_ratio)
                continue

            i_iter += 1
            assert len(best_scores) == num_to_sample
            sampled_rollouts += best_rollouts
            sampled_scores += best_scores
            cur_sampled_mean_max = (len(sampled_scores),
                                    np.mean(sampled_scores),
                                    np.max(sampled_scores))

        return sampled_rollouts