示例#1
0
def infer(valid_queue, model, criterion):
    objs = ig_utils.AvgrageMeter()
    top1 = ig_utils.AvgrageMeter()
    top5 = ig_utils.AvgrageMeter()
    model.eval()

    with torch.no_grad():
        for step, (input, target) in enumerate(valid_queue):
            input = input.cuda()
            target = target.cuda(non_blocking=True)

            logits, _ = model(input)
            loss = criterion(logits, target)

            prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5))
            n = input.size(0)
            objs.update(loss.data, n)
            top1.update(prec1.data, n)
            top5.update(prec5.data, n)

            if step % args.report_freq == 0:
                logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg,
                             top5.avg)

            if args.fast:
                logging.info('//// WARNING: FAST MODE')
                break

    return top1.avg, objs.avg
示例#2
0
def infer(valid_queue, model, log=True, _eval=True, weights_dict=None):
    objs = ig_utils.AvgrageMeter()
    top1 = ig_utils.AvgrageMeter()
    top5 = ig_utils.AvgrageMeter()
    model.eval() if _eval else model.train() # disable running stats for projection

    with torch.no_grad():
        for step, (input, target) in enumerate(valid_queue):
            input = input.cuda()
            target = target.cuda(non_blocking=True)
            
            if weights_dict is None:
                loss, logits = model._loss(input, target, return_logits=True)
            else:
                logits = model(input, weights_dict=weights_dict)
                loss = model._criterion(logits, target)

            prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5))
            n = input.size(0)
            objs.update(loss.data.item(), n)
            top1.update(prec1.data.item(), n)
            top5.update(prec5.data.item(), n)

            if step % args.report_freq == 0 and log:
                logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)

            if args.fast:
                break

    return top1.avg, objs.avg
示例#3
0
def train(train_queue, model, criterion, optimizer):
    objs = ig_utils.AvgrageMeter()
    top1 = ig_utils.AvgrageMeter()
    top5 = ig_utils.AvgrageMeter()
    model.train()

    for step, (input, target) in enumerate(train_queue):
        input = input.cuda()
        target = target.cuda(non_blocking=True)

        optimizer.zero_grad()
        logits, logits_aux = model(input)
        loss = criterion(logits, target)
        if args.auxiliary:
            loss_aux = criterion(logits_aux, target)
            loss += args.auxiliary_weight * loss_aux
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
        optimizer.step()

        prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5))
        n = input.size(0)
        objs.update(loss.data, n)
        top1.update(prec1.data, n)
        top5.update(prec5.data, n)

        if step % args.report_freq == 0:
            logging.info('train %03d %e %f %f', step, objs.avg, top1.avg,
                         top5.avg)

        if args.fast:
            logging.info('//// WARNING: FAST MODE')
            break

    return top1.avg, objs.avg
示例#4
0
def infer(valid_queue,
          model,
          criterion,
          log=True,
          eval=True,
          weights=None,
          double=False,
          bn_est=False):
    objs = ig_utils.AvgrageMeter()
    top1 = ig_utils.AvgrageMeter()
    top5 = ig_utils.AvgrageMeter()
    model.eval() if eval else model.train(
    )  # disable running stats for projection

    if bn_est:
        _data_loader = deepcopy(valid_queue)
        for step, (input, target) in enumerate(_data_loader):
            input = input.cuda()
            target = target.cuda(non_blocking=True)
            with torch.no_grad():
                logits = model(input)
        model.eval()

    with torch.no_grad():
        for step, (input, target) in enumerate(valid_queue):
            input = input.cuda()
            target = target.cuda(non_blocking=True)
            if double:
                input = input.double()
                target = target.double()

            logits = model(input) if weights is None else model(
                input, weights=weights)
            loss = criterion(logits, target)

            prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5))
            n = input.size(0)
            objs.update(loss.data, n)
            top1.update(prec1.data, n)
            top5.update(prec5.data, n)

            if log and step % args.report_freq == 0:
                logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg,
                             top5.avg)

            if args.fast:
                break

    return top1.avg, objs.avg
示例#5
0
def train(train_queue, valid_queue, model, architect, optimizer, lr, epoch,
          perturb_alpha, epsilon_alpha):
    objs = ig_utils.AvgrageMeter()
    top1 = ig_utils.AvgrageMeter()
    top5 = ig_utils.AvgrageMeter()

    for step in range(len(train_queue)):
        model.train()

        ## data
        input, target = next(iter(train_queue))
        input = input.cuda(); target = target.cuda(non_blocking=True)
        input_search, target_search = next(iter(valid_queue))
        input_search = input_search.cuda(); target_search = target_search.cuda(non_blocking=True)

        ## train alpha
        optimizer.zero_grad(); architect.optimizer.zero_grad()
        architect.step(input, target, input_search, target_search, lr, optimizer)

        ## sdarts
        if perturb_alpha:
            # transform arch_parameters to prob (for perturbation)
            model.softmax_arch_parameters()
            optimizer.zero_grad(); architect.optimizer.zero_grad()
            perturb_alpha(model, input, target, epsilon_alpha)

        ## train weights
        optimizer.zero_grad(); architect.optimizer.zero_grad()
        logits, loss = model.step(input, target, args)
        
        ## sdarts
        if perturb_alpha:
            ## restore alpha to unperturbed arch_parameters
            model.restore_arch_parameters()

        ## logging
        n = input.size(0)
        prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5))
        objs.update(loss.data.item(), n)
        top1.update(prec1.data.item(), n)
        top5.update(prec5.data.item(), n)
        if step % args.report_freq == 0:
            logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)

        if args.fast:
            break

    return  top1.avg, objs.avg
示例#6
0
def train(train_queue, valid_queue, model, architect, optimizer, lr, epoch):
    objs = ig_utils.AvgrageMeter()
    top1 = ig_utils.AvgrageMeter()
    top5 = ig_utils.AvgrageMeter()

    for step in range(len(train_queue)):
        model.train()

        ## data
        input, target = next(iter(train_queue))
        input = input.cuda()
        target = target.cuda(non_blocking=True)
        input_search, target_search = next(iter(valid_queue))
        input_search = input_search.cuda()
        target_search = target_search.cuda(non_blocking=True)

        ## train alpha
        optimizer.zero_grad()
        architect.optimizer.zero_grad()
        shared = architect.step(input,
                                target,
                                input_search,
                                target_search,
                                eta=lr,
                                network_optimizer=optimizer)

        ## train weight
        optimizer.zero_grad()
        architect.optimizer.zero_grad()
        logits, loss = model.step(input, target, args, shared=shared)

        ## logging
        prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5))
        n = input.size(0)
        objs.update(loss.data, n)
        top1.update(prec1.data, n)
        top5.update(prec5.data, n)

        if step % args.report_freq == 0:
            logging.info('train %03d %e %f %f', step, objs.avg, top1.avg,
                         top5.avg)

        if args.fast:
            break

    return top1.avg, objs.avg
示例#7
0
def pt_project(train_queue, valid_queue, model, architect, optimizer,
               epoch, args, infer, perturb_alpha, epsilon_alpha):
    model.train()
    model.printing(logging)

    train_acc, train_obj = infer(train_queue, model, log=False)
    logging.info('train_acc  %f', train_acc)
    logging.info('train_loss %f', train_obj)

    valid_acc, valid_obj = infer(valid_queue, model, log=False)
    logging.info('valid_acc  %f', valid_acc)
    logging.info('valid_loss %f', valid_obj)

    objs = ig_utils.AvgrageMeter()
    top1 = ig_utils.AvgrageMeter()
    top5 = ig_utils.AvgrageMeter()


    #### macros
    num_projs = model.num_edges + len(model.nid2eids.keys()) - 1 ## -1 because we project at both epoch 0 and -1
    tune_epochs = args.proj_intv * num_projs + 1
    proj_intv = args.proj_intv
    args.proj_crit = {'normal':args.proj_crit_normal, 'reduce':args.proj_crit_reduce}
    proj_queue = valid_queue


    #### reset optimizer
    model.reset_optimizer(args.learning_rate / 10, args.momentum, args.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        model.optimizer, float(tune_epochs), eta_min=args.learning_rate_min)


    #### load proj checkpoints
    start_epoch = 0
    if args.dev_resume_epoch >= 0:
        filename = os.path.join(args.dev_resume_checkpoint_dir, 'checkpoint_{}.pth.tar'.format(args.dev_resume_epoch))
        if os.path.isfile(filename):
            logging.info("=> loading projection checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename, map_location='cpu')
            start_epoch = checkpoint['epoch']
            model.set_state_dict(architect, scheduler, checkpoint)
            model.set_arch_parameters(checkpoint['alpha'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            model.optimizer.load_state_dict(checkpoint['optimizer']) # optimizer
        else:
            logging.info("=> no checkpoint found at '{}'".format(filename))
            exit(0)


    #### projecting and tuning
    for epoch in range(start_epoch, tune_epochs):
        logging.info('epoch %d', epoch)
        
        ## project
        if epoch % proj_intv == 0 or epoch == tune_epochs - 1:
            ## saving every projection
            save_state_dict = model.get_state_dict(epoch, architect, scheduler)
            ig_utils.save_checkpoint(save_state_dict, False, args.dev_save_checkpoint_dir, per_epoch=True)

            if epoch < proj_intv * model.num_edges:
                logging.info('project op')
                
                selected_eid_normal, best_opid_normal = project_op(model, proj_queue, args, infer, cell_type='normal')
                model.project_op(selected_eid_normal, best_opid_normal, cell_type='normal')
                selected_eid_reduce, best_opid_reduce = project_op(model, proj_queue, args, infer, cell_type='reduce')
                model.project_op(selected_eid_reduce, best_opid_reduce, cell_type='reduce')

                model.printing(logging)
            else:
                logging.info('project edge')
                
                selected_nid_normal, eids_normal = project_edge(model, proj_queue, args, infer, cell_type='normal')
                model.project_edge(selected_nid_normal, eids_normal, cell_type='normal')
                selected_nid_reduce, eids_reduce = project_edge(model, proj_queue, args, infer, cell_type='reduce')
                model.project_edge(selected_nid_reduce, eids_reduce, cell_type='reduce')

                model.printing(logging)

        ## tune
        for step, (input, target) in enumerate(train_queue):
            model.train()
            n = input.size(0)

            ## fetch data
            input = input.cuda()
            target = target.cuda(non_blocking=True)
            input_search, target_search = next(iter(valid_queue))
            input_search = input_search.cuda()
            target_search = target_search.cuda(non_blocking=True)

            ## train alpha
            optimizer.zero_grad(); architect.optimizer.zero_grad()
            architect.step(input, target, input_search, target_search,
                           return_logits=True)

            ## sdarts
            if perturb_alpha:
                # transform arch_parameters to prob (for perturbation)
                model.softmax_arch_parameters()
                optimizer.zero_grad(); architect.optimizer.zero_grad()
                perturb_alpha(model, input, target, epsilon_alpha)

            ## train weight
            optimizer.zero_grad(); architect.optimizer.zero_grad()
            logits, loss = model.step(input, target, args)

            ## sdarts
            if perturb_alpha:
                ## restore alpha to unperturbed arch_parameters
                model.restore_arch_parameters()

            ## logging
            prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5))
            objs.update(loss.data, n)
            top1.update(prec1.data, n)
            top5.update(prec5.data, n)
            if step % args.report_freq == 0:
                logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)

            if args.fast:
                break

        ## one epoch end
        model.printing(logging)

        train_acc, train_obj = infer(train_queue, model, log=False)
        logging.info('train_acc  %f', train_acc)
        logging.info('train_loss %f', train_obj)

        valid_acc, valid_obj = infer(valid_queue, model, log=False)
        logging.info('valid_acc  %f', valid_acc)
        logging.info('valid_loss %f', valid_obj)


    logging.info('projection finished')
    model.printing(logging)
    num_params = ig_utils.count_parameters_in_Compact(model)
    genotype = model.genotype()
    logging.info('param size = %f', num_params)
    logging.info('genotype = %s', genotype)

    return
示例#8
0
def pt_project(train_queue, valid_queue, model, architect, criterion,
               optimizer, epoch, args, infer, query):
    def project(model, args):
        ## macros
        num_edge, num_op = model.num_edge, model.num_op

        ## select an edge
        remain_eids = torch.nonzero(model.candidate_flags).cpu().numpy().T[0]
        if args.edge_decision == "random":
            selected_eid = np.random.choice(remain_eids, size=1)[0]

        ## select the best operation
        if args.proj_crit == 'loss':
            crit_idx = 1
            compare = lambda x, y: x > y
        if args.proj_crit == 'acc':
            crit_idx = 0
            compare = lambda x, y: x < y

        best_opid = 0
        crit_extrema = None
        for opid in range(num_op):
            ## projection
            weights = model.get_projected_weights()
            proj_mask = torch.ones_like(weights[selected_eid])
            proj_mask[opid] = 0
            weights[selected_eid] = weights[selected_eid] * proj_mask

            ## proj evaluation
            valid_stats = infer(valid_queue,
                                model,
                                criterion,
                                log=False,
                                eval=False,
                                weights=weights)
            crit = valid_stats[crit_idx]

            if crit_extrema is None or compare(crit, crit_extrema):
                crit_extrema = crit
                best_opid = opid
            logging.info('valid_acc %f', valid_stats[0])
            logging.info('valid_loss %f', valid_stats[1])

        logging.info('best opid %d', best_opid)
        return selected_eid, best_opid

    ## query
    if not args.fast:
        api = API('../data/NAS-Bench-201-v1_0-e61699.pth')

    model.train()
    model.printing(logging)

    train_acc, train_obj = infer(train_queue, model, criterion, log=False)
    logging.info('train_acc  %f', train_acc)
    logging.info('train_loss %f', train_obj)

    valid_acc, valid_obj = infer(valid_queue, model, criterion, log=False)
    logging.info('valid_acc  %f', valid_acc)
    logging.info('valid_loss %f', valid_obj)

    objs = ig_utils.AvgrageMeter()
    top1 = ig_utils.AvgrageMeter()
    top5 = ig_utils.AvgrageMeter()

    num_edges = model.arch_parameters()[0].shape[0]
    proj_intv = args.proj_intv
    tune_epochs = proj_intv * (num_edges - 1)
    model.reset_optimizer(args.learning_rate / 10, args.momentum,
                          args.weight_decay)

    for epoch in range(tune_epochs):
        logging.info('epoch %d', epoch)

        if epoch % proj_intv == 0 or epoch == tune_epochs - 1:
            logging.info('project')
            selected_eid, best_opid = project(model, args)
            model.project_op(selected_eid, best_opid)
            model.printing(logging)

        for step, (input, target) in enumerate(train_queue):
            model.train()
            n = input.size(0)

            ## fetch data
            input = input.cuda()
            target = target.cuda(non_blocking=True)
            input_search, target_search = next(iter(valid_queue))
            input_search = input_search.cuda()
            target_search = target_search.cuda(non_blocking=True)

            ## train alpha
            optimizer.zero_grad()
            architect.optimizer.zero_grad()
            shared = architect.step(input,
                                    target,
                                    input_search,
                                    target_search,
                                    return_logits=True)

            ## train weight
            optimizer.zero_grad()
            architect.optimizer.zero_grad()
            logits, loss = model.step(input, target, args, shared=shared)

            ## logging
            prec1, prec5 = ig_utils.accuracy(logits, target, topk=(1, 5))
            objs.update(loss.data, n)
            top1.update(prec1.data, n)
            top5.update(prec5.data, n)

            if step % args.report_freq == 0:
                logging.info('train %03d %e %f %f', step, objs.avg, top1.avg,
                             top5.avg)

            if args.fast:
                break

        ## one epoch end
        model.printing(logging)

        train_acc, train_obj = infer(train_queue, model, criterion, log=False)
        logging.info('train_acc  %f', train_acc)
        logging.info('train_loss %f', train_obj)

        valid_acc, valid_obj = infer(valid_queue, model, criterion, log=False)
        logging.info('valid_acc  %f', valid_acc)
        logging.info('valid_loss %f', valid_obj)

    # nasbench201
    if not args.fast:
        query(api, model.genotype(), logging)

    return