Beispiel #1
0
    def _init_model(self, args):
        print("Loading model: " + args.model_name)
        encoder_dict, decoder_dict, _, _, load_args = load_checkpoint(
            args.model_name, args.use_gpu)
        load_args.use_gpu = args.use_gpu
        self.encoder = FeatureExtractor(load_args)
        if args.zero_shot:
            self.decoder = RSIS(load_args)
        else:
            self.decoder = RSISMask(load_args)
        print(load_args)

        if args.ngpus > 1 and args.use_gpu:
            self.decoder = torch.nn.DataParallel(self.decoder,
                                                 device_ids=range(args.ngpus))
            self.encoder = torch.nn.DataParallel(self.encoder,
                                                 device_ids=range(args.ngpus))

        encoder_dict, decoder_dict = check_parallel(encoder_dict, decoder_dict)
        self.encoder.load_state_dict(encoder_dict)

        to_be_deleted_dec = []
        for k in decoder_dict.keys():
            if 'fc_stop' in k:
                to_be_deleted_dec.append(k)
        for k in to_be_deleted_dec:
            del decoder_dict[k]
        self.decoder.load_state_dict(decoder_dict)

        if args.use_gpu:
            self.encoder.cuda()
            self.decoder.cuda()

        self.encoder.eval()
        self.decoder.eval()
        if load_args.length_clip == 1:
            self.video_mode = False
            print('video mode not activated')
        else:
            self.video_mode = True
            print('video mode activated')
Beispiel #2
0
def trainIters(args):
    epoch_resume = 0
    model_dir = os.path.join('../models/',
                             args.model_name + '_prev_inference_mask')

    if args.resume:
        # will resume training the model with name args.model_name
        encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, load_args = load_checkpoint(
            args.model_name, args.use_gpu)

        epoch_resume = load_args.epoch_resume
        encoder = FeatureExtractor(load_args)
        decoder = RSISMask(load_args)
        encoder_dict, decoder_dict = check_parallel(encoder_dict, decoder_dict)
        encoder.load_state_dict(encoder_dict)
        decoder.load_state_dict(decoder_dict)

        args = load_args

    elif args.transfer:
        # load model from args and replace last fc layer
        encoder_dict, decoder_dict, _, _, load_args = load_checkpoint(
            args.transfer_from, args.use_gpu)
        encoder = FeatureExtractor(load_args)
        decoder = RSISMask(args)
        encoder_dict, decoder_dict = check_parallel(encoder_dict, decoder_dict)
        encoder.load_state_dict(encoder_dict)
        decoder.load_state_dict(decoder_dict)

    else:
        encoder = FeatureExtractor(args)
        decoder = RSISMask(args)

    # model checkpoints will be saved here
    make_dir(model_dir)

    # save parameters for future use
    pickle.dump(args, open(os.path.join(model_dir, 'args.pkl'), 'wb'))

    encoder_params = get_base_params(args, encoder)
    skip_params = get_skip_params(encoder)
    decoder_params = list(decoder.parameters()) + list(skip_params)
    dec_opt = get_optimizer(args.optim, args.lr, decoder_params,
                            args.weight_decay)
    enc_opt = get_optimizer(args.optim_cnn, args.lr_cnn, encoder_params,
                            args.weight_decay_cnn)

    if args.resume:
        enc_opt.load_state_dict(enc_opt_dict)
        dec_opt.load_state_dict(dec_opt_dict)
        from collections import defaultdict
        dec_opt.state = defaultdict(dict, dec_opt.state)

    if not args.log_term:
        print("Training logs will be saved to:",
              os.path.join(model_dir, 'train.log'))
        sys.stdout = open(os.path.join(model_dir, 'train.log'), 'w')
        sys.stderr = open(os.path.join(model_dir, 'train.err'), 'w')

    print(args)

    # objective function for mask
    mask_siou = softIoULoss()

    if args.use_gpu:
        encoder.cuda()
        decoder.cuda()
        mask_siou.cuda()

    crits = mask_siou
    optims = [enc_opt, dec_opt]
    if args.use_gpu:
        torch.cuda.synchronize()
    start = time.time()

    # vars for early stopping
    best_val_loss = args.best_val_loss
    acc_patience = 0
    mt_val = -1

    # keep track of the number of batches in each epoch for continuity when plotting curves
    loaders = init_dataloaders(args)
    num_batches = {'train': 0, 'val': 0}
    #area_range = [[0 ** 2, 1e5 ** 2], [0 ** 2, 20 ** 2], [20 ** 2, 59 ** 2], [59 ** 2, 1e5 ** 2]]
    area_range = [[0**2, 1e5**2], [0**2, 30**2], [30**2, 90**2],
                  [90**2, 1e5**2]]  #for (287,950))
    resolution = 0

    for e in range(args.max_epoch):
        print("Epoch", e + epoch_resume)
        # store losses in lists to display average since beginning
        epoch_losses = {
            'train': {
                'total': [],
                'iou': []
            },
            'val': {
                'total': [],
                'iou': []
            }
        }
        # total mean for epoch will be saved here to display at the end
        total_losses = {'total': [], 'iou': []}

        # check if it's time to do some changes here
        if e + epoch_resume >= args.finetune_after and not args.update_encoder and not args.finetune_after == -1:
            print("Starting to update encoder")
            args.update_encoder = True
            acc_patience = 0
            mt_val = -1

        if args.loss_penalization:
            if e < 10:
                resolution = area_range[2]
            else:
                resolution = area_range[0]

        # we validate after each epoch
        for split in ['train', 'val']:
            if args.dataset == 'davis2017' or args.dataset == 'youtube' or args.dataset == 'kittimots':
                loaders[split].dataset.set_epoch(e)
                for batch_idx, (inputs, targets, seq_name,
                                starting_frame) in enumerate(loaders[split]):
                    # send batch to GPU

                    prev_hidden_temporal_list = None
                    loss = None
                    last_frame = False
                    max_ii = min(len(inputs), args.length_clip)

                    for ii in range(max_ii):
                        # If are on the last frame from a clip, we will have to backpropagate the loss back to the beginning of the clip.
                        if ii == max_ii - 1:
                            last_frame = True

                        #                x: input images (N consecutive frames from M different sequences)
                        #                y_mask: ground truth annotations (some of them are zeros to have a fixed length in number of object instances)
                        #                sw_mask: this mask indicates which masks from y_mask are valid
                        x, y_mask, sw_mask = batch_to_var(
                            args, inputs[ii], targets[ii])

                        if ii == 0:
                            prev_mask = y_mask

                        # From one frame to the following frame the prev_hidden_temporal_list is updated.
                        loss, losses, outs, hidden_temporal_list = runIter(
                            args, encoder, decoder, x, y_mask, sw_mask,
                            resolution, crits, optims, split, loss,
                            prev_hidden_temporal_list, prev_mask, last_frame)

                        # Hidden temporal state from time instant ii is saved to be used when processing next time instant ii+1
                        if args.only_spatial == False:
                            prev_hidden_temporal_list = hidden_temporal_list

                        prev_mask = outs

                    # store loss values in dictionary separately
                    epoch_losses[split]['total'].append(losses[0])
                    epoch_losses[split]['iou'].append(losses[1])

                    # print after some iterations
                    if (batch_idx + 1) % args.print_every == 0:

                        mt = np.mean(epoch_losses[split]['total'])
                        mi = np.mean(epoch_losses[split]['iou'])

                        te = time.time() - start
                        print("iter %d:\ttotal:%.4f\tiou:%.4f\ttime:%.4f" %
                              (batch_idx, mt, mi, te))
                        if args.use_gpu:
                            torch.cuda.synchronize()
                        start = time.time()

            num_batches[split] = batch_idx + 1
            # compute mean val losses within epoch

            if split == 'val' and args.smooth_curves:
                if mt_val == -1:
                    mt = np.mean(epoch_losses[split]['total'])
                else:
                    mt = 0.9 * mt_val + 0.1 * np.mean(
                        epoch_losses[split]['total'])
                mt_val = mt

            else:
                mt = np.mean(epoch_losses[split]['total'])

            mi = np.mean(epoch_losses[split]['iou'])

            # save train and val losses for the epoch
            total_losses['iou'].append(mi)
            total_losses['total'].append(mt)

            args.epoch_resume = e + epoch_resume

            print("Epoch %d:\ttotal:%.4f\tiou:%.4f\t(%s)" % (e, mt, mi, split))

        if mt < (best_val_loss - args.min_delta):
            print("Saving checkpoint.")
            best_val_loss = mt
            args.best_val_loss = best_val_loss
            # saves model, params, and optimizers
            save_checkpoint_prev_inference_mask(args, encoder, decoder,
                                                enc_opt, dec_opt)
            acc_patience = 0
        else:
            acc_patience += 1

        if acc_patience > args.patience and not args.update_encoder and not args.finetune_after == -1:
            print("Starting to update encoder")
            acc_patience = 0
            args.update_encoder = True
            best_val_loss = 1000  # reset because adding a loss term will increase the total value
            mt_val = -1
            encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, _ = load_checkpoint(
                args.model_name, args.use_gpu)
            encoder.load_state_dict(encoder_dict)
            decoder.load_state_dict(decoder_dict)
            enc_opt.load_state_dict(enc_opt_dict)
            dec_opt.load_state_dict(dec_opt_dict)

        # early stopping after N epochs without improvement
        if acc_patience > args.patience_stop:
            break
Beispiel #3
0
    def __init__(self, args):

        self.split = args.eval_split
        self.dataset = args.dataset
        to_tensor = transforms.ToTensor()
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        image_transforms = transforms.Compose([to_tensor, normalize])

        if args.dataset == 'youtube':
            dataset = get_dataset(args,
                                  split=self.split,
                                  image_transforms=image_transforms,
                                  target_transforms=None,
                                  augment=args.augment
                                  and self.split == 'train',
                                  inputRes=(256, 448),
                                  video_mode=True,
                                  use_prev_mask=False)
        else:
            dataset = get_dataset(args,
                                  split=self.split,
                                  image_transforms=image_transforms,
                                  target_transforms=None,
                                  augment=args.augment
                                  and self.split == 'train',
                                  video_mode=True,
                                  use_prev_mask=False)

        self.loader = data.DataLoader(dataset,
                                      batch_size=args.batch_size,
                                      shuffle=False,
                                      num_workers=args.num_workers,
                                      drop_last=False)

        self.args = args

        print(args.model_name)
        encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, load_args = load_checkpoint(
            args.model_name, args.use_gpu)
        load_args.use_gpu = args.use_gpu
        self.encoder = FeatureExtractor(load_args)
        self.decoder = RSIS(load_args)

        print(load_args)

        if args.ngpus > 1 and args.use_gpu:
            self.decoder = torch.nn.DataParallel(self.decoder,
                                                 device_ids=range(args.ngpus))
            self.encoder = torch.nn.DataParallel(self.encoder,
                                                 device_ids=range(args.ngpus))

        encoder_dict, decoder_dict = check_parallel(encoder_dict, decoder_dict)
        self.encoder.load_state_dict(encoder_dict)
        to_be_deleted_dec = []
        for k in decoder_dict.keys():
            if 'fc_stop' in k:
                to_be_deleted_dec.append(k)
        for k in to_be_deleted_dec:
            del decoder_dict[k]
        self.decoder.load_state_dict(decoder_dict)

        if args.use_gpu:
            self.encoder.cuda()
            self.decoder.cuda()

        self.encoder.eval()
        self.decoder.eval()
        if load_args.length_clip == 1:
            self.video_mode = False
            print('video mode not activated')
        else:
            self.video_mode = True
            print('video mode activated')
Beispiel #4
0
def trainIters(args):
    print(args)

    model_dir = os.path.join('ckpt/', args.model_name)
    make_dir(model_dir)

    epoch_resume = 0
    if args.resume:
        encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, load_args = \
            load_checkpoint_epoch(args.model_name, args.epoch_resume,
                                  args.use_gpu)

        epoch_resume = args.epoch_resume

        encoder = Encoder()
        decoder = Decoder()

        encoder_dict, decoder_dict = check_parallel(encoder_dict, decoder_dict)
        encoder.load_state_dict(encoder_dict)
        decoder.load_state_dict(decoder_dict)
    else:
        encoder = Encoder()
        decoder = Decoder()

    criterion = WeightedBCE2d()

    if args.use_gpu:
        encoder.cuda()
        decoder.cuda()
        criterion.cuda()

    encoder_params = list(encoder.parameters())
    decoder_params = list(decoder.parameters())
    dec_opt = get_optimizer(args.optim, args.lr, decoder_params,
                            args.weight_decay)
    enc_opt = get_optimizer(args.optim_cnn, args.lr_cnn, encoder_params,
                            args.weight_decay_cnn)

    loaders = init_dataloaders(args)

    best_iou = 0

    start = time.time()
    for e in range(epoch_resume, args.max_epoch):
        print("Epoch", e)
        epoch_losses = {
            'train': {
                'total': [],
                'iou': [],
                'mask_loss': [],
                'bdry_loss': []
            },
            'val': {
                'total': [],
                'iou': [],
                'mask_loss': [],
                'bdry_loss': []
            }
        }

        for split in ['train', 'val']:
            if split == 'train':
                encoder.train(True)
                decoder.train(True)
            else:
                encoder.train(False)
                decoder.train(False)

            for batch_idx, (image, flow, mask, bdry, negative_pixels) in\
                    enumerate(loaders[split]):
                image, flow, mask, bdry, negative_pixels = \
                    image.cuda(), flow.cuda(), mask.cuda(), bdry.cuda(),\
                    negative_pixels.cuda()

                if split == 'train':
                    r5, r4, r3, r2 = encoder(image, flow)
                    mask_pred, p1, p2, p3, p4, p5 = decoder(r5, r4, r3, r2)

                    mask_loss = criterion(mask_pred, mask, negative_pixels)
                    bdry_loss = criterion(p1, bdry, negative_pixels) + \
                                criterion(p2, bdry, negative_pixels) + \
                                criterion(p3, bdry, negative_pixels) + \
                                criterion(p4, bdry, negative_pixels) + \
                                criterion(p5, bdry, negative_pixels)
                    loss = mask_loss + 0.2 * bdry_loss

                    iou = db_eval_iou_multi(mask.cpu().detach().numpy(),
                                            mask_pred.cpu().detach().numpy())

                    dec_opt.zero_grad()
                    enc_opt.zero_grad()
                    loss.backward()
                    enc_opt.step()
                    dec_opt.step()
                else:
                    with torch.no_grad():
                        r5, r4, r3, r2 = encoder(image, flow)
                        mask_pred, p1, p2, p3, p4, p5 = decoder(r5, r4, r3, r2)

                        mask_loss = criterion(mask_pred, mask, negative_pixels)
                        bdry_loss = criterion(p1, bdry, negative_pixels) + \
                                    criterion(p2, bdry, negative_pixels) + \
                                    criterion(p3, bdry, negative_pixels) + \
                                    criterion(p4, bdry, negative_pixels) + \
                                    criterion(p5, bdry, negative_pixels)
                        loss = mask_loss + 0.2 * bdry_loss

                    iou = db_eval_iou_multi(mask.cpu().detach().numpy(),
                                            mask_pred.cpu().detach().numpy())

                epoch_losses[split]['total'].append(loss.data.item())
                epoch_losses[split]['mask_loss'].append(mask_loss.data.item())
                epoch_losses[split]['bdry_loss'].append(bdry_loss.data.item())
                epoch_losses[split]['iou'].append(iou)

                if (batch_idx + 1) % args.print_every == 0:
                    mt = np.mean(epoch_losses[split]['total'])
                    mmask = np.mean(epoch_losses[split]['mask_loss'])
                    mbdry = np.mean(epoch_losses[split]['bdry_loss'])
                    miou = np.mean(epoch_losses[split]['iou'])

                    te = time.time() - start
                    print('Epoch: [{}/{}][{}/{}]\tTime {:.3f}s\tLoss: {:.4f}'
                          '\tMask Loss: {:.4f}\tBdry Loss: {:.4f}'
                          '\tIOU: {:.4f}'.format(e, args.max_epoch, batch_idx,
                                                 len(loaders[split]), te, mt,
                                                 mmask, mbdry, miou))

                    start = time.time()

        miou = np.mean(epoch_losses['val']['iou'])
        if miou > best_iou:
            best_iou = miou
            save_checkpoint_epoch(args, encoder, decoder, enc_opt, dec_opt, e,
                                  False)
Beispiel #5
0
    def __init__(self, args):

        self.split = args.eval_split
        self.display = args.display
        self.no_display_text = args.no_display_text
        self.dataset = args.dataset
        self.all_classes = args.all_classes
        self.use_cats = args.use_cats
        to_tensor = transforms.ToTensor()
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        image_transforms = transforms.Compose([to_tensor, normalize])

        dataset = get_dataset(args,
                              self.split,
                              image_transforms,
                              augment=False,
                              imsize=args.imsize)

        self.sample_list = dataset.get_sample_list()
        self.class_names = dataset.get_classes()
        if args.dataset == 'pascal':
            self.gt_file = pickle.load(
                open(
                    os.path.join(args.pascal_dir,
                                 'VOCGT_%s.pkl' % (self.split)), 'rb'))
            self.key_to_anns = dict()
            self.ignoremasks = {}
            for ann in self.gt_file:
                if ann['ignore'] == 1:
                    if type(ann['segmentation']['counts']) == list:
                        im_height = ann['segmentation']['size'][0]
                        im_width = ann['segmentation']['size'][1]
                        rle = mask.frPyObjects([ann['segmentation']],
                                               im_height, im_width)
                    else:
                        rle = [ann['segmentation']]
                    m = mask.decode(rle)
                    self.ignoremasks[ann['image_id']] = m
                if ann['image_id'] in self.key_to_anns.keys():
                    self.key_to_anns[ann['image_id']].append(ann)
                else:
                    self.key_to_anns[ann['image_id']] = [ann]
            self.coco = create_coco_object(args, self.sample_list,
                                           self.class_names)
        self.loader = data.DataLoader(dataset,
                                      batch_size=args.batch_size,
                                      shuffle=False,
                                      num_workers=args.num_workers,
                                      drop_last=False)

        self.args = args
        self.colors = []
        palette = sequence_palette()
        inv_palette = {}
        for k, v in palette.iteritems():
            inv_palette[v] = k
        num_colors = len(inv_palette.keys())
        for i in range(num_colors):
            if i == 0 or i == 21:
                continue
            c = inv_palette[i]
            self.colors.append(c)

        encoder_dict, decoder_dict, _, _, load_args = load_checkpoint(
            args.model_name, args.use_gpu)
        load_args.use_gpu = args.use_gpu
        self.encoder = FeatureExtractor(load_args)
        self.decoder = RSIS(load_args)

        print(load_args)

        if args.ngpus > 1 and args.use_gpu:
            self.decoder = torch.nn.DataParallel(self.decoder,
                                                 device_ids=range(args.ngpus))
            self.encoder = torch.nn.DataParallel(self.encoder,
                                                 device_ids=range(args.ngpus))

        encoder_dict, decoder_dict = check_parallel(encoder_dict, decoder_dict)
        self.encoder.load_state_dict(encoder_dict)
        self.decoder.load_state_dict(decoder_dict)

        if args.use_gpu:
            self.encoder.cuda()
            self.decoder.cuda()

        self.encoder.eval()
        self.decoder.eval()
Beispiel #6
0
def trainIters(args):

    epoch_resume = 0
    model_dir = os.path.join('../models/', args.model_name)

    if args.resume:
        # will resume training the model with name args.model_name
        encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, load_args = load_checkpoint(args.model_name,args.use_gpu)

        epoch_resume = load_args.epoch_resume
        encoder = FeatureExtractor(load_args)
        decoder = RSIS(load_args)
        encoder_dict, decoder_dict = check_parallel(encoder_dict,decoder_dict)
        encoder.load_state_dict(encoder_dict)
        decoder.load_state_dict(decoder_dict)

        args = load_args

    elif args.transfer:
        # load model from args and replace last fc layer
        encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, load_args = load_checkpoint(args.transfer_from,args.use_gpu)
        encoder = FeatureExtractor(load_args)
        decoder = RSIS(args)
        encoder_dict, decoder_dict = check_parallel(encoder_dict,decoder_dict)
        encoder.load_state_dict(encoder_dict)
        decoder.load_state_dict(decoder_dict)

    else:
        encoder = FeatureExtractor(args)
        decoder = RSIS(args)

    # model checkpoints will be saved here
    make_dir(model_dir)

    # save parameters for future use
    pickle.dump(args, open(os.path.join(model_dir,'args.pkl'),'wb'))

    encoder_params = get_base_params(args,encoder)
    skip_params = get_skip_params(encoder)
    decoder_params = list(decoder.parameters()) + list(skip_params)
    dec_opt = get_optimizer(args.optim, args.lr, decoder_params, args.weight_decay)
    enc_opt = get_optimizer(args.optim_cnn, args.lr_cnn, encoder_params, args.weight_decay_cnn)

    if args.resume or args.transfer:
        enc_opt.load_state_dict(enc_opt_dict)
        dec_opt.load_state_dict(dec_opt_dict)
        from collections import defaultdict
        dec_opt.state = defaultdict(dict, dec_opt.state)

        # change fc layer for new classes
        if load_args.dataset != args.dataset and args.transfer:
            dim_in = decoder.fc_class.weight.size()[1]
            decoder.fc_class = nn.Linear(dim_in,args.num_classes)

    if not args.log_term:
        print "Training logs will be saved to:", os.path.join(model_dir, 'train.log')
        sys.stdout = open(os.path.join(model_dir, 'train.log'), 'w')
        sys.stderr = open(os.path.join(model_dir, 'train.err'), 'w')

    print args

    # objective functions for mask and class outputs.
    # these return the average across samples in batch whose value
    # needs to be considered (those where sw is 1)
    # mask_xentropy = BalancedStableMaskedBCELoss()
    mask_siou = softIoULoss()

    class_xentropy = MaskedNLLLoss(balance_weight=None)
    stop_xentropy = MaskedBCELoss(balance_weight=args.stop_balance_weight)

    if args.ngpus > 1 and args.use_gpu:
        decoder = torch.nn.DataParallel(decoder, device_ids=range(args.ngpus))
        encoder = torch.nn.DataParallel(encoder, device_ids=range(args.ngpus))
        mask_siou = torch.nn.DataParallel(mask_siou, device_ids=range(args.ngpus))
        class_xentropy = torch.nn.DataParallel(class_xentropy, device_ids=range(args.ngpus))
        stop_xentropy = torch.nn.DataParallel(stop_xentropy, device_ids=range(args.ngpus))
    if args.use_gpu:
        encoder.cuda()
        decoder.cuda()
        class_xentropy.cuda()
        mask_siou.cuda()
        stop_xentropy.cuda()

    crits = [mask_siou, class_xentropy, stop_xentropy]
    optims = [enc_opt, dec_opt]
    if args.use_gpu:
        torch.cuda.synchronize()
    start = time.time()

    # vars for early stopping
    best_val_loss = args.best_val_loss
    acc_patience = 0
    mt_val = -1

    # init windows to visualize, if visdom is enabled
    if args.visdom:
        import visdom
        viz = visdom.Visdom(port=args.port, server=args.server)
        lot, elot, mviz_pred, mviz_true, image_lot = init_visdom(args, viz)

    if args.curriculum_learning and epoch_resume == 0:
            args.limit_seqlen_to = 2

    # keep track of the number of batches in each epoch for continuity when plotting curves
    loaders, class_names = init_dataloaders(args)
    num_batches = {'train': 0, 'val': 0}
    for e in range(args.max_epoch):
        print "Epoch", e + epoch_resume
        # store losses in lists to display average since beginning
        epoch_losses = {'train': {'total': [], 'iou': [], 'stop': [], 'class': []},
                            'val': {'total': [], 'iou': [], 'stop': [], 'class': []}}
            # total mean for epoch will be saved here to display at the end
        total_losses = {'total': [], 'iou': [], 'stop': [], 'class': []}

        # check if it's time to do some changes here
        if e + epoch_resume >= args.finetune_after and not args.update_encoder and not args.finetune_after == -1:
            print("Starting to update encoder")
            args.update_encoder = True
            acc_patience = 0
            mt_val = -1
        if e + epoch_resume >= args.class_loss_after and not args.use_class_loss and not args.class_loss_after == -1:
            print("Starting to learn class loss")
            args.use_class_loss = True
            best_val_loss = 1000  # reset because adding a loss term will increase the total value
            acc_patience = 0
            mt_val = -1
        if e + epoch_resume >= args.stop_loss_after and not args.use_stop_loss and not args.stop_loss_after == -1:
            if args.curriculum_learning:
                if args.limit_seqlen_to > args.min_steps:
                    print("Starting to learn stop loss")
                    args.use_stop_loss = True
                    best_val_loss = 1000 # reset because adding a loss term will increase the total value
                    acc_patience = 0
                    mt_val = -1
            else:
                print("Starting to learn stop loss")
                args.use_stop_loss = True
                best_val_loss = 1000 # reset because adding a loss term will increase the total value
                acc_patience = 0
                mt_val = -1

        # we validate after each epoch
        for split in ['train', 'val']:
            for batch_idx, (inputs, targets) in enumerate(loaders[split]):
                # send batch to GPU

                x, y_mask, y_class, sw_mask, sw_class = batch_to_var(args, inputs, targets)

                # we forward (and backward & update if training set)
                losses, outs, true_perm = runIter(args, encoder, decoder, x, y_mask,
                                                  y_class, sw_mask, sw_class,
                                                  crits, optims, mode=split)

                # store loss values in dictionary separately
                epoch_losses[split]['total'].append(losses[0])
                epoch_losses[split]['iou'].append(losses[1])
                epoch_losses[split]['stop'].append(losses[2])
                epoch_losses[split]['class'].append(losses[3])


                # print and display in visdom after some iterations
                if (batch_idx + 1)% args.print_every == 0:

                    mt = np.mean(epoch_losses[split]['total'])
                    mi = np.mean(epoch_losses[split]['iou'])
                    mc = np.mean(epoch_losses[split]['class'])
                    mx = np.mean(epoch_losses[split]['stop'])
                    if args.visdom:

                        if split == 'train':
                            # we display batch loss values in visdom (Training only)
                            viz.line(
                                X=torch.ones((1, 4)).cpu() * (batch_idx + e * num_batches[split]),
                                Y=torch.Tensor([mi, mx, mc, mt]).unsqueeze(0).cpu(),
                                win=lot,
                                update='append')
                        w = x.size()[-1]
                        h = x.size()[-2]
                        out_masks, out_classes, y_mask, y_class = outs_perms_to_cpu(args, outs, true_perm, h, w)

                        x = x.data.cpu().numpy()
                        # send image, sample predictions and ground truths to visdom
                        for t in range(np.shape(out_masks)[1]):
                            mask_pred = out_masks[0, t]
                            mask_true = y_mask[0, t]
                            class_pred = class_names[out_classes[0, t]]
                            class_true = class_names[y_class[0, t]]
                            mask_pred = np.reshape(mask_pred, (x.shape[-2], x.shape[-1]))
                            mask_true = np.reshape(mask_true, (x.shape[-2], x.shape[-1]))

                            # heatmap displays the mask upside down
                            viz.heatmap(np.flipud(mask_pred), win=mviz_pred[t],
                                        opts=dict(title='pred mask %d %s' % (t, class_pred)))
                            viz.heatmap(np.flipud(mask_true), win=mviz_true[t],
                                        opts=dict(title='true mask %d %s' % (t, class_true)))
                            viz.image((x[0] * 0.2 + 0.5) * 256, win=image_lot,
                                      opts=dict(title='image (unnnormalized)'))

                    te = time.time() - start
                    print "iter %d:\ttotal:%.4f\tclass:%.4f\tiou:%.4f\tstop:%.4f\ttime:%.4f" % (batch_idx, mt, mc, mi, mx, te)
                    if args.use_gpu:
                        torch.cuda.synchronize()
                    start = time.time()

            num_batches[split] = batch_idx + 1
            # compute mean val losses within epoch

            if split == 'val' and args.smooth_curves:
                if mt_val == -1:
                    mt = np.mean(epoch_losses[split]['total'])
                else:
                    mt = 0.9*mt_val + 0.1*np.mean(epoch_losses[split]['total'])
                mt_val = mt

            else:
                mt = np.mean(epoch_losses[split]['total'])

            mi = np.mean(epoch_losses[split]['iou'])
            mc = np.mean(epoch_losses[split]['class'])
            mx = np.mean(epoch_losses[split]['stop'])


            # save train and val losses for the epoch to display in visdom
            total_losses['iou'].append(mi)
            total_losses['class'].append(mc)
            total_losses['stop'].append(mx)
            total_losses['total'].append(mt)

            args.epoch_resume = e + epoch_resume

            print "Epoch %d:\ttotal:%.4f\tclass:%.4f\tiou:%.4f\tstop:%.4f\t(%s)" % (e, mt, mc, mi,mx, split)

        # epoch losses
        if args.visdom:
            update = True if e == 0 else 'append'
            for l in ['total', 'iou', 'stop', 'class']:
                viz.line(X=torch.ones((1, 2)).cpu() * (e + 1),
                         Y=torch.Tensor(total_losses[l]).unsqueeze(0).cpu(),
                         win=elot[l],
                         update=update)

        if mt < (best_val_loss - args.min_delta):
            print "Saving checkpoint."
            best_val_loss = mt
            args.best_val_loss = best_val_loss
            # saves model, params, and optimizers
            save_checkpoint(args, encoder, decoder, enc_opt, dec_opt)
            acc_patience = 0
        else:
            acc_patience += 1

        if acc_patience > args.patience and not args.use_class_loss and not args.class_loss_after == -1:
            print("Starting to learn class loss")
            acc_patience = 0
            args.use_class_loss = True
            best_val_loss = 1000  # reset because adding a loss term will increase the total value
            mt_val = -1
            encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, _ = load_checkpoint(args.model_name,args.use_gpu)
            encoder.load_state_dict(encoder_dict)
            decoder.load_state_dict(decoder_dict)
            enc_opt.load_state_dict(enc_opt_dict)
            dec_opt.load_state_dict(dec_opt_dict)
        if acc_patience > args.patience and args.curriculum_learning and args.limit_seqlen_to < args.maxseqlen:
            print("Adding one step more:")
            acc_patience = 0
            args.limit_seqlen_to += args.steps_cl
            print(args.limit_seqlen_to)
            best_val_loss = 1000
            mt_val = -1

        if acc_patience > args.patience and not args.update_encoder and not args.finetune_after == -1:
            print("Starting to update encoder")
            acc_patience = 0
            args.update_encoder = True
            best_val_loss = 1000  # reset because adding a loss term will increase the total value
            mt_val = -1
            encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, _ = load_checkpoint(args.model_name,args.use_gpu)
            encoder.load_state_dict(encoder_dict)
            decoder.load_state_dict(decoder_dict)
            enc_opt.load_state_dict(enc_opt_dict)
            dec_opt.load_state_dict(dec_opt_dict)
        if acc_patience > args.patience and not args.use_stop_loss and not args.stop_loss_after == -1:
            if args.curriculum_learning:
                print("Starting to learn stop loss")
                if args.limit_seqlen_to > args.min_steps:
                    acc_patience = 0
                    args.use_stop_loss = True
                    best_val_loss = 1000 # reset because adding a loss term will increase the total value
                    mt_val = -1
            else:
                print("Starting to learn stop loss")
                acc_patience = 0
                args.use_stop_loss = True
                best_val_loss = 1000 # reset because adding a loss term will increase the total value
                mt_val = -1

            encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, _ = load_checkpoint(args.model_name,args.use_gpu)
            encoder.load_state_dict(encoder_dict)
            decoder.load_state_dict(decoder_dict)
            enc_opt.load_state_dict(enc_opt_dict)
            dec_opt.load_state_dict(dec_opt_dict)
        # early stopping after N epochs without improvement
        if acc_patience > args.patience_stop:
            break
Beispiel #7
0
use_flip = True

to_tensor = transforms.ToTensor()
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
image_transforms = transforms.Compose([to_tensor, normalize])

model_name = 'MATNet'  # specify the model name
epoch = 0  # specify the epoch number
davis_result_dir = './output/davis16'

encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, load_args =\
    load_checkpoint_epoch(model_name, epoch, True, False)
encoder = Encoder()
decoder = Decoder()
encoder_dict, decoder_dict = check_parallel(encoder_dict, decoder_dict)
encoder.load_state_dict(encoder_dict)
decoder.load_state_dict(decoder_dict)

encoder.cuda()
decoder.cuda()

encoder.train(False)
decoder.train(False)

val_set = 'data/DAVIS2017/ImageSets/2016/val.txt'
with open(val_set) as f:
    seqs = f.readlines()
    seqs = [seq.strip() for seq in seqs]

for video in tqdm(seqs):