Esempio n. 1
0
    def create_figures(self):

        acc_samples = 0
        results_root_dir = os.path.join('../models', args.model_name,
                                        args.model_name + '_results')
        make_dir(results_root_dir)
        results_dir = os.path.join(results_root_dir, 'A1')
        make_dir(results_dir)
        print "Creating annotations for leaves validation..."
        for batch_idx, (inputs, targets) in enumerate(self.loader):
            x, y_mask, y_class, sw_mask, sw_class = batch_to_var(
                self.args, inputs, targets)
            out_masks, _, stop_probs = test(self.args, self.encoder,
                                            self.decoder, x)

            for sample in range(self.batch_size):
                sample_idx = self.sample_list[sample + acc_samples]
                image_dir = os.path.join(sample_idx.split('.')[0] + '.png')
                im = scipy.misc.imread(image_dir)
                h = im.shape[0]
                w = im.shape[1]

                mask_sample = np.zeros([h, w])
                sample_idx = sample_idx.split('/')[-1].split('.')[0]
                img_masks = out_masks[sample]

                instance_id = 0
                class_scores = stop_probs[sample]

                for time_step in range(self.T):
                    mask = img_masks[time_step].cpu().numpy()
                    mask = scipy.misc.imresize(mask, [h, w])

                    class_scores_mask = class_scores[time_step].cpu().numpy()
                    class_score = class_scores_mask[0]
                    if class_score > args.class_th:
                        mask_sample[mask > args.mask_th * 255] = time_step
                        instance_id += 1

                file_name = os.path.join(results_dir, sample_idx + '.png')
                file_name_prediction = file_name.replace(
                    'rgb.png', 'label.png')

                im = Image.fromarray(mask_sample).convert('L')
                im.save(file_name_prediction)

            acc_samples += self.batch_size
Esempio n. 2
0
def trainIters(args):
    epoch_resume = 0
    model_dir = os.path.join('../models/',
                             args.model_name + '_prev_inference_mask')

    if args.resume:
        # will resume training the model with name args.model_name
        encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, load_args = load_checkpoint(
            args.model_name, args.use_gpu)

        epoch_resume = load_args.epoch_resume
        encoder = FeatureExtractor(load_args)
        decoder = RSISMask(load_args)
        encoder_dict, decoder_dict = check_parallel(encoder_dict, decoder_dict)
        encoder.load_state_dict(encoder_dict)
        decoder.load_state_dict(decoder_dict)

        args = load_args

    elif args.transfer:
        # load model from args and replace last fc layer
        encoder_dict, decoder_dict, _, _, load_args = load_checkpoint(
            args.transfer_from, args.use_gpu)
        encoder = FeatureExtractor(load_args)
        decoder = RSISMask(args)
        encoder_dict, decoder_dict = check_parallel(encoder_dict, decoder_dict)
        encoder.load_state_dict(encoder_dict)
        decoder.load_state_dict(decoder_dict)

    else:
        encoder = FeatureExtractor(args)
        decoder = RSISMask(args)

    # model checkpoints will be saved here
    make_dir(model_dir)

    # save parameters for future use
    pickle.dump(args, open(os.path.join(model_dir, 'args.pkl'), 'wb'))

    encoder_params = get_base_params(args, encoder)
    skip_params = get_skip_params(encoder)
    decoder_params = list(decoder.parameters()) + list(skip_params)
    dec_opt = get_optimizer(args.optim, args.lr, decoder_params,
                            args.weight_decay)
    enc_opt = get_optimizer(args.optim_cnn, args.lr_cnn, encoder_params,
                            args.weight_decay_cnn)

    if args.resume:
        enc_opt.load_state_dict(enc_opt_dict)
        dec_opt.load_state_dict(dec_opt_dict)
        from collections import defaultdict
        dec_opt.state = defaultdict(dict, dec_opt.state)

    if not args.log_term:
        print("Training logs will be saved to:",
              os.path.join(model_dir, 'train.log'))
        sys.stdout = open(os.path.join(model_dir, 'train.log'), 'w')
        sys.stderr = open(os.path.join(model_dir, 'train.err'), 'w')

    print(args)

    # objective function for mask
    mask_siou = softIoULoss()

    if args.use_gpu:
        encoder.cuda()
        decoder.cuda()
        mask_siou.cuda()

    crits = mask_siou
    optims = [enc_opt, dec_opt]
    if args.use_gpu:
        torch.cuda.synchronize()
    start = time.time()

    # vars for early stopping
    best_val_loss = args.best_val_loss
    acc_patience = 0
    mt_val = -1

    # keep track of the number of batches in each epoch for continuity when plotting curves
    loaders = init_dataloaders(args)
    num_batches = {'train': 0, 'val': 0}
    #area_range = [[0 ** 2, 1e5 ** 2], [0 ** 2, 20 ** 2], [20 ** 2, 59 ** 2], [59 ** 2, 1e5 ** 2]]
    area_range = [[0**2, 1e5**2], [0**2, 30**2], [30**2, 90**2],
                  [90**2, 1e5**2]]  #for (287,950))
    resolution = 0

    for e in range(args.max_epoch):
        print("Epoch", e + epoch_resume)
        # store losses in lists to display average since beginning
        epoch_losses = {
            'train': {
                'total': [],
                'iou': []
            },
            'val': {
                'total': [],
                'iou': []
            }
        }
        # total mean for epoch will be saved here to display at the end
        total_losses = {'total': [], 'iou': []}

        # check if it's time to do some changes here
        if e + epoch_resume >= args.finetune_after and not args.update_encoder and not args.finetune_after == -1:
            print("Starting to update encoder")
            args.update_encoder = True
            acc_patience = 0
            mt_val = -1

        if args.loss_penalization:
            if e < 10:
                resolution = area_range[2]
            else:
                resolution = area_range[0]

        # we validate after each epoch
        for split in ['train', 'val']:
            if args.dataset == 'davis2017' or args.dataset == 'youtube' or args.dataset == 'kittimots':
                loaders[split].dataset.set_epoch(e)
                for batch_idx, (inputs, targets, seq_name,
                                starting_frame) in enumerate(loaders[split]):
                    # send batch to GPU

                    prev_hidden_temporal_list = None
                    loss = None
                    last_frame = False
                    max_ii = min(len(inputs), args.length_clip)

                    for ii in range(max_ii):
                        # If are on the last frame from a clip, we will have to backpropagate the loss back to the beginning of the clip.
                        if ii == max_ii - 1:
                            last_frame = True

                        #                x: input images (N consecutive frames from M different sequences)
                        #                y_mask: ground truth annotations (some of them are zeros to have a fixed length in number of object instances)
                        #                sw_mask: this mask indicates which masks from y_mask are valid
                        x, y_mask, sw_mask = batch_to_var(
                            args, inputs[ii], targets[ii])

                        if ii == 0:
                            prev_mask = y_mask

                        # From one frame to the following frame the prev_hidden_temporal_list is updated.
                        loss, losses, outs, hidden_temporal_list = runIter(
                            args, encoder, decoder, x, y_mask, sw_mask,
                            resolution, crits, optims, split, loss,
                            prev_hidden_temporal_list, prev_mask, last_frame)

                        # Hidden temporal state from time instant ii is saved to be used when processing next time instant ii+1
                        if args.only_spatial == False:
                            prev_hidden_temporal_list = hidden_temporal_list

                        prev_mask = outs

                    # store loss values in dictionary separately
                    epoch_losses[split]['total'].append(losses[0])
                    epoch_losses[split]['iou'].append(losses[1])

                    # print after some iterations
                    if (batch_idx + 1) % args.print_every == 0:

                        mt = np.mean(epoch_losses[split]['total'])
                        mi = np.mean(epoch_losses[split]['iou'])

                        te = time.time() - start
                        print("iter %d:\ttotal:%.4f\tiou:%.4f\ttime:%.4f" %
                              (batch_idx, mt, mi, te))
                        if args.use_gpu:
                            torch.cuda.synchronize()
                        start = time.time()

            num_batches[split] = batch_idx + 1
            # compute mean val losses within epoch

            if split == 'val' and args.smooth_curves:
                if mt_val == -1:
                    mt = np.mean(epoch_losses[split]['total'])
                else:
                    mt = 0.9 * mt_val + 0.1 * np.mean(
                        epoch_losses[split]['total'])
                mt_val = mt

            else:
                mt = np.mean(epoch_losses[split]['total'])

            mi = np.mean(epoch_losses[split]['iou'])

            # save train and val losses for the epoch
            total_losses['iou'].append(mi)
            total_losses['total'].append(mt)

            args.epoch_resume = e + epoch_resume

            print("Epoch %d:\ttotal:%.4f\tiou:%.4f\t(%s)" % (e, mt, mi, split))

        if mt < (best_val_loss - args.min_delta):
            print("Saving checkpoint.")
            best_val_loss = mt
            args.best_val_loss = best_val_loss
            # saves model, params, and optimizers
            save_checkpoint_prev_inference_mask(args, encoder, decoder,
                                                enc_opt, dec_opt)
            acc_patience = 0
        else:
            acc_patience += 1

        if acc_patience > args.patience and not args.update_encoder and not args.finetune_after == -1:
            print("Starting to update encoder")
            acc_patience = 0
            args.update_encoder = True
            best_val_loss = 1000  # reset because adding a loss term will increase the total value
            mt_val = -1
            encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, _ = load_checkpoint(
                args.model_name, args.use_gpu)
            encoder.load_state_dict(encoder_dict)
            decoder.load_state_dict(decoder_dict)
            enc_opt.load_state_dict(enc_opt_dict)
            dec_opt.load_state_dict(dec_opt_dict)

        # early stopping after N epochs without improvement
        if acc_patience > args.patience_stop:
            break
Esempio n. 3
0
    def run_eval(self):
        print("Dataset is %s" % (self.dataset))
        print("Split is %s" % (self.split))

        if args.overlay_masks:

            colors = []
            palette = sequence_palette()
            inv_palette = {}
            for k, v in palette.items():
                inv_palette[v] = k
            num_colors = len(inv_palette.keys())
            for id_color in range(num_colors):
                if id_color == 0 or id_color == 21:
                    continue
                c = inv_palette[id_color]
                colors.append(c)

        if self.split == 'val':

            if args.dataset == 'youtube':

                masks_sep_dir = os.path.join('../models', args.model_name,
                                             'masks_sep_2assess')
                make_dir(masks_sep_dir)
                if args.overlay_masks:
                    results_dir = os.path.join('../models', args.model_name,
                                               'results')
                    make_dir(results_dir)

                json_data = open(
                    '../../databases/YouTubeVOS/train/train-val-meta.json')
                data = json.load(json_data)

            else:

                masks_sep_dir = os.path.join('../models', args.model_name,
                                             'masks_sep_2assess-davis')
                make_dir(masks_sep_dir)
                if args.overlay_masks:
                    results_dir = os.path.join('../models', args.model_name,
                                               'results-davis')
                    make_dir(results_dir)

            for batch_idx, (inputs, targets, seq_name,
                            starting_frame) in enumerate(self.loader):

                prev_hidden_temporal_list = None
                max_ii = min(len(inputs), args.length_clip)

                base_dir_masks_sep = masks_sep_dir + '/' + seq_name[0] + '/'
                make_dir(base_dir_masks_sep)

                if args.overlay_masks:
                    base_dir = results_dir + '/' + seq_name[0] + '/'
                    make_dir(base_dir)

                for ii in range(max_ii):

                    #                x: input images (N consecutive frames from M different sequences)
                    #                y_mask: ground truth annotations (some of them are zeros to have a fixed length in number of object instances)
                    #                sw_mask: this mask indicates which masks from y_mask are valid
                    x, y_mask, sw_mask = batch_to_var(args, inputs[ii],
                                                      targets[ii])

                    print(seq_name[0] + '/' + '%05d' %
                          (starting_frame[0] + ii))

                    #from one frame to the following frame the prev_hidden_temporal_list is updated.
                    outs, hidden_temporal_list = test(
                        args, self.encoder, self.decoder, x,
                        prev_hidden_temporal_list)

                    if args.dataset == 'youtube':
                        num_instances = len(
                            data['videos'][seq_name[0]]['objects'])
                    else:
                        num_instances = 1  #int(torch.sum(sw_mask.data).data.cpu().numpy())

                    x_tmp = x.data.cpu().numpy()
                    height = x_tmp.shape[-2]
                    width = x_tmp.shape[-1]
                    for t in range(10):
                        mask_pred = (torch.squeeze(outs[0,
                                                        t, :])).cpu().numpy()
                        mask_pred = np.reshape(mask_pred, (height, width))
                        indxs_instance = np.where(mask_pred > 0.5)
                        mask2assess = np.zeros((height, width))
                        mask2assess[indxs_instance] = 255
                        toimage(mask2assess, cmin=0,
                                cmax=255).save(base_dir_masks_sep +
                                               '%05d_instance_%02d.png' %
                                               (starting_frame[0] + ii, t))

                    if args.overlay_masks:

                        frame_img = x.data.cpu().numpy()[0, :, :, :].squeeze()
                        frame_img = np.transpose(frame_img, (1, 2, 0))
                        mean = np.array([0.485, 0.456, 0.406])
                        std = np.array([0.229, 0.224, 0.225])
                        frame_img = std * frame_img + mean
                        frame_img = np.clip(frame_img, 0, 1)
                        plt.figure()
                        plt.axis('off')
                        plt.figure()
                        plt.axis('off')
                        plt.imshow(frame_img)

                        for t in range(num_instances):
                            mask_pred = (torch.squeeze(
                                outs[0, t, :])).cpu().numpy()
                            mask_pred = np.reshape(mask_pred, (height, width))
                            ax = plt.gca()
                            tmp_img = np.ones(
                                (mask_pred.shape[0], mask_pred.shape[1], 3))
                            color_mask = np.array(colors[t]) / 255.0
                            for i in range(3):
                                tmp_img[:, :, i] = color_mask[i]
                            ax.imshow(np.dstack((tmp_img, mask_pred * 0.7)))

                        figname = base_dir + 'frame_%02d.png' % (
                            starting_frame[0] + ii)
                        plt.savefig(figname, bbox_inches='tight')
                        plt.close()

                    if self.video_mode:
                        prev_hidden_temporal_list = hidden_temporal_list

        else:

            if args.dataset == 'youtube':

                masks_sep_dir = os.path.join('../models', args.model_name,
                                             'masks_sep_2assess_val')
                make_dir(masks_sep_dir)
                if args.overlay_masks:
                    results_dir = os.path.join('../models', args.model_name,
                                               'results_val')
                    make_dir(results_dir)

                json_data = open('../../databases/YouTubeVOS/val/meta.json')
                data = json.load(json_data)

            else:

                masks_sep_dir = os.path.join('../models', args.model_name,
                                             'masks_sep_2assess_val_davis')
                make_dir(masks_sep_dir)
                if args.overlay_masks:
                    results_dir = os.path.join('../models', args.model_name,
                                               'results_val_davis')
                    make_dir(results_dir)

            for batch_idx, (inputs, seq_name,
                            starting_frame) in enumerate(self.loader):

                prev_hidden_temporal_list = None
                max_ii = min(len(inputs), args.length_clip)

                for ii in range(max_ii):

                    #                x: input images (N consecutive frames from M different sequences)
                    x = batch_to_var_test(args, inputs[ii])

                    print(seq_name[0] + '/' + '%05d' %
                          (starting_frame[0] + ii))

                    if ii == 0:

                        if args.dataset == 'youtube':

                            num_instances = len(
                                data['videos'][seq_name[0]]['objects'])

                        else:

                            annotation = Image.open(
                                '../../databases/DAVIS2017/Annotations/480p/' +
                                seq_name[0] + '/00000.png')
                            instance_ids = sorted(np.unique(annotation))
                            instance_ids = instance_ids if instance_ids[
                                0] else instance_ids[1:]
                            if len(instance_ids) > 0:
                                instance_ids = instance_ids[:-1] if instance_ids[
                                    -1] == 255 else instance_ids
                            num_instances = len(instance_ids)

                    #from one frame to the following frame the prev_hidden_temporal_list is updated.
                    outs, hidden_temporal_list = test(
                        args, self.encoder, self.decoder, x,
                        prev_hidden_temporal_list)

                    base_dir_masks_sep = masks_sep_dir + '/' + seq_name[0] + '/'
                    make_dir(base_dir_masks_sep)

                    if args.overlay_masks:
                        base_dir = results_dir + '/' + seq_name[0] + '/'
                        make_dir(base_dir)

                    x_tmp = x.data.cpu().numpy()
                    height = x_tmp.shape[-2]
                    width = x_tmp.shape[-1]
                    for t in range(10):
                        mask_pred = (torch.squeeze(outs[0,
                                                        t, :])).cpu().numpy()
                        mask_pred = np.reshape(mask_pred, (height, width))
                        indxs_instance = np.where(mask_pred > 0.5)
                        mask2assess = np.zeros((height, width))
                        mask2assess[indxs_instance] = 255
                        toimage(mask2assess, cmin=0,
                                cmax=255).save(base_dir_masks_sep +
                                               '%05d_instance_%02d.png' %
                                               (starting_frame[0] + ii, t))

                    if args.overlay_masks:

                        frame_img = x.data.cpu().numpy()[0, :, :, :].squeeze()
                        frame_img = np.transpose(frame_img, (1, 2, 0))
                        mean = np.array([0.485, 0.456, 0.406])
                        std = np.array([0.229, 0.224, 0.225])
                        frame_img = std * frame_img + mean
                        frame_img = np.clip(frame_img, 0, 1)
                        plt.figure()
                        plt.axis('off')
                        plt.figure()
                        plt.axis('off')
                        plt.imshow(frame_img)

                        for t in range(num_instances):

                            mask_pred = (torch.squeeze(
                                outs[0, t, :])).cpu().numpy()
                            mask_pred = np.reshape(mask_pred, (height, width))
                            ax = plt.gca()
                            tmp_img = np.ones(
                                (mask_pred.shape[0], mask_pred.shape[1], 3))
                            color_mask = np.array(colors[t]) / 255.0
                            for i in range(3):
                                tmp_img[:, :, i] = color_mask[i]
                            ax.imshow(np.dstack((tmp_img, mask_pred * 0.7)))

                        figname = base_dir + 'frame_%02d.png' % (
                            starting_frame[0] + ii)
                        plt.savefig(figname, bbox_inches='tight')
                        plt.close()

                    if self.video_mode:
                        prev_hidden_temporal_list = hidden_temporal_list
Esempio n. 4
0
    def create_figures(self):

        acc_samples = 0
        results_dir = os.path.join('../models', args.model_name,
                                   args.model_name + '_results')

        make_dir(results_dir)
        masks_dir = os.path.join(args.model_name + '_masks')
        abs_masks_dir = os.path.join(results_dir, masks_dir)
        make_dir(abs_masks_dir)
        print "Creating annotations for cityscapes validation..."
        for batch_idx, (inputs, targets) in enumerate(self.loader):
            x, y_mask, y_class, sw_mask, sw_class = batch_to_var(
                self.args, inputs, targets)
            out_masks, out_scores, stop_probs = test(self.args, self.encoder,
                                                     self.decoder, x)

            class_ids = [24, 25, 26, 27, 28, 31, 32, 33]

            for sample in range(self.batch_size):

                sample_idx = self.sample_list[sample + acc_samples]
                image_dir = os.path.join(sample_idx.split('.')[0] + '.png')
                im = scipy.misc.imread(image_dir)
                h = im.shape[0]
                w = im.shape[1]

                sample_idx = sample_idx.split('/')[-1].split('.')[0]

                results_file = open(
                    os.path.join(results_dir, sample_idx + '.txt'), 'w')
                img_masks = out_masks[sample]

                instance_id = 0

                class_scores = out_scores[sample]
                stop_scores = stop_probs[sample]

                for time_step in range(self.T):
                    mask = img_masks[time_step].cpu().numpy()
                    mask = (mask > args.mask_th)

                    h_mask = mask.shape[0]
                    w_mask = mask.shape[1]

                    mask = (mask > 0)
                    labeled_blobs = measure.label(mask, background=0).flatten()

                    # find the biggest one
                    count = Counter(labeled_blobs)
                    s = []
                    max_num = 0
                    for v, k in count.iteritems():
                        if v == 0:
                            continue
                        if k > max_num:
                            max_num = k
                            max_label = v
                    # build mask from the largest connected component
                    segmentation = (labeled_blobs == max_label).astype("uint8")
                    mask = segmentation.reshape([h_mask, w_mask]) * 255

                    mask = scipy.misc.imresize(mask, [h, w])
                    class_scores_mask = class_scores[time_step].cpu().numpy()
                    stop_scores_mask = stop_scores[time_step].cpu().numpy()
                    class_score = np.argmax(class_scores_mask)

                    for i in range(len(class_scores_mask) - 1):
                        name_instance = sample_idx + '_' + str(
                            instance_id) + '.png'
                        pred_class_score = class_scores_mask[i + 1]
                        objectness = stop_scores_mask[0]
                        pred_class_score *= objectness
                        scipy.misc.imsave(
                            os.path.join(abs_masks_dir, name_instance), mask)
                        results_file.write(masks_dir + '/' + name_instance +
                                           ' ' + str(class_ids[i]) + ' ' +
                                           str(pred_class_score) + '\n')
                        instance_id += 1

                results_file.close()

            acc_samples += self.batch_size
Esempio n. 5
0
    def run_eval(self):
        print("Dataset is %s" % (self.dataset))
        print("Split is %s" % (self.split))

        if args.overlay_masks:

            colors = []
            palette = sequence_palette()
            inv_palette = {}
            for k, v in palette.items():
                inv_palette[v] = k
            num_colors = len(inv_palette.keys())
            for id_color in range(num_colors):
                if id_color == 0 or id_color == 21:
                    continue
                c = inv_palette[id_color]
                colors.append(c)

        if self.split == 'val':

            if args.dataset == 'youtube':

                masks_sep_dir = os.path.join('../models', args.model_name,
                                             'masks_sep_2assess')
                make_dir(masks_sep_dir)
                if args.overlay_masks:
                    results_dir = os.path.join('../models', args.model_name,
                                               'results')
                    make_dir(results_dir)

                json_data = open(
                    '../../databases/YouTubeVOS/train/train-val-meta.json')
                data = json.load(json_data)

            else:  #args.dataset == 'davis2017'

                import lmdb
                from misc.config import cfg

                masks_sep_dir = os.path.join('../models', args.model_name,
                                             'masks_sep_2assess-davis')
                make_dir(masks_sep_dir)

                if args.overlay_masks:
                    results_dir = os.path.join('../models', args.model_name,
                                               'results-davis')
                    make_dir(results_dir)

                lmdb_env_seq_dir = osp.join(cfg.PATH.DATA, 'lmdb_seq')

                if osp.isdir(lmdb_env_seq_dir):
                    lmdb_env_seq = lmdb.open(lmdb_env_seq_dir)
                else:
                    lmdb_env_seq = None

            for batch_idx, (inputs, targets, seq_name,
                            starting_frame) in enumerate(self.loader):

                prev_hidden_temporal_list = None
                max_ii = min(len(inputs), args.length_clip)

                if args.overlay_masks:
                    base_dir = results_dir + '/' + seq_name[0] + '/'
                    make_dir(base_dir)

                if args.dataset == 'davis2017':
                    key_db = osp.basename(seq_name[0])

                    if not lmdb_env_seq == None:
                        with lmdb_env_seq.begin() as txn:
                            _files_vec = txn.get(
                                key_db.encode()).decode().split('|')
                            _files = [osp.splitext(f)[0] for f in _files_vec]
                    else:
                        seq_dir = osp.join(cfg['PATH']['SEQUENCES'], key_db)
                        _files_vec = os.listdir(seq_dir)
                        _files = [osp.splitext(f)[0] for f in _files_vec]

                    frame_names = sorted(_files)

                for ii in range(max_ii):

                    #start_time = time.time()
                    #                x: input images (N consecutive frames from M different sequences)
                    #                y_mask: ground truth annotations (some of them are zeros to have a fixed length in number of object instances)
                    #                sw_mask: this mask indicates which masks from y_mask are valid
                    x, y_mask, sw_mask = batch_to_var(args, inputs[ii],
                                                      targets[ii])

                    if ii == 0:
                        prev_mask = y_mask

                    #from one frame to the following frame the prev_hidden_temporal_list is updated.
                    outs, hidden_temporal_list = test_prev_mask(
                        args, self.encoder, self.decoder, x,
                        prev_hidden_temporal_list, prev_mask)

                    #end_inference_time = time.time()
                    #print("inference time: %.3f" %(end_inference_time-start_time))

                    if args.dataset == 'youtube':
                        num_instances = len(
                            data['videos'][seq_name[0]]['objects'])
                    else:
                        num_instances = int(
                            torch.sum(sw_mask.data).data.cpu().numpy())

                    base_dir_masks_sep = masks_sep_dir + '/' + seq_name[0] + '/'
                    make_dir(base_dir_masks_sep)

                    x_tmp = x.data.cpu().numpy()
                    height = x_tmp.shape[-2]
                    width = x_tmp.shape[-1]
                    for t in range(num_instances):
                        mask_pred = (torch.squeeze(outs[0,
                                                        t, :])).cpu().numpy()
                        mask_pred = np.reshape(mask_pred, (height, width))
                        indxs_instance = np.where(mask_pred > 0.5)
                        mask2assess = np.zeros((height, width))
                        mask2assess[indxs_instance] = 255
                        if args.dataset == 'youtube':
                            toimage(mask2assess, cmin=0,
                                    cmax=255).save(base_dir_masks_sep +
                                                   '%05d_instance_%02d.png' %
                                                   (starting_frame[0] + ii, t))
                        else:
                            toimage(mask2assess, cmin=0,
                                    cmax=255).save(base_dir_masks_sep +
                                                   frame_names[ii] +
                                                   '_instance_%02d.png' % (t))

                    #end_saving_masks_time = time.time()
                    #print("inference + saving masks time: %.3f" %(end_saving_masks_time - start_time))
                    if args.dataset == 'youtube':
                        print(seq_name[0] + '/' + '%05d' %
                              (starting_frame[0] + ii))
                    else:
                        print(seq_name[0] + '/' + frame_names[ii])

                    if args.overlay_masks:

                        frame_img = x.data.cpu().numpy()[0, :, :, :].squeeze()
                        frame_img = np.transpose(frame_img, (1, 2, 0))
                        mean = np.array([0.485, 0.456, 0.406])
                        std = np.array([0.229, 0.224, 0.225])
                        frame_img = std * frame_img + mean
                        frame_img = np.clip(frame_img, 0, 1)
                        plt.figure()
                        plt.axis('off')
                        plt.figure()
                        plt.axis('off')
                        plt.imshow(frame_img)

                        for t in range(num_instances):

                            mask_pred = (torch.squeeze(
                                outs[0, t, :])).cpu().numpy()
                            mask_pred = np.reshape(mask_pred, (height, width))

                            ax = plt.gca()
                            tmp_img = np.ones(
                                (mask_pred.shape[0], mask_pred.shape[1], 3))
                            color_mask = np.array(colors[t]) / 255.0
                            for i in range(3):
                                tmp_img[:, :, i] = color_mask[i]
                            ax.imshow(np.dstack((tmp_img, mask_pred * 0.7)))

                        if args.dataset == 'youtube':
                            figname = base_dir + 'frame_%02d.png' % (
                                starting_frame[0] + ii)
                        else:
                            figname = base_dir + frame_names[ii] + '.png'

                        plt.savefig(figname, bbox_inches='tight')
                        plt.close()

                    if self.video_mode:
                        if args.only_spatial == False:
                            prev_hidden_temporal_list = hidden_temporal_list
                        if ii > 0:
                            prev_mask = outs
                        else:
                            prev_mask = y_mask

                    del outs, hidden_temporal_list, x, y_mask, sw_mask

        else:

            if args.dataset == 'youtube':

                masks_sep_dir = os.path.join('../models', args.model_name,
                                             'masks_sep_2assess_val')
                make_dir(masks_sep_dir)
                if args.overlay_masks:
                    results_dir = os.path.join('../models', args.model_name,
                                               'results_val')
                    make_dir(results_dir)

                json_data = open('../../databases/YouTubeVOS/valid/meta.json')
                data = json.load(json_data)

            else:  #args.dataset == 'davis2017'

                import lmdb
                from misc.config import cfg

                masks_sep_dir = os.path.join('../models', args.model_name,
                                             'masks_sep_2assess_val_davis')
                make_dir(masks_sep_dir)
                if args.overlay_masks:
                    results_dir = os.path.join('../models', args.model_name,
                                               'results_val_davis')
                    make_dir(results_dir)

                lmdb_env_seq_dir = osp.join(cfg.PATH.DATA, 'lmdb_seq')

                if osp.isdir(lmdb_env_seq_dir):
                    lmdb_env_seq = lmdb.open(lmdb_env_seq_dir)
                else:
                    lmdb_env_seq = None

            for batch_idx, (inputs, seq_name,
                            starting_frame) in enumerate(self.loader):

                prev_hidden_temporal_list = None
                max_ii = min(len(inputs), args.length_clip)

                if args.overlay_masks:
                    base_dir = results_dir + '/' + seq_name[0] + '/'
                    make_dir(base_dir)

                if args.dataset == 'youtube':

                    seq_data = data['videos'][seq_name[0]]['objects']
                    frame_names = []
                    frame_names_with_new_objects = []
                    instance_ids = []

                    for obj_id in seq_data.keys():
                        instance_ids.append(int(obj_id))
                        frame_names_with_new_objects.append(
                            seq_data[obj_id]['frames'][0])
                        for frame_name in seq_data[obj_id]['frames']:
                            if frame_name not in frame_names:
                                frame_names.append(frame_name)

                    frame_names.sort()
                    frame_names_with_new_objects_idxs = []
                    for kk in range(len(frame_names_with_new_objects)):
                        new_frame_idx = frame_names.index(
                            frame_names_with_new_objects[kk])
                        frame_names_with_new_objects_idxs.append(new_frame_idx)

                else:  #davis2017

                    key_db = osp.basename(seq_name[0])

                    if not lmdb_env_seq == None:
                        with lmdb_env_seq.begin() as txn:
                            _files_vec = txn.get(
                                key_db.encode()).decode().split('|')
                            _files = [osp.splitext(f)[0] for f in _files_vec]
                    else:
                        seq_dir = osp.join(cfg['PATH']['SEQUENCES'], key_db)
                        _files_vec = os.listdir(seq_dir)
                        _files = [osp.splitext(f)[0] for f in _files_vec]

                    frame_names = sorted(_files)

                for ii in range(max_ii):

                    #                x: input images (N consecutive frames from M different sequences)
                    #                y_mask: ground truth annotations (some of them are zeros to have a fixed length in number of object instances)
                    #                sw_mask: this mask indicates which masks from y_mask are valid
                    x = batch_to_var_test(args, inputs[ii])

                    print(seq_name[0] + '/' + frame_names[ii])

                    if ii == 0:

                        frame_name = frame_names[0]
                        if args.dataset == 'youtube':
                            annotation = Image.open(
                                '../../databases/YouTubeVOS/valid/Annotations/'
                                + seq_name[0] + '/' + frame_name + '.png')
                            annot = imresize(annotation, (256, 448),
                                             interp='nearest')
                        else:  #davis2017
                            annotation = Image.open(
                                '../../databases/DAVIS2017/Annotations/480p/' +
                                seq_name[0] + '/' + frame_name + '.png')
                            instance_ids = sorted(np.unique(annotation))
                            instance_ids = instance_ids if instance_ids[
                                0] else instance_ids[1:]
                            if len(instance_ids) > 0:
                                instance_ids = instance_ids[:-1] if instance_ids[
                                    -1] == 255 else instance_ids
                            annot = imresize(annotation, (240, 427),
                                             interp='nearest')

                        annot = np.expand_dims(annot, axis=0)
                        annot = torch.from_numpy(annot)
                        annot = annot.float()
                        annot = annot.numpy().squeeze()
                        annot = annot_from_mask(annot, instance_ids)
                        prev_mask = annot
                        prev_mask = np.expand_dims(prev_mask, axis=0)
                        prev_mask = torch.from_numpy(prev_mask)
                        y_mask = Variable(prev_mask.float(),
                                          requires_grad=False)
                        prev_mask = y_mask.cuda()
                        del annot

                    if args.dataset == 'youtube':
                        if ii > 0 and ii in frame_names_with_new_objects_idxs:

                            frame_name = frame_names[ii]
                            annotation = Image.open(
                                '../../databases/YouTubeVOS/valid/Annotations/'
                                + seq_name[0] + '/' + frame_name + '.png')
                            annot = imresize(annotation, (256, 448),
                                             interp='nearest')
                            annot = np.expand_dims(annot, axis=0)
                            annot = torch.from_numpy(annot)
                            annot = annot.float()
                            annot = annot.numpy().squeeze()
                            new_instance_ids = np.unique(annot)[1:]
                            annot = annot_from_mask(annot, new_instance_ids)
                            annot = np.expand_dims(annot, axis=0)
                            annot = torch.from_numpy(annot)
                            annot = Variable(annot.float(),
                                             requires_grad=False)
                            annot = annot.cuda()
                            for kk in new_instance_ids:
                                prev_mask[:, int(kk - 1), :] = annot[:,
                                                                     int(kk -
                                                                         1), :]
                            del annot

                    #from one frame to the following frame the prev_hidden_temporal_list is updated.
                    outs, hidden_temporal_list = test_prev_mask(
                        args, self.encoder, self.decoder, x,
                        prev_hidden_temporal_list, prev_mask)

                    base_dir_masks_sep = masks_sep_dir + '/' + seq_name[0] + '/'
                    make_dir(base_dir_masks_sep)

                    x_tmp = x.data.cpu().numpy()
                    height = x_tmp.shape[-2]
                    width = x_tmp.shape[-1]

                    for t in range(len(instance_ids)):
                        mask_pred = (torch.squeeze(outs[0,
                                                        t, :])).cpu().numpy()
                        mask_pred = np.reshape(mask_pred, (height, width))
                        indxs_instance = np.where(mask_pred > 0.5)
                        mask2assess = np.zeros((height, width))
                        mask2assess[indxs_instance] = 255
                        toimage(mask2assess, cmin=0,
                                cmax=255).save(base_dir_masks_sep +
                                               frame_names[ii] +
                                               '_instance_%02d.png' % (t))

                    if args.overlay_masks:

                        frame_img = x.data.cpu().numpy()[0, :, :, :].squeeze()
                        frame_img = np.transpose(frame_img, (1, 2, 0))
                        mean = np.array([0.485, 0.456, 0.406])
                        std = np.array([0.229, 0.224, 0.225])
                        frame_img = std * frame_img + mean
                        frame_img = np.clip(frame_img, 0, 1)
                        plt.figure()
                        plt.axis('off')
                        plt.figure()
                        plt.axis('off')
                        plt.imshow(frame_img)

                        for t in range(len(instance_ids)):

                            mask_pred = (torch.squeeze(
                                outs[0, t, :])).cpu().numpy()
                            mask_pred = np.reshape(mask_pred, (height, width))
                            ax = plt.gca()
                            tmp_img = np.ones(
                                (mask_pred.shape[0], mask_pred.shape[1], 3))
                            color_mask = np.array(colors[t]) / 255.0
                            for i in range(3):
                                tmp_img[:, :, i] = color_mask[i]
                            ax.imshow(np.dstack((tmp_img, mask_pred * 0.7)))

                        figname = base_dir + frame_names[ii] + '.png'
                        plt.savefig(figname, bbox_inches='tight')
                        plt.close()

                    if self.video_mode:
                        if args.only_spatial == False:
                            prev_hidden_temporal_list = hidden_temporal_list
                        if ii > 0:
                            prev_mask = outs
                        del x, hidden_temporal_list, outs
Esempio n. 6
0
    def _create_json(self):

        predictions = list()
        acc_samples = 0
        print "Creating annotations..."

        for batch_idx, (inputs, targets) in enumerate(self.loader):

            x, y_mask, y_class, sw_mask, sw_class = batch_to_var(
                self.args, inputs, targets)
            num_objects = np.sum(sw_mask.data.float().cpu().numpy(), axis=-1)

            out_masks, out_scores, stop_probs = test(self.args, self.encoder,
                                                     self.decoder, x)

            out_scores = out_scores.cpu().numpy()
            stop_scores = stop_probs.cpu().numpy()
            out_masks = out_masks.cpu().numpy()
            out_classes = np.argmax(out_scores, axis=-1)
            w = x.size()[-1]
            h = x.size()[-2]
            #out_masks, out_classes, y_mask, y_class = outs_perms_to_cpu(self.args,outs,true_perms,h,w)
            for s in range(out_masks.shape[0]):
                this_pred = list()
                sample_idx = self.sample_list[s + acc_samples]

                if self.args.dataset == 'pascal':
                    ignore_mask = self.ignoremasks[sample_idx]
                else:
                    ignore_mask = None

                if self.dataset == 'pascal':
                    image_dir = os.path.join(args.pascal_dir, 'JPEGImages',
                                             sample_idx + '.jpg')
                elif self.dataset == 'cityscapes':
                    sample_idx = sample_idx.split('.')[0]
                    image_dir = sample_idx + '.png'
                elif self.dataset == 'leaves':
                    image_dir = sample_idx

                im = imread(image_dir)
                h = im.shape[0]
                w = im.shape[1]
                objectness_scores = []
                class_scores = []
                reached_end = False
                for i in range(out_masks.shape[1]):

                    if reached_end:
                        break
                    objectness = stop_scores[s][i][0]
                    if objectness < args.stop_th:
                        continue
                    pred_mask = out_masks[s][i]
                    # store class with max confidence for display
                    if args.class_th == 0.0:
                        max_class = 1
                    else:
                        max_class = out_classes[s][i]
                    # process mask to create annotation

                    pred_mask, is_valid, raw_pred_mask = resize_mask(
                        args, pred_mask, h, w, ignore_mask)

                    # for evaluation we repeat the mask with all its class probs
                    for cls_id in range(len(self.class_names)):
                        if cls_id == 0:
                            # ignore eos
                            continue

                        pred_class_score = out_scores[s][i][cls_id]
                        pred_class_score_mod = pred_class_score * objectness

                        ann = create_annotation(self.args, sample_idx,
                                                pred_mask, cls_id,
                                                pred_class_score_mod,
                                                self.class_names, is_valid)
                        if ann is not None:
                            if self.dataset == 'leaves':
                                if objectness > args.stop_th:
                                    this_pred.append(ann)
                            else:
                                # for display we only take the mask with max confidence
                                if cls_id == max_class and pred_class_score_mod >= self.args.class_th:
                                    ann_save = create_annotation(
                                        self.args, sample_idx, raw_pred_mask,
                                        cls_id, pred_class_score_mod,
                                        self.class_names, is_valid)
                                    this_pred.append(ann_save)

                            predictions.append(ann)

                if self.display:
                    figures_dir = os.path.join(
                        '../models', args.model_name,
                        args.model_name + '_figs_' + args.eval_split)
                    make_dir(figures_dir)

                    plt.figure()
                    plt.axis('off')
                    plt.figure()
                    plt.axis('off')
                    plt.imshow(im)
                    display_masks(this_pred,
                                  self.colors,
                                  im_height=im.shape[0],
                                  im_width=im.shape[1],
                                  no_display_text=self.args.no_display_text)

                    if self.dataset == 'cityscapes':
                        sample_idx = sample_idx.split('/')[-1]
                    if self.dataset == 'leaves':
                        sample_idx = sample_idx.split('/')[-1]
                    figname = os.path.join(figures_dir, sample_idx)
                    plt.savefig(figname, bbox_inches='tight')
                    plt.close()

            acc_samples += np.shape(out_masks)[0]

        return predictions
Esempio n. 7
0
    def run_eval(self):
        print("Dataset is %s" % (self.dataset))
        print("Split is %s" % (self.split))

        if args.overlay_masks:

            colors = []
            palette = sequence_palette()
            inv_palette = {}
            for k, v in palette.items():
                inv_palette[v] = k
            num_colors = len(inv_palette.keys())
            for id_color in range(num_colors):
                if id_color == 0 or id_color == 21:
                    continue
                c = inv_palette[id_color]
                colors.append(c)

        if self.split == 'val-inference':

            if args.dataset == 'youtube':

                masks_sep_dir = os.path.join('../models', args.model_name, 'masks_sep_2assess')
                make_dir(masks_sep_dir)
                if args.overlay_masks:
                    results_dir = os.path.join('../models', args.model_name, 'results')
                    make_dir(results_dir)

                json_data = open('../../databases/YouTubeVOS/train/train-val-meta.json')
                data = json.load(json_data)

            elif args.dataset == 'davis2017':

                import lmdb
                from misc.config import cfg

                masks_sep_dir = os.path.join('../models', args.model_name, 'masks_sep_2assess-davis')
                make_dir(masks_sep_dir)

                if args.overlay_masks:
                    results_dir = os.path.join('../models', args.model_name, 'results-davis')
                    make_dir(results_dir)

                lmdb_env_seq_dir = osp.join(cfg.PATH.DATA, 'lmdb_seq')

                if osp.isdir(lmdb_env_seq_dir):
                    lmdb_env_seq = lmdb.open(lmdb_env_seq_dir)
                else:
                    lmdb_env_seq = None
            else:

                import lmdb
                from misc.config_kittimots import cfg

                masks_sep_dir = os.path.join('../models', args.model_name, 'masks_sep_2assess-kitti')
                make_dir(masks_sep_dir)

                if args.overlay_masks:
                    results_dir = os.path.join('../models', args.model_name, 'results-kitti')
                    make_dir(results_dir)

                lmdb_env_seq_dir = osp.join(cfg.PATH.DATA, 'lmdb_seq')

                if osp.isdir(lmdb_env_seq_dir):
                    lmdb_env_seq = lmdb.open(lmdb_env_seq_dir)
                else:
                    lmdb_env_seq = None

            for batch_idx, (inputs, targets, seq_name, starting_frame, frames_with_new_ids) in enumerate(self.loader):

                prev_hidden_temporal_list = None
                max_ii = min(len(inputs), args.length_clip)
                frames_with_new_ids = np.array(frames_with_new_ids)
                #print('Variable max_ii')
                #print(max_ii)

                if args.overlay_masks:
                    base_dir = results_dir + '/' + seq_name[0] + '/'
                    make_dir(base_dir)

                if args.dataset == 'davis2017':
                    key_db = osp.basename(seq_name[0])

                    if not lmdb_env_seq == None:
                        with lmdb_env_seq.begin() as txn:
                            _files_vec = txn.get(key_db.encode()).decode().split('|')
                            _files = [osp.splitext(f)[0] for f in _files_vec]
                    else:
                        seq_dir = osp.join(cfg['PATH']['SEQUENCES'], key_db)
                        _files_vec = os.listdir(seq_dir)
                        _files = [osp.splitext(f)[0] for f in _files_vec]

                    frame_names = sorted(_files)

                if args.dataset == 'kittimots':
                    key_db = osp.basename(seq_name[0])

                    if not lmdb_env_seq == None:
                        with lmdb_env_seq.begin() as txn:
                            _files_vec = txn.get(key_db.encode()).decode().split('|')
                            _files = [osp.splitext(f)[0] for f in _files_vec]
                    else:
                        seq_dir = osp.join(cfg['PATH']['SEQUENCES'], key_db)
                        _files_vec = os.listdir(seq_dir)
                        _files = [osp.splitext(f)[0] for f in _files_vec]

                    frame_names = sorted(_files)  # llistat de frames d'una seqüència de video

                dict_outs = {}
                #print("NEW OBJECTS FRAMES", frames_with_new_ids)

                # make a dir of results for each instance
                '''for t in range(args.maxseqlen):
                    base_dir_2 = results_dir + '/' + seq_name[0] + '/' + str(t)
                    make_dir(base_dir_2)'''

                for ii in range(max_ii):  # iteració sobre els frames/clips amb dimensio lenght_clip

                    # start_time = time.time()
                    #                x: input images (N consecutive frames from M different sequences)
                    #                y_mask: ground truth annotations (some of them are zeros to have a fixed length in number of object instances)
                    #                sw_mask: this mask indicates which masks from y_mask are valid
                    x, y_mask, sw_mask = batch_to_var(args, inputs[ii], targets[ii])

                    #print(seq_name[0] + '/' + frame_names[ii])

                    if ii == 0:
                        #one-shot approach, information about the first frame of the sequende
                        prev_mask = y_mask

                        #list of the first instances that appear on the sequence and update of the dictionary
                        annotation = Image.open(
                            '../../databases/KITTIMOTS/Annotations/' + seq_name[0] + '/' + frame_names[
                                ii] + '.png').convert('P')
                        annot = np.expand_dims(annotation, axis=0)
                        annot = torch.from_numpy(annot)
                        annot = annot.float()
                        instance_ids = np.unique(annot)
                        for i in instance_ids[1:]:
                            dict_outs.update({str(int(i-1)):int(i)})
                        #instances = len(instance_ids)-1


                    #one-shot approach, add GT information when a new instance appears on the video sequence
                    if args.dataset == 'kittimots':
                        if ii > 0 and ii in frames_with_new_ids:
                            frame_name = frame_names[ii]
                            annotation = Image.open(
                                '../../databases/KITTIMOTS/Annotations/' + seq_name[
                                    0] + '/' + frame_name + '.png').convert('P')

                            #annot = imresize(annotation, (256, 448), interp='nearest')
                            annot = imresize(annotation, (287,950), interp='nearest')
                            #annot = imresize(annotation, (412, 723), interp='nearest')
                            #annot = imresize(annotation, (178, 590), interp='nearest')
                            annot = np.expand_dims(annot, axis=0)
                            annot = torch.from_numpy(annot)
                            annot = annot.float()
                            annot = annot.numpy().squeeze()

                            new_instance_ids = np.setdiff1d(np.unique(annot), instance_ids)
                            annot = annot_from_mask(annot, new_instance_ids)
                            annot = np.expand_dims(annot, axis=0)
                            annot = torch.from_numpy(annot)
                            annot = Variable(annot.float(), requires_grad=False)
                            annot = annot.cuda()
                            #adding only the information of the new instance after the last active branch
                            for kk in new_instance_ids:
                                if dict_outs:
                                    last = int(list(dict_outs.keys())[-1])
                                else:
                                    #if the dictionary is empty
                                    last = -1
                                prev_mask[:, int(last+1), :] = annot[:, int(kk - 1), :]
                                dict_outs.update({str(last+1):int(kk)})
                            del annot
                            #update the list of instances that have appeared on the video sequence
                            if len(new_instance_ids) > 0:
                                instance_ids = np.append(instance_ids, new_instance_ids)
                            #instances = instances + len(new_instance_ids)

                    # from one frame to the following frame the prev_hidden_temporal_list is updated.
                    outs, hidden_temporal_list = test_prev_mask(args, self.encoder, self.decoder, x,
                                                                prev_hidden_temporal_list, prev_mask)

                    # end_inference_time = time.time()
                    # print("inference time: %.3f" %(end_inference_time-start_time))


                    if args.dataset == 'youtube':
                        num_instances = len(data['videos'][seq_name[0]]['objects'])
                    else:
                        num_instances = int(torch.sum(sw_mask.data).data.cpu().numpy())
                    num_instances = args.maxseqlen

                    base_dir_masks_sep = masks_sep_dir + '/' + seq_name[0] + '/'
                    make_dir(base_dir_masks_sep)

                    x_tmp = x.data.cpu().numpy()
                    height = x_tmp.shape[-2]
                    width = x_tmp.shape[-1]

                    '''out_stop = outs[1]
                    outs = outs[0]
                    for m in range(len(out_stop[0])):
                            print(m)
                            print(out_stop[0][m])'''


                    #print("OUT STOP: ", out_stop[0])

                    outs_masks = np.zeros((args.maxseqlen,), dtype=int)

                    for t in range(num_instances):

                        mask_pred = (torch.squeeze(outs[0, t, :])).cpu().numpy()
                        mask_pred = np.reshape(mask_pred, (height, width))
                        indxs_instance = np.where(mask_pred > 0.5)
                        '''indxs_instance = np.where((0.6 > mask_pred) & (mask_pred>= 0.5))
                        indxs_instance_1 = np.where((0.7 > mask_pred) & (mask_pred>= 0.6))
                        indxs_instance_2 = np.where((0.8 > mask_pred) & (mask_pred>= 0.7))
                        indxs_instance_3 = np.where((0.9 > mask_pred) & (mask_pred>= 0.8))
                        indxs_instance_4 = np.where((0.999> mask_pred) & (mask_pred>= 0.9))
                        indxs_instance_5 = np.where((0.9999 > mask_pred) & (mask_pred>= 0.999))'''



                        mask2assess = np.zeros((height, width))
                        mask2assess[indxs_instance] = 255

                        ''' mask2assess[indxs_instance] = 40
                        mask2assess[indxs_instance_1] = 80
                        mask2assess[indxs_instance_2] = 120
                        mask2assess[indxs_instance_3] = 160
                        mask2assess[indxs_instance_4] = 200
                        mask2assess[indxs_instance_5] = 255'''


                        if str(t) in dict_outs:
                            i = dict_outs[str(t)]
                        else:
                            break

                        if args.dataset == 'youtube':
                            toimage(mask2assess, cmin=0, cmax=255).save(
                                base_dir_masks_sep + '%05d_instance_%02d.png' % (starting_frame[0] + ii, i))
                        else:
                            toimage(mask2assess, cmin=0, cmax=255).save(
                                base_dir_masks_sep + frame_names[ii] + '_instance_%02d.png' % (i))
                        #create vector of predictions, gives information about which branches are active
                        if len(indxs_instance[0]) != 0:
                            outs_masks[t] = 1
                        else:
                            outs_masks[t] = 0

                    outs = outs.cpu().numpy()
                    #print("INS: ", outs_masks)
                    #print(json.dumps(dict_outs))

                    #delete spurious branches
                    last_position = last_ocurrence(outs_masks, 1) + 1
                    while len(dict_outs) < last_position:
                        for n in range(args.maxseqlen):
                            if outs_masks[n] == 1 and str(n) not in dict_outs:
                                outs = np.delete(outs, n, axis=1)
                                outs_masks = np.delete(outs_masks, n)
                                del hidden_temporal_list[n]
                                z = np.zeros((height * width))
                                outs = np.insert(outs, args.maxseqlen - 1, z, axis=1)
                                hidden_temporal_list.append(None)
                                outs_masks = np.append(outs_masks, 0)
                        last_position = last_ocurrence(outs_masks, 1) + 1

                    instances = sum(outs_masks)  # number of active branches
                    #delete branches of instances that disappear and rearrange
                    for n in range(args.maxseqlen):

                        while outs_masks[n] == 0 and n < instances:

                            outs = np.delete(outs, n, axis=1 )
                            outs_masks = np.delete(outs_masks, n)
                            del hidden_temporal_list[n]
                            z = np.zeros((height * width))
                            outs = np.insert(outs, args.maxseqlen-1, z, axis=1)
                            hidden_temporal_list.append(None)
                            outs_masks = np.append(outs_masks, 0)


                            #update dictionary by shifting entries
                            for m in range(len(dict_outs)-(n+1)):
                                value = dict_outs[str(m + n + 1)]
                                dict_outs.update({str(n + m): value})
                            #print(json.dumps(dict_outs))
                            last = int(list(dict_outs.keys())[-1])
                            del dict_outs[str(last)]

                    #an instance has disappeared, update dictionary
                    while len(dict_outs) > sum(outs_masks):
                        last = int(list(dict_outs.keys())[-1])
                        del dict_outs[str(last)]

                    outs = torch.from_numpy(outs)
                    outs = outs.cuda()
                    #print("OUTS: ", outs_masks)
                    #print(json.dumps(dict_outs))



                    # end_saving_masks_time = time.time()
                    # print("inference + saving masks time: %.3f" %(end_saving_masks_time - start_time))
                    if args.dataset == 'youtube':
                        print(seq_name[0] + '/' + '%05d' % (starting_frame[0] + ii))
                    else:
                        print(seq_name[0] + '/' + frame_names[ii])

                    if args.overlay_masks:

                        frame_img = x.data.cpu().numpy()[0, :, :, :].squeeze()
                        frame_img = np.transpose(frame_img, (1, 2, 0))
                        mean = np.array([0.485, 0.456, 0.406])
                        std = np.array([0.229, 0.224, 0.225])
                        frame_img = std * frame_img + mean
                        frame_img = np.clip(frame_img, 0, 1)
                        plt.figure();
                        plt.axis('off')
                        plt.figure();
                        plt.axis('off')
                        plt.imshow(frame_img)
                        #print("INSTANCES: ", instances)

                        for t in range(instances):

                            mask_pred = (torch.squeeze(outs[0, t, :])).cpu().numpy()
                            mask_pred = np.reshape(mask_pred, (height, width))
                            ax = plt.gca()
                            tmp_img = np.ones((mask_pred.shape[0], mask_pred.shape[1], 3))
                            if str(t) in dict_outs:
                                color_mask = np.array(colors[dict_outs[str(t)]]) / 255.0
                            else:
                                #color_mask = np.array(colors[0]) / 255.0
                                break
                            for i in range(3):
                                tmp_img[:, :, i] = color_mask[i]
                            ax.imshow(np.dstack((tmp_img, mask_pred * 0.7)))

                        if args.dataset == 'youtube':
                            figname = base_dir + 'frame_%02d.png' % (starting_frame[0] + ii)
                        else:
                            figname = base_dir + frame_names[ii] + '.png'

                        plt.savefig(figname, bbox_inches='tight')
                        plt.close()

                    #Print a video for each instance
                    '''for t in range(instances):

                        if args.overlay_masks:

                            frame_img = x.data.cpu().numpy()[0, :, :, :].squeeze()
                            frame_img = np.transpose(frame_img, (1, 2, 0))
                            mean = np.array([0.485, 0.456, 0.406])
                            std = np.array([0.229, 0.224, 0.225])
                            frame_img = std * frame_img + mean
                            frame_img = np.clip(frame_img, 0, 1)
                            plt.figure();
                            plt.axis('off')
                            plt.figure();
                            plt.axis('off')
                            plt.imshow(frame_img)


                            mask_pred = (torch.squeeze(outs[0, t, :])).cpu().numpy()
                            mask_pred = np.reshape(mask_pred, (height, width))
                            ax = plt.gca()
                            tmp_img = np.ones((mask_pred.shape[0], mask_pred.shape[1], 3))
                            if str(t) in dict_outs:
                                color_mask = np.array(colors[dict_outs[str(t)]]) / 255.0
                            else:
                                #color_mask = np.array(colors[0]) / 255.0
                                break
                            for i in range(3):
                                tmp_img[:, :, i] = color_mask[i]
                            ax.imshow(np.dstack((tmp_img, mask_pred * 0.7)))

                            figname = base_dir + '/' + str(dict_outs[str(t)]) + '/' + frame_names[ii] + '.png'

                            plt.savefig(figname, bbox_inches='tight')
                            plt.close()'''

                    if self.video_mode:
                        if args.only_spatial == False:
                            prev_hidden_temporal_list = hidden_temporal_list
                        if ii > 0:
                            prev_mask = outs
                        else:
                            prev_mask = y_mask

                    del outs, hidden_temporal_list, x, y_mask, sw_mask

        else:

            if args.dataset == 'youtube':

                masks_sep_dir = os.path.join('../models', args.model_name, 'masks_sep_2assess_val')
                make_dir(masks_sep_dir)
                if args.overlay_masks:
                    results_dir = os.path.join('../models', args.model_name, 'results_val')
                    make_dir(results_dir)

                json_data = open('../../databases/YouTubeVOS/val/meta.json')
                data = json.load(json_data)

            elif args.dataset == 'davis2017':

                import lmdb
                from misc.config import cfg

                masks_sep_dir = os.path.join('../models', args.model_name, 'masks_sep_2assess_val_davis')
                make_dir(masks_sep_dir)
                if args.overlay_masks:
                    results_dir = os.path.join('../models', args.model_name, 'results_val_davis')
                    make_dir(results_dir)

                lmdb_env_seq_dir = osp.join(cfg.PATH.DATA, 'lmdb_seq')

                if osp.isdir(lmdb_env_seq_dir):
                    lmdb_env_seq = lmdb.open(lmdb_env_seq_dir)
                else:
                    lmdb_env_seq = None
            else:

                import lmdb
                from misc.config_kittimots import cfg

                masks_sep_dir = os.path.join('../models', args.model_name, 'masks_sep_2assess-kitti')
                make_dir(masks_sep_dir)

                if args.overlay_masks:
                    results_dir = os.path.join('../models', args.model_name, 'results-kitti')
                    make_dir(results_dir)

                lmdb_env_seq_dir = osp.join(cfg.PATH.DATA, 'lmdb_seq')

                if osp.isdir(lmdb_env_seq_dir):
                    lmdb_env_seq = lmdb.open(lmdb_env_seq_dir)
                else:
                    lmdb_env_seq = None

            for batch_idx, (inputs, seq_name, starting_frame) in enumerate(self.loader):

                prev_hidden_temporal_list = None
                max_ii = min(len(inputs), args.length_clip)

                if args.overlay_masks:
                    base_dir = results_dir + '/' + seq_name[0] + '/'
                    make_dir(base_dir)

                if args.dataset == 'youtube':

                    seq_data = data['videos'][seq_name[0]]['objects']
                    frame_names = []
                    frame_names_with_new_objects = []
                    instance_ids = []

                    for obj_id in seq_data.keys():
                        instance_ids.append(int(obj_id))
                        frame_names_with_new_objects.append(seq_data[obj_id]['frames'][0])
                        for frame_name in seq_data[obj_id]['frames']:
                            if frame_name not in frame_names:
                                frame_names.append(frame_name)

                    frame_names.sort()
                    frame_names_with_new_objects_idxs = []
                    for kk in range(len(frame_names_with_new_objects)):
                        new_frame_idx = frame_names.index(frame_names_with_new_objects[kk])
                        frame_names_with_new_objects_idxs.append(new_frame_idx)

                elif args.dataset == 'davis2017':

                    key_db = osp.basename(seq_name[0])

                    if not lmdb_env_seq == None:
                        with lmdb_env_seq.begin() as txn:
                            _files_vec = txn.get(key_db.encode()).decode().split('|')
                            _files = [osp.splitext(f)[0] for f in _files_vec]
                    else:
                        seq_dir = osp.join(cfg['PATH']['SEQUENCES'], key_db)
                        _files_vec = os.listdir(seq_dir)
                        _files = [osp.splitext(f)[0] for f in _files_vec]

                    frame_names = sorted(_files)
                else:

                    key_db = osp.basename(seq_name[0])

                    if not lmdb_env_seq == None:
                        with lmdb_env_seq.begin() as txn:
                            _files_vec = txn.get(key_db.encode()).decode().split('|')
                            _files = [osp.splitext(f)[0] for f in _files_vec]
                    else:
                        seq_dir = osp.join(cfg['PATH']['SEQUENCES'], key_db)
                        _files_vec = os.listdir(seq_dir)
                        _files = [osp.splitext(f)[0] for f in _files_vec]

                    # frame_names_with_new_objects_idxs = [3,6,9]
                    frame_names = sorted(_files)

                for ii in range(max_ii):

                    #                x: input images (N consecutive frames from M different sequences)
                    #                y_mask: ground truth annotations (some of them are zeros to have a fixed length in number of object instances)
                    #                sw_mask: this mask indicates which masks from y_mask are valid
                    x = batch_to_var_test(args, inputs[ii])

                    print(seq_name[0] + '/' + frame_names[ii])

                    if ii == 0:

                        frame_name = frame_names[0]
                        if args.dataset == 'youtube':
                            annotation = Image.open(
                                '../../databases/YouTubeVOS/val/Annotations/' + seq_name[0] + '/' + frame_name + '.png')
                            annot = imresize(annotation, (256, 448), interp='nearest')
                        elif args.dataset == 'davis2017':
                            annotation = Image.open(
                                '../../databases/DAVIS2017/Annotations/480p/' + seq_name[0] + '/' + frame_name + '.png')
                            instance_ids = sorted(np.unique(annotation))
                            instance_ids = instance_ids if instance_ids[0] else instance_ids[1:]
                            if len(instance_ids) > 0:
                                instance_ids = instance_ids[:-1] if instance_ids[-1] == 255 else instance_ids
                            annot = imresize(annotation, (240, 427), interp='nearest')
                        else:  # kittimots
                            annotation = Image.open(
                                '../../databases/KITTIMOTS/Annotations/' + seq_name[0] + '/' + frame_name + '.png')
                            instance_ids = sorted(np.unique(annotation))
                            instance_ids = instance_ids if instance_ids[0] else instance_ids[1:]
                            print("IDS instances: ", instance_ids)
                            if len(instance_ids) > 0:
                                instance_ids = instance_ids[:-1] if instance_ids[-1] == 255 else instance_ids
                            annot = imresize(annotation, (256, 448), interp='nearest')

                        annot = np.expand_dims(annot, axis=0)
                        annot = torch.from_numpy(annot)
                        annot = annot.float()
                        annot = annot.numpy().squeeze()
                        annot = annot_from_mask(annot, instance_ids)
                        prev_mask = annot
                        prev_mask = np.expand_dims(prev_mask, axis=0)
                        prev_mask = torch.from_numpy(prev_mask)
                        y_mask = Variable(prev_mask.float(), requires_grad=False)
                        prev_mask = y_mask.cuda()
                        del annot

                    if args.dataset == 'youtube':
                        if ii > 0 and ii in frame_names_with_new_objects_idxs:

                            frame_name = frame_names[ii]
                            annotation = Image.open(
                                '../../databases/YouTubeVOS/val/Annotations/' + seq_name[0] + '/' + frame_name + '.png')
                            annot = imresize(annotation, (256, 448), interp='nearest')
                            annot = np.expand_dims(annot, axis=0)
                            annot = torch.from_numpy(annot)
                            annot = annot.float()
                            annot = annot.numpy().squeeze()
                            new_instance_ids = np.unique(annot)[1:]
                            annot = annot_from_mask(annot, new_instance_ids)
                            annot = np.expand_dims(annot, axis=0)
                            annot = torch.from_numpy(annot)
                            annot = Variable(annot.float(), requires_grad=False)
                            annot = annot.cuda()
                            for kk in new_instance_ids:
                                prev_mask[:, int(kk - 1), :] = annot[:, int(kk - 1), :]
                            del annot

                    # from one frame to the following frame the prev_hidden_temporal_list is updated.
                    outs, hidden_temporal_list = test_prev_mask(args, self.encoder, self.decoder, x,
                                                                prev_hidden_temporal_list, prev_mask)

                    base_dir_masks_sep = masks_sep_dir + '/' + seq_name[0] + '/'
                    make_dir(base_dir_masks_sep)

                    x_tmp = x.data.cpu().numpy()
                    height = x_tmp.shape[-2]
                    width = x_tmp.shape[-1]

                    for t in range(len(instance_ids)):
                        mask_pred = (torch.squeeze(outs[0, t, :])).cpu().numpy()
                        mask_pred = np.reshape(mask_pred, (height, width))
                        indxs_instance = np.where(mask_pred > 0.5)
                        mask2assess = np.zeros((height, width))
                        mask2assess[indxs_instance] = 255
                        toimage(mask2assess, cmin=0, cmax=255).save(
                            base_dir_masks_sep + frame_names[ii] + '_instance_%02d.png' % (t))

                    if args.overlay_masks:

                        frame_img = x.data.cpu().numpy()[0, :, :, :].squeeze()
                        frame_img = np.transpose(frame_img, (1, 2, 0))
                        mean = np.array([0.485, 0.456, 0.406])
                        std = np.array([0.229, 0.224, 0.225])
                        frame_img = std * frame_img + mean
                        frame_img = np.clip(frame_img, 0, 1)
                        plt.figure();
                        plt.axis('off')
                        plt.figure();
                        plt.axis('off')
                        plt.imshow(frame_img)

                        for t in range(len(instance_ids)):

                            mask_pred = (torch.squeeze(outs[0, t, :])).cpu().numpy()
                            mask_pred = np.reshape(mask_pred, (height, width))
                            ax = plt.gca()
                            tmp_img = np.ones((mask_pred.shape[0], mask_pred.shape[1], 3))
                            color_mask = np.array(colors[t]) / 255.0
                            for i in range(3):
                                tmp_img[:, :, i] = color_mask[i]
                            ax.imshow(np.dstack((tmp_img, mask_pred * 0.7)))

                        figname = base_dir + frame_names[ii] + '.png'
                        plt.savefig(figname, bbox_inches='tight')
                        plt.close()

                    if self.video_mode:
                        if args.only_spatial == False:
                            prev_hidden_temporal_list = hidden_temporal_list
                        if ii > 0:
                            prev_mask = outs
                        del x, hidden_temporal_list, outs
Esempio n. 8
0
def trainIters(args):

    epoch_resume = 0
    model_dir = os.path.join('../models/', args.model_name)

    if args.resume:
        # will resume training the model with name args.model_name
        encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, load_args = load_checkpoint(args.model_name,args.use_gpu)

        epoch_resume = load_args.epoch_resume
        encoder = FeatureExtractor(load_args)
        decoder = RSIS(load_args)
        encoder_dict, decoder_dict = check_parallel(encoder_dict,decoder_dict)
        encoder.load_state_dict(encoder_dict)
        decoder.load_state_dict(decoder_dict)

        args = load_args

    elif args.transfer:
        # load model from args and replace last fc layer
        encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, load_args = load_checkpoint(args.transfer_from,args.use_gpu)
        encoder = FeatureExtractor(load_args)
        decoder = RSIS(args)
        encoder_dict, decoder_dict = check_parallel(encoder_dict,decoder_dict)
        encoder.load_state_dict(encoder_dict)
        decoder.load_state_dict(decoder_dict)

    else:
        encoder = FeatureExtractor(args)
        decoder = RSIS(args)

    # model checkpoints will be saved here
    make_dir(model_dir)

    # save parameters for future use
    pickle.dump(args, open(os.path.join(model_dir,'args.pkl'),'wb'))

    encoder_params = get_base_params(args,encoder)
    skip_params = get_skip_params(encoder)
    decoder_params = list(decoder.parameters()) + list(skip_params)
    dec_opt = get_optimizer(args.optim, args.lr, decoder_params, args.weight_decay)
    enc_opt = get_optimizer(args.optim_cnn, args.lr_cnn, encoder_params, args.weight_decay_cnn)

    if args.resume or args.transfer:
        enc_opt.load_state_dict(enc_opt_dict)
        dec_opt.load_state_dict(dec_opt_dict)
        from collections import defaultdict
        dec_opt.state = defaultdict(dict, dec_opt.state)

        # change fc layer for new classes
        if load_args.dataset != args.dataset and args.transfer:
            dim_in = decoder.fc_class.weight.size()[1]
            decoder.fc_class = nn.Linear(dim_in,args.num_classes)

    if not args.log_term:
        print "Training logs will be saved to:", os.path.join(model_dir, 'train.log')
        sys.stdout = open(os.path.join(model_dir, 'train.log'), 'w')
        sys.stderr = open(os.path.join(model_dir, 'train.err'), 'w')

    print args

    # objective functions for mask and class outputs.
    # these return the average across samples in batch whose value
    # needs to be considered (those where sw is 1)
    # mask_xentropy = BalancedStableMaskedBCELoss()
    mask_siou = softIoULoss()

    class_xentropy = MaskedNLLLoss(balance_weight=None)
    stop_xentropy = MaskedBCELoss(balance_weight=args.stop_balance_weight)

    if args.ngpus > 1 and args.use_gpu:
        decoder = torch.nn.DataParallel(decoder, device_ids=range(args.ngpus))
        encoder = torch.nn.DataParallel(encoder, device_ids=range(args.ngpus))
        mask_siou = torch.nn.DataParallel(mask_siou, device_ids=range(args.ngpus))
        class_xentropy = torch.nn.DataParallel(class_xentropy, device_ids=range(args.ngpus))
        stop_xentropy = torch.nn.DataParallel(stop_xentropy, device_ids=range(args.ngpus))
    if args.use_gpu:
        encoder.cuda()
        decoder.cuda()
        class_xentropy.cuda()
        mask_siou.cuda()
        stop_xentropy.cuda()

    crits = [mask_siou, class_xentropy, stop_xentropy]
    optims = [enc_opt, dec_opt]
    if args.use_gpu:
        torch.cuda.synchronize()
    start = time.time()

    # vars for early stopping
    best_val_loss = args.best_val_loss
    acc_patience = 0
    mt_val = -1

    # init windows to visualize, if visdom is enabled
    if args.visdom:
        import visdom
        viz = visdom.Visdom(port=args.port, server=args.server)
        lot, elot, mviz_pred, mviz_true, image_lot = init_visdom(args, viz)

    if args.curriculum_learning and epoch_resume == 0:
            args.limit_seqlen_to = 2

    # keep track of the number of batches in each epoch for continuity when plotting curves
    loaders, class_names = init_dataloaders(args)
    num_batches = {'train': 0, 'val': 0}
    for e in range(args.max_epoch):
        print "Epoch", e + epoch_resume
        # store losses in lists to display average since beginning
        epoch_losses = {'train': {'total': [], 'iou': [], 'stop': [], 'class': []},
                            'val': {'total': [], 'iou': [], 'stop': [], 'class': []}}
            # total mean for epoch will be saved here to display at the end
        total_losses = {'total': [], 'iou': [], 'stop': [], 'class': []}

        # check if it's time to do some changes here
        if e + epoch_resume >= args.finetune_after and not args.update_encoder and not args.finetune_after == -1:
            print("Starting to update encoder")
            args.update_encoder = True
            acc_patience = 0
            mt_val = -1
        if e + epoch_resume >= args.class_loss_after and not args.use_class_loss and not args.class_loss_after == -1:
            print("Starting to learn class loss")
            args.use_class_loss = True
            best_val_loss = 1000  # reset because adding a loss term will increase the total value
            acc_patience = 0
            mt_val = -1
        if e + epoch_resume >= args.stop_loss_after and not args.use_stop_loss and not args.stop_loss_after == -1:
            if args.curriculum_learning:
                if args.limit_seqlen_to > args.min_steps:
                    print("Starting to learn stop loss")
                    args.use_stop_loss = True
                    best_val_loss = 1000 # reset because adding a loss term will increase the total value
                    acc_patience = 0
                    mt_val = -1
            else:
                print("Starting to learn stop loss")
                args.use_stop_loss = True
                best_val_loss = 1000 # reset because adding a loss term will increase the total value
                acc_patience = 0
                mt_val = -1

        # we validate after each epoch
        for split in ['train', 'val']:
            for batch_idx, (inputs, targets) in enumerate(loaders[split]):
                # send batch to GPU

                x, y_mask, y_class, sw_mask, sw_class = batch_to_var(args, inputs, targets)

                # we forward (and backward & update if training set)
                losses, outs, true_perm = runIter(args, encoder, decoder, x, y_mask,
                                                  y_class, sw_mask, sw_class,
                                                  crits, optims, mode=split)

                # store loss values in dictionary separately
                epoch_losses[split]['total'].append(losses[0])
                epoch_losses[split]['iou'].append(losses[1])
                epoch_losses[split]['stop'].append(losses[2])
                epoch_losses[split]['class'].append(losses[3])


                # print and display in visdom after some iterations
                if (batch_idx + 1)% args.print_every == 0:

                    mt = np.mean(epoch_losses[split]['total'])
                    mi = np.mean(epoch_losses[split]['iou'])
                    mc = np.mean(epoch_losses[split]['class'])
                    mx = np.mean(epoch_losses[split]['stop'])
                    if args.visdom:

                        if split == 'train':
                            # we display batch loss values in visdom (Training only)
                            viz.line(
                                X=torch.ones((1, 4)).cpu() * (batch_idx + e * num_batches[split]),
                                Y=torch.Tensor([mi, mx, mc, mt]).unsqueeze(0).cpu(),
                                win=lot,
                                update='append')
                        w = x.size()[-1]
                        h = x.size()[-2]
                        out_masks, out_classes, y_mask, y_class = outs_perms_to_cpu(args, outs, true_perm, h, w)

                        x = x.data.cpu().numpy()
                        # send image, sample predictions and ground truths to visdom
                        for t in range(np.shape(out_masks)[1]):
                            mask_pred = out_masks[0, t]
                            mask_true = y_mask[0, t]
                            class_pred = class_names[out_classes[0, t]]
                            class_true = class_names[y_class[0, t]]
                            mask_pred = np.reshape(mask_pred, (x.shape[-2], x.shape[-1]))
                            mask_true = np.reshape(mask_true, (x.shape[-2], x.shape[-1]))

                            # heatmap displays the mask upside down
                            viz.heatmap(np.flipud(mask_pred), win=mviz_pred[t],
                                        opts=dict(title='pred mask %d %s' % (t, class_pred)))
                            viz.heatmap(np.flipud(mask_true), win=mviz_true[t],
                                        opts=dict(title='true mask %d %s' % (t, class_true)))
                            viz.image((x[0] * 0.2 + 0.5) * 256, win=image_lot,
                                      opts=dict(title='image (unnnormalized)'))

                    te = time.time() - start
                    print "iter %d:\ttotal:%.4f\tclass:%.4f\tiou:%.4f\tstop:%.4f\ttime:%.4f" % (batch_idx, mt, mc, mi, mx, te)
                    if args.use_gpu:
                        torch.cuda.synchronize()
                    start = time.time()

            num_batches[split] = batch_idx + 1
            # compute mean val losses within epoch

            if split == 'val' and args.smooth_curves:
                if mt_val == -1:
                    mt = np.mean(epoch_losses[split]['total'])
                else:
                    mt = 0.9*mt_val + 0.1*np.mean(epoch_losses[split]['total'])
                mt_val = mt

            else:
                mt = np.mean(epoch_losses[split]['total'])

            mi = np.mean(epoch_losses[split]['iou'])
            mc = np.mean(epoch_losses[split]['class'])
            mx = np.mean(epoch_losses[split]['stop'])


            # save train and val losses for the epoch to display in visdom
            total_losses['iou'].append(mi)
            total_losses['class'].append(mc)
            total_losses['stop'].append(mx)
            total_losses['total'].append(mt)

            args.epoch_resume = e + epoch_resume

            print "Epoch %d:\ttotal:%.4f\tclass:%.4f\tiou:%.4f\tstop:%.4f\t(%s)" % (e, mt, mc, mi,mx, split)

        # epoch losses
        if args.visdom:
            update = True if e == 0 else 'append'
            for l in ['total', 'iou', 'stop', 'class']:
                viz.line(X=torch.ones((1, 2)).cpu() * (e + 1),
                         Y=torch.Tensor(total_losses[l]).unsqueeze(0).cpu(),
                         win=elot[l],
                         update=update)

        if mt < (best_val_loss - args.min_delta):
            print "Saving checkpoint."
            best_val_loss = mt
            args.best_val_loss = best_val_loss
            # saves model, params, and optimizers
            save_checkpoint(args, encoder, decoder, enc_opt, dec_opt)
            acc_patience = 0
        else:
            acc_patience += 1

        if acc_patience > args.patience and not args.use_class_loss and not args.class_loss_after == -1:
            print("Starting to learn class loss")
            acc_patience = 0
            args.use_class_loss = True
            best_val_loss = 1000  # reset because adding a loss term will increase the total value
            mt_val = -1
            encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, _ = load_checkpoint(args.model_name,args.use_gpu)
            encoder.load_state_dict(encoder_dict)
            decoder.load_state_dict(decoder_dict)
            enc_opt.load_state_dict(enc_opt_dict)
            dec_opt.load_state_dict(dec_opt_dict)
        if acc_patience > args.patience and args.curriculum_learning and args.limit_seqlen_to < args.maxseqlen:
            print("Adding one step more:")
            acc_patience = 0
            args.limit_seqlen_to += args.steps_cl
            print(args.limit_seqlen_to)
            best_val_loss = 1000
            mt_val = -1

        if acc_patience > args.patience and not args.update_encoder and not args.finetune_after == -1:
            print("Starting to update encoder")
            acc_patience = 0
            args.update_encoder = True
            best_val_loss = 1000  # reset because adding a loss term will increase the total value
            mt_val = -1
            encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, _ = load_checkpoint(args.model_name,args.use_gpu)
            encoder.load_state_dict(encoder_dict)
            decoder.load_state_dict(decoder_dict)
            enc_opt.load_state_dict(enc_opt_dict)
            dec_opt.load_state_dict(dec_opt_dict)
        if acc_patience > args.patience and not args.use_stop_loss and not args.stop_loss_after == -1:
            if args.curriculum_learning:
                print("Starting to learn stop loss")
                if args.limit_seqlen_to > args.min_steps:
                    acc_patience = 0
                    args.use_stop_loss = True
                    best_val_loss = 1000 # reset because adding a loss term will increase the total value
                    mt_val = -1
            else:
                print("Starting to learn stop loss")
                acc_patience = 0
                args.use_stop_loss = True
                best_val_loss = 1000 # reset because adding a loss term will increase the total value
                mt_val = -1

            encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, _ = load_checkpoint(args.model_name,args.use_gpu)
            encoder.load_state_dict(encoder_dict)
            decoder.load_state_dict(decoder_dict)
            enc_opt.load_state_dict(enc_opt_dict)
            dec_opt.load_state_dict(dec_opt_dict)
        # early stopping after N epochs without improvement
        if acc_patience > args.patience_stop:
            break