def infer(valid_queue, model, criterion): objs = ig_utils.AvgrageMeter() top1 = ig_utils.AvgrageMeter() top5 = ig_utils.AvgrageMeter() model.eval() with torch.no_grad(): for step, (input, target) in enumerate(valid_queue): input = input.cuda() target = target.cuda(non_blocking=True) logits, _ = model(input) loss = criterion(logits, target) prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5)) n = input.size(0) objs.update(loss.data, n) top1.update(prec1.data, n) top5.update(prec5.data, n) if step % args.report_freq == 0: logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) if args.fast: logging.info('//// WARNING: FAST MODE') break return top1.avg, objs.avg
def infer(valid_queue, model, log=True, _eval=True, weights_dict=None): objs = ig_utils.AvgrageMeter() top1 = ig_utils.AvgrageMeter() top5 = ig_utils.AvgrageMeter() model.eval() if _eval else model.train() # disable running stats for projection with torch.no_grad(): for step, (input, target) in enumerate(valid_queue): input = input.cuda() target = target.cuda(non_blocking=True) if weights_dict is None: loss, logits = model._loss(input, target, return_logits=True) else: logits = model(input, weights_dict=weights_dict) loss = model._criterion(logits, target) prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5)) n = input.size(0) objs.update(loss.data.item(), n) top1.update(prec1.data.item(), n) top5.update(prec5.data.item(), n) if step % args.report_freq == 0 and log: logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) if args.fast: break return top1.avg, objs.avg
def train(train_queue, model, criterion, optimizer): objs = ig_utils.AvgrageMeter() top1 = ig_utils.AvgrageMeter() top5 = ig_utils.AvgrageMeter() model.train() for step, (input, target) in enumerate(train_queue): input = input.cuda() target = target.cuda(non_blocking=True) optimizer.zero_grad() logits, logits_aux = model(input) loss = criterion(logits, target) if args.auxiliary: loss_aux = criterion(logits_aux, target) loss += args.auxiliary_weight * loss_aux loss.backward() nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5)) n = input.size(0) objs.update(loss.data, n) top1.update(prec1.data, n) top5.update(prec5.data, n) if step % args.report_freq == 0: logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) if args.fast: logging.info('//// WARNING: FAST MODE') break return top1.avg, objs.avg
def infer(valid_queue, model, criterion, log=True, eval=True, weights=None, double=False, bn_est=False): objs = ig_utils.AvgrageMeter() top1 = ig_utils.AvgrageMeter() top5 = ig_utils.AvgrageMeter() model.eval() if eval else model.train( ) # disable running stats for projection if bn_est: _data_loader = deepcopy(valid_queue) for step, (input, target) in enumerate(_data_loader): input = input.cuda() target = target.cuda(non_blocking=True) with torch.no_grad(): logits = model(input) model.eval() with torch.no_grad(): for step, (input, target) in enumerate(valid_queue): input = input.cuda() target = target.cuda(non_blocking=True) if double: input = input.double() target = target.double() logits = model(input) if weights is None else model( input, weights=weights) loss = criterion(logits, target) prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5)) n = input.size(0) objs.update(loss.data, n) top1.update(prec1.data, n) top5.update(prec5.data, n) if log and step % args.report_freq == 0: logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) if args.fast: break return top1.avg, objs.avg
def train(train_queue, valid_queue, model, architect, optimizer, lr, epoch, perturb_alpha, epsilon_alpha): objs = ig_utils.AvgrageMeter() top1 = ig_utils.AvgrageMeter() top5 = ig_utils.AvgrageMeter() for step in range(len(train_queue)): model.train() ## data input, target = next(iter(train_queue)) input = input.cuda(); target = target.cuda(non_blocking=True) input_search, target_search = next(iter(valid_queue)) input_search = input_search.cuda(); target_search = target_search.cuda(non_blocking=True) ## train alpha optimizer.zero_grad(); architect.optimizer.zero_grad() architect.step(input, target, input_search, target_search, lr, optimizer) ## sdarts if perturb_alpha: # transform arch_parameters to prob (for perturbation) model.softmax_arch_parameters() optimizer.zero_grad(); architect.optimizer.zero_grad() perturb_alpha(model, input, target, epsilon_alpha) ## train weights optimizer.zero_grad(); architect.optimizer.zero_grad() logits, loss = model.step(input, target, args) ## sdarts if perturb_alpha: ## restore alpha to unperturbed arch_parameters model.restore_arch_parameters() ## logging n = input.size(0) prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5)) objs.update(loss.data.item(), n) top1.update(prec1.data.item(), n) top5.update(prec5.data.item(), n) if step % args.report_freq == 0: logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) if args.fast: break return top1.avg, objs.avg
def train(train_queue, valid_queue, model, architect, optimizer, lr, epoch): objs = ig_utils.AvgrageMeter() top1 = ig_utils.AvgrageMeter() top5 = ig_utils.AvgrageMeter() for step in range(len(train_queue)): model.train() ## data input, target = next(iter(train_queue)) input = input.cuda() target = target.cuda(non_blocking=True) input_search, target_search = next(iter(valid_queue)) input_search = input_search.cuda() target_search = target_search.cuda(non_blocking=True) ## train alpha optimizer.zero_grad() architect.optimizer.zero_grad() shared = architect.step(input, target, input_search, target_search, eta=lr, network_optimizer=optimizer) ## train weight optimizer.zero_grad() architect.optimizer.zero_grad() logits, loss = model.step(input, target, args, shared=shared) ## logging prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5)) n = input.size(0) objs.update(loss.data, n) top1.update(prec1.data, n) top5.update(prec5.data, n) if step % args.report_freq == 0: logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) if args.fast: break return top1.avg, objs.avg
def pt_project(train_queue, valid_queue, model, architect, optimizer, epoch, args, infer, perturb_alpha, epsilon_alpha): model.train() model.printing(logging) train_acc, train_obj = infer(train_queue, model, log=False) logging.info('train_acc %f', train_acc) logging.info('train_loss %f', train_obj) valid_acc, valid_obj = infer(valid_queue, model, log=False) logging.info('valid_acc %f', valid_acc) logging.info('valid_loss %f', valid_obj) objs = ig_utils.AvgrageMeter() top1 = ig_utils.AvgrageMeter() top5 = ig_utils.AvgrageMeter() #### macros num_projs = model.num_edges + len(model.nid2eids.keys()) - 1 ## -1 because we project at both epoch 0 and -1 tune_epochs = args.proj_intv * num_projs + 1 proj_intv = args.proj_intv args.proj_crit = {'normal':args.proj_crit_normal, 'reduce':args.proj_crit_reduce} proj_queue = valid_queue #### reset optimizer model.reset_optimizer(args.learning_rate / 10, args.momentum, args.weight_decay) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( model.optimizer, float(tune_epochs), eta_min=args.learning_rate_min) #### load proj checkpoints start_epoch = 0 if args.dev_resume_epoch >= 0: filename = os.path.join(args.dev_resume_checkpoint_dir, 'checkpoint_{}.pth.tar'.format(args.dev_resume_epoch)) if os.path.isfile(filename): logging.info("=> loading projection checkpoint '{}'".format(filename)) checkpoint = torch.load(filename, map_location='cpu') start_epoch = checkpoint['epoch'] model.set_state_dict(architect, scheduler, checkpoint) model.set_arch_parameters(checkpoint['alpha']) scheduler.load_state_dict(checkpoint['scheduler']) model.optimizer.load_state_dict(checkpoint['optimizer']) # optimizer else: logging.info("=> no checkpoint found at '{}'".format(filename)) exit(0) #### projecting and tuning for epoch in range(start_epoch, tune_epochs): logging.info('epoch %d', epoch) ## project if epoch % proj_intv == 0 or epoch == tune_epochs - 1: ## saving every projection save_state_dict = model.get_state_dict(epoch, architect, scheduler) ig_utils.save_checkpoint(save_state_dict, False, args.dev_save_checkpoint_dir, per_epoch=True) if epoch < proj_intv * model.num_edges: logging.info('project op') selected_eid_normal, best_opid_normal = project_op(model, proj_queue, args, infer, cell_type='normal') model.project_op(selected_eid_normal, best_opid_normal, cell_type='normal') selected_eid_reduce, best_opid_reduce = project_op(model, proj_queue, args, infer, cell_type='reduce') model.project_op(selected_eid_reduce, best_opid_reduce, cell_type='reduce') model.printing(logging) else: logging.info('project edge') selected_nid_normal, eids_normal = project_edge(model, proj_queue, args, infer, cell_type='normal') model.project_edge(selected_nid_normal, eids_normal, cell_type='normal') selected_nid_reduce, eids_reduce = project_edge(model, proj_queue, args, infer, cell_type='reduce') model.project_edge(selected_nid_reduce, eids_reduce, cell_type='reduce') model.printing(logging) ## tune for step, (input, target) in enumerate(train_queue): model.train() n = input.size(0) ## fetch data input = input.cuda() target = target.cuda(non_blocking=True) input_search, target_search = next(iter(valid_queue)) input_search = input_search.cuda() target_search = target_search.cuda(non_blocking=True) ## train alpha optimizer.zero_grad(); architect.optimizer.zero_grad() architect.step(input, target, input_search, target_search, return_logits=True) ## sdarts if perturb_alpha: # transform arch_parameters to prob (for perturbation) model.softmax_arch_parameters() optimizer.zero_grad(); architect.optimizer.zero_grad() perturb_alpha(model, input, target, epsilon_alpha) ## train weight optimizer.zero_grad(); architect.optimizer.zero_grad() logits, loss = model.step(input, target, args) ## sdarts if perturb_alpha: ## restore alpha to unperturbed arch_parameters model.restore_arch_parameters() ## logging prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5)) objs.update(loss.data, n) top1.update(prec1.data, n) top5.update(prec5.data, n) if step % args.report_freq == 0: logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) if args.fast: break ## one epoch end model.printing(logging) train_acc, train_obj = infer(train_queue, model, log=False) logging.info('train_acc %f', train_acc) logging.info('train_loss %f', train_obj) valid_acc, valid_obj = infer(valid_queue, model, log=False) logging.info('valid_acc %f', valid_acc) logging.info('valid_loss %f', valid_obj) logging.info('projection finished') model.printing(logging) num_params = ig_utils.count_parameters_in_Compact(model) genotype = model.genotype() logging.info('param size = %f', num_params) logging.info('genotype = %s', genotype) return
def pt_project(train_queue, valid_queue, model, architect, criterion, optimizer, epoch, args, infer, query): def project(model, args): ## macros num_edge, num_op = model.num_edge, model.num_op ## select an edge remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0] if args.edge_decision == "random": selected_eid = np.random.choice(remain_eids, size=1)[0] ## select the best operation if args.proj_crit == 'loss': crit_idx = 1 compare = lambda x, y: x > y if args.proj_crit == 'acc': crit_idx = 0 compare = lambda x, y: x < y best_opid = 0 crit_extrema = None for opid in range(num_op): ## projection weights = model.get_projected_weights() proj_mask = torch.ones_like(weights[selected_eid]) proj_mask[opid] = 0 weights[selected_eid] = weights[selected_eid] * proj_mask ## proj evaluation valid_stats = infer(valid_queue, model, criterion, log=False, eval=False, weights=weights) crit = valid_stats[crit_idx] if crit_extrema is None or compare(crit, crit_extrema): crit_extrema = crit best_opid = opid logging.info('valid_acc %f', valid_stats[0]) logging.info('valid_loss %f', valid_stats[1]) logging.info('best opid %d', best_opid) return selected_eid, best_opid ## query if not args.fast: api = API('../data/NAS-Bench-201-v1_0-e61699.pth') model.train() model.printing(logging) train_acc, train_obj = infer(train_queue, model, criterion, log=False) logging.info('train_acc %f', train_acc) logging.info('train_loss %f', train_obj) valid_acc, valid_obj = infer(valid_queue, model, criterion, log=False) logging.info('valid_acc %f', valid_acc) logging.info('valid_loss %f', valid_obj) objs = ig_utils.AvgrageMeter() top1 = ig_utils.AvgrageMeter() top5 = ig_utils.AvgrageMeter() num_edges = model.arch_parameters()[0].shape[0] proj_intv = args.proj_intv tune_epochs = proj_intv * (num_edges - 1) model.reset_optimizer(args.learning_rate / 10, args.momentum, args.weight_decay) for epoch in range(tune_epochs): logging.info('epoch %d', epoch) if epoch % proj_intv == 0 or epoch == tune_epochs - 1: logging.info('project') selected_eid, best_opid = project(model, args) model.project_op(selected_eid, best_opid) model.printing(logging) for step, (input, target) in enumerate(train_queue): model.train() n = input.size(0) ## fetch data input = input.cuda() target = target.cuda(non_blocking=True) input_search, target_search = next(iter(valid_queue)) input_search = input_search.cuda() target_search = target_search.cuda(non_blocking=True) ## train alpha optimizer.zero_grad() architect.optimizer.zero_grad() shared = architect.step(input, target, input_search, target_search, return_logits=True) ## train weight optimizer.zero_grad() architect.optimizer.zero_grad() logits, loss = model.step(input, target, args, shared=shared) ## logging prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5)) objs.update(loss.data, n) top1.update(prec1.data, n) top5.update(prec5.data, n) if step % args.report_freq == 0: logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) if args.fast: break ## one epoch end model.printing(logging) train_acc, train_obj = infer(train_queue, model, criterion, log=False) logging.info('train_acc %f', train_acc) logging.info('train_loss %f', train_obj) valid_acc, valid_obj = infer(valid_queue, model, criterion, log=False) logging.info('valid_acc %f', valid_acc) logging.info('valid_loss %f', valid_obj) # nasbench201 if not args.fast: query(api, model.genotype(), logging) return