Example #1
0
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, algo, epoch_str, print_freq, logger):
  data_time, batch_time = AverageMeter(), AverageMeter()
  base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
  arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
  end = time.time()
  network.train()
  for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
    scheduler.update(None, 1.0 * step / len(xloader))
    base_inputs = base_inputs.cuda(non_blocking=True)
    arch_inputs = arch_inputs.cuda(non_blocking=True)
    base_targets = base_targets.cuda(non_blocking=True)
    arch_targets = arch_targets.cuda(non_blocking=True)
    # measure data loading time
    data_time.update(time.time() - end)
    
    # Update the weights
    network.zero_grad()
    _, logits, _ = network(base_inputs)
    base_loss = criterion(logits, base_targets)
    base_loss.backward()
    w_optimizer.step()
    # record
    base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
    base_losses.update(base_loss.item(),  base_inputs.size(0))
    base_top1.update  (base_prec1.item(), base_inputs.size(0))
    base_top5.update  (base_prec5.item(), base_inputs.size(0))

    # update the architecture-weight
    network.zero_grad()
    _, logits, log_probs = network(arch_inputs)
    arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
    if algo == 'tunas':
      with torch.no_grad():
        RL_BASELINE_EMA.update(arch_prec1.item())
        rl_advantage = arch_prec1 - RL_BASELINE_EMA.value
      rl_log_prob = sum(log_probs)
      arch_loss = - rl_advantage * rl_log_prob
    elif algo == 'tas' or algo == 'fbv2':
      arch_loss = criterion(logits, arch_targets)
    else:
      raise ValueError('invalid algorightm name: {:}'.format(algo))
    arch_loss.backward()
    a_optimizer.step()
    # record
    arch_losses.update(arch_loss.item(),  arch_inputs.size(0))
    arch_top1.update  (arch_prec1.item(), arch_inputs.size(0))
    arch_top5.update  (arch_prec5.item(), arch_inputs.size(0))

    # measure elapsed time
    batch_time.update(time.time() - end)
    end = time.time()

    if step % print_freq == 0 or step + 1 == len(xloader):
      Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
      Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
      Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5)
      Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5)
      logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
  return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
Example #2
0
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
    data_time, batch_time = AverageMeter(), AverageMeter()
    base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
    arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
    network.train()
    end = time.time()
    # pdb.set_trace()
    for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
        scheduler.update(None, 1.0 * step / len(xloader))
        base_targets = base_targets.cuda(non_blocking=True)
        arch_targets = arch_targets.cuda(non_blocking=True)
        # measure data loading time
        data_time.update(time.time() - end)

        # update the weights
        w_optimizer.zero_grad()
        _, logits, _ = network(base_inputs)
        base_loss = criterion(logits, base_targets)
        base_loss.backward()
        torch.nn.utils.clip_grad_norm_(network.parameters(), 5)
        w_optimizer.step()
        base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
        base_losses.update(base_loss.item(), base_inputs.size(0))
        base_top1.update(base_prec1.item(), base_inputs.size(0))
        base_top5.update(base_prec5.item(), base_inputs.size(0))

        # update the architecture-weight
        a_optimizer.zero_grad()
        _, logits, activation_terms = network(inputs=arch_inputs)
        arch_loss = criterion(logits, arch_targets)

        arch_loss = arch_loss + activation_terms.mean()
        arch_loss.backward()
        a_optimizer.step()

        arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
        arch_losses.update(arch_loss.item(), arch_inputs.size(0))
        arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
        arch_top5.update(arch_prec5.item(), arch_inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if step % print_freq == 0 or step + 1 == len(xloader):
            Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
            Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                batch_time=batch_time, data_time=data_time)
            Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(
                loss=base_losses, top1=base_top1, top5=base_top5)
            Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(
                loss=arch_losses, top1=arch_top1, top5=arch_top5)
            logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
    # pdb.set_trace()
    return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
Example #3
0
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
  data_time, batch_time = AverageMeter(), AverageMeter()
  base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
  arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
  end = time.time()
  network.train()
  for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
    scheduler.update(None, 1.0 * step / len(xloader))
    base_targets = base_targets.cuda(non_blocking=True)
    arch_targets = arch_targets.cuda(non_blocking=True)
    # measure data loading time
    data_time.update(time.time() - end)
    
    # update the weights
    sampled_arch = network.module.dync_genotype(True)
    network.module.set_cal_mode('dynamic', sampled_arch)
    #network.module.set_cal_mode( 'urs' )
    network.zero_grad()
    _, logits = network(base_inputs)
    base_loss = criterion(logits, base_targets)
    base_loss.backward()
    w_optimizer.step()
    # record
    base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
    base_losses.update(base_loss.item(),  base_inputs.size(0))
    base_top1.update  (base_prec1.item(), base_inputs.size(0))
    base_top5.update  (base_prec5.item(), base_inputs.size(0))

    # update the architecture-weight
    network.module.set_cal_mode( 'joint' )
    network.zero_grad()
    _, logits = network(arch_inputs)
    arch_loss = criterion(logits, arch_targets)
    arch_loss.backward()
    a_optimizer.step()
    # record
    arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
    arch_losses.update(arch_loss.item(),  arch_inputs.size(0))
    arch_top1.update  (arch_prec1.item(), arch_inputs.size(0))
    arch_top5.update  (arch_prec5.item(), arch_inputs.size(0))

    # measure elapsed time
    batch_time.update(time.time() - end)
    end = time.time()

    if step % print_freq == 0 or step + 1 == len(xloader):
      Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
      Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
      Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5)
      Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5)
      logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
      #print (nn.functional.softmax(network.module.arch_parameters, dim=-1))
      #print (network.module.arch_parameters)
  return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
Example #4
0
def valid_func(xloader, network, criterion):
    data_time, batch_time = AverageMeter(), AverageMeter()
    arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    end = time.time()
    with torch.no_grad():
        for step, (X, Y) in enumerate(xloader):
            X = X.cuda(non_blocking=True)
            Y = Y.cuda(non_blocking=True)
            # measure data loading time
            data_time.update(time.time() - end)
            # prediction

            network.random_genotype(True)
            _, logits = network(X)
            arch_loss = criterion(logits, Y)
            # record
            arch_prec1, arch_prec5 = obtain_accuracy(logits.data,
                                                     Y.data,
                                                     topk=(1, 5))
            arch_losses.update(arch_loss.item(), Y.size(0))
            arch_top1.update(arch_prec1.item(), Y.size(0))
            arch_top5.update(arch_prec5.item(), Y.size(0))
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
    return arch_losses.avg, arch_top1.avg, arch_top5.avg
Example #5
0
def get_best_arch(controller, shared_cnn, xloader, n_samples=10):
    with torch.no_grad():
        controller.eval()
        shared_cnn.eval()
        archs, valid_accs = [], []
        loader_iter = iter(xloader)
        for i in range(n_samples):
            try:
                inputs, targets = next(loader_iter)
            except:
                loader_iter = iter(xloader)
                inputs, targets = next(loader_iter)

            _, _, sampled_arch = controller()
            arch = shared_cnn.module.update_arch(sampled_arch)
            _, logits = shared_cnn(inputs)
            val_top1, val_top5 = obtain_accuracy(logits.cpu().data,
                                                 targets.data,
                                                 topk=(1, 5))

            archs.append(arch)
            valid_accs.append(val_top1.item())

        best_idx = np.argmax(valid_accs)
        best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
        return best_arch, best_valid_acc
Example #6
0
def search_find_best(xloader, network, n_samples):
    with torch.no_grad():
        network.eval()
        archs, valid_accs = [], []
        #print ('obtain the top-{:} architectures'.format(n_samples))
        loader_iter = iter(xloader)
        for i in range(n_samples):
            arch = network.module.random_genotype(True)
            try:
                inputs, targets = next(loader_iter)
            except:
                loader_iter = iter(xloader)
                inputs, targets = next(loader_iter)

            _, logits = network(inputs)
            val_top1, val_top5 = obtain_accuracy(logits.cpu().data,
                                                 targets.data,
                                                 topk=(1, 5))

            archs.append(arch)
            valid_accs.append(val_top1.item())

        best_idx = np.argmax(valid_accs)
        best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
        return best_arch, best_valid_acc
Example #7
0
def get_best_arch(xloader, network, n_samples):
    with torch.no_grad():
        network.eval()
        archs, valid_accs = network.module.return_topK(n_samples), []
        #print ('obtain the top-{:} architectures'.format(n_samples))
        loader_iter = iter(xloader)
        for i, sampled_arch in enumerate(archs):
            network.module.set_cal_mode('dynamic', sampled_arch)
            try:
                inputs, targets = next(loader_iter)
            except:
                loader_iter = iter(xloader)
                inputs, targets = next(loader_iter)

            _, logits = network(inputs)
            val_top1, val_top5 = obtain_accuracy(logits.cpu().data,
                                                 targets.data,
                                                 topk=(1, 5))

            valid_accs.append(val_top1.item())
            #print ('--- {:}/{:} : {:} : {:}'.format(i, len(archs), sampled_arch, val_top1))

        best_idx = np.argmax(valid_accs)
        best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
        return best_arch, best_valid_acc
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
    data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
    losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
    latencies, device = [], torch.cuda.current_device()
    network.eval()
    with torch.no_grad():
        end = time.time()
        for i, (inputs, targets) in enumerate(xloader):
            targets = targets.cuda(device=device, non_blocking=True)
            inputs = inputs.cuda(device=device, non_blocking=True)
            data_time.update(time.time() - end)
            # forward
            features, logits = network(inputs)
            loss = criterion(logits, targets)
            batch_time.update(time.time() - end)
            if batch is None or batch == inputs.size(0):
                batch = inputs.size(0)
                latencies.append(batch_time.val - data_time.val)
            # record loss and accuracy
            prec1, prec5 = obtain_accuracy(logits.data,
                                           targets.data,
                                           topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))
            end = time.time()
    if len(latencies) > 2: latencies = latencies[1:]
    return losses.avg, top1.avg, top5.avg, latencies
Example #9
0
def train_shared_cnn(xloader, shared_cnn, criterion, scheduler, optimizer,
                     print_freq, logger, config, start_epoch):
    # start training
    start_time, epoch_time, total_epoch = time.time(), AverageMeter(
    ), config.epochs + config.warmup
    for epoch in range(start_epoch, total_epoch):
        scheduler.update(epoch, 0.0)
        need_time = 'Time Left: {:}'.format(
            convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
        epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
        logger.log('\n[Traing the {:}-th epoch] {:}, LR={:}'.format(
            epoch_str, need_time, min(scheduler.get_lr())))

        data_time, batch_time = AverageMeter(), AverageMeter()
        losses, top1s, top5s, xend = AverageMeter(), AverageMeter(
        ), AverageMeter(), time.time()

        shared_cnn.train()

        for step, (inputs, targets) in enumerate(xloader):
            scheduler.update(None, 1.0 * step / len(xloader))
            targets = targets.cuda(non_blocking=True)
            # measure data loading time
            data_time.update(time.time() - xend)

            optimizer.zero_grad()
            _, logits = shared_cnn(inputs)
            loss = criterion(logits, targets)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), 5)
            optimizer.step()
            # record
            prec1, prec5 = obtain_accuracy(logits.data,
                                           targets.data,
                                           topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1s.update(prec1.item(), inputs.size(0))
            top5s.update(prec5.item(), inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - xend)
            xend = time.time()

            if step % print_freq == 0 or step + 1 == len(xloader):
                Sstr = '*Train-Shared-CNN* ' + time_string(
                ) + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step,
                                                   len(xloader))
                Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                    batch_time=batch_time, data_time=data_time)
                Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(
                    loss=losses, top1=top1s, top5=top5s)
                logger.log(Sstr + ' ' + Tstr + ' ' + Wstr)

        cnn_loss, cnn_top1, cnn_top5 = losses.avg, top1s.avg, top5s.avg
        logger.log(
            '[{:}] shared-cnn : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'
            .format(epoch_str, cnn_loss, cnn_top1, cnn_top5))
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
    return
def procedure(xloader, network, criterion, scheduler, optimizer, mode: str):
    losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
    if mode == 'train':
        network.train()
    elif mode == 'valid':
        network.eval()
    else:
        raise ValueError("The mode is not right : {:}".format(mode))
    device = torch.cuda.current_device()
    data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
    for i, (inputs, targets) in enumerate(xloader):
        if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))

        targets = targets.cuda(device=device, non_blocking=True)
        if mode == 'train': optimizer.zero_grad()
        # forward
        features, logits = network(inputs)
        loss = criterion(logits, targets)
        # backward
        if mode == 'train':
            loss.backward()
            optimizer.step()
        # record loss and accuracy
        prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))
        # count time
        batch_time.update(time.time() - end)
        end = time.time()
    return losses.avg, top1.avg, top5.avg, batch_time.sum
Example #11
0
def get_best_arch(xloader, network, n_samples):
    with torch.no_grad():
        network.eval()
        archs, valid_accs = [], []
        loader_iter = iter(xloader)
        for i in range(n_samples):
            try:
                inputs, targets = next(loader_iter)
            except:
                loader_iter = iter(xloader)
                inputs, targets = next(loader_iter)

            sampled_arch = network.module.dync_genotype(False)
            network.module.set_cal_mode('dynamic', sampled_arch)
            _, logits = network(inputs)
            val_top1, val_top5 = obtain_accuracy(logits.cpu().data,
                                                 targets.data,
                                                 topk=(1, 5))

            archs.append(sampled_arch)
            valid_accs.append(val_top1.item())

        best_idx = np.argmax(valid_accs)
        best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
        return best_arch, best_valid_acc
Example #12
0
def get_best_arch(xloader, network, n_samples, algo):
    with torch.no_grad():
        network.eval()
        if algo == 'random':
            archs, valid_accs = network.return_topK(n_samples, True), []
        elif algo == 'setn':
            archs, valid_accs = network.return_topK(n_samples, False), []
        elif algo.startswith('darts') or algo == 'gdas':
            arch = network.genotype
            archs, valid_accs = [arch], []
        elif algo == 'enas':
            archs, valid_accs = [], []
            for _ in range(n_samples):
                _, _, sampled_arch = network.controller()
                archs.append(sampled_arch)
        else:
            raise ValueError('Invalid algorithm name : {:}'.format(algo))
        loader_iter = iter(xloader)
        for i, sampled_arch in enumerate(archs):
            network.set_cal_mode('dynamic', sampled_arch)
            try:
                inputs, targets = next(loader_iter)
            except:
                loader_iter = iter(xloader)
                inputs, targets = next(loader_iter)
            _, logits = network(inputs.cuda(non_blocking=True))
            val_top1, val_top5 = obtain_accuracy(logits.cpu().data,
                                                 targets.data,
                                                 topk=(1, 5))
            valid_accs.append(val_top1.item())
        best_idx = np.argmax(valid_accs)
        best_arch, best_valid_acc = archs[best_idx], valid_accs[best_idx]
        return best_arch, best_valid_acc
def valid_func(xloader, network, criterion):
    data_time, batch_time = AverageMeter(), AverageMeter()
    arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    network.eval()
    end = time.time()
    with torch.no_grad():
        for step, (arch_inputs, arch_targets) in enumerate(xloader):
            arch_targets = arch_targets.cuda(non_blocking=True)
            # measure data loading time
            data_time.update(time.time() - end)
            # prediction
            _, logits = network(arch_inputs)
            arch_loss = criterion(logits, arch_targets)
            # record
            arch_prec1, arch_prec5 = obtain_accuracy(logits.data,
                                                     arch_targets.data,
                                                     topk=(1, 5))
            arch_losses.update(arch_loss.item(), arch_inputs.size(0))
            arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
            arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
    return arch_losses.avg, arch_top1.avg, arch_top5.avg
Example #14
0
def train_shared_cnn(
    xloader,
    shared_cnn,
    controller,
    criterion,
    scheduler,
    optimizer,
    epoch_str,
    print_freq,
    logger,
):
    data_time, batch_time = AverageMeter(), AverageMeter()
    losses, top1s, top5s, xend = (
        AverageMeter(),
        AverageMeter(),
        AverageMeter(),
        time.time(),
    )

    shared_cnn.train()
    controller.eval()

    for step, (inputs, targets) in enumerate(xloader):
        scheduler.update(None, 1.0 * step / len(xloader))
        targets = targets.cuda(non_blocking=True)
        # measure data loading time
        data_time.update(time.time() - xend)

        with torch.no_grad():
            _, _, sampled_arch = controller()

        optimizer.zero_grad()
        shared_cnn.module.update_arch(sampled_arch)
        _, logits = shared_cnn(inputs)
        loss = criterion(logits, targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), 5)
        optimizer.step()
        # record
        prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1s.update(prec1.item(), inputs.size(0))
        top5s.update(prec5.item(), inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - xend)
        xend = time.time()

        if step % print_freq == 0 or step + 1 == len(xloader):
            Sstr = (
                "*Train-Shared-CNN* " + time_string() +
                " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)))
            Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
                batch_time=batch_time, data_time=data_time)
            Wstr = "[Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
                loss=losses, top1=top1s, top5=top5s)
            logger.log(Sstr + " " + Tstr + " " + Wstr)
    return losses.avg, top1s.avg, top5s.avg
Example #15
0
def search_valid(xloader, network, criterion, extra_info, print_freq, logger):
    data_time, batch_time, losses, top1, top5 = (
        AverageMeter(),
        AverageMeter(),
        AverageMeter(),
        AverageMeter(),
        AverageMeter(),
    )

    network.eval()
    network.apply(change_key("search_mode", "search"))
    end = time.time()
    # logger.log('Starting evaluating {:}'.format(epoch_info))
    with torch.no_grad():
        for i, (inputs, targets) in enumerate(xloader):
            # measure data loading time
            data_time.update(time.time() - end)
            # calculate prediction and loss
            targets = targets.cuda(non_blocking=True)

            logits, expected_flop = network(inputs)
            loss = criterion(logits, targets)
            # record
            prec1, prec5 = obtain_accuracy(logits.data,
                                           targets.data,
                                           topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0 or (i + 1) == len(xloader):
                Sstr = ("**VALID** " + time_string() +
                        " [{:}][{:03d}/{:03d}]".format(extra_info, i,
                                                       len(xloader)))
                Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
                    batch_time=batch_time, data_time=data_time)
                Lstr = "Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})".format(
                    loss=losses, top1=top1, top5=top5)
                Istr = "Size={:}".format(list(inputs.size()))
                logger.log(Sstr + " " + Tstr + " " + Lstr + " " + Istr)

    logger.log(
        " **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}"
        .format(
            top1=top1,
            top5=top5,
            error1=100 - top1.avg,
            error5=100 - top5.avg,
            loss=losses.avg,
        ))

    return losses.avg, top1.avg, top5.avg
Example #16
0
def train_shared_cnn(xloader, shared_cnn, controller, criterion, scheduler,
                     optimizer, epoch_str, print_freq, logger):
    data_time, batch_time = AverageMeter(), AverageMeter()
    losses, top1s, top5s, xend = AverageMeter(), AverageMeter(), AverageMeter(
    ), time.time()

    shared_cnn.train()
    controller.eval()
    ne = 10

    for ni in range(ne):
        with torch.no_grad():
            _, _, sampled_arch = controller()
        shared_cnn.module.update_arch(sampled_arch)
        print(sampled_arch)
        # arch_str = op_list2str(sampled_arch)
        for step, (inputs, targets) in enumerate(xloader):
            # print(step,inputs,targets)
            scheduler.update(None, 1.0 * step / len(xloader))
            targets = targets.cuda(non_blocking=True)
            # measure data loading time
            data_time.update(time.time() - xend)

            optimizer.zero_grad()

            _, logits = shared_cnn(inputs)
            loss = criterion(logits, targets)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), 5)
            optimizer.step()
            # record
            prec1, prec5 = obtain_accuracy(logits.data,
                                           targets.data,
                                           topk=(1, 2))
            losses.update(loss.item(), inputs.size(0))
            top1s.update(prec1.item(), inputs.size(0))
            top5s.update(prec5.item(), inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - xend)
            xend = time.time()

            # if step + 1 == len(xloader):
        Sstr = '*Train-Shared-CNN* ' + time_string() + ' [{:03d}/10]'.format(
            ni, ne)
        Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
            batch_time=batch_time, data_time=data_time)
        Wstr = '[Loss {loss.avg:.3f}  Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f}]'.format(
            loss=losses, top1=top1s, top5=top5s)
        losses.reset()
        top1s.reset()
        top5s.reset()
        logger.log(Sstr + ' ' + Tstr + ' ' + Wstr)

    return losses.avg, top1s.avg, top5s.avg
Example #17
0
    def valid_fn(network, arch, args):
        network.arch_cache = arch
        train_bn(iter(train_loader), network, bn_iter)
        network.eval()
        try:
            inputs, targets = next(loader_iter)
        except:
            loader_iter = iter(xloader)
            inputs, targets = next(loader_iter)

        inputs, targets = inputs.cuda(non_blocking=True), targets.cuda(
            non_blocking=True)
        _, logits = network(inputs)
        val_top1 = obtain_accuracy(logits, targets.data)[0]
        return val_top1
    def test_archi_acc(self, arch):
        if self.train_loader is not None:
            self.model.apply(ResetRunningStats)

            self.model.train()
            for step, (data, target) in enumerate(self.train_loader):
                # print('train step: {} total: {}'.format(step,max_train_iters))
                # data, target = train_dataprovider.next()
                # print('get data',data.shape)
                #data = data.cuda()
                output = self.model.forward(data, arch)  #_with_architect
                del data, target, output

        base_top1, base_top5 = AverageMeter(), AverageMeter()
        self.model.eval()

        one_batch = None
        for step, (data, target) in enumerate(self.val_loader):
            # print('test step: {} total: {}'.format(step,max_test_iters))
            if one_batch == None:
                one_batch = data
            batchsize = data.shape[0]
            # print('get data',data.shape)
            target = target.cuda(non_blocking=True)
            #data, target = data.to(device), target.to(device)

            _, logits = self.model.forward(data, arch)  #_with_architect

            prec1, prec5 = obtain_accuracy(logits.data,
                                           target.data,
                                           topk=(1, 5))
            base_top1.update(prec1.item(), batchsize)
            base_top5.update(prec5.item(), batchsize)

            del data, target, logits, prec1, prec5

        if self.lambda_t > 0.0:
            start_time = time.time()
            len_batch = min(len(one_batch), 50)
            for i in range(len_batch):
                _, _ = self.model.forward(one_batch[i:i + 1, :, :, :], arch)
            end_time = time.time()
            time_per = (end_time - start_time) / len_batch
        else:
            time_per = 0.0

        #print('top1: {:.2f} top5: {:.2f}'.format(base_top1.avg * 100, base_top5.avg * 100))
        return base_top1.avg, base_top5.avg, time_per
Example #19
0
def get_arch_acc(xloader, network, sampled_arch):
    with torch.no_grad():
        network.eval()
        loader_iter = iter(xloader)
        network.module.set_cal_mode('dynamic', sampled_arch)
        try:
            inputs, targets = next(loader_iter)
        except:
            loader_iter = iter(xloader)
            inputs, targets = next(loader_iter)

        _, logits = network(inputs)
        val_top1, val_top5 = obtain_accuracy(logits.cpu().data,
                                             targets.data,
                                             topk=(1, 5))
        return val_top1.item()
Example #20
0
def search_func(xloader, network, criterion, scheduler, w_optimizer, epoch_str,
                print_freq, logger):
    data_time, batch_time = AverageMeter(), AverageMeter()
    base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    network.train()
    end = time.time()
    for step, (base_inputs, base_targets, arch_inputs,
               arch_targets) in enumerate(xloader):
        scheduler.update(None, 1.0 * step / len(xloader))
        base_targets = base_targets.cuda(non_blocking=True)
        arch_targets = arch_targets.cuda(non_blocking=True)
        # measure data loading time
        data_time.update(time.time() - end)

        # update the weights
        network.module.random_genotype(True)
        w_optimizer.zero_grad()
        _, logits = network(base_inputs)
        base_loss = criterion(logits, base_targets)
        base_loss.backward()
        nn.utils.clip_grad_norm_(network.parameters(), 5)
        w_optimizer.step()
        # record
        base_prec1, base_prec5 = obtain_accuracy(logits.data,
                                                 base_targets.data,
                                                 topk=(1, 5))
        base_losses.update(base_loss.item(), base_inputs.size(0))
        base_top1.update(base_prec1.item(), base_inputs.size(0))
        base_top5.update(base_prec5.item(), base_inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if step % print_freq == 0 or step + 1 == len(xloader):
            Sstr = (
                "*SEARCH* " + time_string() +
                " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)))
            Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
                batch_time=batch_time, data_time=data_time)
            Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
                loss=base_losses, top1=base_top1, top5=base_top5)
            logger.log(Sstr + " " + Tstr + " " + Wstr)
    return base_losses.avg, base_top1.avg, base_top5.avg
Example #21
0
def train_supernet(xloader, network, criterion, w_optimizer, epoch_str,
                   print_freq, logger):
    data_time, batch_time = AverageMeter(), AverageMeter()
    base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    network.train()
    end = time.time()
    for step, (X, Y) in enumerate(xloader):
        # measure data loading time
        X, Y = X.cuda(non_blocking=True), Y.cuda(non_blocking=True)
        data_time.update(time.time() - end)

        # update the weights
        network.random_genotype(True)
        w_optimizer.zero_grad()
        _, logits = network(X)
        base_loss = criterion(logits, Y)
        base_loss.backward()
        nn.utils.clip_grad_norm_(network.parameters(), 5)
        w_optimizer.step()
        # record
        base_prec1, base_prec5 = obtain_accuracy(logits.data,
                                                 Y.data,
                                                 topk=(1, 5))
        base_losses.update(base_loss.item(), X.size(0))
        base_top1.update(base_prec1.item(), X.size(0))
        base_top5.update(base_prec5.item(), X.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if step % print_freq == 0 or step + 1 == len(xloader):
            Sstr = '*SEARCH* ' + ' [{:}][{:03d}/{:03d}]'.format(
                epoch_str, step, len(xloader))
            Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                batch_time=batch_time, data_time=data_time)
            Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(
                loss=base_losses, top1=base_top1, top5=base_top5)
            logger.info(Sstr + ' ' + Tstr + ' ' + Wstr)
    return base_losses.avg, base_top1.avg, base_top5.avg
def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode,
              config, extra_info, print_freq, logger):
    data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(
    ), AverageMeter(), AverageMeter(), AverageMeter()
    Ttop1, Ttop5 = AverageMeter(), AverageMeter()
    if mode == 'train':
        network.train()
    elif mode == 'valid':
        network.eval()
    else:
        raise ValueError("The mode is not right : {:}".format(mode))
    teacher.eval()

    logger.log(
        '[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]'
        .format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1,
                config.KD_alpha, config.KD_temperature))
    end = time.time()
    for i, (inputs, targets) in enumerate(xloader):
        if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
        # measure data loading time
        data_time.update(time.time() - end)
        # calculate prediction and loss
        targets = targets.cuda(non_blocking=True)

        if mode == 'train': optimizer.zero_grad()

        student_f, logits = network(inputs)
        if isinstance(logits, list):
            assert len(
                logits
            ) == 2, 'logits must has {:} items instead of {:}'.format(
                2, len(logits))
            logits, logits_aux = logits
        else:
            logits, logits_aux = logits, None
        with torch.no_grad():
            teacher_f, teacher_logits = teacher(inputs)

        loss = loss_KD_fn(criterion, logits, teacher_logits, student_f,
                          teacher_f, targets, config.KD_alpha,
                          config.KD_temperature)
        if config is not None and hasattr(
                config, 'auxiliary') and config.auxiliary > 0:
            loss_aux = criterion(logits_aux, targets)
            loss += config.auxiliary * loss_aux

        if mode == 'train':
            loss.backward()
            optimizer.step()

        # record
        sprec1, sprec5 = obtain_accuracy(logits.data,
                                         targets.data,
                                         topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(sprec1.item(), inputs.size(0))
        top5.update(sprec5.item(), inputs.size(0))
        # teacher
        tprec1, tprec5 = obtain_accuracy(teacher_logits.data,
                                         targets.data,
                                         topk=(1, 5))
        Ttop1.update(tprec1.item(), inputs.size(0))
        Ttop5.update(tprec5.item(), inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % print_freq == 0 or (i + 1) == len(xloader):
            Sstr = ' {:5s} '.format(
                mode.upper()) + time_string() + ' [{:}][{:03d}/{:03d}]'.format(
                    extra_info, i, len(xloader))
            if scheduler is not None:
                Sstr += ' {:}'.format(scheduler.get_min_info())
            Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                batch_time=batch_time, data_time=data_time)
            Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(
                loss=losses, top1=top1, top5=top5)
            Lstr += ' Teacher : acc@1={:.2f}, acc@5={:.2f}'.format(
                Ttop1.avg, Ttop5.avg)
            Istr = 'Size={:}'.format(list(inputs.size()))
            logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr)

    logger.log(' **{:5s}** accuracy drop :: @1={:.2f}, @5={:.2f}'.format(
        mode.upper(), Ttop1.avg - top1.avg, Ttop5.avg - top5.avg))
    logger.log(
        ' **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'
        .format(mode=mode.upper(),
                top1=top1,
                top5=top5,
                error1=100 - top1.avg,
                error5=100 - top5.avg,
                loss=losses.avg))
    return losses.avg, top1.avg, top5.avg
Example #23
0
def search_func(xloader,
                network,
                criterion,
                scheduler,
                w_optimizer,
                a_optimizer,
                epoch_str,
                xargs,
                logger,
                ood_loader=None):
    data_time, batch_time = AverageMeter(), AverageMeter()
    base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    network.train()
    end = time.time()
    for step, (base_inputs, base_targets, arch_inputs,
               arch_targets) in enumerate(xloader):
        base_inputs = base_inputs.cuda(non_blocking=True)
        arch_inputs = arch_inputs.cuda(non_blocking=True)
        if xargs.adv_outer:
            arch_inputs.requires_grad = True
        scheduler.update(None, 1.0 * step / len(xloader))
        base_targets = base_targets.cuda(non_blocking=True)
        arch_targets = arch_targets.cuda(non_blocking=True)

        if xargs.ood_inner or xargs.ood_outer:
            try:
                ood_input, _ = next(ood_loader_iter)
            except:
                ood_loader_iter = iter(ood_loader)
                ood_input, _ = next(ood_loader_iter)
            ood_input = ood_input.cuda(non_blocking=True)

        # measure data loading time
        data_time.update(time.time() - end)

        # update the weights
        w_optimizer.zero_grad()
        _, logits, _, _ = network(base_inputs)
        base_loss = criterion(logits, base_targets)
        if xargs.ood_inner and ood_loader is not None:
            _, ood_logits, _, _ = network(ood_input)
            ood_loss = F.kl_div(input=F.log_softmax(ood_logits, dim=-1),
                                target=torch.ones_like(ood_logits) /
                                ood_logits.size()[-1])
            base_loss += ood_loss
        base_loss.backward()
        torch.nn.utils.clip_grad_norm_(network.parameters(), 5)
        w_optimizer.step()
        # record
        base_prec1, base_prec5 = obtain_accuracy(logits.data,
                                                 base_targets.data,
                                                 topk=(1, 5))
        base_losses.update(base_loss.item(), base_inputs.size(0))
        base_top1.update(base_prec1.item(), base_inputs.size(0))
        base_top5.update(base_prec5.item(), base_inputs.size(0))

        # update the architecture-weight
        a_optimizer.zero_grad()
        grads = {}
        loss_data = {}
        # ---- acc loss ----
        _, acc_logits, nop_loss, flp_loss = network(arch_inputs)
        acc_loss = criterion(acc_logits, arch_targets)
        loss_data['acc'] = acc_loss.item()
        grads['acc'] = list(
            torch.autograd.grad(acc_loss,
                                network.get_alphas(),
                                retain_graph=True))
        # del acc_logits
        # ---- end ----

        # ---- nop loss ----
        if xargs.nop_outer:
            if xargs.nop_constrain == 'abs':
                nop_loss = torch.abs(xargs.nop_constrain_min - nop_loss)
            loss_data['nop'] = nop_loss.item()
            grads['nop'] = list(
                torch.autograd.grad(nop_loss,
                                    network.get_alphas(),
                                    retain_graph=True))
        # ---- end ----

        # ---- flp loss ----
        if xargs.flp_outer:
            if xargs.flp_constrain == 'abs':
                flp_loss = torch.abs(xargs.flp_constrain_min - flp_loss)
            loss_data['flp'] = flp_loss.item()
            grads['flp'] = list(
                torch.autograd.grad(flp_loss,
                                    network.get_alphas(),
                                    retain_graph=True))
        # ---- end ----

        # ---- ood loss ----
        if xargs.ood_outer and ood_loader is not None:
            _, ood_logits, _, _ = network(ood_input)
            ood_loss = F.kl_div(input=F.log_softmax(ood_logits),
                                target=torch.ones_like(ood_logits) /
                                ood_logits.size()[-1])
            loss_data['ood'] = ood_loss.item()
            grads['ood'] = list(
                torch.autograd.grad(ood_loss,
                                    network.get_alphas(),
                                    retain_graph=True))
            del ood_logits
        # ---- end ----

        # ---- adv loss ----
        if xargs.adv_outer:
            if xargs.dataset == 'cifar10':
                mean = (0.4914, 0.4822, 0.4465)
                std = (0.2471, 0.2435, 0.2616)
            elif xargs.dataset == 'cifar100':
                mean = (0.5071, 0.4867, 0.4408)
                std = (0.2675, 0.2565, 0.2761)
            mean = torch.FloatTensor(mean).view(3, 1, 1)
            std = torch.FloatTensor(std).view(3, 1, 1)
            upper_limit = ((1 - mean) / std).cuda()
            lower_limit = ((0 - mean) / std).cuda()
            epsilon = ((xargs.epsilon / 255.) / std).cuda()
            step_size = epsilon * 1.25
            delta = (
                (torch.rand(arch_inputs.size()) - 0.5) * 2).cuda() * epsilon
            adv_grad = torch.autograd.grad(acc_loss,
                                           arch_inputs,
                                           retain_graph=True,
                                           create_graph=False)[0]
            adv_grad = adv_grad.detach().data
            delta = clamp(delta + step_size * torch.sign(adv_grad), -epsilon,
                          epsilon)
            delta = clamp(delta, lower_limit - arch_inputs.data,
                          upper_limit - arch_inputs.data)
            adv_input = (arch_inputs.data + delta).cuda()
            _, adv_logits, _, _ = network(adv_input)
            adv_loss = criterion(adv_logits, arch_targets)
            loss_data['adv'] = adv_loss.item()
            grads['adv'] = list(
                torch.autograd.grad(adv_loss,
                                    network.get_alphas(),
                                    retain_graph=True))
            del mean, std, upper_limit, lower_limit, epsilon, step_size, delta, adv_grad, adv_input, adv_logits
        # ---- end ----

        # ---- MGDA ----
        gn = gradient_normalizers(
            grads, loss_data,
            normalization_type=xargs.grad_norm)  # loss+, loss, l2

        for t in grads:
            for gr_i in range(len(grads[t])):
                grads[t][gr_i] = grads[t][gr_i] / (gn[t] + 1e-7)

        if xargs.MGDA and (len(grads) > 1):
            sol, _ = MinNormSolver.find_min_norm_element(
                [grads[t] for t in grads])
            print(sol)  # acc, adv, nop
        else:
            sol = [1] * len(grads)

        arch_loss = 0
        for kk, t in enumerate(grads):
            if t == 'acc':
                arch_loss += float(sol[kk]) * acc_loss
            elif t == 'adv':
                arch_loss += float(sol[kk]) * adv_loss
            elif t == 'nop':
                arch_loss += float(sol[kk]) * nop_loss
            elif t == 'ood':
                arch_loss += float(sol[kk]) * ood_loss
            elif t == 'flp':
                arch_loss += float(sol[kk]) * flp_loss
        # ---- end ----

        arch_loss.backward()
        a_optimizer.step()
        # record
        arch_prec1, arch_prec5 = obtain_accuracy(acc_logits.data,
                                                 arch_targets.data,
                                                 topk=(1, 5))
        arch_losses.update(arch_loss.item(), arch_inputs.size(0))
        arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
        arch_top5.update(arch_prec5.item(), arch_inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if step % xargs.print_freq == 0 or step + 1 == len(xloader):
            Sstr = '*SEARCH* ' + time_string(
            ) + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
            Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                batch_time=batch_time, data_time=data_time)
            Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(
                loss=base_losses, top1=base_top1, top5=base_top5)
            Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(
                loss=arch_losses, top1=arch_top1, top5=arch_top5)
            logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
    return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
Example #24
0
def train_controller(xloader, network, criterion, optimizer, prev_baseline,
                     epoch_str, print_freq, logger):
    # config. (containing some necessary arg)
    #   baseline: The baseline score (i.e. average val_acc) from the previous epoch
    data_time, batch_time = AverageMeter(), AverageMeter()
    GradnormMeter, LossMeter, ValAccMeter, EntropyMeter, BaselineMeter, RewardMeter, xend = AverageMeter(
    ), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(
    ), AverageMeter(), time.time()

    controller_num_aggregate = 20
    controller_train_steps = 50
    controller_bl_dec = 0.99
    controller_entropy_weight = 0.0001

    network.eval()
    network.controller.train()
    network.controller.zero_grad()
    loader_iter = iter(xloader)
    for step in range(controller_train_steps * controller_num_aggregate):
        try:
            inputs, targets = next(loader_iter)
        except:
            loader_iter = iter(xloader)
            inputs, targets = next(loader_iter)
        inputs = inputs.cuda(non_blocking=True)
        targets = targets.cuda(non_blocking=True)
        # measure data loading time
        data_time.update(time.time() - xend)

        log_prob, entropy, sampled_arch = network.controller()
        with torch.no_grad():
            network.set_cal_mode('dynamic', sampled_arch)
            _, logits = network(inputs)
            val_top1, val_top5 = obtain_accuracy(logits.data,
                                                 targets.data,
                                                 topk=(1, 5))
            val_top1 = val_top1.view(-1) / 100
        reward = val_top1 + controller_entropy_weight * entropy
        if prev_baseline is None:
            baseline = val_top1
        else:
            baseline = prev_baseline - (1 - controller_bl_dec) * (
                prev_baseline - reward)

        loss = -1 * log_prob * (reward - baseline)

        # account
        RewardMeter.update(reward.item())
        BaselineMeter.update(baseline.item())
        ValAccMeter.update(val_top1.item() * 100)
        LossMeter.update(loss.item())
        EntropyMeter.update(entropy.item())

        # Average gradient over controller_num_aggregate samples
        loss = loss / controller_num_aggregate
        loss.backward(retain_graph=True)

        # measure elapsed time
        batch_time.update(time.time() - xend)
        xend = time.time()
        if (step + 1) % controller_num_aggregate == 0:
            grad_norm = torch.nn.utils.clip_grad_norm_(
                network.controller.parameters(), 5.0)
            GradnormMeter.update(grad_norm)
            optimizer.step()
            network.controller.zero_grad()

        if step % print_freq == 0:
            Sstr = '*Train-Controller* ' + time_string(
            ) + ' [{:}][{:03d}/{:03d}]'.format(
                epoch_str, step,
                controller_train_steps * controller_num_aggregate)
            Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                batch_time=batch_time, data_time=data_time)
            Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})'.format(
                loss=LossMeter,
                top1=ValAccMeter,
                reward=RewardMeter,
                basel=BaselineMeter)
            Estr = 'Entropy={:.4f} ({:.4f})'.format(EntropyMeter.val,
                                                    EntropyMeter.avg)
            logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Estr)

    return LossMeter.avg, ValAccMeter.avg, BaselineMeter.avg, RewardMeter.avg
Example #25
0
def search_func(xloader, network, criterion, scheduler, w_optimizer,
                a_optimizer, epoch_str, print_freq, algo, logger):
    data_time, batch_time = AverageMeter(), AverageMeter()
    base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    end = time.time()
    network.train()
    for step, (base_inputs, base_targets, arch_inputs,
               arch_targets) in enumerate(xloader):
        scheduler.update(None, 1.0 * step / len(xloader))
        base_inputs = base_inputs.cuda(non_blocking=True)
        arch_inputs = arch_inputs.cuda(non_blocking=True)
        base_targets = base_targets.cuda(non_blocking=True)
        arch_targets = arch_targets.cuda(non_blocking=True)
        # measure data loading time
        data_time.update(time.time() - end)

        # Update the weights
        if algo == 'setn':
            sampled_arch = network.dync_genotype(True)
            network.set_cal_mode('dynamic', sampled_arch)
        elif algo == 'gdas':
            network.set_cal_mode('gdas', None)
        elif algo.startswith('darts'):
            network.set_cal_mode('joint', None)
        elif algo == 'random':
            network.set_cal_mode('urs', None)
        elif algo == 'enas':
            with torch.no_grad():
                network.controller.eval()
                _, _, sampled_arch = network.controller()
            network.set_cal_mode('dynamic', sampled_arch)
        else:
            raise ValueError('Invalid algo name : {:}'.format(algo))

        network.zero_grad()
        _, logits = network(base_inputs)
        base_loss = criterion(logits, base_targets)
        base_loss.backward()
        w_optimizer.step()
        # record
        base_prec1, base_prec5 = obtain_accuracy(logits.data,
                                                 base_targets.data,
                                                 topk=(1, 5))
        base_losses.update(base_loss.item(), base_inputs.size(0))
        base_top1.update(base_prec1.item(), base_inputs.size(0))
        base_top5.update(base_prec5.item(), base_inputs.size(0))

        # update the architecture-weight
        if algo == 'setn':
            network.set_cal_mode('joint')
        elif algo == 'gdas':
            network.set_cal_mode('gdas', None)
        elif algo.startswith('darts'):
            network.set_cal_mode('joint', None)
        elif algo == 'random':
            network.set_cal_mode('urs', None)
        elif algo != 'enas':
            raise ValueError('Invalid algo name : {:}'.format(algo))
        network.zero_grad()
        if algo == 'darts-v2':
            arch_loss, logits = backward_step_unrolled(
                network, criterion, base_inputs, base_targets, w_optimizer,
                arch_inputs, arch_targets)
            a_optimizer.step()
        elif algo == 'random' or algo == 'enas':
            with torch.no_grad():
                _, logits = network(arch_inputs)
                arch_loss = criterion(logits, arch_targets)
        else:
            _, logits = network(arch_inputs)
            arch_loss = criterion(logits, arch_targets)
            arch_loss.backward()
            a_optimizer.step()
        # record
        arch_prec1, arch_prec5 = obtain_accuracy(logits.data,
                                                 arch_targets.data,
                                                 topk=(1, 5))
        arch_losses.update(arch_loss.item(), arch_inputs.size(0))
        arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
        arch_top5.update(arch_prec5.item(), arch_inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if step % print_freq == 0 or step + 1 == len(xloader):
            Sstr = '*SEARCH* ' + time_string(
            ) + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
            Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                batch_time=batch_time, data_time=data_time)
            Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(
                loss=base_losses, top1=base_top1, top5=base_top5)
            Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(
                loss=arch_losses, top1=arch_top1, top5=arch_top5)
            logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
    return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
Example #26
0
def search_train(search_loader, network, criterion, scheduler, base_optimizer,
                 arch_optimizer, optim_config, extra_info, print_freq, logger):
    data_time, batch_time = AverageMeter(), AverageMeter()
    base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(
    ), AverageMeter(), AverageMeter()
    arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
    epoch_str, flop_need, flop_weight, flop_tolerant = extra_info[
        'epoch-str'], extra_info['FLOP-exp'], extra_info[
            'FLOP-weight'], extra_info['FLOP-tolerant']

    network.train()
    logger.log(
        '[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}'.format(
            epoch_str, flop_need, flop_weight))
    end = time.time()
    network.apply(change_key('search_mode', 'search'))
    for step, (base_inputs, base_targets, arch_inputs,
               arch_targets) in enumerate(search_loader):
        scheduler.update(None, 1.0 * step / len(search_loader))
        # calculate prediction and loss
        base_targets = base_targets.cuda(non_blocking=True)
        arch_targets = arch_targets.cuda(non_blocking=True)
        # measure data loading time
        data_time.update(time.time() - end)

        # update the weights
        base_optimizer.zero_grad()
        logits, expected_flop = network(base_inputs)
        # network.apply( change_key('search_mode', 'basic') )
        # features, logits = network(base_inputs)
        base_loss = criterion(logits, base_targets)
        base_loss.backward()
        base_optimizer.step()
        # record
        prec1, prec5 = obtain_accuracy(logits.data,
                                       base_targets.data,
                                       topk=(1, 5))
        base_losses.update(base_loss.item(), base_inputs.size(0))
        top1.update(prec1.item(), base_inputs.size(0))
        top5.update(prec5.item(), base_inputs.size(0))

        # update the architecture
        arch_optimizer.zero_grad()
        logits, expected_flop = network(arch_inputs)
        flop_cur = network.module.get_flop('genotype', None, None)
        flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur,
                                                   flop_need, flop_tolerant)
        acls_loss = criterion(logits, arch_targets)
        arch_loss = acls_loss + flop_loss * flop_weight
        arch_loss.backward()
        arch_optimizer.step()

        # record
        arch_losses.update(arch_loss.item(), arch_inputs.size(0))
        arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0))
        arch_cls_losses.update(acls_loss.item(), arch_inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % print_freq == 0 or (step + 1) == len(search_loader):
            Sstr = '**TRAIN** ' + time_string(
            ) + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step,
                                               len(search_loader))
            Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                batch_time=batch_time, data_time=data_time)
            Lstr = 'Base-Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(
                loss=base_losses, top1=top1, top5=top5)
            Vstr = 'Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})'.format(
                aloss=arch_cls_losses,
                floss=arch_flop_losses,
                loss=arch_losses)
            logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr)
            # Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size()))
            # logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr)
            # print(network.module.get_arch_info())
            # print(network.module.width_attentions[0])
            # print(network.module.width_attentions[1])

    logger.log(
        ' **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}'
        .format(top1=top1,
                top5=top5,
                error1=100 - top1.avg,
                error5=100 - top5.avg,
                baseloss=base_losses.avg,
                archloss=arch_losses.avg))
    return base_losses.avg, arch_losses.avg, top1.avg, top5.avg
Example #27
0
arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(
), AverageMeter()
time_start = time.time()
time_pre = time.time()

search_model.eval()
# search_model.eval()
for step, (base_inputs, base_targets) in enumerate(valid_loader):
    base_targets = base_targets.cuda(non_blocking=True)
    # print('in',base_inputs[0])

    # optim.zero_grad()
    with torch.no_grad():
        _, logits = search_model(base_inputs.cuda())

        arch_prec1 = obtain_accuracy(logits.data, base_targets.data)
    arch_top1.update(arch_prec1[0])

print('val_acc %.2f used %.2fs' % (arch_top1.avg, time.time() - time_pre))
time_pre = time.time()

reg_lambda = 0.001
for epoch in range(10000):
    search_model.train()
    for step, (base_inputs, base_targets, arch_inputs,
               arch_targets) in enumerate(search_loader):
        base_targets = base_targets.cuda(non_blocking=True)
        arch_targets = arch_targets.cuda(non_blocking=True)
        # print(torch.mean(base_targets.float()))

        optim.zero_grad()
Example #28
0
def search_find_best(xloader, train_loader, network, predictor, optimizer,
                     logger, args):
    bn_iter = 0 if not args.track_running_stat else args.bn_train_iters
    arch_pool = [Architect() for _ in range(args.init_pool_size)]
    valid_accs = {}
    best_arch, best_acc = None, 100.
    # network.eval()
    loader_iter = iter(xloader)
    while True:
        logger.info('arch pool: %d' % len(arch_pool))
        # evaluate unseen architectures
        with torch.no_grad():
            for arch in arch_pool:
                # arch_str = arch.struct.to_unique_str(True)
                arch_str = arch.struct.tostr()
                if arch_str in valid_accs:
                    continue
                network.set_genotype(arch.struct)

                if bn_iter > 0:
                    train_bn(train_loader, network, bn_iter)
                network.eval()

                arch_top1 = AverageMeter()
                try:
                    inputs, targets = next(loader_iter)
                except:
                    loader_iter = iter(xloader)
                    inputs, targets = next(loader_iter)

                inputs, targets = inputs.cuda(), targets.cuda()
                _, logits = network(inputs)
                val_top1 = obtain_accuracy(logits, targets.data)[0]
                arch_top1.update(val_top1.item(), targets.size(0))

                arch_err = 100. - arch_top1.avg
                valid_accs[arch_str] = arch_err
                if arch_err < best_acc:
                    best_arch = arch.struct.tostr()
                    best_acc = arch_err

        logger.info("best arch err ever: %2.2f" % best_acc)
        logger.info(best_arch)
        if len(arch_pool) >= args.max_samples:
            break

        # train predictor
        p_train_data = [
            arch2data(
                a.struct.tostr(),
                #   valid_accs[a.struct.to_unique_str(True)])
                valid_accs[a.struct.tostr()]) for a in arch_pool
        ]
        p_train_queue = gd.DataListLoader(p_train_data,
                                          args.p_batch_size,
                                          shuffle=True)
        for epoch in range(args.p_epochs):
            # train_predictor(predictor, p_train_data, optimizer, p_batch_size)
            predictor.fit(p_train_queue, optimizer, 0, None)

        # grad search
        checker = lambda arch: arch.struct.tostr() not in valid_accs
        new_trace = predictor.grad_step_on_archs(arch_pool,
                                                 args.step_batch_size,
                                                 args.step_size, checker)
        new_trace = sorted(new_trace, key=itemgetter(1))[:args.new_pop_limit]

        # new_trace = sorted(new_trace, key=itemgetter(1))[:args.max_samples]
        arch_pool = (arch_pool +
                     list(map(itemgetter(0), new_trace)))[:args.max_samples]

    return best_arch, best_acc
Example #29
0
def train_func(xloader, network, criterion, scheduler, w_optimizer, epoch_str,
               print_freq, archs, arch_iter, logger):
    data_time, batch_time = AverageMeter(), AverageMeter()
    base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    val_losses, val_top1, val_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    network.train()
    end = time.time()
    for step, (base_inputs, base_targets, val_inputs,
               val_targets) in enumerate(xloader):
        scheduler.update(None, 1.0 * step / len(xloader))
        try:
            arch = next(arch_iter)
        except:
            arch_iter = iter(archs)
            arch = next(arch_iter)
        base_inputs = base_inputs.cuda()
        base_targets = base_targets.cuda(non_blocking=True)
        val_inputs = val_inputs.cuda()
        val_targets = val_targets.cuda(non_blocking=True)
        # measure data loading time
        data_time.update(time.time() - end)

        # update the weights
        w_optimizer.zero_grad()
        _, logits, _ = network(base_inputs)  #, arch)
        base_loss = criterion(logits, base_targets)
        base_loss.backward()
        torch.nn.utils.clip_grad_norm_(network.parameters(), 5)
        w_optimizer.step()
        # record
        base_prec1, base_prec5 = obtain_accuracy(logits.data,
                                                 base_targets.data,
                                                 topk=(1, 5))
        base_losses.update(base_loss.item(), base_inputs.size(0))
        base_top1.update(base_prec1.item(), base_inputs.size(0))
        base_top5.update(base_prec5.item(), base_inputs.size(0))

        # validate arch
        _, logits, _ = network(val_inputs, arch)
        val_loss = criterion(logits, val_targets)
        # record
        val_prec1, val_prec5 = obtain_accuracy(logits.data,
                                               val_targets.data,
                                               topk=(1, 5))
        val_losses.update(val_loss.item(), val_inputs.size(0))
        val_top1.update(val_prec1.item(), val_inputs.size(0))
        val_top5.update(val_prec5.item(), val_inputs.size(0))
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if step % print_freq == 0 or step + 1 == len(xloader):
            Sstr = '*TRAIN* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(
                epoch_str, step, len(xloader))
            Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                batch_time=batch_time, data_time=data_time)
            Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(
                loss=base_losses, top1=base_top1, top5=base_top5)
            Astr = 'Val  [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(
                loss=val_losses, top1=val_top1, top5=val_top5)
            logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
    return base_losses.avg, base_top1.avg, base_top5.avg, val_losses.avg, val_top1.avg, val_top5.avg, arch_iter
Example #30
0
def main(xargs):
    assert torch.cuda.is_available(), 'CUDA is not available.'
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.set_num_threads(xargs.workers)
    prepare_seed(xargs.rand_seed)
    logger = prepare_logger(args)

    train_data, valid_data, xshape, class_num = get_datasets(
        xargs.dataset, xargs.data_path, -1)
    #config_path = 'configs/nas-benchmark/algos/GDAS.config'
    config = load_config(xargs.config_path, {
        'class_num': class_num,
        'xshape': xshape
    }, logger)
    search_loader, _, valid_loader = get_nas_search_loaders(
        train_data, valid_data, xargs.dataset, 'configs/nas-benchmark/',
        config.batch_size, xargs.workers)
    logger.log(
        '||||||| {:10s} ||||||| Search-Loader-Num={:}, batch size={:}'.format(
            xargs.dataset, len(search_loader), config.batch_size))
    logger.log('||||||| {:10s} ||||||| Config={:}'.format(
        xargs.dataset, config))

    search_space = get_search_spaces('cell', xargs.search_space_name)
    if xargs.model_config is None and not args.constrain:
        model_config = dict2config(
            {
                'name': 'GDAS',
                'C': xargs.channel,
                'N': xargs.num_cells,
                'max_nodes': xargs.max_nodes,
                'num_classes': class_num,
                'space': search_space,
                'inp_size': 0,
                'affine': False,
                'track_running_stats': bool(xargs.track_running_stats)
            }, None)
    elif xargs.model_config is None:
        model_config = dict2config(
            {
                'name': 'GDAS',
                'C': xargs.channel,
                'N': xargs.num_cells,
                'max_nodes': xargs.max_nodes,
                'num_classes': class_num,
                'space': search_space,
                'inp_size': 32,
                'affine': False,
                'track_running_stats': bool(xargs.track_running_stats)
            }, None)
    else:
        model_config = load_config(
            xargs.model_config, {
                'num_classes': class_num,
                'space': search_space,
                'affine': False,
                'track_running_stats': bool(xargs.track_running_stats)
            }, None)
    search_model = get_cell_based_tiny_net(model_config)
    #logger.log('search-model :\n{:}'.format(search_model))
    logger.log('model-config : {:}'.format(model_config))

    w_optimizer, w_scheduler, criterion = get_optim_scheduler(
        search_model.get_weights(), config)
    a_optimizer = torch.optim.Adam(search_model.get_alphas(),
                                   lr=xargs.arch_learning_rate,
                                   betas=(0.5, 0.999),
                                   weight_decay=xargs.arch_weight_decay)
    logger.log('w-optimizer : {:}'.format(w_optimizer))
    logger.log('a-optimizer : {:}'.format(a_optimizer))
    logger.log('w-scheduler : {:}'.format(w_scheduler))
    logger.log('criterion   : {:}'.format(criterion))
    flop, param = get_model_infos(search_model, xshape)
    #logger.log('{:}'.format(search_model))
    logger.log('FLOP = {:.2f} M, Params = {:.2f} MB'.format(flop, param))
    logger.log('search-space [{:} ops] : {:}'.format(len(search_space),
                                                     search_space))
    if xargs.arch_nas_dataset is None:
        api = None
    else:
        api = API(xargs.arch_nas_dataset)
    logger.log('{:} create API = {:} done'.format(time_string(), api))

    last_info, model_base_path, model_best_path = logger.path(
        'info'), logger.path('model'), logger.path('best')
    network, criterion = torch.nn.DataParallel(
        search_model).cuda(), criterion.cuda()
    #network, criterion = search_model.cuda(), criterion.cuda()

    if last_info.exists():  # automatically resume from previous checkpoint
        logger.log("=> loading checkpoint of the last-info '{:}' start".format(
            last_info))
        last_info = torch.load(last_info)
        start_epoch = last_info['epoch']
        checkpoint = torch.load(last_info['last_checkpoint'])
        genotypes = checkpoint['genotypes']
        valid_accuracies = checkpoint['valid_accuracies']
        search_model.load_state_dict(checkpoint['search_model'])
        w_scheduler.load_state_dict(checkpoint['w_scheduler'])
        w_optimizer.load_state_dict(checkpoint['w_optimizer'])
        a_optimizer.load_state_dict(checkpoint['a_optimizer'])
        logger.log(
            "=> loading checkpoint of the last-info '{:}' start with {:}-th epoch."
            .format(last_info, start_epoch))
    else:
        logger.log("=> do not find the last-info file : {:}".format(last_info))
        start_epoch, valid_accuracies, genotypes = 0, {
            'best': -1
        }, {
            -1: search_model.genotype()
        }

    # start training
    start_time, search_time, epoch_time, total_epoch = time.time(
    ), AverageMeter(), AverageMeter(), config.epochs + config.warmup
    sampled_weights = []
    for epoch in range(start_epoch, total_epoch + config.t_epochs):
        w_scheduler.update(epoch, 0.0)
        need_time = 'Time Left: {:}'.format(
            convert_secs2time(
                epoch_time.val * (total_epoch - epoch + config.t_epochs),
                True))
        epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
        search_model.set_tau(xargs.tau_max -
                             (xargs.tau_max - xargs.tau_min) * epoch /
                             (total_epoch - 1))
        logger.log('\n[Search the {:}-th epoch] {:}, tau={:}, LR={:}'.format(
            epoch_str, need_time, search_model.get_tau(),
            min(w_scheduler.get_lr())))
        if epoch < total_epoch:
            search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5 \
                      = search_func(search_loader, network, criterion, w_scheduler, w_optimizer, a_optimizer, epoch_str, xargs.print_freq, logger, xargs.bilevel)
        else:
            search_w_loss, search_w_top1, search_w_top5, valid_a_loss , valid_a_top1 , valid_a_top5, arch_iter \
                       = train_func(search_loader, network, criterion, w_scheduler, w_optimizer, epoch_str, xargs.print_freq, sampled_weights[0], arch_iter, logger)
        search_time.update(time.time() - start_time)
        logger.log(
            '[{:}] searching : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%, time-cost={:.1f} s'
            .format(epoch_str, search_w_loss, search_w_top1, search_w_top5,
                    search_time.sum))
        logger.log(
            '[{:}] evaluate  : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'
            .format(epoch_str, valid_a_loss, valid_a_top1, valid_a_top5))

        if (epoch + 1) % 50 == 0 and not config.t_epochs:
            weights = search_model.sample_weights(100)
            sampled_weights.append(weights)
        elif (epoch + 1) == total_epoch and config.t_epochs:
            weights = search_model.sample_weights(100)
            sampled_weights.append(weights)
            arch_iter = iter(weights)
        # validate with single arch
        single_weight = search_model.sample_weights(1)[0]
        single_valid_acc = AverageMeter()
        network.eval()
        for i in range(10):
            try:
                val_input, val_target = next(valid_iter)
            except Exception as e:
                valid_iter = iter(valid_loader)
                val_input, val_target = next(valid_iter)
            n_val = val_input.size(0)
            with torch.no_grad():
                val_target = val_target.cuda(non_blocking=True)
                _, logits, _ = network(val_input, weights=single_weight)
                val_acc1, val_acc5 = obtain_accuracy(logits.data,
                                                     val_target.data,
                                                     topk=(1, 5))
                single_valid_acc.update(val_acc1.item(), n_val)
        logger.log('[{:}] valid : accuracy = {:.2f}'.format(
            epoch_str, single_valid_acc.avg))

        # check the best accuracy
        valid_accuracies[epoch] = valid_a_top1
        if valid_a_top1 > valid_accuracies['best']:
            valid_accuracies['best'] = valid_a_top1
            genotypes['best'] = search_model.genotype()
            find_best = True
        else:
            find_best = False

        if epoch < total_epoch:
            genotypes[epoch] = search_model.genotype()
            logger.log('<<<--->>> The {:}-th epoch : {:}'.format(
                epoch_str, genotypes[epoch]))
        # save checkpoint
        save_path = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(xargs),
                'search_model': search_model.state_dict(),
                'w_optimizer': w_optimizer.state_dict(),
                'a_optimizer': a_optimizer.state_dict(),
                'w_scheduler': w_scheduler.state_dict(),
                'genotypes': genotypes,
                'valid_accuracies': valid_accuracies
            }, model_base_path, logger)
        last_info = save_checkpoint(
            {
                'epoch': epoch + 1,
                'args': deepcopy(args),
                'last_checkpoint': save_path,
            }, logger.path('info'), logger)
        if find_best:
            logger.log(
                '<<<--->>> The {:}-th epoch : find the highest validation accuracy : {:.2f}%.'
                .format(epoch_str, valid_a_top1))
            copy_checkpoint(model_base_path, model_best_path, logger)
        with torch.no_grad():
            logger.log('{:}'.format(search_model.show_alphas()))
        if api is not None and epoch < total_epoch:
            logger.log('{:}'.format(api.query_by_arch(genotypes[epoch])))

        # measure elapsed time
        epoch_time.update(time.time() - start_time)
        start_time = time.time()

    network.eval()
    # Evaluate the architectures sampled throughout the search
    for i in range(len(sampled_weights) - 1):
        logger.log('Sample eval : epoch {}'.format((i + 1) * 50 - 1))
        for w in sampled_weights[i]:
            sample_valid_acc = AverageMeter()
            for i in range(10):
                try:
                    val_input, val_target = next(valid_iter)
                except Exception as e:
                    valid_iter = iter(valid_loader)
                    val_input, val_target = next(valid_iter)
                n_val = val_input.size(0)
                with torch.no_grad():
                    val_target = val_target.cuda(non_blocking=True)
                    _, logits, _ = network(val_input, weights=w)
                    val_acc1, val_acc5 = obtain_accuracy(logits.data,
                                                         val_target.data,
                                                         topk=(1, 5))
                    sample_valid_acc.update(val_acc1.item(), n_val)
            w_gene = search_model.genotype(w)
            if api is not None:
                ind = api.query_index_by_arch(w_gene)
                info = api.query_meta_info_by_index(ind)
                metrics = info.get_metrics('cifar10', 'ori-test')
                acc = metrics['accuracy']
            else:
                acc = 0.0
            logger.log(
                'sample valid : val_acc = {:.2f} test_acc = {:.2f}'.format(
                    sample_valid_acc.avg, acc))
    # Evaluate the final sampling separately to find the top 10 architectures
    logger.log('Final sample eval')
    final_archs = []
    for w in sampled_weights[-1]:
        sample_valid_acc = AverageMeter()
        for i in range(10):
            try:
                val_input, val_target = next(valid_iter)
            except Exception as e:
                valid_iter = iter(valid_loader)
                val_input, val_target = next(valid_iter)
            n_val = val_input.size(0)
            with torch.no_grad():
                val_target = val_target.cuda(non_blocking=True)
                _, logits, _ = network(val_input, weights=w)
                val_acc1, val_acc5 = obtain_accuracy(logits.data,
                                                     val_target.data,
                                                     topk=(1, 5))
                sample_valid_acc.update(val_acc1.item(), n_val)
        w_gene = search_model.genotype(w)
        if api is not None:
            ind = api.query_index_by_arch(w_gene)
            info = api.query_meta_info_by_index(ind)
            metrics = info.get_metrics('cifar10', 'ori-test')
            acc = metrics['accuracy']
        else:
            acc = 0.0
        logger.log('sample valid : val_acc = {:.2f} test_acc = {:.2f}'.format(
            sample_valid_acc.avg, acc))
        final_archs.append((w, sample_valid_acc.avg))
    top_10 = sorted(final_archs, key=lambda x: x[1], reverse=True)[:10]
    # Evaluate the top 10 architectures on the entire validation set
    logger.log('Evaluating top archs')
    for w, prev_acc in top_10:
        full_valid_acc = AverageMeter()
        for val_input, val_target in valid_loader:
            n_val = val_input.size(0)
            with torch.no_grad():
                val_target = val_target.cuda(non_blocking=True)
                _, logits, _ = network(val_input, weights=w)
                val_acc1, val_acc5 = obtain_accuracy(logits.data,
                                                     val_target.data,
                                                     topk=(1, 5))
                full_valid_acc.update(val_acc1.item(), n_val)
        w_gene = search_model.genotype(w)
        logger.log('genotype {}'.format(w_gene))
        if api is not None:
            ind = api.query_index_by_arch(w_gene)
            info = api.query_meta_info_by_index(ind)
            metrics = info.get_metrics('cifar10', 'ori-test')
            acc = metrics['accuracy']
        else:
            acc = 0.0
        logger.log(
            'full valid : val_acc = {:.2f} test_acc = {:.2f} pval_acc = {:.2f}'
            .format(full_valid_acc.avg, acc, prev_acc))

    logger.log('\n' + '-' * 100)
    # check the performance from the architecture dataset
    logger.log(
        'GDAS : run {:} epochs, cost {:.1f} s, last-geno is {:}.'.format(
            total_epoch, search_time.sum, genotypes[total_epoch - 1]))
    if api is not None:
        logger.log('{:}'.format(api.query_by_arch(genotypes[total_epoch - 1])))
    logger.close()