def augment(out_dir, chkpt_path, train_loader, valid_loader, model, writer, logger, device, config): w_optim = utils.get_optim(model.weights(), config.w_optim) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( w_optim, config.epochs, eta_min=config.w_optim.lr_min) init_epoch = -1 if chkpt_path is not None: logger.info("Resuming from checkpoint: %s" % chkpt_path) checkpoint = torch.load(chkpt_path) model.load_state_dict(checkpoint['model']) w_optim.load_state_dict(checkpoint['w_optim']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) init_epoch = checkpoint['epoch'] else: logger.info("Starting new training run") logger.info("Model params count: {:.3f} M, size: {:.3f} MB".format( utils.param_size(model), utils.param_count(model))) # training loop logger.info('begin training') best_top1 = 0. tot_epochs = config.epochs for epoch in itertools.count(init_epoch + 1): if epoch == tot_epochs: break drop_prob = config.drop_path_prob * epoch / tot_epochs model.drop_path_prob(drop_prob) lr = lr_scheduler.get_lr()[0] # training train(train_loader, None, model, writer, logger, None, w_optim, None, lr, epoch, tot_epochs, device, config) lr_scheduler.step() # validation cur_step = (epoch + 1) * len(train_loader) top1 = validate(valid_loader, model, writer, logger, epoch, tot_epochs, cur_step, device, config) # save if best_top1 < top1: best_top1 = top1 is_best = True else: is_best = False if config.save_freq != 0 and epoch % config.save_freq == 0: save_checkpoint(out_dir, model, w_optim, None, lr_scheduler, epoch, logger) print("") logger.info("Final best Prec@1 = {:.4%}".format(best_top1)) tprof.stat_acc('model_' + NASModule.get_device()[0])
def get_model(config, device, dev_list, genotype=None): mtype = config.type configure_ops(config) if mtype in model_creator: config.augment = not genotype is None net, arch = model_creator[mtype](config) crit = get_net_crit(config).to(device) prim = gt.get_primitives() model = NASController(config, net, crit, prim, dev_list).to(device) if config.augment: print("genotype = {}".format(genotype)) model.build_from_genotype(genotype) model.to(device=device) if config.verbose: print(model) mb_params = param_size(model) n_params = param_count(model) print("Model params count: {:.3f} M, size: {:.3f} MB".format(n_params, mb_params)) NASModule.set_device(dev_list) return model, arch else: raise Exception("invalid model type")
def save_checkpoint(out_dir, model, w_optim, a_optim, lr_scheduler, epoch, logger): try: save_path = os.path.join(out_dir, 'chkpt_%03d.pt' % (epoch + 1)) torch.save( { 'model': model.state_dict(), 'arch': NASModule.nasmod_state_dict(), 'w_optim': w_optim.state_dict(), 'a_optim': a_optim.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'epoch': epoch, }, save_path) logger.info("Saved checkpoint to: %s" % save_path) except Exception as e: logger.error("Save checkpoint failed: " + str(e))
def __init__(self, config, n_nodes, chn_in, stride, shared_a, allocator, merger_state, merger_out, enumerator, preproc, aggregate, edge_cls, edge_kwargs): super().__init__() global edge_id self.edge_id = edge_id edge_id += 1 self.n_nodes = n_nodes self.stride = stride self.chn_in = chn_in self.n_input = len(chn_in) self.n_states = self.n_input + self.n_nodes self.n_input_e = len(edge_kwargs['chn_in']) self.shared_a = shared_a if shared_a: NASModule.add_shared_param() self.allocator = allocator(self.n_input, self.n_nodes) self.merger_state = merger_state() self.merger_out = merger_out(start=self.n_input) self.merge_out_range = self.merger_out.merge_range(self.n_states) self.enumerator = enumerator() chn_states = [] if not preproc: self.preprocs = None chn_states.extend(chn_in) else: chn_cur = edge_kwargs['chn_in'][0] self.preprocs = nn.ModuleList() for i in range(self.n_input): self.preprocs.append(preproc[i](chn_in[i], chn_cur, False)) chn_states.append(chn_cur) if not config.augment: self.fixed = False self.dag = nn.ModuleList() self.edges = [] self.num_edges = 0 for i in range(n_nodes): cur_state = self.n_input + i self.dag.append(nn.ModuleList()) num_edges = self.enumerator.len_enum(cur_state, self.n_input_e) for sidx in self.enumerator.enum(cur_state, self.n_input_e): e_chn_in = self.allocator.chn_in( [chn_states[s] for s in sidx], sidx, cur_state) edge_kwargs['chn_in'] = e_chn_in edge_kwargs['stride'] = stride if all(s < self.n_input for s in sidx) else 1 edge_kwargs['shared_a'] = shared_a e = edge_cls(**edge_kwargs) self.dag[i].append(e) self.edges.append(e) self.num_edges += num_edges chn_states.append( self.merger_state.chn_out( [ei.chn_out for ei in self.dag[i]])) self.chn_out = self.merger_out.chn_out(chn_states) # print('DAGLayer: etype:{} chn_in:{} chn:{} #n:{} #e:{}'.format(str(edge_cls), self.chn_in, edge_kwargs['chn_in'][0],self.n_nodes, self.num_edges)) # print('DAGLayer param count: {:.6f}'.format(param_count(self))) else: self.chn_states = chn_states self.edge_cls = edge_cls self.edge_kwargs = edge_kwargs self.fixed = True if aggregate is not None: self.merge_filter = aggregate(n_in=self.n_input + self.n_nodes, n_out=self.n_input + self.n_nodes // 2) else: self.merge_filter = None self.chn_out = self.merger_out.chn_out(chn_states)
def __init__(self, config, n_nodes, chn_in, stride, shared_a, allocator, merger_out, preproc, aggregate, child_cls, child_kwargs, edge_cls, edge_kwargs, children=None, edges=None): super().__init__() self.edges = nn.ModuleList() self.subnets = nn.ModuleList() chn_in = (chn_in, ) if isinstance(chn_in, int) else chn_in self.n_input = len(chn_in) self.n_nodes = n_nodes self.n_states = self.n_input + self.n_nodes self.allocator = allocator(self.n_input, self.n_nodes) self.merger_out = merger_out(start=self.n_input) self.merge_out_range = self.merger_out.merge_range(self.n_states) if shared_a: NASModule.add_shared_param() chn_states = [] if not preproc: self.preprocs = None chn_states.extend(chn_in) else: chn_cur = edge_kwargs['chn_in'][0] self.preprocs = nn.ModuleList() for i in range(self.n_input): self.preprocs.append( preproc(chn_in[i], chn_cur, 1, 1, 0, False)) chn_states.append(chn_cur) sidx = range(self.n_input) for i in range(self.n_nodes): e_chn_in = self.allocator.chn_in([chn_states[s] for s in sidx], sidx, i) if not edges is None: self.edges.append(edges[i]) c_chn_in = edges[i].chn_out elif not edge_cls is None: edge_kwargs['chn_in'] = e_chn_in edge_kwargs['stride'] = stride if 'shared_a' in edge_kwargs: edge_kwargs['shared_a'] = shared_a e = edge_cls(**edge_kwargs) self.edges.append(e) c_chn_in = e.chn_out else: self.edges.append(None) c_chn_in = e_chn_in if not children is None: self.subnets.append(children[i]) elif not child_cls is None: child_kwargs['chn_in'] = c_chn_in self.subnets.append(child_cls(**child_kwargs)) else: self.subnets.append(None) if aggregate is not None: self.merge_filter = aggregate(n_in=self.n_states, n_out=self.n_states // 2) else: self.merge_filter = None
def step(self, trn_X, trn_y, val_X, val_y, xi, w_optim, a_optim): """ Compute unrolled loss and backward its gradients Args: xi: learning rate for virtual gradient step (same as net lr) w_optim: weights optimizer - for virtual step """ a_optim.zero_grad() # sample k if self.sample: NASModule.param_module_call('sample_ops', n_samples=self.n_samples) loss = self.net.loss(val_X, val_y) m_out_dev = [] for dev_id in NASModule.get_device(): m_out = [ m.get_state('m_out' + dev_id) for m in NASModule.modules() ] m_len = len(m_out) m_out_dev.extend(m_out) m_grad = torch.autograd.grad(loss, m_out_dev) for i, dev_id in enumerate(NASModule.get_device()): NASModule.param_backward_from_grad( m_grad[i * m_len:(i + 1) * m_len], dev_id) if not self.renorm: a_optim.step() else: # renormalization prev_pw = [] for p, m in NASModule.param_modules(): s_op = m.get_state('s_op') pdt = p.detach() pp = pdt.index_select(-1, s_op) if pp.size() == pdt.size(): continue k = torch.sum(torch.exp(pdt)) / torch.sum(torch.exp(pp)) - 1 prev_pw.append(k) a_optim.step() for kprev, (p, m) in zip(prev_pw, NASModule.param_modules()): s_op = m.get_state('s_op') pdt = p.detach() pp = pdt.index_select(-1, s_op) k = torch.sum(torch.exp(pdt)) / torch.sum(torch.exp(pp)) - 1 for i in s_op: p[i] += (torch.log(k) - torch.log(kprev)) NASModule.module_call('reset_ops')
def search(out_dir, chkpt_path, w_train_loader, a_train_loader, model, arch, writer, logger, device, config): valid_loader = a_train_loader w_optim = utils.get_optim(model.weights(), config.w_optim) a_optim = utils.get_optim(model.alphas(), config.a_optim) init_epoch = -1 if chkpt_path is not None: logger.info("Resuming from checkpoint: %s" % chkpt_path) checkpoint = torch.load(chkpt_path) model.load_state_dict(checkpoint['model']) NASModule.nasmod_load_state_dict(checkpoint['arch']) w_optim.load_state_dict(checkpoint['w_optim']) a_optim.load_state_dict(checkpoint['a_optim']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) init_epoch = checkpoint['epoch'] else: logger.info("Starting new training run") architect = arch(config, model) # warmup training loop logger.info('begin warmup training') try: if config.warmup_epochs > 0: warmup_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( w_optim, config.warmup_epochs, eta_min=config.w_optim.lr_min) last_epoch = 0 else: last_epoch = -1 tot_epochs = config.warmup_epochs for epoch in itertools.count(init_epoch + 1): if epoch == tot_epochs: break lr = warmup_lr_scheduler.get_lr()[0] # training train(w_train_loader, None, model, writer, logger, architect, w_optim, a_optim, lr, epoch, tot_epochs, device, config) # validation cur_step = (epoch + 1) * len(w_train_loader) top1 = validate(valid_loader, model, writer, logger, epoch, tot_epochs, cur_step, device, config) warmup_lr_scheduler.step() print("") except KeyboardInterrupt: print('skipped') lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( w_optim, config.epochs, eta_min=config.w_optim.lr_min, last_epoch=last_epoch) save_checkpoint(out_dir, model, w_optim, a_optim, lr_scheduler, init_epoch, logger) save_genotype(out_dir, model.genotype(), init_epoch, logger) # training loop logger.info('begin w/a training') best_top1 = 0. tot_epochs = config.epochs for epoch in itertools.count(init_epoch + 1): if epoch == tot_epochs: break lr = lr_scheduler.get_lr()[0] model.print_alphas(logger) # training train(w_train_loader, a_train_loader, model, writer, logger, architect, w_optim, a_optim, lr, epoch, tot_epochs, device, config) # validation cur_step = (epoch + 1) * len(w_train_loader) top1 = validate(valid_loader, model, writer, logger, epoch, tot_epochs, cur_step, device, config) # genotype genotype = model.genotype() save_genotype(out_dir, genotype, epoch, logger) # genotype as image if config.plot: for i, dag in enumerate(model.dags()): plot_path = os.path.join(config.plot_path, "EP{:02d}".format(epoch + 1)) caption = "Epoch {} - DAG {}".format(epoch + 1, i) plot(genotype.dag[i], dag, plot_path + "-dag_{}".format(i), caption) if best_top1 < top1: best_top1 = top1 best_genotype = genotype if config.save_freq != 0 and epoch % config.save_freq == 0: save_checkpoint(out_dir, model, w_optim, a_optim, lr_scheduler, epoch, logger) lr_scheduler.step() print("") logger.info("Final best Prec@1 = {:.4%}".format(best_top1)) logger.info("Best Genotype = {}".format(best_genotype)) tprof.stat_acc('model_' + NASModule.get_device()[0]) gt.to_file(best_genotype, os.path.join(out_dir, 'best.gt'))
def train(train_loader, valid_loader, model, writer, logger, architect, w_optim, a_optim, lr, epoch, tot_epochs, device, config): top1 = utils.AverageMeter() top5 = utils.AverageMeter() losses = utils.AverageMeter() cur_step = epoch * len(train_loader) writer.add_scalar('train/lr', lr, cur_step) model.train() if not valid_loader is None: tr_ratio = len(train_loader) // len(valid_loader) val_iter = iter(valid_loader) eta_m = utils.ETAMeter(tot_epochs, epoch, len(train_loader)) eta_m.start() for step, (trn_X, trn_y) in enumerate(train_loader): trn_X, trn_y = trn_X.to(device, non_blocking=True), trn_y.to(device, non_blocking=True) N = trn_X.size(0) # phase 1. child network step (w) w_optim.zero_grad() tprof.timer_start('train') loss, logits = model.loss_logits(trn_X, trn_y, config.aux_weight) tprof.timer_stop('train') loss.backward() # gradient clipping if config.w_grad_clip > 0: nn.utils.clip_grad_norm_(model.weights(), config.w_grad_clip) w_optim.step() # phase 2. architect step (alpha) if not valid_loader is None and step % tr_ratio == 0: try: val_X, val_y = next(val_iter) except: val_iter = iter(valid_loader) val_X, val_y = next(val_iter) val_X, val_y = val_X.to(device, non_blocking=True), val_y.to( device, non_blocking=True) tprof.timer_start('arch') architect.step(trn_X, trn_y, val_X, val_y, lr, w_optim, a_optim) tprof.timer_stop('arch') prec1, prec5 = utils.accuracy(logits, trn_y, topk=(1, 5)) losses.update(loss.item(), N) top1.update(prec1.item(), N) top5.update(prec5.item(), N) if step != 0 and step % config.print_freq == 0 or step == len( train_loader) - 1: eta = eta_m.step(step) logger.info( "Train: [{:2d}/{}] Step {:03d}/{:03d} LR {:.3f} Loss {losses.avg:.3f} " "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%}) | ETA: {eta}". format(epoch + 1, tot_epochs, step, len(train_loader) - 1, lr, losses=losses, top1=top1, top5=top5, eta=utils.format_time(eta))) writer.add_scalar('train/loss', loss.item(), cur_step) writer.add_scalar('train/top1', prec1.item(), cur_step) writer.add_scalar('train/top5', prec5.item(), cur_step) cur_step += 1 logger.info("Train: [{:2d}/{}] Final Prec@1 {:.4%}".format( epoch + 1, tot_epochs, top1.avg)) tprof.stat_acc('model_' + NASModule.get_device()[0]) tprof.print_stat('train') tprof.print_stat('arch')