Exemplo n.º 1
0
    def _validate(self):
        """
        Do validation. During validation, LayerChoices use the mixed-op.

        Returns
        -------
        float, float
            average loss, average nme
        """

        # test on validation set under eval mode
        self.model.eval()
        self.auxiliarynet.eval()

        losses, nme = list(), list()
        batch_time = AverageMeter("batch_time")
        end = time.time()
        with torch.no_grad():
            for i, (img, land_gt, angle_gt) in enumerate(self.valid_loader):
                img = img.to(self.device, non_blocking=True)
                landmark_gt = land_gt.to(self.device, non_blocking=True)
                angle_gt = angle_gt.to(self.device, non_blocking=True)

                landmark, _ = self.model(img)

                # compute the l2 loss
                landmark = landmark.squeeze()
                l2_diff = torch.sum((landmark_gt - landmark)**2, axis=1)
                loss = torch.mean(l2_diff)
                losses.append(loss.cpu().detach().numpy())

                # compute the accuracy
                landmark = landmark.cpu().detach().numpy()
                landmark = landmark.reshape(landmark.shape[0], -1, 2)
                landmark_gt = landmark_gt.cpu().detach().numpy()
                landmark_gt = landmark_gt.reshape(landmark_gt.shape[0], -1, 2)
                _, nme_i = accuracy(landmark, landmark_gt)
                for item in nme_i:
                    nme.append(item)

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

        self.logger.info("===> Evaluate:")
        self.logger.info("Eval set: Average loss: {:.4f} nme: {:.4f}".format(
            np.mean(losses), np.mean(nme)))
        return np.mean(losses), np.mean(nme)
Exemplo n.º 2
0
def validate(config, valid_loader, model, criterion, epoch, cur_step):
    top1 = AverageMeter("top1")
    top5 = AverageMeter("top5")
    losses = AverageMeter("losses")

    model.eval()

    with torch.no_grad():
        for step, (X, y) in enumerate(valid_loader):
            X, y = X.to(device, non_blocking=True), y.to(device,
                                                         non_blocking=True)
            bs = X.size(0)

            logits = model(X)
            loss = criterion(logits, y)

            accuracy = utils.accuracy(logits, y, topk=(1, 5))
            losses.update(loss.item(), bs)
            top1.update(accuracy["acc1"], bs)
            top5.update(accuracy["acc5"], bs)

            if step % config.log_frequency == 0 or step == len(
                    valid_loader) - 1:
                logger.info(
                    "Valid: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                    "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                        epoch + 1,
                        config.epochs,
                        step,
                        len(valid_loader) - 1,
                        losses=losses,
                        top1=top1,
                        top5=top5))

    writer.add_scalar("loss/test", losses.avg, global_step=cur_step)
    writer.add_scalar("acc1/test", top1.avg, global_step=cur_step)
    writer.add_scalar("acc5/test", top5.avg, global_step=cur_step)

    logger.info("Valid: [{:3d}/{}] Final Prec@1 {:.4%}".format(
        epoch + 1, config.epochs, top1.avg))

    return top1.avg
Exemplo n.º 3
0
def train(config, train_loader, model, optimizer, criterion, epoch):
    top1 = AverageMeter("top1")
    top5 = AverageMeter("top5")
    losses = AverageMeter("losses")

    cur_step = epoch * len(train_loader)
    cur_lr = optimizer.param_groups[0]["lr"]
    logger.info("Epoch %d LR %.6f", epoch, cur_lr)
    writer.add_scalar("lr", cur_lr, global_step=cur_step)

    model.train()

    for step, (x, y) in enumerate(train_loader):
        x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        bs = x.size(0)

        optimizer.zero_grad()
        logits, aux_logits = model(x)
        loss = criterion(logits, y)
        if config.aux_weight > 0.:
            loss += config.aux_weight * criterion(aux_logits, y)
        loss.backward()
        # gradient clipping
        nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
        optimizer.step()

        accuracy = utils.accuracy(logits, y, topk=(1, 5))
        losses.update(loss.item(), bs)
        top1.update(accuracy["acc1"], bs)
        top5.update(accuracy["acc5"], bs)
        writer.add_scalar("loss/train", loss.item(), global_step=cur_step)
        writer.add_scalar("acc1/train", accuracy["acc1"], global_step=cur_step)
        writer.add_scalar("acc5/train", accuracy["acc5"], global_step=cur_step)

        if step % config.log_frequency == 0 or step == len(train_loader) - 1:
            logger.info(
                "Train: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
                "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
                    epoch + 1,
                    config.epochs,
                    step,
                    len(train_loader) - 1,
                    losses=losses,
                    top1=top1,
                    top5=top5))

        cur_step += 1

    logger.info("Train: [{:3d}/{}] Final Prec@1 {:.4%}".format(
        epoch + 1, config.epochs, top1.avg))
Exemplo n.º 4
0
    def train(self, epoch):
        f1 = AverageMeter("f1")
        acc = AverageMeter("acc")
        losses = AverageMeter("losses")

        self.model.train()
        cur_step = epoch * len(self.train_loader)
        cur_lr = self.optimizer.param_groups[0]["lr"]
        logger.info("Epoch %d LR %.6f", epoch, cur_lr)
        for step, (x, y) in enumerate(self.train_loader):
            bs = x.size(0)
            self.optimizer.zero_grad()
            logits = self.model(x)

            if isinstance(logits, tuple):
                logits, aux_logits = logits
                aux_loss = criterion(aux_logits, y)
            else:
                aux_loss = 0.
            metrics = accuracy_metrics(logits, y)
            loss = self.criterion(logits, y)
            loss = loss + 0.4 * aux_loss
            loss.backward()
            nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
            self.optimizer.step()
            losses.update(loss.item(), bs)
            acc.update(metrics["acc_score"], bs)
            f1.update(metrics["f1_score"], bs)
            cur_step += 1
        logger.info("Train: [{:3d}/{}] Loss {losses.avg:.3f} "
                    "acc {acc.avg:.2%}, f1 {f1.avg:.2%}".format(epoch + 1,
                                                                100,
                                                                losses=losses,
                                                                acc=acc,
                                                                f1=f1))
        return losses.avg
Exemplo n.º 5
0
    def validate(self, epoch):
        f1 = AverageMeter("f1")
        acc = AverageMeter("acc")
        losses = AverageMeter("losses")

        self.model.eval()
        with torch.no_grad():
            for step, (x, y) in enumerate(self.test_loader):
                bs = x.size(0)
                logits = self.model(x)
                if isinstance(logits, tuple):
                    logits, _ = logits
                metrics = accuracy_metrics(logits, y)
                loss = self.criterion(logits, y)
                losses.update(loss.item(), bs)
                acc.update(metrics["acc_score"], bs)
                f1.update(metrics["f1_score"], bs)

        logger.info("Valid: [{:3d}/{}] Loss {losses.avg:.3f} "
                    "acc {acc.avg:.2%}, f1 {f1.avg:.2%}".format(epoch + 1,
                                                                100,
                                                                losses=losses,
                                                                acc=acc,
                                                                f1=f1))
        return f1.avg, losses.avg
Exemplo n.º 6
0
    def _train(self):
        """
        Train the model, it trains model weights and architecute weights.
        Architecture weights are trained according to the schedule.
        Before updating architecture weights, ```requires_grad``` is enabled.
        Then, it is disabled after the updating, in order not to update
        architecture weights when training model weights.
        """
        nBatch = len(self.train_loader)
        arch_param_num = self.mutator.num_arch_params()
        binary_gates_num = self.mutator.num_arch_params()
        logger.info('#arch_params: %d\t#binary_gates: %d', arch_param_num,
                    binary_gates_num)

        update_schedule = self._get_update_schedule(nBatch)

        for epoch in range(self.train_curr_epoch, self.n_epochs):
            logger.info('\n--------Train epoch: %d--------\n', epoch + 1)
            batch_time = AverageMeter('batch_time')
            data_time = AverageMeter('data_time')
            losses = AverageMeter('losses')
            top1 = AverageMeter('top1')
            top5 = AverageMeter('top5')
            # switch to train mode
            self.model.train()

            end = time.time()
            for i, (images, labels) in enumerate(self.train_loader):
                data_time.update(time.time() - end)
                lr = self._adjust_learning_rate(self.model_optim,
                                                epoch,
                                                batch=i,
                                                nBatch=nBatch)
                # train weight parameters
                images, labels = images.to(self.device), labels.to(self.device)
                self.mutator.reset_binary_gates()
                self.mutator.unused_modules_off()
                output = self.model(images)
                if self.label_smoothing > 0:
                    loss = cross_entropy_with_label_smoothing(
                        output, labels, self.label_smoothing)
                else:
                    loss = self.criterion(output, labels)
                acc1, acc5 = accuracy(output, labels, topk=(1, 5))
                losses.update(loss, images.size(0))
                top1.update(acc1[0], images.size(0))
                top5.update(acc5[0], images.size(0))
                self.model.zero_grad()
                loss.backward()
                self.model_optim.step()
                self.mutator.unused_modules_back()
                if epoch > 0:
                    for _ in range(update_schedule.get(i, 0)):
                        start_time = time.time()
                        # GradientArchSearchConfig
                        self.mutator.arch_requires_grad()
                        arch_loss, exp_value = self._gradient_step()
                        self.mutator.arch_disable_grad()
                        used_time = time.time() - start_time
                        log_str = 'Architecture [%d-%d]\t Time %.4f\t Loss %.4f\t null %s' % \
                                    (epoch + 1, i, used_time, arch_loss, exp_value)
                        logger.info(log_str)
                batch_time.update(time.time() - end)
                end = time.time()
                # training log
                if i % 10 == 0 or i + 1 == nBatch:
                    batch_log = 'Train [{0}][{1}/{2}]\t' \
                                'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                                'Data Time {data_time.val:.3f} ({data_time.avg:.3f})\t' \
                                'Loss {losses.val:.4f} ({losses.avg:.4f})\t' \
                                'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})\t' \
                                'Top-5 acc {top5.val:.3f} ({top5.avg:.3f})\tlr {lr:.5f}'. \
                        format(epoch + 1, i, nBatch - 1, batch_time=batch_time, data_time=data_time,
                               losses=losses, top1=top1, top5=top5, lr=lr)
                    logger.info(batch_log)
            # validate
            if (epoch + 1) % self.arch_valid_frequency == 0:
                val_loss, val_top1, val_top5 = self._validate()
                val_log = 'Valid [{0}]\tloss {1:.3f}\ttop-1 acc {2:.3f} \ttop-5 acc {3:.3f}\t' \
                          'Train top-1 {top1.avg:.3f}\ttop-5 {top5.avg:.3f}'. \
                    format(epoch + 1, val_loss, val_top1, val_top5, top1=top1, top5=top5)
                logger.info(val_log)
            self.save_checkpoint()
            self.train_curr_epoch += 1
Exemplo n.º 7
0
    def _warm_up(self):
        """
        Warm up the model, during warm up, architecture weights are not trained.
        """
        lr_max = 0.05
        data_loader = self.train_loader
        nBatch = len(data_loader)
        T_total = self.warmup_epochs * nBatch  # total num of batches

        for epoch in range(self.warmup_curr_epoch, self.warmup_epochs):
            logger.info('\n--------Warmup epoch: %d--------\n', epoch + 1)
            batch_time = AverageMeter('batch_time')
            data_time = AverageMeter('data_time')
            losses = AverageMeter('losses')
            top1 = AverageMeter('top1')
            top5 = AverageMeter('top5')
            # switch to train mode
            self.model.train()

            end = time.time()
            logger.info('warm_up epoch: %d', epoch)
            for i, (images, labels) in enumerate(data_loader):
                data_time.update(time.time() - end)
                # lr
                T_cur = epoch * nBatch + i
                warmup_lr = 0.5 * lr_max * (
                    1 + math.cos(math.pi * T_cur / T_total))
                for param_group in self.model_optim.param_groups:
                    param_group['lr'] = warmup_lr
                images, labels = images.to(self.device), labels.to(self.device)
                # compute output
                self.mutator.reset_binary_gates()  # random sample binary gates
                self.mutator.unused_modules_off(
                )  # remove unused module for speedup
                output = self.model(images)
                if self.label_smoothing > 0:
                    loss = cross_entropy_with_label_smoothing(
                        output, labels, self.label_smoothing)
                else:
                    loss = self.criterion(output, labels)
                # measure accuracy and record loss
                acc1, acc5 = accuracy(output, labels, topk=(1, 5))
                losses.update(loss, images.size(0))
                top1.update(acc1[0], images.size(0))
                top5.update(acc5[0], images.size(0))
                # compute gradient and do SGD step
                self.model.zero_grad()
                loss.backward()
                self.model_optim.step()
                # unused modules back
                self.mutator.unused_modules_back()
                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if i % 10 == 0 or i + 1 == nBatch:
                    batch_log = 'Warmup Train [{0}][{1}/{2}]\t' \
                                'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                                'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \
                                'Loss {losses.val:.4f} ({losses.avg:.4f})\t' \
                                'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})\t' \
                                'Top-5 acc {top5.val:.3f} ({top5.avg:.3f})\tlr {lr:.5f}'. \
                        format(epoch + 1, i, nBatch - 1, batch_time=batch_time, data_time=data_time,
                               losses=losses, top1=top1, top5=top5, lr=warmup_lr)
                    logger.info(batch_log)
            val_loss, val_top1, val_top5 = self._validate()
            val_log = 'Warmup Valid [{0}/{1}]\tloss {2:.3f}\ttop-1 acc {3:.3f}\ttop-5 acc {4:.3f}\t' \
                      'Train top-1 {top1.avg:.3f}\ttop-5 {top5.avg:.3f}M'. \
                format(epoch + 1, self.warmup_epochs, val_loss, val_top1, val_top5, top1=top1, top5=top5)
            logger.info(val_log)
            self.save_checkpoint()
            self.warmup_curr_epoch += 1
Exemplo n.º 8
0
    def _validate(self):
        """
        Do validation. During validation, LayerChoices use the chosen active op.

        Returns
        -------
        float, float, float
            average loss, average top1 accuracy, average top5 accuracy
        """
        self.valid_loader.batch_sampler.batch_size = self.valid_batch_size
        self.valid_loader.batch_sampler.drop_last = False

        self.mutator.set_chosen_op_active()
        # remove unused modules to save memory
        self.mutator.unused_modules_off()
        # test on validation set under train mode
        self.model.train()
        batch_time = AverageMeter('batch_time')
        losses = AverageMeter('losses')
        top1 = AverageMeter('top1')
        top5 = AverageMeter('top5')
        end = time.time()
        with torch.no_grad():
            for i, (images, labels) in enumerate(self.valid_loader):
                images, labels = images.to(self.device), labels.to(self.device)
                output = self.model(images)
                loss = self.criterion(output, labels)
                acc1, acc5 = accuracy(output, labels, topk=(1, 5))
                losses.update(loss, images.size(0))
                top1.update(acc1[0], images.size(0))
                top5.update(acc5[0], images.size(0))
                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if i % 10 == 0 or i + 1 == len(self.valid_loader):
                    test_log = 'Valid' + ': [{0}/{1}]\t'\
                                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'\
                                        'Loss {loss.val:.4f} ({loss.avg:.4f})\t'\
                                        'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})'.\
                        format(i, len(self.valid_loader) - 1, batch_time=batch_time, loss=losses, top1=top1)
                    # return top5:
                    test_log += '\tTop-5 acc {top5.val:.3f} ({top5.avg:.3f})'.format(
                        top5=top5)
                    logger.info(test_log)
        self.mutator.unused_modules_back()
        return losses.avg, top1.avg, top5.avg
Exemplo n.º 9
0
    def _train_epoch(self, epoch, optimizer, arch_train=False):
        """
        Train one epoch.
        """
        batch_time = AverageMeter("batch_time")
        data_time = AverageMeter("data_time")
        losses = AverageMeter("losses")
        top1 = AverageMeter("top1")
        top5 = AverageMeter("top5")

        # switch to train mode
        self.model.train()

        data_loader = self.valid_loader if arch_train else self.train_loader
        end = time.time()
        for i, (images, labels) in enumerate(data_loader):
            data_time.update(time.time() - end)
            images = images.to(self.device, non_blocking=True)
            labels = labels.to(self.device, non_blocking=True)

            output = self.model(images)
            loss = self.criterion(output, labels)

            # hardware-aware loss
            perf_cost = self._get_perf_cost(requires_grad=True)
            regu_loss = self.reg_loss(perf_cost)
            if self.mode.startswith("mul"):
                loss = loss * regu_loss
            elif self.mode.startswith("add"):
                loss = loss + regu_loss

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, labels, topk=(1, 5))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0].item(), images.size(0))
            top5.update(acc5[0].item(), images.size(0))
            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % 10 == 0:
                batch_log = (
                    "Warmup Train [{0}][{1}]\t"
                    "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                    "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
                    "Loss {losses.val:.4f} ({losses.avg:.4f})\t"
                    "Top-1 acc {top1.val:.3f} ({top1.avg:.3f})\t"
                    "Top-5 acc {top5.val:.3f} ({top5.avg:.3f})\t".format(
                        epoch + 1,
                        i,
                        batch_time=batch_time,
                        data_time=data_time,
                        losses=losses,
                        top1=top1,
                        top5=top5,
                    ))
                self.logger.info(batch_log)
Exemplo n.º 10
0
    def _validate(self):
        """
        Do validation. During validation, LayerChoices use the mixed-op.

        Returns
        -------
        float, float, float
            average loss, average top1 accuracy, average top5 accuracy
        """
        self.valid_loader.batch_sampler.drop_last = False
        batch_time = AverageMeter("batch_time")
        losses = AverageMeter("losses")
        top1 = AverageMeter("top1")
        top5 = AverageMeter("top5")

        # test on validation set under eval mode
        self.model.eval()

        end = time.time()
        with torch.no_grad():
            for i, (images, labels) in enumerate(self.valid_loader):
                images = images.to(self.device, non_blocking=True)
                labels = labels.to(self.device, non_blocking=True)

                output = self.model(images)

                loss = self.criterion(output, labels)
                acc1, acc5 = accuracy(output, labels, topk=(1, 5))
                losses.update(loss, images.size(0))
                top1.update(acc1[0], images.size(0))
                top5.update(acc5[0], images.size(0))
                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if i % 10 == 0 or i + 1 == len(self.valid_loader):
                    test_log = (
                        "Valid" + ": [{0}/{1}]\t"
                        "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                        "Loss {loss.val:.4f} ({loss.avg:.4f})\t"
                        "Top-1 acc {top1.val:.3f} ({top1.avg:.3f})\t"
                        "Top-5 acc {top5.val:.3f} ({top5.avg:.3f})".format(
                            i,
                            len(self.valid_loader) - 1,
                            batch_time=batch_time,
                            loss=losses,
                            top1=top1,
                            top5=top5,
                        ))
                    self.logger.info(test_log)

        return losses.avg, top1.avg, top5.avg
Exemplo n.º 11
0
def train_epoch(
    model,
    auxiliarynet,
    criterion,
    train_loader,
    device,
    epoch,
    optimizer,
    logger,
):
    """Train one epoch."""
    model.train()
    auxiliarynet.train()

    batch_time = AverageMeter("batch_time")
    data_time = AverageMeter("data_time")
    losses = AverageMeter("losses")

    end = time.time()
    for i, (img, landmark_gt, angle_gt) in enumerate(train_loader):
        data_time.update(time.time() - end)
        img = img.to(device, non_blocking=True)
        landmark_gt = landmark_gt.to(device, non_blocking=True)
        angle_gt = angle_gt.to(device, non_blocking=True)

        lands, feats = model(img)
        landmarks = lands.squeeze()
        angle = auxiliarynet(feats)

        # task loss
        weighted_loss, _ = criterion(landmark_gt, angle_gt, angle, landmarks)
        loss = weighted_loss

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        # measure accuracy and record loss
        losses.update(np.squeeze(loss.cpu().detach().numpy()), img.size(0))

        if i % 10 == 0:
            batch_log = ("Train [{0}][{1}]\t"
                         "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                         "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
                         "Loss {losses.val:.4f} ({losses.avg:.4f})".format(
                             epoch + 1,
                             i,
                             batch_time=batch_time,
                             data_time=data_time,
                             losses=losses,
                         ))
            logger.info(batch_log)
Exemplo n.º 12
0
    def train_one_epoch(self,
                        adjust_lr_func,
                        train_log_func,
                        label_smoothing=0.1):
        batch_time = AverageMeter('batch_time')
        data_time = AverageMeter('data_time')
        losses = AverageMeter('losses')
        top1 = AverageMeter('top1')
        top5 = AverageMeter('top5')
        self.model.train()
        end = time.time()
        for i, (images, labels) in enumerate(self.train_loader):
            data_time.update(time.time() - end)
            new_lr = adjust_lr_func(i)
            images, labels = images.to(self.device), labels.to(self.device)
            output = self.model(images)
            if label_smoothing > 0:
                loss = cross_entropy_with_label_smoothing(
                    output, labels, label_smoothing)
            else:
                loss = self.criterion(output, labels)
            acc1, acc5 = accuracy(output, labels, topk=(1, 5))
            losses.update(loss, images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # compute gradient and do SGD step
            self.model.zero_grad()  # or self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % 10 == 0 or i + 1 == len(self.train_loader):
                batch_log = train_log_func(i, batch_time, data_time, losses,
                                           top1, top5, new_lr)
                print(batch_log)
        return top1, top5
Exemplo n.º 13
0
    def validate(self, is_test=True):
        if is_test:
            data_loader = self.test_loader
        else:
            data_loader = self.valid_loader
        self.model.eval()
        batch_time = AverageMeter('batch_time')
        losses = AverageMeter('losses')
        top1 = AverageMeter('top1')
        top5 = AverageMeter('top5')

        end = time.time()
        with torch.no_grad():
            for i, (images, labels) in enumerate(data_loader):
                images, labels = images.to(self.device), labels.to(self.device)
                # compute output
                output = self.model(images)
                loss = self.criterion(output, labels)
                # measure accuracy and record loss
                acc1, acc5 = accuracy(output, labels, topk=(1, 5))
                losses.update(loss, images.size(0))
                top1.update(acc1[0], images.size(0))
                top5.update(acc5[0], images.size(0))
                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if i % 10 == 0 or i + 1 == len(data_loader):
                    if is_test:
                        prefix = 'Test'
                    else:
                        prefix = 'Valid'
                    test_log = prefix + ': [{0}/{1}]\t'\
                                        'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'\
                                        'Loss {loss.val:.4f} ({loss.avg:.4f})\t'\
                                        'Top-1 acc {top1.val:.3f} ({top1.avg:.3f})'.\
                        format(i, len(data_loader) - 1, batch_time=batch_time, loss=losses, top1=top1)
                    test_log += '\tTop-5 acc {top5.val:.3f} ({top5.avg:.3f})'.format(
                        top5=top5)
                    print(test_log)
        return losses.avg, top1.avg, top5.avg
Exemplo n.º 14
0
    def _train_epoch(self, epoch, optimizer, arch_train=False):
        """
        Train one epoch.
        """
        # switch to train mode
        self.model.train()
        self.auxiliarynet.train()

        batch_time = AverageMeter("batch_time")
        data_time = AverageMeter("data_time")
        losses = AverageMeter("losses")

        data_loader = self.valid_loader if arch_train else self.train_loader
        end = time.time()
        for i, (img, landmark_gt, angle_gt) in enumerate(data_loader):
            data_time.update(time.time() - end)
            img = img.to(self.device, non_blocking=True)
            landmark_gt = landmark_gt.to(self.device, non_blocking=True)
            angle_gt = angle_gt.to(self.device, non_blocking=True)

            lands, feats = self.model(img)
            landmarks = lands.squeeze()
            angle = self.auxiliarynet(feats)

            # task loss
            weighted_loss, l2_loss = self.criterion(landmark_gt, angle_gt,
                                                    angle, landmarks)
            loss = l2_loss if arch_train else weighted_loss

            # hardware-aware loss
            perf_cost = self._get_perf_cost(requires_grad=True)
            regu_loss = self.reg_loss(perf_cost)
            if self.mode.startswith("mul"):
                loss = loss * regu_loss
            elif self.mode.startswith("add"):
                loss = loss + regu_loss

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()
            # measure accuracy and record loss
            losses.update(np.squeeze(loss.cpu().detach().numpy()), img.size(0))

            if i % 10 == 0:
                batch_log = (
                    "Train [{0}][{1}]\t"
                    "Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t"
                    "Data {data_time.val:.3f} ({data_time.avg:.3f})\t"
                    "Loss {losses.val:.4f} ({losses.avg:.4f})".format(
                        epoch + 1,
                        i,
                        batch_time=batch_time,
                        data_time=data_time,
                        losses=losses,
                    ))
                self.logger.info(batch_log)