def main(args):
    torch.manual_seed(1)

    # crop input image and ground truth and save on disk
    cropped_input_images_path = os.path.join(args.save_cropped, 'input_images')
    cropped_gt_images_path = os.path.join(args.save_cropped, 'gt_images')

    if args.crop_images:
        crop_and_save(args, cropped_input_images_path, cropped_gt_images_path)

    seg_dataset = SegmentationData(cropped_input_images_path, cropped_gt_images_path, args.n_classes, args.phase)
    train_loader = DataLoader(seg_dataset, shuffle=True, num_workers=4, batch_size=args.batch_size)

    model = FCN(args.n_classes)
    use_gpu = torch.cuda.is_available()
    num_gpu = list(range(torch.cuda.device_count()))
    if use_gpu :
        model = model.cuda()
        model = nn.DataParallel(model, device_ids=num_gpu)

    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    criterion = nn.BCEWithLogitsLoss()
    losses = []
    for epoch in range(args.n_epoch):
        for i, (image, segement_im) in enumerate(train_loader):
            image = image.float()
            images = Variable(image.cuda())
            labels = Variable(segement_im.cuda())

            optimizer.zero_grad()

            outputs = model(images)

            loss = criterion(outputs, labels.float())
            loss.backward()
            optimizer.step()

            # add loss to a list for plotting it later
            if i == 0:
                losses.append(loss)
            print("epoch{} iteration {} loss: {}".format(epoch, i, loss.data.item()))

            if epoch%5 == 0:
                pred = outputs.data.max(1)[1].cpu().numpy()[0]

                decoded = decode_segmap(pred)
                decoded = Image.fromarray(decoded)

                path = os.path.join(args.output_path, 'output_%d_%d.png' % (epoch, i))

                decoded.save(path)

    # plot loss
    plot(losses, args)

    # save model
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)
    model_name = os.path.join(args.model_path, 'fcn.pt')
    torch.save(model, model_name)
    print('Initializing FCN...')
    model = FCN(args.input_size, args.output_size)
elif args.clf == 'svm':
    print('Initializing SVM...')
    model = SVM(args.input_size, args.output_size)
elif args.clf == 'resnet18':
    print('Initializing ResNet18...')
    model = resnet.resnet18(num_channels=args.num_channels,
                            num_classes=args.output_size)

model.load_state_dict(torch.load(init_path))
print('Load init: {}'.format(init_path))
model = nn.DataParallel(model.to(device), device_ids=args.device_id)

if 'sgd' in args.paradigm:
    optim = optim.SGD(params=model.parameters(), lr=args.lr)
elif 'adam' in args.paradigm:
    optim = optim.Adam(params=model.parameters(), lr=args.lr)

loss_fn = multiClassHingeLoss() if args.clf == 'svm' else F.nll_loss
loss_type = 'hinge' if args.clf == 'svm' else 'nll'

print('+' * 80)

best = 0

x_ax = []
acc_train = []
acc_test = []
l_train = []
l_test = []
示例#3
0
def main():
    root = "./data/VOCdevkit/VOC2012"
    batch_size = 4
    num_workers = 4
    num_classes = 21
    lr = 0.0025
    # lr = 5e-4  # fine-tune
    epoches = 100
    writer = SummaryWriter(comment="-fcn")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_transform = A.Compose([
                                 A.HorizontalFlip(),  # 注意这个先后顺序
                                 A.VerticalFlip(),
                                #  A.transpose(p=0.5),
                                 A.RandomRotate90(),
                                #  A.ElasticTransform(p=1, alpha=120,
                                #                     sigma=120 * 0.05,
                                #                     alpha_affine=120 * 0.03),
                                A.RandomResizedCrop(320, 480),
                                ])
    val_transform = A.Compose([
        A.RandomResizedCrop(320, 480)])
    train_set = VOCdataset(root, mode="train", transform=train_transform)
    val_set = VOCdataset(root, mode="val", transform=val_transform)

    train_loader = data.DataLoader(train_set, batch_size=batch_size,
                                   shuffle=True, num_workers=num_workers)
    val_loader = data.DataLoader(val_set, batch_size=batch_size,
                                 shuffle=False, num_workers=num_workers)

    model = FCN(num_classes).to(device)
    # state_dict = torch.load("./model/best.pth")
    # print("loading pretrained parameters")
    # model.load_state_dict(state_dict)
    # del state_dict
    criteria = nn.CrossEntropyLoss()
    # optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=2e-4)
    # optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9,
    #                       weight_decay=2e-4)

    vgg_parameters = (list(map(id, model.encode1.parameters()))+
                      list(map(id, model.encode2.parameters()))+
                      list(map(id, model.encode3.parameters()))+
                      list(map(id, model.encode4.parameters()))+
                      list(map(id, model.encode5.parameters())))
    encode_parameters = (list(model.encode1.parameters())+
                         list(model.encode2.parameters())+
                         list(model.encode3.parameters())+
                         list(model.encode4.parameters())+
                         list(model.encode5.parameters()))

    decode_parameters = filter(lambda p: id(p) not in vgg_parameters, model.parameters())
    optimizer = optim.SGD([{'params': encode_parameters, 'lr': 0.1 * lr},
                           {'params': decode_parameters, 'lr': lr}],
                          momentum=0.9,
                          weight_decay=2e-3)

    # optimizer = optim.Adam([{'params': encode_parameters, 'lr': 0.1 * lr},
    #                         {'params': decode_parameters, 'lr': lr}],
    #                        weight_decay=2e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.85)
    # scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,
    #                                                            T_0=100,
    #                                                            T_mult=1,
    #                                                            eta_min=0.0001)
    # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
    best_miou = 0.0
    for epoch in range(1, epoches+1):
        print("Epoch: ", epoch)
        scheduler.step()
        train_info = train(train_loader, model,
                           criteria, optimizer, device, batch_size)
        val_info = validate(val_loader, model,
                            criteria, device, batch_size)
        string = "loss: {}, pixel acc: {}, mean acc: {} miou: {}"
        print("train", end=' ')
        print(string.format(train_info['loss'],
                            train_info["pixel acc"],
                            train_info['mean acc'],
                            train_info['miou']))
        print("val", end=' ')
        print(string.format(val_info['loss'],
                            val_info['pixel acc'],
                            val_info['mean acc'],
                            val_info['miou']))

        writer.add_scalar("lr",
                          optimizer.state_dict()['param_groups'][0]['lr'],
                          epoch)
        writer.add_scalar('train/loss', train_info['loss'], epoch)
        writer.add_scalar('train/pixel acc', train_info['pixel acc'], epoch)
        writer.add_scalar('train/mean acc', train_info['mean acc'], epoch)
        writer.add_scalar('train/miou', train_info['miou'], epoch)
        writer.add_scalar('val/loss', val_info['loss'], epoch)
        writer.add_scalar('val/pixel acc', val_info['pixel acc'], epoch)
        writer.add_scalar('val/mean acc', val_info['mean acc'], epoch)
        writer.add_scalar('val/miou', val_info['miou'], epoch)
        if val_info['miou'] > best_miou:
            best_miou = val_info['miou']
            torch.save(model.state_dict(), './model/best.pth')
            print("best model find at {} epoch".format(epoch))