예제 #1
0
def main():
    # tensorboard writer
    """
    os.system('rm -rf ./runs/*')
    writer = SummaryWriter('./runs/'+datetime.now().strftime('%B%d  %H:%M:%S'))
    if not os.path.exists('./runs'):
        os.mkdir('./runs')
    std = [.229, .224, .225]
    mean = [.485, .456, .406]
    """
    train_dir = opt.train_dir
    val_dir = opt.val_dir
    check_dir = opt.check_dir

    bsize = opt.b
    iter_num = opt.e  # training iterations

    if not os.path.exists(check_dir):
        os.mkdir(check_dir)

    # models
    if opt.q == 'vgg':
        feature = vgg.vgg(pretrained=True)
    elif 'resnet' in opt.q:
        feature = getattr(resnet, opt.q)(pretrained=True)
    elif 'densenet' in opt.q:
        feature = getattr(densenet, opt.q)(pretrained=True)
    else:
        feature = None
    feature.cuda()
    deconv = Deconv(opt.q)
    deconv.cuda()

    train_loader = torch.utils.data.DataLoader(MyData(train_dir,
                                                      transform=True,
                                                      crop=False,
                                                      hflip=False,
                                                      vflip=False),
                                               batch_size=bsize,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=True)
    val_loader = torch.utils.data.DataLoader(MyData(val_dir,
                                                    transform=True,
                                                    crop=False,
                                                    hflip=False,
                                                    vflip=False),
                                             batch_size=bsize / 2,
                                             shuffle=True,
                                             num_workers=4,
                                             pin_memory=True)
    if 'resnet' in opt.q:
        lr = 5e-3
        lr_decay = 0.9
        optimizer = torch.optim.SGD([{
            'params': [
                param for name, param in deconv.named_parameters()
                if name[-4:] == 'bias'
            ],
            'lr':
            2 * lr
        }, {
            'params': [
                param for name, param in deconv.named_parameters()
                if name[-4:] != 'bias'
            ],
            'lr':
            lr,
            'weight_decay':
            1e-4
        }, {
            'params': [
                param for name, param in feature.named_parameters()
                if name[-4:] == 'bias'
            ],
            'lr':
            2 * lr
        }, {
            'params': [
                param for name, param in feature.named_parameters()
                if name[-4:] != 'bias'
            ],
            'lr':
            lr,
            'weight_decay':
            1e-4
        }],
                                    momentum=0.9,
                                    nesterov=True)
    else:
        optimizer = torch.optim.Adam([
            {
                'params': feature.parameters(),
                'lr': 1e-4
            },
            {
                'params': deconv.parameters(),
                'lr': 1e-3
            },
        ])
    min_loss = 10000.0
    for it in range(iter_num):
        if 'resnet' in opt.q:
            optimizer.param_groups[0]['lr'] = 2 * lr * (
                1 - float(it) / iter_num)**lr_decay  # bias
            optimizer.param_groups[1]['lr'] = lr * (
                1 - float(it) / iter_num)**lr_decay  # weight
            optimizer.param_groups[2]['lr'] = 2 * lr * (
                1 - float(it) / iter_num)**lr_decay  # bias
            optimizer.param_groups[3]['lr'] = lr * (
                1 - float(it) / iter_num)**lr_decay  # weight
        for ib, (data, lbl) in enumerate(train_loader):
            inputs = Variable(data).cuda()
            lbl = Variable(lbl.float().unsqueeze(1)).cuda()
            feats = feature(inputs)
            msk = deconv(feats)
            loss = F.binary_cross_entropy_with_logits(msk, lbl)

            deconv.zero_grad()
            feature.zero_grad()

            loss.backward()

            optimizer.step()
            # visualize
            """
            if ib % 100 ==0:
                # visulize
                image = make_image_grid(inputs.data[:4, :3], mean, std)
                writer.add_image('Image', torchvision.utils.make_grid(image), ib)
                msk = F.sigmoid(msk)
                mask1 = msk.data[:4]
                mask1 = mask1.repeat(1, 3, 1, 1)
                writer.add_image('Image2', torchvision.utils.make_grid(mask1), ib)
                mask1 = lbl.data[:4]
                mask1 = mask1.repeat(1, 3, 1, 1)
                writer.add_image('Label', torchvision.utils.make_grid(mask1), ib)
                writer.add_scalar('M_global', loss.data[0], ib)
            """
            print('loss: %.4f (epoch: %d, step: %d)' % (loss.data[0], it, ib))
            del inputs, msk, lbl, loss, feats
            gc.collect()

        sb = validation(feature, deconv, val_loader)
        if sb < min_loss:
            filename = ('%s/deconv.pth' % (check_dir))
            torch.save(deconv.state_dict(), filename)
            filename = ('%s/feature.pth' % (check_dir))
            torch.save(feature.state_dict(), filename)
            print('save: (epoch: %d)' % it)
            min_loss = sb
예제 #2
0
    for ib, (data, _, lbl) in enumerate(train_loader):
        inputs = Variable(data)
        # inputs = Variable(data).cuda()
        # lbl = Variable(lbl.unsqueeze(1)).cuda()
        lbl = Variable(lbl.unsqueeze(1))
        loss = 0

        feats = feature(inputs)
        feats = feats[-3:]
        feats = feats[::-1]
        msk = deconv(feats)
        msk = functional.upsample(msk, scale_factor=4)
        prior = functional.sigmoid(msk)
        loss += criterion(msk, lbl)

        deconv.zero_grad()
        feature.zero_grad()

        loss.backward()

        optimizer_feature.step()
        optimizer_deconv.step()

        # visulize
        image = make_image_grid(inputs.data[:, :3], mean, std)
        writer.add_image('Image', torchvision.utils.make_grid(image), ib)
        msk = functional.sigmoid(msk)
        mask1 = msk.data  # mskdata,分割出来的。
        mask1 = mask1.repeat(1, 3, 1, 1)
        acc = math.e**(0 - loss)
        writer.add_image('Image2', torchvision.utils.make_grid(mask1), ib)