class Model: def __init__(self, args): # Define encoder and decoder self.encoder = None self.decoder = None # Mode self.video_mode = False # Load model self._init_model(args) def _init_model(self, args): print("Loading model: " + args.model_name) encoder_dict, decoder_dict, _, _, load_args = load_checkpoint( args.model_name, args.use_gpu) load_args.use_gpu = args.use_gpu self.encoder = FeatureExtractor(load_args) if args.zero_shot: self.decoder = RSIS(load_args) else: self.decoder = RSISMask(load_args) print(load_args) if args.ngpus > 1 and args.use_gpu: self.decoder = torch.nn.DataParallel(self.decoder, device_ids=range(args.ngpus)) self.encoder = torch.nn.DataParallel(self.encoder, device_ids=range(args.ngpus)) encoder_dict, decoder_dict = check_parallel(encoder_dict, decoder_dict) self.encoder.load_state_dict(encoder_dict) to_be_deleted_dec = [] for k in decoder_dict.keys(): if 'fc_stop' in k: to_be_deleted_dec.append(k) for k in to_be_deleted_dec: del decoder_dict[k] self.decoder.load_state_dict(decoder_dict) if args.use_gpu: self.encoder.cuda() self.decoder.cuda() self.encoder.eval() self.decoder.eval() if load_args.length_clip == 1: self.video_mode = False print('video mode not activated') else: self.video_mode = True print('video mode activated')
def trainIters(args): epoch_resume = 0 model_dir = os.path.join('../models/', args.model_name + '_prev_inference_mask') if args.resume: # will resume training the model with name args.model_name encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, load_args = load_checkpoint( args.model_name, args.use_gpu) epoch_resume = load_args.epoch_resume encoder = FeatureExtractor(load_args) decoder = RSISMask(load_args) encoder_dict, decoder_dict = check_parallel(encoder_dict, decoder_dict) encoder.load_state_dict(encoder_dict) decoder.load_state_dict(decoder_dict) args = load_args elif args.transfer: # load model from args and replace last fc layer encoder_dict, decoder_dict, _, _, load_args = load_checkpoint( args.transfer_from, args.use_gpu) encoder = FeatureExtractor(load_args) decoder = RSISMask(args) encoder_dict, decoder_dict = check_parallel(encoder_dict, decoder_dict) encoder.load_state_dict(encoder_dict) decoder.load_state_dict(decoder_dict) else: encoder = FeatureExtractor(args) decoder = RSISMask(args) # model checkpoints will be saved here make_dir(model_dir) # save parameters for future use pickle.dump(args, open(os.path.join(model_dir, 'args.pkl'), 'wb')) encoder_params = get_base_params(args, encoder) skip_params = get_skip_params(encoder) decoder_params = list(decoder.parameters()) + list(skip_params) dec_opt = get_optimizer(args.optim, args.lr, decoder_params, args.weight_decay) enc_opt = get_optimizer(args.optim_cnn, args.lr_cnn, encoder_params, args.weight_decay_cnn) if args.resume: enc_opt.load_state_dict(enc_opt_dict) dec_opt.load_state_dict(dec_opt_dict) from collections import defaultdict dec_opt.state = defaultdict(dict, dec_opt.state) if not args.log_term: print("Training logs will be saved to:", os.path.join(model_dir, 'train.log')) sys.stdout = open(os.path.join(model_dir, 'train.log'), 'w') sys.stderr = open(os.path.join(model_dir, 'train.err'), 'w') print(args) # objective function for mask mask_siou = softIoULoss() if args.use_gpu: encoder.cuda() decoder.cuda() mask_siou.cuda() crits = mask_siou optims = [enc_opt, dec_opt] if args.use_gpu: torch.cuda.synchronize() start = time.time() # vars for early stopping best_val_loss = args.best_val_loss acc_patience = 0 mt_val = -1 # keep track of the number of batches in each epoch for continuity when plotting curves loaders = init_dataloaders(args) num_batches = {'train': 0, 'val': 0} #area_range = [[0 ** 2, 1e5 ** 2], [0 ** 2, 20 ** 2], [20 ** 2, 59 ** 2], [59 ** 2, 1e5 ** 2]] area_range = [[0**2, 1e5**2], [0**2, 30**2], [30**2, 90**2], [90**2, 1e5**2]] #for (287,950)) resolution = 0 for e in range(args.max_epoch): print("Epoch", e + epoch_resume) # store losses in lists to display average since beginning epoch_losses = { 'train': { 'total': [], 'iou': [] }, 'val': { 'total': [], 'iou': [] } } # total mean for epoch will be saved here to display at the end total_losses = {'total': [], 'iou': []} # check if it's time to do some changes here if e + epoch_resume >= args.finetune_after and not args.update_encoder and not args.finetune_after == -1: print("Starting to update encoder") args.update_encoder = True acc_patience = 0 mt_val = -1 if args.loss_penalization: if e < 10: resolution = area_range[2] else: resolution = area_range[0] # we validate after each epoch for split in ['train', 'val']: if args.dataset == 'davis2017' or args.dataset == 'youtube' or args.dataset == 'kittimots': loaders[split].dataset.set_epoch(e) for batch_idx, (inputs, targets, seq_name, starting_frame) in enumerate(loaders[split]): # send batch to GPU prev_hidden_temporal_list = None loss = None last_frame = False max_ii = min(len(inputs), args.length_clip) for ii in range(max_ii): # If are on the last frame from a clip, we will have to backpropagate the loss back to the beginning of the clip. if ii == max_ii - 1: last_frame = True # x: input images (N consecutive frames from M different sequences) # y_mask: ground truth annotations (some of them are zeros to have a fixed length in number of object instances) # sw_mask: this mask indicates which masks from y_mask are valid x, y_mask, sw_mask = batch_to_var( args, inputs[ii], targets[ii]) if ii == 0: prev_mask = y_mask # From one frame to the following frame the prev_hidden_temporal_list is updated. loss, losses, outs, hidden_temporal_list = runIter( args, encoder, decoder, x, y_mask, sw_mask, resolution, crits, optims, split, loss, prev_hidden_temporal_list, prev_mask, last_frame) # Hidden temporal state from time instant ii is saved to be used when processing next time instant ii+1 if args.only_spatial == False: prev_hidden_temporal_list = hidden_temporal_list prev_mask = outs # store loss values in dictionary separately epoch_losses[split]['total'].append(losses[0]) epoch_losses[split]['iou'].append(losses[1]) # print after some iterations if (batch_idx + 1) % args.print_every == 0: mt = np.mean(epoch_losses[split]['total']) mi = np.mean(epoch_losses[split]['iou']) te = time.time() - start print("iter %d:\ttotal:%.4f\tiou:%.4f\ttime:%.4f" % (batch_idx, mt, mi, te)) if args.use_gpu: torch.cuda.synchronize() start = time.time() num_batches[split] = batch_idx + 1 # compute mean val losses within epoch if split == 'val' and args.smooth_curves: if mt_val == -1: mt = np.mean(epoch_losses[split]['total']) else: mt = 0.9 * mt_val + 0.1 * np.mean( epoch_losses[split]['total']) mt_val = mt else: mt = np.mean(epoch_losses[split]['total']) mi = np.mean(epoch_losses[split]['iou']) # save train and val losses for the epoch total_losses['iou'].append(mi) total_losses['total'].append(mt) args.epoch_resume = e + epoch_resume print("Epoch %d:\ttotal:%.4f\tiou:%.4f\t(%s)" % (e, mt, mi, split)) if mt < (best_val_loss - args.min_delta): print("Saving checkpoint.") best_val_loss = mt args.best_val_loss = best_val_loss # saves model, params, and optimizers save_checkpoint_prev_inference_mask(args, encoder, decoder, enc_opt, dec_opt) acc_patience = 0 else: acc_patience += 1 if acc_patience > args.patience and not args.update_encoder and not args.finetune_after == -1: print("Starting to update encoder") acc_patience = 0 args.update_encoder = True best_val_loss = 1000 # reset because adding a loss term will increase the total value mt_val = -1 encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, _ = load_checkpoint( args.model_name, args.use_gpu) encoder.load_state_dict(encoder_dict) decoder.load_state_dict(decoder_dict) enc_opt.load_state_dict(enc_opt_dict) dec_opt.load_state_dict(dec_opt_dict) # early stopping after N epochs without improvement if acc_patience > args.patience_stop: break
class Evaluate(): def __init__(self, args): self.split = args.eval_split self.dataset = args.dataset to_tensor = transforms.ToTensor() normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) image_transforms = transforms.Compose([to_tensor, normalize]) if args.dataset == 'youtube': dataset = get_dataset(args, split=self.split, image_transforms=image_transforms, target_transforms=None, augment=args.augment and self.split == 'train', inputRes=(256, 448), video_mode=True, use_prev_mask=False) else: dataset = get_dataset(args, split=self.split, image_transforms=image_transforms, target_transforms=None, augment=args.augment and self.split == 'train', video_mode=True, use_prev_mask=False) self.loader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=False) self.args = args print(args.model_name) encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, load_args = load_checkpoint( args.model_name, args.use_gpu) load_args.use_gpu = args.use_gpu self.encoder = FeatureExtractor(load_args) self.decoder = RSIS(load_args) print(load_args) if args.ngpus > 1 and args.use_gpu: self.decoder = torch.nn.DataParallel(self.decoder, device_ids=range(args.ngpus)) self.encoder = torch.nn.DataParallel(self.encoder, device_ids=range(args.ngpus)) encoder_dict, decoder_dict = check_parallel(encoder_dict, decoder_dict) self.encoder.load_state_dict(encoder_dict) to_be_deleted_dec = [] for k in decoder_dict.keys(): if 'fc_stop' in k: to_be_deleted_dec.append(k) for k in to_be_deleted_dec: del decoder_dict[k] self.decoder.load_state_dict(decoder_dict) if args.use_gpu: self.encoder.cuda() self.decoder.cuda() self.encoder.eval() self.decoder.eval() if load_args.length_clip == 1: self.video_mode = False print('video mode not activated') else: self.video_mode = True print('video mode activated') def run_eval(self): print("Dataset is %s" % (self.dataset)) print("Split is %s" % (self.split)) if args.overlay_masks: colors = [] palette = sequence_palette() inv_palette = {} for k, v in palette.items(): inv_palette[v] = k num_colors = len(inv_palette.keys()) for id_color in range(num_colors): if id_color == 0 or id_color == 21: continue c = inv_palette[id_color] colors.append(c) if self.split == 'val': if args.dataset == 'youtube': masks_sep_dir = os.path.join('../models', args.model_name, 'masks_sep_2assess') make_dir(masks_sep_dir) if args.overlay_masks: results_dir = os.path.join('../models', args.model_name, 'results') make_dir(results_dir) json_data = open( '../../databases/YouTubeVOS/train/train-val-meta.json') data = json.load(json_data) else: masks_sep_dir = os.path.join('../models', args.model_name, 'masks_sep_2assess-davis') make_dir(masks_sep_dir) if args.overlay_masks: results_dir = os.path.join('../models', args.model_name, 'results-davis') make_dir(results_dir) for batch_idx, (inputs, targets, seq_name, starting_frame) in enumerate(self.loader): prev_hidden_temporal_list = None max_ii = min(len(inputs), args.length_clip) base_dir_masks_sep = masks_sep_dir + '/' + seq_name[0] + '/' make_dir(base_dir_masks_sep) if args.overlay_masks: base_dir = results_dir + '/' + seq_name[0] + '/' make_dir(base_dir) for ii in range(max_ii): # x: input images (N consecutive frames from M different sequences) # y_mask: ground truth annotations (some of them are zeros to have a fixed length in number of object instances) # sw_mask: this mask indicates which masks from y_mask are valid x, y_mask, sw_mask = batch_to_var(args, inputs[ii], targets[ii]) print(seq_name[0] + '/' + '%05d' % (starting_frame[0] + ii)) #from one frame to the following frame the prev_hidden_temporal_list is updated. outs, hidden_temporal_list = test( args, self.encoder, self.decoder, x, prev_hidden_temporal_list) if args.dataset == 'youtube': num_instances = len( data['videos'][seq_name[0]]['objects']) else: num_instances = 1 #int(torch.sum(sw_mask.data).data.cpu().numpy()) x_tmp = x.data.cpu().numpy() height = x_tmp.shape[-2] width = x_tmp.shape[-1] for t in range(10): mask_pred = (torch.squeeze(outs[0, t, :])).cpu().numpy() mask_pred = np.reshape(mask_pred, (height, width)) indxs_instance = np.where(mask_pred > 0.5) mask2assess = np.zeros((height, width)) mask2assess[indxs_instance] = 255 toimage(mask2assess, cmin=0, cmax=255).save(base_dir_masks_sep + '%05d_instance_%02d.png' % (starting_frame[0] + ii, t)) if args.overlay_masks: frame_img = x.data.cpu().numpy()[0, :, :, :].squeeze() frame_img = np.transpose(frame_img, (1, 2, 0)) mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) frame_img = std * frame_img + mean frame_img = np.clip(frame_img, 0, 1) plt.figure() plt.axis('off') plt.figure() plt.axis('off') plt.imshow(frame_img) for t in range(num_instances): mask_pred = (torch.squeeze( outs[0, t, :])).cpu().numpy() mask_pred = np.reshape(mask_pred, (height, width)) ax = plt.gca() tmp_img = np.ones( (mask_pred.shape[0], mask_pred.shape[1], 3)) color_mask = np.array(colors[t]) / 255.0 for i in range(3): tmp_img[:, :, i] = color_mask[i] ax.imshow(np.dstack((tmp_img, mask_pred * 0.7))) figname = base_dir + 'frame_%02d.png' % ( starting_frame[0] + ii) plt.savefig(figname, bbox_inches='tight') plt.close() if self.video_mode: prev_hidden_temporal_list = hidden_temporal_list else: if args.dataset == 'youtube': masks_sep_dir = os.path.join('../models', args.model_name, 'masks_sep_2assess_val') make_dir(masks_sep_dir) if args.overlay_masks: results_dir = os.path.join('../models', args.model_name, 'results_val') make_dir(results_dir) json_data = open('../../databases/YouTubeVOS/val/meta.json') data = json.load(json_data) else: masks_sep_dir = os.path.join('../models', args.model_name, 'masks_sep_2assess_val_davis') make_dir(masks_sep_dir) if args.overlay_masks: results_dir = os.path.join('../models', args.model_name, 'results_val_davis') make_dir(results_dir) for batch_idx, (inputs, seq_name, starting_frame) in enumerate(self.loader): prev_hidden_temporal_list = None max_ii = min(len(inputs), args.length_clip) for ii in range(max_ii): # x: input images (N consecutive frames from M different sequences) x = batch_to_var_test(args, inputs[ii]) print(seq_name[0] + '/' + '%05d' % (starting_frame[0] + ii)) if ii == 0: if args.dataset == 'youtube': num_instances = len( data['videos'][seq_name[0]]['objects']) else: annotation = Image.open( '../../databases/DAVIS2017/Annotations/480p/' + seq_name[0] + '/00000.png') instance_ids = sorted(np.unique(annotation)) instance_ids = instance_ids if instance_ids[ 0] else instance_ids[1:] if len(instance_ids) > 0: instance_ids = instance_ids[:-1] if instance_ids[ -1] == 255 else instance_ids num_instances = len(instance_ids) #from one frame to the following frame the prev_hidden_temporal_list is updated. outs, hidden_temporal_list = test( args, self.encoder, self.decoder, x, prev_hidden_temporal_list) base_dir_masks_sep = masks_sep_dir + '/' + seq_name[0] + '/' make_dir(base_dir_masks_sep) if args.overlay_masks: base_dir = results_dir + '/' + seq_name[0] + '/' make_dir(base_dir) x_tmp = x.data.cpu().numpy() height = x_tmp.shape[-2] width = x_tmp.shape[-1] for t in range(10): mask_pred = (torch.squeeze(outs[0, t, :])).cpu().numpy() mask_pred = np.reshape(mask_pred, (height, width)) indxs_instance = np.where(mask_pred > 0.5) mask2assess = np.zeros((height, width)) mask2assess[indxs_instance] = 255 toimage(mask2assess, cmin=0, cmax=255).save(base_dir_masks_sep + '%05d_instance_%02d.png' % (starting_frame[0] + ii, t)) if args.overlay_masks: frame_img = x.data.cpu().numpy()[0, :, :, :].squeeze() frame_img = np.transpose(frame_img, (1, 2, 0)) mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) frame_img = std * frame_img + mean frame_img = np.clip(frame_img, 0, 1) plt.figure() plt.axis('off') plt.figure() plt.axis('off') plt.imshow(frame_img) for t in range(num_instances): mask_pred = (torch.squeeze( outs[0, t, :])).cpu().numpy() mask_pred = np.reshape(mask_pred, (height, width)) ax = plt.gca() tmp_img = np.ones( (mask_pred.shape[0], mask_pred.shape[1], 3)) color_mask = np.array(colors[t]) / 255.0 for i in range(3): tmp_img[:, :, i] = color_mask[i] ax.imshow(np.dstack((tmp_img, mask_pred * 0.7))) figname = base_dir + 'frame_%02d.png' % ( starting_frame[0] + ii) plt.savefig(figname, bbox_inches='tight') plt.close() if self.video_mode: prev_hidden_temporal_list = hidden_temporal_list
class Evaluate(): def __init__(self, args): self.split = args.eval_split self.display = args.display self.dataset = args.dataset self.all_classes = args.all_classes self.T = args.maxseqlen self.batch_size = args.batch_size to_tensor = transforms.ToTensor() normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) image_transforms = transforms.Compose([to_tensor, normalize]) dataset = get_dataset(args, self.split, image_transforms, augment=False) self.loader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=False) self.sample_list = dataset.get_sample_list() self.args = args encoder_dict, decoder_dict, _, _, load_args = load_checkpoint( args.model_name) self.args.use_feedback = load_args.use_feedback self.args.base_model = load_args.base_model self.hidden_size = load_args.hidden_size self.args.nconvlstm = load_args.nconvlstm self.encoder = FeatureExtractor(load_args) self.decoder = RSIS(load_args) if args.ngpus > 1 and args.use_gpu: self.decoder = torch.nn.DataParallel(self.decoder, device_ids=range(args.ngpus)) self.encoder = torch.nn.DataParallel(self.encoder, device_ids=range(args.ngpus)) # check if the model was trained using multiple gpus trained_parallel = False for k, v in encoder_dict.items(): if k[:7] == "module.": trained_parallel = True break if trained_parallel and not args.ngpus > 1: # create new OrderedDict that does not contain "module." new_encoder_state_dict = OrderedDict() new_decoder_state_dict = OrderedDict() for k, v in encoder_dict.items(): name = k[7:] # remove "module." new_encoder_state_dict[name] = v for k, v in decoder_dict.items(): name = k[7:] # remove "module." new_decoder_state_dict[name] = v encoder_dict = new_encoder_state_dict decoder_dict = new_decoder_state_dict self.encoder.load_state_dict(encoder_dict) self.decoder.load_state_dict(decoder_dict) if args.use_gpu: self.encoder.cuda() self.decoder.cuda() self.encoder.eval() self.decoder.eval() def create_figures(self): acc_samples = 0 results_dir = os.path.join('../models', args.model_name, args.model_name + '_results') make_dir(results_dir) masks_dir = os.path.join(args.model_name + '_masks') abs_masks_dir = os.path.join(results_dir, masks_dir) make_dir(abs_masks_dir) print "Creating annotations for cityscapes validation..." for batch_idx, (inputs, targets) in enumerate(self.loader): x, y_mask, y_class, sw_mask, sw_class = batch_to_var( self.args, inputs, targets) out_masks, out_scores, stop_probs = test(self.args, self.encoder, self.decoder, x) class_ids = [24, 25, 26, 27, 28, 31, 32, 33] for sample in range(self.batch_size): sample_idx = self.sample_list[sample + acc_samples] image_dir = os.path.join(sample_idx.split('.')[0] + '.png') im = scipy.misc.imread(image_dir) h = im.shape[0] w = im.shape[1] sample_idx = sample_idx.split('/')[-1].split('.')[0] results_file = open( os.path.join(results_dir, sample_idx + '.txt'), 'w') img_masks = out_masks[sample] instance_id = 0 class_scores = out_scores[sample] stop_scores = stop_probs[sample] for time_step in range(self.T): mask = img_masks[time_step].cpu().numpy() mask = (mask > args.mask_th) h_mask = mask.shape[0] w_mask = mask.shape[1] mask = (mask > 0) labeled_blobs = measure.label(mask, background=0).flatten() # find the biggest one count = Counter(labeled_blobs) s = [] max_num = 0 for v, k in count.iteritems(): if v == 0: continue if k > max_num: max_num = k max_label = v # build mask from the largest connected component segmentation = (labeled_blobs == max_label).astype("uint8") mask = segmentation.reshape([h_mask, w_mask]) * 255 mask = scipy.misc.imresize(mask, [h, w]) class_scores_mask = class_scores[time_step].cpu().numpy() stop_scores_mask = stop_scores[time_step].cpu().numpy() class_score = np.argmax(class_scores_mask) for i in range(len(class_scores_mask) - 1): name_instance = sample_idx + '_' + str( instance_id) + '.png' pred_class_score = class_scores_mask[i + 1] objectness = stop_scores_mask[0] pred_class_score *= objectness scipy.misc.imsave( os.path.join(abs_masks_dir, name_instance), mask) results_file.write(masks_dir + '/' + name_instance + ' ' + str(class_ids[i]) + ' ' + str(pred_class_score) + '\n') instance_id += 1 results_file.close() acc_samples += self.batch_size
class Evaluate(): def __init__(self, args): self.split = args.eval_split self.dataset = args.dataset to_tensor = transforms.ToTensor() normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) image_transforms = transforms.Compose([to_tensor, normalize]) if args.dataset == 'davis2017': dataset = get_dataset(args, split=self.split, image_transforms=image_transforms, target_transforms=None, augment=args.augment and self.split == 'train', inputRes=(240, 427), video_mode=True, use_prev_mask=True) else: #args.dataset == 'youtube' dataset = get_dataset(args, split=self.split, image_transforms=image_transforms, target_transforms=None, augment=args.augment and self.split == 'train', inputRes=(256, 448), video_mode=True, use_prev_mask=True) self.loader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=False) self.args = args print(args.model_name) encoder_dict, decoder_dict, _, _, load_args = load_checkpoint( args.model_name, args.use_gpu) load_args.use_gpu = args.use_gpu self.encoder = FeatureExtractor(load_args) self.decoder = RSISMask(load_args) print(load_args) if args.ngpus > 1 and args.use_gpu: self.decoder = torch.nn.DataParallel(self.decoder, device_ids=range(args.ngpus)) self.encoder = torch.nn.DataParallel(self.encoder, device_ids=range(args.ngpus)) encoder_dict, decoder_dict = check_parallel(encoder_dict, decoder_dict) self.encoder.load_state_dict(encoder_dict) to_be_deleted_dec = [] for k in decoder_dict.keys(): if 'fc_stop' in k: to_be_deleted_dec.append(k) for k in to_be_deleted_dec: del decoder_dict[k] self.decoder.load_state_dict(decoder_dict) if args.use_gpu: self.encoder.cuda() self.decoder.cuda() self.encoder.eval() self.decoder.eval() if load_args.length_clip == 1: self.video_mode = False print('video mode not activated') else: self.video_mode = True print('video mode activated') def run_eval(self): print("Dataset is %s" % (self.dataset)) print("Split is %s" % (self.split)) if args.overlay_masks: colors = [] palette = sequence_palette() inv_palette = {} for k, v in palette.items(): inv_palette[v] = k num_colors = len(inv_palette.keys()) for id_color in range(num_colors): if id_color == 0 or id_color == 21: continue c = inv_palette[id_color] colors.append(c) if self.split == 'val': if args.dataset == 'youtube': masks_sep_dir = os.path.join('../models', args.model_name, 'masks_sep_2assess') make_dir(masks_sep_dir) if args.overlay_masks: results_dir = os.path.join('../models', args.model_name, 'results') make_dir(results_dir) json_data = open( '../../databases/YouTubeVOS/train/train-val-meta.json') data = json.load(json_data) else: #args.dataset == 'davis2017' import lmdb from misc.config import cfg masks_sep_dir = os.path.join('../models', args.model_name, 'masks_sep_2assess-davis') make_dir(masks_sep_dir) if args.overlay_masks: results_dir = os.path.join('../models', args.model_name, 'results-davis') make_dir(results_dir) lmdb_env_seq_dir = osp.join(cfg.PATH.DATA, 'lmdb_seq') if osp.isdir(lmdb_env_seq_dir): lmdb_env_seq = lmdb.open(lmdb_env_seq_dir) else: lmdb_env_seq = None for batch_idx, (inputs, targets, seq_name, starting_frame) in enumerate(self.loader): prev_hidden_temporal_list = None max_ii = min(len(inputs), args.length_clip) if args.overlay_masks: base_dir = results_dir + '/' + seq_name[0] + '/' make_dir(base_dir) if args.dataset == 'davis2017': key_db = osp.basename(seq_name[0]) if not lmdb_env_seq == None: with lmdb_env_seq.begin() as txn: _files_vec = txn.get( key_db.encode()).decode().split('|') _files = [osp.splitext(f)[0] for f in _files_vec] else: seq_dir = osp.join(cfg['PATH']['SEQUENCES'], key_db) _files_vec = os.listdir(seq_dir) _files = [osp.splitext(f)[0] for f in _files_vec] frame_names = sorted(_files) for ii in range(max_ii): #start_time = time.time() # x: input images (N consecutive frames from M different sequences) # y_mask: ground truth annotations (some of them are zeros to have a fixed length in number of object instances) # sw_mask: this mask indicates which masks from y_mask are valid x, y_mask, sw_mask = batch_to_var(args, inputs[ii], targets[ii]) if ii == 0: prev_mask = y_mask #from one frame to the following frame the prev_hidden_temporal_list is updated. outs, hidden_temporal_list = test_prev_mask( args, self.encoder, self.decoder, x, prev_hidden_temporal_list, prev_mask) #end_inference_time = time.time() #print("inference time: %.3f" %(end_inference_time-start_time)) if args.dataset == 'youtube': num_instances = len( data['videos'][seq_name[0]]['objects']) else: num_instances = int( torch.sum(sw_mask.data).data.cpu().numpy()) base_dir_masks_sep = masks_sep_dir + '/' + seq_name[0] + '/' make_dir(base_dir_masks_sep) x_tmp = x.data.cpu().numpy() height = x_tmp.shape[-2] width = x_tmp.shape[-1] for t in range(num_instances): mask_pred = (torch.squeeze(outs[0, t, :])).cpu().numpy() mask_pred = np.reshape(mask_pred, (height, width)) indxs_instance = np.where(mask_pred > 0.5) mask2assess = np.zeros((height, width)) mask2assess[indxs_instance] = 255 if args.dataset == 'youtube': toimage(mask2assess, cmin=0, cmax=255).save(base_dir_masks_sep + '%05d_instance_%02d.png' % (starting_frame[0] + ii, t)) else: toimage(mask2assess, cmin=0, cmax=255).save(base_dir_masks_sep + frame_names[ii] + '_instance_%02d.png' % (t)) #end_saving_masks_time = time.time() #print("inference + saving masks time: %.3f" %(end_saving_masks_time - start_time)) if args.dataset == 'youtube': print(seq_name[0] + '/' + '%05d' % (starting_frame[0] + ii)) else: print(seq_name[0] + '/' + frame_names[ii]) if args.overlay_masks: frame_img = x.data.cpu().numpy()[0, :, :, :].squeeze() frame_img = np.transpose(frame_img, (1, 2, 0)) mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) frame_img = std * frame_img + mean frame_img = np.clip(frame_img, 0, 1) plt.figure() plt.axis('off') plt.figure() plt.axis('off') plt.imshow(frame_img) for t in range(num_instances): mask_pred = (torch.squeeze( outs[0, t, :])).cpu().numpy() mask_pred = np.reshape(mask_pred, (height, width)) ax = plt.gca() tmp_img = np.ones( (mask_pred.shape[0], mask_pred.shape[1], 3)) color_mask = np.array(colors[t]) / 255.0 for i in range(3): tmp_img[:, :, i] = color_mask[i] ax.imshow(np.dstack((tmp_img, mask_pred * 0.7))) if args.dataset == 'youtube': figname = base_dir + 'frame_%02d.png' % ( starting_frame[0] + ii) else: figname = base_dir + frame_names[ii] + '.png' plt.savefig(figname, bbox_inches='tight') plt.close() if self.video_mode: if args.only_spatial == False: prev_hidden_temporal_list = hidden_temporal_list if ii > 0: prev_mask = outs else: prev_mask = y_mask del outs, hidden_temporal_list, x, y_mask, sw_mask else: if args.dataset == 'youtube': masks_sep_dir = os.path.join('../models', args.model_name, 'masks_sep_2assess_val') make_dir(masks_sep_dir) if args.overlay_masks: results_dir = os.path.join('../models', args.model_name, 'results_val') make_dir(results_dir) json_data = open('../../databases/YouTubeVOS/valid/meta.json') data = json.load(json_data) else: #args.dataset == 'davis2017' import lmdb from misc.config import cfg masks_sep_dir = os.path.join('../models', args.model_name, 'masks_sep_2assess_val_davis') make_dir(masks_sep_dir) if args.overlay_masks: results_dir = os.path.join('../models', args.model_name, 'results_val_davis') make_dir(results_dir) lmdb_env_seq_dir = osp.join(cfg.PATH.DATA, 'lmdb_seq') if osp.isdir(lmdb_env_seq_dir): lmdb_env_seq = lmdb.open(lmdb_env_seq_dir) else: lmdb_env_seq = None for batch_idx, (inputs, seq_name, starting_frame) in enumerate(self.loader): prev_hidden_temporal_list = None max_ii = min(len(inputs), args.length_clip) if args.overlay_masks: base_dir = results_dir + '/' + seq_name[0] + '/' make_dir(base_dir) if args.dataset == 'youtube': seq_data = data['videos'][seq_name[0]]['objects'] frame_names = [] frame_names_with_new_objects = [] instance_ids = [] for obj_id in seq_data.keys(): instance_ids.append(int(obj_id)) frame_names_with_new_objects.append( seq_data[obj_id]['frames'][0]) for frame_name in seq_data[obj_id]['frames']: if frame_name not in frame_names: frame_names.append(frame_name) frame_names.sort() frame_names_with_new_objects_idxs = [] for kk in range(len(frame_names_with_new_objects)): new_frame_idx = frame_names.index( frame_names_with_new_objects[kk]) frame_names_with_new_objects_idxs.append(new_frame_idx) else: #davis2017 key_db = osp.basename(seq_name[0]) if not lmdb_env_seq == None: with lmdb_env_seq.begin() as txn: _files_vec = txn.get( key_db.encode()).decode().split('|') _files = [osp.splitext(f)[0] for f in _files_vec] else: seq_dir = osp.join(cfg['PATH']['SEQUENCES'], key_db) _files_vec = os.listdir(seq_dir) _files = [osp.splitext(f)[0] for f in _files_vec] frame_names = sorted(_files) for ii in range(max_ii): # x: input images (N consecutive frames from M different sequences) # y_mask: ground truth annotations (some of them are zeros to have a fixed length in number of object instances) # sw_mask: this mask indicates which masks from y_mask are valid x = batch_to_var_test(args, inputs[ii]) print(seq_name[0] + '/' + frame_names[ii]) if ii == 0: frame_name = frame_names[0] if args.dataset == 'youtube': annotation = Image.open( '../../databases/YouTubeVOS/valid/Annotations/' + seq_name[0] + '/' + frame_name + '.png') annot = imresize(annotation, (256, 448), interp='nearest') else: #davis2017 annotation = Image.open( '../../databases/DAVIS2017/Annotations/480p/' + seq_name[0] + '/' + frame_name + '.png') instance_ids = sorted(np.unique(annotation)) instance_ids = instance_ids if instance_ids[ 0] else instance_ids[1:] if len(instance_ids) > 0: instance_ids = instance_ids[:-1] if instance_ids[ -1] == 255 else instance_ids annot = imresize(annotation, (240, 427), interp='nearest') annot = np.expand_dims(annot, axis=0) annot = torch.from_numpy(annot) annot = annot.float() annot = annot.numpy().squeeze() annot = annot_from_mask(annot, instance_ids) prev_mask = annot prev_mask = np.expand_dims(prev_mask, axis=0) prev_mask = torch.from_numpy(prev_mask) y_mask = Variable(prev_mask.float(), requires_grad=False) prev_mask = y_mask.cuda() del annot if args.dataset == 'youtube': if ii > 0 and ii in frame_names_with_new_objects_idxs: frame_name = frame_names[ii] annotation = Image.open( '../../databases/YouTubeVOS/valid/Annotations/' + seq_name[0] + '/' + frame_name + '.png') annot = imresize(annotation, (256, 448), interp='nearest') annot = np.expand_dims(annot, axis=0) annot = torch.from_numpy(annot) annot = annot.float() annot = annot.numpy().squeeze() new_instance_ids = np.unique(annot)[1:] annot = annot_from_mask(annot, new_instance_ids) annot = np.expand_dims(annot, axis=0) annot = torch.from_numpy(annot) annot = Variable(annot.float(), requires_grad=False) annot = annot.cuda() for kk in new_instance_ids: prev_mask[:, int(kk - 1), :] = annot[:, int(kk - 1), :] del annot #from one frame to the following frame the prev_hidden_temporal_list is updated. outs, hidden_temporal_list = test_prev_mask( args, self.encoder, self.decoder, x, prev_hidden_temporal_list, prev_mask) base_dir_masks_sep = masks_sep_dir + '/' + seq_name[0] + '/' make_dir(base_dir_masks_sep) x_tmp = x.data.cpu().numpy() height = x_tmp.shape[-2] width = x_tmp.shape[-1] for t in range(len(instance_ids)): mask_pred = (torch.squeeze(outs[0, t, :])).cpu().numpy() mask_pred = np.reshape(mask_pred, (height, width)) indxs_instance = np.where(mask_pred > 0.5) mask2assess = np.zeros((height, width)) mask2assess[indxs_instance] = 255 toimage(mask2assess, cmin=0, cmax=255).save(base_dir_masks_sep + frame_names[ii] + '_instance_%02d.png' % (t)) if args.overlay_masks: frame_img = x.data.cpu().numpy()[0, :, :, :].squeeze() frame_img = np.transpose(frame_img, (1, 2, 0)) mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) frame_img = std * frame_img + mean frame_img = np.clip(frame_img, 0, 1) plt.figure() plt.axis('off') plt.figure() plt.axis('off') plt.imshow(frame_img) for t in range(len(instance_ids)): mask_pred = (torch.squeeze( outs[0, t, :])).cpu().numpy() mask_pred = np.reshape(mask_pred, (height, width)) ax = plt.gca() tmp_img = np.ones( (mask_pred.shape[0], mask_pred.shape[1], 3)) color_mask = np.array(colors[t]) / 255.0 for i in range(3): tmp_img[:, :, i] = color_mask[i] ax.imshow(np.dstack((tmp_img, mask_pred * 0.7))) figname = base_dir + frame_names[ii] + '.png' plt.savefig(figname, bbox_inches='tight') plt.close() if self.video_mode: if args.only_spatial == False: prev_hidden_temporal_list = hidden_temporal_list if ii > 0: prev_mask = outs del x, hidden_temporal_list, outs
class Evaluate(): def __init__(self, args): self.split = args.eval_split self.display = args.display self.no_display_text = args.no_display_text self.dataset = args.dataset self.all_classes = args.all_classes self.use_cats = args.use_cats to_tensor = transforms.ToTensor() normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) image_transforms = transforms.Compose([to_tensor, normalize]) dataset = get_dataset(args, self.split, image_transforms, augment=False, imsize=args.imsize) self.sample_list = dataset.get_sample_list() self.class_names = dataset.get_classes() if args.dataset == 'pascal': self.gt_file = pickle.load( open( os.path.join(args.pascal_dir, 'VOCGT_%s.pkl' % (self.split)), 'rb')) self.key_to_anns = dict() self.ignoremasks = {} for ann in self.gt_file: if ann['ignore'] == 1: if type(ann['segmentation']['counts']) == list: im_height = ann['segmentation']['size'][0] im_width = ann['segmentation']['size'][1] rle = mask.frPyObjects([ann['segmentation']], im_height, im_width) else: rle = [ann['segmentation']] m = mask.decode(rle) self.ignoremasks[ann['image_id']] = m if ann['image_id'] in self.key_to_anns.keys(): self.key_to_anns[ann['image_id']].append(ann) else: self.key_to_anns[ann['image_id']] = [ann] self.coco = create_coco_object(args, self.sample_list, self.class_names) self.loader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=False) self.args = args self.colors = [] palette = sequence_palette() inv_palette = {} for k, v in palette.iteritems(): inv_palette[v] = k num_colors = len(inv_palette.keys()) for i in range(num_colors): if i == 0 or i == 21: continue c = inv_palette[i] self.colors.append(c) encoder_dict, decoder_dict, _, _, load_args = load_checkpoint( args.model_name, args.use_gpu) load_args.use_gpu = args.use_gpu self.encoder = FeatureExtractor(load_args) self.decoder = RSIS(load_args) print(load_args) if args.ngpus > 1 and args.use_gpu: self.decoder = torch.nn.DataParallel(self.decoder, device_ids=range(args.ngpus)) self.encoder = torch.nn.DataParallel(self.encoder, device_ids=range(args.ngpus)) encoder_dict, decoder_dict = check_parallel(encoder_dict, decoder_dict) self.encoder.load_state_dict(encoder_dict) self.decoder.load_state_dict(decoder_dict) if args.use_gpu: self.encoder.cuda() self.decoder.cuda() self.encoder.eval() self.decoder.eval() def _create_json(self): predictions = list() acc_samples = 0 print "Creating annotations..." for batch_idx, (inputs, targets) in enumerate(self.loader): x, y_mask, y_class, sw_mask, sw_class = batch_to_var( self.args, inputs, targets) num_objects = np.sum(sw_mask.data.float().cpu().numpy(), axis=-1) out_masks, out_scores, stop_probs = test(self.args, self.encoder, self.decoder, x) out_scores = out_scores.cpu().numpy() stop_scores = stop_probs.cpu().numpy() out_masks = out_masks.cpu().numpy() out_classes = np.argmax(out_scores, axis=-1) w = x.size()[-1] h = x.size()[-2] #out_masks, out_classes, y_mask, y_class = outs_perms_to_cpu(self.args,outs,true_perms,h,w) for s in range(out_masks.shape[0]): this_pred = list() sample_idx = self.sample_list[s + acc_samples] if self.args.dataset == 'pascal': ignore_mask = self.ignoremasks[sample_idx] else: ignore_mask = None if self.dataset == 'pascal': image_dir = os.path.join(args.pascal_dir, 'JPEGImages', sample_idx + '.jpg') elif self.dataset == 'cityscapes': sample_idx = sample_idx.split('.')[0] image_dir = sample_idx + '.png' elif self.dataset == 'leaves': image_dir = sample_idx im = imread(image_dir) h = im.shape[0] w = im.shape[1] objectness_scores = [] class_scores = [] reached_end = False for i in range(out_masks.shape[1]): if reached_end: break objectness = stop_scores[s][i][0] if objectness < args.stop_th: continue pred_mask = out_masks[s][i] # store class with max confidence for display if args.class_th == 0.0: max_class = 1 else: max_class = out_classes[s][i] # process mask to create annotation pred_mask, is_valid, raw_pred_mask = resize_mask( args, pred_mask, h, w, ignore_mask) # for evaluation we repeat the mask with all its class probs for cls_id in range(len(self.class_names)): if cls_id == 0: # ignore eos continue pred_class_score = out_scores[s][i][cls_id] pred_class_score_mod = pred_class_score * objectness ann = create_annotation(self.args, sample_idx, pred_mask, cls_id, pred_class_score_mod, self.class_names, is_valid) if ann is not None: if self.dataset == 'leaves': if objectness > args.stop_th: this_pred.append(ann) else: # for display we only take the mask with max confidence if cls_id == max_class and pred_class_score_mod >= self.args.class_th: ann_save = create_annotation( self.args, sample_idx, raw_pred_mask, cls_id, pred_class_score_mod, self.class_names, is_valid) this_pred.append(ann_save) predictions.append(ann) if self.display: figures_dir = os.path.join( '../models', args.model_name, args.model_name + '_figs_' + args.eval_split) make_dir(figures_dir) plt.figure() plt.axis('off') plt.figure() plt.axis('off') plt.imshow(im) display_masks(this_pred, self.colors, im_height=im.shape[0], im_width=im.shape[1], no_display_text=self.args.no_display_text) if self.dataset == 'cityscapes': sample_idx = sample_idx.split('/')[-1] if self.dataset == 'leaves': sample_idx = sample_idx.split('/')[-1] figname = os.path.join(figures_dir, sample_idx) plt.savefig(figname, bbox_inches='tight') plt.close() acc_samples += np.shape(out_masks)[0] return predictions def run_eval(self): print "Dataset is %s" % (self.dataset) print "Split is %s" % (self.split) print "Evaluating for %d images" % (len(self.sample_list)) print "Number of classes is %d" % (len(self.class_names)) if self.dataset == 'pascal': cocoGT = self.coco.loadRes(self.gt_file) predictions = self._create_json() if not args.no_run_coco_eval: cocoP = self.coco.loadRes(predictions) cocoEval = COCOeval(cocoGT, cocoP, 'segm') cocoEval.params.maxDets = [1, args.max_dets, 100] cocoEval.params.useCats = args.use_cats if not args.cat_id == -1: cocoEval.params.catIds = [args.cat_id] cocoEval.params.imgIds = sorted(self.sample_list) cocoEval.params.catIds = range(1, len(self.class_names)) print("Results for all the classes together") cocoEval.evaluate() cocoEval.accumulate() cocoEval.summarize() if self.all_classes: for actual_class in cocoEval.params.catIds: print("Testing class dataset_id: " + str(actual_class)) print("Which corresponds to name: " + self.class_names[actual_class]) cocoEval.params.catIds = [actual_class] cocoEval.evaluate() cocoEval.accumulate() cocoEval.summarize()
class Evaluate(): def __init__(self, args): self.split = args.eval_split self.dataset = args.dataset to_tensor = transforms.ToTensor() normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) image_transforms = transforms.Compose([to_tensor, normalize]) if args.dataset == 'davis2017': dataset = get_dataset(args, split=self.split, e=0, image_transforms=image_transforms, target_transforms=None, augment=args.augment and self.split == 'train', inputRes=(240, 427), video_mode=True, use_prev_mask=True, eval=True) else: # args.dataset == 'youtube' or kittimots dataset = get_dataset(args, split=self.split, e=0, image_transforms=image_transforms, target_transforms=None, augment=args.augment and self.split == 'train', #inputRes=(256, 448), inputRes=(287, 950), #inputRes=(412,723), #inputRes=(178,590), video_mode=True, use_prev_mask=True, eval=True) self.loader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=False) self.args = args print(args.model_name) encoder_dict, decoder_dict, _, _, load_args = load_checkpoint(args.model_name, args.use_gpu) load_args.use_gpu = args.use_gpu self.encoder = FeatureExtractor(load_args) self.decoder = RSISMask(load_args) print(load_args) if args.ngpus > 1 and args.use_gpu: self.decoder = torch.nn.DataParallel(self.decoder, device_ids=range(args.ngpus)) self.encoder = torch.nn.DataParallel(self.encoder, device_ids=range(args.ngpus)) encoder_dict, decoder_dict = check_parallel(encoder_dict, decoder_dict) self.encoder.load_state_dict(encoder_dict) to_be_deleted_dec = [] for k in decoder_dict.keys(): if 'fc_stop' in k: to_be_deleted_dec.append(k) for k in to_be_deleted_dec: del decoder_dict[k] self.decoder.load_state_dict(decoder_dict) if args.use_gpu: self.encoder.cuda() self.decoder.cuda() self.encoder.eval() self.decoder.eval() if load_args.length_clip == 1: self.video_mode = False print('video mode not activated') else: self.video_mode = True print('video mode activated') def run_eval(self): print("Dataset is %s" % (self.dataset)) print("Split is %s" % (self.split)) if args.overlay_masks: colors = [] palette = sequence_palette() inv_palette = {} for k, v in palette.items(): inv_palette[v] = k num_colors = len(inv_palette.keys()) for id_color in range(num_colors): if id_color == 0 or id_color == 21: continue c = inv_palette[id_color] colors.append(c) if self.split == 'val-inference': if args.dataset == 'youtube': masks_sep_dir = os.path.join('../models', args.model_name, 'masks_sep_2assess') make_dir(masks_sep_dir) if args.overlay_masks: results_dir = os.path.join('../models', args.model_name, 'results') make_dir(results_dir) json_data = open('../../databases/YouTubeVOS/train/train-val-meta.json') data = json.load(json_data) elif args.dataset == 'davis2017': import lmdb from misc.config import cfg masks_sep_dir = os.path.join('../models', args.model_name, 'masks_sep_2assess-davis') make_dir(masks_sep_dir) if args.overlay_masks: results_dir = os.path.join('../models', args.model_name, 'results-davis') make_dir(results_dir) lmdb_env_seq_dir = osp.join(cfg.PATH.DATA, 'lmdb_seq') if osp.isdir(lmdb_env_seq_dir): lmdb_env_seq = lmdb.open(lmdb_env_seq_dir) else: lmdb_env_seq = None else: import lmdb from misc.config_kittimots import cfg masks_sep_dir = os.path.join('../models', args.model_name, 'masks_sep_2assess-kitti') make_dir(masks_sep_dir) if args.overlay_masks: results_dir = os.path.join('../models', args.model_name, 'results-kitti') make_dir(results_dir) lmdb_env_seq_dir = osp.join(cfg.PATH.DATA, 'lmdb_seq') if osp.isdir(lmdb_env_seq_dir): lmdb_env_seq = lmdb.open(lmdb_env_seq_dir) else: lmdb_env_seq = None for batch_idx, (inputs, targets, seq_name, starting_frame, frames_with_new_ids) in enumerate(self.loader): prev_hidden_temporal_list = None max_ii = min(len(inputs), args.length_clip) frames_with_new_ids = np.array(frames_with_new_ids) #print('Variable max_ii') #print(max_ii) if args.overlay_masks: base_dir = results_dir + '/' + seq_name[0] + '/' make_dir(base_dir) if args.dataset == 'davis2017': key_db = osp.basename(seq_name[0]) if not lmdb_env_seq == None: with lmdb_env_seq.begin() as txn: _files_vec = txn.get(key_db.encode()).decode().split('|') _files = [osp.splitext(f)[0] for f in _files_vec] else: seq_dir = osp.join(cfg['PATH']['SEQUENCES'], key_db) _files_vec = os.listdir(seq_dir) _files = [osp.splitext(f)[0] for f in _files_vec] frame_names = sorted(_files) if args.dataset == 'kittimots': key_db = osp.basename(seq_name[0]) if not lmdb_env_seq == None: with lmdb_env_seq.begin() as txn: _files_vec = txn.get(key_db.encode()).decode().split('|') _files = [osp.splitext(f)[0] for f in _files_vec] else: seq_dir = osp.join(cfg['PATH']['SEQUENCES'], key_db) _files_vec = os.listdir(seq_dir) _files = [osp.splitext(f)[0] for f in _files_vec] frame_names = sorted(_files) # llistat de frames d'una seqüència de video dict_outs = {} #print("NEW OBJECTS FRAMES", frames_with_new_ids) # make a dir of results for each instance '''for t in range(args.maxseqlen): base_dir_2 = results_dir + '/' + seq_name[0] + '/' + str(t) make_dir(base_dir_2)''' for ii in range(max_ii): # iteració sobre els frames/clips amb dimensio lenght_clip # start_time = time.time() # x: input images (N consecutive frames from M different sequences) # y_mask: ground truth annotations (some of them are zeros to have a fixed length in number of object instances) # sw_mask: this mask indicates which masks from y_mask are valid x, y_mask, sw_mask = batch_to_var(args, inputs[ii], targets[ii]) #print(seq_name[0] + '/' + frame_names[ii]) if ii == 0: #one-shot approach, information about the first frame of the sequende prev_mask = y_mask #list of the first instances that appear on the sequence and update of the dictionary annotation = Image.open( '../../databases/KITTIMOTS/Annotations/' + seq_name[0] + '/' + frame_names[ ii] + '.png').convert('P') annot = np.expand_dims(annotation, axis=0) annot = torch.from_numpy(annot) annot = annot.float() instance_ids = np.unique(annot) for i in instance_ids[1:]: dict_outs.update({str(int(i-1)):int(i)}) #instances = len(instance_ids)-1 #one-shot approach, add GT information when a new instance appears on the video sequence if args.dataset == 'kittimots': if ii > 0 and ii in frames_with_new_ids: frame_name = frame_names[ii] annotation = Image.open( '../../databases/KITTIMOTS/Annotations/' + seq_name[ 0] + '/' + frame_name + '.png').convert('P') #annot = imresize(annotation, (256, 448), interp='nearest') annot = imresize(annotation, (287,950), interp='nearest') #annot = imresize(annotation, (412, 723), interp='nearest') #annot = imresize(annotation, (178, 590), interp='nearest') annot = np.expand_dims(annot, axis=0) annot = torch.from_numpy(annot) annot = annot.float() annot = annot.numpy().squeeze() new_instance_ids = np.setdiff1d(np.unique(annot), instance_ids) annot = annot_from_mask(annot, new_instance_ids) annot = np.expand_dims(annot, axis=0) annot = torch.from_numpy(annot) annot = Variable(annot.float(), requires_grad=False) annot = annot.cuda() #adding only the information of the new instance after the last active branch for kk in new_instance_ids: if dict_outs: last = int(list(dict_outs.keys())[-1]) else: #if the dictionary is empty last = -1 prev_mask[:, int(last+1), :] = annot[:, int(kk - 1), :] dict_outs.update({str(last+1):int(kk)}) del annot #update the list of instances that have appeared on the video sequence if len(new_instance_ids) > 0: instance_ids = np.append(instance_ids, new_instance_ids) #instances = instances + len(new_instance_ids) # from one frame to the following frame the prev_hidden_temporal_list is updated. outs, hidden_temporal_list = test_prev_mask(args, self.encoder, self.decoder, x, prev_hidden_temporal_list, prev_mask) # end_inference_time = time.time() # print("inference time: %.3f" %(end_inference_time-start_time)) if args.dataset == 'youtube': num_instances = len(data['videos'][seq_name[0]]['objects']) else: num_instances = int(torch.sum(sw_mask.data).data.cpu().numpy()) num_instances = args.maxseqlen base_dir_masks_sep = masks_sep_dir + '/' + seq_name[0] + '/' make_dir(base_dir_masks_sep) x_tmp = x.data.cpu().numpy() height = x_tmp.shape[-2] width = x_tmp.shape[-1] '''out_stop = outs[1] outs = outs[0] for m in range(len(out_stop[0])): print(m) print(out_stop[0][m])''' #print("OUT STOP: ", out_stop[0]) outs_masks = np.zeros((args.maxseqlen,), dtype=int) for t in range(num_instances): mask_pred = (torch.squeeze(outs[0, t, :])).cpu().numpy() mask_pred = np.reshape(mask_pred, (height, width)) indxs_instance = np.where(mask_pred > 0.5) '''indxs_instance = np.where((0.6 > mask_pred) & (mask_pred>= 0.5)) indxs_instance_1 = np.where((0.7 > mask_pred) & (mask_pred>= 0.6)) indxs_instance_2 = np.where((0.8 > mask_pred) & (mask_pred>= 0.7)) indxs_instance_3 = np.where((0.9 > mask_pred) & (mask_pred>= 0.8)) indxs_instance_4 = np.where((0.999> mask_pred) & (mask_pred>= 0.9)) indxs_instance_5 = np.where((0.9999 > mask_pred) & (mask_pred>= 0.999))''' mask2assess = np.zeros((height, width)) mask2assess[indxs_instance] = 255 ''' mask2assess[indxs_instance] = 40 mask2assess[indxs_instance_1] = 80 mask2assess[indxs_instance_2] = 120 mask2assess[indxs_instance_3] = 160 mask2assess[indxs_instance_4] = 200 mask2assess[indxs_instance_5] = 255''' if str(t) in dict_outs: i = dict_outs[str(t)] else: break if args.dataset == 'youtube': toimage(mask2assess, cmin=0, cmax=255).save( base_dir_masks_sep + '%05d_instance_%02d.png' % (starting_frame[0] + ii, i)) else: toimage(mask2assess, cmin=0, cmax=255).save( base_dir_masks_sep + frame_names[ii] + '_instance_%02d.png' % (i)) #create vector of predictions, gives information about which branches are active if len(indxs_instance[0]) != 0: outs_masks[t] = 1 else: outs_masks[t] = 0 outs = outs.cpu().numpy() #print("INS: ", outs_masks) #print(json.dumps(dict_outs)) #delete spurious branches last_position = last_ocurrence(outs_masks, 1) + 1 while len(dict_outs) < last_position: for n in range(args.maxseqlen): if outs_masks[n] == 1 and str(n) not in dict_outs: outs = np.delete(outs, n, axis=1) outs_masks = np.delete(outs_masks, n) del hidden_temporal_list[n] z = np.zeros((height * width)) outs = np.insert(outs, args.maxseqlen - 1, z, axis=1) hidden_temporal_list.append(None) outs_masks = np.append(outs_masks, 0) last_position = last_ocurrence(outs_masks, 1) + 1 instances = sum(outs_masks) # number of active branches #delete branches of instances that disappear and rearrange for n in range(args.maxseqlen): while outs_masks[n] == 0 and n < instances: outs = np.delete(outs, n, axis=1 ) outs_masks = np.delete(outs_masks, n) del hidden_temporal_list[n] z = np.zeros((height * width)) outs = np.insert(outs, args.maxseqlen-1, z, axis=1) hidden_temporal_list.append(None) outs_masks = np.append(outs_masks, 0) #update dictionary by shifting entries for m in range(len(dict_outs)-(n+1)): value = dict_outs[str(m + n + 1)] dict_outs.update({str(n + m): value}) #print(json.dumps(dict_outs)) last = int(list(dict_outs.keys())[-1]) del dict_outs[str(last)] #an instance has disappeared, update dictionary while len(dict_outs) > sum(outs_masks): last = int(list(dict_outs.keys())[-1]) del dict_outs[str(last)] outs = torch.from_numpy(outs) outs = outs.cuda() #print("OUTS: ", outs_masks) #print(json.dumps(dict_outs)) # end_saving_masks_time = time.time() # print("inference + saving masks time: %.3f" %(end_saving_masks_time - start_time)) if args.dataset == 'youtube': print(seq_name[0] + '/' + '%05d' % (starting_frame[0] + ii)) else: print(seq_name[0] + '/' + frame_names[ii]) if args.overlay_masks: frame_img = x.data.cpu().numpy()[0, :, :, :].squeeze() frame_img = np.transpose(frame_img, (1, 2, 0)) mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) frame_img = std * frame_img + mean frame_img = np.clip(frame_img, 0, 1) plt.figure(); plt.axis('off') plt.figure(); plt.axis('off') plt.imshow(frame_img) #print("INSTANCES: ", instances) for t in range(instances): mask_pred = (torch.squeeze(outs[0, t, :])).cpu().numpy() mask_pred = np.reshape(mask_pred, (height, width)) ax = plt.gca() tmp_img = np.ones((mask_pred.shape[0], mask_pred.shape[1], 3)) if str(t) in dict_outs: color_mask = np.array(colors[dict_outs[str(t)]]) / 255.0 else: #color_mask = np.array(colors[0]) / 255.0 break for i in range(3): tmp_img[:, :, i] = color_mask[i] ax.imshow(np.dstack((tmp_img, mask_pred * 0.7))) if args.dataset == 'youtube': figname = base_dir + 'frame_%02d.png' % (starting_frame[0] + ii) else: figname = base_dir + frame_names[ii] + '.png' plt.savefig(figname, bbox_inches='tight') plt.close() #Print a video for each instance '''for t in range(instances): if args.overlay_masks: frame_img = x.data.cpu().numpy()[0, :, :, :].squeeze() frame_img = np.transpose(frame_img, (1, 2, 0)) mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) frame_img = std * frame_img + mean frame_img = np.clip(frame_img, 0, 1) plt.figure(); plt.axis('off') plt.figure(); plt.axis('off') plt.imshow(frame_img) mask_pred = (torch.squeeze(outs[0, t, :])).cpu().numpy() mask_pred = np.reshape(mask_pred, (height, width)) ax = plt.gca() tmp_img = np.ones((mask_pred.shape[0], mask_pred.shape[1], 3)) if str(t) in dict_outs: color_mask = np.array(colors[dict_outs[str(t)]]) / 255.0 else: #color_mask = np.array(colors[0]) / 255.0 break for i in range(3): tmp_img[:, :, i] = color_mask[i] ax.imshow(np.dstack((tmp_img, mask_pred * 0.7))) figname = base_dir + '/' + str(dict_outs[str(t)]) + '/' + frame_names[ii] + '.png' plt.savefig(figname, bbox_inches='tight') plt.close()''' if self.video_mode: if args.only_spatial == False: prev_hidden_temporal_list = hidden_temporal_list if ii > 0: prev_mask = outs else: prev_mask = y_mask del outs, hidden_temporal_list, x, y_mask, sw_mask else: if args.dataset == 'youtube': masks_sep_dir = os.path.join('../models', args.model_name, 'masks_sep_2assess_val') make_dir(masks_sep_dir) if args.overlay_masks: results_dir = os.path.join('../models', args.model_name, 'results_val') make_dir(results_dir) json_data = open('../../databases/YouTubeVOS/val/meta.json') data = json.load(json_data) elif args.dataset == 'davis2017': import lmdb from misc.config import cfg masks_sep_dir = os.path.join('../models', args.model_name, 'masks_sep_2assess_val_davis') make_dir(masks_sep_dir) if args.overlay_masks: results_dir = os.path.join('../models', args.model_name, 'results_val_davis') make_dir(results_dir) lmdb_env_seq_dir = osp.join(cfg.PATH.DATA, 'lmdb_seq') if osp.isdir(lmdb_env_seq_dir): lmdb_env_seq = lmdb.open(lmdb_env_seq_dir) else: lmdb_env_seq = None else: import lmdb from misc.config_kittimots import cfg masks_sep_dir = os.path.join('../models', args.model_name, 'masks_sep_2assess-kitti') make_dir(masks_sep_dir) if args.overlay_masks: results_dir = os.path.join('../models', args.model_name, 'results-kitti') make_dir(results_dir) lmdb_env_seq_dir = osp.join(cfg.PATH.DATA, 'lmdb_seq') if osp.isdir(lmdb_env_seq_dir): lmdb_env_seq = lmdb.open(lmdb_env_seq_dir) else: lmdb_env_seq = None for batch_idx, (inputs, seq_name, starting_frame) in enumerate(self.loader): prev_hidden_temporal_list = None max_ii = min(len(inputs), args.length_clip) if args.overlay_masks: base_dir = results_dir + '/' + seq_name[0] + '/' make_dir(base_dir) if args.dataset == 'youtube': seq_data = data['videos'][seq_name[0]]['objects'] frame_names = [] frame_names_with_new_objects = [] instance_ids = [] for obj_id in seq_data.keys(): instance_ids.append(int(obj_id)) frame_names_with_new_objects.append(seq_data[obj_id]['frames'][0]) for frame_name in seq_data[obj_id]['frames']: if frame_name not in frame_names: frame_names.append(frame_name) frame_names.sort() frame_names_with_new_objects_idxs = [] for kk in range(len(frame_names_with_new_objects)): new_frame_idx = frame_names.index(frame_names_with_new_objects[kk]) frame_names_with_new_objects_idxs.append(new_frame_idx) elif args.dataset == 'davis2017': key_db = osp.basename(seq_name[0]) if not lmdb_env_seq == None: with lmdb_env_seq.begin() as txn: _files_vec = txn.get(key_db.encode()).decode().split('|') _files = [osp.splitext(f)[0] for f in _files_vec] else: seq_dir = osp.join(cfg['PATH']['SEQUENCES'], key_db) _files_vec = os.listdir(seq_dir) _files = [osp.splitext(f)[0] for f in _files_vec] frame_names = sorted(_files) else: key_db = osp.basename(seq_name[0]) if not lmdb_env_seq == None: with lmdb_env_seq.begin() as txn: _files_vec = txn.get(key_db.encode()).decode().split('|') _files = [osp.splitext(f)[0] for f in _files_vec] else: seq_dir = osp.join(cfg['PATH']['SEQUENCES'], key_db) _files_vec = os.listdir(seq_dir) _files = [osp.splitext(f)[0] for f in _files_vec] # frame_names_with_new_objects_idxs = [3,6,9] frame_names = sorted(_files) for ii in range(max_ii): # x: input images (N consecutive frames from M different sequences) # y_mask: ground truth annotations (some of them are zeros to have a fixed length in number of object instances) # sw_mask: this mask indicates which masks from y_mask are valid x = batch_to_var_test(args, inputs[ii]) print(seq_name[0] + '/' + frame_names[ii]) if ii == 0: frame_name = frame_names[0] if args.dataset == 'youtube': annotation = Image.open( '../../databases/YouTubeVOS/val/Annotations/' + seq_name[0] + '/' + frame_name + '.png') annot = imresize(annotation, (256, 448), interp='nearest') elif args.dataset == 'davis2017': annotation = Image.open( '../../databases/DAVIS2017/Annotations/480p/' + seq_name[0] + '/' + frame_name + '.png') instance_ids = sorted(np.unique(annotation)) instance_ids = instance_ids if instance_ids[0] else instance_ids[1:] if len(instance_ids) > 0: instance_ids = instance_ids[:-1] if instance_ids[-1] == 255 else instance_ids annot = imresize(annotation, (240, 427), interp='nearest') else: # kittimots annotation = Image.open( '../../databases/KITTIMOTS/Annotations/' + seq_name[0] + '/' + frame_name + '.png') instance_ids = sorted(np.unique(annotation)) instance_ids = instance_ids if instance_ids[0] else instance_ids[1:] print("IDS instances: ", instance_ids) if len(instance_ids) > 0: instance_ids = instance_ids[:-1] if instance_ids[-1] == 255 else instance_ids annot = imresize(annotation, (256, 448), interp='nearest') annot = np.expand_dims(annot, axis=0) annot = torch.from_numpy(annot) annot = annot.float() annot = annot.numpy().squeeze() annot = annot_from_mask(annot, instance_ids) prev_mask = annot prev_mask = np.expand_dims(prev_mask, axis=0) prev_mask = torch.from_numpy(prev_mask) y_mask = Variable(prev_mask.float(), requires_grad=False) prev_mask = y_mask.cuda() del annot if args.dataset == 'youtube': if ii > 0 and ii in frame_names_with_new_objects_idxs: frame_name = frame_names[ii] annotation = Image.open( '../../databases/YouTubeVOS/val/Annotations/' + seq_name[0] + '/' + frame_name + '.png') annot = imresize(annotation, (256, 448), interp='nearest') annot = np.expand_dims(annot, axis=0) annot = torch.from_numpy(annot) annot = annot.float() annot = annot.numpy().squeeze() new_instance_ids = np.unique(annot)[1:] annot = annot_from_mask(annot, new_instance_ids) annot = np.expand_dims(annot, axis=0) annot = torch.from_numpy(annot) annot = Variable(annot.float(), requires_grad=False) annot = annot.cuda() for kk in new_instance_ids: prev_mask[:, int(kk - 1), :] = annot[:, int(kk - 1), :] del annot # from one frame to the following frame the prev_hidden_temporal_list is updated. outs, hidden_temporal_list = test_prev_mask(args, self.encoder, self.decoder, x, prev_hidden_temporal_list, prev_mask) base_dir_masks_sep = masks_sep_dir + '/' + seq_name[0] + '/' make_dir(base_dir_masks_sep) x_tmp = x.data.cpu().numpy() height = x_tmp.shape[-2] width = x_tmp.shape[-1] for t in range(len(instance_ids)): mask_pred = (torch.squeeze(outs[0, t, :])).cpu().numpy() mask_pred = np.reshape(mask_pred, (height, width)) indxs_instance = np.where(mask_pred > 0.5) mask2assess = np.zeros((height, width)) mask2assess[indxs_instance] = 255 toimage(mask2assess, cmin=0, cmax=255).save( base_dir_masks_sep + frame_names[ii] + '_instance_%02d.png' % (t)) if args.overlay_masks: frame_img = x.data.cpu().numpy()[0, :, :, :].squeeze() frame_img = np.transpose(frame_img, (1, 2, 0)) mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) frame_img = std * frame_img + mean frame_img = np.clip(frame_img, 0, 1) plt.figure(); plt.axis('off') plt.figure(); plt.axis('off') plt.imshow(frame_img) for t in range(len(instance_ids)): mask_pred = (torch.squeeze(outs[0, t, :])).cpu().numpy() mask_pred = np.reshape(mask_pred, (height, width)) ax = plt.gca() tmp_img = np.ones((mask_pred.shape[0], mask_pred.shape[1], 3)) color_mask = np.array(colors[t]) / 255.0 for i in range(3): tmp_img[:, :, i] = color_mask[i] ax.imshow(np.dstack((tmp_img, mask_pred * 0.7))) figname = base_dir + frame_names[ii] + '.png' plt.savefig(figname, bbox_inches='tight') plt.close() if self.video_mode: if args.only_spatial == False: prev_hidden_temporal_list = hidden_temporal_list if ii > 0: prev_mask = outs del x, hidden_temporal_list, outs
class Evaluate(): def __init__(self, args): self.split = args.eval_split self.display = args.display self.dataset = args.dataset self.all_classes = args.all_classes self.T = args.maxseqlen self.batch_size = args.batch_size to_tensor = transforms.ToTensor() normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) image_transforms = transforms.Compose([to_tensor, normalize]) dataset = get_dataset(args, self.split, image_transforms, augment=False, imsize=args.imsize) self.loader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=False) self.sample_list = dataset.get_sample_list() self.args = args encoder_dict, decoder_dict, _, _, load_args = load_checkpoint( args.model_name) self.args.use_feedback = load_args.use_feedback self.args.base_model = load_args.base_model self.hidden_size = load_args.hidden_size self.args.nconvlstm = load_args.nconvlstm self.encoder = FeatureExtractor(load_args) self.decoder = RSIS(load_args) if args.ngpus > 1 and args.use_gpu: self.decoder = torch.nn.DataParallel(self.decoder, device_ids=range(args.ngpus)) self.encoder = torch.nn.DataParallel(self.encoder, device_ids=range(args.ngpus)) # check if the model was trained using multiple gpus trained_parallel = False for k, v in encoder_dict.items(): if k[:7] == "module.": trained_parallel = True break if trained_parallel and not args.ngpus > 1: # create new OrderedDict that does not contain "module." new_encoder_state_dict = OrderedDict() new_decoder_state_dict = OrderedDict() for k, v in encoder_dict.items(): name = k[7:] # remove "module." new_encoder_state_dict[name] = v for k, v in decoder_dict.items(): name = k[7:] # remove "module." new_decoder_state_dict[name] = v encoder_dict = new_encoder_state_dict decoder_dict = new_decoder_state_dict self.encoder.load_state_dict(encoder_dict) self.decoder.load_state_dict(decoder_dict) if args.use_gpu: self.encoder.cuda() self.decoder.cuda() self.encoder.eval() self.decoder.eval() def create_figures(self): acc_samples = 0 results_root_dir = os.path.join('../models', args.model_name, args.model_name + '_results') make_dir(results_root_dir) results_dir = os.path.join(results_root_dir, 'A1') make_dir(results_dir) print "Creating annotations for leaves validation..." for batch_idx, (inputs, targets) in enumerate(self.loader): x, y_mask, y_class, sw_mask, sw_class = batch_to_var( self.args, inputs, targets) out_masks, _, stop_probs = test(self.args, self.encoder, self.decoder, x) for sample in range(self.batch_size): sample_idx = self.sample_list[sample + acc_samples] image_dir = os.path.join(sample_idx.split('.')[0] + '.png') im = scipy.misc.imread(image_dir) h = im.shape[0] w = im.shape[1] mask_sample = np.zeros([h, w]) sample_idx = sample_idx.split('/')[-1].split('.')[0] img_masks = out_masks[sample] instance_id = 0 class_scores = stop_probs[sample] for time_step in range(self.T): mask = img_masks[time_step].cpu().numpy() mask = scipy.misc.imresize(mask, [h, w]) class_scores_mask = class_scores[time_step].cpu().numpy() class_score = class_scores_mask[0] if class_score > args.class_th: mask_sample[mask > args.mask_th * 255] = time_step instance_id += 1 file_name = os.path.join(results_dir, sample_idx + '.png') file_name_prediction = file_name.replace( 'rgb.png', 'label.png') im = Image.fromarray(mask_sample).convert('L') im.save(file_name_prediction) acc_samples += self.batch_size
def trainIters(args): epoch_resume = 0 model_dir = os.path.join('../models/', args.model_name) if args.resume: # will resume training the model with name args.model_name encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, load_args = load_checkpoint(args.model_name,args.use_gpu) epoch_resume = load_args.epoch_resume encoder = FeatureExtractor(load_args) decoder = RSIS(load_args) encoder_dict, decoder_dict = check_parallel(encoder_dict,decoder_dict) encoder.load_state_dict(encoder_dict) decoder.load_state_dict(decoder_dict) args = load_args elif args.transfer: # load model from args and replace last fc layer encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, load_args = load_checkpoint(args.transfer_from,args.use_gpu) encoder = FeatureExtractor(load_args) decoder = RSIS(args) encoder_dict, decoder_dict = check_parallel(encoder_dict,decoder_dict) encoder.load_state_dict(encoder_dict) decoder.load_state_dict(decoder_dict) else: encoder = FeatureExtractor(args) decoder = RSIS(args) # model checkpoints will be saved here make_dir(model_dir) # save parameters for future use pickle.dump(args, open(os.path.join(model_dir,'args.pkl'),'wb')) encoder_params = get_base_params(args,encoder) skip_params = get_skip_params(encoder) decoder_params = list(decoder.parameters()) + list(skip_params) dec_opt = get_optimizer(args.optim, args.lr, decoder_params, args.weight_decay) enc_opt = get_optimizer(args.optim_cnn, args.lr_cnn, encoder_params, args.weight_decay_cnn) if args.resume or args.transfer: enc_opt.load_state_dict(enc_opt_dict) dec_opt.load_state_dict(dec_opt_dict) from collections import defaultdict dec_opt.state = defaultdict(dict, dec_opt.state) # change fc layer for new classes if load_args.dataset != args.dataset and args.transfer: dim_in = decoder.fc_class.weight.size()[1] decoder.fc_class = nn.Linear(dim_in,args.num_classes) if not args.log_term: print "Training logs will be saved to:", os.path.join(model_dir, 'train.log') sys.stdout = open(os.path.join(model_dir, 'train.log'), 'w') sys.stderr = open(os.path.join(model_dir, 'train.err'), 'w') print args # objective functions for mask and class outputs. # these return the average across samples in batch whose value # needs to be considered (those where sw is 1) # mask_xentropy = BalancedStableMaskedBCELoss() mask_siou = softIoULoss() class_xentropy = MaskedNLLLoss(balance_weight=None) stop_xentropy = MaskedBCELoss(balance_weight=args.stop_balance_weight) if args.ngpus > 1 and args.use_gpu: decoder = torch.nn.DataParallel(decoder, device_ids=range(args.ngpus)) encoder = torch.nn.DataParallel(encoder, device_ids=range(args.ngpus)) mask_siou = torch.nn.DataParallel(mask_siou, device_ids=range(args.ngpus)) class_xentropy = torch.nn.DataParallel(class_xentropy, device_ids=range(args.ngpus)) stop_xentropy = torch.nn.DataParallel(stop_xentropy, device_ids=range(args.ngpus)) if args.use_gpu: encoder.cuda() decoder.cuda() class_xentropy.cuda() mask_siou.cuda() stop_xentropy.cuda() crits = [mask_siou, class_xentropy, stop_xentropy] optims = [enc_opt, dec_opt] if args.use_gpu: torch.cuda.synchronize() start = time.time() # vars for early stopping best_val_loss = args.best_val_loss acc_patience = 0 mt_val = -1 # init windows to visualize, if visdom is enabled if args.visdom: import visdom viz = visdom.Visdom(port=args.port, server=args.server) lot, elot, mviz_pred, mviz_true, image_lot = init_visdom(args, viz) if args.curriculum_learning and epoch_resume == 0: args.limit_seqlen_to = 2 # keep track of the number of batches in each epoch for continuity when plotting curves loaders, class_names = init_dataloaders(args) num_batches = {'train': 0, 'val': 0} for e in range(args.max_epoch): print "Epoch", e + epoch_resume # store losses in lists to display average since beginning epoch_losses = {'train': {'total': [], 'iou': [], 'stop': [], 'class': []}, 'val': {'total': [], 'iou': [], 'stop': [], 'class': []}} # total mean for epoch will be saved here to display at the end total_losses = {'total': [], 'iou': [], 'stop': [], 'class': []} # check if it's time to do some changes here if e + epoch_resume >= args.finetune_after and not args.update_encoder and not args.finetune_after == -1: print("Starting to update encoder") args.update_encoder = True acc_patience = 0 mt_val = -1 if e + epoch_resume >= args.class_loss_after and not args.use_class_loss and not args.class_loss_after == -1: print("Starting to learn class loss") args.use_class_loss = True best_val_loss = 1000 # reset because adding a loss term will increase the total value acc_patience = 0 mt_val = -1 if e + epoch_resume >= args.stop_loss_after and not args.use_stop_loss and not args.stop_loss_after == -1: if args.curriculum_learning: if args.limit_seqlen_to > args.min_steps: print("Starting to learn stop loss") args.use_stop_loss = True best_val_loss = 1000 # reset because adding a loss term will increase the total value acc_patience = 0 mt_val = -1 else: print("Starting to learn stop loss") args.use_stop_loss = True best_val_loss = 1000 # reset because adding a loss term will increase the total value acc_patience = 0 mt_val = -1 # we validate after each epoch for split in ['train', 'val']: for batch_idx, (inputs, targets) in enumerate(loaders[split]): # send batch to GPU x, y_mask, y_class, sw_mask, sw_class = batch_to_var(args, inputs, targets) # we forward (and backward & update if training set) losses, outs, true_perm = runIter(args, encoder, decoder, x, y_mask, y_class, sw_mask, sw_class, crits, optims, mode=split) # store loss values in dictionary separately epoch_losses[split]['total'].append(losses[0]) epoch_losses[split]['iou'].append(losses[1]) epoch_losses[split]['stop'].append(losses[2]) epoch_losses[split]['class'].append(losses[3]) # print and display in visdom after some iterations if (batch_idx + 1)% args.print_every == 0: mt = np.mean(epoch_losses[split]['total']) mi = np.mean(epoch_losses[split]['iou']) mc = np.mean(epoch_losses[split]['class']) mx = np.mean(epoch_losses[split]['stop']) if args.visdom: if split == 'train': # we display batch loss values in visdom (Training only) viz.line( X=torch.ones((1, 4)).cpu() * (batch_idx + e * num_batches[split]), Y=torch.Tensor([mi, mx, mc, mt]).unsqueeze(0).cpu(), win=lot, update='append') w = x.size()[-1] h = x.size()[-2] out_masks, out_classes, y_mask, y_class = outs_perms_to_cpu(args, outs, true_perm, h, w) x = x.data.cpu().numpy() # send image, sample predictions and ground truths to visdom for t in range(np.shape(out_masks)[1]): mask_pred = out_masks[0, t] mask_true = y_mask[0, t] class_pred = class_names[out_classes[0, t]] class_true = class_names[y_class[0, t]] mask_pred = np.reshape(mask_pred, (x.shape[-2], x.shape[-1])) mask_true = np.reshape(mask_true, (x.shape[-2], x.shape[-1])) # heatmap displays the mask upside down viz.heatmap(np.flipud(mask_pred), win=mviz_pred[t], opts=dict(title='pred mask %d %s' % (t, class_pred))) viz.heatmap(np.flipud(mask_true), win=mviz_true[t], opts=dict(title='true mask %d %s' % (t, class_true))) viz.image((x[0] * 0.2 + 0.5) * 256, win=image_lot, opts=dict(title='image (unnnormalized)')) te = time.time() - start print "iter %d:\ttotal:%.4f\tclass:%.4f\tiou:%.4f\tstop:%.4f\ttime:%.4f" % (batch_idx, mt, mc, mi, mx, te) if args.use_gpu: torch.cuda.synchronize() start = time.time() num_batches[split] = batch_idx + 1 # compute mean val losses within epoch if split == 'val' and args.smooth_curves: if mt_val == -1: mt = np.mean(epoch_losses[split]['total']) else: mt = 0.9*mt_val + 0.1*np.mean(epoch_losses[split]['total']) mt_val = mt else: mt = np.mean(epoch_losses[split]['total']) mi = np.mean(epoch_losses[split]['iou']) mc = np.mean(epoch_losses[split]['class']) mx = np.mean(epoch_losses[split]['stop']) # save train and val losses for the epoch to display in visdom total_losses['iou'].append(mi) total_losses['class'].append(mc) total_losses['stop'].append(mx) total_losses['total'].append(mt) args.epoch_resume = e + epoch_resume print "Epoch %d:\ttotal:%.4f\tclass:%.4f\tiou:%.4f\tstop:%.4f\t(%s)" % (e, mt, mc, mi,mx, split) # epoch losses if args.visdom: update = True if e == 0 else 'append' for l in ['total', 'iou', 'stop', 'class']: viz.line(X=torch.ones((1, 2)).cpu() * (e + 1), Y=torch.Tensor(total_losses[l]).unsqueeze(0).cpu(), win=elot[l], update=update) if mt < (best_val_loss - args.min_delta): print "Saving checkpoint." best_val_loss = mt args.best_val_loss = best_val_loss # saves model, params, and optimizers save_checkpoint(args, encoder, decoder, enc_opt, dec_opt) acc_patience = 0 else: acc_patience += 1 if acc_patience > args.patience and not args.use_class_loss and not args.class_loss_after == -1: print("Starting to learn class loss") acc_patience = 0 args.use_class_loss = True best_val_loss = 1000 # reset because adding a loss term will increase the total value mt_val = -1 encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, _ = load_checkpoint(args.model_name,args.use_gpu) encoder.load_state_dict(encoder_dict) decoder.load_state_dict(decoder_dict) enc_opt.load_state_dict(enc_opt_dict) dec_opt.load_state_dict(dec_opt_dict) if acc_patience > args.patience and args.curriculum_learning and args.limit_seqlen_to < args.maxseqlen: print("Adding one step more:") acc_patience = 0 args.limit_seqlen_to += args.steps_cl print(args.limit_seqlen_to) best_val_loss = 1000 mt_val = -1 if acc_patience > args.patience and not args.update_encoder and not args.finetune_after == -1: print("Starting to update encoder") acc_patience = 0 args.update_encoder = True best_val_loss = 1000 # reset because adding a loss term will increase the total value mt_val = -1 encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, _ = load_checkpoint(args.model_name,args.use_gpu) encoder.load_state_dict(encoder_dict) decoder.load_state_dict(decoder_dict) enc_opt.load_state_dict(enc_opt_dict) dec_opt.load_state_dict(dec_opt_dict) if acc_patience > args.patience and not args.use_stop_loss and not args.stop_loss_after == -1: if args.curriculum_learning: print("Starting to learn stop loss") if args.limit_seqlen_to > args.min_steps: acc_patience = 0 args.use_stop_loss = True best_val_loss = 1000 # reset because adding a loss term will increase the total value mt_val = -1 else: print("Starting to learn stop loss") acc_patience = 0 args.use_stop_loss = True best_val_loss = 1000 # reset because adding a loss term will increase the total value mt_val = -1 encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, _ = load_checkpoint(args.model_name,args.use_gpu) encoder.load_state_dict(encoder_dict) decoder.load_state_dict(decoder_dict) enc_opt.load_state_dict(enc_opt_dict) dec_opt.load_state_dict(dec_opt_dict) # early stopping after N epochs without improvement if acc_patience > args.patience_stop: break