Exemple #1
0
class CARSTrainerCallback(Callback):
    """A special callback for CARSTrainer."""

    disable_callbacks = ["ModelStatistics"]

    def __init__(self):
        super(CARSTrainerCallback, self).__init__()
        self.alg_policy = None

    def before_train(self, logs=None):
        """Be called before the training process."""
        # Use zero valid_interval to supress default valid step
        self.trainer.valid_interval = 0
        self.trainer.config.report_on_epoch = True
        if vega.is_torch_backend():
            cudnn.benchmark = True
            cudnn.enabled = True
        self.search_alg = SearchAlgorithm(SearchSpace().search_space)
        self.alg_policy = self.search_alg.config.policy
        self.set_algorithm_model(self.trainer.model)
        # setup alphas
        n_individual = self.alg_policy.num_individual
        self.alphas = np.stack([self.search_alg.random_sample_path()
                                for i in range(n_individual)], axis=0)
        self.trainer.train_loader = self.trainer._init_dataloader(mode='train')
        self.trainer.valid_loader = self.trainer._init_dataloader(mode='val')

    def before_epoch(self, epoch, logs=None):
        """Be called before each epoach."""
        self.epoch = epoch

    def train_step(self, batch):
        """Replace the default train_step function."""
        self.trainer.model.train()
        input, target = batch
        self.trainer.optimizer.zero_grad()
        alphas = torch.from_numpy(self.alphas).cuda()
        for j in range(self.alg_policy.num_individual_per_iter):
            i = np.random.randint(0, self.alg_policy.num_individual, 1)[0]
            if self.epoch < self.alg_policy.warmup:
                alpha = torch.from_numpy(self.search_alg.random_sample_path()).cuda()
                # logits = self.trainer.model.forward_random(input)
            else:
                alpha = alphas[i]
            logits = self.trainer.model(input, alpha)
            loss = self.trainer.loss(logits, target)
            loss.backward(retain_graph=True)
            if self.epoch < self.alg_policy.warmup:
                break
        nn.utils.clip_grad_norm_(
            self.trainer.model.parameters(), self.trainer.config.grad_clip)
        self.trainer.optimizer.step()
        return {'loss': loss.item(),
                'train_batch_output': logits}

    def model_fn(self, features, labels, mode):
        """Define cars model_fn used by TensorFlow Estimator."""
        logging.info('Cars model function action')
        self.trainer.loss = Loss()()

        train_op = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            global_step = tf.train.get_global_step()
            epoch = tf.cast(global_step, tf.float32) / tf.cast(len(self.trainer.train_loader), tf.float32)
            self.trainer.lr_scheduler = LrScheduler()()
            self.trainer.optimizer = Optimizer()(lr_scheduler=self.trainer.lr_scheduler,
                                                 epoch=epoch,
                                                 distributed=self.trainer.distributed)
            alphas = tf.convert_to_tensor(self.alphas)
            for j in range(self.alg_policy.num_individual_per_iter):
                i = np.random.randint(0, self.alg_policy.num_individual, 1)[0]
                if self.epoch < self.alg_policy.warmup:
                    alpha = tf.convert_to_tensor(self.search_alg.random_sample_path())
                else:
                    alpha = alphas[i]
                logits = self.trainer.model(features, alpha, True)
                logits = tf.cast(logits, tf.float32)
                loss = self.trainer.loss(logits=logits, labels=labels)
                grads, vars = zip(*self.trainer.optimizer.compute_gradients(loss))
                if j == 0:
                    accum_grads = [tf.Variable(tf.zeros_like(grad), trainable=False) for grad in grads]
                accum_grads = [accum_grads[k] + grads[k] for k in range(len(grads))]
                if self.epoch < self.alg_policy.warmup:
                    break
            clipped_grads, _ = tf.clip_by_global_norm(accum_grads, self.trainer.config.grad_clip)
            minimize_op = self.trainer.optimizer.apply_gradients(list(zip(clipped_grads, vars)), global_step)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            train_op = tf.group(minimize_op, update_ops)

        eval_metric_ops = None
        if mode == tf.estimator.ModeKeys.EVAL:
            alpha = tf.convert_to_tensor(self.trainer.valid_alpha)
            logits = self.trainer.model(features, alpha, False)
            logits = tf.cast(logits, tf.float32)
            loss = self.trainer.loss(logits=logits, labels=labels)
            eval_metric_ops = self.trainer.valid_metrics(logits, labels)

        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op,
                                          eval_metric_ops=eval_metric_ops)

    def after_epoch(self, epoch, logs=None):
        """Be called after each epoch."""
        self.alphas = self.search_alg.search_evol_arch(epoch, self.alg_policy, self.trainer, self.alphas)

    def set_algorithm_model(self, model):
        """Set model to algorithm.

        :param model: network model
        :type model: torch.nn.Module
        """
        self.search_alg.set_model(model)
class DartsTrainerCallback(Callback):
    """A special callback for DartsTrainer."""

    def before_train(self, logs=None):
        """Be called before the training process."""
        self.config = self.trainer.config
        self.unrolled = self.trainer.config.unrolled
        self.device = self.trainer.config.device
        self.model = self.trainer.model
        self.optimizer = self.trainer.optimizer
        self.lr_scheduler = self.trainer.lr_scheduler
        self.loss = self.trainer.loss
        self.search_alg = SearchAlgorithm(SearchSpace().search_space)
        self._set_algorithm_model(self.model)
        self.trainer.train_loader = self.trainer._init_dataloader(mode='train')
        self.trainer.valid_loader = self.trainer._init_dataloader(mode='val')

    def before_epoch(self, epoch, logs=None):
        """Be called before each epoach."""
        if vega.is_torch_backend():
            self.valid_loader_iter = iter(self.trainer.valid_loader)

    def before_train_step(self, epoch, logs=None):
        """Be called before a batch training."""
        # Get current train batch directly from logs
        train_batch = logs['train_batch']
        train_input, train_target = train_batch
        # Prepare valid batch data by using valid loader from trainer
        try:
            valid_input, valid_target = next(self.valid_loader_iter)
        except Exception:
            self.valid_loader_iter = iter(self.trainer.valid_loader)
            valid_input, valid_target = next(self.valid_loader_iter)
        valid_input, valid_target = valid_input.to(self.device), valid_target.to(self.device)
        # Call arch search step
        self._train_arch_step(train_input, train_target, valid_input, valid_target)

    def after_epoch(self, epoch, logs=None):
        """Be called after each epoch."""
        child_desc_temp = self.search_alg.codec.calc_genotype(self._get_arch_weights())
        logging.info('normal = %s', child_desc_temp[0])
        logging.info('reduce = %s', child_desc_temp[1])
        self._save_descript()

    def after_train(self, logs=None):
        """Be called after Training."""
        self.trainer._backup()

    def _train_arch_step(self, train_input, train_target, valid_input, valid_target):
        lr = self.lr_scheduler.get_lr()[0]
        self.search_alg.step(train_input, train_target, valid_input, valid_target,
                             lr, self.optimizer, self.loss, self.unrolled)

    def _set_algorithm_model(self, model):
        self.search_alg.set_model(model)

    def train_input_fn(self):
        """Input function for search."""

        def map_to_dict(td, vd):
            return {'train': td[0], 'valid': vd[0]}, {'train': td[1], 'valid': vd[1]}

        dataset = tf.data.Dataset.zip((self.trainer.train_loader.input_fn(),
                                       self.trainer.valid_loader.input_fn()))
        dataset = dataset.map(lambda td, vd: map_to_dict(td, vd))
        dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE)
        return dataset

    def model_fn(self, features, labels, mode):
        """Darts model_fn used by TensorFlow Estimator."""
        logging.info('Darts model function action')
        global_step = tf.train.get_global_step()
        train_op = None
        if mode == tf.estimator.ModeKeys.TRAIN:
            features, valid_features = features['train'], features['valid']
            labels, valid_labels = labels['train'], labels['valid']
            # update arch
            epoch = tf.cast(global_step, tf.float32) / tf.cast(len(self.trainer.train_loader), tf.float32)
            self.trainer.lr_scheduler = LrScheduler()()
            self.trainer.optimizer = Optimizer()(lr_scheduler=self.trainer.lr_scheduler,
                                                 epoch=epoch,
                                                 distributed=self.trainer.distributed)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            arch_minimize_op = self.search_alg.step(valid_x=valid_features,
                                                    valid_y=valid_labels,
                                                    lr=self.trainer.lr_scheduler.get_lr()[0])
            train_op = tf.group(arch_minimize_op, update_ops)

        logits = self.model(features, mode == tf.estimator.ModeKeys.TRAIN)
        logits = tf.cast(logits, tf.float32)
        self.trainer.loss = Loss()()
        loss = self.trainer.loss(logits=logits, labels=labels)

        if mode == tf.estimator.ModeKeys.TRAIN:
            with tf.control_dependencies([train_op]):
                weight_ops = self.model.get_weight_ops()
                train_op = self.trainer._init_minimize_op(loss, global_step, weight_ops)

        eval_metric_ops = None
        if mode == tf.estimator.ModeKeys.EVAL:
            eval_metric_ops = self.trainer.valid_metrics(logits, labels)

        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op,
                                          eval_metric_ops=eval_metric_ops)

    def _get_arch_weights(self):
        if vega.is_torch_backend():
            arch_weights = self.model.arch_weights
        elif vega.is_tf_backend():
            sess_config = self.trainer._init_session_config()
            with tf.Session(config=sess_config) as sess:
                # tf.reset_default_graph()
                checkpoint_file = tf.train.latest_checkpoint(self.trainer.get_local_worker_path())
                saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
                saver.restore(sess, checkpoint_file)
                arch_weights = self.model.arch_weights
                arch_weights = [weight.eval() for weight in arch_weights]
        return arch_weights

    def _save_descript(self):
        """Save result descript."""
        template_file = self.config.darts_template_file
        genotypes = self.search_alg.codec.calc_genotype(self._get_arch_weights())
        if template_file == "{default_darts_cifar10_template}":
            template = DartsNetworkTemplateConfig.cifar10
        elif template_file == "{default_darts_imagenet_template}":
            template = DartsNetworkTemplateConfig.imagenet
        else:
            dst = FileOps.join_path(self.trainer.get_local_worker_path(), os.path.basename(template_file))
            FileOps.copy_file(template_file, dst)
            template = Config(dst)
        model_desc = self._gen_model_desc(genotypes, template)
        self.trainer.config.codec = model_desc

    def _gen_model_desc(self, genotypes, template):
        model_desc = deepcopy(template)
        model_desc.super_network.normal.genotype = genotypes[0]
        model_desc.super_network.reduce.genotype = genotypes[1]
        return model_desc
Exemple #3
0
class CARSTrainerCallback(Callback):
    """A special callback for CARSTrainer."""
    def __init__(self):
        super(CARSTrainerCallback, self).__init__()
        self.alg_policy = ClassFactory.__configs__.search_algorithm.policy

    def before_train(self, epoch, logs=None):
        """Be called before the training process."""
        # Use zero valid_freq to supress default valid step
        self.trainer.auto_save_ckpt = False
        self.trainer.auto_save_perf = False
        self.trainer.valid_freq = 0
        cudnn.benchmark = True
        cudnn.enabled = True
        self.search_alg = SearchAlgorithm(SearchSpace())
        self.set_algorithm_model(self.trainer.model)
        # setup alphas
        n_individual = self.alg_policy.num_individual
        self.alphas = torch.cat([
            self.trainer.model.random_single_path().unsqueeze(0)
            for i in range(n_individual)
        ],
                                dim=0)
        self.trainer.train_loader = self.trainer._init_dataloader(mode='train')
        self.trainer.valid_loader = self.trainer._init_dataloader(mode='val')

    def before_epoch(self, epoch, logs=None):
        """Be called before each epoach."""
        self.epoch = epoch
        self.trainer.lr_scheduler.step()

    def train_step(self, batch):
        """Replace the default train_step function."""
        self.trainer.model.train()
        input, target = batch
        self.trainer.optimizer.zero_grad()
        for j in range(self.alg_policy.num_individual_per_iter):
            i = np.random.randint(0, self.alg_policy.num_individual, 1)[0]
            if self.epoch < self.alg_policy.warmup:
                logits = self.trainer.model.forward_random(input)
            else:
                logits = self.trainer.model(input, self.alphas[i])
            loss = self.trainer.loss(logits, target)
            loss.backward(retain_graph=True)
            if self.epoch < self.alg_policy.warmup:
                break
        nn.utils.clip_grad_norm(self.trainer.model.parameters(),
                                self.trainer.cfg.grad_clip)
        self.trainer.optimizer.step()
        return {'loss': loss.item(), 'train_batch_output': logits}

    def after_epoch(self, epoch, logs=None):
        """Be called after each epoch."""
        self.search_evol_arch(epoch)

    def set_algorithm_model(self, model):
        """Set model to algorithm.

        :param model: network model
        :type model: torch.nn.Module
        """
        self.search_alg.set_model(model)

    def search_evol_arch(self, epoch):
        """Update architectures.

        :param epoch: The current epoch
        :type epoch: int
        :param valid_queue: valid dataloader
        :type valid_queue: dataloader
        :param model: The model to be trained
        :type model: nn.Module
        """
        if epoch >= self.alg_policy.start_ga_epoch and \
                (epoch - self.alg_policy.start_ga_epoch) % self.alg_policy.ga_interval == 0:
            self.save_model_checkpoint(self.trainer.model,
                                       'weights_{}.pt'.format(epoch))
            for generation in range(self.alg_policy.num_generation):
                fitness = np.zeros(
                    int(self.alg_policy.num_individual *
                        (1 + self.alg_policy.expand)))
                model_sizes = np.zeros(
                    int(self.alg_policy.num_individual *
                        (1 + self.alg_policy.expand)))
                genotypes = []
                # generate offsprings using mutation and cross-over
                offsprings = self.search_alg.gen_offspring(self.alphas)
                self.alphas = torch.cat((self.alphas, offsprings), dim=0)
                # calculate fitness (accuracy) and #parameters
                for i in range(
                        int(self.alg_policy.num_individual *
                            (1 + self.alg_policy.expand))):
                    fitness[i], _ = self.search_infer_step(self.alphas[i])
                    genotypes.append(self.genotype_namedtuple(self.alphas[i]))
                    model_sizes[i] = self.eval_model_sizes(self.alphas[i])
                    logging.info(
                        'Valid_acc for invidual {} %f, size %f'.format(i),
                        fitness[i], model_sizes[i])
                # update population using pNSGA-III (CARS_NSGA)
                logging.info('############## Begin update alpha ############')
                if self.alg_policy.nsga_method == 'nsga3':
                    _, _, keep = SortAndSelectPopulation(
                        np.vstack((1 / fitness, model_sizes)),
                        self.alg_policy.num_individual)
                elif self.alg_policy.nsga_method == 'cars_nsga':
                    nsga_objs = [model_sizes]
                    keep = CARS_NSGA(fitness, nsga_objs,
                                     self.alg_policy.num_individual)
                drop = list(
                    set(
                        list(
                            range(
                                int(self.alg_policy.num_individual *
                                    (1 + self.alg_policy.expand))))) -
                    set(keep.tolist()))
                logging.info('############## KEEP ############')
                fitness_keep = []
                size_keep = []
                genotype_keep = []
                for i in keep:
                    logging.info(
                        'KEEP Valid_acc for invidual {} %f, size %f, genotype %s'
                        .format(i), fitness[i], model_sizes[i], genotypes[i])
                    fitness_keep.append(fitness[i])
                    size_keep.append(model_sizes[i])
                    genotype_keep.append(genotypes[i])
                logging.info('############## DROP ############')
                for i in drop:
                    logging.info(
                        'DROP Valid_acc for invidual {} %f, size %f, genotype %s'
                        .format(i), fitness[i], model_sizes[i], genotypes[i])
                if self.alg_policy.select_method == 'uniform':
                    selected_genotypes, selected_acc, selected_model_sizes = \
                        self.select_uniform_pareto_front(
                            np.array(fitness_keep), np.array(size_keep), genotype_keep)
                else:  # default: first
                    selected_genotypes, selected_acc, selected_model_sizes = \
                        self.select_first_pareto_front(
                            np.array(fitness_keep), np.array(size_keep), genotype_keep)
                ga_epoch = int((epoch - self.alg_policy.start_ga_epoch) /
                               self.alg_policy.ga_interval)
                self.save_genotypes(
                    selected_genotypes, selected_acc, selected_model_sizes,
                    'genotype_selected_{}.txt'.format(ga_epoch))
                self.save_genotypes(genotype_keep, np.array(fitness_keep),
                                    np.array(size_keep),
                                    'genotype_keep_{}.txt'.format(ga_epoch))
                self.save_genotypes_to_json(genotype_keep,
                                            np.array(fitness_keep),
                                            np.array(size_keep),
                                            'genotype_keep_jsons', ga_epoch)
                self.alphas = self.alphas[keep].clone()
                logging.info('############## End update alpha ############')

    def search_infer_step(self, alpha):
        """Infer in search stage.

        :param valid_queue: valid dataloader
        :type valid_queue: dataloader
        :param model: The model to be trained
        :type model: nn.Module
        :param alpha: encoding of a model
        :type alpha: array
        :return: Average top1 acc and loss
        :rtype: nn.Tensor
        """
        metrics = Metrics(self.trainer.cfg.metric)
        self.trainer.model.eval()
        with torch.no_grad():
            for step, (input, target) in enumerate(self.trainer.valid_loader):
                input = input.cuda()
                target = target.cuda(non_blocking=True)
                logits = self.trainer.model(input, alpha)
                metrics(logits, target)
        top1 = metrics.results[0]
        return top1

    def select_first_pareto_front(self, fitness, obj, genotypes):
        """Select models in the first pareto front.

        :param fitness: fitness, e.g. accuracy
        :type fitness: ndarray
        :param obj: objectives (model sizes, FLOPS, latency etc)
        :type obj: ndarray
        :param genotypes: genotypes for searched models
        :type genotypes: list
        :return: The selected samples
        :rtype: list
        """
        F, _, selected_idx = SortAndSelectPopulation(
            np.vstack((1 / fitness, obj)), self.alg_policy.pareto_model_num)
        selected_genotypes = []
        selected_acc = []
        selected_model_sizes = []
        for idx in selected_idx:
            selected_genotypes.append(genotypes[idx])
            selected_acc.append(fitness[idx])
            selected_model_sizes.append(obj[idx])
        return selected_genotypes, selected_acc, selected_model_sizes

    def select_uniform_pareto_front(self, fitness, obj, genotypes):
        """Select models in the first pareto front.

        :param fitness: fitness, e.g. accuracy
        :type fitness: ndarray
        :param obj: objectives (model sizes, FLOPS, latency etc)
        :type obj: ndarray
        :param genotypes: genotypes for searched models
        :type genotypes: list
        :return: The selected samples
        :rtype: list
        """
        # preprocess
        max_acc = fitness.max()
        keep = (fitness > max_acc * 0.5)
        fitness = fitness[keep]
        obj = obj[keep]
        genotypes = [i for (i, v) in zip(genotypes, keep) if v]
        max_obj = obj.max()
        min_obj = obj.min()
        grid_num = self.alg_policy.pareto_model_num
        grid = np.linspace(min_obj, max_obj, num=grid_num + 1)
        selected_idx = []
        for idx in range(grid_num):
            keep = (obj <= grid[idx]) | (obj > grid[idx + 1])
            sub_fitness = np.array(fitness)
            sub_fitness[keep] = 0
            selected_idx.append(sub_fitness.argmax())
        selected_genotypes = []
        selected_acc = []
        selected_model_sizes = []
        for idx in selected_idx:
            selected_genotypes.append(genotypes[idx])
            selected_acc.append(fitness[idx])
            selected_model_sizes.append(obj[idx])
        return selected_genotypes, selected_acc, selected_model_sizes

    def eval_model_sizes(self, alpha):
        """Calculate model size for a genotype.

        :param genotype: genotype for searched model
        :type genotype: list
        :return: The number of parameters
        :rtype: Float
        """
        normal = alpha[:self.trainer.model.len_alpha].data.cpu().numpy()
        reduce = alpha[self.trainer.model.len_alpha:].data.cpu().numpy()
        child_desc = self.search_alg.codec.calc_genotype([normal, reduce])
        child_cfg = copy.deepcopy(
            self.search_alg.codec.darts_cfg.super_network)
        child_cfg.normal.genotype = child_desc[0]
        child_cfg.reduce.genotype = child_desc[1]
        net = CARSDartsNetwork(child_cfg)
        model_size = eval_model_parameters(net)
        return model_size

    def genotype_namedtuple(self, alpha):
        """Obtain genotype.

        :param alpha: alpha for cell
        :type alpha: Tensor
        :return: genotype
        :rtype: Genotype
        """
        normal = alpha[:self.trainer.model.len_alpha].data.cpu().numpy()
        reduce = alpha[self.trainer.model.len_alpha:].data.cpu().numpy()
        child_desc = self.search_alg.codec.calc_genotype([normal, reduce])
        _multiplier = 4
        concat = range(2 + self.trainer.model._steps - _multiplier,
                       self.trainer.model._steps + 2)
        genotype = Genotype(normal=child_desc[0],
                            normal_concat=concat,
                            reduce=child_desc[1],
                            reduce_concat=concat)
        return genotype

    def save_model_checkpoint(self, model, model_name):
        """Save checkpoint for a model.

        :param model: A model
        :type model: nn.Module
        :param model_name: Path to save
        :type model_name: string
        """
        worker_path = self.trainer.get_local_worker_path()
        save_path = os.path.join(worker_path, model_name)
        _path, _ = os.path.split(save_path)
        if not os.path.isdir(_path):
            os.makedirs(_path)
        torch.save(model, save_path)
        logging.info("checkpoint saved to %s", save_path)

    def save_genotypes(self, genotypes, acc, obj, save_name):
        """Save genotypes.

        :param genotypes: Genotype for models
        :type genotypes: namedtuple Genotype
        :param acc: accuracy
        :type acc: ndarray
        :param obj: objectives, etc. FLOPs or number of parameters
        :type obj: ndarray
        :param save_name: Path to save
        :type save_name: string
        """
        worker_path = self.trainer.get_local_worker_path()
        save_path = os.path.join(worker_path, save_name)
        _path, _ = os.path.split(save_path)
        if not os.path.isdir(_path):
            os.makedirs(_path)
        with open(save_path, "w") as f:
            for idx in range(len(genotypes)):
                f.write('{}\t{}\t{}\n'.format(acc[idx], obj[idx],
                                              genotypes[idx]))
        logging.info("genotypes saved to %s", save_path)

    def save_genotypes_to_json(self, genotypes, acc, obj, save_folder,
                               ga_epoch):
        """Save genotypes.

        :param genotypes: Genotype for models
        :type genotypes: namedtuple Genotype
        :param acc: accuracy
        :type acc: ndarray
        :param obj: objectives, etc. FLOPs or number of parameters
        :type obj: ndarray
        :param save_name: Path to save
        :type save_name: string
        """
        if self.trainer.cfg.darts_template_file == "{default_darts_cifar10_template}":
            template = DefaultConfig().data.default_darts_cifar10_template
        elif self.trainer.cfg.darts_template_file == "{default_darts_imagenet_template}":
            template = DefaultConfig().data.default_darts_imagenet_template
        else:
            worker_path = self.trainer.get_local_worker_path()
            _path = os.path.join(worker_path,
                                 save_folder + '_{}'.format(ga_epoch))
            if not os.path.isdir(_path):
                os.makedirs(_path)
            base_file = os.path.basename(self.trainer.cfg.darts_template_file)
            local_template = FileOps.join_path(self.trainer.local_output_path,
                                               base_file)
            FileOps.copy_file(self.trainer.cfg.darts_template_file,
                              local_template)
            with open(local_template, 'r') as f:
                template = json.load(f)

        for idx in range(len(genotypes)):
            template_cfg = Config(template)
            template_cfg.super_network.normal.genotype = genotypes[idx].normal
            template_cfg.super_network.reduce.genotype = genotypes[idx].reduce
            self.trainer.output_model_desc(idx, template_cfg)
Exemple #4
0
class DartsTrainerCallback(Callback):
    """A special callback for DartsTrainer."""

    def before_train(self, epoch, logs=None):
        """Be called before the training process."""
        self.cfg = self.trainer.cfg
        self.trainer.auto_save_ckpt = False
        self.trainer.auto_save_perf = False
        self.unrolled = self.trainer.cfg.get('unrolled', True)
        self.device = self.trainer.cfg.get('device', 0)
        self.model = self.trainer.model
        self.optimizer = self.trainer.optimizer
        self.lr_scheduler = self.trainer.lr_scheduler
        self.loss = self.trainer.loss
        self.search_alg = SearchAlgorithm(SearchSpace())
        self._set_algorithm_model(self.model)
        self.trainer.train_loader = self.trainer._init_dataloader(mode='train')
        self.trainer.valid_loader = self.trainer._init_dataloader(mode='val')

    def before_epoch(self, epoch, logs=None):
        """Be called before each epoach."""
        self.valid_loader_iter = iter(self.trainer.valid_loader)

    def before_train_step(self, epoch, logs=None):
        """Be called before a batch training."""
        # Get curretn train batch directly from logs
        train_batch = logs['train_batch']
        train_input, train_target = train_batch
        # Prepare valid batch data by using valid loader from trianer
        try:
            valid_input, valid_target = next(self.valid_loader_iter)
        except Exception:
            self.valid_loader_iter = iter(self.trainer.valid_loader)
            valid_input, valid_target = next(self.valid_loader_iter)
        valid_input, valid_target = valid_input.to(self.device), valid_target.to(self.device)
        # Call arch search step
        self._train_arch_step(train_input, train_target, valid_input, valid_target)

    def after_epoch(self, epoch, logs=None):
        """Be called after each epoch."""
        child_desc_temp = self.search_alg.codec.calc_genotype(self.model.arch_weights)
        logging.info('normal = %s', child_desc_temp[0])
        logging.info('reduce = %s', child_desc_temp[1])
        logging.info('lr = {}'.format(self.lr_scheduler.get_lr()[0]))

    def after_train(self, epoch, logs=None):
        """Be called after Training."""
        child_desc = self.search_alg.codec.decode(self.model.arch_weights)
        self._save_descript(child_desc)
        self.trainer._backup()

    def _train_arch_step(self, train_input, train_target, valid_input, valid_target):
        lr = self.lr_scheduler.get_lr()[0]
        self.search_alg.step(train_input, train_target, valid_input, valid_target,
                             lr, self.optimizer, self.loss, self.unrolled)

    def _set_algorithm_model(self, model):
        self.search_alg.set_model(model)

    def _save_descript(self, descript):
        """Save result descript.

        :param descript: darts search result descript
        :type descript: dict or Config
        """
        template_file = self.cfg.darts_template_file
        genotypes = self.search_alg.codec.calc_genotype(self.model.arch_weights)
        if template_file == "{default_darts_cifar10_template}":
            template = DefaultConfig().data.default_darts_cifar10_template
        elif template_file == "{default_darts_imagenet_template}":
            template = DefaultConfig().data.default_darts_imagenet_template
        else:
            dst = FileOps.join_path(self.trainer.get_local_worker_path(), os.path.basename(template_file))
            FileOps.copy_file(template_file, dst)
            template = Config(dst)
        model_desc = self._gen_model_desc(genotypes, template)
        self.trainer.output_model_desc(self.trainer.worker_id, model_desc)

    def _gen_model_desc(self, genotypes, template):
        model_desc = deepcopy(template)
        model_desc.super_network.normal.genotype = genotypes[0]
        model_desc.super_network.reduce.genotype = genotypes[1]
        return model_desc