コード例 #1
0
def main():
    net = Ensemble(device_id, pretrained=False)

    print ('load snapshot \'%s\' for testing' % args['snapshot'])
    # net.load_state_dict(torch.load('pretrained/R2Net.pth', map_location='cuda:2'))
    # net = load_part_of_model2(net, 'pretrained/R2Net.pth', device_id=2)
    net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'] + '.pth'),
                                   map_location='cuda:' + str(device_id)))
    net.eval()
    net.cuda()
    results = {}

    with torch.no_grad():

        for name, root in to_test.items():

            precision_record, recall_record, = [AvgMeter() for _ in range(256)], [AvgMeter() for _ in range(256)]
            mae_record = AvgMeter()

            if args['save_results']:
                check_mkdir(os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot'])))
            img_list = [i_id.strip() for i_id in open(imgs_path)]
            for idx, img_name in enumerate(img_list):
                print('predicting for %s: %d / %d' % (name, idx + 1, len(img_list)))
                print(img_name)

                if name == 'VOS' or name == 'DAVSOD':
                    img = Image.open(os.path.join(root, img_name + '.png')).convert('RGB')
                else:
                    img = Image.open(os.path.join(root, img_name + '.jpg')).convert('RGB')
                shape = img.size
                img = img.resize(args['input_size'])
                img_var = Variable(img_transform(img).unsqueeze(0), volatile=True).cuda()
                start = time.time()
                outputs_a, outputs_c = net(img_var)
                a_out1u, a_out2u, a_out2r, a_out3r, a_out4r, a_out5r = outputs_a  # F3Net
                # b_outputs0, b_outputs1 = outputs_b  # CPD
                c_outputs0, c_outputs1, c_outputs2, c_outputs3, c_outputs4 = outputs_c  # RAS
                prediction = torch.sigmoid(c_outputs0)
                end = time.time()
                print('running time:', (end - start))
                # e = Erosion2d(1, 1, 5, soft_max=False).cuda()
                # prediction2 = e(prediction)
                #
                # precision2 = to_pil(prediction2.data.squeeze(0).cpu())
                # precision2 = prediction2.data.squeeze(0).cpu().numpy()
                # precision2 = precision2.resize(shape)
                # prediction2 = np.array(precision2)
                # prediction2 = prediction2.astype('float')

                precision = to_pil(prediction.data.squeeze(0).cpu())
                precision = precision.resize(shape)
                prediction = np.array(precision)
                prediction = prediction.astype('float')

                # plt.style.use('classic')
                # plt.subplot(1, 2, 1)
                # plt.imshow(prediction)
                # plt.subplot(1, 2, 2)
                # plt.imshow(precision2[0])
                # plt.show()

                prediction = MaxMinNormalization(prediction, prediction.max(), prediction.min()) * 255.0
                prediction = prediction.astype('uint8')
                # if args['crf_refine']:
                #     prediction = crf_refine(np.array(img), prediction)

                gt = np.array(Image.open(os.path.join(gt_root, img_name + '.png')).convert('L'))
                precision, recall, mae = cal_precision_recall_mae(prediction, gt)
                for pidx, pdata in enumerate(zip(precision, recall)):
                    p, r = pdata
                    precision_record[pidx].update(p)
                    recall_record[pidx].update(r)
                mae_record.update(mae)

                if args['save_results']:
                    folder, sub_name = os.path.split(img_name)
                    save_path = os.path.join(ckpt_path, exp_name, '(%s) %s_%s' % (exp_name, name, args['snapshot']), folder)
                    if not os.path.exists(save_path):
                        os.makedirs(save_path)
                    Image.fromarray(prediction).save(os.path.join(save_path, sub_name + '.png'))

            fmeasure = cal_fmeasure([precord.avg for precord in precision_record],
                                    [rrecord.avg for rrecord in recall_record])

            results[name] = {'fmeasure': fmeasure, 'mae': mae_record.avg}

    print ('test results:')
    print (results)
    log_path = os.path.join('result_all.txt')
    open(log_path, 'a').write(exp_name + ' ' + args['snapshot'] + '\n')
    open(log_path, 'a').write(str(results) + '\n\n')
コード例 #2
0
        validation_guess_accuracy = list()

        for split, dataset in zip(exp_config['splits'], [dataset_train, dataset_val]):

            dataloader = DataLoader(
            dataset=dataset,
            batch_size=optimizer_args['batch_size'],
            shuffle=True,
            num_workers= 1 if optimizer_args['my_cpu'] else multiprocessing.cpu_count()//2,
            pin_memory= use_cuda,
            drop_last=False)

            if split == 'train':
                model.train()
            else:
                model.eval()

            for i_batch, sample in enumerate(dataloader):
                if i_batch > 4 and breaking:
                    print('Breaking after processing 4 batch')
                    break

                sample['tgt_len'], ind = torch.sort(sample['tgt_len'], 0, descending=True)
                batch_size = ind.size(0)

                # Get Batch
                for k, v in sample.items():
                    if k == 'tgt_len':
                        sample[k] = to_var(v)
                    elif torch.is_tensor(v):
                        sample[k] = to_var(v[ind])
コード例 #3
0
ファイル: POLOAgent.py プロジェクト: mohakbhardwaj/aop
class POLOAgent(MPCAgent):
    """
    MPC-based agent that uses the Plan Online, Learn Offline (POLO) framework
    (Lowrey et. al. 2018) for trajectory optimization.
    """
    def __init__(self, params):
        super(POLOAgent, self).__init__(params)
        self.H_backup = self.params['polo']['H_backup']

        # Create ensemble of value functions
        model_params = params['polo']['ens_params']['model_params']
        model_params['input_size'] = self.N
        model_params['output_size'] = 1

        params['polo']['ens_params']['dtype'] = self.dtype
        params['polo']['ens_params']['device'] = self.device

        self.val_ens = Ensemble(self.params['polo']['ens_params'])

        # Learn from replay buffer
        self.polo_buf = ReplayBuffer(self.N, self.M,
                                     self.params['polo']['buf_size'])

        # Value (from forward), value mean, value std
        self.hist['vals'] = np.zeros((self.T, 3))

    def get_action(self, prior=None):
        """
        POLO selects action based on MPC optimization with an optimistic
        terminal value function.
        """
        self.val_ens.eval()

        # Get value of current state
        s = torch.tensor(self.prev_obs, dtype=self.dtype)
        s = s.to(device=self.device)
        current_val = self.val_ens.forward(s)[0]
        current_val = torch.squeeze(current_val, -1)
        current_val = current_val.detach().cpu().numpy()

        # Get prediction of every function in ensemble
        preds = self.val_ens.get_preds_np(self.prev_obs)

        # Log information from value function
        self.hist['vals'][self.time] = \
            np.array([current_val, np.mean(preds), np.std(preds)])

        # Run MPC to get action
        act = super(POLOAgent, self).get_action(terminal=self.val_ens,
                                                prior=prior)

        return act

    def action_taken(self, prev_obs, obs, rew, done, ifo):
        """
        Update buffer for value function learning.
        """
        self.polo_buf.update(prev_obs, obs, rew, done)

    def do_updates(self):
        """
        POLO learns a value function from its past true history of interactions
        with the environment.
        """
        super(POLOAgent, self).do_updates()
        if self.time % self.params['polo']['update_freq'] == 0:
            self.val_ens.update_from_buf(self.polo_buf,
                                         self.params['polo']['grad_steps'],
                                         self.params['polo']['batch_size'],
                                         self.params['polo']['H_backup'],
                                         self.gamma)

    def print_logs(self):
        """
        POLO-specific logging information.
        """
        bi, ei = super(POLOAgent, self).print_logs()

        self.print('POLO metrics', mode='head')

        self.print('current state val', self.hist['vals'][self.time - 1][0])
        self.print('current state std', self.hist['vals'][self.time - 1][2])

        return bi, ei