def main(): net = Ensemble(device_id, pretrained=False) print ('load snapshot \'%s\' for testing' % args['snapshot']) # net.load_state_dict(torch.load('pretrained/R2Net.pth', map_location='cuda:2')) # net = load_part_of_model2(net, 'pretrained/R2Net.pth', device_id=2) net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'), map_location='cuda:' + str(device_id))) net.eval() net.cuda() results = {} with torch.no_grad(): for name, root in to_test.items(): precision_record, recall_record, = [AvgMeter() for _ in range(256)], [AvgMeter() for _ in range(256)] mae_record = AvgMeter() if args['save_results']: check_mkdir(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']))) img_list = [i_id.strip() for i_id in open(imgs_path)] for idx, img_name in enumerate(img_list): print('predicting for %s: %d / %d' % (name, idx + 1, len(img_list))) print(img_name) if name == 'VOS' or name == 'DAVSOD': img = Image.open(os.path.join(root, img_name + '.png')).convert('RGB') else: img = Image.open(os.path.join(root, img_name + '.jpg')).convert('RGB') shape = img.size img = img.resize(args['input_size']) img_var = Variable(img_transform(img).unsqueeze(0), volatile=True).cuda() start = time.time() outputs_a, outputs_c = net(img_var) a_out1u, a_out2u, a_out2r, a_out3r, a_out4r, a_out5r = outputs_a # F3Net # b_outputs0, b_outputs1 = outputs_b # CPD c_outputs0, c_outputs1, c_outputs2, c_outputs3, c_outputs4 = outputs_c # RAS prediction = torch.sigmoid(c_outputs0) end = time.time() print('running time:', (end - start)) # e = Erosion2d(1, 1, 5, soft_max=False).cuda() # prediction2 = e(prediction) # # precision2 = to_pil(prediction2.data.squeeze(0).cpu()) # precision2 = prediction2.data.squeeze(0).cpu().numpy() # precision2 = precision2.resize(shape) # prediction2 = np.array(precision2) # prediction2 = prediction2.astype('float') precision = to_pil(prediction.data.squeeze(0).cpu()) precision = precision.resize(shape) prediction = np.array(precision) prediction = prediction.astype('float') # plt.style.use('classic') # plt.subplot(1, 2, 1) # plt.imshow(prediction) # plt.subplot(1, 2, 2) # plt.imshow(precision2[0]) # plt.show() prediction = MaxMinNormalization(prediction, prediction.max(), prediction.min()) * 255.0 prediction = prediction.astype('uint8') # if args['crf_refine']: # prediction = crf_refine(np.array(img), prediction) gt = np.array(Image.open(os.path.join(gt_root, img_name + '.png')).convert('L')) precision, recall, mae = cal_precision_recall_mae(prediction, gt) for pidx, pdata in enumerate(zip(precision, recall)): p, r = pdata precision_record[pidx].update(p) recall_record[pidx].update(r) mae_record.update(mae) if args['save_results']: folder, sub_name = os.path.split(img_name) save_path = os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']), folder) if not os.path.exists(save_path): os.makedirs(save_path) Image.fromarray(prediction).save(os.path.join(save_path, sub_name + '.png')) fmeasure = cal_fmeasure([precord.avg for precord in precision_record], [rrecord.avg for rrecord in recall_record]) results[name] = {'fmeasure': fmeasure, 'mae': mae_record.avg} print ('test results:') print (results) log_path = os.path.join('result_all.txt') open(log_path, 'a').write(exp_name + ' ' + args['snapshot'] + '\n') open(log_path, 'a').write(str(results) + '\n\n')
validation_guess_accuracy = list() for split, dataset in zip(exp_config['splits'], [dataset_train, dataset_val]): dataloader = DataLoader( dataset=dataset, batch_size=optimizer_args['batch_size'], shuffle=True, num_workers= 1 if optimizer_args['my_cpu'] else multiprocessing.cpu_count()//2, pin_memory= use_cuda, drop_last=False) if split == 'train': model.train() else: model.eval() for i_batch, sample in enumerate(dataloader): if i_batch > 4 and breaking: print('Breaking after processing 4 batch') break sample['tgt_len'], ind = torch.sort(sample['tgt_len'], 0, descending=True) batch_size = ind.size(0) # Get Batch for k, v in sample.items(): if k == 'tgt_len': sample[k] = to_var(v) elif torch.is_tensor(v): sample[k] = to_var(v[ind])
class POLOAgent(MPCAgent): """ MPC-based agent that uses the Plan Online, Learn Offline (POLO) framework (Lowrey et. al. 2018) for trajectory optimization. """ def __init__(self, params): super(POLOAgent, self).__init__(params) self.H_backup = self.params['polo']['H_backup'] # Create ensemble of value functions model_params = params['polo']['ens_params']['model_params'] model_params['input_size'] = self.N model_params['output_size'] = 1 params['polo']['ens_params']['dtype'] = self.dtype params['polo']['ens_params']['device'] = self.device self.val_ens = Ensemble(self.params['polo']['ens_params']) # Learn from replay buffer self.polo_buf = ReplayBuffer(self.N, self.M, self.params['polo']['buf_size']) # Value (from forward), value mean, value std self.hist['vals'] = np.zeros((self.T, 3)) def get_action(self, prior=None): """ POLO selects action based on MPC optimization with an optimistic terminal value function. """ self.val_ens.eval() # Get value of current state s = torch.tensor(self.prev_obs, dtype=self.dtype) s = s.to(device=self.device) current_val = self.val_ens.forward(s)[0] current_val = torch.squeeze(current_val, -1) current_val = current_val.detach().cpu().numpy() # Get prediction of every function in ensemble preds = self.val_ens.get_preds_np(self.prev_obs) # Log information from value function self.hist['vals'][self.time] = \ np.array([current_val, np.mean(preds), np.std(preds)]) # Run MPC to get action act = super(POLOAgent, self).get_action(terminal=self.val_ens, prior=prior) return act def action_taken(self, prev_obs, obs, rew, done, ifo): """ Update buffer for value function learning. """ self.polo_buf.update(prev_obs, obs, rew, done) def do_updates(self): """ POLO learns a value function from its past true history of interactions with the environment. """ super(POLOAgent, self).do_updates() if self.time % self.params['polo']['update_freq'] == 0: self.val_ens.update_from_buf(self.polo_buf, self.params['polo']['grad_steps'], self.params['polo']['batch_size'], self.params['polo']['H_backup'], self.gamma) def print_logs(self): """ POLO-specific logging information. """ bi, ei = super(POLOAgent, self).print_logs() self.print('POLO metrics', mode='head') self.print('current state val', self.hist['vals'][self.time - 1][0]) self.print('current state std', self.hist['vals'][self.time - 1][2]) return bi, ei