Exemple #1
0
def test(args, model, device, test_loader_creator, logger):
    model.eval()

    criterion = torch.nn.CrossEntropyLoss().to(device)

    with torch.no_grad():
        losses = AverageMeter()
        acc = AverageMeter()

        for test_loader in test_loader_creator.data_loaders:

            for data, target in test_loader:

                data, target = data.to(device), target.to(device)
                _, output = model(data)

                loss = criterion(output, target)

                output = output.float()
                loss = loss.float()

                it_acc = accuracy(output.data, target)[0]
                losses.update(loss.item(), data.size(0))
                acc.update(it_acc.item(), data.size(0))

    logger.info('Test set: Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                'Acc {acc.avg:.3f}'.format(loss=losses, acc=acc))
def valid_func(xloader, network, criterion):
    data_time, batch_time = AverageMeter(), AverageMeter()
    arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    network.eval()
    end = time.time()
    with torch.no_grad():
        for step, (arch_inputs, arch_targets) in enumerate(xloader):
            arch_targets = arch_targets.cuda(non_blocking=True)
            # measure data loading time
            data_time.update(time.time() - end)
            # prediction
            _, logits = network(arch_inputs)
            arch_loss = criterion(logits, arch_targets)
            # record
            arch_prec1, arch_prec5 = obtain_accuracy(logits.data,
                                                     arch_targets.data,
                                                     topk=(1, 5))
            arch_losses.update(arch_loss.item(), arch_inputs.size(0))
            arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
            arch_top5.update(arch_prec5.item(), arch_inputs.size(0))
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
    return arch_losses.avg, arch_top1.avg, arch_top5.avg
Exemple #3
0
def train_shared_cnn(xloader, shared_cnn, criterion, scheduler, optimizer,
                     print_freq, logger, config, start_epoch):
    # start training
    start_time, epoch_time, total_epoch = time.time(), AverageMeter(
    ), config.epochs + config.warmup
    for epoch in range(start_epoch, total_epoch):
        scheduler.update(epoch, 0.0)
        need_time = 'Time Left: {:}'.format(
            convert_secs2time(epoch_time.val * (total_epoch - epoch), True))
        epoch_str = '{:03d}-{:03d}'.format(epoch, total_epoch)
        logger.log('\n[Traing the {:}-th epoch] {:}, LR={:}'.format(
            epoch_str, need_time, min(scheduler.get_lr())))

        data_time, batch_time = AverageMeter(), AverageMeter()
        losses, top1s, top5s, xend = AverageMeter(), AverageMeter(
        ), AverageMeter(), time.time()

        shared_cnn.train()

        for step, (inputs, targets) in enumerate(xloader):
            scheduler.update(None, 1.0 * step / len(xloader))
            targets = targets.cuda(non_blocking=True)
            # measure data loading time
            data_time.update(time.time() - xend)

            optimizer.zero_grad()
            _, logits = shared_cnn(inputs)
            loss = criterion(logits, targets)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), 5)
            optimizer.step()
            # record
            prec1, prec5 = obtain_accuracy(logits.data,
                                           targets.data,
                                           topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1s.update(prec1.item(), inputs.size(0))
            top5s.update(prec5.item(), inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - xend)
            xend = time.time()

            if step % print_freq == 0 or step + 1 == len(xloader):
                Sstr = '*Train-Shared-CNN* ' + time_string(
                ) + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step,
                                                   len(xloader))
                Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                    batch_time=batch_time, data_time=data_time)
                Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(
                    loss=losses, top1=top1s, top5=top5s)
                logger.log(Sstr + ' ' + Tstr + ' ' + Wstr)

        cnn_loss, cnn_top1, cnn_top5 = losses.avg, top1s.avg, top5s.avg
        logger.log(
            '[{:}] shared-cnn : loss={:.2f}, accuracy@1={:.2f}%, accuracy@5={:.2f}%'
            .format(epoch_str, cnn_loss, cnn_top1, cnn_top5))
        epoch_time.update(time.time() - start_time)
        start_time = time.time()
    return
Exemple #4
0
 def train_or_test_epoch(self,
                         xloader,
                         model,
                         loss_fn,
                         metric_fn,
                         is_train,
                         optimizer=None):
     if is_train:
         model.train()
     else:
         model.eval()
     score_meter, loss_meter = AverageMeter(), AverageMeter()
     for ibatch, (feats, labels) in enumerate(xloader):
         feats = feats.to(self.device, non_blocking=True)
         labels = labels.to(self.device, non_blocking=True)
         # forward the network
         preds = model(feats)
         loss = loss_fn(preds, labels)
         with torch.no_grad():
             score = self.metric_fn(preds, labels)
             loss_meter.update(loss.item(), feats.size(0))
             score_meter.update(score.item(), feats.size(0))
         # optimize the network
         if is_train and optimizer is not None:
             optimizer.zero_grad()
             loss.backward()
             torch.nn.utils.clip_grad_value_(model.parameters(), 3.0)
             optimizer.step()
     return loss_meter.avg, score_meter.avg
def pure_evaluate(xloader, network, criterion=torch.nn.CrossEntropyLoss()):
    data_time, batch_time, batch = AverageMeter(), AverageMeter(), None
    losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
    latencies, device = [], torch.cuda.current_device()
    network.eval()
    with torch.no_grad():
        end = time.time()
        for i, (inputs, targets) in enumerate(xloader):
            targets = targets.cuda(device=device, non_blocking=True)
            inputs = inputs.cuda(device=device, non_blocking=True)
            data_time.update(time.time() - end)
            # forward
            features, logits = network(inputs)
            loss = criterion(logits, targets)
            batch_time.update(time.time() - end)
            if batch is None or batch == inputs.size(0):
                batch = inputs.size(0)
                latencies.append(batch_time.val - data_time.val)
            # record loss and accuracy
            prec1, prec5 = obtain_accuracy(logits.data,
                                           targets.data,
                                           topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))
            end = time.time()
    if len(latencies) > 2: latencies = latencies[1:]
    return losses.avg, top1.avg, top5.avg, latencies
def test_contrastive(args, model, nearest_proto_model, device,
                     test_loader_creator_l, logger):
    model.eval()

    acc = AverageMeter()
    tasks_acc = [
        AverageMeter() for i in range(len(test_loader_creator_l.data_loaders))
    ]

    test_loaders_l = test_loader_creator_l.data_loaders

    with torch.no_grad():
        for task_idx, test_loader_l in enumerate(test_loaders_l):

            for batch_idx, (data, _, target) in enumerate(test_loader_l):
                data, target = data.to(device), target.to(device)
                cur_feats, _ = model(data)
                output = nearest_proto_model.predict(cur_feats)
                it_acc = (output == target).sum().item() / data.shape[0]
                acc.update(it_acc, data.size(0))
                tasks_acc[task_idx].update(it_acc, data.size(0))

    if args.acc_per_task:
        tasks_acc_str = 'Tess Acc per task: '
        for i, task_acc in enumerate(tasks_acc):
            tasks_acc_str += 'Task{:2d} Acc: {acc.avg:.3f}'.format(
                (i + 1), acc=task_acc) + '\t'
        logger.info(tasks_acc_str)
    logger.info('Test Acc: {acc.avg:.3f}'.format(acc=acc))
def procedure(xloader, network, criterion, scheduler, optimizer, mode: str):
    losses, top1, top5 = AverageMeter(), AverageMeter(), AverageMeter()
    if mode == 'train':
        network.train()
    elif mode == 'valid':
        network.eval()
    else:
        raise ValueError("The mode is not right : {:}".format(mode))
    device = torch.cuda.current_device()
    data_time, batch_time, end = AverageMeter(), AverageMeter(), time.time()
    for i, (inputs, targets) in enumerate(xloader):
        if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))

        targets = targets.cuda(device=device, non_blocking=True)
        if mode == 'train': optimizer.zero_grad()
        # forward
        features, logits = network(inputs)
        loss = criterion(logits, targets)
        # backward
        if mode == 'train':
            loss.backward()
            optimizer.step()
        # record loss and accuracy
        prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(prec1.item(), inputs.size(0))
        top5.update(prec5.item(), inputs.size(0))
        # count time
        batch_time.update(time.time() - end)
        end = time.time()
    return losses.avg, top1.avg, top5.avg, batch_time.sum
Exemple #8
0
def train_bptt(num_epochs: int, model, dset_train, batch_size: int, T: int,
               w_checkpoint_freq: int, grad_clip: float, w_lr: float,
               logging_freq: int, sotl_order: int, hvp: str):
    model.train()
    train_loader = torch.utils.data.DataLoader(dset_train,
                                               batch_size=batch_size * T,
                                               shuffle=True)

    for epoch in range(num_epochs):
        epoch_loss = AverageMeter()
        true_batch_index = 0
        for batch_idx, batch in enumerate(train_loader):
            xs, ys = torch.split(batch[0], batch_size), torch.split(
                batch[1], batch_size)

            weight_buffer = WeightBuffer(T=T,
                                         checkpoint_freq=w_checkpoint_freq)
            for intra_batch_idx, (x, y) in enumerate(zip(xs, ys)):
                weight_buffer.add(model, intra_batch_idx)

                y_pred = model(x)
                loss = criterion(y_pred, y)
                epoch_loss.update(loss.item())

                grads = torch.autograd.grad(loss,
                                            model.weight_params(),
                                            retain_graph=True,
                                            allow_unused=True,
                                            create_graph=True)

                w_optimizer.zero_grad()

                with torch.no_grad():
                    for g, w in zip(grads, model.weight_params()):
                        w.grad = g
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1)

                w_optimizer.step()
                true_batch_index += 1
                if true_batch_index % logging_freq == 0:
                    print("Epoch: {}, Batch: {}, Loss: {}".format(
                        epoch, true_batch_index, epoch_loss.avg))
                    wandb.log({"Train loss": epoch_loss.avg})

            total_arch_gradient = sotl_gradient(model,
                                                criterion,
                                                xs,
                                                ys,
                                                weight_buffer,
                                                w_lr=w_lr,
                                                hvp=hvp,
                                                order=sotl_order)

            a_optimizer.zero_grad()

            for g, w in zip(total_arch_gradient, model.arch_params()):
                w.grad = g
            torch.nn.utils.clip_grad_norm_(model.arch_params(), 1)
            a_optimizer.step()
Exemple #9
0
def train_shared_cnn(
    xloader,
    shared_cnn,
    controller,
    criterion,
    scheduler,
    optimizer,
    epoch_str,
    print_freq,
    logger,
):
    data_time, batch_time = AverageMeter(), AverageMeter()
    losses, top1s, top5s, xend = (
        AverageMeter(),
        AverageMeter(),
        AverageMeter(),
        time.time(),
    )

    shared_cnn.train()
    controller.eval()

    for step, (inputs, targets) in enumerate(xloader):
        scheduler.update(None, 1.0 * step / len(xloader))
        targets = targets.cuda(non_blocking=True)
        # measure data loading time
        data_time.update(time.time() - xend)

        with torch.no_grad():
            _, _, sampled_arch = controller()

        optimizer.zero_grad()
        shared_cnn.module.update_arch(sampled_arch)
        _, logits = shared_cnn(inputs)
        loss = criterion(logits, targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), 5)
        optimizer.step()
        # record
        prec1, prec5 = obtain_accuracy(logits.data, targets.data, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1s.update(prec1.item(), inputs.size(0))
        top5s.update(prec5.item(), inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - xend)
        xend = time.time()

        if step % print_freq == 0 or step + 1 == len(xloader):
            Sstr = (
                "*Train-Shared-CNN* " + time_string() +
                " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)))
            Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
                batch_time=batch_time, data_time=data_time)
            Wstr = "[Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
                loss=losses, top1=top1s, top5=top5s)
            logger.log(Sstr + " " + Tstr + " " + Wstr)
    return losses.avg, top1s.avg, top5s.avg
Exemple #10
0
def train_shared_cnn(xloader, shared_cnn, controller, criterion, scheduler,
                     optimizer, epoch_str, print_freq, logger):
    data_time, batch_time = AverageMeter(), AverageMeter()
    losses, top1s, top5s, xend = AverageMeter(), AverageMeter(), AverageMeter(
    ), time.time()

    shared_cnn.train()
    controller.eval()
    ne = 10

    for ni in range(ne):
        with torch.no_grad():
            _, _, sampled_arch = controller()
        shared_cnn.module.update_arch(sampled_arch)
        print(sampled_arch)
        # arch_str = op_list2str(sampled_arch)
        for step, (inputs, targets) in enumerate(xloader):
            # print(step,inputs,targets)
            scheduler.update(None, 1.0 * step / len(xloader))
            targets = targets.cuda(non_blocking=True)
            # measure data loading time
            data_time.update(time.time() - xend)

            optimizer.zero_grad()

            _, logits = shared_cnn(inputs)
            loss = criterion(logits, targets)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(shared_cnn.parameters(), 5)
            optimizer.step()
            # record
            prec1, prec5 = obtain_accuracy(logits.data,
                                           targets.data,
                                           topk=(1, 2))
            losses.update(loss.item(), inputs.size(0))
            top1s.update(prec1.item(), inputs.size(0))
            top5s.update(prec5.item(), inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - xend)
            xend = time.time()

            # if step + 1 == len(xloader):
        Sstr = '*Train-Shared-CNN* ' + time_string() + ' [{:03d}/10]'.format(
            ni, ne)
        Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
            batch_time=batch_time, data_time=data_time)
        Wstr = '[Loss {loss.avg:.3f}  Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f}]'.format(
            loss=losses, top1=top1s, top5=top5s)
        losses.reset()
        top1s.reset()
        top5s.reset()
        logger.log(Sstr + ' ' + Tstr + ' ' + Wstr)

    return losses.avg, top1s.avg, top5s.avg
Exemple #11
0
def valid_func(model, val_loader, criterion):
    model.eval()
    val_meter = AverageMeter()
    with torch.no_grad():
        for batch in val_loader:
            x, y = batch
            y_pred = model(x)
            val_loss = criterion(y_pred, y)
            val_meter.update(val_loss.item())
    print("Val loss: {}".format(val_meter.avg))
    return val_meter
def eval_robust_heatmap(detector, xloader, print_freq, logger):
    batch_time, NUM_PTS = AverageMeter(), xloader.dataset.NUM_PTS
    Preds, GT_locs, Distances = [], [], []
    eval_meta, end = Eval_Meta(), time.time()

    with torch.no_grad():
        detector.eval()
        for i, (inputs, heatmaps, masks, norm_points, thetas, data_index,
                nopoints, xshapes) in enumerate(xloader):
            data_index = data_index.squeeze(1).tolist()
            batch_size, iters, C, H, W = inputs.size()
            for ibatch in range(batch_size):
                xinputs, xpoints, xthetas = inputs[ibatch], norm_points[
                    ibatch].permute(0, 2, 1).contiguous(), thetas[ibatch]
                batch_features, batch_heatmaps, batch_locs, batch_scos = detector(
                    xinputs.cuda(non_blocking=True))
                batch_locs = batch_locs.cpu()[:, :-1]
                all_locs = []
                for _iter in range(iters):
                    _locs = normalize_points((H, W),
                                             batch_locs[_iter].permute(1, 0))
                    xlocs = torch.cat((_locs, torch.ones(1, NUM_PTS)), dim=0)
                    nlocs = torch.mm(xthetas[_iter, :2], xlocs)
                    rlocs = denormalize_points(xshapes[ibatch].tolist(), nlocs)
                    rlocs = torch.cat(
                        (rlocs.permute(1, 0), xpoints[_iter, :, 2:]), dim=1)
                    all_locs.append(rlocs.clone())
                GT_loc = xloader.dataset.labels[
                    data_index[ibatch]].get_points()
                norm_distance = xloader.dataset.get_normalization_distance(
                    data_index[ibatch])
                # save the results
                eval_meta.append((sum(all_locs) / len(all_locs)).numpy().T,
                                 GT_loc.numpy(),
                                 xloader.dataset.datas[data_index[ibatch]],
                                 norm_distance)
                Distances.append(norm_distance)
                Preds.append(all_locs)
                GT_locs.append(GT_loc.permute(1, 0))
            # compute time
            batch_time.update(time.time() - end)
            end = time.time()
            if i % print_freq == 0 or i + 1 == len(xloader):
                last_time = convert_secs2time(
                    batch_time.avg * (len(xloader) - i - 1), True)
                logger.log(
                    ' -->>[Robust HEATMAP-based Evaluation] [{:03d}/{:03d}] Time : {:}'
                    .format(i, len(xloader), last_time))
    # evaluate the results
    errors, valids = calculate_robust(Preds, GT_locs, Distances, NUM_PTS)
    return errors, valids, eval_meta
Exemple #13
0
def valid_func(model, dset_val, criterion, print_results=True):
    model.eval()
    val_loader = torch.utils.data.DataLoader(dset_val, batch_size=32)

    val_meter = AverageMeter()
    with torch.no_grad():
        for batch in val_loader:
            x, y = batch
            y_pred = model(x)
            val_loss = criterion(y_pred, y)
            val_meter.update(val_loss.item())
    if print_results:
        print("Val loss: {}".format(val_meter.avg))
    return val_meter
    def search(self):
        self.eva_time = AverageMeter()
        init_start = time.time()
        self.init_random()
        self.logger.log('Initial_takes: %.2f' % (time.time() - init_start))

        epoch_start_time = time.time()
        epoch_time_meter = AverageMeter()
        bests_per_epoch = list()
        perform_trace = list()
        for i in range(self.max_epochs):
            self.performances = torch.Tensor(self.performances)
            top_k = torch.argsort(self.performances,
                                  descending=True)[:self.parent_num]

            if self.best_perf is None or self.performances[
                    top_k[0]] > self.best_perf:
                self.best_cand = self.candidates[top_k[0]]
                self.best_perf = self.performances[top_k[0]]
            bests_per_epoch.append(self.best_cand)
            perform_trace.append(self.performances)

            self.parents = []
            for idx in top_k:
                self.parents.append(self.candidates[idx])
            self.candidates, self.performances = list(), list()
            self.eva_time = AverageMeter()
            self.get_mutation(self.population_num // 2)
            self.get_crossover()

            self.logger.log(
                '*SEARCH* ' + time_string() +
                '||| Epoch: %2d finished, %3d models have been tested, best performance is %.2f'
                % (i, len(self.perform_dict.keys()), self.best_perf))
            self.logger.log(' - Best Cand: ' + str(self.best_cand))
            this_epoch_time = time.time() - epoch_start_time
            epoch_time_meter.update(this_epoch_time)
            epoch_start_time = time.time()
            self.logger.log('Time for Epoch %d : %.2fs' % (i, this_epoch_time))
            self.logger.log(' -- Evaluated %d models, with %.2f s in average' %
                            (self.eva_time.count, self.eva_time.avg))

        self.logger.log(
            '--------\nSearching Finished. Best Arch Found with Acc %.2f' %
            (self.best_perf))
        self.logger.log(str(self.best_cand))
        #torch.save(self.best_cand, self.save_dir+'/best_arch.pth')
        #torch.save(self.perform_dict, self.save_dir+'/perform_dict.pth')
        return bests_per_epoch, self.perform_dict, perform_trace
Exemple #15
0
def train_normal(num_epochs,
                 model,
                 dset_train,
                 batch_size,
                 grad_clip,
                 logging_freq,
                 optim="sgd",
                 **kwargs):
    train_loader = torch.utils.data.DataLoader(dset_train,
                                               batch_size=batch_size,
                                               shuffle=True)

    model.train()
    for epoch in range(num_epochs):

        epoch_loss = AverageMeter()
        for batch_idx, batch in enumerate(train_loader):
            x, y = batch
            w_optimizer.zero_grad()

            y_pred = model(x)
            loss = criterion(y_pred, y)
            loss.backward(retain_graph=True)

            epoch_loss.update(loss.item())
            if optim == "newton":
                linear_weight = list(model.weight_params())[0]
                hessian_newton = torch.inverse(
                    hessian(loss * 1, linear_weight,
                            linear_weight).reshape(linear_weight.size()[1],
                                                   linear_weight.size()[1]))
                with torch.no_grad():
                    for w in model.weight_params():
                        w = w.subtract_(torch.matmul(w.grad, hessian_newton))
            elif optim == "sgd":
                torch.nn.utils.clip_grad_norm_(model.weight_params(), 1)
                w_optimizer.step()
            else:
                raise NotImplementedError

            wandb.log({
                "Train loss": epoch_loss.avg,
                "Epoch": epoch,
                "Batch": batch_idx
            })

            if batch_idx % logging_freq == 0:
                print("Epoch: {}, Batch: {}, Loss: {}, Alphas: {}".format(
                    epoch, batch_idx, epoch_loss.avg, model.fc1.alphas.data))
Exemple #16
0
def search_valid(xloader, network, criterion, extra_info, print_freq, logger):
    data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(
    ), AverageMeter(), AverageMeter(), AverageMeter()

    network.eval()
    network.apply(change_key('search_mode', 'search'))
    end = time.time()
    # logger.log('Starting evaluating {:}'.format(epoch_info))
    with torch.no_grad():
        for i, (inputs, targets) in enumerate(xloader):
            # measure data loading time
            data_time.update(time.time() - end)
            # calculate prediction and loss
            targets = targets.cuda(non_blocking=True)

            logits, expected_flop = network(inputs)
            loss = criterion(logits, targets)
            # record
            prec1, prec5 = obtain_accuracy(logits.data,
                                           targets.data,
                                           topk=(1, 5))
            losses.update(loss.item(), inputs.size(0))
            top1.update(prec1.item(), inputs.size(0))
            top5.update(prec5.item(), inputs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % print_freq == 0 or (i + 1) == len(xloader):
                Sstr = '**VALID** ' + time_string(
                ) + ' [{:}][{:03d}/{:03d}]'.format(extra_info, i, len(xloader))
                Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                    batch_time=batch_time, data_time=data_time)
                Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(
                    loss=losses, top1=top1, top5=top5)
                Istr = 'Size={:}'.format(list(inputs.size()))
                logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr)

    logger.log(
        ' **VALID** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'
        .format(top1=top1,
                top5=top5,
                error1=100 - top1.avg,
                error5=100 - top5.avg,
                loss=losses.avg))

    return losses.avg, top1.avg, top5.avg
    def test_archi_acc(self, arch):
        if self.train_loader is not None:
            self.model.apply(ResetRunningStats)

            self.model.train()
            for step, (data, target) in enumerate(self.train_loader):
                # print('train step: {} total: {}'.format(step,max_train_iters))
                # data, target = train_dataprovider.next()
                # print('get data',data.shape)
                #data = data.cuda()
                output = self.model.forward(data, arch)  #_with_architect
                del data, target, output

        base_top1, base_top5 = AverageMeter(), AverageMeter()
        self.model.eval()

        one_batch = None
        for step, (data, target) in enumerate(self.val_loader):
            # print('test step: {} total: {}'.format(step,max_test_iters))
            if one_batch == None:
                one_batch = data
            batchsize = data.shape[0]
            # print('get data',data.shape)
            target = target.cuda(non_blocking=True)
            #data, target = data.to(device), target.to(device)

            _, logits = self.model.forward(data, arch)  #_with_architect

            prec1, prec5 = obtain_accuracy(logits.data,
                                           target.data,
                                           topk=(1, 5))
            base_top1.update(prec1.item(), batchsize)
            base_top5.update(prec5.item(), batchsize)

            del data, target, logits, prec1, prec5

        if self.lambda_t > 0.0:
            start_time = time.time()
            len_batch = min(len(one_batch), 50)
            for i in range(len_batch):
                _, _ = self.model.forward(one_batch[i:i + 1, :, :, :], arch)
            end_time = time.time()
            time_per = (end_time - start_time) / len_batch
        else:
            time_per = 0.0

        #print('top1: {:.2f} top5: {:.2f}'.format(base_top1.avg * 100, base_top5.avg * 100))
        return base_top1.avg, base_top5.avg, time_per
Exemple #18
0
def search_func(xloader, network, criterion, scheduler, w_optimizer,
                a_optimizer, epoch_str, print_freq, logger):
    data_time, batch_time = AverageMeter(), AverageMeter()
    base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    end = time.time()
    network.train()
    for step, (base_inputs, base_targets, arch_inputs,
               arch_targets) in enumerate(xloader):
        scheduler.update(None, 1.0 * step / len(xloader))
        base_targets = base_targets.cuda(non_blocking=True)
        arch_targets = arch_targets.cuda(non_blocking=True)
        # measure data loading time
        data_time.update(time.time() - end)

        network.module.set_cal_mode('urs')
        network.zero_grad()
        _, logits = network(base_inputs)
        base_loss = criterion(logits, base_targets)
        base_loss.backward()
        w_optimizer.step()
        # record
        base_prec1, base_prec5 = obtain_accuracy(logits.data,
                                                 base_targets.data,
                                                 topk=(1, 5))
        base_losses.update(base_loss.item(), base_inputs.size(0))
        base_top1.update(base_prec1.item(), base_inputs.size(0))
        base_top5.update(base_prec5.item(), base_inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if step % print_freq == 0 or step + 1 == len(xloader):
            Sstr = '*SEARCH* ' + time_string(
            ) + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
            Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                batch_time=batch_time, data_time=data_time)
            Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(
                loss=base_losses, top1=base_top1, top5=base_top5)
            Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(
                loss=arch_losses, top1=arch_top1, top5=arch_top5)
            logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)

    return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
def search_func(xloader, network, criterion, scheduler, w_optimizer, epoch_str,
                print_freq, logger):
    data_time, batch_time = AverageMeter(), AverageMeter()
    base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    network.train()
    end = time.time()
    for step, (base_inputs, base_targets, arch_inputs,
               arch_targets) in enumerate(xloader):
        scheduler.update(None, 1.0 * step / len(xloader))
        base_targets = base_targets.cuda(non_blocking=True)
        arch_targets = arch_targets.cuda(non_blocking=True)
        # measure data loading time
        data_time.update(time.time() - end)

        # update the weights
        network.module.random_genotype(True)
        w_optimizer.zero_grad()
        _, logits = network(base_inputs)
        base_loss = criterion(logits, base_targets)
        base_loss.backward()
        nn.utils.clip_grad_norm_(network.parameters(), 5)
        w_optimizer.step()
        # record
        base_prec1, base_prec5 = obtain_accuracy(logits.data,
                                                 base_targets.data,
                                                 topk=(1, 5))
        base_losses.update(base_loss.item(), base_inputs.size(0))
        base_top1.update(base_prec1.item(), base_inputs.size(0))
        base_top5.update(base_prec5.item(), base_inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if step % print_freq == 0 or step + 1 == len(xloader):
            Sstr = (
                "*SEARCH* " + time_string() +
                " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)))
            Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
                batch_time=batch_time, data_time=data_time)
            Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
                loss=base_losses, top1=base_top1, top5=base_top5)
            logger.log(Sstr + " " + Tstr + " " + Wstr)
    return base_losses.avg, base_top1.avg, base_top5.avg
Exemple #20
0
def procedure(
    xloader,
    network,
    criterion,
    optimizer,
    metric,
    mode: Text,
    logger_fn: Callable = None,
):
    data_time, batch_time = AverageMeter(), AverageMeter()
    if mode.lower() == "train":
        network.train()
    elif mode.lower() == "valid":
        network.eval()
    else:
        raise ValueError("The mode is not right : {:}".format(mode))

    end = time.time()
    for i, (inputs, targets) in enumerate(xloader):
        # measure data loading time
        data_time.update(time.time() - end)
        # calculate prediction and loss

        if mode == "train":
            optimizer.zero_grad()

        outputs = network(inputs)
        targets = targets.to(get_device(outputs))

        if mode == "train":
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

        # record
        with torch.no_grad():
            results = metric(outputs, targets)

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
    return metric.get_info()
Exemple #21
0
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, algo, epoch_str, print_freq, logger):
  data_time, batch_time = AverageMeter(), AverageMeter()
  base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
  arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
  end = time.time()
  network.train()
  for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
    scheduler.update(None, 1.0 * step / len(xloader))
    base_inputs = base_inputs.cuda(non_blocking=True)
    arch_inputs = arch_inputs.cuda(non_blocking=True)
    base_targets = base_targets.cuda(non_blocking=True)
    arch_targets = arch_targets.cuda(non_blocking=True)
    # measure data loading time
    data_time.update(time.time() - end)
    
    # Update the weights
    network.zero_grad()
    _, logits, _ = network(base_inputs)
    base_loss = criterion(logits, base_targets)
    base_loss.backward()
    w_optimizer.step()
    # record
    base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 5))
    base_losses.update(base_loss.item(),  base_inputs.size(0))
    base_top1.update  (base_prec1.item(), base_inputs.size(0))
    base_top5.update  (base_prec5.item(), base_inputs.size(0))

    # update the architecture-weight
    network.zero_grad()
    _, logits, log_probs = network(arch_inputs)
    arch_prec1, arch_prec5 = obtain_accuracy(logits.data, arch_targets.data, topk=(1, 5))
    if algo == 'tunas':
      with torch.no_grad():
        RL_BASELINE_EMA.update(arch_prec1.item())
        rl_advantage = arch_prec1 - RL_BASELINE_EMA.value
      rl_log_prob = sum(log_probs)
      arch_loss = - rl_advantage * rl_log_prob
    elif algo == 'tas' or algo == 'fbv2':
      arch_loss = criterion(logits, arch_targets)
    else:
      raise ValueError('invalid algorightm name: {:}'.format(algo))
    arch_loss.backward()
    a_optimizer.step()
    # record
    arch_losses.update(arch_loss.item(),  arch_inputs.size(0))
    arch_top1.update  (arch_prec1.item(), arch_inputs.size(0))
    arch_top5.update  (arch_prec5.item(), arch_inputs.size(0))

    # measure elapsed time
    batch_time.update(time.time() - end)
    end = time.time()

    if step % print_freq == 0 or step + 1 == len(xloader):
      Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
      Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
      Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=base_losses, top1=base_top1, top5=base_top5)
      Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5)
      logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
  return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
Exemple #22
0
def valid_func(model, dset_val, criterion, device = 'cuda' if torch.cuda.is_available() else 'cpu', print_results=True):
    model.eval()
    val_loader = torch.utils.data.DataLoader(dset_val, batch_size=32)
    val_meter = AverageMeter()
    val_acc_meter = AverageMeter()

    with torch.no_grad():
        for batch in val_loader:
            x, y = batch
            x = x.to(device)
            y = y.to(device)
            y_pred = model(x)

            if isinstance(criterion, torch.nn.CrossEntropyLoss):
                predicted = torch.argmax(y_pred, dim=1)
                correct = torch.sum((predicted == y)).item()
                total = predicted.size()[0]
                val_acc_meter.update(correct/total)
            val_loss = criterion(y_pred, y)
            val_meter.update(val_loss.item())
    if print_results:
        print("Val loss: {}, Val acc: {}".format(val_meter.avg, val_acc_meter.avg if val_acc_meter.avg > 0 else "Not applicable"))
    return val_meter
Exemple #23
0
def search_func(xloader, network, criterion, scheduler, w_optimizer, a_optimizer, epoch_str, print_freq, logger):
    data_time, batch_time = AverageMeter(), AverageMeter()
    base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(), AverageMeter()
    arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(), AverageMeter()
    network.train()
    end = time.time()
    for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(xloader):
        # print(111111111111111111111)
        # print(arch_inputs.size())
        # print(arch_targets.size())
        scheduler.update(None, 1.0 * step / len(xloader))
        base_targets = base_targets.cuda(non_blocking=True)
        arch_targets = arch_targets.cuda(non_blocking=True)
        # measure data loading time
        data_time.update(time.time() - end)

        # update the architecture-weight
        a_optimizer.zero_grad()
        arch_loss, arch_logits = backward_step_unrolled(network, criterion, base_inputs, base_targets, w_optimizer, arch_inputs, arch_targets)
        a_optimizer.step()
        # record
        arch_prec1, arch_prec5 = obtain_accuracy(arch_logits.data, arch_targets.data, topk=(1, 2))
        arch_losses.update(arch_loss.item(),  arch_inputs.size(0))
        arch_top1.update  (arch_prec1.item(), arch_inputs.size(0))
        arch_top5.update  (arch_prec5.item(), arch_inputs.size(0))

        # update the weights
        w_optimizer.zero_grad()
        _, logits = network(base_inputs)
        base_loss = criterion(logits, base_targets)
        base_loss.backward()
        torch.nn.utils.clip_grad_norm_(network.parameters(), 5)
        w_optimizer.step()
        # record
        base_prec1, base_prec5 = obtain_accuracy(logits.data, base_targets.data, topk=(1, 2))
        base_losses.update(base_loss.item(),  base_inputs.size(0))
        base_top1.update  (base_prec1.item(), base_inputs.size(0))
        base_top5.update  (base_prec5.item(), base_inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if step + 1 == len(xloader):
            Sstr = '*SEARCH* ' + time_string() + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
            # Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(batch_time=batch_time, data_time=data_time)
            Wstr = 'Base [Loss {loss.avg:.3f}  Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f}]'.format(loss=base_losses, top1=base_top1, top5=base_top5)
            Astr = 'Arch [Loss {loss.avg:.3f}  Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f}]'.format(loss=arch_losses, top1=arch_top1, top5=arch_top5)
            logger.log(Sstr + ' ' + Wstr + ' ' + Astr)
    return base_losses.avg, base_top1.avg, base_top5.avg
Exemple #24
0
def train_bptt(
    num_epochs: int,
    model,
    criterion,
    w_optimizer,
    a_optimizer,
    dset_train,
    dset_val,
    batch_size: int,
    T: int,
    w_checkpoint_freq: int,
    grad_clip: float,
    w_lr: float,
    logging_freq: int,
    grad_inner_loop_order: int,
    grad_outer_loop_order:int,
    hvp: str,
    arch_train_data:str,
    normalize_a_lr:bool,
    log_grad_norm:bool,
    log_alphas:bool,
    w_warm_start:int,
    extra_weight_decay:float,
    train_arch:bool,
    device:str
):
    train_loader = torch.utils.data.DataLoader(
        dset_train, batch_size=batch_size * T, shuffle=True
    )
    val_loader = torch.utils.data.DataLoader(dset_val, batch_size=batch_size)
    grad_compute_speed = AverageMeter()
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'

    for epoch in range(num_epochs):
        model.train()

        epoch_loss = AverageMeter()
        true_batch_index = 0
        
        val_iter = iter(val_loader)
        for batch_idx, batch in enumerate(train_loader):

            
            xs, ys = torch.split(batch[0], batch_size), torch.split(
                batch[1], batch_size
            )

            weight_buffer = WeightBuffer(T=T, checkpoint_freq=w_checkpoint_freq)
            weight_buffer.add(model, 0)

            for intra_batch_idx, (x, y) in enumerate(zip(xs, ys),1):
                x = x.to(device)
                y = y.to(device)

                # weight_buffer.add(model, intra_batch_idx) # TODO Should it be added here?

                y_pred = model(x)

                param_norm = 0
                if extra_weight_decay is not None and extra_weight_decay != 0:
                    for n,weight in model.named_weight_params():
                        if 'weight' in n:
                            param_norm = param_norm + torch.pow(weight.norm(2), 2)
                    param_norm = torch.multiply(model.alpha_weight_decay, param_norm)
                # print(param_norm)
                
                
                loss = criterion(y_pred, y) + param_norm
                epoch_loss.update(loss.item())

                grads = torch.autograd.grad(
                    loss,
                    model.weight_params()
                )

                with torch.no_grad():
                    for g, w in zip(grads, model.weight_params()):
                        w.grad = g
                torch.nn.utils.clip_grad_norm_(model.weight_params(), 1)

                w_optimizer.step()
                w_optimizer.zero_grad()
                weight_buffer.add(model, intra_batch_idx)

                true_batch_index += 1
                wandb.log(
                    {
                        "Train loss": epoch_loss.avg,
                        "Epoch": epoch,
                        "Batch": true_batch_index,
                    }
                )

                if true_batch_index % logging_freq == 0:
                    print(
                        "Epoch: {}, Batch: {}, Loss: {}, Alphas: {}".format(
                            epoch,
                            true_batch_index,
                            epoch_loss.avg,
                            [x.data for x in model.arch_params()],
                        )
                    )

            if train_arch:
                val_xs = None
                val_ys = None
                if arch_train_data == "val":
                    try:
                        val_batch = next(val_iter)
                        val_xs, val_ys = torch.split(val_batch[0], batch_size), torch.split(
                            val_batch[1], batch_size
                        )

                    except:
                        val_iter = iter(val_loader)
                        val_batch = next(val_iter)
                        val_xs, val_ys = torch.split(val_batch[0], batch_size), torch.split(
                            val_batch[1], batch_size
                        )


                if epoch >= w_warm_start:
                    start_time = time.time()
                    total_arch_gradient = sotl_gradient(
                        model=model,
                        criterion=criterion,
                        xs=xs,
                        ys=ys,
                        weight_buffer=weight_buffer,
                        w_lr=w_lr,
                        hvp=hvp,
                        grad_inner_loop_order=grad_inner_loop_order,
                        grad_outer_loop_order=grad_outer_loop_order,
                        T=T,
                        normalize_a_lr=normalize_a_lr,
                        weight_decay_term=None,
                        val_xs=val_xs,
                        val_ys=val_ys
                    )
                    grad_compute_speed.update(time.time() - start_time)


                    if log_grad_norm:
                        norm = 0
                        for g in total_arch_gradient:
                            norm = norm + g.data.norm(2).item()
                        wandb.log({"Arch grad norm": norm})

                    if log_alphas:
                        if hasattr(model, "fc1") and hasattr(model.fc1, "degree"):
                            wandb.log({"Alpha":model.fc1.degree.item()})
                        if hasattr(model,"alpha_weight_decay"):
                            wandb.log({"Alpha": model.alpha_weight_decay.item()})

                    a_optimizer.zero_grad()

                    for g, w in zip(total_arch_gradient, model.arch_params()):
                        w.grad = g
                    torch.nn.utils.clip_grad_norm_(model.arch_params(), 1)
                    a_optimizer.step()

        val_results = valid_func(
            model=model, dset_val=dset_val, criterion=criterion, device=device, print_results=False
        )
        print("Epoch: {}, Val Loss: {}".format(epoch, val_results.avg))
        wandb.log({"Val loss": val_results.avg, "Epoch": epoch})
        wandb.run.summary["Grad compute speed"] = grad_compute_speed.avg

        print(f"Grad compute speed: {grad_compute_speed.avg}s")
def main(args):
  assert torch.cuda.is_available(), 'CUDA is not available.'
  torch.backends.cudnn.enabled   = True
  torch.backends.cudnn.benchmark = True
  prepare_seed(args.rand_seed)

  logstr = 'seed-{:}-time-{:}'.format(args.rand_seed, time_for_file())
  logger = Logger(args.save_path, logstr)
  logger.log('Main Function with logger : {:}'.format(logger))
  logger.log('Arguments : -------------------------------')
  for name, value in args._get_kwargs():
    logger.log('{:16} : {:}'.format(name, value))
  logger.log("Python  version : {}".format(sys.version.replace('\n', ' ')))
  logger.log("Pillow  version : {}".format(PIL.__version__))
  logger.log("PyTorch version : {}".format(torch.__version__))
  logger.log("cuDNN   version : {}".format(torch.backends.cudnn.version()))

  # General Data Argumentation
  mean_fill   = tuple( [int(x*255) for x in [0.485, 0.456, 0.406] ] )
  normalize   = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225])
  assert args.arg_flip == False, 'The flip is : {}, rotate is {}'.format(args.arg_flip, args.rotate_max)
  train_transform  = [transforms.PreCrop(args.pre_crop_expand)]
  train_transform += [transforms.TrainScale2WH((args.crop_width, args.crop_height))]
  train_transform += [transforms.AugScale(args.scale_prob, args.scale_min, args.scale_max)]
  #if args.arg_flip:
  #  train_transform += [transforms.AugHorizontalFlip()]
  if args.rotate_max:
    train_transform += [transforms.AugRotate(args.rotate_max)]
  train_transform += [transforms.AugCrop(args.crop_width, args.crop_height, args.crop_perturb_max, mean_fill)]
  train_transform += [transforms.ToTensor(), normalize]
  train_transform  = transforms.Compose( train_transform )

  eval_transform  = transforms.Compose([transforms.PreCrop(args.pre_crop_expand), transforms.TrainScale2WH((args.crop_width, args.crop_height)),  transforms.ToTensor(), normalize])
  assert (args.scale_min+args.scale_max) / 2 == args.scale_eval, 'The scale is not ok : {},{} vs {}'.format(args.scale_min, args.scale_max, args.scale_eval)
  
  # Model Configure Load
  model_config = load_configure(args.model_config, logger)
  args.sigma   = args.sigma * args.scale_eval
  logger.log('Real Sigma : {:}'.format(args.sigma))

  # Training Dataset
  train_data   = Dataset(train_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator)
  train_data.load_list(args.train_lists, args.num_pts, True)
  train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True)


  # Evaluation Dataloader
  eval_loaders = []
  if args.eval_vlists is not None:
    for eval_vlist in args.eval_vlists:
      eval_vdata = Dataset(eval_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator)
      eval_vdata.load_list(eval_vlist, args.num_pts, True)
      eval_vloader = torch.utils.data.DataLoader(eval_vdata, batch_size=args.batch_size, shuffle=False,
                                                 num_workers=args.workers, pin_memory=True)
      eval_loaders.append((eval_vloader, True))

  if args.eval_ilists is not None:
    for eval_ilist in args.eval_ilists:
      eval_idata = Dataset(eval_transform, args.sigma, model_config.downsample, args.heatmap_type, args.data_indicator)
      eval_idata.load_list(eval_ilist, args.num_pts, True)
      eval_iloader = torch.utils.data.DataLoader(eval_idata, batch_size=args.batch_size, shuffle=False,
                                                 num_workers=args.workers, pin_memory=True)
      eval_loaders.append((eval_iloader, False))

  # Define network
  logger.log('configure : {:}'.format(model_config))
  net = obtain_model(model_config, args.num_pts + 1)
  assert model_config.downsample == net.downsample, 'downsample is not correct : {} vs {}'.format(model_config.downsample, net.downsample)
  logger.log("=> network :\n {}".format(net))

  logger.log('Training-data : {:}'.format(train_data))
  for i, eval_loader in enumerate(eval_loaders):
    eval_loader, is_video = eval_loader
    logger.log('The [{:2d}/{:2d}]-th testing-data [{:}] = {:}'.format(i, len(eval_loaders), 'video' if is_video else 'image', eval_loader.dataset))
    
  logger.log('arguments : {:}'.format(args))

  opt_config = load_configure(args.opt_config, logger)

  if hasattr(net, 'specify_parameter'):
    net_param_dict = net.specify_parameter(opt_config.LR, opt_config.Decay)
  else:
    net_param_dict = net.parameters()

  optimizer, scheduler, criterion = obtain_optimizer(net_param_dict, opt_config, logger)
  logger.log('criterion : {:}'.format(criterion))
  net, criterion = net.cuda(), criterion.cuda()
  net = torch.nn.DataParallel(net)

  last_info = logger.last_info()
  if last_info.exists():
    logger.log("=> loading checkpoint of the last-info '{:}' start".format(last_info))
    last_info = torch.load(last_info)
    start_epoch = last_info['epoch'] + 1
    checkpoint  = torch.load(last_info['last_checkpoint'])
    assert last_info['epoch'] == checkpoint['epoch'], 'Last-Info is not right {:} vs {:}'.format(last_info, checkpoint['epoch'])
    net.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    scheduler.load_state_dict(checkpoint['scheduler'])
    logger.log("=> load-ok checkpoint '{:}' (epoch {:}) done" .format(logger.last_info(), checkpoint['epoch']))
  else:
    logger.log("=> do not find the last-info file : {:}".format(last_info))
    start_epoch = 0


  if args.eval_once:
    logger.log("=> only evaluate the model once")
    eval_results = eval_all(args, eval_loaders, net, criterion, 'eval-once', logger, opt_config)
    logger.close() ; return


  # Main Training and Evaluation Loop
  start_time = time.time()
  epoch_time = AverageMeter()
  for epoch in range(start_epoch, opt_config.epochs):

    scheduler.step()
    need_time = convert_secs2time(epoch_time.avg * (opt_config.epochs-epoch), True)
    epoch_str = 'epoch-{:03d}-{:03d}'.format(epoch, opt_config.epochs)
    LRs       = scheduler.get_lr()
    logger.log('\n==>>{:s} [{:s}], [{:s}], LR : [{:.5f} ~ {:.5f}], Config : {:}'.format(time_string(), epoch_str, need_time, min(LRs), max(LRs), opt_config))

    # train for one epoch
    train_loss, train_nme = train(args, train_loader, net, criterion, optimizer, epoch_str, logger, opt_config)
    # log the results    
    logger.log('==>>{:s} Train [{:}] Average Loss = {:.6f}, NME = {:.2f}'.format(time_string(), epoch_str, train_loss, train_nme*100))

    # remember best prec@1 and save checkpoint
    save_path = save_checkpoint({
          'epoch': epoch,
          'args' : deepcopy(args),
          'arch' : model_config.arch,
          'state_dict': net.state_dict(),
          'scheduler' : scheduler.state_dict(),
          'optimizer' : optimizer.state_dict(),
          }, logger.path('model') / '{:}-{:}.pth'.format(model_config.arch, epoch_str), logger)

    last_info = save_checkpoint({
          'epoch': epoch,
          'last_checkpoint': save_path,
          }, logger.last_info(), logger)

    eval_results = eval_all(args, eval_loaders, net, criterion, epoch_str, logger, opt_config)
    
    # measure elapsed time
    epoch_time.update(time.time() - start_time)
    start_time = time.time()

  logger.close()
Exemple #26
0
def search_func(
    xloader,
    network,
    criterion,
    scheduler,
    w_optimizer,
    a_optimizer,
    epoch_str,
    print_freq,
    logger,
):
    data_time, batch_time = AverageMeter(), AverageMeter()
    base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    end = time.time()
    network.train()
    for step, (base_inputs, base_targets, arch_inputs,
               arch_targets) in enumerate(xloader):
        scheduler.update(None, 1.0 * step / len(xloader))
        base_targets = base_targets.cuda(non_blocking=True)
        arch_targets = arch_targets.cuda(non_blocking=True)
        # measure data loading time
        data_time.update(time.time() - end)

        # update the weights
        sampled_arch = network.module.dync_genotype(True)
        network.module.set_cal_mode("dynamic", sampled_arch)
        # network.module.set_cal_mode( 'urs' )
        network.zero_grad()
        _, logits = network(base_inputs)
        base_loss = criterion(logits, base_targets)
        base_loss.backward()
        w_optimizer.step()
        # record
        base_prec1, base_prec5 = obtain_accuracy(logits.data,
                                                 base_targets.data,
                                                 topk=(1, 5))
        base_losses.update(base_loss.item(), base_inputs.size(0))
        base_top1.update(base_prec1.item(), base_inputs.size(0))
        base_top5.update(base_prec5.item(), base_inputs.size(0))

        # update the architecture-weight
        network.module.set_cal_mode("joint")
        network.zero_grad()
        _, logits = network(arch_inputs)
        arch_loss = criterion(logits, arch_targets)
        arch_loss.backward()
        a_optimizer.step()
        # record
        arch_prec1, arch_prec5 = obtain_accuracy(logits.data,
                                                 arch_targets.data,
                                                 topk=(1, 5))
        arch_losses.update(arch_loss.item(), arch_inputs.size(0))
        arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
        arch_top5.update(arch_prec5.item(), arch_inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if step % print_freq == 0 or step + 1 == len(xloader):
            Sstr = (
                "*SEARCH* " + time_string() +
                " [{:}][{:03d}/{:03d}]".format(epoch_str, step, len(xloader)))
            Tstr = "Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})".format(
                batch_time=batch_time, data_time=data_time)
            Wstr = "Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
                loss=base_losses, top1=base_top1, top5=base_top5)
            Astr = "Arch [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]".format(
                loss=arch_losses, top1=arch_top1, top5=arch_top5)
            logger.log(Sstr + " " + Tstr + " " + Wstr + " " + Astr)
            # print (nn.functional.softmax(network.module.arch_parameters, dim=-1))
            # print (network.module.arch_parameters)
    return (
        base_losses.avg,
        base_top1.avg,
        base_top5.avg,
        arch_losses.avg,
        arch_top1.avg,
        arch_top5.avg,
    )
Exemple #27
0
def train_controller(xloader, network, criterion, optimizer, prev_baseline,
                     epoch_str, print_freq, logger):
    # config. (containing some necessary arg)
    #   baseline: The baseline score (i.e. average val_acc) from the previous epoch
    data_time, batch_time = AverageMeter(), AverageMeter()
    GradnormMeter, LossMeter, ValAccMeter, EntropyMeter, BaselineMeter, RewardMeter, xend = AverageMeter(
    ), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter(
    ), AverageMeter(), time.time()

    controller_num_aggregate = 20
    controller_train_steps = 50
    controller_bl_dec = 0.99
    controller_entropy_weight = 0.0001

    network.eval()
    network.controller.train()
    network.controller.zero_grad()
    loader_iter = iter(xloader)
    for step in range(controller_train_steps * controller_num_aggregate):
        try:
            inputs, targets = next(loader_iter)
        except:
            loader_iter = iter(xloader)
            inputs, targets = next(loader_iter)
        inputs = inputs.cuda(non_blocking=True)
        targets = targets.cuda(non_blocking=True)
        # measure data loading time
        data_time.update(time.time() - xend)

        log_prob, entropy, sampled_arch = network.controller()
        with torch.no_grad():
            network.set_cal_mode('dynamic', sampled_arch)
            _, logits = network(inputs)
            val_top1, val_top5 = obtain_accuracy(logits.data,
                                                 targets.data,
                                                 topk=(1, 5))
            val_top1 = val_top1.view(-1) / 100
        reward = val_top1 + controller_entropy_weight * entropy
        if prev_baseline is None:
            baseline = val_top1
        else:
            baseline = prev_baseline - (1 - controller_bl_dec) * (
                prev_baseline - reward)

        loss = -1 * log_prob * (reward - baseline)

        # account
        RewardMeter.update(reward.item())
        BaselineMeter.update(baseline.item())
        ValAccMeter.update(val_top1.item() * 100)
        LossMeter.update(loss.item())
        EntropyMeter.update(entropy.item())

        # Average gradient over controller_num_aggregate samples
        loss = loss / controller_num_aggregate
        loss.backward(retain_graph=True)

        # measure elapsed time
        batch_time.update(time.time() - xend)
        xend = time.time()
        if (step + 1) % controller_num_aggregate == 0:
            grad_norm = torch.nn.utils.clip_grad_norm_(
                network.controller.parameters(), 5.0)
            GradnormMeter.update(grad_norm)
            optimizer.step()
            network.controller.zero_grad()

        if step % print_freq == 0:
            Sstr = '*Train-Controller* ' + time_string(
            ) + ' [{:}][{:03d}/{:03d}]'.format(
                epoch_str, step,
                controller_train_steps * controller_num_aggregate)
            Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                batch_time=batch_time, data_time=data_time)
            Wstr = '[Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Reward {reward.val:.2f} ({reward.avg:.2f})] Baseline {basel.val:.2f} ({basel.avg:.2f})'.format(
                loss=LossMeter,
                top1=ValAccMeter,
                reward=RewardMeter,
                basel=BaselineMeter)
            Estr = 'Entropy={:.4f} ({:.4f})'.format(EntropyMeter.val,
                                                    EntropyMeter.avg)
            logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Estr)

    return LossMeter.avg, ValAccMeter.avg, BaselineMeter.avg, RewardMeter.avg
Exemple #28
0
def search_func(xloader, network, criterion, scheduler, w_optimizer,
                a_optimizer, epoch_str, print_freq, algo, logger):
    data_time, batch_time = AverageMeter(), AverageMeter()
    base_losses, base_top1, base_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    arch_losses, arch_top1, arch_top5 = AverageMeter(), AverageMeter(
    ), AverageMeter()
    end = time.time()
    network.train()
    for step, (base_inputs, base_targets, arch_inputs,
               arch_targets) in enumerate(xloader):
        scheduler.update(None, 1.0 * step / len(xloader))
        base_inputs = base_inputs.cuda(non_blocking=True)
        arch_inputs = arch_inputs.cuda(non_blocking=True)
        base_targets = base_targets.cuda(non_blocking=True)
        arch_targets = arch_targets.cuda(non_blocking=True)
        # measure data loading time
        data_time.update(time.time() - end)

        # Update the weights
        if algo == 'setn':
            sampled_arch = network.dync_genotype(True)
            network.set_cal_mode('dynamic', sampled_arch)
        elif algo == 'gdas':
            network.set_cal_mode('gdas', None)
        elif algo.startswith('darts'):
            network.set_cal_mode('joint', None)
        elif algo == 'random':
            network.set_cal_mode('urs', None)
        elif algo == 'enas':
            with torch.no_grad():
                network.controller.eval()
                _, _, sampled_arch = network.controller()
            network.set_cal_mode('dynamic', sampled_arch)
        else:
            raise ValueError('Invalid algo name : {:}'.format(algo))

        network.zero_grad()
        _, logits = network(base_inputs)
        base_loss = criterion(logits, base_targets)
        base_loss.backward()
        w_optimizer.step()
        # record
        base_prec1, base_prec5 = obtain_accuracy(logits.data,
                                                 base_targets.data,
                                                 topk=(1, 5))
        base_losses.update(base_loss.item(), base_inputs.size(0))
        base_top1.update(base_prec1.item(), base_inputs.size(0))
        base_top5.update(base_prec5.item(), base_inputs.size(0))

        # update the architecture-weight
        if algo == 'setn':
            network.set_cal_mode('joint')
        elif algo == 'gdas':
            network.set_cal_mode('gdas', None)
        elif algo.startswith('darts'):
            network.set_cal_mode('joint', None)
        elif algo == 'random':
            network.set_cal_mode('urs', None)
        elif algo != 'enas':
            raise ValueError('Invalid algo name : {:}'.format(algo))
        network.zero_grad()
        if algo == 'darts-v2':
            arch_loss, logits = backward_step_unrolled(
                network, criterion, base_inputs, base_targets, w_optimizer,
                arch_inputs, arch_targets)
            a_optimizer.step()
        elif algo == 'random' or algo == 'enas':
            with torch.no_grad():
                _, logits = network(arch_inputs)
                arch_loss = criterion(logits, arch_targets)
        else:
            _, logits = network(arch_inputs)
            arch_loss = criterion(logits, arch_targets)
            arch_loss.backward()
            a_optimizer.step()
        # record
        arch_prec1, arch_prec5 = obtain_accuracy(logits.data,
                                                 arch_targets.data,
                                                 topk=(1, 5))
        arch_losses.update(arch_loss.item(), arch_inputs.size(0))
        arch_top1.update(arch_prec1.item(), arch_inputs.size(0))
        arch_top5.update(arch_prec5.item(), arch_inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if step % print_freq == 0 or step + 1 == len(xloader):
            Sstr = '*SEARCH* ' + time_string(
            ) + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step, len(xloader))
            Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                batch_time=batch_time, data_time=data_time)
            Wstr = 'Base [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(
                loss=base_losses, top1=base_top1, top5=base_top5)
            Astr = 'Arch [Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})]'.format(
                loss=arch_losses, top1=arch_top1, top5=arch_top5)
            logger.log(Sstr + ' ' + Tstr + ' ' + Wstr + ' ' + Astr)
    return base_losses.avg, base_top1.avg, base_top5.avg, arch_losses.avg, arch_top1.avg, arch_top5.avg
Exemple #29
0
def search_train(search_loader, network, criterion, scheduler, base_optimizer,
                 arch_optimizer, optim_config, extra_info, print_freq, logger):
    data_time, batch_time = AverageMeter(), AverageMeter()
    base_losses, arch_losses, top1, top5 = AverageMeter(), AverageMeter(
    ), AverageMeter(), AverageMeter()
    arch_cls_losses, arch_flop_losses = AverageMeter(), AverageMeter()
    epoch_str, flop_need, flop_weight, flop_tolerant = extra_info[
        'epoch-str'], extra_info['FLOP-exp'], extra_info[
            'FLOP-weight'], extra_info['FLOP-tolerant']

    network.train()
    logger.log(
        '[Search] : {:}, FLOP-Require={:.2f} MB, FLOP-WEIGHT={:.2f}'.format(
            epoch_str, flop_need, flop_weight))
    end = time.time()
    network.apply(change_key('search_mode', 'search'))
    for step, (base_inputs, base_targets, arch_inputs,
               arch_targets) in enumerate(search_loader):
        scheduler.update(None, 1.0 * step / len(search_loader))
        # calculate prediction and loss
        base_targets = base_targets.cuda(non_blocking=True)
        arch_targets = arch_targets.cuda(non_blocking=True)
        # measure data loading time
        data_time.update(time.time() - end)

        # update the weights
        base_optimizer.zero_grad()
        logits, expected_flop = network(base_inputs)
        # network.apply( change_key('search_mode', 'basic') )
        # features, logits = network(base_inputs)
        base_loss = criterion(logits, base_targets)
        base_loss.backward()
        base_optimizer.step()
        # record
        prec1, prec5 = obtain_accuracy(logits.data,
                                       base_targets.data,
                                       topk=(1, 5))
        base_losses.update(base_loss.item(), base_inputs.size(0))
        top1.update(prec1.item(), base_inputs.size(0))
        top5.update(prec5.item(), base_inputs.size(0))

        # update the architecture
        arch_optimizer.zero_grad()
        logits, expected_flop = network(arch_inputs)
        flop_cur = network.module.get_flop('genotype', None, None)
        flop_loss, flop_loss_scale = get_flop_loss(expected_flop, flop_cur,
                                                   flop_need, flop_tolerant)
        acls_loss = criterion(logits, arch_targets)
        arch_loss = acls_loss + flop_loss * flop_weight
        arch_loss.backward()
        arch_optimizer.step()

        # record
        arch_losses.update(arch_loss.item(), arch_inputs.size(0))
        arch_flop_losses.update(flop_loss_scale, arch_inputs.size(0))
        arch_cls_losses.update(acls_loss.item(), arch_inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if step % print_freq == 0 or (step + 1) == len(search_loader):
            Sstr = '**TRAIN** ' + time_string(
            ) + ' [{:}][{:03d}/{:03d}]'.format(epoch_str, step,
                                               len(search_loader))
            Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                batch_time=batch_time, data_time=data_time)
            Lstr = 'Base-Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(
                loss=base_losses, top1=top1, top5=top5)
            Vstr = 'Acls-loss {aloss.val:.3f} ({aloss.avg:.3f}) FLOP-Loss {floss.val:.3f} ({floss.avg:.3f}) Arch-Loss {loss.val:.3f} ({loss.avg:.3f})'.format(
                aloss=arch_cls_losses,
                floss=arch_flop_losses,
                loss=arch_losses)
            logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr)
            # Istr = 'Bsz={:} Asz={:}'.format(list(base_inputs.size()), list(arch_inputs.size()))
            # logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Vstr + ' ' + Istr)
            # print(network.module.get_arch_info())
            # print(network.module.width_attentions[0])
            # print(network.module.width_attentions[1])

    logger.log(
        ' **TRAIN** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Base-Loss:{baseloss:.3f}, Arch-Loss={archloss:.3f}'
        .format(top1=top1,
                top5=top5,
                error1=100 - top1.avg,
                error5=100 - top5.avg,
                baseloss=base_losses.avg,
                archloss=arch_losses.avg))
    return base_losses.avg, arch_losses.avg, top1.avg, top5.avg
def basic_train(args, loader, net, criterion, optimizer, epoch_str, logger, opt_config):
  args = deepcopy(args)
  batch_time, data_time, forward_time, eval_time = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
  visible_points, losses = AverageMeter(), AverageMeter()
  eval_meta = Eval_Meta()
  cpu = torch.device('cpu')

  # switch to train mode
  net.train()
  criterion.train()

  end = time.time()
  for i, (inputs, target, mask, points, image_index, nopoints, cropped_size) in enumerate(loader):
    # inputs : Batch, Channel, Height, Width

    target = target.cuda(non_blocking=True)

    image_index = image_index.numpy().squeeze(1).tolist()
    batch_size, num_pts = inputs.size(0), args.num_pts
    visible_point_num   = float(np.sum(mask.numpy()[:,:-1,:,:])) / batch_size
    visible_points.update(visible_point_num, batch_size)
    nopoints    = nopoints.numpy().squeeze(1).tolist()
    annotated_num = batch_size - sum(nopoints)

    # measure data loading time
    mask = mask.cuda(non_blocking=True)
    data_time.update(time.time() - end)

    # batch_heatmaps is a list for stage-predictions, each element should be [Batch, C, H, W]
    batch_heatmaps, batch_locs, batch_scos = net(inputs)
    forward_time.update(time.time() - end)

    loss, each_stage_loss_value = compute_stage_loss(criterion, target, batch_heatmaps, mask)

    if opt_config.lossnorm:
      loss, each_stage_loss_value = loss / annotated_num / 2, [x/annotated_num/2 for x in each_stage_loss_value]

    # measure accuracy and record loss
    losses.update(loss.item(), batch_size)

    # compute gradient and do SGD step
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    eval_time.update(time.time() - end)

    np_batch_locs, np_batch_scos = batch_locs.detach().to(cpu).numpy(), batch_scos.detach().to(cpu).numpy()
    cropped_size = cropped_size.numpy()
    # evaluate the training data
    for ibatch, (imgidx, nopoint) in enumerate(zip(image_index, nopoints)):
      if nopoint == 1: continue
      locations, scores = np_batch_locs[ibatch,:-1,:], np.expand_dims(np_batch_scos[ibatch,:-1], -1)
      xpoints = loader.dataset.labels[imgidx].get_points()
      assert cropped_size[ibatch,0] > 0 and cropped_size[ibatch,1] > 0, 'The ibatch={:}, imgidx={:} is not right.'.format(ibatch, imgidx, cropped_size[ibatch])
      scale_h, scale_w = cropped_size[ibatch,0] * 1. / inputs.size(-2) , cropped_size[ibatch,1] * 1. / inputs.size(-1)
      locations[:, 0], locations[:, 1] = locations[:, 0] * scale_w + cropped_size[ibatch,2], locations[:, 1] * scale_h + cropped_size[ibatch,3]
      assert xpoints.shape[1] == num_pts and locations.shape[0] == num_pts and scores.shape[0] == num_pts, 'The number of points is {} vs {} vs {} vs {}'.format(num_pts, xpoints.shape, locations.shape, scores.shape)
      # recover the original resolution
      prediction = np.concatenate((locations, scores), axis=1).transpose(1,0)
      image_path = loader.dataset.datas[imgidx]
      face_size  = loader.dataset.face_sizes[imgidx]
      eval_meta.append(prediction, xpoints, image_path, face_size)

    # measure elapsed time
    batch_time.update(time.time() - end)
    last_time = convert_secs2time(batch_time.avg * (len(loader)-i-1), True)
    end = time.time()

    if i % args.print_freq == 0 or i+1 == len(loader):
      logger.log(' -->>[Train]: [{:}][{:03d}/{:03d}] '
                'Time {batch_time.val:4.2f} ({batch_time.avg:4.2f}) '
                'Data {data_time.val:4.2f} ({data_time.avg:4.2f}) '
                'Forward {forward_time.val:4.2f} ({forward_time.avg:4.2f}) '
                'Loss {loss.val:7.4f} ({loss.avg:7.4f})  '.format(
                    epoch_str, i, len(loader), batch_time=batch_time,
                    data_time=data_time, forward_time=forward_time, loss=losses)
                  + last_time + show_stage_loss(each_stage_loss_value) \
                  + ' In={:} Tar={:}'.format(list(inputs.size()), list(target.size())) \
                  + ' Vis-PTS : {:2d} ({:.1f})'.format(int(visible_points.val), visible_points.avg))
  nme, _, _ = eval_meta.compute_mse(logger)
  return losses.avg, nme
def procedure(xloader, teacher, network, criterion, scheduler, optimizer, mode,
              config, extra_info, print_freq, logger):
    data_time, batch_time, losses, top1, top5 = AverageMeter(), AverageMeter(
    ), AverageMeter(), AverageMeter(), AverageMeter()
    Ttop1, Ttop5 = AverageMeter(), AverageMeter()
    if mode == 'train':
        network.train()
    elif mode == 'valid':
        network.eval()
    else:
        raise ValueError("The mode is not right : {:}".format(mode))
    teacher.eval()

    logger.log(
        '[{:5s}] config :: auxiliary={:}, KD :: [alpha={:.2f}, temperature={:.2f}]'
        .format(mode, config.auxiliary if hasattr(config, 'auxiliary') else -1,
                config.KD_alpha, config.KD_temperature))
    end = time.time()
    for i, (inputs, targets) in enumerate(xloader):
        if mode == 'train': scheduler.update(None, 1.0 * i / len(xloader))
        # measure data loading time
        data_time.update(time.time() - end)
        # calculate prediction and loss
        targets = targets.cuda(non_blocking=True)

        if mode == 'train': optimizer.zero_grad()

        student_f, logits = network(inputs)
        if isinstance(logits, list):
            assert len(
                logits
            ) == 2, 'logits must has {:} items instead of {:}'.format(
                2, len(logits))
            logits, logits_aux = logits
        else:
            logits, logits_aux = logits, None
        with torch.no_grad():
            teacher_f, teacher_logits = teacher(inputs)

        loss = loss_KD_fn(criterion, logits, teacher_logits, student_f,
                          teacher_f, targets, config.KD_alpha,
                          config.KD_temperature)
        if config is not None and hasattr(
                config, 'auxiliary') and config.auxiliary > 0:
            loss_aux = criterion(logits_aux, targets)
            loss += config.auxiliary * loss_aux

        if mode == 'train':
            loss.backward()
            optimizer.step()

        # record
        sprec1, sprec5 = obtain_accuracy(logits.data,
                                         targets.data,
                                         topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        top1.update(sprec1.item(), inputs.size(0))
        top5.update(sprec5.item(), inputs.size(0))
        # teacher
        tprec1, tprec5 = obtain_accuracy(teacher_logits.data,
                                         targets.data,
                                         topk=(1, 5))
        Ttop1.update(tprec1.item(), inputs.size(0))
        Ttop5.update(tprec5.item(), inputs.size(0))

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % print_freq == 0 or (i + 1) == len(xloader):
            Sstr = ' {:5s} '.format(
                mode.upper()) + time_string() + ' [{:}][{:03d}/{:03d}]'.format(
                    extra_info, i, len(xloader))
            if scheduler is not None:
                Sstr += ' {:}'.format(scheduler.get_min_info())
            Tstr = 'Time {batch_time.val:.2f} ({batch_time.avg:.2f}) Data {data_time.val:.2f} ({data_time.avg:.2f})'.format(
                batch_time=batch_time, data_time=data_time)
            Lstr = 'Loss {loss.val:.3f} ({loss.avg:.3f})  Prec@1 {top1.val:.2f} ({top1.avg:.2f}) Prec@5 {top5.val:.2f} ({top5.avg:.2f})'.format(
                loss=losses, top1=top1, top5=top5)
            Lstr += ' Teacher : acc@1={:.2f}, acc@5={:.2f}'.format(
                Ttop1.avg, Ttop5.avg)
            Istr = 'Size={:}'.format(list(inputs.size()))
            logger.log(Sstr + ' ' + Tstr + ' ' + Lstr + ' ' + Istr)

    logger.log(' **{:5s}** accuracy drop :: @1={:.2f}, @5={:.2f}'.format(
        mode.upper(), Ttop1.avg - top1.avg, Ttop5.avg - top5.avg))
    logger.log(
        ' **{mode:5s}** Prec@1 {top1.avg:.2f} Prec@5 {top5.avg:.2f} Error@1 {error1:.2f} Error@5 {error5:.2f} Loss:{loss:.3f}'
        .format(mode=mode.upper(),
                top1=top1,
                top5=top5,
                error1=100 - top1.avg,
                error5=100 - top5.avg,
                loss=losses.avg))
    return losses.avg, top1.avg, top5.avg
Exemple #32
0
def stm_main_heatmap(args, loader, net, criterion, optimizer, epoch_str,
                     logger, opt_config, stm_config, use_stm, mode):
    assert mode == 'train' or mode == 'test', 'invalid mode : {:}'.format(mode)
    args = copy.deepcopy(args)
    batch_time, data_time, forward_time, eval_time = AverageMeter(
    ), AverageMeter(), AverageMeter(), AverageMeter()
    visible_points, DetLosses, TemporalLosses, MultiviewLosses, TotalLosses = AverageMeter(
    ), AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
    alk_points, a3d_points = AverageMeter(), AverageMeter()
    annotate_index = loader.dataset.video_L
    eval_meta = Eval_Meta()
    cpu = torch.device('cpu')

    if args.debug:
        save_dir = Path(
            args.save_path) / 'DEBUG' / ('{:}-'.format(mode) + epoch_str)
    else:
        save_dir = None

    # switch to train mode
    if mode == 'train':
        logger.log('STM-Main-REG : training : {:} .. STM = {:}'.format(
            stm_config, use_stm))
        print_freq = args.print_freq
        net.train()
        criterion.train()
    else:
        logger.log('STM-Main-REG : evaluation mode.')
        print_freq = args.print_freq_eval
        net.eval()
        criterion.eval()

    i_batch_size, v_batch_size, m_batch_size = args.i_batch_size, args.v_batch_size, args.m_batch_size
    iv_size = i_batch_size + v_batch_size
    end = time.time()
    for i, (frames, Fflows, Bflows, targets, masks, normpoints, transthetas, MV_Tensors, MV_Thetas, MV_Shapes, MV_KRT, torch_is_3D, torch_is_images \
              , image_index, nopoints, shapes, MultiViewPaths) in enumerate(loader):
        # frames : IBatch+VBatch+MBatch, Frame, Channel, Height, Width
        # Fflows : IBatch+VBatch+MBatch, Frame-1, Height, Width, 2
        # Bflows : IBatch+VBatch+MBatch, Frame-1, Height, Width, 2

        # information
        MV_Mask = masks[iv_size:]
        frames, Fflows, Bflows, targets, masks, normpoints, transthetas = frames[:
                                                                                 iv_size], Fflows[:
                                                                                                  iv_size], Bflows[:
                                                                                                                   iv_size], targets[:
                                                                                                                                     iv_size], masks[:
                                                                                                                                                     iv_size], normpoints[:
                                                                                                                                                                          iv_size], transthetas[:
                                                                                                                                                                                                iv_size]
        nopoints, shapes, torch_is_images = nopoints[:
                                                     iv_size], shapes[:
                                                                      iv_size], torch_is_images[:
                                                                                                iv_size]
        MV_Tensors, MV_Thetas, MV_Shapes, MV_KRT, torch_is_3D = \
          MV_Tensors[iv_size:], MV_Thetas[iv_size:], MV_Shapes[iv_size:], MV_KRT[iv_size:], torch_is_3D[iv_size:]
        assert torch.sum(torch_is_images[:i_batch_size]).item(
        ) == i_batch_size, 'Image Check Fail : {:} vs. {:}'.format(
            torch_is_images[:i_batch_size], i_batch_size)
        assert v_batch_size == 0 or torch.sum(
            torch_is_images[i_batch_size:]).item(
            ) == 0, 'Video Check Fail : {:} vs. {:}'.format(
                torch_is_images[i_batch_size:], v_batch_size)
        assert torch_is_3D.sum().item(
        ) == m_batch_size, 'Multiview Check Fail : {:} vs. {:}'.format(
            torch_is_3D, m_batch_size)
        image_index = image_index.squeeze(1).tolist()
        (batch_size, frame_length, C, H, W), num_pts, num_views = frames.size(
        ), args.num_pts, stm_config.max_views
        visible_point_num = float(np.sum(
            masks.numpy()[:, :-1, :, :])) / batch_size
        visible_points.update(visible_point_num, batch_size)

        normpoints = normpoints.permute(0, 2, 1)
        target_heats = targets.cuda(non_blocking=True)
        target_points = normpoints[:, :, :2].contiguous().cuda(
            non_blocking=True)
        target_scores = normpoints[:, :,
                                   2:].contiguous().cuda(non_blocking=True)
        det_masks = (1 - nopoints).view(batch_size, 1, 1, 1) * masks
        have_det_loss = det_masks.sum().item() > 0
        det_masks = det_masks.cuda(non_blocking=True)
        nopoints = nopoints.squeeze(1).tolist()

        # measure data loading time
        data_time.update(time.time() - end)

        # batch_heatmaps is a list for stage-predictions, each element should be [Batch, Sequence, PTS, H/Down, W/Down]
        batch_heatmaps, batch_locs, batch_scos, batch_past2now, batch_future2now, batch_FBcheck, multiview_heatmaps, multiview_locs = net(
            frames, Fflows, Bflows, MV_Tensors, torch_is_images)
        annot_heatmaps = [x[:, annotate_index] for x in batch_heatmaps]
        forward_time.update(time.time() - end)

        # detection loss
        if have_det_loss:
            det_loss, each_stage_loss_value = compute_stage_loss(
                criterion, target_heats, annot_heatmaps, det_masks)
            DetLosses.update(det_loss.item(), batch_size)
            each_stage_loss_value = show_stage_loss(each_stage_loss_value)
        else:
            det_loss, each_stage_loss_value = 0, 'no-det-loss'

        # temporal loss
        if use_stm[0]:
            video_batch_locs = batch_locs[i_batch_size:, :, :num_pts]
            video_past2now, video_future2now = batch_past2now[
                i_batch_size:, :, :num_pts], batch_future2now[
                    i_batch_size:, :, :num_pts]
            video_FBcheck = batch_FBcheck[i_batch_size:, :num_pts]
            video_mask = masks[i_batch_size:, :num_pts].contiguous().cuda(
                non_blocking=True)
            video_heatmaps = [
                x[i_batch_size:, :, :num_pts] for x in batch_heatmaps
            ]
            sbr_loss, available_nums, loss_string = calculate_temporal_loss(
                criterion, video_heatmaps, video_batch_locs, video_past2now,
                video_future2now, video_FBcheck, video_mask, stm_config)
            alk_points.update(
                float(available_nums) / v_batch_size, v_batch_size)
            if available_nums > stm_config.available_sbr_thresh:
                TemporalLosses.update(sbr_loss.item(), v_batch_size)
            else:
                sbr_loss, sbr_loss_string = 0, 'non-sbr-loss'
        else:
            sbr_loss, sbr_loss_string = 0, 'non-sbr-loss'

        # multiview loss
        if use_stm[1]:
            MV_Mask_G = MV_Mask[:, :-1].view(
                m_batch_size, 1, -1, 1).contiguous().cuda(non_blocking=True)
            MV_Thetas_G = MV_Thetas.to(multiview_locs.device)
            MV_Shapes_G = MV_Shapes.to(multiview_locs.device).view(
                m_batch_size, num_views, 1, 2)
            MV_KRT_G = MV_KRT.to(multiview_locs.device)
            mv_norm_locs_trs = torch.cat(
                (multiview_locs[:, :, :num_pts].permute(0, 1, 3, 2),
                 torch.ones(m_batch_size,
                            num_views,
                            1,
                            num_pts,
                            device=multiview_locs.device)),
                dim=2)
            mv_norm_locs_ori = torch.matmul(MV_Thetas_G[:, :, :2],
                                            mv_norm_locs_trs)
            mv_norm_locs_ori = mv_norm_locs_ori.permute(0, 1, 3, 2)
            mv_real_locs_ori = denormalize_L(mv_norm_locs_ori, MV_Shapes_G)
            mv_3D_locs_ori = TriangulateDLT_BatchCam(MV_KRT_G,
                                                     mv_real_locs_ori)
            mv_proj_locs_ori = ProjectKRT_Batch(
                MV_KRT_G, mv_3D_locs_ori.view(m_batch_size, 1, num_pts, 3))
            mv_pnorm_locs_ori = normalize_L(mv_proj_locs_ori, MV_Shapes_G)
            mv_pnorm_locs_trs = convert_theta(mv_pnorm_locs_ori, MV_Thetas_G)
            MV_locs = multiview_locs[:, :, :num_pts].contiguous()
            MV_heatmaps = [x[:, :, :num_pts] for x in multiview_heatmaps]

            if args.debug:
                with torch.no_grad():
                    for ims in range(m_batch_size):
                        x_index = image_index[iv_size + ims]
                        x_paths = [
                            xlist[iv_size + ims] for xlist in MultiViewPaths
                        ]
                        x_mv_locs, p_mv_locs = mv_real_locs_ori[
                            ims], mv_proj_locs_ori[ims]
                        multiview_debug_save(save_dir, '{:}'.format(x_index),
                                             x_paths,
                                             x_mv_locs.cpu().numpy(),
                                             p_mv_locs.cpu().numpy())
                        y_mv_locs = denormalize_points_batch((H, W),
                                                             MV_locs[ims])
                        q_mv_locs = denormalize_points_batch(
                            (H, W), mv_pnorm_locs_trs[ims])
                        temp_tensors = MV_Tensors[ims]
                        temp_images = [
                            args.tensor2imageF(x) for x in temp_tensors
                        ]
                        temp_names = [Path(x).name for x in x_paths]
                        multiview_debug_save_v2(save_dir,
                                                '{:}'.format(x_index),
                                                temp_names, temp_images,
                                                y_mv_locs.cpu().numpy(),
                                                q_mv_locs.cpu().numpy())

            stm_loss, available_nums = calculate_multiview_loss(
                criterion, MV_heatmaps, MV_locs, mv_pnorm_locs_trs, MV_Mask_G,
                stm_config)
            a3d_points.update(
                float(available_nums) / m_batch_size, m_batch_size)
            if available_nums > stm_config.available_stm_thresh:
                MultiviewLosses.update(stm_loss.item(), m_batch_size)
            else:
                stm_loss = 0
        else:
            stm_loss = 0

        # measure accuracy and record loss
        if use_stm[0]:
            total_loss = det_loss + sbr_loss * stm_config.sbr_weights + stm_loss * stm_config.stm_weights
        else:
            total_loss = det_loss + stm_loss * stm_config.stm_weights
        if isinstance(total_loss, numbers.Number):
            warnings.warn(
                'The {:}-th iteration has no detection loss and no lk loss'.
                format(i))
        else:
            TotalLosses.update(total_loss.item(), batch_size)
            # compute gradient and do SGD step
            if mode == 'train':  # training mode
                optimizer.zero_grad()
                total_loss.backward()
                optimizer.step()

        eval_time.update(time.time() - end)

        with torch.no_grad():
            batch_locs = batch_locs.detach().to(cpu)[:,
                                                     annotate_index, :num_pts]
            batch_scos = batch_scos.detach().to(cpu)[:,
                                                     annotate_index, :num_pts]
            # evaluate the training data
            for ibatch in range(iv_size):
                imgidx, nopoint = image_index[ibatch], nopoints[ibatch]
                if nopoint == 1: continue
                norm_locs = torch.cat(
                    (batch_locs[ibatch].permute(1, 0), torch.ones(1, num_pts)),
                    dim=0)
                transtheta = transthetas[ibatch][:2, :]
                norm_locs = torch.mm(transtheta, norm_locs)
                real_locs = denormalize_points(shapes[ibatch].tolist(),
                                               norm_locs)
                real_locs = torch.cat(
                    (real_locs, batch_scos[ibatch].view(1, num_pts)), dim=0)

                image_path = loader.dataset.datas[imgidx][annotate_index]
                normDistce = loader.dataset.NormDistances[imgidx]
                xpoints = loader.dataset.labels[imgidx].get_points()
                eval_meta.append(real_locs.numpy(), xpoints.numpy(),
                                 image_path, normDistce)
                if save_dir:
                    pro_debug_save(save_dir,
                                   Path(image_path).name,
                                   frames[ibatch,
                                          annotate_index], targets[ibatch],
                                   normpoints[ibatch], meanthetas[ibatch],
                                   batch_heatmaps[-1][ibatch, annotate_index],
                                   args.tensor2imageF)

        # measure elapsed time
        batch_time.update(time.time() - end)
        last_time = convert_secs2time(batch_time.avg * (len(loader) - i - 1),
                                      True)
        end = time.time()

        if i % print_freq == 0 or i + 1 == len(loader):
            logger.log(' -->>[{:}]: [{:}][{:03d}/{:03d}] '
                      'Time {batch_time.val:4.2f} ({batch_time.avg:4.2f}) '
                      'Data {data_time.val:4.2f} ({data_time.avg:4.2f}) '
                      'F-time {forward_time.val:4.2f} ({forward_time.avg:4.2f}) '
                      'Det {dloss.val:7.4f} ({dloss.avg:7.4f}) '
                      'SBR {sloss.val:7.6f} ({sloss.avg:7.6f}) '
                      'STM {mloss.val:7.6f} ({mloss.avg:7.6f}) '
                      'Loss {loss.val:7.4f} ({loss.avg:7.4f})  '.format(
                          mode, epoch_str, i, len(loader), batch_time=batch_time,
                          data_time=data_time, forward_time=forward_time, \
                          dloss=DetLosses, sloss=TemporalLosses, mloss=MultiviewLosses, loss=TotalLosses)
                        + last_time + each_stage_loss_value \
                        + ' I={:}'.format(list(frames.size())) \
                        + ' Vis-PTS : {:2d} ({:.1f})'.format(int(visible_points.val), visible_points.avg) \
                        + ' Ava-PTS : {:.1f} ({:.1f})'.format(alk_points.val, alk_points.avg) \
                        + ' A3D-PTS : {:.1f} ({:.1f})'.format(a3d_points.val, a3d_points.avg) )
            if args.debug:
                logger.log('  -->>Indexes : {:}'.format(image_index))
    nme, _, _ = eval_meta.compute_mse(loader.dataset.dataset_name, logger)
    return TotalLosses.avg, nme
def lk_train(args, loader, net, criterion, optimizer, epoch_str, logger, opt_config, lk_config, use_lk):
  args = deepcopy(args)
  batch_time, data_time, forward_time, eval_time = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
  visible_points, detlosses, lklosses = AverageMeter(), AverageMeter(), AverageMeter()
  alk_points, losses = AverageMeter(), AverageMeter()
  cpu = torch.device('cpu')
  
  annotate_index = loader.dataset.center_idx

  # switch to train mode
  net.train()
  criterion.train()

  end = time.time()
  for i, (inputs, target, mask, points, image_index, nopoints, video_or_not, cropped_size) in enumerate(loader):
    # inputs : Batch, Sequence Channel, Height, Width

    target = target.cuda(non_blocking=True)

    image_index = image_index.numpy().squeeze(1).tolist()
    batch_size, sequence, num_pts = inputs.size(0), inputs.size(1), args.num_pts
    mask_np = mask.numpy().squeeze(-1).squeeze(-1)
    visible_point_num   = float(np.sum(mask.numpy()[:,:-1,:,:])) / batch_size
    visible_points.update(visible_point_num, batch_size)
    nopoints    = nopoints.numpy().squeeze(1).tolist()
    video_or_not= video_or_not.numpy().squeeze(1).tolist()
    annotated_num = batch_size - sum(nopoints)

    # measure data loading time
    mask = mask.cuda(non_blocking=True)
    data_time.update(time.time() - end)

    # batch_heatmaps is a list for stage-predictions, each element should be [Batch, Sequence, PTS, H/Down, W/Down]
    batch_heatmaps, batch_locs, batch_scos, batch_next, batch_fback, batch_back = net(inputs)
    annot_heatmaps = [x[:, annotate_index] for x in batch_heatmaps]
    forward_time.update(time.time() - end)

    if annotated_num > 0:
      # have the detection loss
      detloss, each_stage_loss_value = compute_stage_loss(criterion, target, annot_heatmaps, mask)
      if opt_config.lossnorm:
        detloss, each_stage_loss_value = detloss / annotated_num / 2, [x/annotated_num/2 for x in each_stage_loss_value]
      # measure accuracy and record loss
      detlosses.update(detloss.item(), batch_size)
      each_stage_loss_value = show_stage_loss(each_stage_loss_value)
    else:
      detloss, each_stage_loss_value = 0, 'no-det-loss'

    if use_lk:
      lkloss, avaliable = lk_target_loss(batch_locs, batch_scos, batch_next, batch_fback, batch_back, lk_config, video_or_not, mask_np, nopoints)
      if lkloss is not None:
        lklosses.update(lkloss.item(), avaliable)
      else: lkloss = 0
      alk_points.update(float(avaliable)/batch_size, batch_size)
    else  : lkloss = 0
     
    loss = detloss + lkloss * lk_config.weight

    if isinstance(loss, numbers.Number):
      warnings.warn('The {:}-th iteration has no detection loss and no lk loss'.format(i))
    else:
      losses.update(loss.item(), batch_size)
      # compute gradient and do SGD step
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

    eval_time.update(time.time() - end)

    # measure elapsed time
    batch_time.update(time.time() - end)
    last_time = convert_secs2time(batch_time.avg * (len(loader)-i-1), True)
    end = time.time()

    if i % args.print_freq == 0 or i+1 == len(loader):
      logger.log(' -->>[Train]: [{:}][{:03d}/{:03d}] '
                'Time {batch_time.val:4.2f} ({batch_time.avg:4.2f}) '
                'Data {data_time.val:4.2f} ({data_time.avg:4.2f}) '
                'Forward {forward_time.val:4.2f} ({forward_time.avg:4.2f}) '
                'Loss {loss.val:7.4f} ({loss.avg:7.4f}) [LK={lk.val:7.4f} ({lk.avg:7.4f})] '.format(
                    epoch_str, i, len(loader), batch_time=batch_time,
                    data_time=data_time, forward_time=forward_time, loss=losses, lk=lklosses)
                  + each_stage_loss_value + ' ' + last_time \
                  + ' Vis-PTS : {:2d} ({:.1f})'.format(int(visible_points.val), visible_points.avg) \
                  + ' Ava-PTS : {:.1f} ({:.1f})'.format(alk_points.val, alk_points.avg))

  return losses.avg