def __init__(self): self.step_name = General.step_name self.search_space = SearchSpace() self.search_alg = SearchAlgorithm(self.search_space.search_space) self.report = Report() self.record = ReportRecord() self.record.step_name = self.step_name if hasattr(self.search_alg.config, 'objective_keys'): self.record.objective_keys = self.search_alg.config.objective_keys
def before_train(self, logs=None): """Be called before the training process.""" self.config = self.trainer.config self.unrolled = self.trainer.config.unrolled self.device = self.trainer.config.device self.model = self.trainer.model self.optimizer = self.trainer.optimizer self.lr_scheduler = self.trainer.lr_scheduler self.loss = self.trainer.loss self.search_alg = SearchAlgorithm(SearchSpace().search_space) self._set_algorithm_model(self.model) self.trainer.train_loader = self.trainer._init_dataloader(mode='train') self.trainer.valid_loader = self.trainer._init_dataloader(mode='val')
def before_train(self, epoch, logs=None): """Be called before the training process.""" self.cfg = self.trainer.cfg self.trainer.auto_save_ckpt = False self.trainer.auto_save_perf = False self.unrolled = self.trainer.cfg.get('unrolled', True) self.device = self.trainer.cfg.get('device', 0) self.model = self.trainer.model self.optimizer = self.trainer.optimizer self.lr_scheduler = self.trainer.lr_scheduler self.loss = self.trainer.loss self.search_alg = SearchAlgorithm(SearchSpace()) self._set_algorithm_model(self.model) self.trainer.train_loader = self.trainer._init_dataloader(mode='train') self.trainer.valid_loader = self.trainer._init_dataloader(mode='val')
def train_process(self): """Train process of parameter sharing.""" self.train_loader = Dataset(mode='train').dataloader self.valid_loader = Dataset(mode='val').dataloader self.model = self.model.to(self.device) self.search_alg = SearchAlgorithm(SearchSpace()) self.set_algorithm_model(self.model) self.optimizer = self._init_optimizer() self.lr_scheduler = self._init_lr_scheduler() self.loss_fn = self._init_loss() np.random.seed(self.cfg.seed) torch.manual_seed(self.cfg.seed) for i in range(self.cfg.epochs): self._train(self.model) train_top1, train_top5 = self._valid(self.model, self.train_loader) valid_top1, valid_top5 = self._valid(self.model, self.valid_loader) self.lr_scheduler.step() child_desc_temp = self.search_alg.codec.calc_genotype( self.model.arch_weights) logging.info(F.softmax(self.model.alphas_normal, dim=-1)) logging.info(F.softmax(self.model.alphas_reduce, dim=-1)) logging.info('normal = %s', child_desc_temp[0]) logging.info('reduce = %s', child_desc_temp[1]) logging.info('Epoch {}: train top1: {}, train top5: {}'.format( i, train_top1, train_top5)) logging.info('Epoch {}: valid top1: {}, valid top5: {}'.format( i, valid_top1, valid_top5)) child_desc = self.search_alg.codec.decode(self.model.arch_weights) self._save_descript(child_desc)
def before_train(self, logs=None): """Be called before the training process.""" # Use zero valid_interval to supress default valid step self.trainer.valid_interval = 0 self.trainer.config.report_on_epoch = True if vega.is_torch_backend(): cudnn.benchmark = True cudnn.enabled = True self.search_alg = SearchAlgorithm(SearchSpace().search_space) self.alg_policy = self.search_alg.config.policy self.set_algorithm_model(self.trainer.model) # setup alphas n_individual = self.alg_policy.num_individual self.alphas = np.stack([self.search_alg.random_sample_path() for i in range(n_individual)], axis=0) self.trainer.train_loader = self.trainer._init_dataloader(mode='train') self.trainer.valid_loader = self.trainer._init_dataloader(mode='val')
def before_train(self, epoch, logs=None): """Be called before the training process.""" # Use zero valid_freq to supress default valid step self.trainer.auto_save_ckpt = False self.trainer.auto_save_perf = False self.trainer.valid_freq = 0 cudnn.benchmark = True cudnn.enabled = True self.search_alg = SearchAlgorithm(SearchSpace()) self.set_algorithm_model(self.trainer.model) # setup alphas n_individual = self.alg_policy.num_individual self.alphas = torch.cat([ self.trainer.model.random_single_path().unsqueeze(0) for i in range(n_individual) ], dim=0) self.trainer.train_loader = self.trainer._init_dataloader(mode='train') self.trainer.valid_loader = self.trainer._init_dataloader(mode='val')
class CARSTrainerCallback(Callback): """A special callback for CARSTrainer.""" disable_callbacks = ["ModelStatistics"] def __init__(self): super(CARSTrainerCallback, self).__init__() self.alg_policy = None def before_train(self, logs=None): """Be called before the training process.""" # Use zero valid_interval to supress default valid step self.trainer.valid_interval = 0 self.trainer.config.report_on_epoch = True if vega.is_torch_backend(): cudnn.benchmark = True cudnn.enabled = True self.search_alg = SearchAlgorithm(SearchSpace().search_space) self.alg_policy = self.search_alg.config.policy self.set_algorithm_model(self.trainer.model) # setup alphas n_individual = self.alg_policy.num_individual self.alphas = np.stack([self.search_alg.random_sample_path() for i in range(n_individual)], axis=0) self.trainer.train_loader = self.trainer._init_dataloader(mode='train') self.trainer.valid_loader = self.trainer._init_dataloader(mode='val') def before_epoch(self, epoch, logs=None): """Be called before each epoach.""" self.epoch = epoch def train_step(self, batch): """Replace the default train_step function.""" self.trainer.model.train() input, target = batch self.trainer.optimizer.zero_grad() alphas = torch.from_numpy(self.alphas).cuda() for j in range(self.alg_policy.num_individual_per_iter): i = np.random.randint(0, self.alg_policy.num_individual, 1)[0] if self.epoch < self.alg_policy.warmup: alpha = torch.from_numpy(self.search_alg.random_sample_path()).cuda() # logits = self.trainer.model.forward_random(input) else: alpha = alphas[i] logits = self.trainer.model(input, alpha) loss = self.trainer.loss(logits, target) loss.backward(retain_graph=True) if self.epoch < self.alg_policy.warmup: break nn.utils.clip_grad_norm_( self.trainer.model.parameters(), self.trainer.config.grad_clip) self.trainer.optimizer.step() return {'loss': loss.item(), 'train_batch_output': logits} def model_fn(self, features, labels, mode): """Define cars model_fn used by TensorFlow Estimator.""" logging.info('Cars model function action') self.trainer.loss = Loss()() train_op = None if mode == tf.estimator.ModeKeys.TRAIN: global_step = tf.train.get_global_step() epoch = tf.cast(global_step, tf.float32) / tf.cast(len(self.trainer.train_loader), tf.float32) self.trainer.lr_scheduler = LrScheduler()() self.trainer.optimizer = Optimizer()(lr_scheduler=self.trainer.lr_scheduler, epoch=epoch, distributed=self.trainer.distributed) alphas = tf.convert_to_tensor(self.alphas) for j in range(self.alg_policy.num_individual_per_iter): i = np.random.randint(0, self.alg_policy.num_individual, 1)[0] if self.epoch < self.alg_policy.warmup: alpha = tf.convert_to_tensor(self.search_alg.random_sample_path()) else: alpha = alphas[i] logits = self.trainer.model(features, alpha, True) logits = tf.cast(logits, tf.float32) loss = self.trainer.loss(logits=logits, labels=labels) grads, vars = zip(*self.trainer.optimizer.compute_gradients(loss)) if j == 0: accum_grads = [tf.Variable(tf.zeros_like(grad), trainable=False) for grad in grads] accum_grads = [accum_grads[k] + grads[k] for k in range(len(grads))] if self.epoch < self.alg_policy.warmup: break clipped_grads, _ = tf.clip_by_global_norm(accum_grads, self.trainer.config.grad_clip) minimize_op = self.trainer.optimizer.apply_gradients(list(zip(clipped_grads, vars)), global_step) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) train_op = tf.group(minimize_op, update_ops) eval_metric_ops = None if mode == tf.estimator.ModeKeys.EVAL: alpha = tf.convert_to_tensor(self.trainer.valid_alpha) logits = self.trainer.model(features, alpha, False) logits = tf.cast(logits, tf.float32) loss = self.trainer.loss(logits=logits, labels=labels) eval_metric_ops = self.trainer.valid_metrics(logits, labels) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops) def after_epoch(self, epoch, logs=None): """Be called after each epoch.""" self.alphas = self.search_alg.search_evol_arch(epoch, self.alg_policy, self.trainer, self.alphas) def set_algorithm_model(self, model): """Set model to algorithm. :param model: network model :type model: torch.nn.Module """ self.search_alg.set_model(model)
class CARSTrainerCallback(Callback): """A special callback for CARSTrainer.""" def __init__(self): super(CARSTrainerCallback, self).__init__() self.alg_policy = ClassFactory.__configs__.search_algorithm.policy def before_train(self, epoch, logs=None): """Be called before the training process.""" # Use zero valid_freq to supress default valid step self.trainer.auto_save_ckpt = False self.trainer.auto_save_perf = False self.trainer.valid_freq = 0 cudnn.benchmark = True cudnn.enabled = True self.search_alg = SearchAlgorithm(SearchSpace()) self.set_algorithm_model(self.trainer.model) # setup alphas n_individual = self.alg_policy.num_individual self.alphas = torch.cat([ self.trainer.model.random_single_path().unsqueeze(0) for i in range(n_individual) ], dim=0) self.trainer.train_loader = self.trainer._init_dataloader(mode='train') self.trainer.valid_loader = self.trainer._init_dataloader(mode='val') def before_epoch(self, epoch, logs=None): """Be called before each epoach.""" self.epoch = epoch self.trainer.lr_scheduler.step() def train_step(self, batch): """Replace the default train_step function.""" self.trainer.model.train() input, target = batch self.trainer.optimizer.zero_grad() for j in range(self.alg_policy.num_individual_per_iter): i = np.random.randint(0, self.alg_policy.num_individual, 1)[0] if self.epoch < self.alg_policy.warmup: logits = self.trainer.model.forward_random(input) else: logits = self.trainer.model(input, self.alphas[i]) loss = self.trainer.loss(logits, target) loss.backward(retain_graph=True) if self.epoch < self.alg_policy.warmup: break nn.utils.clip_grad_norm(self.trainer.model.parameters(), self.trainer.cfg.grad_clip) self.trainer.optimizer.step() return {'loss': loss.item(), 'train_batch_output': logits} def after_epoch(self, epoch, logs=None): """Be called after each epoch.""" self.search_evol_arch(epoch) def set_algorithm_model(self, model): """Set model to algorithm. :param model: network model :type model: torch.nn.Module """ self.search_alg.set_model(model) def search_evol_arch(self, epoch): """Update architectures. :param epoch: The current epoch :type epoch: int :param valid_queue: valid dataloader :type valid_queue: dataloader :param model: The model to be trained :type model: nn.Module """ if epoch >= self.alg_policy.start_ga_epoch and \ (epoch - self.alg_policy.start_ga_epoch) % self.alg_policy.ga_interval == 0: self.save_model_checkpoint(self.trainer.model, 'weights_{}.pt'.format(epoch)) for generation in range(self.alg_policy.num_generation): fitness = np.zeros( int(self.alg_policy.num_individual * (1 + self.alg_policy.expand))) model_sizes = np.zeros( int(self.alg_policy.num_individual * (1 + self.alg_policy.expand))) genotypes = [] # generate offsprings using mutation and cross-over offsprings = self.search_alg.gen_offspring(self.alphas) self.alphas = torch.cat((self.alphas, offsprings), dim=0) # calculate fitness (accuracy) and #parameters for i in range( int(self.alg_policy.num_individual * (1 + self.alg_policy.expand))): fitness[i], _ = self.search_infer_step(self.alphas[i]) genotypes.append(self.genotype_namedtuple(self.alphas[i])) model_sizes[i] = self.eval_model_sizes(self.alphas[i]) logging.info( 'Valid_acc for invidual {} %f, size %f'.format(i), fitness[i], model_sizes[i]) # update population using pNSGA-III (CARS_NSGA) logging.info('############## Begin update alpha ############') if self.alg_policy.nsga_method == 'nsga3': _, _, keep = SortAndSelectPopulation( np.vstack((1 / fitness, model_sizes)), self.alg_policy.num_individual) elif self.alg_policy.nsga_method == 'cars_nsga': nsga_objs = [model_sizes] keep = CARS_NSGA(fitness, nsga_objs, self.alg_policy.num_individual) drop = list( set( list( range( int(self.alg_policy.num_individual * (1 + self.alg_policy.expand))))) - set(keep.tolist())) logging.info('############## KEEP ############') fitness_keep = [] size_keep = [] genotype_keep = [] for i in keep: logging.info( 'KEEP Valid_acc for invidual {} %f, size %f, genotype %s' .format(i), fitness[i], model_sizes[i], genotypes[i]) fitness_keep.append(fitness[i]) size_keep.append(model_sizes[i]) genotype_keep.append(genotypes[i]) logging.info('############## DROP ############') for i in drop: logging.info( 'DROP Valid_acc for invidual {} %f, size %f, genotype %s' .format(i), fitness[i], model_sizes[i], genotypes[i]) if self.alg_policy.select_method == 'uniform': selected_genotypes, selected_acc, selected_model_sizes = \ self.select_uniform_pareto_front( np.array(fitness_keep), np.array(size_keep), genotype_keep) else: # default: first selected_genotypes, selected_acc, selected_model_sizes = \ self.select_first_pareto_front( np.array(fitness_keep), np.array(size_keep), genotype_keep) ga_epoch = int((epoch - self.alg_policy.start_ga_epoch) / self.alg_policy.ga_interval) self.save_genotypes( selected_genotypes, selected_acc, selected_model_sizes, 'genotype_selected_{}.txt'.format(ga_epoch)) self.save_genotypes(genotype_keep, np.array(fitness_keep), np.array(size_keep), 'genotype_keep_{}.txt'.format(ga_epoch)) self.save_genotypes_to_json(genotype_keep, np.array(fitness_keep), np.array(size_keep), 'genotype_keep_jsons', ga_epoch) self.alphas = self.alphas[keep].clone() logging.info('############## End update alpha ############') def search_infer_step(self, alpha): """Infer in search stage. :param valid_queue: valid dataloader :type valid_queue: dataloader :param model: The model to be trained :type model: nn.Module :param alpha: encoding of a model :type alpha: array :return: Average top1 acc and loss :rtype: nn.Tensor """ metrics = Metrics(self.trainer.cfg.metric) self.trainer.model.eval() with torch.no_grad(): for step, (input, target) in enumerate(self.trainer.valid_loader): input = input.cuda() target = target.cuda(non_blocking=True) logits = self.trainer.model(input, alpha) metrics(logits, target) top1 = metrics.results[0] return top1 def select_first_pareto_front(self, fitness, obj, genotypes): """Select models in the first pareto front. :param fitness: fitness, e.g. accuracy :type fitness: ndarray :param obj: objectives (model sizes, FLOPS, latency etc) :type obj: ndarray :param genotypes: genotypes for searched models :type genotypes: list :return: The selected samples :rtype: list """ F, _, selected_idx = SortAndSelectPopulation( np.vstack((1 / fitness, obj)), self.alg_policy.pareto_model_num) selected_genotypes = [] selected_acc = [] selected_model_sizes = [] for idx in selected_idx: selected_genotypes.append(genotypes[idx]) selected_acc.append(fitness[idx]) selected_model_sizes.append(obj[idx]) return selected_genotypes, selected_acc, selected_model_sizes def select_uniform_pareto_front(self, fitness, obj, genotypes): """Select models in the first pareto front. :param fitness: fitness, e.g. accuracy :type fitness: ndarray :param obj: objectives (model sizes, FLOPS, latency etc) :type obj: ndarray :param genotypes: genotypes for searched models :type genotypes: list :return: The selected samples :rtype: list """ # preprocess max_acc = fitness.max() keep = (fitness > max_acc * 0.5) fitness = fitness[keep] obj = obj[keep] genotypes = [i for (i, v) in zip(genotypes, keep) if v] max_obj = obj.max() min_obj = obj.min() grid_num = self.alg_policy.pareto_model_num grid = np.linspace(min_obj, max_obj, num=grid_num + 1) selected_idx = [] for idx in range(grid_num): keep = (obj <= grid[idx]) | (obj > grid[idx + 1]) sub_fitness = np.array(fitness) sub_fitness[keep] = 0 selected_idx.append(sub_fitness.argmax()) selected_genotypes = [] selected_acc = [] selected_model_sizes = [] for idx in selected_idx: selected_genotypes.append(genotypes[idx]) selected_acc.append(fitness[idx]) selected_model_sizes.append(obj[idx]) return selected_genotypes, selected_acc, selected_model_sizes def eval_model_sizes(self, alpha): """Calculate model size for a genotype. :param genotype: genotype for searched model :type genotype: list :return: The number of parameters :rtype: Float """ normal = alpha[:self.trainer.model.len_alpha].data.cpu().numpy() reduce = alpha[self.trainer.model.len_alpha:].data.cpu().numpy() child_desc = self.search_alg.codec.calc_genotype([normal, reduce]) child_cfg = copy.deepcopy( self.search_alg.codec.darts_cfg.super_network) child_cfg.normal.genotype = child_desc[0] child_cfg.reduce.genotype = child_desc[1] net = CARSDartsNetwork(child_cfg) model_size = eval_model_parameters(net) return model_size def genotype_namedtuple(self, alpha): """Obtain genotype. :param alpha: alpha for cell :type alpha: Tensor :return: genotype :rtype: Genotype """ normal = alpha[:self.trainer.model.len_alpha].data.cpu().numpy() reduce = alpha[self.trainer.model.len_alpha:].data.cpu().numpy() child_desc = self.search_alg.codec.calc_genotype([normal, reduce]) _multiplier = 4 concat = range(2 + self.trainer.model._steps - _multiplier, self.trainer.model._steps + 2) genotype = Genotype(normal=child_desc[0], normal_concat=concat, reduce=child_desc[1], reduce_concat=concat) return genotype def save_model_checkpoint(self, model, model_name): """Save checkpoint for a model. :param model: A model :type model: nn.Module :param model_name: Path to save :type model_name: string """ worker_path = self.trainer.get_local_worker_path() save_path = os.path.join(worker_path, model_name) _path, _ = os.path.split(save_path) if not os.path.isdir(_path): os.makedirs(_path) torch.save(model, save_path) logging.info("checkpoint saved to %s", save_path) def save_genotypes(self, genotypes, acc, obj, save_name): """Save genotypes. :param genotypes: Genotype for models :type genotypes: namedtuple Genotype :param acc: accuracy :type acc: ndarray :param obj: objectives, etc. FLOPs or number of parameters :type obj: ndarray :param save_name: Path to save :type save_name: string """ worker_path = self.trainer.get_local_worker_path() save_path = os.path.join(worker_path, save_name) _path, _ = os.path.split(save_path) if not os.path.isdir(_path): os.makedirs(_path) with open(save_path, "w") as f: for idx in range(len(genotypes)): f.write('{}\t{}\t{}\n'.format(acc[idx], obj[idx], genotypes[idx])) logging.info("genotypes saved to %s", save_path) def save_genotypes_to_json(self, genotypes, acc, obj, save_folder, ga_epoch): """Save genotypes. :param genotypes: Genotype for models :type genotypes: namedtuple Genotype :param acc: accuracy :type acc: ndarray :param obj: objectives, etc. FLOPs or number of parameters :type obj: ndarray :param save_name: Path to save :type save_name: string """ if self.trainer.cfg.darts_template_file == "{default_darts_cifar10_template}": template = DefaultConfig().data.default_darts_cifar10_template elif self.trainer.cfg.darts_template_file == "{default_darts_imagenet_template}": template = DefaultConfig().data.default_darts_imagenet_template else: worker_path = self.trainer.get_local_worker_path() _path = os.path.join(worker_path, save_folder + '_{}'.format(ga_epoch)) if not os.path.isdir(_path): os.makedirs(_path) base_file = os.path.basename(self.trainer.cfg.darts_template_file) local_template = FileOps.join_path(self.trainer.local_output_path, base_file) FileOps.copy_file(self.trainer.cfg.darts_template_file, local_template) with open(local_template, 'r') as f: template = json.load(f) for idx in range(len(genotypes)): template_cfg = Config(template) template_cfg.super_network.normal.genotype = genotypes[idx].normal template_cfg.super_network.reduce.genotype = genotypes[idx].reduce self.trainer.output_model_desc(idx, template_cfg)
class DartsTrainerCallback(Callback): """A special callback for DartsTrainer.""" def before_train(self, logs=None): """Be called before the training process.""" self.config = self.trainer.config self.unrolled = self.trainer.config.unrolled self.device = self.trainer.config.device self.model = self.trainer.model self.optimizer = self.trainer.optimizer self.lr_scheduler = self.trainer.lr_scheduler self.loss = self.trainer.loss self.search_alg = SearchAlgorithm(SearchSpace().search_space) self._set_algorithm_model(self.model) self.trainer.train_loader = self.trainer._init_dataloader(mode='train') self.trainer.valid_loader = self.trainer._init_dataloader(mode='val') def before_epoch(self, epoch, logs=None): """Be called before each epoach.""" if vega.is_torch_backend(): self.valid_loader_iter = iter(self.trainer.valid_loader) def before_train_step(self, epoch, logs=None): """Be called before a batch training.""" # Get current train batch directly from logs train_batch = logs['train_batch'] train_input, train_target = train_batch # Prepare valid batch data by using valid loader from trainer try: valid_input, valid_target = next(self.valid_loader_iter) except Exception: self.valid_loader_iter = iter(self.trainer.valid_loader) valid_input, valid_target = next(self.valid_loader_iter) valid_input, valid_target = valid_input.to(self.device), valid_target.to(self.device) # Call arch search step self._train_arch_step(train_input, train_target, valid_input, valid_target) def after_epoch(self, epoch, logs=None): """Be called after each epoch.""" child_desc_temp = self.search_alg.codec.calc_genotype(self._get_arch_weights()) logging.info('normal = %s', child_desc_temp[0]) logging.info('reduce = %s', child_desc_temp[1]) self._save_descript() def after_train(self, logs=None): """Be called after Training.""" self.trainer._backup() def _train_arch_step(self, train_input, train_target, valid_input, valid_target): lr = self.lr_scheduler.get_lr()[0] self.search_alg.step(train_input, train_target, valid_input, valid_target, lr, self.optimizer, self.loss, self.unrolled) def _set_algorithm_model(self, model): self.search_alg.set_model(model) def train_input_fn(self): """Input function for search.""" def map_to_dict(td, vd): return {'train': td[0], 'valid': vd[0]}, {'train': td[1], 'valid': vd[1]} dataset = tf.data.Dataset.zip((self.trainer.train_loader.input_fn(), self.trainer.valid_loader.input_fn())) dataset = dataset.map(lambda td, vd: map_to_dict(td, vd)) dataset = dataset.prefetch(buffer_size=tf.contrib.data.AUTOTUNE) return dataset def model_fn(self, features, labels, mode): """Darts model_fn used by TensorFlow Estimator.""" logging.info('Darts model function action') global_step = tf.train.get_global_step() train_op = None if mode == tf.estimator.ModeKeys.TRAIN: features, valid_features = features['train'], features['valid'] labels, valid_labels = labels['train'], labels['valid'] # update arch epoch = tf.cast(global_step, tf.float32) / tf.cast(len(self.trainer.train_loader), tf.float32) self.trainer.lr_scheduler = LrScheduler()() self.trainer.optimizer = Optimizer()(lr_scheduler=self.trainer.lr_scheduler, epoch=epoch, distributed=self.trainer.distributed) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) arch_minimize_op = self.search_alg.step(valid_x=valid_features, valid_y=valid_labels, lr=self.trainer.lr_scheduler.get_lr()[0]) train_op = tf.group(arch_minimize_op, update_ops) logits = self.model(features, mode == tf.estimator.ModeKeys.TRAIN) logits = tf.cast(logits, tf.float32) self.trainer.loss = Loss()() loss = self.trainer.loss(logits=logits, labels=labels) if mode == tf.estimator.ModeKeys.TRAIN: with tf.control_dependencies([train_op]): weight_ops = self.model.get_weight_ops() train_op = self.trainer._init_minimize_op(loss, global_step, weight_ops) eval_metric_ops = None if mode == tf.estimator.ModeKeys.EVAL: eval_metric_ops = self.trainer.valid_metrics(logits, labels) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops) def _get_arch_weights(self): if vega.is_torch_backend(): arch_weights = self.model.arch_weights elif vega.is_tf_backend(): sess_config = self.trainer._init_session_config() with tf.Session(config=sess_config) as sess: # tf.reset_default_graph() checkpoint_file = tf.train.latest_checkpoint(self.trainer.get_local_worker_path()) saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file)) saver.restore(sess, checkpoint_file) arch_weights = self.model.arch_weights arch_weights = [weight.eval() for weight in arch_weights] return arch_weights def _save_descript(self): """Save result descript.""" template_file = self.config.darts_template_file genotypes = self.search_alg.codec.calc_genotype(self._get_arch_weights()) if template_file == "{default_darts_cifar10_template}": template = DartsNetworkTemplateConfig.cifar10 elif template_file == "{default_darts_imagenet_template}": template = DartsNetworkTemplateConfig.imagenet else: dst = FileOps.join_path(self.trainer.get_local_worker_path(), os.path.basename(template_file)) FileOps.copy_file(template_file, dst) template = Config(dst) model_desc = self._gen_model_desc(genotypes, template) self.trainer.config.codec = model_desc def _gen_model_desc(self, genotypes, template): model_desc = deepcopy(template) model_desc.super_network.normal.genotype = genotypes[0] model_desc.super_network.reduce.genotype = genotypes[1] return model_desc
class DartsTrainerCallback(Callback): """A special callback for DartsTrainer.""" def before_train(self, epoch, logs=None): """Be called before the training process.""" self.cfg = self.trainer.cfg self.trainer.auto_save_ckpt = False self.trainer.auto_save_perf = False self.unrolled = self.trainer.cfg.get('unrolled', True) self.device = self.trainer.cfg.get('device', 0) self.model = self.trainer.model self.optimizer = self.trainer.optimizer self.lr_scheduler = self.trainer.lr_scheduler self.loss = self.trainer.loss self.search_alg = SearchAlgorithm(SearchSpace()) self._set_algorithm_model(self.model) self.trainer.train_loader = self.trainer._init_dataloader(mode='train') self.trainer.valid_loader = self.trainer._init_dataloader(mode='val') def before_epoch(self, epoch, logs=None): """Be called before each epoach.""" self.valid_loader_iter = iter(self.trainer.valid_loader) def before_train_step(self, epoch, logs=None): """Be called before a batch training.""" # Get curretn train batch directly from logs train_batch = logs['train_batch'] train_input, train_target = train_batch # Prepare valid batch data by using valid loader from trianer try: valid_input, valid_target = next(self.valid_loader_iter) except Exception: self.valid_loader_iter = iter(self.trainer.valid_loader) valid_input, valid_target = next(self.valid_loader_iter) valid_input, valid_target = valid_input.to(self.device), valid_target.to(self.device) # Call arch search step self._train_arch_step(train_input, train_target, valid_input, valid_target) def after_epoch(self, epoch, logs=None): """Be called after each epoch.""" child_desc_temp = self.search_alg.codec.calc_genotype(self.model.arch_weights) logging.info('normal = %s', child_desc_temp[0]) logging.info('reduce = %s', child_desc_temp[1]) logging.info('lr = {}'.format(self.lr_scheduler.get_lr()[0])) def after_train(self, epoch, logs=None): """Be called after Training.""" child_desc = self.search_alg.codec.decode(self.model.arch_weights) self._save_descript(child_desc) self.trainer._backup() def _train_arch_step(self, train_input, train_target, valid_input, valid_target): lr = self.lr_scheduler.get_lr()[0] self.search_alg.step(train_input, train_target, valid_input, valid_target, lr, self.optimizer, self.loss, self.unrolled) def _set_algorithm_model(self, model): self.search_alg.set_model(model) def _save_descript(self, descript): """Save result descript. :param descript: darts search result descript :type descript: dict or Config """ template_file = self.cfg.darts_template_file genotypes = self.search_alg.codec.calc_genotype(self.model.arch_weights) if template_file == "{default_darts_cifar10_template}": template = DefaultConfig().data.default_darts_cifar10_template elif template_file == "{default_darts_imagenet_template}": template = DefaultConfig().data.default_darts_imagenet_template else: dst = FileOps.join_path(self.trainer.get_local_worker_path(), os.path.basename(template_file)) FileOps.copy_file(template_file, dst) template = Config(dst) model_desc = self._gen_model_desc(genotypes, template) self.trainer.output_model_desc(self.trainer.worker_id, model_desc) def _gen_model_desc(self, genotypes, template): model_desc = deepcopy(template) model_desc.super_network.normal.genotype = genotypes[0] model_desc.super_network.reduce.genotype = genotypes[1] return model_desc
class Generator(object): """Convert search space and search algorithm, sample a new model.""" def __init__(self): self.step_name = General.step_name self.search_space = SearchSpace() self.search_alg = SearchAlgorithm(self.search_space.search_space) self.report = Report() self.record = ReportRecord() self.record.step_name = self.step_name if hasattr(self.search_alg.config, 'objective_keys'): self.record.objective_keys = self.search_alg.config.objective_keys @property def is_completed(self): """Define a property to determine search algorithm is completed.""" return self.search_alg.is_completed def sample(self): """Sample a work id and model from search algorithm.""" res = self.search_alg.search() if not res: return None if not isinstance(res, list): res = [res] out = [] for sample in res: if isinstance(sample, tuple): sample = dict(worker_id=sample[0], desc=sample[1]) record = self.record.load_dict(sample) logging.debug("Broadcast Record=%s", str(record)) desc = self._decode_hps(record.desc) record.desc = desc Report().broadcast(record) out.append((record.worker_id, desc)) return out def update(self, step_name, worker_id): """Update search algorithm accord to the worker path. :param step_name: step name :param worker_id: current worker id :return: """ report = Report() record = report.receive(step_name, worker_id) logging.debug("Get Record=%s", str(record)) self.search_alg.update(record.serialize()) report.dump_report(record.step_name, record) logging.info("Update Success. step_name=%s, worker_id=%s", step_name, worker_id) logging.info("Best values: %s", Report().pareto_front(step_name=General.step_name)) @staticmethod def _decode_hps(hps): """Decode hps: `trainer.optim.lr : 0.1` to dict format. And convert to `vega.core.common.config import Config` object This Config will be override in Trainer or Datasets class The override priority is: input hps > user configuration > default configuration :param hps: hyper params :return: dict """ hps_dict = {} if hps is None: return None if isinstance(hps, tuple): return hps for hp_name, value in hps.items(): hp_dict = {} for key in list(reversed(hp_name.split('.'))): if hp_dict: hp_dict = {key: hp_dict} else: hp_dict = {key: value} # update cfg with hps hps_dict = update_dict(hps_dict, hp_dict, []) return Config(hps_dict)