Exemplo n.º 1
0
def create_transforms(relax_crop, zero_crop):
    # Preparation of the data loaders
    first = [
        tr.CropFromMask(crop_elems=('image', 'gt'),
                        relax=relax_crop,
                        zero_pad=zero_crop),
        tr.FixedResize(resolutions={
            'crop_image': (512, 512),
            'crop_gt': (512, 512)
        })
    ]
    second = [
        tr.ToImage(norm_elem='extreme_points'),
        tr.ConcatInputs(elems=('crop_image', 'extreme_points')),
        tr.ToTensor()
    ]
    train_tf = transforms.Compose([
        tr.RandomHorizontalFlip(),
        tr.ScaleNRotate(rots=(-20, 20), scales=(.75, 1.25)), *first,
        tr.ExtremePoints(sigma=10, pert=5, elem='crop_gt'), *second
    ])
    test_tf = transforms.Compose(
        [*first,
         tr.ExtremePoints(sigma=10, pert=0, elem='crop_gt'), *second])
    return train_tf, test_tf
Exemplo n.º 2
0
 def transform_tr(self, sample):
     composed_transforms_tr = transforms.Compose([
         tr.RandomHorizontalFlip(),
         tr.ScaleNRotate(rots=(-20, 20), scales=(.75, 1.25)),
         tr.CropFromMask(crop_elems=('image', 'gt'), relax=20, zero_pad=True),
         tr.FixedResize(resolutions={'crop_image': (256, 256), 'crop_gt': (256, 256)}),
         tr.Normalize(elems='crop_image'),
         tr.ToTensor()
     ])
     return composed_transforms_tr(sample)
Exemplo n.º 3
0
net.to(device)

# Training the network
if resume_epoch != nEpochs:
    # Logging into Tensorboard
    log_dir = os.path.join(save_dir, 'models', datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
    # writer = SummaryWriter(log_dir=log_dir)

    # Use the following optimizer
    optimizer = optim.SGD(train_params, lr=p['lr'], momentum=p['momentum'], weight_decay=p['wd'])
    p['optimizer'] = str(optimizer)

    # Preparation of the data loaders
    composed_transforms_tr = transforms.Compose([
        tr.RandomHorizontalFlip(),
        tr.ScaleNRotate(rots=(-20, 20), scales=(.75, 1.25)),
        tr.CropFromMask(crop_elems=('image', 'gt'), relax=relax_crop, zero_pad=zero_pad_crop),
        tr.FixedResize(resolutions={'crop_image': (512, 512), 'crop_gt': (512, 512)}),
        tr.ExtremePoints(sigma=10, pert=5, elem='crop_gt'),
        tr.ToImage(norm_elem='extreme_points'),
        tr.ConcatInputs(elems=('crop_image', 'extreme_points')),
        tr.ToTensor()])
    composed_transforms_ts = transforms.Compose([
        tr.CropFromMask(crop_elems=('image', 'gt'), relax=relax_crop, zero_pad=zero_pad_crop),
        tr.FixedResize(resolutions={'crop_image': (512, 512), 'crop_gt': (512, 512)}),
        tr.ExtremePoints(sigma=10, pert=0, elem='crop_gt'),
        tr.ToImage(norm_elem='extreme_points'),
        tr.ConcatInputs(elems=('crop_image', 'extreme_points')),
        tr.ToTensor()])

    voc_train = pascal.VOCSegmentation(split='train', transform=composed_transforms_tr)
Exemplo n.º 4
0
        'lr': lr / 100,
        'weight_decay': wd
    },
    {
        'params': net.fuse.bias,
        'lr': 2 * lr / 100
    },
],
                      lr=lr,
                      momentum=0.9)

# Preparation of the data loaders
# Define augmentation transformations as a composition
composed_transforms = transforms.Compose([
    tr.RandomHorizontalFlip(),
    tr.ScaleNRotate(rots=(-30, 30), scales=(.75, 1.25)),
    tr.ToTensor()
])
# Training dataset and its iterator
db_train = db.DAVIS2016(train=True,
                        db_root_dir=db_root_dir,
                        transform=composed_transforms,
                        seq_name=seq_name)
trainloader = DataLoader(db_train,
                         batch_size=p['trainBatch'],
                         shuffle=True,
                         num_workers=1)

# Testing dataset and its iterator
db_test = db.DAVIS2016(train=False,
                       db_root_dir=db_root_dir,
Exemplo n.º 5
0
            img = cv2.resize(img, (self.inputRes[1], self.inputRes[0]))
            gt = cv2.resize(gt, (self.inputRes[1], self.inputRes[0]),
                            interpolation=cv2.INTER_NEAREST)

        sample = {'images': img, 'gts': gt}
        if self.transform is not None:
            sample = self.transform(sample)

            return sample

    def __len__(self):
        return len(self.imgs)


if __name__ == '__main__':
    composed_transforms = transforms.Compose([
        tr.RandomHorizontalFlip(),
        tr.ScaleNRotate(rots=(-10, 10), scales=(.75, 1.25)),
        tr.ToTensor()
    ])
    train_set = VOC('train',
                    inputRes=(512, 512),
                    transform=composed_transforms)
    train_loader = DataLoader(train_set,
                              batch_size=1,
                              num_workers=8,
                              shuffle=True)
    for ii, sample_batched in enumerate(train_loader):
        img, mask = sample_batched
        break
Exemplo n.º 6
0
        return _img, _target

    def __str__(self):
        return 'VOC2012(split=' + str(self.split) + ')'


if __name__ == '__main__':
    from dataloaders import custom_transforms as tr
    from dataloaders.utils import decode_segmap
    from torch.utils.data import DataLoader
    from torchvision import transforms
    import matplotlib.pyplot as plt

    composed_transforms_tr = transforms.Compose([
        tr.RandomHorizontalFlip(),
        tr.ScaleNRotate(rots=(-15, 15), scales=(.75, 1.5)),
        tr.FixedResize(size=512),
        tr.ToTensor()
    ])

    voc_train = VOCSegmentation(split='train',
                                transform=composed_transforms_tr)

    dataloader = DataLoader(voc_train,
                            batch_size=2,
                            shuffle=True,
                            num_workers=2)

    for ii, sample in enumerate(dataloader):
        for jj in range(sample["image"].size()[0]):
            img = sample['image'].numpy()
Exemplo n.º 7
0
save_dir = os.path.join(Path.save_root_dir(),
                        'lr_' + str(base_lr) + '_wd_' + str(weight_decay))

if not os.path.exists(save_dir):
    os.makedirs(os.path.join(save_dir))

davis17loader = db17.DAVISLoader(year=cfg.YEAR, phase=cfg.PHASE)
seq_data = davis17loader[seq_name]

images = seq_data.images
anno = seq_data.annotations

composed_transforms = transforms.Compose([
    tr.RandomHorizontalFlip(),
    tr.ScaleNRotate(rots=(-30, 30), scales=(.5, 1.3)),
    tr.ToTensor()
])

alreadyTrained = False
file_name = os.path.join(
    save_dir, 'online_training_' + seq_name + '_object_id_' + str(1) +
    'epoch_' + str(nEpochs) + '.pth')

if os.path.exists(file_name):
    print('Training already completed! Not doing it again.')
    alreadyTrained = True

if not alreadyTrained:
    if os.path.exists(os.path.join(save_dir, 'logs')) == False:
        os.mkdir(os.path.join(save_dir, 'logs'))
Exemplo n.º 8
0
    def train(self,
              first_frame,
              n_interaction,
              obj_id,
              scribbles_data,
              scribble_iter,
              subset,
              use_previous_mask=False):
        nAveGrad = 1
        num_workers = 4
        train_batch = min(n_interaction, self.train_batch)

        frames_list = interactive_utils.scribbles.annotated_frames_object(
            scribbles_data, obj_id)
        scribbles_list = scribbles_data['scribbles']
        seq_name = scribbles_data['sequence']

        if obj_id == 1 and n_interaction == 1:
            self.prev_models = {}

        # Network definition
        if n_interaction == 1:
            print('Loading weights from: {}'.format(self.parent_model))
            self.net.load_state_dict(self.parent_model_state)
            self.prev_models[obj_id] = None
        else:
            print(
                'Loading weights from previous network: objId-{}_interaction-{}_scribble-{}.pth'
                .format(obj_id, n_interaction - 1, scribble_iter))
            self.net.load_state_dict(self.prev_models[obj_id])

        lr = 1e-8
        wd = 0.0002
        optimizer = optim.SGD([
            {
                'params': [
                    pr[1] for pr in self.net.stages.named_parameters()
                    if 'weight' in pr[0]
                ],
                'weight_decay':
                wd
            },
            {
                'params': [
                    pr[1] for pr in self.net.stages.named_parameters()
                    if 'bias' in pr[0]
                ],
                'lr':
                lr * 2
            },
            {
                'params': [
                    pr[1] for pr in self.net.side_prep.named_parameters()
                    if 'weight' in pr[0]
                ],
                'weight_decay':
                wd
            },
            {
                'params': [
                    pr[1] for pr in self.net.side_prep.named_parameters()
                    if 'bias' in pr[0]
                ],
                'lr':
                lr * 2
            },
            {
                'params': [
                    pr[1] for pr in self.net.upscale.named_parameters()
                    if 'weight' in pr[0]
                ],
                'lr':
                0
            },
            {
                'params': [
                    pr[1] for pr in self.net.upscale_.named_parameters()
                    if 'weight' in pr[0]
                ],
                'lr':
                0
            },
            {
                'params': self.net.fuse.weight,
                'lr': lr / 100,
                'weight_decay': wd
            },
            {
                'params': self.net.fuse.bias,
                'lr': 2 * lr / 100
            },
        ],
                              lr=lr,
                              momentum=0.9)

        prev_mask_path = os.path.join(
            self.save_res_dir, 'interaction-{}'.format(n_interaction - 1),
            'scribble-{}'.format(scribble_iter))
        composed_transforms_tr = transforms.Compose([
            tr.SubtractMeanImage(self.meanval),
            tr.CustomScribbleInteractive(scribbles_list,
                                         first_frame,
                                         use_previous_mask=use_previous_mask,
                                         previous_mask_path=prev_mask_path),
            tr.RandomHorizontalFlip(),
            tr.ScaleNRotate(rots=(-30, 30), scales=(.75, 1.25)),
            tr.ToTensor()
        ])
        # Training dataset and its iterator
        db_train = db.DAVIS2017(split=subset,
                                transform=composed_transforms_tr,
                                custom_frames=frames_list,
                                seq_name=seq_name,
                                obj_id=obj_id,
                                no_gt=True,
                                retname=True)
        trainloader = DataLoader(db_train,
                                 batch_size=train_batch,
                                 shuffle=True,
                                 num_workers=num_workers)
        num_img_tr = len(trainloader)
        loss_tr = []
        aveGrad = 0

        start_time = timeit.default_timer()
        # Main Training and Testing Loop
        epoch = 0
        while 1:
            # One training epoch
            running_loss_tr = 0
            for ii, sample_batched in enumerate(trainloader):

                inputs, gts, void = sample_batched['image'], sample_batched[
                    'scribble_gt'], sample_batched['scribble_void_pixels']

                # Forward-Backward of the mini-batch
                inputs, gts, void = Variable(inputs), Variable(gts), Variable(
                    void)
                if self.gpu_id >= 0:
                    inputs, gts, void = inputs.cuda(), gts.cuda(), void.cuda()

                outputs = self.net.forward(inputs)

                # Compute the fuse loss
                loss = class_balanced_cross_entropy_loss(outputs[-1],
                                                         gts,
                                                         size_average=False,
                                                         void_pixels=void)
                running_loss_tr += loss.item()

                # Print stuff
                if epoch % 10 == 0:
                    running_loss_tr /= num_img_tr
                    loss_tr.append(running_loss_tr)

                    print('[Epoch: %d, numImages: %5d]' % (epoch + 1, ii + 1))
                    print('Loss: %f' % running_loss_tr)
                    # writer.add_scalar('data/total_loss_epoch', running_loss_tr, epoch)

                # Backward the averaged gradient
                loss /= nAveGrad
                loss.backward()
                aveGrad += 1

                # Update the weights once in nAveGrad forward passes
                if aveGrad % nAveGrad == 0:
                    # writer.add_scalar('data/total_loss_iter', loss.data[0], ii + num_img_tr * epoch)
                    optimizer.step()
                    optimizer.zero_grad()
                    aveGrad = 0

            epoch += train_batch
            stop_time = timeit.default_timer()
            if stop_time - start_time > self.time_budget:
                break

        # Save the model into dictionary
        self.prev_models[obj_id] = copy.deepcopy(self.net.state_dict())
Exemplo n.º 9
0
def make_data_loader(args, **kwargs):
    crop_size = args.crop_size
    gt_size = args.gt_size
    if args.dataset == 'pascal' or args.dataset == 'click':
        composed_transforms_tr = transforms.Compose([
            tr.RandomHorizontalFlip(),
            tr.ScaleNRotate(rots=(-20, 20), scales=(.75, 1.25)),
            tr.CropFromMask(crop_elems=('image', 'gt'),
                            relax=20,
                            zero_pad=True,
                            jitters_bound=(40, 70)),
            tr.FixedResize(
                resolutions={
                    'crop_image': (crop_size, crop_size),
                    'crop_gt': (gt_size, gt_size)
                }),
            tr.Normalize(elems='crop_image'),
            tr.ToTensor()
        ])
        composed_transforms_val = transforms.Compose([
            tr.CropFromMask(crop_elems=('image', 'gt'),
                            relax=20,
                            zero_pad=True,
                            jitters_bound=(50, 51)),
            tr.FixedResize(
                resolutions={
                    'crop_image': (crop_size, crop_size),
                    'crop_gt': (gt_size, gt_size)
                }),
            tr.Normalize(elems='crop_image'),
            tr.ToTensor()
        ])
        train_set = pascal.VOCSegmentation(split='train',
                                           transform=composed_transforms_tr)
        if args.dataset == 'click':
            train_set.reset_target_list(args)
        val_set = pascal.VOCSegmentation(split='val',
                                         transform=composed_transforms_val)
        if args.use_sbd:
            sbd_train = sbd.SBDSegmentation(args, split=['train', 'val'])
            train_set = combine_dbs.CombineDBs([train_set, sbd_train],
                                               excluded=[val_set])

        train_loader = DataLoader(train_set,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  drop_last=True,
                                  **kwargs)
        val_loader = DataLoader(val_set,
                                batch_size=args.batch_size,
                                shuffle=False,
                                **kwargs)
        test_loader = None
        NUM_CLASSES = 2
        return train_loader, val_loader, test_loader, NUM_CLASSES

    elif args.dataset == 'grabcut':
        composed_transforms_val = transforms.Compose([
            tr.CropFromMask(crop_elems=('image', 'gt'),
                            relax=20,
                            zero_pad=True,
                            jitters_bound=(50, 51)),
            tr.FixedResize(
                resolutions={
                    'crop_image': (crop_size, crop_size),
                    'crop_gt': (gt_size, gt_size)
                }),
            tr.Normalize(elems='crop_image'),
            tr.ToTensor()
        ])
        val_set = grab_berkeley_eval.GrabBerkely(
            which='grabcut', transform=composed_transforms_val)
        val_loader = DataLoader(val_set,
                                batch_size=args.batch_size,
                                shuffle=False,
                                **kwargs)
        test_loader = None
        train_loader = None
        NUM_CLASSES = 2
        return train_loader, val_loader, test_loader, NUM_CLASSES

    elif args.dataset == 'bekeley':
        composed_transforms_val = transforms.Compose([
            tr.CropFromMask(crop_elems=('image', 'gt'),
                            relax=20,
                            zero_pad=True,
                            jitters_bound=(50, 51)),
            tr.FixedResize(
                resolutions={
                    'crop_image': (crop_size, crop_size),
                    'crop_gt': (gt_size, gt_size)
                }),
            tr.Normalize(elems='crop_image'),
            tr.ToTensor()
        ])
        val_set = grab_berkeley_eval.GrabBerkely(
            which='bekeley', transform=composed_transforms_val)
        val_loader = DataLoader(val_set,
                                batch_size=args.batch_size,
                                shuffle=False,
                                **kwargs)
        test_loader = None
        train_loader = None
        NUM_CLASSES = 2
        return train_loader, val_loader, test_loader, NUM_CLASSES

    elif args.dataset == 'cityscapes':
        train_set = cityscapes.CityscapesSegmentation(args, split='train')
        val_set = cityscapes.CityscapesSegmentation(args, split='val')
        test_set = cityscapes.CityscapesSegmentation(args, split='test')
        num_class = train_set.NUM_CLASSES
        train_loader = DataLoader(train_set,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  **kwargs)
        val_loader = DataLoader(val_set,
                                batch_size=args.batch_size,
                                shuffle=False,
                                **kwargs)
        test_loader = DataLoader(test_set,
                                 batch_size=args.batch_size,
                                 shuffle=False,
                                 **kwargs)

        return train_loader, val_loader, test_loader, num_class

    elif args.dataset == 'coco':
        val_set = coco_eval.COCOSegmentation(split='val', cat=args.coco_part)
        num_class = 2
        train_loader = None
        val_loader = DataLoader(val_set,
                                batch_size=args.batch_size,
                                shuffle=False,
                                **kwargs)
        test_loader = None
        return train_loader, val_loader, test_loader, num_class

    # elif args.dataset == 'click':
    #     train_set = click_dataset.ClickDataset(split='train')
    #     val_set = click_dataset.ClickDataset(split='val')
    #     num_class = 2
    #     train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs)
    #     val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)
    #     test_loader = None
    #     return train_loader, val_loader, test_loader, num_class

    else:
        raise NotImplementedError
Exemplo n.º 10
0
def main(args):
    # # Select which GPU, -1 if CPU
    gpu_id = 0
    device = torch.device("cuda:" +
                          str(gpu_id) if torch.cuda.is_available() else "cpu")

    # # Setting other parameters
    resume_epoch = 0  # Default is 0, change if want to resume
    nEpochs = 10  # Number of epochs for training (500.000/2079)
    batch_size = 1
    snapshot = 1  # Store a model every snapshot epochs
    pred_lr = 1e-8
    seg_lr = 1e-4
    lr_D = 1e-4
    wd = 5e-4
    beta = 0.001
    margin = 0.3

    updateD = True
    updateG = False
    num_frame = args.frame_nums

    modelName = 'STCNN_frame_' + str(num_frame)

    save_dir = Path.save_root_dir()
    if not os.path.exists(save_dir):
        os.makedirs(os.path.join(save_dir))
    save_model_dir = os.path.join(save_dir, modelName)
    if not os.path.exists(save_model_dir):
        os.makedirs(os.path.join(save_model_dir))

    # Network definition

    netD = Inception3(num_classes=1, aux_logits=False, transform_input=True)
    initialize_netD(
        netD,
        os.path.join(save_dir, 'FramePredModels',
                     'frame_nums_' + str(num_frame), 'NetD_epoch-90.pth'))
    seg_enc = SegEncoder()
    pred_enc = FramePredEncoder(frame_nums=num_frame)
    pred_dec = FramePredDecoder()
    j_seg_dec = JointSegDecoder()
    if resume_epoch == 0:
        initialize_model(pred_enc,
                         seg_enc,
                         pred_dec,
                         j_seg_dec,
                         save_dir,
                         num_frame=num_frame)
        net = STCNN(pred_enc, seg_enc, pred_dec, j_seg_dec)
    else:
        net = STCNN(pred_enc, seg_enc, pred_dec, j_seg_dec)
        print("Updating weights from: {}".format(
            os.path.join(
                save_model_dir,
                modelName + '_epoch-' + str(resume_epoch - 1) + '.pth')))
        net.load_state_dict(
            torch.load(os.path.join(
                save_model_dir,
                modelName + '_epoch-' + str(resume_epoch - 1) + '.pth'),
                       map_location=lambda storage, loc: storage))

    # Logging into Tensorboard
    log_dir = os.path.join(
        save_dir, 'JointPredSegNet_runs',
        datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
    writer = SummaryWriter(log_dir=log_dir, comment='-parent')

    # PyTorch 0.4.0 style
    net.to(device)
    netD.to(device)

    lp_function = nn.MSELoss().to(device)
    criterion = nn.BCELoss().to(device)
    seg_criterion = nn.BCEWithLogitsLoss().to(device)

    # Use the following optimizer
    optimizer = optim.SGD([
        {
            'params':
            [param for name, param in net.seg_encoder.named_parameters()],
            'lr': seg_lr
        },
        {
            'params':
            [param for name, param in net.seg_decoder.named_parameters()],
            'lr': seg_lr
        },
    ],
                          weight_decay=wd,
                          momentum=0.9)

    optimizerG = optim.Adam([
        {
            'params':
            [param for name, param in net.pred_encoder.named_parameters()],
            'lr':
            pred_lr
        },
        {
            'params':
            [param for name, param in net.pred_decoder.named_parameters()],
            'lr':
            pred_lr
        },
    ],
                            lr=pred_lr,
                            weight_decay=wd)

    optimizerD = optim.Adam(netD.parameters(), lr=lr_D, weight_decay=wd)
    # Preparation of the data loaders
    # Define augmentation transformations as a composition
    composed_transforms = transforms.Compose([
        tr.RandomHorizontalFlip(),
        tr.ScaleNRotate(rots=(-30, 30), scales=(0.75, 1.25))
    ])

    # Training dataset and its iterator
    db_train = db.DAVISDataset(inputRes=(400, 710),
                               samples_list_file=os.path.join(
                                   Path.data_dir(), 'DAVIS16_samples_list_' +
                                   str(num_frame) + '.txt'),
                               transform=composed_transforms,
                               num_frame=num_frame)
    trainloader = DataLoader(db_train,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=4)
    num_img_tr = len(trainloader)
    iter_num = nEpochs * num_img_tr
    curr_iter = resume_epoch * num_img_tr
    print("Training Network")
    real_label = torch.ones(batch_size).float().to(device)
    fake_label = torch.zeros(batch_size).float().to(device)
    for epoch in range(resume_epoch, nEpochs):
        start_time = timeit.default_timer()

        for ii, sample_batched in enumerate(trainloader):

            seqs, frames, gts, pred_gts = sample_batched['images'], sample_batched['frame'],sample_batched['seg_gt'], \
                    sample_batched['pred_gt']

            # Forward-Backward of the mini-batch
            seqs.requires_grad_()
            frames.requires_grad_()

            seqs, frames, gts, pred_gts = seqs.to(device), frames.to(
                device), gts.to(device), pred_gts.to(device)

            pred_gts = F.upsample(pred_gts,
                                  size=(100, 178),
                                  mode='bilinear',
                                  align_corners=False)

            pred_gts = pred_gts.detach()
            seg_res, pred = net.forward(seqs, frames)

            D_real = netD(pred_gts)
            errD_real = criterion(D_real, real_label)
            D_fake = netD(pred.detach())
            errD_fake = criterion(D_fake, fake_label)

            optimizer.zero_grad()
            seg_loss = seg_criterion(seg_res[-1], gts)
            for i in reversed(range(len(seg_res) - 1)):
                seg_loss = seg_loss + (
                    1 - curr_iter / iter_num) * seg_criterion(seg_res[i], gts)

            seg_loss.backward()
            optimizer.step()
            curr_iter += 1
            if updateD:
                ############################
                # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
                ###########################
                # train with real
                netD.zero_grad()
                # train with fake
                d_loss = errD_fake + errD_real
                d_loss.backward()
                optimizerD.step()

            if updateG:
                ############################
                # (2) Update G network: maximize log(D(G(z)))
                ###########################
                optimizerG.zero_grad()
                D_fake = netD(pred)
                errG = criterion(D_fake, real_label)

                lp_loss = lp_function(pred, pred_gts)
                total_loss = lp_loss + beta * errG
                total_loss.backward()
                optimizerG.step()

            if (errD_fake.data < margin).all() or (errD_real.data <
                                                   margin).all():
                updateD = False
            if (errD_fake.data > (1. - margin)).all() or (errD_real.data >
                                                          (1. - margin)).all():
                updateG = False
            if not updateD and not updateG:
                updateD = True
                updateG = True

            if (ii + num_img_tr * epoch) % 5 == 4:
                print(
                    "Iters: [%2d] time: %4.4f, lp_loss: %.8f, G_loss: %.8f,seg_loss: %.8f"
                    %
                    (ii + num_img_tr * epoch, timeit.default_timer() -
                     start_time, lp_loss.item(), errG.item(), seg_loss.item()))
                print('updateD:', updateD, 'updateG:', updateG)
            if (ii + num_img_tr * epoch) % 10 == 9:
                writer.add_scalar('data/loss_iter', total_loss.item(),
                                  ii + num_img_tr * epoch)
                writer.add_scalar('data/lp_loss_iter', lp_loss.item(),
                                  ii + num_img_tr * epoch)
                writer.add_scalar('data/G_loss_iter', errG.item(),
                                  ii + num_img_tr * epoch)
                writer.add_scalar('data/seg_loss_iter', seg_loss.item(),
                                  ii + num_img_tr * epoch)

            if (ii + num_img_tr * epoch) % 500 == 0:

                seg_pred = seg_res[-1][0, :, :, :].data.cpu().numpy()
                seg_pred = 1 / (1 + np.exp(-seg_pred))
                gt_sample = gts[0, :, :, :].data.cpu().numpy().transpose(
                    [1, 2, 0]) * 255

                seg_pred = seg_pred.transpose([1, 2, 0]) * 255
                frame_sample = frames[0, :, :, :].data.cpu().numpy().transpose(
                    [1, 2, 0])
                frame_sample = inverse_transform(frame_sample) * 255
                gt_sample3 = np.concatenate([gt_sample, gt_sample, gt_sample],
                                            axis=2)

                seg_pred3 = np.concatenate([seg_pred, seg_pred, seg_pred],
                                           axis=2)
                samples1 = np.concatenate(
                    (seg_pred3, gt_sample3, frame_sample), axis=0)

                pred_sample = pred[0, :, :, :].data.cpu().numpy().transpose(
                    [1, 2, 0])
                frame_sample = pred_gts[
                    0, :, :, :].data.cpu().numpy().transpose([1, 2, 0])
                samples2 = np.concatenate((pred_sample, frame_sample), axis=0)
                samples2 = inverse_transform(samples2) * 255
                print("Saving sample ...")
                running_res_dir = os.path.join(save_dir,
                                               modelName + '_results')
                if not os.path.exists(running_res_dir):
                    os.makedirs(running_res_dir)
                imageio.imwrite(
                    os.path.join(running_res_dir,
                                 "train_%s_s.png" % (ii + num_img_tr * epoch)),
                    np.uint8(samples1))
                imageio.imwrite(
                    os.path.join(running_res_dir,
                                 "train_%s_p.png" % (ii + num_img_tr * epoch)),
                    np.uint8(samples2))
        # Print stuff
        print('[Epoch: %d, numImages: %5d]' % (epoch, (ii + 1) * batch_size))
        stop_time = timeit.default_timer()
        print("Execution time: " + str(stop_time - start_time))
        # Save the model
        if (epoch % snapshot) == snapshot - 1 and epoch != 0:
            torch.save(
                net.state_dict(),
                os.path.join(save_model_dir,
                             modelName + '_epoch-' + str(epoch) + '.pth'))

    writer.close()
Exemplo n.º 11
0
def train(epochs_wo_avegrad):

    # Setting of parameters
    if 'SEQ_NAME' not in os.environ.keys():
        seq_name = 'blackswan'
    else:
        seq_name = str(os.environ['SEQ_NAME'])

    db_root_dir = Path.db_root_dir()
    save_dir = Path.save_root_dir()

    if not os.path.exists(save_dir):
        os.makedirs(os.path.join(save_dir))

    vis_net = 0  # Visualize the network?
    vis_res = 0  # Visualize the results?
    nAveGrad = 5  # Average the gradient every nAveGrad iterations
    nEpochs = epochs_wo_avegrad * nAveGrad  # Number of epochs for training #CHANGED from 2000
    snapshot = nEpochs  # Store a model every snapshot epochs
    parentEpoch = 240

    # Parameters in p are used for the name of the model
    p = {
        'trainBatch': 1,  # Number of Images in each mini-batch
    }
    seed = 0

    parentModelName = 'parent'
    # Select which GPU, -1 if CPU
    gpu_id = 0
    device = torch.device("cuda:" +
                          str(gpu_id) if torch.cuda.is_available() else "cpu")

    # Network definition
    net = vo.OSVOS(pretrained=0)
    net.load_state_dict(
        torch.load(os.path.join(
            save_dir,
            parentModelName + '_epoch-' + str(parentEpoch - 1) + '.pth'),
                   map_location=lambda storage, loc: storage))

    # Logging into Tensorboard
    log_dir = os.path.join(
        save_dir, 'runs',
        datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname() +
        '-' + seq_name)
    writer = SummaryWriter(logdir=log_dir)

    net.to(device)  # PyTorch 0.4.0 style

    # Visualize the network
    if vis_net:
        x = torch.randn(1, 3, 480, 854)
        x.requires_grad_()
        x = x.to(device)
        y = net.forward(x)
        g = viz.make_dot(y, net.state_dict())
        g.view()

    # Use the following optimizer
    lr = 1e-8
    wd = 0.0002
    optimizer = optim.SGD([
        {
            'params': [
                pr[1]
                for pr in net.stages.named_parameters() if 'weight' in pr[0]
            ],
            'weight_decay':
            wd
        },
        {
            'params':
            [pr[1] for pr in net.stages.named_parameters() if 'bias' in pr[0]],
            'lr':
            lr * 2
        },
        {
            'params': [
                pr[1]
                for pr in net.side_prep.named_parameters() if 'weight' in pr[0]
            ],
            'weight_decay':
            wd
        },
        {
            'params': [
                pr[1]
                for pr in net.side_prep.named_parameters() if 'bias' in pr[0]
            ],
            'lr':
            lr * 2
        },
        {
            'params': [
                pr[1]
                for pr in net.upscale.named_parameters() if 'weight' in pr[0]
            ],
            'lr':
            0
        },
        {
            'params': [
                pr[1]
                for pr in net.upscale_.named_parameters() if 'weight' in pr[0]
            ],
            'lr':
            0
        },
        {
            'params': net.fuse.weight,
            'lr': lr / 100,
            'weight_decay': wd
        },
        {
            'params': net.fuse.bias,
            'lr': 2 * lr / 100
        },
    ],
                          lr=lr,
                          momentum=0.9)

    # Preparation of the data loaders
    # Define augmentation transformations as a composition
    composed_transforms = transforms.Compose([
        tr.RandomHorizontalFlip(),
        tr.ScaleNRotate(rots=(-30, 30), scales=(.75, 1.25)),
        tr.ToTensor()
    ])
    # Training dataset and its iterator
    db_train = db.DAVIS2016(train=True,
                            db_root_dir=db_root_dir,
                            transform=composed_transforms,
                            seq_name=seq_name)
    trainloader = DataLoader(db_train,
                             batch_size=p['trainBatch'],
                             shuffle=True,
                             num_workers=1)

    # Testing dataset and its iterator
    db_test = db.DAVIS2016(train=False,
                           db_root_dir=db_root_dir,
                           transform=tr.ToTensor(),
                           seq_name=seq_name)
    testloader = DataLoader(db_test,
                            batch_size=1,
                            shuffle=False,
                            num_workers=1)

    num_img_tr = len(trainloader)
    num_img_ts = len(testloader)

    loss_tr = []
    aveGrad = 0

    print("Start of Online Training, sequence: " + seq_name)
    start_time = timeit.default_timer()
    # Main Training and Testing Loop
    for epoch in range(0, nEpochs):
        # One training epoch
        running_loss_tr = 0
        np.random.seed(seed + epoch)
        for ii, sample_batched in enumerate(trainloader):
            inputs, gts = sample_batched['image'], sample_batched['gt']

            # Forward-Backward of the mini-batch
            inputs.requires_grad_()
            inputs, gts = inputs.to(device), gts.to(device)

            outputs = net.forward(inputs)

            # Compute the fuse loss
            loss = class_balanced_cross_entropy_loss(outputs[-1],
                                                     gts,
                                                     size_average=False)
            running_loss_tr += loss.item()  # PyTorch 0.4.0 style
            # Print stuff
            if epoch % (nEpochs // 20) == (nEpochs // 20 - 1):
                running_loss_tr /= num_img_tr
                loss_tr.append(running_loss_tr)

                print('[Epoch: %d, numImages: %5d]' % (epoch + 1, ii + 1))
                print('Loss: %f' % running_loss_tr)
                writer.add_scalar('data/total_loss_epoch', running_loss_tr,
                                  epoch)

            # Backward the averaged gradient
            loss /= nAveGrad
            loss.backward()
            aveGrad += 1

            # Update the weights once in nAveGrad forward passes
            if aveGrad % nAveGrad == 0:
                writer.add_scalar('data/total_loss_iter', loss.item(),
                                  ii + num_img_tr * epoch)
                optimizer.step()
                optimizer.zero_grad()
                aveGrad = 0

        # Save the model
        if (epoch % snapshot) == snapshot - 1 and epoch != 0:
            torch.save(
                net.state_dict(),
                os.path.join(save_dir,
                             seq_name + '_epoch-' + str(epoch) + '.pth'))

    stop_time = timeit.default_timer()
    print('Online training time: ' + str(stop_time - start_time))

    # Testing Phase
    if vis_res:
        import matplotlib.pyplot as plt
        plt.close("all")
        plt.ion()
        f, ax_arr = plt.subplots(1, 3)

    save_dir_res = os.path.join(save_dir, 'Results', seq_name)
    if not os.path.exists(save_dir_res):
        os.makedirs(save_dir_res)

    print('Testing Network')
    with torch.no_grad():  # PyTorch 0.4.0 style
        # Main Testing Loop
        for ii, sample_batched in enumerate(testloader):

            img, gt, fname = sample_batched['image'], sample_batched[
                'gt'], sample_batched['fname']

            # Forward of the mini-batch
            inputs, gts = img.to(device), gt.to(device)

            outputs = net.forward(inputs)

            for jj in range(int(inputs.size()[0])):
                pred = np.transpose(
                    outputs[-1].cpu().data.numpy()[jj, :, :, :], (1, 2, 0))
                pred = 1 / (1 + np.exp(-pred))
                pred = np.squeeze(pred)

                # Save the result, attention to the index jj
                sm.imsave(
                    os.path.join(save_dir_res,
                                 os.path.basename(fname[jj]) + '.png'), pred)

                if vis_res:
                    img_ = np.transpose(img.numpy()[jj, :, :, :], (1, 2, 0))
                    gt_ = np.transpose(gt.numpy()[jj, :, :, :], (1, 2, 0))
                    gt_ = np.squeeze(gt)
                    # Plot the particular example
                    ax_arr[0].cla()
                    ax_arr[1].cla()
                    ax_arr[2].cla()
                    ax_arr[0].set_title('Input Image')
                    ax_arr[1].set_title('Ground Truth')
                    ax_arr[2].set_title('Detection')
                    ax_arr[0].imshow(im_normalize(img_))
                    ax_arr[1].imshow(gt_)
                    ax_arr[2].imshow(im_normalize(pred))
                    plt.pause(0.001)

    writer.close()
Exemplo n.º 12
0
    def train(self, first_frame, n_interaction, obj_id, scribbles_data, scribble_iter, subset, use_previous_mask=False):
        nAveGrad = 1
        num_workers = 4
        train_batch = min(n_interaction, self.train_batch)

        frames_list = interactive_utils.scribbles.annotated_frames_object(scribbles_data, obj_id)
        scribbles_list = scribbles_data['scribbles']
        seq_name = scribbles_data['sequence']

        if obj_id == 1 and n_interaction == 1:
            self.prev_models = {}

        # # Network definition
        # if n_interaction == 1:
        #     print('Loading weights from: {}'.format(self.parent_model))
        #     self.net.load_state_dict(self.parent_model_state)
        #     self.prev_models[obj_id] = None
        # else:
        #     print('Loading weights from previous network: objId-{}_interaction-{}_scribble-{}.pth'
        #           .format(obj_id, n_interaction-1, scribble_iter))
        #     self.net.load_state_dict(self.prev_models[obj_id])

        lr = 1e-5
        wd = 0.0002
        # optimizer = optim.SGD([
        #     {'params': [pr[1] for pr in self.net.stages.named_parameters() if 'weight' in pr[0]], 'weight_decay': wd},
        #     {'params': [pr[1] for pr in self.net.stages.named_parameters() if 'bias' in pr[0]], 'lr': lr * 2},
        #     {'params': [pr[1] for pr in self.net.side_prep.named_parameters() if 'weight' in pr[0]], 'weight_decay': wd},
        #     {'params': [pr[1] for pr in self.net.side_prep.named_parameters() if 'bias' in pr[0]], 'lr': lr * 2},
        #     {'params': [pr[1] for pr in self.net.upscale.named_parameters() if 'weight' in pr[0]], 'lr': 0},
        #     {'params': [pr[1] for pr in self.net.upscale_.named_parameters() if 'weight' in pr[0]], 'lr': 0},
        #     {'params': self.net.fuse.weight, 'lr': lr / 100, 'weight_decay': wd},
        #     {'params': self.net.fuse.bias, 'lr': 2 * lr / 100},
        # ], lr=lr, momentum=0.9)

        optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.net.parameters()), lr=lr, momentum=0.9)

        prev_mask_path = os.path.join(self.save_res_dir, 'interaction-{}'.format(n_interaction-1),
                                      'scribble-{}'.format(scribble_iter))
        composed_transforms_tr = transforms.Compose([tr.CenterCrop((480,832)),tr.SubtractMeanImage(self.meanval),
                                                     tr.CustomScribbleInteractive(scribbles_list, first_frame,
                                                                                  use_previous_mask=use_previous_mask,
                                                                                  previous_mask_path=prev_mask_path),
                                                     tr.RandomHorizontalFlip(),
                                                     tr.ScaleNRotate(rots=(-30, 30), scales=(.75, 1.25)),
                                                     tr.ToTensor()])
        # Training dataset and its iterator
        # db_train = db.DAVIS2017(split=subset, transform=composed_transforms_tr,
        #                         custom_frames=frames_list, seq_name=seq_name,
        #                         obj_id=obj_id, no_gt=True, retname=True)
        db_train = db_scribblenet.DAVIS2017(split=subset, transform=composed_transforms_tr,
                                custom_frames=frames_list, seq_name=seq_name,
                                obj_id=obj_id, no_gt=True, retname=True)       
        trainloader = DataLoader(db_train, batch_size=train_batch, shuffle=True, num_workers=num_workers)
        num_img_tr = len(trainloader)
        loss_tr = []
        aveGrad = 0

        # List of all previous masks and aggregated features
        prev_masks = []
        prev_aggs = []

        start_time = timeit.default_timer()
        # Main Training and Testing Loop
        epoch = 0
        while 1:
            # One training epoch
            running_loss_tr = 0
            for ii, sample_batched in enumerate(trainloader):
                optimizer.zero_grad()

                # Parse from dataset loader
                inputs = sample_batched['images'].cuda()
                gts = sample_batched['scribble_gt'].cuda()
                scribbles = sample_batched['scribble_raw'].cuda()
                scribbles_idx = sample_batched['scribble_idx'].cuda()

                # Forward-Backward of the mini-batch
                # prev_masks = torch.tensor(prev_masks).unsqueeze(0)
                # prev_aggs = torch.tensor(prev_aggs).unsqueeze(0)
                masks, agg = self.net.forward(inputs, scribbles, scribble_idx, prev_masks, prev_agg)

                # Compute the fuse loss
                loss = class_balanced_cross_entropy_loss(masks, gts, scribble_idx)
                running_loss_tr += loss.item()

                # Print stuff
                if epoch % 10 == 0:
                    running_loss_tr /= num_img_tr
                    loss_tr.append(running_loss_tr)

                    print('[Epoch: %d, numImages: %5d]' % (epoch + 1, ii + 1))
                    print('Loss: %f' % running_loss_tr)

                # Backward the averaged gradient
                loss.backward()
                optimizer.step()

                # Update the current round data
                prev_masks.append(masks.detach())
                prev_aggs.append(agg.detach())

            epoch += train_batch
            stop_time = timeit.default_timer()
            if stop_time - start_time > self.time_budget:
                break

        # Save the model into dictionary
        self.prev_models[obj_id] = copy.deepcopy(self.net.state_dict())