Beispiel #1
0
 def before_train(self, logs=None):
     """Be called before the training process."""
     self.config = self.trainer.config
     self.unrolled = self.trainer.config.unrolled
     self.device = self.trainer.config.device
     self.model = self.trainer.model
     self.optimizer = self.trainer.optimizer
     self.lr_scheduler = self.trainer.lr_scheduler
     self.loss = self.trainer.loss
     self.search_alg = SearchAlgorithm(SearchSpace())
     self._set_algorithm_model(self.model)
     self.trainer.train_loader = self.trainer._init_dataloader(mode='train')
     self.trainer.valid_loader = self.trainer._init_dataloader(mode='val')
     normal_selected_idxs = torch.tensor(len(self.model.alphas_normal) * [-1],
                                         requires_grad=False, dtype=torch.int).cuda()
     reduce_selected_idxs = torch.tensor(len(self.model.alphas_reduce) * [-1],
                                         requires_grad=False, dtype=torch.int).cuda()
     normal_candidate_flags = torch.tensor(len(self.model.alphas_normal) * [True],
                                           requires_grad=False, dtype=torch.bool).cuda()
     reduce_candidate_flags = torch.tensor(len(self.model.alphas_reduce) * [True],
                                           requires_grad=False, dtype=torch.bool).cuda()
     logging.info('normal_selected_idxs: {}'.format(normal_selected_idxs))
     logging.info('reduce_selected_idxs: {}'.format(reduce_selected_idxs))
     logging.info('normal_candidate_flags: {}'.format(normal_candidate_flags))
     logging.info('reduce_candidate_flags: {}'.format(reduce_candidate_flags))
     self.model.normal_selected_idxs = normal_selected_idxs
     self.model.reduce_selected_idxs = reduce_selected_idxs
     self.model.normal_candidate_flags = normal_candidate_flags
     self.model.reduce_candidate_flags = reduce_candidate_flags
     logging.info(F.softmax(torch.stack(self.model.alphas_normal, dim=0), dim=-1).detach())
     logging.info(F.softmax(torch.stack(self.model.alphas_reduce, dim=0), dim=-1).detach())
     self.normal_probs_history = []
     self.reduce_probs_history = []
Beispiel #2
0
 def __init__(self):
     self.step_name = General.step_name
     self.search_space = SearchSpace()
     self.search_alg = SearchAlgorithm(self.search_space)
     if hasattr(self.search_alg.config, 'objective_keys'):
         self.objective_keys = self.search_alg.config.objective_keys
     self.quota = QuotaCompare('restrict')
     self.affinity = None if General.quota.affinity.type is None else QuotaAffinity(
         General.quota.affinity)
Beispiel #3
0
 def __init__(self):
     self.step_name = General.step_name
     self.search_space = SearchSpace()
     self.search_alg = SearchAlgorithm(self.search_space)
     self.report = Report()
     self.record = ReportRecord()
     self.record.step_name = self.step_name
     if hasattr(self.search_alg.config, 'objective_keys'):
         self.record.objective_keys = self.search_alg.config.objective_keys
     self.quota = QuotaCompare('restrict')
Beispiel #4
0
 def before_train(self, logs=None):
     """Be called before the training process."""
     self.config = self.trainer.config
     self.unrolled = self.trainer.config.unrolled
     self.device = self.trainer.config.device
     self.model = self.trainer.model
     self.optimizer = self.trainer.optimizer
     self.lr_scheduler = self.trainer.lr_scheduler
     self.loss = self.trainer.loss
     self.search_alg = SearchAlgorithm(SearchSpace())
     self._set_algorithm_model(self.model)
     self.trainer.train_loader = self.trainer._init_dataloader(mode='train')
     self.trainer.valid_loader = self.trainer._init_dataloader(mode='val')
Beispiel #5
0
 def before_train(self, logs=None):
     """Be called before the training process."""
     # Use zero valid_interval to supress default valid step
     self.trainer.valid_interval = 0
     self.trainer.config.report_on_epoch = True
     if vega.is_torch_backend():
         cudnn.benchmark = True
         cudnn.enabled = True
     self.search_alg = SearchAlgorithm(SearchSpace())
     self.alg_policy = self.search_alg.config.policy
     self.set_algorithm_model(self.trainer.model)
     # setup alphas
     n_individual = self.alg_policy.num_individual
     self.alphas = np.stack([
         self.search_alg.random_sample_path() for i in range(n_individual)
     ],
                            axis=0)
     self.trainer.train_loader = self.trainer._init_dataloader(mode='train')
     self.trainer.valid_loader = self.trainer._init_dataloader(mode='val')
Beispiel #6
0
    def before_train(self, logs=None):
        """Be called before the training process."""
        if self.initialized:
            return
        self.initialized = True
        self.trainer_config = self.trainer.config
        self.config = copy.deepcopy(self.trainer_config.modnas)
        self.model = self.trainer.model
        self.search_alg = None
        if not self.config.get('fully_train'):
            self.search_alg = SearchAlgorithm(SearchSpace())
        self.trainer.train_loader = self.trainer._init_dataloader(mode='train')
        self.trainer.valid_loader = self.trainer._init_dataloader(mode='val')
        self.init()
        if self.config.get('disable_estim'):
            self.wrp_trainer.disable_cond('before_epoch')
            self.wrp_trainer.disable_cond('before_train_step')
            return

        def estim_runner():
            try:
                for estim in self.estims.values():
                    estim.set_trainer(self.wrp_trainer)
                    estim.config.epochs = estim.config.get(
                        'epochs', self.trainer_config.epochs)
                results = {}
                for estim_name, estim in self.estims.items():
                    logger.info('Running estim: {} type: {}'.format(
                        estim_name, estim.__class__.__name__))
                    self.wrp_trainer.wrap_loss(estim)
                    ret = estim.run(self.search_alg)
                    results[estim_name] = ret
                logger.info('All results: {{\n{}\n}}'.format('\n'.join(
                    ['{}: {}'.format(k, v) for k, v in results.items()])))
                results['final'] = ret
                self.estim_ret = results
            except Exception:
                traceback.print_exc()
            # try to release the trainer
            self.trainer.train_loader = []
            self.trainer.valid_loader = []
            self.wrp_trainer.notify_all()
            self.wrp_trainer.disable_cond('before_epoch')
            self.wrp_trainer.disable_cond('before_train_step')

        # start estim coroutine
        estim_th = threading.Thread(target=estim_runner)
        estim_th.setDaemon(True)
        estim_th.start()
        self.estim_th = estim_th
Beispiel #7
0
class Generator(object):
    """Convert search space and search algorithm, sample a new model."""

    def __init__(self):
        self.step_name = General.step_name
        self.search_space = SearchSpace()
        self.search_alg = SearchAlgorithm(self.search_space)
        self.report = Report()
        self.record = ReportRecord()
        self.record.step_name = self.step_name
        if hasattr(self.search_alg.config, 'objective_keys'):
            self.record.objective_keys = self.search_alg.config.objective_keys
        self.quota = QuotaCompare('restrict')

    @property
    def is_completed(self):
        """Define a property to determine search algorithm is completed."""
        return self.search_alg.is_completed or self.quota.is_halted()

    def sample(self):
        """Sample a work id and model from search algorithm."""
        res = self.search_alg.search()
        if not res:
            return None
        if not isinstance(res, list):
            res = [res]
        if len(res) == 0:
            return None
        out = []
        for sample in res:
            desc = sample.get("desc") if isinstance(sample, dict) else sample[1]
            desc = self._decode_hps(desc)
            model_desc = deepcopy(desc)
            if "modules" in desc:
                PipeStepConfig.model.model_desc = deepcopy(desc)
            elif "network" in desc:
                origin_desc = PipeStepConfig.model.model_desc
                desc = update_dict(desc["network"], origin_desc)
                PipeStepConfig.model.model_desc = deepcopy(desc)
            if self.quota.is_filtered(desc):
                continue
            record = self.record.from_sample(sample, desc)
            Report().broadcast(record)
            out.append((record.worker_id, model_desc))
        return out

    def update(self, step_name, worker_id):
        """Update search algorithm accord to the worker path.

        :param step_name: step name
        :param worker_id: current worker id
        :return:
        """
        report = Report()
        record = report.receive(step_name, worker_id)
        logging.debug("Get Record=%s", str(record))
        self.search_alg.update(record.serialize())
        report.dump_report(record.step_name, record)
        self.dump()
        logging.info("Update Success. step_name=%s, worker_id=%s", step_name, worker_id)
        logging.info("Best values: %s", Report().print_best(step_name=General.step_name))

    @staticmethod
    def _decode_hps(hps):
        """Decode hps: `trainer.optim.lr : 0.1` to dict format.

        And convert to `zeus.common.config import Config` object
        This Config will be override in Trainer or Datasets class
        The override priority is: input hps > user configuration >  default configuration
        :param hps: hyper params
        :return: dict
        """
        hps_dict = {}
        if hps is None:
            return None
        if isinstance(hps, tuple):
            return hps
        for hp_name, value in hps.items():
            hp_dict = {}
            for key in list(reversed(hp_name.split('.'))):
                if hp_dict:
                    hp_dict = {key: hp_dict}
                else:
                    hp_dict = {key: value}
            # update cfg with hps
            hps_dict = update_dict(hps_dict, hp_dict, [])
        return Config(hps_dict)

    def dump(self):
        """Dump generator to file."""
        step_path = TaskOps().step_path
        _file = os.path.join(step_path, ".generator")
        with open(_file, "wb") as f:
            pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)

    @classmethod
    def restore(cls):
        """Restore generator from file."""
        step_path = TaskOps().step_path
        _file = os.path.join(step_path, ".generator")
        if os.path.exists(_file):
            with open(_file, "rb") as f:
                return pickle.load(f)
        else:
            return None
Beispiel #8
0
class CARSTrainerCallback(Callback):
    """A special callback for CARSTrainer."""

    disable_callbacks = ["ModelStatistics"]

    def __init__(self):
        super(CARSTrainerCallback, self).__init__()
        self.alg_policy = None

    def before_train(self, logs=None):
        """Be called before the training process."""
        # Use zero valid_interval to supress default valid step
        self.trainer.valid_interval = 0
        self.trainer.config.report_on_epoch = True
        if vega.is_torch_backend():
            cudnn.benchmark = True
            cudnn.enabled = True
        self.search_alg = SearchAlgorithm(SearchSpace())
        self.alg_policy = self.search_alg.config.policy
        self.set_algorithm_model(self.trainer.model)
        # setup alphas
        n_individual = self.alg_policy.num_individual
        self.alphas = np.stack([
            self.search_alg.random_sample_path() for i in range(n_individual)
        ],
                               axis=0)
        self.trainer.train_loader = self.trainer._init_dataloader(mode='train')
        self.trainer.valid_loader = self.trainer._init_dataloader(mode='val')

    def before_epoch(self, epoch, logs=None):
        """Be called before each epoach."""
        self.epoch = epoch

    def train_step(self, batch):
        """Replace the default train_step function."""
        self.trainer.model.train()
        input, target = batch
        self.trainer.optimizer.zero_grad()
        alphas = torch.from_numpy(self.alphas).cuda()
        for j in range(self.alg_policy.num_individual_per_iter):
            i = np.random.randint(0, self.alg_policy.num_individual, 1)[0]
            if self.epoch < self.alg_policy.warmup:
                alpha = torch.from_numpy(
                    self.search_alg.random_sample_path()).cuda()
                # logits = self.trainer.model.forward_random(input)
            else:
                alpha = alphas[i]
            logits = self.trainer.model(input, alpha=alpha)
            loss = self.trainer.loss(logits, target)
            loss.backward(retain_graph=True)
            if self.epoch < self.alg_policy.warmup:
                break
        nn.utils.clip_grad_norm_(self.trainer.model.parameters(),
                                 self.trainer.config.grad_clip)
        self.trainer.optimizer.step()
        return {
            'loss': loss.item(),
            'train_batch_output': logits,
            'lr': self.trainer.lr_scheduler.get_lr()
        }

    def model_fn(self, features, labels, mode):
        """Define cars model_fn used by TensorFlow Estimator."""
        logging.info('Cars model function action')
        self.trainer.loss = Loss()()

        train_op = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            global_step = tf.compat.v1.train.get_global_step()
            epoch = tf.cast(global_step, tf.float32) / tf.cast(
                len(self.trainer.train_loader), tf.float32)
            self.trainer.optimizer = Optimizer()(
                distributed=self.trainer.distributed)
            self.trainer.lr_scheduler = LrScheduler()(self.trainer.optimizer)
            self.trainer.lr_scheduler.step(epoch)
            self.trainer.model.training = True
            alphas = tf.convert_to_tensor(self.alphas)
            for j in range(self.alg_policy.num_individual_per_iter):
                i = np.random.randint(0, self.alg_policy.num_individual, 1)[0]
                if self.epoch < self.alg_policy.warmup:
                    alpha = tf.convert_to_tensor(
                        self.search_alg.random_sample_path())
                else:
                    alpha = alphas[i]
                logits = self.trainer.model(features, alpha=alpha)
                logits = tf.cast(logits, tf.float32)
                loss = self.trainer.loss(logits=logits, labels=labels)
                loss = self.trainer.optimizer.regularize_loss(loss)
                grads, vars = zip(
                    *self.trainer.optimizer.compute_gradients(loss))
                if j == 0:
                    accum_grads = [
                        tf.Variable(tf.zeros_like(grad), trainable=False)
                        for grad in grads
                    ]
                accum_grads = [
                    accum_grads[k] + grads[k] for k in range(len(grads))
                ]
                if self.epoch < self.alg_policy.warmup:
                    break
            clipped_grads, _ = tf.clip_by_global_norm(
                accum_grads, self.trainer.config.grad_clip)
            minimize_op = self.trainer.optimizer.apply_gradients(
                list(zip(clipped_grads, vars)), global_step)
            update_ops = tf.compat.v1.get_collection(
                tf.compat.v1.GraphKeys.UPDATE_OPS)
            train_op = tf.group(minimize_op, update_ops)

        eval_metric_ops = None
        if mode == tf.estimator.ModeKeys.EVAL:
            alpha = tf.convert_to_tensor(self.trainer.valid_alpha)
            self.trainer.model.training = False
            logits = self.trainer.model(features, alpha=alpha)
            logits = tf.cast(logits, tf.float32)
            loss = self.trainer.loss(logits=logits, labels=labels)
            eval_metric_ops = self.trainer.valid_metrics(logits, labels)

        return tf.estimator.EstimatorSpec(mode=mode,
                                          loss=loss,
                                          train_op=train_op,
                                          eval_metric_ops=eval_metric_ops)

    def after_epoch(self, epoch, logs=None):
        """Be called after each epoch."""
        self.alphas = self.search_alg.search_evol_arch(epoch, self.alg_policy,
                                                       self.trainer,
                                                       self.alphas)

    def set_algorithm_model(self, model):
        """Set model to algorithm.

        :param model: network model
        :type model: torch.nn.Module
        """
        self.search_alg.set_model(model)
Beispiel #9
0
class SGASTrainerCallback(DartsTrainerCallback):
    """A special callback for DartsTrainer."""

    disable_callbacks = ["ModelCheckpoint"]

    def before_train(self, logs=None):
        """Be called before the training process."""
        self.config = self.trainer.config
        self.unrolled = self.trainer.config.unrolled
        self.device = self.trainer.config.device
        self.model = self.trainer.model
        self.optimizer = self.trainer.optimizer
        self.lr_scheduler = self.trainer.lr_scheduler
        self.loss = self.trainer.loss
        self.search_alg = SearchAlgorithm(SearchSpace())
        self._set_algorithm_model(self.model)
        self.trainer.train_loader = self.trainer._init_dataloader(mode='train')
        self.trainer.valid_loader = self.trainer._init_dataloader(mode='val')
        normal_selected_idxs = torch.tensor(len(self.model.alphas_normal) * [-1],
                                            requires_grad=False, dtype=torch.int).cuda()
        reduce_selected_idxs = torch.tensor(len(self.model.alphas_reduce) * [-1],
                                            requires_grad=False, dtype=torch.int).cuda()
        normal_candidate_flags = torch.tensor(len(self.model.alphas_normal) * [True],
                                              requires_grad=False, dtype=torch.bool).cuda()
        reduce_candidate_flags = torch.tensor(len(self.model.alphas_reduce) * [True],
                                              requires_grad=False, dtype=torch.bool).cuda()
        logging.info('normal_selected_idxs: {}'.format(normal_selected_idxs))
        logging.info('reduce_selected_idxs: {}'.format(reduce_selected_idxs))
        logging.info('normal_candidate_flags: {}'.format(normal_candidate_flags))
        logging.info('reduce_candidate_flags: {}'.format(reduce_candidate_flags))
        self.model.normal_selected_idxs = normal_selected_idxs
        self.model.reduce_selected_idxs = reduce_selected_idxs
        self.model.normal_candidate_flags = normal_candidate_flags
        self.model.reduce_candidate_flags = reduce_candidate_flags
        logging.info(F.softmax(torch.stack(self.model.alphas_normal, dim=0), dim=-1).detach())
        logging.info(F.softmax(torch.stack(self.model.alphas_reduce, dim=0), dim=-1).detach())
        self.normal_probs_history = []
        self.reduce_probs_history = []

    def before_epoch(self, epoch, logs=None):
        """Be called before each epoach."""
        if vega.is_torch_backend():
            self.valid_loader_iter = iter(self.trainer.valid_loader)

    def before_train_step(self, epoch, logs=None):
        """Be called before a batch training."""
        # Get current train batch directly from logs
        train_batch = logs['train_batch']
        train_input, train_target = train_batch
        # Prepare valid batch data by using valid loader from trainer
        try:
            valid_input, valid_target = next(self.valid_loader_iter)
        except Exception:
            self.valid_loader_iter = iter(self.trainer.valid_loader)
            valid_input, valid_target = next(self.valid_loader_iter)
        valid_input, valid_target = valid_input.to(self.device), valid_target.to(self.device)
        # Call arch search step
        self._train_arch_step(train_input, train_target, valid_input, valid_target)

    def after_epoch(self, epoch, logs=None):
        """Be called after each epoch."""
        child_desc_temp = self.search_alg.codec.calc_genotype(self._get_arch_weights())
        logging.info('normal = %s', child_desc_temp[0])
        logging.info('reduce = %s', child_desc_temp[1])
        normal_edge_decision = self.search_alg.edge_decision('normal',
                                                             self.model.alphas_normal,
                                                             self.model.normal_selected_idxs,
                                                             self.model.normal_candidate_flags,
                                                             self.normal_probs_history,
                                                             epoch)
        saved_memory_normal, self.model.normal_selected_idxs, self.model.normal_candidate_flags = normal_edge_decision
        reduce_edge_decision = self.search_alg.edge_decision('reduce',
                                                             self.model.alphas_reduce,
                                                             self.model.reduce_selected_idxs,
                                                             self.model.reduce_candidate_flags,
                                                             self.reduce_probs_history,
                                                             epoch)
        saved_memory_reduce, self.model.reduce_selected_idxs, self.model.reduce_candidate_flags = reduce_edge_decision
        if saved_memory_normal or saved_memory_reduce:
            torch.cuda.empty_cache()
        self._save_descript()
class DartsTrainerCallback(Callback):
    """A special callback for DartsTrainer."""

    disable_callbacks = ["ModelCheckpoint"]

    def before_train(self, logs=None):
        """Be called before the training process."""
        self.config = self.trainer.config
        self.unrolled = self.trainer.config.unrolled
        self.device = vega.is_gpu_device() if vega.is_gpu_device(
        ) is not True else 0
        self.model = self.trainer.model
        self.optimizer = self.trainer.optimizer
        self.lr_scheduler = self.trainer.lr_scheduler
        self.loss = self.trainer.loss
        self.search_alg = SearchAlgorithm(SearchSpace())
        self._set_algorithm_model(self.model)
        self.trainer.train_loader = self.trainer._init_dataloader(mode='train')
        self.trainer.valid_loader = self.trainer._init_dataloader(mode='val')

    def before_epoch(self, epoch, logs=None):
        """Be called before each epoach."""
        if vega.is_torch_backend():
            self.valid_loader_iter = iter(self.trainer.valid_loader)

    def before_train_step(self, epoch, logs=None):
        """Be called before a batch training."""
        # Get current train batch directly from logs
        train_batch = logs['train_batch']
        train_input, train_target = train_batch
        # Prepare valid batch data by using valid loader from trainer
        try:
            valid_input, valid_target = next(self.valid_loader_iter)
        except Exception:
            self.valid_loader_iter = iter(self.trainer.valid_loader)
            valid_input, valid_target = next(self.valid_loader_iter)
        valid_input, valid_target = valid_input.to(
            self.device), valid_target.to(self.device)
        # Call arch search step
        self._train_arch_step(train_input, train_target, valid_input,
                              valid_target)

    def after_epoch(self, epoch, logs=None):
        """Be called after each epoch."""
        child_desc_temp = self.search_alg.codec.calc_genotype(
            self._get_arch_weights())
        logging.info('normal = %s', child_desc_temp[0])
        logging.info('reduce = %s', child_desc_temp[1])
        self._save_descript()

    def after_train(self, logs=None):
        """Be called after Training."""
        self.trainer._backup()

    def _train_arch_step(self, train_input, train_target, valid_input,
                         valid_target):
        lr = self.lr_scheduler.get_lr()[0]
        self.search_alg.step(train_input, train_target, valid_input,
                             valid_target, lr, self.optimizer, self.loss,
                             self.unrolled)

    def _set_algorithm_model(self, model):
        self.search_alg.set_model(model)

    def train_input_fn(self):
        """Input function for search."""
        def map_to_dict(td, vd):
            return {
                'train': td[0],
                'valid': vd[0]
            }, {
                'train': td[1],
                'valid': vd[1]
            }

        dataset = tf.data.Dataset.zip((self.trainer.train_loader.input_fn(),
                                       self.trainer.valid_loader.input_fn()))
        dataset = dataset.map(lambda td, vd: map_to_dict(td, vd))
        # dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
        return dataset

    def model_fn(self, features, labels, mode):
        """Darts model_fn used by TensorFlow Estimator."""
        logging.info('Darts model function action')
        global_step = tf.compat.v1.train.get_global_step()
        train_op = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            features, valid_features = features['train'], features['valid']
            labels, valid_labels = labels['train'], labels['valid']
            # update arch
            epoch = tf.cast(global_step, tf.float32) / tf.cast(
                len(self.trainer.train_loader), tf.float32)
            self.trainer.optimizer = Optimizer()(
                distributed=self.trainer.distributed)
            self.trainer.lr_scheduler = LrScheduler()(self.trainer.optimizer)
            self.trainer.lr_scheduler.step(epoch)
            update_ops = tf.compat.v1.get_collection(
                tf.compat.v1.GraphKeys.UPDATE_OPS)
            arch_minimize_op = self.search_alg.step(
                valid_x=valid_features,
                valid_y=valid_labels,
                lr=self.trainer.lr_scheduler.get_lr()[0])
            train_op = tf.group(arch_minimize_op, update_ops)
        self.model.training = mode == tf.estimator.ModeKeys.TRAIN
        logits = self.model(features)
        logits = tf.cast(logits, tf.float32)
        self.trainer.loss = Loss()()
        loss = self.trainer.loss(logits=logits, labels=labels)

        if mode == tf.estimator.ModeKeys.TRAIN:
            with tf.control_dependencies([train_op]):
                weight_ops = self.model.get_weight_ops()
                loss_scale = self.trainer.config.loss_scale if self.trainer.use_amp else 1
                train_op = self.trainer.optimizer.step(loss, loss_scale,
                                                       global_step, weight_ops)

        eval_metric_ops = None
        if mode == tf.estimator.ModeKeys.EVAL:
            eval_metric_ops = self.trainer.valid_metrics(logits, labels)

        return tf.estimator.EstimatorSpec(mode=mode,
                                          loss=loss,
                                          train_op=train_op,
                                          eval_metric_ops=eval_metric_ops)

    def _get_arch_weights(self):
        if vega.is_torch_backend():
            arch_weights = self.model.arch_weights
        elif vega.is_tf_backend():
            sess_config = self.trainer._init_session_config()
            with tf.compat.v1.Session(config=sess_config) as sess:
                # tf.reset_default_graph()
                checkpoint_file = tf.train.latest_checkpoint(
                    self.trainer.get_local_worker_path())
                saver = tf.train.import_meta_graph(
                    "{}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
                # initializer is necessary here
                sess.run(tf.global_variables_initializer())
                arch_weights = self.model.arch_weights
                arch_weights = [weight.eval() for weight in arch_weights]
        return arch_weights

    def _save_descript(self):
        """Save result descript."""
        template_file = self.config.darts_template_file
        genotypes = self.search_alg.codec.calc_genotype(
            self._get_arch_weights())
        if template_file == "{default_darts_cifar10_template}":
            template = DartsNetworkTemplateConfig.cifar10
        elif template_file == "{default_darts_cifar100_template}":
            template = DartsNetworkTemplateConfig.cifar100
        elif template_file == "{default_darts_imagenet_template}":
            template = DartsNetworkTemplateConfig.imagenet
        else:
            dst = FileOps.join_path(self.trainer.get_local_worker_path(),
                                    os.path.basename(template_file))
            FileOps.copy_file(template_file, dst)
            template = Config(dst)
        model_desc = self._gen_model_desc(genotypes, template)
        self.trainer.config.codec = model_desc

    def _gen_model_desc(self, genotypes, template):
        model_desc = deepcopy(template)
        model_desc.super_network.cells.normal.genotype = genotypes[0]
        model_desc.super_network.cells.reduce.genotype = genotypes[1]
        return model_desc
Beispiel #11
0
class Generator(object):
    """Convert search space and search algorithm, sample a new model."""
    def __init__(self):
        self.step_name = General.step_name
        self.search_space = SearchSpace()
        self.search_alg = SearchAlgorithm(self.search_space)
        if hasattr(self.search_alg.config, 'objective_keys'):
            self.objective_keys = self.search_alg.config.objective_keys
        self.quota = QuotaCompare('restrict')
        self.affinity = None if General.quota.affinity.type is None else QuotaAffinity(
            General.quota.affinity)

    @property
    def is_completed(self):
        """Define a property to determine search algorithm is completed."""
        return self.search_alg.is_completed or self.quota.is_halted()

    def sample(self):
        """Sample a work id and model from search algorithm."""
        for _ in range(10):
            res = self.search_alg.search()
            if not res:
                return None
            if not isinstance(res, list):
                res = [res]
            if len(res) == 0:
                return None
            out = []
            for sample in res:
                if isinstance(sample, dict):
                    id = sample["worker_id"]
                    desc = self._decode_hps(sample["encoded_desc"])
                    sample.pop("worker_id")
                    sample.pop("encoded_desc")
                    kwargs = sample
                    sample = _split_sample((id, desc))
                else:
                    kwargs = {}
                    sample = _split_sample(sample)
                if hasattr(self, "objective_keys") and self.objective_keys:
                    kwargs["objective_keys"] = self.objective_keys
                (id, desc, hps) = sample

                if "modules" in desc:
                    PipeStepConfig.model.model_desc = deepcopy(desc)
                elif "network" in desc:
                    origin_desc = PipeStepConfig.model.model_desc
                    model_desc = update_dict(desc["network"], origin_desc)
                    PipeStepConfig.model.model_desc = model_desc
                    desc.pop('network')
                    desc.update(model_desc)

                if self.quota.is_filtered(desc):
                    continue
                if self.affinity and not self.affinity.is_affinity(desc):
                    continue
                ReportClient().update(General.step_name,
                                      id,
                                      desc=desc,
                                      hps=hps,
                                      **kwargs)
                out.append((id, desc, hps))
            if out:
                break
        return out

    def update(self, step_name, worker_id):
        """Update search algorithm accord to the worker path.

        :param step_name: step name
        :param worker_id: current worker id
        :return:
        """
        record = ReportClient().get_record(step_name, worker_id)
        logging.debug("Get Record=%s", str(record))
        self.search_alg.update(record.serialize())
        try:
            self.dump()
        except TypeError:
            logging.warning(
                "The Generator contains object which can't be pickled.")
        logging.info(
            f"Update Success. step_name={step_name}, worker_id={worker_id}")
        logging.info("Best values: %s",
                     ReportServer().print_best(step_name=General.step_name))

    @staticmethod
    def _decode_hps(hps):
        """Decode hps: `trainer.optim.lr : 0.1` to dict format.

        And convert to `vega.common.config import Config` object
        This Config will be override in Trainer or Datasets class
        The override priority is: input hps > user configuration >  default configuration
        :param hps: hyper params
        :return: dict
        """
        hps_dict = {}
        if hps is None:
            return None
        if isinstance(hps, tuple):
            return hps
        for hp_name, value in hps.items():
            hp_dict = {}
            for key in list(reversed(hp_name.split('.'))):
                if hp_dict:
                    hp_dict = {key: hp_dict}
                else:
                    hp_dict = {key: value}
            # update cfg with hps
            hps_dict = update_dict(hps_dict, hp_dict, [])
        return Config(hps_dict)

    def dump(self):
        """Dump generator to file."""
        step_path = TaskOps().step_path
        _file = os.path.join(step_path, ".generator")
        with open(_file, "wb") as f:
            pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)

    @classmethod
    def restore(cls):
        """Restore generator from file."""
        step_path = TaskOps().step_path
        _file = os.path.join(step_path, ".generator")
        if os.path.exists(_file):
            with open(_file, "rb") as f:
                return pickle.load(f)
        else:
            return None