コード例 #1
0
def validate(val_loader, model, criterion):
    if is_main_process():
        logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>")
    batch_time = AverageMeter()
    data_time = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()
    top1_meter = AverageMeter()
    top5_meter = AverageMeter()

    model.eval()
    end = time.time()
    for i, (input, target) in enumerate(val_loader):
        data_time.update(time.time() - end)
        input = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)
        output = model(input)
        loss = criterion(output, target)

        top1, top5 = cal_accuracy(output, target, topk=(1, 5))
        n = input.size(0)
        if CONFIG.multiprocessing_distributed:
            with torch.no_grad():
                loss, top1, top5 = loss.detach() * n, top1 * n, top5 * n
                count = target.new_tensor([n], dtype=torch.long)
                distributed.all_reduce(loss), distributed.all_reduce(
                    top1), distributed.all_reduce(
                        top5), distributed.all_reduce(count)
                n = count.item()
                loss, top1, top5 = loss / n, top1 / n, top5 / n
        loss_meter.update(loss.item(), n), top1_meter.update(
            top1.item(), n), top5_meter.update(top5.item(), n)

        output = output.max(1)[1]
        intersection, union, target = intersection_and_union_gpu(
            output, target, val_loader.dataset.response_shape[0],
            CONFIG.ignore_label)
        if CONFIG.multiprocessing_distributed:
            distributed.all_reduce(intersection), distributed.all_reduce(
                union), distributed.all_reduce(target)
        intersection, union, target = (
            intersection.cpu().numpy(),
            union.cpu().numpy(),
            target.cpu().numpy(),
        )
        intersection_meter.update(intersection), union_meter.update(
            union), target_meter.update(target)
        accuracy = sum(
            intersection_meter.val) / (sum(target_meter.val) + 1e-10)
        batch_time.update(time.time() - end)
        end = time.time()

        if ((i + 1) % CONFIG.print_freq == 0) and is_main_process():
            logger.info(
                f"Test: [{i + 1}/{len(val_loader)}] Data {data_time.val:.3f} ({data_time.avg:.3f}) Batch "
                f"{batch_time.val:.3f} ({batch_time.avg:.3f}) Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f}) "
                f"Accuracy {accuracy:.4f} Acc@1 {top1_meter.val:.3f} ({top1_meter.avg:.3f}) Acc@5 "
                f"{top5_meter.val:.3f} ({top5_meter.avg:.3f}).")

    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    mIoU = numpy.mean(iou_class)
    mAcc = numpy.mean(accuracy_class)
    allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)

    if is_main_process():
        logger.info(
            f"Val result: mIoU/mAcc/allAcc/top1/top5 {mIoU:.4f}/{mAcc:.4f}/{allAcc:.4f}/{top1_meter.avg:.4f}/"
            f"{top5_meter.avg:.4f}.")
        for i in range(val_loader.dataset.response_shape[0]):
            if target_meter.sum[i] > 0:
                logger.info(
                    f"Class_{i} Result: iou/accuracy {iou_class[i]:.4f}/{accuracy_class[i]:.4f} Count:{target_meter.sum[i]}"
                )
        logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
    return loss_meter.avg, mIoU, mAcc, allAcc, top1_meter.avg, top5_meter.avg
コード例 #2
0
def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    loss_meter = AverageMeter()
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()
    top1_meter = AverageMeter()
    top5_meter = AverageMeter()

    model.train()
    end = time.time()
    max_iter = CONFIG.epochs * len(train_loader)
    for i, (input, target) in enumerate(train_loader):
        data_time.update(time.time() - end)
        input = input.cuda(non_blocking=True)
        target = target.cuda(non_blocking=True)
        if CONFIG.mixup_alpha:
            eps = CONFIG.label_smoothing if CONFIG.label_smoothing else 0.0
            input, target_a, target_b, lam = mixup_data(
                input, target, CONFIG.mixup_alpha)
            output = model(input)
            loss = mixup_loss(output, target_a, target_b, lam, eps)
        else:
            output = model(input)
            loss = (smooth_loss(output, target, CONFIG.label_smoothing)
                    if CONFIG.label_smoothing else criterion(output, target))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        top1, top5 = cal_accuracy(output, target, topk=(1, 5))
        n = input.size(0)
        if CONFIG.multiprocessing_distributed:
            with torch.no_grad():
                loss, top1, top5 = loss.detach() * n, top1 * n, top5 * n
                count = target.new_tensor([n], dtype=torch.long)
                distributed.all_reduce(loss)
                distributed.all_reduce(top1)
                distributed.all_reduce(top5)
                distributed.all_reduce(count)
                n = count.item()
                loss, top1, top5 = loss / n, top1 / n, top5 / n
        loss_meter.update(loss.item(), n), top1_meter.update(
            top1.item(), n), top5_meter.update(top5.item(), n)

        output = output.max(1)[1]
        intersection, union, target = intersection_and_union_gpu(
            output, target, train_loader.dataset.response_shape[0],
            CONFIG.ignore_label)
        if CONFIG.multiprocessing_distributed:
            distributed.all_reduce(intersection)
            distributed.all_reduce(union)
            distributed.all_reduce(target)
        intersection, union, target = (
            intersection.cpu().numpy(),
            union.cpu().numpy(),
            target.cpu().numpy(),
        )
        intersection_meter.update(intersection), union_meter.update(
            union), target_meter.update(target)
        accuracy = sum(
            intersection_meter.val) / (sum(target_meter.val) + 1e-10)
        batch_time.update(time.time() - end)
        end = time.time()

        # calculate remain time
        current_iter = epoch * len(train_loader) + i + 1
        remain_iter = max_iter - current_iter
        remain_time = remain_iter * batch_time.avg
        t_m, t_s = divmod(remain_time, 60)
        t_h, t_m = divmod(t_m, 60)
        remain_time = f"{int(t_h):02d}:{int(t_m):02d}:{int(t_s):02d}"

        if ((i + 1) % CONFIG.print_freq == 0) and is_main_process():
            logger.info(
                f"Epoch: [{epoch + 1}/{CONFIG.epochs}][{i + 1}/{len(train_loader)}] Data {data_time.val:.3f} ("
                f"{data_time.avg:.3f}) Batch {batch_time.val:.3f} ({batch_time.avg:.3f}) Remain {remain_time} Loss "
                f"{loss_meter.val:.4f} Accuracy {accuracy:.4f} Acc@1 {top1_meter.val:.3f} ({top1_meter.avg:.3f}) "
                f"Acc@5 {top5_meter.val:.3f} ({top5_meter.avg:.3f}).")
        if is_main_process():
            writer.scalar("loss_train_batch", loss_meter.val, current_iter)
            writer.scalar(
                "mIoU_train_batch",
                numpy.mean(intersection / (union + 1e-10)),
                current_iter,
            )
            writer.scalar(
                "mAcc_train_batch",
                numpy.mean(intersection / (target + 1e-10)),
                current_iter,
            )
            writer.scalar("allAcc_train_batch", accuracy, current_iter)
            writer.scalar("top1_train_batch", top1, current_iter)
            writer.scalar("top5_train_batch", top5, current_iter)

    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    mIoU = numpy.mean(iou_class)
    mAcc = numpy.mean(accuracy_class)
    allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)
    if is_main_process():
        logger.info(
            f"Train result at epoch [{epoch + 1}/{CONFIG.epochs}]: mIoU/mAcc/allAcc/top1/top5 {mIoU:.4f}/"
            f"{mAcc:.4f}/{allAcc:.4f}/{top1_meter.avg:.4f}/{top5_meter.avg:.4f}."
        )
    return loss_meter.avg, mIoU, mAcc, allAcc, top1_meter.avg, top5_meter.avg
コード例 #3
0
    def validate(self, epoch):
        """Evaluate the RAM model on the validation set.
"""
        losses = AverageMeter()
        accs = AverageMeter()

        for i, (x, y) in enumerate(self.valid_loader):
            x, y = x.to(self.device), y.to(self.device)

            # duplicate M times
            x = x.repeat(self.M, 1, 1, 1)

            # initialize location vector and hidden state
            self.batch_size = x.shape[0]
            h_t, l_t = self.reset()

            # extract the glimpses
            log_pi = []
            baselines = []
            for t in range(self.num_glimpses - 1):
                # forward pass through model
                h_t, l_t, b_t, p = self.model(x, l_t, h_t)

                # store
                baselines.append(b_t)
                log_pi.append(p)

            # last iteration
            h_t, l_t, b_t, log_probas, p = self.model(x, l_t, h_t, last=True)
            log_pi.append(p)
            baselines.append(b_t)

            # convert list to tensors and reshape
            baselines = torch.stack(baselines).transpose(1, 0)
            log_pi = torch.stack(log_pi).transpose(1, 0)

            # average
            log_probas = log_probas.view(self.M, -1, log_probas.shape[-1])
            log_probas = torch.mean(log_probas, dim=0)

            baselines = baselines.contiguous().view(self.M, -1,
                                                    baselines.shape[-1])
            baselines = torch.mean(baselines, dim=0)

            log_pi = log_pi.contiguous().view(self.M, -1, log_pi.shape[-1])
            log_pi = torch.mean(log_pi, dim=0)

            # calculate reward
            predicted = torch.max(log_probas, 1)[1]
            R = (predicted.detach() == y).float()
            R = R.unsqueeze(1).repeat(1, self.num_glimpses)

            # compute losses for differentiable modules
            loss_action = F.nll_loss(log_probas, y)
            loss_baseline = F.mse_loss(baselines, R)

            # compute reinforce loss
            adjusted_reward = R - baselines.detach()
            loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1)
            loss_reinforce = torch.mean(loss_reinforce, dim=0)

            # sum up into a hybrid loss
            loss = loss_action + loss_baseline + loss_reinforce * 0.01

            # compute accuracy
            correct = (predicted == y).float()
            acc = 100 * (correct.sum() / len(y))

            # store
            losses.update(loss.item(), x.size()[0])
            accs.update(acc.item(), x.size()[0])

            # log to tensorboard
            if self.use_tensorboard:
                iteration = epoch * len(self.valid_loader) + i
                log_value("valid_loss", losses.avg, iteration)
                log_value("valid_acc", accs.avg, iteration)

        return losses.avg, accs.avg
コード例 #4
0
ファイル: test_san.py プロジェクト: sintefneodroid/vision
    def main(shuffle: bool = True, how_many_batches=10, batch_size=1):
        def get_logger():
            logger_name = "main-logger"
            logger = logging.getLogger(logger_name)
            logger.setLevel(logging.INFO)
            handler = logging.StreamHandler()
            fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s"
            handler.setFormatter(logging.Formatter(fmt))
            logger.addHandler(handler)
            return logger

        from samples.classification.san.configs.imagenet_san10_patchwise import (
            SAN_CONFIG, )

        dataset = SAN_CONFIG.dataset_type(SAN_CONFIG.dataset_path,
                                          Split.Validation)

        logger = get_logger()
        logger.info(SAN_CONFIG)
        logger.info("=> creating model ...")
        logger.info(f"Classes: {dataset.response_shape[0]}")
        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
            str(x) for x in SAN_CONFIG.test_gpu)
        model = make_san(
            self_attention_type=SelfAttentionTypeEnum(
                SAN_CONFIG.self_attention_type),
            layers=SAN_CONFIG.layers,
            kernels=SAN_CONFIG.kernels,
            num_classes=dataset.response_shape[0],
        )
        logger.info(model)
        model = torch.nn.DataParallel(model.cuda())

        if os.path.isdir(SAN_CONFIG.save_path):
            logger.info(f"=> loading checkpoint '{SAN_CONFIG.model_path}'")
            checkpoint = torch.load(SAN_CONFIG.model_path)
            model.load_state_dict(checkpoint["state_dict"], strict=True)
            logger.info(f"=> loaded checkpoint '{SAN_CONFIG.model_path}'")
        else:
            raise RuntimeError(
                f"=> no checkpoint found at '{SAN_CONFIG.model_path}'")

        criterion = nn.CrossEntropyLoss(ignore_index=SAN_CONFIG.ignore_label)

        val_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=SAN_CONFIG.test_workers,
            pin_memory=True,
        )

        logger.info(">>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>")
        batch_time = AverageMeter()
        data_time = AverageMeter()
        loss_meter = AverageMeter()
        intersection_meter = AverageMeter()
        union_meter = AverageMeter()
        target_meter = AverageMeter()
        top1_meter = AverageMeter()
        top5_meter = AverageMeter()

        model.eval()
        end = time.time()

        if how_many_batches:
            T = range(how_many_batches)
        else:
            T = count()

        for i, (input, target) in zip(T, val_loader):
            data_time.update(time.time() - end)
            input = input.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

            with torch.no_grad():
                output = model(input)
                pyplot.imshow(dataset.inverse_base_transform(input[0].cpu()))
                pyplot.title(
                    f"pred:{dataset.category_names[output.max(1)[1][0].item()]} truth:{dataset.category_names[target[0].item()]}"
                )
                pyplot.show()

            loss = criterion(output, target)
            top1, top5 = cal_accuracy(output, target, topk=(1, 5))
            n = input.size(0)
            loss_meter.update(loss.item(), n), top1_meter.update(
                top1.item(), n), top5_meter.update(top5.item(), n)

            intersection, union, target = intersection_and_union_gpu(
                output.max(1)[1],
                target,
                dataset.response_shape[0],
                SAN_CONFIG.ignore_label,
            )
            intersection, union, target = (
                intersection.cpu().numpy(),
                union.cpu().numpy(),
                target.cpu().numpy(),
            )
            intersection_meter.update(intersection), union_meter.update(
                union), target_meter.update(target)

            accuracy = sum(
                intersection_meter.val) / (sum(target_meter.val) + 1e-10)
            batch_time.update(time.time() - end)
            end = time.time()

            if (i + 1) % SAN_CONFIG.print_freq == 0:
                logger.info(
                    f"Test: [{i + 1}/{len(val_loader)}] Data {data_time.val:.3f} ({data_time.avg:.3f}) Batch "
                    f"{batch_time.val:.3f} ({batch_time.avg:.3f}) Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f}) "
                    f"Accuracy {accuracy:.4f} Acc@1 {top1_meter.val:.3f} ({top1_meter.avg:.3f}) Acc@5 "
                    f"{top5_meter.val:.3f} ({top5_meter.avg:.3f}).")

        iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
        accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
        mIoU = numpy.mean(iou_class)
        mAcc = numpy.mean(accuracy_class)
        allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)

        logger.info(
            f"Val result: mIoU/mAcc/allAcc/top1/top5 {mIoU:.4f}/{mAcc:.4f}/{allAcc:.4f}/{top1_meter.avg:.4f}/"
            f"{top5_meter.avg:.4f}.")
        for i in range(dataset.response_shape[0]):
            if target_meter.sum[i] > 0:
                logger.info(
                    f"Class_{i} Result: iou/accuracy {iou_class[i]:.4f}/{accuracy_class[i]:.4f} Count:{target_meter.sum[i]}"
                )
        logger.info("<<<<<<<<<<<<<<<<< End Evaluation <<<<<<<<<<<<<<<<<")
        print(loss_meter.avg, mIoU, mAcc, allAcc, top1_meter.avg,
              top5_meter.avg)
コード例 #5
0
    def train_one_epoch(self, epoch):
        """
Train the model for 1 epoch of the training set.

An epoch corresponds to one full pass through the entire
training set in successive mini-batches.

This is used by train() and should not be called manually.
"""
        self.model.train()
        batch_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()

        tic = time.time()
        with tqdm(total=self.num_train) as pbar:
            for i, (x, y) in enumerate(self.train_loader):
                self.optimizer.zero_grad()

                x, y = x.to(self.device), y.to(self.device)

                plot = False
                if (epoch % self.plot_freq == 0) and (i == 0):
                    plot = True

                # initialize location vector and hidden state
                self.batch_size = x.shape[0]
                h_t, l_t = self.reset()

                # save images
                imgs = []
                imgs.append(x[0:9])

                # extract the glimpses
                locs = []
                log_pi = []
                baselines = []
                for t in range(self.num_glimpses - 1):
                    # forward pass through model
                    h_t, l_t, b_t, p = self.model(x, l_t, h_t)

                    # store
                    locs.append(l_t[0:9])
                    baselines.append(b_t)
                    log_pi.append(p)

                # last iteration
                h_t, l_t, b_t, log_probas, p = self.model(x,
                                                          l_t,
                                                          h_t,
                                                          last=True)
                log_pi.append(p)
                baselines.append(b_t)
                locs.append(l_t[0:9])

                # convert list to tensors and reshape
                baselines = torch.stack(baselines).transpose(1, 0)
                log_pi = torch.stack(log_pi).transpose(1, 0)

                # calculate reward
                predicted = torch.max(log_probas, 1)[1]
                R = (predicted.detach() == y).float()
                R = R.unsqueeze(1).repeat(1, self.num_glimpses)

                # compute losses for differentiable modules
                loss_action = F.nll_loss(log_probas, y)
                loss_baseline = F.mse_loss(baselines, R)

                # compute reinforce loss
                # summed over timesteps and averaged across batch
                adjusted_reward = R - baselines.detach()
                loss_reinforce = torch.sum(-log_pi * adjusted_reward, dim=1)
                loss_reinforce = torch.mean(loss_reinforce, dim=0)

                # sum up into a hybrid loss
                loss = loss_action + loss_baseline + loss_reinforce * 0.01

                # compute accuracy
                correct = (predicted == y).float()
                acc = 100 * (correct.sum() / len(y))

                # store
                losses.update(loss.item(), x.size()[0])
                accs.update(acc.item(), x.size()[0])

                # compute gradients and update SGD
                loss.backward()
                self.optimizer.step()

                # measure elapsed time
                toc = time.time()
                batch_time.update(toc - tic)

                pbar.set_description((
                    f"{(toc - tic):.1f}s - loss: {loss.item():.3f} - acc: {acc.item():.3f}"
                ))
                pbar.update(self.batch_size)

                # dump the glimpses and locs
                if plot:
                    imgs = [g.cpu().data.numpy().squeeze() for g in imgs]
                    locs = [l.cpu().data.numpy() for l in locs]
                    pickle.dump(
                        imgs,
                        open(str(self.plot_dir / f"g_{epoch + 1}.p"), "wb"))
                    pickle.dump(
                        locs,
                        open(str(self.plot_dir / f"l_{epoch + 1}.p"), "wb"))

                # log to tensorboard
                if self.use_tensorboard:
                    iteration = epoch * len(self.train_loader) + i
                    log_value("train_loss", losses.avg, iteration)
                    log_value("train_acc", accs.avg, iteration)

            return losses.avg, accs.avg