def _init_model(self, args): print("Loading model: " + args.model_name) encoder_dict, decoder_dict, _, _, load_args = load_checkpoint( args.model_name, args.use_gpu) load_args.use_gpu = args.use_gpu self.encoder = FeatureExtractor(load_args) if args.zero_shot: self.decoder = RSIS(load_args) else: self.decoder = RSISMask(load_args) print(load_args) if args.ngpus > 1 and args.use_gpu: self.decoder = torch.nn.DataParallel(self.decoder, device_ids=range(args.ngpus)) self.encoder = torch.nn.DataParallel(self.encoder, device_ids=range(args.ngpus)) encoder_dict, decoder_dict = check_parallel(encoder_dict, decoder_dict) self.encoder.load_state_dict(encoder_dict) to_be_deleted_dec = [] for k in decoder_dict.keys(): if 'fc_stop' in k: to_be_deleted_dec.append(k) for k in to_be_deleted_dec: del decoder_dict[k] self.decoder.load_state_dict(decoder_dict) if args.use_gpu: self.encoder.cuda() self.decoder.cuda() self.encoder.eval() self.decoder.eval() if load_args.length_clip == 1: self.video_mode = False print('video mode not activated') else: self.video_mode = True print('video mode activated')
def trainIters(args): epoch_resume = 0 model_dir = os.path.join('../models/', args.model_name + '_prev_inference_mask') if args.resume: # will resume training the model with name args.model_name encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, load_args = load_checkpoint( args.model_name, args.use_gpu) epoch_resume = load_args.epoch_resume encoder = FeatureExtractor(load_args) decoder = RSISMask(load_args) encoder_dict, decoder_dict = check_parallel(encoder_dict, decoder_dict) encoder.load_state_dict(encoder_dict) decoder.load_state_dict(decoder_dict) args = load_args elif args.transfer: # load model from args and replace last fc layer encoder_dict, decoder_dict, _, _, load_args = load_checkpoint( args.transfer_from, args.use_gpu) encoder = FeatureExtractor(load_args) decoder = RSISMask(args) encoder_dict, decoder_dict = check_parallel(encoder_dict, decoder_dict) encoder.load_state_dict(encoder_dict) decoder.load_state_dict(decoder_dict) else: encoder = FeatureExtractor(args) decoder = RSISMask(args) # model checkpoints will be saved here make_dir(model_dir) # save parameters for future use pickle.dump(args, open(os.path.join(model_dir, 'args.pkl'), 'wb')) encoder_params = get_base_params(args, encoder) skip_params = get_skip_params(encoder) decoder_params = list(decoder.parameters()) + list(skip_params) dec_opt = get_optimizer(args.optim, args.lr, decoder_params, args.weight_decay) enc_opt = get_optimizer(args.optim_cnn, args.lr_cnn, encoder_params, args.weight_decay_cnn) if args.resume: enc_opt.load_state_dict(enc_opt_dict) dec_opt.load_state_dict(dec_opt_dict) from collections import defaultdict dec_opt.state = defaultdict(dict, dec_opt.state) if not args.log_term: print("Training logs will be saved to:", os.path.join(model_dir, 'train.log')) sys.stdout = open(os.path.join(model_dir, 'train.log'), 'w') sys.stderr = open(os.path.join(model_dir, 'train.err'), 'w') print(args) # objective function for mask mask_siou = softIoULoss() if args.use_gpu: encoder.cuda() decoder.cuda() mask_siou.cuda() crits = mask_siou optims = [enc_opt, dec_opt] if args.use_gpu: torch.cuda.synchronize() start = time.time() # vars for early stopping best_val_loss = args.best_val_loss acc_patience = 0 mt_val = -1 # keep track of the number of batches in each epoch for continuity when plotting curves loaders = init_dataloaders(args) num_batches = {'train': 0, 'val': 0} #area_range = [[0 ** 2, 1e5 ** 2], [0 ** 2, 20 ** 2], [20 ** 2, 59 ** 2], [59 ** 2, 1e5 ** 2]] area_range = [[0**2, 1e5**2], [0**2, 30**2], [30**2, 90**2], [90**2, 1e5**2]] #for (287,950)) resolution = 0 for e in range(args.max_epoch): print("Epoch", e + epoch_resume) # store losses in lists to display average since beginning epoch_losses = { 'train': { 'total': [], 'iou': [] }, 'val': { 'total': [], 'iou': [] } } # total mean for epoch will be saved here to display at the end total_losses = {'total': [], 'iou': []} # check if it's time to do some changes here if e + epoch_resume >= args.finetune_after and not args.update_encoder and not args.finetune_after == -1: print("Starting to update encoder") args.update_encoder = True acc_patience = 0 mt_val = -1 if args.loss_penalization: if e < 10: resolution = area_range[2] else: resolution = area_range[0] # we validate after each epoch for split in ['train', 'val']: if args.dataset == 'davis2017' or args.dataset == 'youtube' or args.dataset == 'kittimots': loaders[split].dataset.set_epoch(e) for batch_idx, (inputs, targets, seq_name, starting_frame) in enumerate(loaders[split]): # send batch to GPU prev_hidden_temporal_list = None loss = None last_frame = False max_ii = min(len(inputs), args.length_clip) for ii in range(max_ii): # If are on the last frame from a clip, we will have to backpropagate the loss back to the beginning of the clip. if ii == max_ii - 1: last_frame = True # x: input images (N consecutive frames from M different sequences) # y_mask: ground truth annotations (some of them are zeros to have a fixed length in number of object instances) # sw_mask: this mask indicates which masks from y_mask are valid x, y_mask, sw_mask = batch_to_var( args, inputs[ii], targets[ii]) if ii == 0: prev_mask = y_mask # From one frame to the following frame the prev_hidden_temporal_list is updated. loss, losses, outs, hidden_temporal_list = runIter( args, encoder, decoder, x, y_mask, sw_mask, resolution, crits, optims, split, loss, prev_hidden_temporal_list, prev_mask, last_frame) # Hidden temporal state from time instant ii is saved to be used when processing next time instant ii+1 if args.only_spatial == False: prev_hidden_temporal_list = hidden_temporal_list prev_mask = outs # store loss values in dictionary separately epoch_losses[split]['total'].append(losses[0]) epoch_losses[split]['iou'].append(losses[1]) # print after some iterations if (batch_idx + 1) % args.print_every == 0: mt = np.mean(epoch_losses[split]['total']) mi = np.mean(epoch_losses[split]['iou']) te = time.time() - start print("iter %d:\ttotal:%.4f\tiou:%.4f\ttime:%.4f" % (batch_idx, mt, mi, te)) if args.use_gpu: torch.cuda.synchronize() start = time.time() num_batches[split] = batch_idx + 1 # compute mean val losses within epoch if split == 'val' and args.smooth_curves: if mt_val == -1: mt = np.mean(epoch_losses[split]['total']) else: mt = 0.9 * mt_val + 0.1 * np.mean( epoch_losses[split]['total']) mt_val = mt else: mt = np.mean(epoch_losses[split]['total']) mi = np.mean(epoch_losses[split]['iou']) # save train and val losses for the epoch total_losses['iou'].append(mi) total_losses['total'].append(mt) args.epoch_resume = e + epoch_resume print("Epoch %d:\ttotal:%.4f\tiou:%.4f\t(%s)" % (e, mt, mi, split)) if mt < (best_val_loss - args.min_delta): print("Saving checkpoint.") best_val_loss = mt args.best_val_loss = best_val_loss # saves model, params, and optimizers save_checkpoint_prev_inference_mask(args, encoder, decoder, enc_opt, dec_opt) acc_patience = 0 else: acc_patience += 1 if acc_patience > args.patience and not args.update_encoder and not args.finetune_after == -1: print("Starting to update encoder") acc_patience = 0 args.update_encoder = True best_val_loss = 1000 # reset because adding a loss term will increase the total value mt_val = -1 encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, _ = load_checkpoint( args.model_name, args.use_gpu) encoder.load_state_dict(encoder_dict) decoder.load_state_dict(decoder_dict) enc_opt.load_state_dict(enc_opt_dict) dec_opt.load_state_dict(dec_opt_dict) # early stopping after N epochs without improvement if acc_patience > args.patience_stop: break
def __init__(self, args): self.split = args.eval_split self.dataset = args.dataset to_tensor = transforms.ToTensor() normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) image_transforms = transforms.Compose([to_tensor, normalize]) if args.dataset == 'youtube': dataset = get_dataset(args, split=self.split, image_transforms=image_transforms, target_transforms=None, augment=args.augment and self.split == 'train', inputRes=(256, 448), video_mode=True, use_prev_mask=False) else: dataset = get_dataset(args, split=self.split, image_transforms=image_transforms, target_transforms=None, augment=args.augment and self.split == 'train', video_mode=True, use_prev_mask=False) self.loader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=False) self.args = args print(args.model_name) encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, load_args = load_checkpoint( args.model_name, args.use_gpu) load_args.use_gpu = args.use_gpu self.encoder = FeatureExtractor(load_args) self.decoder = RSIS(load_args) print(load_args) if args.ngpus > 1 and args.use_gpu: self.decoder = torch.nn.DataParallel(self.decoder, device_ids=range(args.ngpus)) self.encoder = torch.nn.DataParallel(self.encoder, device_ids=range(args.ngpus)) encoder_dict, decoder_dict = check_parallel(encoder_dict, decoder_dict) self.encoder.load_state_dict(encoder_dict) to_be_deleted_dec = [] for k in decoder_dict.keys(): if 'fc_stop' in k: to_be_deleted_dec.append(k) for k in to_be_deleted_dec: del decoder_dict[k] self.decoder.load_state_dict(decoder_dict) if args.use_gpu: self.encoder.cuda() self.decoder.cuda() self.encoder.eval() self.decoder.eval() if load_args.length_clip == 1: self.video_mode = False print('video mode not activated') else: self.video_mode = True print('video mode activated')
def trainIters(args): print(args) model_dir = os.path.join('ckpt/', args.model_name) make_dir(model_dir) epoch_resume = 0 if args.resume: encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, load_args = \ load_checkpoint_epoch(args.model_name, args.epoch_resume, args.use_gpu) epoch_resume = args.epoch_resume encoder = Encoder() decoder = Decoder() encoder_dict, decoder_dict = check_parallel(encoder_dict, decoder_dict) encoder.load_state_dict(encoder_dict) decoder.load_state_dict(decoder_dict) else: encoder = Encoder() decoder = Decoder() criterion = WeightedBCE2d() if args.use_gpu: encoder.cuda() decoder.cuda() criterion.cuda() encoder_params = list(encoder.parameters()) decoder_params = list(decoder.parameters()) dec_opt = get_optimizer(args.optim, args.lr, decoder_params, args.weight_decay) enc_opt = get_optimizer(args.optim_cnn, args.lr_cnn, encoder_params, args.weight_decay_cnn) loaders = init_dataloaders(args) best_iou = 0 start = time.time() for e in range(epoch_resume, args.max_epoch): print("Epoch", e) epoch_losses = { 'train': { 'total': [], 'iou': [], 'mask_loss': [], 'bdry_loss': [] }, 'val': { 'total': [], 'iou': [], 'mask_loss': [], 'bdry_loss': [] } } for split in ['train', 'val']: if split == 'train': encoder.train(True) decoder.train(True) else: encoder.train(False) decoder.train(False) for batch_idx, (image, flow, mask, bdry, negative_pixels) in\ enumerate(loaders[split]): image, flow, mask, bdry, negative_pixels = \ image.cuda(), flow.cuda(), mask.cuda(), bdry.cuda(),\ negative_pixels.cuda() if split == 'train': r5, r4, r3, r2 = encoder(image, flow) mask_pred, p1, p2, p3, p4, p5 = decoder(r5, r4, r3, r2) mask_loss = criterion(mask_pred, mask, negative_pixels) bdry_loss = criterion(p1, bdry, negative_pixels) + \ criterion(p2, bdry, negative_pixels) + \ criterion(p3, bdry, negative_pixels) + \ criterion(p4, bdry, negative_pixels) + \ criterion(p5, bdry, negative_pixels) loss = mask_loss + 0.2 * bdry_loss iou = db_eval_iou_multi(mask.cpu().detach().numpy(), mask_pred.cpu().detach().numpy()) dec_opt.zero_grad() enc_opt.zero_grad() loss.backward() enc_opt.step() dec_opt.step() else: with torch.no_grad(): r5, r4, r3, r2 = encoder(image, flow) mask_pred, p1, p2, p3, p4, p5 = decoder(r5, r4, r3, r2) mask_loss = criterion(mask_pred, mask, negative_pixels) bdry_loss = criterion(p1, bdry, negative_pixels) + \ criterion(p2, bdry, negative_pixels) + \ criterion(p3, bdry, negative_pixels) + \ criterion(p4, bdry, negative_pixels) + \ criterion(p5, bdry, negative_pixels) loss = mask_loss + 0.2 * bdry_loss iou = db_eval_iou_multi(mask.cpu().detach().numpy(), mask_pred.cpu().detach().numpy()) epoch_losses[split]['total'].append(loss.data.item()) epoch_losses[split]['mask_loss'].append(mask_loss.data.item()) epoch_losses[split]['bdry_loss'].append(bdry_loss.data.item()) epoch_losses[split]['iou'].append(iou) if (batch_idx + 1) % args.print_every == 0: mt = np.mean(epoch_losses[split]['total']) mmask = np.mean(epoch_losses[split]['mask_loss']) mbdry = np.mean(epoch_losses[split]['bdry_loss']) miou = np.mean(epoch_losses[split]['iou']) te = time.time() - start print('Epoch: [{}/{}][{}/{}]\tTime {:.3f}s\tLoss: {:.4f}' '\tMask Loss: {:.4f}\tBdry Loss: {:.4f}' '\tIOU: {:.4f}'.format(e, args.max_epoch, batch_idx, len(loaders[split]), te, mt, mmask, mbdry, miou)) start = time.time() miou = np.mean(epoch_losses['val']['iou']) if miou > best_iou: best_iou = miou save_checkpoint_epoch(args, encoder, decoder, enc_opt, dec_opt, e, False)
def __init__(self, args): self.split = args.eval_split self.display = args.display self.no_display_text = args.no_display_text self.dataset = args.dataset self.all_classes = args.all_classes self.use_cats = args.use_cats to_tensor = transforms.ToTensor() normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) image_transforms = transforms.Compose([to_tensor, normalize]) dataset = get_dataset(args, self.split, image_transforms, augment=False, imsize=args.imsize) self.sample_list = dataset.get_sample_list() self.class_names = dataset.get_classes() if args.dataset == 'pascal': self.gt_file = pickle.load( open( os.path.join(args.pascal_dir, 'VOCGT_%s.pkl' % (self.split)), 'rb')) self.key_to_anns = dict() self.ignoremasks = {} for ann in self.gt_file: if ann['ignore'] == 1: if type(ann['segmentation']['counts']) == list: im_height = ann['segmentation']['size'][0] im_width = ann['segmentation']['size'][1] rle = mask.frPyObjects([ann['segmentation']], im_height, im_width) else: rle = [ann['segmentation']] m = mask.decode(rle) self.ignoremasks[ann['image_id']] = m if ann['image_id'] in self.key_to_anns.keys(): self.key_to_anns[ann['image_id']].append(ann) else: self.key_to_anns[ann['image_id']] = [ann] self.coco = create_coco_object(args, self.sample_list, self.class_names) self.loader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=False) self.args = args self.colors = [] palette = sequence_palette() inv_palette = {} for k, v in palette.iteritems(): inv_palette[v] = k num_colors = len(inv_palette.keys()) for i in range(num_colors): if i == 0 or i == 21: continue c = inv_palette[i] self.colors.append(c) encoder_dict, decoder_dict, _, _, load_args = load_checkpoint( args.model_name, args.use_gpu) load_args.use_gpu = args.use_gpu self.encoder = FeatureExtractor(load_args) self.decoder = RSIS(load_args) print(load_args) if args.ngpus > 1 and args.use_gpu: self.decoder = torch.nn.DataParallel(self.decoder, device_ids=range(args.ngpus)) self.encoder = torch.nn.DataParallel(self.encoder, device_ids=range(args.ngpus)) encoder_dict, decoder_dict = check_parallel(encoder_dict, decoder_dict) self.encoder.load_state_dict(encoder_dict) self.decoder.load_state_dict(decoder_dict) if args.use_gpu: self.encoder.cuda() self.decoder.cuda() self.encoder.eval() self.decoder.eval()
def trainIters(args): epoch_resume = 0 model_dir = os.path.join('../models/', args.model_name) if args.resume: # will resume training the model with name args.model_name encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, load_args = load_checkpoint(args.model_name,args.use_gpu) epoch_resume = load_args.epoch_resume encoder = FeatureExtractor(load_args) decoder = RSIS(load_args) encoder_dict, decoder_dict = check_parallel(encoder_dict,decoder_dict) encoder.load_state_dict(encoder_dict) decoder.load_state_dict(decoder_dict) args = load_args elif args.transfer: # load model from args and replace last fc layer encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, load_args = load_checkpoint(args.transfer_from,args.use_gpu) encoder = FeatureExtractor(load_args) decoder = RSIS(args) encoder_dict, decoder_dict = check_parallel(encoder_dict,decoder_dict) encoder.load_state_dict(encoder_dict) decoder.load_state_dict(decoder_dict) else: encoder = FeatureExtractor(args) decoder = RSIS(args) # model checkpoints will be saved here make_dir(model_dir) # save parameters for future use pickle.dump(args, open(os.path.join(model_dir,'args.pkl'),'wb')) encoder_params = get_base_params(args,encoder) skip_params = get_skip_params(encoder) decoder_params = list(decoder.parameters()) + list(skip_params) dec_opt = get_optimizer(args.optim, args.lr, decoder_params, args.weight_decay) enc_opt = get_optimizer(args.optim_cnn, args.lr_cnn, encoder_params, args.weight_decay_cnn) if args.resume or args.transfer: enc_opt.load_state_dict(enc_opt_dict) dec_opt.load_state_dict(dec_opt_dict) from collections import defaultdict dec_opt.state = defaultdict(dict, dec_opt.state) # change fc layer for new classes if load_args.dataset != args.dataset and args.transfer: dim_in = decoder.fc_class.weight.size()[1] decoder.fc_class = nn.Linear(dim_in,args.num_classes) if not args.log_term: print "Training logs will be saved to:", os.path.join(model_dir, 'train.log') sys.stdout = open(os.path.join(model_dir, 'train.log'), 'w') sys.stderr = open(os.path.join(model_dir, 'train.err'), 'w') print args # objective functions for mask and class outputs. # these return the average across samples in batch whose value # needs to be considered (those where sw is 1) # mask_xentropy = BalancedStableMaskedBCELoss() mask_siou = softIoULoss() class_xentropy = MaskedNLLLoss(balance_weight=None) stop_xentropy = MaskedBCELoss(balance_weight=args.stop_balance_weight) if args.ngpus > 1 and args.use_gpu: decoder = torch.nn.DataParallel(decoder, device_ids=range(args.ngpus)) encoder = torch.nn.DataParallel(encoder, device_ids=range(args.ngpus)) mask_siou = torch.nn.DataParallel(mask_siou, device_ids=range(args.ngpus)) class_xentropy = torch.nn.DataParallel(class_xentropy, device_ids=range(args.ngpus)) stop_xentropy = torch.nn.DataParallel(stop_xentropy, device_ids=range(args.ngpus)) if args.use_gpu: encoder.cuda() decoder.cuda() class_xentropy.cuda() mask_siou.cuda() stop_xentropy.cuda() crits = [mask_siou, class_xentropy, stop_xentropy] optims = [enc_opt, dec_opt] if args.use_gpu: torch.cuda.synchronize() start = time.time() # vars for early stopping best_val_loss = args.best_val_loss acc_patience = 0 mt_val = -1 # init windows to visualize, if visdom is enabled if args.visdom: import visdom viz = visdom.Visdom(port=args.port, server=args.server) lot, elot, mviz_pred, mviz_true, image_lot = init_visdom(args, viz) if args.curriculum_learning and epoch_resume == 0: args.limit_seqlen_to = 2 # keep track of the number of batches in each epoch for continuity when plotting curves loaders, class_names = init_dataloaders(args) num_batches = {'train': 0, 'val': 0} for e in range(args.max_epoch): print "Epoch", e + epoch_resume # store losses in lists to display average since beginning epoch_losses = {'train': {'total': [], 'iou': [], 'stop': [], 'class': []}, 'val': {'total': [], 'iou': [], 'stop': [], 'class': []}} # total mean for epoch will be saved here to display at the end total_losses = {'total': [], 'iou': [], 'stop': [], 'class': []} # check if it's time to do some changes here if e + epoch_resume >= args.finetune_after and not args.update_encoder and not args.finetune_after == -1: print("Starting to update encoder") args.update_encoder = True acc_patience = 0 mt_val = -1 if e + epoch_resume >= args.class_loss_after and not args.use_class_loss and not args.class_loss_after == -1: print("Starting to learn class loss") args.use_class_loss = True best_val_loss = 1000 # reset because adding a loss term will increase the total value acc_patience = 0 mt_val = -1 if e + epoch_resume >= args.stop_loss_after and not args.use_stop_loss and not args.stop_loss_after == -1: if args.curriculum_learning: if args.limit_seqlen_to > args.min_steps: print("Starting to learn stop loss") args.use_stop_loss = True best_val_loss = 1000 # reset because adding a loss term will increase the total value acc_patience = 0 mt_val = -1 else: print("Starting to learn stop loss") args.use_stop_loss = True best_val_loss = 1000 # reset because adding a loss term will increase the total value acc_patience = 0 mt_val = -1 # we validate after each epoch for split in ['train', 'val']: for batch_idx, (inputs, targets) in enumerate(loaders[split]): # send batch to GPU x, y_mask, y_class, sw_mask, sw_class = batch_to_var(args, inputs, targets) # we forward (and backward & update if training set) losses, outs, true_perm = runIter(args, encoder, decoder, x, y_mask, y_class, sw_mask, sw_class, crits, optims, mode=split) # store loss values in dictionary separately epoch_losses[split]['total'].append(losses[0]) epoch_losses[split]['iou'].append(losses[1]) epoch_losses[split]['stop'].append(losses[2]) epoch_losses[split]['class'].append(losses[3]) # print and display in visdom after some iterations if (batch_idx + 1)% args.print_every == 0: mt = np.mean(epoch_losses[split]['total']) mi = np.mean(epoch_losses[split]['iou']) mc = np.mean(epoch_losses[split]['class']) mx = np.mean(epoch_losses[split]['stop']) if args.visdom: if split == 'train': # we display batch loss values in visdom (Training only) viz.line( X=torch.ones((1, 4)).cpu() * (batch_idx + e * num_batches[split]), Y=torch.Tensor([mi, mx, mc, mt]).unsqueeze(0).cpu(), win=lot, update='append') w = x.size()[-1] h = x.size()[-2] out_masks, out_classes, y_mask, y_class = outs_perms_to_cpu(args, outs, true_perm, h, w) x = x.data.cpu().numpy() # send image, sample predictions and ground truths to visdom for t in range(np.shape(out_masks)[1]): mask_pred = out_masks[0, t] mask_true = y_mask[0, t] class_pred = class_names[out_classes[0, t]] class_true = class_names[y_class[0, t]] mask_pred = np.reshape(mask_pred, (x.shape[-2], x.shape[-1])) mask_true = np.reshape(mask_true, (x.shape[-2], x.shape[-1])) # heatmap displays the mask upside down viz.heatmap(np.flipud(mask_pred), win=mviz_pred[t], opts=dict(title='pred mask %d %s' % (t, class_pred))) viz.heatmap(np.flipud(mask_true), win=mviz_true[t], opts=dict(title='true mask %d %s' % (t, class_true))) viz.image((x[0] * 0.2 + 0.5) * 256, win=image_lot, opts=dict(title='image (unnnormalized)')) te = time.time() - start print "iter %d:\ttotal:%.4f\tclass:%.4f\tiou:%.4f\tstop:%.4f\ttime:%.4f" % (batch_idx, mt, mc, mi, mx, te) if args.use_gpu: torch.cuda.synchronize() start = time.time() num_batches[split] = batch_idx + 1 # compute mean val losses within epoch if split == 'val' and args.smooth_curves: if mt_val == -1: mt = np.mean(epoch_losses[split]['total']) else: mt = 0.9*mt_val + 0.1*np.mean(epoch_losses[split]['total']) mt_val = mt else: mt = np.mean(epoch_losses[split]['total']) mi = np.mean(epoch_losses[split]['iou']) mc = np.mean(epoch_losses[split]['class']) mx = np.mean(epoch_losses[split]['stop']) # save train and val losses for the epoch to display in visdom total_losses['iou'].append(mi) total_losses['class'].append(mc) total_losses['stop'].append(mx) total_losses['total'].append(mt) args.epoch_resume = e + epoch_resume print "Epoch %d:\ttotal:%.4f\tclass:%.4f\tiou:%.4f\tstop:%.4f\t(%s)" % (e, mt, mc, mi,mx, split) # epoch losses if args.visdom: update = True if e == 0 else 'append' for l in ['total', 'iou', 'stop', 'class']: viz.line(X=torch.ones((1, 2)).cpu() * (e + 1), Y=torch.Tensor(total_losses[l]).unsqueeze(0).cpu(), win=elot[l], update=update) if mt < (best_val_loss - args.min_delta): print "Saving checkpoint." best_val_loss = mt args.best_val_loss = best_val_loss # saves model, params, and optimizers save_checkpoint(args, encoder, decoder, enc_opt, dec_opt) acc_patience = 0 else: acc_patience += 1 if acc_patience > args.patience and not args.use_class_loss and not args.class_loss_after == -1: print("Starting to learn class loss") acc_patience = 0 args.use_class_loss = True best_val_loss = 1000 # reset because adding a loss term will increase the total value mt_val = -1 encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, _ = load_checkpoint(args.model_name,args.use_gpu) encoder.load_state_dict(encoder_dict) decoder.load_state_dict(decoder_dict) enc_opt.load_state_dict(enc_opt_dict) dec_opt.load_state_dict(dec_opt_dict) if acc_patience > args.patience and args.curriculum_learning and args.limit_seqlen_to < args.maxseqlen: print("Adding one step more:") acc_patience = 0 args.limit_seqlen_to += args.steps_cl print(args.limit_seqlen_to) best_val_loss = 1000 mt_val = -1 if acc_patience > args.patience and not args.update_encoder and not args.finetune_after == -1: print("Starting to update encoder") acc_patience = 0 args.update_encoder = True best_val_loss = 1000 # reset because adding a loss term will increase the total value mt_val = -1 encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, _ = load_checkpoint(args.model_name,args.use_gpu) encoder.load_state_dict(encoder_dict) decoder.load_state_dict(decoder_dict) enc_opt.load_state_dict(enc_opt_dict) dec_opt.load_state_dict(dec_opt_dict) if acc_patience > args.patience and not args.use_stop_loss and not args.stop_loss_after == -1: if args.curriculum_learning: print("Starting to learn stop loss") if args.limit_seqlen_to > args.min_steps: acc_patience = 0 args.use_stop_loss = True best_val_loss = 1000 # reset because adding a loss term will increase the total value mt_val = -1 else: print("Starting to learn stop loss") acc_patience = 0 args.use_stop_loss = True best_val_loss = 1000 # reset because adding a loss term will increase the total value mt_val = -1 encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, _ = load_checkpoint(args.model_name,args.use_gpu) encoder.load_state_dict(encoder_dict) decoder.load_state_dict(decoder_dict) enc_opt.load_state_dict(enc_opt_dict) dec_opt.load_state_dict(dec_opt_dict) # early stopping after N epochs without improvement if acc_patience > args.patience_stop: break
use_flip = True to_tensor = transforms.ToTensor() normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) image_transforms = transforms.Compose([to_tensor, normalize]) model_name = 'MATNet' # specify the model name epoch = 0 # specify the epoch number davis_result_dir = './output/davis16' encoder_dict, decoder_dict, enc_opt_dict, dec_opt_dict, load_args =\ load_checkpoint_epoch(model_name, epoch, True, False) encoder = Encoder() decoder = Decoder() encoder_dict, decoder_dict = check_parallel(encoder_dict, decoder_dict) encoder.load_state_dict(encoder_dict) decoder.load_state_dict(decoder_dict) encoder.cuda() decoder.cuda() encoder.train(False) decoder.train(False) val_set = 'data/DAVIS2017/ImageSets/2016/val.txt' with open(val_set) as f: seqs = f.readlines() seqs = [seq.strip() for seq in seqs] for video in tqdm(seqs):