Пример #1
0
def test(dataset, epoch):
    dataset_transform = transforms.Compose([
        transforms.ToTensor(),
        # transforms.Normalize((0.1307,), (0.3081,))
    ])

    test_dataset = datasets.TUMOR_IMG(dataset,
                                      train=False,
                                      transform=dataset_transform)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=test_batch_size,
                                              shuffle=True)
    # 使用模型,
    checkpoint = None
    checkpoint_file = 'model.pt'
    output_path = "./outputs"
    if os.path.exists(output_path):
        checkpoint = torch.load(os.path.join(output_path, checkpoint_file))
    if checkpoint != None:
        print("test and load from ckpt...")
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        print("load ckpt fail....")
    model.eval()
    test_loss = 0
    correct = 0

    for batch_idx, (images, corp_images, target) in enumerate(test_loader):

        batch_size = images.size(0)

        target_indices = target.long().cpu()
        target_one_hot = to_one_hot(batch_size,
                                    target,
                                    length=model.digits.num_units)

        images, corp_images, target = images.float().cuda(), \
                                    corp_images.float().cuda(),\
                                    target_one_hot.cuda()

        output = model(images, corp_images)

        test_loss += model.loss(images,
                                output, target, size_average=False).data.sum(
                                    dim=0)  # sum up batch loss

        v_mag = torch.sqrt((output**2).sum(dim=2, keepdim=True))

        pred = v_mag.data.max(1, keepdim=True)[1].cpu()

        correct += pred.eq(target_indices.view_as(pred)).sum()

    test_loss /= len(test_loader.dataset)
    # 日志记录
    writer.add_scalar("Test/acc", correct / len(test_loader.dataset), epoch)
    writer.add_scalar("Test/loss", test_loss, epoch)
    print(
        '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))
Пример #2
0
def train(dataset, model, optimizer, start_epoch, output_path=None):
    dataset_transform = transforms.Compose([
        transforms.ToTensor(),
        # transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = datasets.TUMOR_IMG(dataset,
                                       train=True,
                                       transform=dataset_transform)
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=False)
    # 尝试加载
    for epoch in range(start_epoch + 1, MAX_EPOCH):
        last_loss = None
        log_interval = 5
        model.train()
        for batch_idx, (images, corp_images,
                        labels) in enumerate(train_loader):
            # print("target1:", labels)
            # images = images.transpose(1,3).transpose(2,3)
            # corp_images = corp_images.transpose(1,3).transpose(2,3)
            # print("images[0]:", images[0])
            origin_labels = labels.long().cuda()
            # images.type(): torch.DoubleTensor
            # images.size(): torch.Size([20, 1, 512, 512])
            # labels.type(): torch.IntTensor
            # labels.size(): torch.Size([20])
            # print("images-max-min:", torch.max(images[0][0][0][100]), torch.min(images))

            target_one_hot = to_one_hot(images.size(0),
                                        labels,
                                        length=model.digits.num_units)

            images, corp_images, labels = images.float().cuda(), \
                                          corp_images.float().cuda(), \
                                          target_one_hot.cuda()

            # print("target2:", labels)

            optimizer.zero_grad()

            output = model(images, corp_images)

            # 总的迭代次数
            n_iter = (epoch - 1) * len(train_loader) + batch_idx - 1

            loss = model.loss(images, output, labels)
            _, _, acc = model.acc(output, origin_labels)
            loss.backward()
            last_loss = loss.data
            optimizer.step()

            # 以每个batch为单位的日志记录
            writer.add_scalar("Train/acc(batch)", acc, n_iter)
            writer.add_scalar("Train/loss(batch)", loss.item(), n_iter)

            # TODO 不知道是干啥的,观察观察
            for name, param in model.named_parameters():
                layer, attr = os.path.splitext(name)
                attr = attr[1:]
                writer.add_histogram("{}/{}".format(layer, attr), param, epoch)

            if batch_idx % log_interval == 0:
                print(
                    'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc:{:.6f}'
                    .format(epoch, batch_idx * len(images),
                            len(train_loader.dataset),
                            100. * batch_idx / len(train_loader), loss.data,
                            acc.data))

            if last_loss < early_stop_loss:
                break

        # 保存
        if not os.path.exists(output_path):
            # 创建目录
            os.makedirs(output_path)
        # test(dataset, epoch)
        checkpoint = {
            'model_state_dict': model.state_dict(),
            'optimizer.state_dict': optimizer.state_dict(),
            'epoch': epoch
        }
        torch.save(checkpoint, os.path.join(output_path, checkpoint_file))
Пример #3
0
test_batch_size = 20

# Stop training if loss goes below this threshold.
early_stop_loss = 0.0001
dataset = "/media/disk/lds/dataset/brain_tumor/512+128/1"

# load the data

# Normalization for TUMOR dataset.
dataset_transform = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.TUMOR_IMG(dataset,
                                   train=True,
                                   transform=dataset_transform)
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

test_dataset = datasets.TUMOR_IMG(dataset,
                                  train=False,
                                  transform=dataset_transform)
test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=test_batch_size,
                                          shuffle=True)

#
# Create capsule network.
#