コード例 #1
0
def valid_func(xloader, network, criterion):
    data_time, batch_time = AverageMeter(), AverageMeter()
    arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    network.eval()
    end = time.time()
    with torch.no_grad():
        for step, (arch_inputs, arch_targets) in enumerate(xloader):
            arch_targets = arch_targets.cuda(non_blocking=True)
            # measure data loading time
            data_time.update(time.time() - end)
            # prediction
            _, logits = network(arch_inputs)
            arch_loss = criterion(logits, arch_targets)
            # record
            arch_prec1, arch_prec5 = obtain_accuracy(logits.data,
                                                     arch_targets.data,
                                                     topk=(1, 5))
            arch_losses.update(arch_loss.item(), arch_inputs.size(0))
            arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
            arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
    return arch_losses.avg, arch_top1.avg, arch_top5.avg
コード例 #2
0
def procedure(xloader, network, criterion, scheduler, optimizer, mode: str):
    losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
    if mode == 'train':
        network.train()
    elif mode == 'valid':
        network.eval()
    else:
        raise ValueError("The mode is not right : {:}".format(mode))
    device = torch.cuda.current_device()
    data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
    for i, (inputs, targets) in enumerate(xloader):
        if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))

        targets = targets.cuda(device=device, non_blocking=True)
        if mode == 'train': optimizer.zero_grad()
        # forward
        features, logits = network(inputs)
        loss = criterion(logits, targets)
        # backward
        if mode == 'train':
            loss.backward()
            optimizer.step()
        # record loss and accuracy
        prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))
        # count time
        batch_time.update(time.time() - end)
        end = time.time()
    return losses.avg, top1.avg, top5.avg, batch_time.sum
コード例 #3
0
def test_contrastive(args, model, nearest_proto_model, device,
                     test_loader_creator_l, logger):
    model.eval()

    acc = AverageMeter()
    tasks_acc = [
        AverageMeter() for i in range(len(test_loader_creator_l.data_loaders))
    ]

    test_loaders_l = test_loader_creator_l.data_loaders

    with torch.no_grad():
        for task_idx, test_loader_l in enumerate(test_loaders_l):

            for batch_idx, (data, _, target) in enumerate(test_loader_l):
                data, target = data.to(device), target.to(device)
                cur_feats, _ = model(data)
                output = nearest_proto_model.predict(cur_feats)
                it_acc = (output == target).sum().item() / data.shape[0]
                acc.update(it_acc, data.size(0))
                tasks_acc[task_idx].update(it_acc, data.size(0))

    if args.acc_per_task:
        tasks_acc_str = 'Tess Acc per task: '
        for i, task_acc in enumerate(tasks_acc):
            tasks_acc_str += 'Task{:2d} Acc: {acc.avg:.3f}'.format(
                (i + 1), acc=task_acc) + '\t'
        logger.info(tasks_acc_str)
    logger.info('Test Acc: {acc.avg:.3f}'.format(acc=acc))
コード例 #4
0
 def train_or_test_epoch(self,
                         xloader,
                         model,
                         loss_fn,
                         metric_fn,
                         is_train,
                         optimizer=None):
     if is_train:
         model.train()
     else:
         model.eval()
     score_meter, loss_meter = AverageMeter(), AverageMeter()
     for ibatch, (feats, labels) in enumerate(xloader):
         feats = feats.to(self.device, non_blocking=True)
         labels = labels.to(self.device, non_blocking=True)
         # forward the network
         preds = model(feats)
         loss = loss_fn(preds, labels)
         with torch.no_grad():
             score = self.metric_fn(preds, labels)
             loss_meter.update(loss.item(), feats.size(0))
             score_meter.update(score.item(), feats.size(0))
         # optimize the network
         if is_train and optimizer is not None:
             optimizer.zero_grad()
             loss.backward()
             torch.nn.utils.clip_grad_value_(model.parameters(), 3.0)
             optimizer.step()
     return loss_meter.avg, score_meter.avg
コード例 #5
0
ファイル: test.py プロジェクト: salarim/continual-learning
def test(args, model, device, test_loader_creator, logger):
    model.eval()

    criterion = torch.nn.CrossEntropyLoss().to(device)

    with torch.no_grad():
        losses = AverageMeter()
        acc = AverageMeter()

        for test_loader in test_loader_creator.data_loaders:

            for data, target in test_loader:

                data, target = data.to(device), target.to(device)
                _, output = model(data)

                loss = criterion(output, target)

                output = output.float()
                loss = loss.float()

                it_acc = accuracy(output.data, target)[0]
                losses.update(loss.item(), data.size(0))
                acc.update(it_acc.item(), data.size(0))

    logger.info('Test set: Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                'Acc {acc.avg:.3f}'.format(loss=losses, acc=acc))
コード例 #6
0
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
    data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
    losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
    latencies, device = [], torch.cuda.current_device()
    network.eval()
    with torch.no_grad():
        end = time.time()
        for i, (inputs, targets) in enumerate(xloader):
            targets = targets.cuda(device=device, non_blocking=True)
            inputs = inputs.cuda(device=device, non_blocking=True)
            data_time.update(time.time() - end)
            # forward
            features, logits = network(inputs)
            loss = criterion(logits, targets)
            batch_time.update(time.time() - end)
            if batch is None or batch == inputs.size(0):
                batch = inputs.size(0)
                latencies.append(batch_time.val - data_time.val)
            # record loss and accuracy
            prec1, prec5 = obtain_accuracy(logits.data,
                                           targets.data,
                                           topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))
            end = time.time()
    if len(latencies) > 2: latencies = latencies[1:]
    return losses.avg, top1.avg, top5.avg, latencies
コード例 #7
0
def train_shared_cnn(
    xloader,
    shared_cnn,
    controller,
    criterion,
    scheduler,
    optimizer,
    epoch_str,
    print_freq,
    logger,
):
    data_time, batch_time = AverageMeter(), AverageMeter()
    losses, top1s, top5s, xend = (
        AverageMeter(),
        AverageMeter(),
        AverageMeter(),
        time.time(),
    )

    shared_cnn.train()
    controller.eval()

    for step, (inputs, targets) in enumerate(xloader):
        scheduler.update(None, 1.0 * step / len(xloader))
        targets = targets.cuda(non_blocking=True)
        # measure data loading time
        data_time.update(time.time() - xend)

        with torch.no_grad():
            _, _, sampled_arch = controller()

        optimizer.zero_grad()
        shared_cnn.module.update_arch(sampled_arch)
        _, logits = shared_cnn(inputs)
        loss = criterion(logits, targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), 5)
        optimizer.step()
        # record
        prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1s.update(prec1.item(), inputs.size(0))
        top5s.update(prec5.item(), inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - xend)
        xend = time.time()

        if step % print_freq == 0 or step + 1 == len(xloader):
            Sstr = (
                "*Train-Shared-CNN* " + time_string() +
                " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)))
            Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
                batch_time=batch_time, data_time=data_time)
            Wstr = "[Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
                loss=losses, top1=top1s, top5=top5s)
            logger.log(Sstr + " " + Tstr + " " + Wstr)
    return losses.avg, top1s.avg, top5s.avg
コード例 #8
0
def search_valid(xloader, network, criterion, extra_info, print_freq, logger):
    data_time, batch_time, losses, top1, top5 = (
        AverageMeter(),
        AverageMeter(),
        AverageMeter(),
        AverageMeter(),
        AverageMeter(),
    )

    network.eval()
    network.apply(change_key("search_mode", "search"))
    end = time.time()
    # logger.log('Starting evaluating {:}'.format(epoch_info))
    with torch.no_grad():
        for i, (inputs, targets) in enumerate(xloader):
            # measure data loading time
            data_time.update(time.time() - end)
            # calculate prediction and loss
            targets = targets.cuda(non_blocking=True)

            logits, expected_flop = network(inputs)
            loss = criterion(logits, targets)
            # record
            prec1, prec5 = obtain_accuracy(logits.data,
                                           targets.data,
                                           topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0 or (i + 1) == len(xloader):
                Sstr = ("**VALID** " + time_string() +
                        " [{:}][{:03d}/{:03d}]".format(extra_info, i,
                                                       len(xloader)))
                Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
                    batch_time=batch_time, data_time=data_time)
                Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
                    loss=losses, top1=top1, top5=top5)
                Istr = "Size={:}".format(list(inputs.size()))
                logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr)

    logger.log(
        " **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}"
        .format(
            top1=top1,
            top5=top5,
            error1=100 - top1.avg,
            error5=100 - top5.avg,
            loss=losses.avg,
        ))

    return losses.avg, top1.avg, top5.avg
コード例 #9
0
ファイル: ENAS.py プロジェクト: city292/NAS-Projects
def train_shared_cnn(xloader, shared_cnn, controller, criterion, scheduler,
                     optimizer, epoch_str, print_freq, logger):
    data_time, batch_time = AverageMeter(), AverageMeter()
    losses, top1s, top5s, xend = AverageMeter(), AverageMeter(), AverageMeter(
    ), time.time()

    shared_cnn.train()
    controller.eval()
    ne = 10

    for ni in range(ne):
        with torch.no_grad():
            _, _, sampled_arch = controller()
        shared_cnn.module.update_arch(sampled_arch)
        print(sampled_arch)
        # arch_str = op_list2str(sampled_arch)
        for step, (inputs, targets) in enumerate(xloader):
            # print(step,inputs,targets)
            scheduler.update(None, 1.0 * step / len(xloader))
            targets = targets.cuda(non_blocking=True)
            # measure data loading time
            data_time.update(time.time() - xend)

            optimizer.zero_grad()

            _, logits = shared_cnn(inputs)
            loss = criterion(logits, targets)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), 5)
            optimizer.step()
            # record
            prec1, prec5 = obtain_accuracy(logits.data,
                                           targets.data,
                                           topk=(1, 2))
            losses.update(loss.item(), inputs.size(0))
            top1s.update(prec1.item(), inputs.size(0))
            top5s.update(prec5.item(), inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - xend)
            xend = time.time()

            # if step + 1 == len(xloader):
        Sstr = '*Train-Shared-CNN* ' + time_string() + ' [{:03d}/10]'.format(
            ni, ne)
        Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
            batch_time=batch_time, data_time=data_time)
        Wstr = '[Loss {loss.avg:.3f}  Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f}]'.format(
            loss=losses, top1=top1s, top5=top5s)
        losses.reset()
        top1s.reset()
        top5s.reset()
        logger.log(Sstr + ' ' + Tstr + ' ' + Wstr)

    return losses.avg, top1s.avg, top5s.avg
コード例 #10
0
    def search(self):
        self.eva_time = AverageMeter()
        init_start = time.time()
        self.init_random()
        self.logger.log('Initial_takes: %.2f' % (time.time() - init_start))

        epoch_start_time = time.time()
        epoch_time_meter = AverageMeter()
        bests_per_epoch = list()
        perform_trace = list()
        for i in range(self.max_epochs):
            self.performances = torch.Tensor(self.performances)
            top_k = torch.argsort(self.performances,
                                  descending=True)[:self.parent_num]

            if self.best_perf is None or self.performances[
                    top_k[0]] > self.best_perf:
                self.best_cand = self.candidates[top_k[0]]
                self.best_perf = self.performances[top_k[0]]
            bests_per_epoch.append(self.best_cand)
            perform_trace.append(self.performances)

            self.parents = []
            for idx in top_k:
                self.parents.append(self.candidates[idx])
            self.candidates, self.performances = list(), list()
            self.eva_time = AverageMeter()
            self.get_mutation(self.population_num // 2)
            self.get_crossover()

            self.logger.log(
                '*SEARCH* ' + time_string() +
                '||| Epoch: %2d finished, %3d models have been tested, best performance is %.2f'
                % (i, len(self.perform_dict.keys()), self.best_perf))
            self.logger.log(' - Best Cand: ' + str(self.best_cand))
            this_epoch_time = time.time() - epoch_start_time
            epoch_time_meter.update(this_epoch_time)
            epoch_start_time = time.time()
            self.logger.log('Time for Epoch %d : %.2fs' % (i, this_epoch_time))
            self.logger.log(' -- Evaluated %d models, with %.2f s in average' %
                            (self.eva_time.count, self.eva_time.avg))

        self.logger.log(
            '--------\nSearching Finished. Best Arch Found with Acc %.2f' %
            (self.best_perf))
        self.logger.log(str(self.best_cand))
        #torch.save(self.best_cand, self.save_dir+'/best_arch.pth')
        #torch.save(self.perform_dict, self.save_dir+'/perform_dict.pth')
        return bests_per_epoch, self.perform_dict, perform_trace
コード例 #11
0
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, algo, epoch_str, print_freq, logger):
  data_time, batch_time = AverageMeter(), AverageMeter()
  base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
  arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
  end = time.time()
  network.train()
  for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
    scheduler.update(None, 1.0 * step / len(xloader))
    base_inputs = base_inputs.cuda(non_blocking=True)
    arch_inputs = arch_inputs.cuda(non_blocking=True)
    base_targets = base_targets.cuda(non_blocking=True)
    arch_targets = arch_targets.cuda(non_blocking=True)
    # measure data loading time
    data_time.update(time.time() - end)
    
    # Update the weights
    network.zero_grad()
    _, logits, _ = network(base_inputs)
    base_loss = criterion(logits, base_targets)
    base_loss.backward()
    w_optimizer.step()
    # record
    base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
    base_losses.update(base_loss.item(),  base_inputs.size(0))
    base_top1.update  (base_prec1.item(), base_inputs.size(0))
    base_top5.update  (base_prec5.item(), base_inputs.size(0))

    # update the architecture-weight
    network.zero_grad()
    _, logits, log_probs = network(arch_inputs)
    arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
    if algo == 'tunas':
      with torch.no_grad():
        RL_BASELINE_EMA.update(arch_prec1.item())
        rl_advantage = arch_prec1 - RL_BASELINE_EMA.value
      rl_log_prob = sum(log_probs)
      arch_loss = - rl_advantage * rl_log_prob
    elif algo == 'tas' or algo == 'fbv2':
      arch_loss = criterion(logits, arch_targets)
    else:
      raise ValueError('invalid algorightm name: {:}'.format(algo))
    arch_loss.backward()
    a_optimizer.step()
    # record
    arch_losses.update(arch_loss.item(),  arch_inputs.size(0))
    arch_top1.update  (arch_prec1.item(), arch_inputs.size(0))
    arch_top5.update  (arch_prec5.item(), arch_inputs.size(0))

    # measure elapsed time
    batch_time.update(time.time() - end)
    end = time.time()

    if step % print_freq == 0 or step + 1 == len(xloader):
      Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
      Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
      Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5)
      Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5)
      logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
  return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
コード例 #12
0
    def test_archi_acc(self, arch):
        if self.train_loader is not None:
            self.model.apply(ResetRunningStats)

            self.model.train()
            for step, (data, target) in enumerate(self.train_loader):
                # print('train step: {} total: {}'.format(step,max_train_iters))
                # data, target = train_dataprovider.next()
                # print('get data',data.shape)
                #data = data.cuda()
                output = self.model.forward(data, arch)  #_with_architect
                del data, target, output

        base_top1, base_top5 = AverageMeter(), AverageMeter()
        self.model.eval()

        one_batch = None
        for step, (data, target) in enumerate(self.val_loader):
            # print('test step: {} total: {}'.format(step,max_test_iters))
            if one_batch == None:
                one_batch = data
            batchsize = data.shape[0]
            # print('get data',data.shape)
            target = target.cuda(non_blocking=True)
            #data, target = data.to(device), target.to(device)

            _, logits = self.model.forward(data, arch)  #_with_architect

            prec1, prec5 = obtain_accuracy(logits.data,
                                           target.data,
                                           topk=(1, 5))
            base_top1.update(prec1.item(), batchsize)
            base_top5.update(prec5.item(), batchsize)

            del data, target, logits, prec1, prec5

        if self.lambda_t > 0.0:
            start_time = time.time()
            len_batch = min(len(one_batch), 50)
            for i in range(len_batch):
                _, _ = self.model.forward(one_batch[i:i + 1, :, :, :], arch)
            end_time = time.time()
            time_per = (end_time - start_time) / len_batch
        else:
            time_per = 0.0

        #print('top1: {:.2f} top5: {:.2f}'.format(base_top1.avg * 100, base_top5.avg * 100))
        return base_top1.avg, base_top5.avg, time_per
コード例 #13
0
ファイル: functions.py プロジェクト: vishruthb/darts-scale
def evaluate_for_seed(arch_config, config, arch, train_loader, valid_loaders, seed, logger):

  prepare_seed(seed) # random seed
  net = get_cell_based_tiny_net(dict2config({'name': 'infer.tiny',
                                             'C': arch_config['channel'], 'N': arch_config['num_cells'],
                                             'genotype': arch, 'num_classes': config.class_num}
                                            , None)
                                 )
  #net = TinyNetwork(arch_config['channel'], arch_config['num_cells'], arch, config.class_num)
  flop, param  = get_model_infos(net, config.xshape)
  logger.log('Network : {:}'.format(net.get_message()), False)
  logger.log('{:} Seed-------------------------- {:} --------------------------'.format(time_string(), seed))
  logger.log('FLOP = {:} MB, Param = {:} MB'.format(flop, param))
  # train and valid
  optimizer, scheduler, criterion = get_optim_scheduler(net.parameters(), config)
  network, criterion = torch.nn.DataParallel(net).cuda(), criterion.cuda()
  # start training
  start_time, epoch_time, total_epoch = time.time(), AverageMeter(), config.epochs + config.warmup
  train_losses, train_acc1es, train_acc5es, valid_losses, valid_acc1es, valid_acc5es = {}, {}, {}, {}, {}, {}
  train_times , valid_times = {}, {}
  for epoch in range(total_epoch):
    scheduler.update(epoch, 0.0)

    train_loss, train_acc1, train_acc5, train_tm = procedure(train_loader, network, criterion, scheduler, optimizer, 'train')
    train_losses[epoch] = train_loss
    train_acc1es[epoch] = train_acc1 
    train_acc5es[epoch] = train_acc5
    train_times [epoch] = train_tm
    with torch.no_grad():
      for key, xloder in valid_loaders.items():
        valid_loss, valid_acc1, valid_acc5, valid_tm = procedure(xloder  , network, criterion,      None,      None, 'valid')
        valid_losses['{:}@{:}'.format(key,epoch)] = valid_loss
        valid_acc1es['{:}@{:}'.format(key,epoch)] = valid_acc1 
        valid_acc5es['{:}@{:}'.format(key,epoch)] = valid_acc5
        valid_times ['{:}@{:}'.format(key,epoch)] = valid_tm

    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()
    need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (total_epoch-epoch-1), True) )
    logger.log('{:} {:} epoch={:03d}/{:03d} :: Train [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%] Valid [loss={:.5f}, acc@1={:.2f}%, acc@5={:.2f}%]'.format(time_string(), need_time, epoch, total_epoch, train_loss, train_acc1, train_acc5, valid_loss, valid_acc1, valid_acc5))
  info_seed = {'flop' : flop,
               'param': param,
               'channel'     : arch_config['channel'],
               'num_cells'   : arch_config['num_cells'],
               'config'      : config._asdict(),
               'total_epoch' : total_epoch ,
               'train_losses': train_losses,
               'train_acc1es': train_acc1es,
               'train_acc5es': train_acc5es,
               'train_times' : train_times,
               'valid_losses': valid_losses,
               'valid_acc1es': valid_acc1es,
               'valid_acc5es': valid_acc5es,
               'valid_times' : valid_times,
               'net_state_dict': net.state_dict(),
               'net_string'  : '{:}'.format(net),
               'finish-train': True
              }
  return info_seed
コード例 #14
0
ファイル: linear_pt.py プロジェクト: Mirofil/SOTL_NAS
def train_bptt(num_epochs: int, model, dset_train, batch_size: int, T: int,
               w_checkpoint_freq: int, grad_clip: float, w_lr: float,
               logging_freq: int, sotl_order: int, hvp: str):
    model.train()
    train_loader = torch.utils.data.DataLoader(dset_train,
                                               batch_size=batch_size * T,
                                               shuffle=True)

    for epoch in range(num_epochs):
        epoch_loss = AverageMeter()
        true_batch_index = 0
        for batch_idx, batch in enumerate(train_loader):
            xs, ys = torch.split(batch[0], batch_size), torch.split(
                batch[1], batch_size)

            weight_buffer = WeightBuffer(T=T,
                                         checkpoint_freq=w_checkpoint_freq)
            for intra_batch_idx, (x, y) in enumerate(zip(xs, ys)):
                weight_buffer.add(model, intra_batch_idx)

                y_pred = model(x)
                loss = criterion(y_pred, y)
                epoch_loss.update(loss.item())

                grads = torch.autograd.grad(loss,
                                            model.weight_params(),
                                            retain_graph=True,
                                            allow_unused=True,
                                            create_graph=True)

                w_optimizer.zero_grad()

                with torch.no_grad():
                    for g, w in zip(grads, model.weight_params()):
                        w.grad = g
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1)

                w_optimizer.step()
                true_batch_index += 1
                if true_batch_index % logging_freq == 0:
                    print("Epoch: {}, Batch: {}, Loss: {}".format(
                        epoch, true_batch_index, epoch_loss.avg))
                    wandb.log({"Train loss": epoch_loss.avg})

            total_arch_gradient = sotl_gradient(model,
                                                criterion,
                                                xs,
                                                ys,
                                                weight_buffer,
                                                w_lr=w_lr,
                                                hvp=hvp,
                                                order=sotl_order)

            a_optimizer.zero_grad()

            for g, w in zip(total_arch_gradient, model.arch_params()):
                w.grad = g
            torch.nn.utils.clip_grad_norm_(model.arch_params(), 1)
            a_optimizer.step()
コード例 #15
0
def search_func(xloader, network, criterion, scheduler, w_optimizer, epoch_str,
                print_freq, logger):
    data_time, batch_time = AverageMeter(), AverageMeter()
    base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    network.train()
    end = time.time()
    for step, (base_inputs, base_targets, arch_inputs,
               arch_targets) in enumerate(xloader):
        scheduler.update(None, 1.0 * step / len(xloader))
        base_targets = base_targets.cuda(non_blocking=True)
        arch_targets = arch_targets.cuda(non_blocking=True)
        # measure data loading time
        data_time.update(time.time() - end)

        # update the weights
        network.module.random_genotype(True)
        w_optimizer.zero_grad()
        _, logits = network(base_inputs)
        base_loss = criterion(logits, base_targets)
        base_loss.backward()
        nn.utils.clip_grad_norm_(network.parameters(), 5)
        w_optimizer.step()
        # record
        base_prec1, base_prec5 = obtain_accuracy(logits.data,
                                                 base_targets.data,
                                                 topk=(1, 5))
        base_losses.update(base_loss.item(), base_inputs.size(0))
        base_top1.update(base_prec1.item(), base_inputs.size(0))
        base_top5.update(base_prec5.item(), base_inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if step % print_freq == 0 or step + 1 == len(xloader):
            Sstr = (
                "*SEARCH* " + time_string() +
                " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)))
            Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
                batch_time=batch_time, data_time=data_time)
            Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
                loss=base_losses, top1=base_top1, top5=base_top5)
            logger.log(Sstr + " " + Tstr + " " + Wstr)
    return base_losses.avg, base_top1.avg, base_top5.avg
コード例 #16
0
ファイル: SETN.py プロジェクト: vishruthb/darts-scale
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
  data_time, batch_time = AverageMeter(), AverageMeter()
  base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
  arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
  end = time.time()
  network.train()
  for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
    scheduler.update(None, 1.0 * step / len(xloader))
    base_targets = base_targets.cuda(non_blocking=True)
    arch_targets = arch_targets.cuda(non_blocking=True)
    # measure data loading time
    data_time.update(time.time() - end)
    
    # update the weights
    sampled_arch = network.module.dync_genotype(True)
    network.module.set_cal_mode('dynamic', sampled_arch)
    #network.module.set_cal_mode( 'urs' )
    network.zero_grad()
    _, logits = network(base_inputs)
    base_loss = criterion(logits, base_targets)
    base_loss.backward()
    w_optimizer.step()
    # record
    base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
    base_losses.update(base_loss.item(),  base_inputs.size(0))
    base_top1.update  (base_prec1.item(), base_inputs.size(0))
    base_top5.update  (base_prec5.item(), base_inputs.size(0))

    # update the architecture-weight
    network.module.set_cal_mode( 'joint' )
    network.zero_grad()
    _, logits = network(arch_inputs)
    arch_loss = criterion(logits, arch_targets)
    arch_loss.backward()
    a_optimizer.step()
    # record
    arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
    arch_losses.update(arch_loss.item(),  arch_inputs.size(0))
    arch_top1.update  (arch_prec1.item(), arch_inputs.size(0))
    arch_top5.update  (arch_prec5.item(), arch_inputs.size(0))

    # measure elapsed time
    batch_time.update(time.time() - end)
    end = time.time()

    if step % print_freq == 0 or step + 1 == len(xloader):
      Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
      Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
      Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5)
      Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5)
      logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
      #print (nn.functional.softmax(network.module.arch_parameters, dim=-1))
      #print (network.module.arch_parameters)
  return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
コード例 #17
0
ファイル: xshapes.py プロジェクト: vishruthb/darts-scale
def main(save_dir: Path, workers: int, datasets: List[Text], xpaths: List[Text],
         splits: List[int], seeds: List[int], nets: List[str], opt_config: Dict[Text, Any],
         to_evaluate_indexes: tuple, cover_mode: bool):

  log_dir = save_dir / 'logs'
  log_dir.mkdir(parents=True, exist_ok=True)
  logger = Logger(str(log_dir), os.getpid(), False)

  logger.log('xargs : seeds      = {:}'.format(seeds))
  logger.log('xargs : cover_mode = {:}'.format(cover_mode))
  logger.log('-' * 100)

  logger.log(
    'Start evaluating range =: {:06d} - {:06d}'.format(min(to_evaluate_indexes), max(to_evaluate_indexes))
   +'({:} in total) / {:06d} with cover-mode={:}'.format(len(to_evaluate_indexes), len(nets), cover_mode))
  for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)):
    logger.log(
      '--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}'.format(i, len(datasets), dataset, xpath, split))
  logger.log('--->>> optimization config : {:}'.format(opt_config))
  #to_evaluate_indexes = list(range(srange[0], srange[1] + 1))

  start_time, epoch_time = time.time(), AverageMeter()
  for i, index in enumerate(to_evaluate_indexes):
    channelstr = nets[index]
    logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}] {:}'.format(time_string(), i,
                       len(to_evaluate_indexes), index, len(nets), seeds, '-' * 15))
    logger.log('{:} {:} {:}'.format('-' * 15, channelstr, '-' * 15))

    # test this arch on different datasets with different seeds
    has_continue = False
    for seed in seeds:
      to_save_name = save_dir / 'arch-{:06d}-seed-{:04d}.pth'.format(index, seed)
      if to_save_name.exists():
        if cover_mode:
          logger.log('Find existing file : {:}, remove it before evaluation'.format(to_save_name))
          os.remove(str(to_save_name))
        else:
          logger.log('Find existing file : {:}, skip this evaluation'.format(to_save_name))
          has_continue = True
          continue
      results = evaluate_all_datasets(channelstr, datasets, xpaths, splits, opt_config, seed, workers, logger)
      torch.save(results, to_save_name)
      logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th arch [seeds={:}]  ===>>> {:}'.format(time_string(), i,
                    len(to_evaluate_indexes), index, len(nets), seeds, to_save_name))
    # measure elapsed time
    if not has_continue: epoch_time.update(time.time() - start_time)
    start_time = time.time()
    need_time = 'Time Left: {:}'.format(convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes) - i - 1), True))
    logger.log('This arch costs : {:}'.format(convert_secs2time(epoch_time.val, True)))
    logger.log('{:}'.format('*' * 100))
    logger.log('{:}   {:74s}   {:}'.format('*' * 10, '{:06d}/{:06d} ({:06d}/{:06d})-th done, left {:}'.format(i, len(
      to_evaluate_indexes), index, len(nets), need_time), '*' * 10))
    logger.log('{:}'.format('*' * 100))

  logger.close()
コード例 #18
0
def procedure(
    xloader,
    network,
    criterion,
    optimizer,
    metric,
    mode: Text,
    logger_fn: Callable = None,
):
    data_time, batch_time = AverageMeter(), AverageMeter()
    if mode.lower() == "train":
        network.train()
    elif mode.lower() == "valid":
        network.eval()
    else:
        raise ValueError("The mode is not right : {:}".format(mode))

    end = time.time()
    for i, (inputs, targets) in enumerate(xloader):
        # measure data loading time
        data_time.update(time.time() - end)
        # calculate prediction and loss

        if mode == "train":
            optimizer.zero_grad()

        outputs = network(inputs)
        targets = targets.to(get_device(outputs))

        if mode == "train":
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

        # record
        with torch.no_grad():
            results = metric(outputs, targets)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
    return metric.get_info()
コード例 #19
0
ファイル: linear_pt.py プロジェクト: Mirofil/SOTL_NAS
def valid_func(model, val_loader, criterion):
    model.eval()
    val_meter = AverageMeter()
    with torch.no_grad():
        for batch in val_loader:
            x, y = batch
            y_pred = model(x)
            val_loss = criterion(y_pred, y)
            val_meter.update(val_loss.item())
    print("Val loss: {}".format(val_meter.avg))
    return val_meter
コード例 #20
0
def train_shared_cnn(xloader, shared_cnn, criterion, scheduler, optimizer,
                     print_freq, logger, config, start_epoch):
    # start training
    start_time, epoch_time, total_epoch = time.time(), AverageMeter(
    ), config.epochs + config.warmup
    for epoch in range(start_epoch, total_epoch):
        scheduler.update(epoch, 0.0)
        need_time = 'Time Left: {:}'.format(
            convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
        epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
        logger.log('\n[Traing the {:}-th epoch] {:}, LR={:}'.format(
            epoch_str, need_time, min(scheduler.get_lr())))

        data_time, batch_time = AverageMeter(), AverageMeter()
        losses, top1s, top5s, xend = AverageMeter(), AverageMeter(
        ), AverageMeter(), time.time()

        shared_cnn.train()

        for step, (inputs, targets) in enumerate(xloader):
            scheduler.update(None, 1.0 * step / len(xloader))
            targets = targets.cuda(non_blocking=True)
            # measure data loading time
            data_time.update(time.time() - xend)

            optimizer.zero_grad()
            _, logits = shared_cnn(inputs)
            loss = criterion(logits, targets)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), 5)
            optimizer.step()
            # record
            prec1, prec5 = obtain_accuracy(logits.data,
                                           targets.data,
                                           topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1s.update(prec1.item(), inputs.size(0))
            top5s.update(prec5.item(), inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - xend)
            xend = time.time()

            if step % print_freq == 0 or step + 1 == len(xloader):
                Sstr = '*Train-Shared-CNN* ' + time_string(
                ) + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step,
                                                   len(xloader))
                Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                    batch_time=batch_time, data_time=data_time)
                Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(
                    loss=losses, top1=top1s, top5=top5s)
                logger.log(Sstr + ' ' + Tstr + ' ' + Wstr)

        cnn_loss, cnn_top1, cnn_top5 = losses.avg, top1s.avg, top5s.avg
        logger.log(
            '[{:}] shared-cnn : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'
            .format(epoch_str, cnn_loss, cnn_top1, cnn_top5))
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
    return
コード例 #21
0
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
    data_time, batch_time = AverageMeter(), AverageMeter()
    base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
    arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
    network.train()
    end = time.time()
    for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
        # print(111111111111111111111)
        # print(arch_inputs.size())
        # print(arch_targets.size())
        scheduler.update(None, 1.0 * step / len(xloader))
        base_targets = base_targets.cuda(non_blocking=True)
        arch_targets = arch_targets.cuda(non_blocking=True)
        # measure data loading time
        data_time.update(time.time() - end)

        # update the architecture-weight
        a_optimizer.zero_grad()
        arch_loss, arch_logits = backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets)
        a_optimizer.step()
        # record
        arch_prec1, arch_prec5 = obtain_accuracy(arch_logits.data, arch_targets.data, topk=(1, 2))
        arch_losses.update(arch_loss.item(),  arch_inputs.size(0))
        arch_top1.update  (arch_prec1.item(), arch_inputs.size(0))
        arch_top5.update  (arch_prec5.item(), arch_inputs.size(0))

        # update the weights
        w_optimizer.zero_grad()
        _, logits = network(base_inputs)
        base_loss = criterion(logits, base_targets)
        base_loss.backward()
        torch.nn.utils.clip_grad_norm_(network.parameters(), 5)
        w_optimizer.step()
        # record
        base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 2))
        base_losses.update(base_loss.item(),  base_inputs.size(0))
        base_top1.update  (base_prec1.item(), base_inputs.size(0))
        base_top5.update  (base_prec5.item(), base_inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if step + 1 == len(xloader):
            Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
            # Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
            Wstr = 'Base [Loss {loss.avg:.3f}  Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f}]'.format(loss=base_losses, top1=base_top1, top5=base_top5)
            Astr = 'Arch [Loss {loss.avg:.3f}  Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f}]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5)
            logger.log(Sstr + ' ' + Wstr + ' ' + Astr)
    return base_losses.avg, base_top1.avg, base_top5.avg
コード例 #22
0
def train_single_model(save_dir, workers, datasets, xpaths, splits, use_less, seeds, model_str, arch_config):
  assert torch.cuda.is_available(), 'CUDA is not available.'
  torch.backends.cudnn.enabled   = True
  torch.backends.cudnn.deterministic = True
  #torch.backends.cudnn.benchmark = True
  torch.set_num_threads( workers )
  
  save_dir = Path(save_dir) / 'specifics' / '{:}-{:}-{:}-{:}'.format('LESS' if use_less else 'FULL', model_str, arch_config['channel'], arch_config['num_cells'])
  logger   = Logger(str(save_dir), 0, False)
  if model_str in CellArchitectures:
    arch   = CellArchitectures[model_str]
    logger.log('The model string is found in pre-defined architecture dict : {:}'.format(model_str))
  else:
    try:
      arch = CellStructure.str2structure(model_str)
    except:
      raise ValueError('Invalid model string : {:}. It can not be found or parsed.'.format(model_str))
  assert arch.check_valid_op(get_search_spaces('cell', 'full')), '{:} has the invalid op.'.format(arch)
  logger.log('Start train-evaluate {:}'.format(arch.tostr()))
  logger.log('arch_config : {:}'.format(arch_config))

  start_time, seed_time = time.time(), AverageMeter()
  for _is, seed in enumerate(seeds):
    logger.log('\nThe {:02d}/{:02d}-th seed is {:} ----------------------<.>----------------------'.format(_is, len(seeds), seed))
    to_save_name = save_dir / 'seed-{:04d}.pth'.format(seed)
    if to_save_name.exists():
      logger.log('Find the existing file {:}, directly load!'.format(to_save_name))
      checkpoint = torch.load(to_save_name)
    else:
      logger.log('Does not find the existing file {:}, train and evaluate!'.format(to_save_name))
      checkpoint = evaluate_all_datasets(arch, datasets, xpaths, splits, use_less, seed, arch_config, workers, logger)
      torch.save(checkpoint, to_save_name)
    # log information
    logger.log('{:}'.format(checkpoint['info']))
    all_dataset_keys = checkpoint['all_dataset_keys']
    for dataset_key in all_dataset_keys:
      logger.log('\n{:} dataset : {:} {:}'.format('-'*15, dataset_key, '-'*15))
      dataset_info = checkpoint[dataset_key]
      #logger.log('Network ==>\n{:}'.format( dataset_info['net_string'] ))
      logger.log('Flops = {:} MB, Params = {:} MB'.format(dataset_info['flop'], dataset_info['param']))
      logger.log('config : {:}'.format(dataset_info['config']))
      logger.log('Training State (finish) = {:}'.format(dataset_info['finish-train']))
      last_epoch = dataset_info['total_epoch'] - 1
      train_acc1es, train_acc5es = dataset_info['train_acc1es'], dataset_info['train_acc5es']
      valid_acc1es, valid_acc5es = dataset_info['valid_acc1es'], dataset_info['valid_acc5es']
      logger.log('Last Info : Train = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%, Test = Acc@1 {:.2f}% Acc@5 {:.2f}% Error@1 {:.2f}%'.format(train_acc1es[last_epoch], train_acc5es[last_epoch], 100-train_acc1es[last_epoch], valid_acc1es[last_epoch], valid_acc5es[last_epoch], 100-valid_acc1es[last_epoch]))
    # measure elapsed time
    seed_time.update(time.time() - start_time)
    start_time = time.time()
    need_time = 'Time Left: {:}'.format( convert_secs2time(seed_time.avg * (len(seeds)-_is-1), True) )
    logger.log('\n<<<***>>> The {:02d}/{:02d}-th seed is {:} <finish> other procedures need {:}'.format(_is, len(seeds), seed, need_time))
  logger.close()
コード例 #23
0
def eval_robust_heatmap(detector, xloader, print_freq, logger):
    batch_time, NUM_PTS = AverageMeter(), xloader.dataset.NUM_PTS
    Preds, GT_locs, Distances = [], [], []
    eval_meta, end = Eval_Meta(), time.time()

    with torch.no_grad():
        detector.eval()
        for i, (inputs, heatmaps, masks, norm_points, thetas, data_index,
                nopoints, xshapes) in enumerate(xloader):
            data_index = data_index.squeeze(1).tolist()
            batch_size, iters, C, H, W = inputs.size()
            for ibatch in range(batch_size):
                xinputs, xpoints, xthetas = inputs[ibatch], norm_points[
                    ibatch].permute(0, 2, 1).contiguous(), thetas[ibatch]
                batch_features, batch_heatmaps, batch_locs, batch_scos = detector(
                    xinputs.cuda(non_blocking=True))
                batch_locs = batch_locs.cpu()[:, :-1]
                all_locs = []
                for _iter in range(iters):
                    _locs = normalize_points((H, W),
                                             batch_locs[_iter].permute(1, 0))
                    xlocs = torch.cat((_locs, torch.ones(1, NUM_PTS)), dim=0)
                    nlocs = torch.mm(xthetas[_iter, :2], xlocs)
                    rlocs = denormalize_points(xshapes[ibatch].tolist(), nlocs)
                    rlocs = torch.cat(
                        (rlocs.permute(1, 0), xpoints[_iter, :, 2:]), dim=1)
                    all_locs.append(rlocs.clone())
                GT_loc = xloader.dataset.labels[
                    data_index[ibatch]].get_points()
                norm_distance = xloader.dataset.get_normalization_distance(
                    data_index[ibatch])
                # save the results
                eval_meta.append((sum(all_locs) / len(all_locs)).numpy().T,
                                 GT_loc.numpy(),
                                 xloader.dataset.datas[data_index[ibatch]],
                                 norm_distance)
                Distances.append(norm_distance)
                Preds.append(all_locs)
                GT_locs.append(GT_loc.permute(1, 0))
            # compute time
            batch_time.update(time.time() - end)
            end = time.time()
            if i % print_freq == 0 or i + 1 == len(xloader):
                last_time = convert_secs2time(
                    batch_time.avg * (len(xloader) - i - 1), True)
                logger.log(
                    ' -->>[Robust HEATMAP-based Evaluation] [{:03d}/{:03d}] Time : {:}'
                    .format(i, len(xloader), last_time))
    # evaluate the results
    errors, valids = calculate_robust(Preds, GT_locs, Distances, NUM_PTS)
    return errors, valids, eval_meta
コード例 #24
0
def train_func(xargs, search_loader, valid_loader, network, operations,
               criterion, w_scheduler, w_optimizer, logger, drop_iter,
               total_epoch):
    logger.log('|=> Train, drop_iter={}, epochs={}'.format(
        drop_iter, total_epoch))
    # start training
    start_time, search_time, epoch_time, start_epoch = time.time(
    ), AverageMeter(), AverageMeter(), 0
    for epoch in range(start_epoch, total_epoch):
        w_scheduler.update(epoch, 0.0)
        need_time = 'Time Left: {:}'.format(
            convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
        epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
        logger.log('\n[Search the {:}-th epoch] {:}, LR={:}'.format(
            epoch_str, need_time, min(w_scheduler.get_lr())))

        search_w_loss, search_w_top1, search_w_top5, search_a_loss, search_a_top1, search_a_top5 \
                    = search_func(search_loader, network, operations, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, logger)
        search_time.update(time.time() - start_time)
        logger.log(
            '[{:}] search [base] : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'
            .format(epoch_str, search_w_loss, search_w_top1, search_w_top5,
                    search_time.sum))
コード例 #25
0
ファイル: luketina.py プロジェクト: Mirofil/SOTL_NAS
def valid_func(model, dset_val, criterion, print_results=True):
    model.eval()
    val_loader = torch.utils.data.DataLoader(dset_val, batch_size=32)

    val_meter = AverageMeter()
    with torch.no_grad():
        for batch in val_loader:
            x, y = batch
            y_pred = model(x)
            val_loss = criterion(y_pred, y)
            val_meter.update(val_loss.item())
    if print_results:
        print("Val loss: {}".format(val_meter.avg))
    return val_meter
コード例 #26
0
ファイル: train.py プロジェクト: Mirofil/SOTL_NAS
def valid_func(model, dset_val, criterion, device = 'cuda' if torch.cuda.is_available() else 'cpu', print_results=True):
    model.eval()
    val_loader = torch.utils.data.DataLoader(dset_val, batch_size=32)
    val_meter = AverageMeter()
    val_acc_meter = AverageMeter()

    with torch.no_grad():
        for batch in val_loader:
            x, y = batch
            x = x.to(device)
            y = y.to(device)
            y_pred = model(x)

            if isinstance(criterion, torch.nn.CrossEntropyLoss):
                predicted = torch.argmax(y_pred, dim=1)
                correct = torch.sum((predicted == y)).item()
                total = predicted.size()[0]
                val_acc_meter.update(correct/total)
            val_loss = criterion(y_pred, y)
            val_meter.update(val_loss.item())
    if print_results:
        print("Val loss: {}, Val acc: {}".format(val_meter.avg, val_acc_meter.avg if val_acc_meter.avg > 0 else "Not applicable"))
    return val_meter
コード例 #27
0
ファイル: luketina.py プロジェクト: Mirofil/SOTL_NAS
def train_normal(num_epochs,
                 model,
                 dset_train,
                 batch_size,
                 grad_clip,
                 logging_freq,
                 optim="sgd",
                 **kwargs):
    train_loader = torch.utils.data.DataLoader(dset_train,
                                               batch_size=batch_size,
                                               shuffle=True)

    model.train()
    for epoch in range(num_epochs):

        epoch_loss = AverageMeter()
        for batch_idx, batch in enumerate(train_loader):
            x, y = batch
            w_optimizer.zero_grad()

            y_pred = model(x)
            loss = criterion(y_pred, y)
            loss.backward(retain_graph=True)

            epoch_loss.update(loss.item())
            if optim == "newton":
                linear_weight = list(model.weight_params())[0]
                hessian_newton = torch.inverse(
                    hessian(loss * 1, linear_weight,
                            linear_weight).reshape(linear_weight.size()[1],
                                                   linear_weight.size()[1]))
                with torch.no_grad():
                    for w in model.weight_params():
                        w = w.subtract_(torch.matmul(w.grad, hessian_newton))
            elif optim == "sgd":
                torch.nn.utils.clip_grad_norm_(model.weight_params(), 1)
                w_optimizer.step()
            else:
                raise NotImplementedError

            wandb.log({
                "Train loss": epoch_loss.avg,
                "Epoch": epoch,
                "Batch": batch_idx
            })

            if batch_idx % logging_freq == 0:
                print("Epoch: {}, Batch: {}, Loss: {}, Alphas: {}".format(
                    epoch, batch_idx, epoch_loss.avg, model.fc1.alphas.data))
コード例 #28
0
def check_files(save_dir, meta_file, basestr):
  meta_infos     = torch.load(meta_file, map_location='cpu')
  meta_archs     = meta_infos['archs']
  meta_num_archs = meta_infos['total']
  meta_max_node  = meta_infos['max_node']
  assert meta_num_archs == len(meta_archs), 'invalid number of archs : {:} vs {:}'.format(meta_num_archs, len(meta_archs))

  sub_model_dirs = sorted(list(save_dir.glob('*-*-{:}'.format(basestr))))
  print ('{:} find {:} directories used to save checkpoints'.format(time_string(), len(sub_model_dirs)))
  
  subdir2archs, num_evaluated_arch = collections.OrderedDict(), 0
  num_seeds = defaultdict(lambda: 0)
  for index, sub_dir in enumerate(sub_model_dirs):
    xcheckpoints = list(sub_dir.glob('arch-*-seed-*.pth'))
    #xcheckpoints = list(sub_dir.glob('arch-*-seed-0777.pth')) + list(sub_dir.glob('arch-*-seed-0888.pth')) + list(sub_dir.glob('arch-*-seed-0999.pth'))
    arch_indexes = set()
    for checkpoint in xcheckpoints:
      temp_names = checkpoint.name.split('-')
      assert len(temp_names) == 4 and temp_names[0] == 'arch' and temp_names[2] == 'seed', 'invalid checkpoint name : {:}'.format(checkpoint.name)
      arch_indexes.add( temp_names[1] )
    subdir2archs[sub_dir] = sorted(list(arch_indexes))
    num_evaluated_arch   += len(arch_indexes)
    # count number of seeds for each architecture
    for arch_index in arch_indexes:
      num_seeds[ len(list(sub_dir.glob('arch-{:}-seed-*.pth'.format(arch_index)))) ] += 1
  print('There are {:5d} architectures that have been evaluated ({:} in total, {:} ckps in total).'.format(num_evaluated_arch, meta_num_archs, sum(k*v for k, v in num_seeds.items())))
  for key in sorted( list( num_seeds.keys() ) ): print ('There are {:5d} architectures that are evaluated {:} times.'.format(num_seeds[key], key))

  dir2ckps, dir2ckp_exists = dict(), dict()
  start_time, epoch_time = time.time(), AverageMeter()
  for IDX, (sub_dir, arch_indexes) in enumerate(subdir2archs.items()):
    seeds = [777, 888, 999]
    numrs = defaultdict(lambda: 0)
    all_checkpoints, all_ckp_exists = [], []
    for arch_index in arch_indexes:
      checkpoints = ['arch-{:}-seed-{:04d}.pth'.format(arch_index, seed) for seed in seeds]
      ckp_exists  = [(sub_dir/x).exists() for x in checkpoints]
      arch_index  = int(arch_index)
      assert 0 <= arch_index < len(meta_archs), 'invalid arch-index {:} (not found in meta_archs)'.format(arch_index)
      all_checkpoints += checkpoints
      all_ckp_exists  += ckp_exists
      numrs[sum(ckp_exists)] += 1
    dir2ckps[ str(sub_dir) ]       = all_checkpoints
    dir2ckp_exists[ str(sub_dir) ] = all_ckp_exists
    # measure time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()
    numrstr = ', '.join( ['{:}: {:03d}'.format(x, numrs[x]) for x in sorted(numrs.keys())] )
    print('{:} load [{:2d}/{:2d}] [{:03d} archs] [{:04d}->{:04d} ckps] {:} done, need {:}. {:}'.format(time_string(), IDX+1, len(subdir2archs), len(arch_indexes), len(all_checkpoints), sum(all_ckp_exists), sub_dir, convert_secs2time(epoch_time.avg * (len(subdir2archs)-IDX-1), True), numrstr))
コード例 #29
0
def main(save_dir, workers, datasets, xpaths, splits, use_less, srange, arch_index, seeds, cover_mode, meta_info, arch_config):
  assert torch.cuda.is_available(), 'CUDA is not available.'
  torch.backends.cudnn.enabled   = True
  #torch.backends.cudnn.benchmark = True
  torch.backends.cudnn.deterministic = True
  torch.set_num_threads( workers )

  assert len(srange) == 2 and 0 <= srange[0] <= srange[1], 'invalid srange : {:}'.format(srange)
  
  if use_less:
    sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}-LESS'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells'])
  else:
    sub_dir = Path(save_dir) / '{:06d}-{:06d}-C{:}-N{:}'.format(srange[0], srange[1], arch_config['channel'], arch_config['num_cells'])
  logger  = Logger(str(sub_dir), 0, False)

  all_archs = meta_info['archs']
  assert srange[1] < meta_info['total'], 'invalid range : {:}-{:} vs. {:}'.format(srange[0], srange[1], meta_info['total'])
  assert arch_index == -1 or srange[0] <= arch_index <= srange[1], 'invalid range : {:} vs. {:} vs. {:}'.format(srange[0], arch_index, srange[1])
  if arch_index == -1:
    to_evaluate_indexes = list(range(srange[0], srange[1]+1))
  else:
    to_evaluate_indexes = [arch_index]
  logger.log('xargs : seeds      = {:}'.format(seeds))
  logger.log('xargs : arch_index = {:}'.format(arch_index))
  logger.log('xargs : cover_mode = {:}'.format(cover_mode))
  logger.log('-'*100)

  logger.log('Start evaluating range =: {:06d} vs. {:06d} vs. {:06d} / {:06d} with cover-mode={:}'.format(srange[0], arch_index, srange[1], meta_info['total'], cover_mode))
  for i, (dataset, xpath, split) in enumerate(zip(datasets, xpaths, splits)):
    logger.log('--->>> Evaluate {:}/{:} : dataset={:9s}, path={:}, split={:}'.format(i, len(datasets), dataset, xpath, split))
  logger.log('--->>> architecture config : {:}'.format(arch_config))
  

  start_time, epoch_time = time.time(), AverageMeter()
  for i, index in enumerate(to_evaluate_indexes):
    arch = all_archs[index]
    logger.log('\n{:} evaluate {:06d}/{:06d} ({:06d}/{:06d})-th architecture [seeds={:}] {:}'.format('-'*15, i, len(to_evaluate_indexes), index, meta_info['total'], seeds, '-'*15))
    #logger.log('{:} {:} {:}'.format('-'*15, arch.tostr(), '-'*15))
    logger.log('{:} {:} {:}'.format('-'*15, arch, '-'*15))
  
    # test this arch on different datasets with different seeds
    has_continue = False
    for seed in seeds:
      to_save_name = sub_dir / 'arch-{:06d}-seed-{:04d}.pth'.format(index, seed)
      if to_save_name.exists():
        if cover_mode:
          logger.log('Find existing file : {:}, remove it before evaluation'.format(to_save_name))
          os.remove(str(to_save_name))
        else         :
          logger.log('Find existing file : {:}, skip this evaluation'.format(to_save_name))
          has_continue = True
          continue
      results = evaluate_all_datasets(CellStructure.str2structure(arch), \
                                        datasets, xpaths, splits, use_less, seed, \
                                        arch_config, workers, logger)
      torch.save(results, to_save_name)
      logger.log('{:} --evaluate-- {:06d}/{:06d} ({:06d}/{:06d})-th seed={:} done, save into {:}'.format('-'*15, i, len(to_evaluate_indexes), index, meta_info['total'], seed, to_save_name))
    # measure elapsed time
    if not has_continue: epoch_time.update(time.time() - start_time)
    start_time = time.time()
    need_time = 'Time Left: {:}'.format( convert_secs2time(epoch_time.avg * (len(to_evaluate_indexes)-i-1), True) )
    logger.log('This arch costs : {:}'.format( convert_secs2time(epoch_time.val, True) ))
    logger.log('{:}'.format('*'*100))
    logger.log('{:}   {:74s}   {:}'.format('*'*10, '{:06d}/{:06d} ({:06d}/{:06d})-th done, left {:}'.format(i, len(to_evaluate_indexes), index, meta_info['total'], need_time), '*'*10))
    logger.log('{:}'.format('*'*100))

  logger.close()
コード例 #30
0
def main(xargs):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(args)

    train_data, valid_data, xshape, class_num = get_datasets(
        xargs.dataset, xargs.data_path, -1)
    #config_path = 'configs/nas-benchmark/algos/GDAS.config'
    config = load_config(xargs.config_path, {
        'class_num': class_num,
        'xshape': xshape
    }, logger)
    search_loader, _, valid_loader = get_nas_search_loaders(
        train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/',
        config.batch_size, xargs.workers)
    logger.log(
        '||||||| {:10s} ||||||| Search-Loader-Num={:}, batch size={:}'.format(
            xargs.dataset, len(search_loader), config.batch_size))
    logger.log('||||||| {:10s} ||||||| Config={:}'.format(
        xargs.dataset, config))

    search_space = get_search_spaces('cell', xargs.search_space_name)
    if xargs.model_config is None:
        model_config = dict2config(
            {
                'name': 'GDAS',
                'C': xargs.channel,
                'N': xargs.num_cells,
                'max_nodes': xargs.max_nodes,
                'num_classes': class_num,
                'space': search_space,
                'affine': False,
                'track_running_stats': bool(xargs.track_running_stats)
            }, None)
    else:
        model_config = load_config(
            xargs.model_config, {
                'num_classes': class_num,
                'space': search_space,
                'affine': False,
                'track_running_stats': bool(xargs.track_running_stats)
            }, None)
    search_model = get_cell_based_tiny_net(model_config)
    logger.log('search-model :\n{:}'.format(search_model))
    logger.log('model-config : {:}'.format(model_config))

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(
        search_model.get_weights(), config)
    a_optimizer = torch.optim.Adam(search_model.get_alphas(),
                                   lr=xargs.arch_learning_rate,
                                   betas=(0.5, 0.999),
                                   weight_decay=xargs.arch_weight_decay)
    logger.log('w-optimizer : {:}'.format(w_optimizer))
    logger.log('a-optimizer : {:}'.format(a_optimizer))
    logger.log('w-scheduler : {:}'.format(w_scheduler))
    logger.log('criterion   : {:}'.format(criterion))
    flop, param = get_model_infos(search_model, xshape)
    logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
    logger.log('search-space [{:} ops] : {:}'.format(len(search_space),
                                                     search_space))
    if xargs.arch_nas_dataset is None:
        api = None
    else:
        api = API(xargs.arch_nas_dataset)
    logger.log('{:} create API = {:} done'.format(time_string(), api))

    last_info, model_base_path, model_best_path = logger.path(
        'info'), logger.path('model'), logger.path('best')
    network, criterion = torch.nn.DataParallel(
        search_model).cuda(), criterion.cuda()

    if last_info.exists():  # automatically resume from previous checkpoint
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info['epoch']
        checkpoint = torch.load(last_info['last_checkpoint'])
        genotypes = checkpoint['genotypes']
        valid_accuracies = checkpoint['valid_accuracies']
        search_model.load_state_dict(checkpoint['search_model'])
        w_scheduler.load_state_dict(checkpoint['w_scheduler'])
        w_optimizer.load_state_dict(checkpoint['w_optimizer'])
        a_optimizer.load_state_dict(checkpoint['a_optimizer'])
        logger.log(
            "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch."
            .format(last_info, start_epoch))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch, valid_accuracies, genotypes = 0, {
            'best': -1
        }, {
            -1: search_model.genotype()
        }

    # start training
    start_time, search_time, epoch_time, total_epoch = time.time(
    ), AverageMeter(), AverageMeter(), config.epochs + config.warmup
    for epoch in range(start_epoch, total_epoch):
        w_scheduler.update(epoch, 0.0)
        need_time = 'Time Left: {:}'.format(
            convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
        epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
        search_model.set_tau(xargs.tau_max -
                             (xargs.tau_max - xargs.tau_min) * epoch /
                             (total_epoch - 1))
        logger.log('\n[Search the {:}-th epoch] {:}, tau={:}, LR={:}'.format(
            epoch_str, need_time, search_model.get_tau(),
            min(w_scheduler.get_lr())))

        search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \
                  = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger)
        search_time.update(time.time() - start_time)
        logger.log(
            '[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'
            .format(epoch_str, search_w_loss, search_w_top1, search_w_top5,
                    search_time.sum))
        logger.log(
            '[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'
            .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))
        # check the best accuracy
        valid_accuracies[epoch] = valid_a_top1
        if valid_a_top1 > valid_accuracies['best']:
            valid_accuracies['best'] = valid_a_top1
            genotypes['best'] = search_model.genotype()
            find_best = True
        else:
            find_best = False

        genotypes[epoch] = search_model.genotype()
        logger.log('<<<--->>> The {:}-th epoch : {:}'.format(
            epoch_str, genotypes[epoch]))
        # save checkpoint
        save_path = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(xargs),
                'search_model': search_model.state_dict(),
                'w_optimizer': w_optimizer.state_dict(),
                'a_optimizer': a_optimizer.state_dict(),
                'w_scheduler': w_scheduler.state_dict(),
                'genotypes': genotypes,
                'valid_accuracies': valid_accuracies
            }, model_base_path, logger)
        last_info = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(args),
                'last_checkpoint': save_path,
            }, logger.path('info'), logger)
        if find_best:
            logger.log(
                '<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'
                .format(epoch_str, valid_a_top1))
            copy_checkpoint(model_base_path, model_best_path, logger)
        with torch.no_grad():
            logger.log('{:}'.format(search_model.show_alphas()))
        if api is not None:
            logger.log('{:}'.format(api.query_by_arch(genotypes[epoch],
                                                      '200')))
        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    logger.log('\n' + '-' * 100)
    # check the performance from the architecture dataset
    logger.log(
        'GDAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(
            total_epoch, search_time.sum, genotypes[total_epoch - 1]))
    if api is not None:
        logger.log('{:}'.format(
            api.query_by_arch(genotypes[total_epoch - 1], '200')))
    logger.close()