Ejemplo n.º 1
0
    def __init__(self, ag_params, policyparams, gpu_id, ngpu):
        """
        :param ag_params:
        :param policyparams:
        """
        self._hp = self._default_hparams()
        self._override_defaults(policyparams)

        self.agentparams = ag_params

        if self._hp.logging_dir:
            self._logger = Logger(
                self._hp.logging_dir,
                'cem{}log.txt'.format(self.agentparams['gpu_id']))
        else:
            self._logger = Logger(printout=True)

        self._logger.log('init inverse model controller')

        #action dimensions:
        self._adim = self.agentparams['adim']
        self._sdim = self.agentparams['sdim']
        predictor_hparams = {}
        self.predictor = self._hp.predictor_class(self._hp.model_params_path,
                                                  predictor_hparams,
                                                  n_gpus=ngpu,
                                                  first_gpu=gpu_id)
        self.predictor.restore()

        self.action_counter = 0
        self.actions = None
        self.context_actions = [None] * self._hp.num_context
        self.context_frames = [None] * self._hp.num_context
    def __init__(self, ag_params, policyparams):
        """
        :param ag_params:
        :param policyparams:
        """
        self._hp = self._default_hparams()
        self._override_defaults(policyparams)

        self.agentparams = ag_params

        if self._hp.logging_dir:
            self._logger = Logger(
                self._hp.logging_dir,
                'cem{}log.txt'.format(self.agentparams['gpu_id']))
        else:
            self._logger = Logger(printout=True)

        self._logger.log('init CEM controller')

        self._t_since_replan = None
        self._t = None
        self._n_iter = self._hp.iterations

        #action dimensions:
        self._adim = self.agentparams['adim']
        self._sdim = self.agentparams['sdim']  # state dimension

        self._sampler = None
        self._best_indices, self._best_actions = None, None

        self._state = None
        assert self._hp.minimum_selection > 0, "must take at least 1 sample for refitting"
Ejemplo n.º 3
0
    def __init__(self, config, gpu_id=0, ngpu=1, logger=None):
        self._hyperparams = config
        self.agent = config['agent']['type'](config['agent'])
        self.agentparams = config['agent']
        self.policyparams = config['policy']
        if logger == None:
            self.logger = Logger(printout=True)
        else:
            self.logger = logger
        self.logger.log('started sim')
        self.agentparams['gpu_id'] = gpu_id

        self.policy = config['policy']['type'](self.agent._hyperparams,
                                               config['policy'], gpu_id, ngpu)

        self._record_queue = config.pop('record_saver', None)
        self._counter = config.pop('counter', None)

        self.trajectory_list = []
        self.im_score_list = []
        try:
            os.remove(self._hyperparams['agent']['image_dir'])
        except:
            pass
        self.task_mode = 'train'
Ejemplo n.º 4
0
class InvModelBaseController(Policy):
    """
    Inverse model policy
    """
    def __init__(self, ag_params, policyparams, gpu_id, ngpu):
        """
        :param ag_params:
        :param policyparams:
        """
        self._hp = self._default_hparams()
        self._override_defaults(policyparams)

        self.agentparams = ag_params

        if self._hp.logging_dir:
            self._logger = Logger(
                self._hp.logging_dir,
                'cem{}log.txt'.format(self.agentparams['gpu_id']))
        else:
            self._logger = Logger(printout=True)

        self._logger.log('init inverse model controller')

        #action dimensions:
        self._adim = self.agentparams['adim']
        self._sdim = self.agentparams['sdim']
        predictor_hparams = {}
        self.predictor = self._hp.predictor_class(self._hp.model_params_path,
                                                  predictor_hparams,
                                                  n_gpus=ngpu,
                                                  first_gpu=gpu_id)
        self.predictor.restore()

        self.action_counter = 0
        self.actions = None
        self.context_actions = [None] * self._hp.num_context
        self.context_frames = [None] * self._hp.num_context

    def _default_hparams(self):
        default_dict = {
            'T': 15,  # planning horizon
            'predictor_class': ActionInferenceInterface,
            'model_params_path': '',
            'model_restore_path': '',
            'logging_dir': '',
            'load_T': 7,
            'num_context': 2,
            'replan_every': 2,
            'context_action_weight': [1, 1, 1, 1],
            'initial_action_low': [-0.025, -0.025, -0.025, 0],
            'initial_action_high': [0.025, 0.025, 0.025, 0],
        }

        parent_params = super(InvModelBaseController, self)._default_hparams()
        for k in default_dict.keys():
            parent_params.add_hparam(k, default_dict[k])
        return parent_params

    def reset(self):
        self.plan_stat = {}  #planning statistics
        self.action_counter = 0
        self.actions = None
        self.context_actions = [None] * self._hp.num_context
        self.context_frames = [None] * self._hp.num_context

    def _sample_initial_action(self):
        return np.random.uniform(self._hp.initial_action_low,
                                 self._hp.initial_action_high)

    def act(self, t=None, i_tr=None, images=None, goal_image=None):

        if t < self._hp.num_context:
            action = self._sample_initial_action(
            ) * self._hp.context_action_weight
        elif t >= self._hp.num_context:
            if (t - self._hp.num_context) % self._hp.replan_every == 0:
                # Perform replanning here.
                float_ctx = [
                    frame[None, None] for frame in self.context_frames
                ]
                prepped_ctx_im = np.concatenate(float_ctx, axis=1)
                prepped_ctx_act = np.array(self.context_actions)[None]
                self.actions = self.predictor(
                    convert_to_float(images[-1, 0]), goal_image[-1, 0],
                    prepped_ctx_act,
                    prepped_ctx_im)  # select last-image and 0-th camera
                # action_counter represents the amount of time since the last replan
                self.action_counter = 0
            print('t {} action counter {}'.format(t, self.action_counter))
            assert self.actions.shape[1] > self.action_counter, \
                'Tried to take action {} of plan containing {}. ' \
                'Maybe re-planning is not occurring often enough?'.format(self.action_counter, self.actions.shape[1])

            action = self.actions[0, self.action_counter]
            self.action_counter += 1

        print('action ', action)
        new_context_image = convert_to_float(np.copy(images[-1, 0]))
        self.update_context(new_context_image, action)
        return {'actions': action, 'plan_stat': self.plan_stat}

    def update_context(self, new_image, new_action):
        self.context_frames.append(new_image)
        self.context_actions.append(new_action)
        if len(self.context_frames) > self._hp.num_context:
            # Maintain newest num_context context frames & actions
            self.context_frames.pop(0)
            self.context_actions.pop(0)
Ejemplo n.º 5
0
class Sim(object):
    """ Main class to run algorithms and experiments. """
    def __init__(self, config, gpu_id=0, ngpu=1, logger=None):
        self._hyperparams = config
        self.agent = config['agent']['type'](config['agent'])
        self.agentparams = config['agent']
        self.policyparams = config['policy']
        if logger == None:
            self.logger = Logger(printout=True)
        else:
            self.logger = logger
        self.logger.log('started sim')
        self.agentparams['gpu_id'] = gpu_id

        self.policy = config['policy']['type'](self.agent._hyperparams,
                                               config['policy'], gpu_id, ngpu)

        self._record_queue = config.pop('record_saver', None)
        self._counter = config.pop('counter', None)

        self.trajectory_list = []
        self.im_score_list = []
        try:
            os.remove(self._hyperparams['agent']['image_dir'])
        except:
            pass
        self.task_mode = 'train'

    def run(self):
        if self._counter is None:
            for i in range(self._hyperparams['start_index'],
                           self._hyperparams['end_index'] + 1):
                self.take_sample(i)
        else:
            itr = self._counter.ret_increment()
            while itr < self._hyperparams['ntraj']:
                print('taking sample {} of {}'.format(
                    itr, self._hyperparams['ntraj']))
                self.take_sample(itr)
                itr = self._counter.ret_increment()

    def take_sample(self, sample_index):
        self.policy.reset()
        agent_data, obs_dict, policy_out = self.agent.sample(
            self.policy, sample_index)
        if self._hyperparams.get('save_data', True):
            self.save_data(sample_index, agent_data, obs_dict, policy_out)
        return agent_data

    def save_data(self, itr, agent_data, obs_dict, policy_outputs):
        if self._hyperparams.get('save_only_good',
                                 False) and not agent_data['goal_reached']:
            return

        if self._hyperparams.get('save_raw_images', False):
            self._save_raw_data(itr, agent_data, obs_dict, policy_outputs)
        elif self._record_queue is not None:
            self._record_queue.put((agent_data, obs_dict, policy_outputs))
        else:
            raise ValueError('Saving neither raw data nor records')

    def _save_raw_data(self, itr, agent_data, obs_dict, policy_outputs):
        data_save_dir = self.agentparams['data_save_dir']

        ngroup = self._hyperparams.get('ngroup', 1000)
        igrp = itr // ngroup
        group_folder = data_save_dir + '/{}/traj_group{}'.format(
            self.task_mode, igrp)
        if not os.path.exists(group_folder):
            os.makedirs(group_folder)

        traj_folder = group_folder + '/traj{}'.format(itr)
        if os.path.exists(traj_folder):
            print('trajectory folder {} already exists, deleting the folder'.
                  format(traj_folder))
            shutil.rmtree(traj_folder)

        os.makedirs(traj_folder)
        print('writing: ', traj_folder)
        if 'images' in obs_dict:
            images = obs_dict.pop('images')
            T, n_cams = images.shape[:2]
            for i in range(n_cams):
                os.mkdir(traj_folder + '/images{}'.format(i))
            for t in range(T):
                for i in range(n_cams):
                    cv2.imwrite(
                        '{}/images{}/im_{}.png'.format(traj_folder, i, t),
                        images[t, i, :, :, ::-1])
        with open('{}/agent_data.pkl'.format(traj_folder), 'wb') as file:
            pkl.dump(agent_data, file)
        with open('{}/obs_dict.pkl'.format(traj_folder), 'wb') as file:
            pkl.dump(obs_dict, file)
        with open('{}/policy_out.pkl'.format(traj_folder), 'wb') as file:
            pkl.dump(policy_outputs, file)
    def __init__(self, ag_params, policyparams):
        """
        :param ag_params:
        :param policyparams:
        :param predictor:
        :param save_subdir:
        :param gdnet: goal-distance network
        """
        self._hp = self._default_hparams()
        self.override_defaults(policyparams)

        self.agentparams = ag_params
        if 'logging_dir' in self.agentparams:
            self.logger = Logger(
                self.agentparams['logging_dir'],
                'cem{}log.txt'.format(self.agentparams['gpu_id']))
        else:
            self.logger = Logger(printout=True)
        self.logger.log('init CEM controller')

        self.t = None

        if self._hp.verbose:
            self.verbose = True
            if isinstance(self._hp.verbose, int):
                self.verbose_freq = self._hp.verbose
            else:
                self.verbose_freq = 1
        else:
            self.verbose = False
            self.verbose_freq = 1

        self.niter = self._hp.iterations

        self.action_list = []
        self.naction_steps = self._hp.nactions
        self.repeat = self._hp.repeat

        if isinstance(self._hp.num_samples, list):
            self.M = self._hp.num_samples[0]
        else:
            self.M = self._hp.num_samples

        if self._hp.selection_frac != -1:
            self.K = int(np.ceil(self.M * self._hp.selection_frac))
        else:
            self.K = 10  # only consider K best samples for refitting

        #action dimensions:
        # deltax, delty, goup_nstep, delta_rot, close_nstep
        self.adim = self.agentparams['adim']
        self.sdim = self.agentparams['sdim']  # state dimension

        self.indices = []
        self.mean = None
        self.sigma = None
        self.state = None

        self.dict_ = collections.OrderedDict()

        self.plan_stat = {}  #planning statistics

        self.warped_image_goal, self.warped_image_start = None, None

        if self._hp.stochastic_planning:
            self.smp_peract = self._hp.stochastic_planning[0]
        else:
            self.smp_peract = 1

        self.ncam = 1
        self.ndesig = 1
        self.ncontxt = 0
        self.len_pred = self.repeat * self.naction_steps - self.ncontxt
        self.best_cost_perstep = np.zeros(
            [self.ncam, self.ndesig, self.len_pred])
        self._close_override = False
class CEM_Controller_Base(Policy):
    """
    Cross Entropy Method Stochastic Optimizer
    """
    def __init__(self, ag_params, policyparams):
        """
        :param ag_params:
        :param policyparams:
        :param predictor:
        :param save_subdir:
        :param gdnet: goal-distance network
        """
        self._hp = self._default_hparams()
        self.override_defaults(policyparams)

        self.agentparams = ag_params
        if 'logging_dir' in self.agentparams:
            self.logger = Logger(
                self.agentparams['logging_dir'],
                'cem{}log.txt'.format(self.agentparams['gpu_id']))
        else:
            self.logger = Logger(printout=True)
        self.logger.log('init CEM controller')

        self.t = None

        if self._hp.verbose:
            self.verbose = True
            if isinstance(self._hp.verbose, int):
                self.verbose_freq = self._hp.verbose
            else:
                self.verbose_freq = 1
        else:
            self.verbose = False
            self.verbose_freq = 1

        self.niter = self._hp.iterations

        self.action_list = []
        self.naction_steps = self._hp.nactions
        self.repeat = self._hp.repeat

        if isinstance(self._hp.num_samples, list):
            self.M = self._hp.num_samples[0]
        else:
            self.M = self._hp.num_samples

        if self._hp.selection_frac != -1:
            self.K = int(np.ceil(self.M * self._hp.selection_frac))
        else:
            self.K = 10  # only consider K best samples for refitting

        #action dimensions:
        # deltax, delty, goup_nstep, delta_rot, close_nstep
        self.adim = self.agentparams['adim']
        self.sdim = self.agentparams['sdim']  # state dimension

        self.indices = []
        self.mean = None
        self.sigma = None
        self.state = None

        self.dict_ = collections.OrderedDict()

        self.plan_stat = {}  #planning statistics

        self.warped_image_goal, self.warped_image_start = None, None

        if self._hp.stochastic_planning:
            self.smp_peract = self._hp.stochastic_planning[0]
        else:
            self.smp_peract = 1

        self.ncam = 1
        self.ndesig = 1
        self.ncontxt = 0
        self.len_pred = self.repeat * self.naction_steps - self.ncontxt
        self.best_cost_perstep = np.zeros(
            [self.ncam, self.ndesig, self.len_pred])
        self._close_override = False

    def _default_hparams(self):
        default_dict = {
            'verbose': False,
            'verbose_every_itr': False,
            'niter': 3,
            'num_samples': [200],
            'selection_frac':
            -1.,  # specifcy which fraction of best samples to use to compute mean and var for next CEM iteration
            'discrete_ind': None,
            'reuse_mean': False,
            'reuse_cov': False,
            'stochastic_planning': False,
            'rejection_sampling': True,
            'cov_blockdiag': False,
            'smooth_cov': False,
            'iterations': 3,
            'nactions': 5,
            'repeat': 3,
            'action_bound': True,
            'action_order': [
                None
            ],  # [None] implies default order, otherwise specify how each action dim in order (aka ['x', 'y', ...]
            'initial_std': 0.05,  #std dev. in xy
            'initial_std_lift': 0.15,  #std dev. in xy
            'initial_std_rot': np.pi / 18,
            'initial_std_grasp': 2,
            'finalweight': 10,
            'use_first_plan': False,
            'custom_sampler': None,
            'replan_interval': -1,
            'type': None,
            'add_zero_action':
            False,  # add one action sample with zero actions, this might prevent random walks in the end
            'reduce_std_dev':
            1.,  # reduce standard dev in later timesteps when reusing action
            'visualize_best':
            True,  # visualizer selects K best if True (random K trajectories otherwise)
        }

        parent_params = super(CEM_Controller_Base, self)._default_hparams()
        for k in default_dict.keys():
            parent_params.add_hparam(k, default_dict[k])
        return parent_params

    def reset(self):
        self.plan_stat = {}  #planning statistics
        self.indices = []
        self.action_list = []

    def perform_CEM(self):
        self.logger.log('starting cem at t{}...'.format(self.t))
        timings = OrderedDict()
        t = time.time()

        if not self._hp.reuse_cov or self.t < 2:
            self.sigma = construct_initial_sigma(self._hp, self.adim, self.t)
            self.sigma_prev = self.sigma
        else:
            self.sigma = reuse_cov(self.sigma, self.adim, self._hp)

        if not self._hp.reuse_mean or self.t < 2:
            self.mean = np.zeros(self.adim * self.naction_steps)
        else:
            self.mean = reuse_action(self.bestaction, self._hp)

        if (self._hp.reuse_mean or self._hp.reuse_cov) and self.t >= 2:
            self.M = self._hp.num_samples[1]
            self.K = int(np.ceil(self.M * self._hp.selection_frac))

        self.bestindices_of_iter = np.zeros((self.niter, self.K))
        self.cost_perstep = np.zeros([
            self.M, self.ncam, self.ndesig,
            self.repeat * self.naction_steps - self.ncontxt
        ])

        self.logger.log('M {}, K{}'.format(self.M, self.K))
        self.logger.log('------------------------------------------------')
        self.logger.log('starting CEM cylce')
        timings['pre_itr'] = time.time() - t

        if self._hp.custom_sampler:
            sampler = self._hp.custom_sampler(self.sigma, self.mean, self._hp,
                                              self.repeat, self.adim)

        for itr in range(self.niter):
            itr_times = OrderedDict()
            self.logger.log('------------')
            self.logger.log('iteration: ', itr)
            t_startiter = time.time()
            if self._hp.custom_sampler is None:
                if self._hp.rejection_sampling:
                    actions = self.sample_actions_rej()
                else:
                    actions = self.sample_actions(self.mean, self.sigma,
                                                  self._hp, self.M)

            else:
                actions = sampler.sample(itr, self.M, self.state, self.mean,
                                         self.sigma, self._close_override)

            itr_times['action_sampling'] = time.time() - t_startiter
            t_start = time.time()

            scores = self.get_rollouts(actions, itr, itr_times)
            itr_times['vid_pred_total'] = time.time() - t_start
            t = time.time()
            self.logger.log(
                'overall time for evaluating actions {}'.format(time.time() -
                                                                t_start))

            if self._hp.stochastic_planning:
                actions, scores = self.action_preselection(actions, scores)

            self.indices = scores.argsort()[:self.K]
            self.bestindices_of_iter[itr] = self.indices

            self.bestaction_withrepeat = actions[self.indices[0]]
            self.plan_stat['scores_itr{}'.format(itr)] = scores
            self.plan_stat['bestscore_itr{}'.format(itr)] = scores[
                self.indices[0]]
            if hasattr(self, 'best_cost_perstep'):
                self.plan_stat['best_cost_perstep'] = self.best_cost_perstep

            actions_flat = self.post_process_actions(actions)

            self.fit_gaussians(actions_flat)

            self.logger.log('iter {0}, bestscore {1}'.format(
                itr, scores[self.indices[0]]))
            self.logger.log(
                'overall time for iteration {}'.format(time.time() -
                                                       t_startiter))
            itr_times['post_pred'] = time.time() - t
            timings['itr{}'.format(itr)] = itr_times

        # pkl.dump(timings, open('{}/timings_CEM_{}.pkl'.format(self.agentparams['record'], self.t), 'wb'))

    def sample_actions(self, mean, sigma, hp, M):
        actions = np.random.multivariate_normal(mean, sigma, M)
        actions = actions.reshape(M, self.naction_steps, self.adim)
        if hp.discrete_ind != None:
            actions = discretize(actions, M, self.naction_steps,
                                 hp.discrete_ind)

        if hp.action_bound:
            actions = truncate_movement(actions, hp)
        actions = np.repeat(actions, hp.repeat, axis=1)

        if hp.add_zero_action:
            actions[0] = 0

        return actions

    def fit_gaussians(self, actions_flat):
        arr_best_actions = actions_flat[
            self.indices]  # only take the K best actions
        self.sigma = np.cov(arr_best_actions, rowvar=False, bias=False)
        if self._hp.cov_blockdiag:
            self.sigma = make_blockdiagonal(self.sigma, self.naction_steps,
                                            self.adim)
        if self._hp.smooth_cov:
            self.sigma = 0.5 * self.sigma + 0.5 * self.sigma_prev
            self.sigma_prev = self.sigma
        self.mean = np.mean(arr_best_actions, axis=0)

    def post_process_actions(self, actions):
        num_ex = self.M // self.smp_peract
        actions = actions.reshape(num_ex, self.naction_steps, self.repeat,
                                  self.adim)
        actions = actions[:, :,
                          -1, :]  # taking only one of the repeated actions
        actions_flat = actions.reshape(num_ex, self.naction_steps * self.adim)
        self.bestaction = actions[self.indices[0]]
        return actions_flat

    def sample_actions_rej(self):
        """
        Perform rejection sampling
        :return:
        """
        runs = []
        actions = []

        if self._hp.stochastic_planning:
            num_distinct_actions = self.M // self.smp_peract
        else:
            num_distinct_actions = self.M

        for i in range(num_distinct_actions):
            ok = False
            i = 0
            while not ok:
                i += 1
                action_seq = np.random.multivariate_normal(
                    self.mean, self.sigma, 1)

                action_seq = action_seq.reshape(self.naction_steps, self.adim)
                xy_std = self._hp.initial_std
                lift_std = self._hp.initial_std_lift

                std_fac = 1.5
                if np.any(action_seq[:, :2] > xy_std*std_fac) or \
                        np.any(action_seq[:, :2] < -xy_std*std_fac) or \
                        np.any(action_seq[:, 2] > lift_std*std_fac) or \
                        np.any(action_seq[:, 2] < -lift_std*std_fac):
                    ok = False
                else:
                    ok = True

            runs.append(i)
            actions.append(action_seq)
        actions = np.stack(actions, axis=0)

        if self._hp.stochastic_planning:
            actions = np.repeat(actions, self._hp.stochastic_planning[0], 0)

        self.logger.log('rejection smp max trials', max(runs))
        if self._hp.discrete_ind != None:
            actions = self.discretize(actions)
        actions = np.repeat(actions, self.repeat, axis=1)

        self.logger.log('max action val xy', np.max(actions[:, :, :2]))
        self.logger.log('max action val z', np.max(actions[:, :, 2]))
        return actions

    def action_preselection(self, actions, scores):
        actions = actions.reshape((self.M // self.smp_peract, self.smp_peract,
                                   self.naction_steps, self.repeat, self.adim))
        scores = scores.reshape((self.M // self.smp_peract, self.smp_peract))
        if self._hp.stochastic_planning[1] == 'optimistic':
            inds = np.argmax(scores, axis=1)
            scores = np.max(scores, axis=1)
        elif self._hp.stochastic_planning[1] == 'pessimistic':
            inds = np.argmin(scores, axis=1)
            scores = np.min(scores, axis=1)
        else:
            raise ValueError

        actions = [
            actions[b, inds[b]] for b in range(self.M // self.smp_peract)
        ]
        return np.stack(actions, 0), scores

    def get_rollouts(self, actions, cem_itr, itr_times):
        raise NotImplementedError

    def act(self, t=None, i_tr=None):
        """
        Return a random action for a state.
        Args:
                if performing highres tracking images is highres image
            t: the current controller's Time step
            goal_pix: in coordinates of small image
            desig_pix: in coordinates of small image
        """
        self.i_tr = i_tr
        self.t = t

        if t == 0:
            action = np.zeros(self.agentparams['adim'])
            self._close_override = False
        else:
            if self._hp.use_first_plan:
                self.logger.log('using actions of first plan, no replanning!!')
                if t == 1:
                    self.perform_CEM()
                action = self.bestaction_withrepeat[t]
            elif self._hp.replan_interval != -1:
                if (t - 1) % self._hp.replan_interval == 0:
                    self.last_replan = t
                    self.perform_CEM()
                self.logger.log('last replan', self.last_replan)
                self.logger.log('taking action of ', t - self.last_replan)
                action = self.bestaction_withrepeat[t - self.last_replan]
            else:
                self.perform_CEM()
                action = self.bestaction[0]
                self.logger.log('########')
                self.logger.log('best action sequence: ')
                for i in range(self.bestaction.shape[0]):
                    self.logger.log("t{}: {}".format(i, self.bestaction[i]))
                self.logger.log('########')

        self.action_list.append(action)

        self.logger.log("applying action  {}".format(action))

        if self.agentparams['adim'] == 5 and action[-1] > 0:
            self._close_override = True
        else:
            self._close_override = False

        return {'actions': action, 'plan_stat': self.plan_stat}
Ejemplo n.º 8
0
def setup_predictor(hyperparams, conf, gpu_id=0, ngpu=1, logger=None):
    """
    Setup up the network for control
    :param hyperparams: general hyperparams, can include control flags
    :param conf_file for network
    :param ngpu number of gpus to use
    :return: function which predicts a batch of whole trajectories
    conditioned on the actions
    """
    assert conf['batch_size'] % ngpu == 0, "ngpu should perfectly divide batch_size"

    conf['ngpu'] = ngpu
    if logger == None:
        logger = Logger(printout=True)

    if 'ncam' in conf:
        ncam = conf['ncam']
    else:
        ncam = 1


    logger.log('making graph')
    gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.7)
    g_predictor = tf.Graph()
    logger.log('making session')
    sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True), graph=g_predictor)
    logger.log('done making session.')
    with sess.as_default():
        with g_predictor.as_default():
            logger.log('Constructing multi gpu model for control...')

            if 'float16' in conf:
                use_dtype = tf.float16
            else:
                use_dtype = tf.float32

            orig_size = conf['orig_size']
            images_pl = tf.placeholder(use_dtype, name='images',
                                       shape=(1, conf['context_frames'], ncam, orig_size[0], orig_size[1], 3))
            sdim = conf['sdim']
            adim = conf['adim']
            logger.log('adim', adim)
            logger.log('sdim', sdim)

            actions_pl = tf.placeholder(use_dtype, name='actions',
                                        shape=(conf['batch_size'], conf['sequence_length'], adim))
            states_pl = tf.placeholder(use_dtype, name='states',
                                       shape=(1, conf['context_frames'], sdim))

            if 'use_goal_image' in conf or 'no_pix_distrib' in conf:
                pix_distrib = None
            else:
                pix_distrib = tf.placeholder(use_dtype, shape=(
                1, conf['context_frames'], ncam, orig_size[0], orig_size[1], conf['ndesig']))

            # making the towers
            towers = []
            for i_gpu in range(gpu_id, ngpu + gpu_id):
                with tf.device('/device:GPU:{}'.format(i_gpu)):
                    with tf.name_scope('tower_%d' % (i_gpu)):
                        logger.log(('creating tower %d: in scope %s' % (i_gpu, tf.get_variable_scope())))
                        towers.append(Tower(conf, i_gpu, images_pl, actions_pl, states_pl, pix_distrib))
                        tf.get_variable_scope().reuse_variables()

            sess.run(tf.global_variables_initializer())

            vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
            vars = filter_vars(vars)

            if 'load_latest' in hyperparams:
                conf['pretrained_model'] = get_maxiter_weights('/result/modeldata')
                logger.log('loading {}'.format(conf['pretrained_model']))
                if conf['pred_model'] == VPred_Model_Interface:
                    towers[0].model.m.restore(sess, conf['pretrained_model'])
                else:
                    vars = variable_checkpoint_matcher(conf, vars, conf['pretrained_model'])
                    saver = tf.train.Saver(vars, max_to_keep=0)
                    saver.restore(sess, conf['pretrained_model'])
            else:
                if conf['pred_model'] == VPred_Model_Interface:
                    towers[0].model.m.restore(sess, conf['pretrained_model'])
                else:
                    vars = variable_checkpoint_matcher(conf, vars, conf['pretrained_model'])
                    saver = tf.train.Saver(vars, max_to_keep=0)
                    saver.restore(sess, conf['pretrained_model'])

            logger.log('restore done. ')

            logger.log('-------------------------------------------------------------------')
            logger.log('verify current settings!! ')
            for key in list(conf.keys()):
                logger.log(key, ': ', conf[key])
            logger.log('-------------------------------------------------------------------')

            comb_gen_img = tf.concat([to.model.gen_images for to in towers], axis=0)
            if towers[0].model.gen_states is not None:
                comb_gen_states = tf.concat([to.model.gen_states for to in towers], axis=0)
            else:
                comb_gen_states = None

            if not 'no_pix_distrib' in conf:
                comb_pix_distrib = tf.concat([to.model.gen_distrib for to in towers], axis=0)

            def predictor_func(input_images=None, input_one_hot_images=None, input_state=None, input_actions=None):
                """
                :param one_hot_images: the first two frames
                :param pixcoord: the coords of the disgnated pixel in images coord system
                :return: the predicted pixcoord at the end of sequence
                """

                feed_dict = {}
                for t in towers:
                    if hasattr(t.model, 'iter_num'):
                        feed_dict[t.model.iter_num] = 0

                feed_dict[images_pl] = input_images
                feed_dict[states_pl] = input_state
                feed_dict[actions_pl] = input_actions

                if input_one_hot_images is None:
                    if comb_gen_states is None:
                        gen_images = sess.run(comb_gen_img, feed_dict)
                        gen_states = None
                    else:
                        gen_images, gen_states = sess.run([comb_gen_img,
                                                       comb_gen_states],
                                                      feed_dict)
                    gen_distrib = None
                elif comb_gen_states is None:
                    feed_dict[pix_distrib] = input_one_hot_images
                    gen_images, gen_distrib = sess.run([comb_gen_img, comb_pix_distrib], feed_dict)
                    gen_states = None
                else:
                    feed_dict[pix_distrib] = input_one_hot_images
                    gen_images, gen_distrib, gen_states = sess.run([comb_gen_img,
                                                                    comb_pix_distrib,
                                                                    comb_gen_states],
                                                                   feed_dict)

                return gen_images, gen_distrib, gen_states

            return predictor_func
class CEMBaseController(Policy):
    """
    Cross Entropy Method Stochastic Optimizer
    """
    def __init__(self, ag_params, policyparams):
        """
        :param ag_params:
        :param policyparams:
        """
        self._hp = self._default_hparams()
        self._override_defaults(policyparams)

        self.agentparams = ag_params

        if self._hp.logging_dir:
            self._logger = Logger(
                self._hp.logging_dir,
                'cem{}log.txt'.format(self.agentparams['gpu_id']))
        else:
            self._logger = Logger(printout=True)

        self._logger.log('init CEM controller')

        self._t_since_replan = None
        self._t = None
        self._n_iter = self._hp.iterations

        #action dimensions:
        self._adim = self.agentparams['adim']
        self._sdim = self.agentparams['sdim']  # state dimension

        self._sampler = None
        self._best_indices, self._best_actions = None, None

        self._state = None
        assert self._hp.minimum_selection > 0, "must take at least 1 sample for refitting"

    def _default_hparams(self):
        default_dict = {
            'append_action': None,
            'verbose': True,
            'verbose_every_iter': False,
            'logging_dir': '',
            'hard_coded_start_action': None,
            'context_action_weight': [0.5, 0.5, 0.05, 1],
            'zeros_for_start_frames': True,
            'replan_interval': 0,
            'sampler': GaussianCEMSampler,
            'T': 15,  # planning horizon
            'iterations': 3,
            'num_samples': 200,
            'selection_frac':
            0.,  # specifcy which fraction of best samples to use to compute mean and var for next CEM iteration
            'start_planning': 0,
            'minimum_selection': 10
        }

        parent_params = super(CEMBaseController, self)._default_hparams()
        for k in default_dict.keys():
            parent_params.add_hparam(k, default_dict[k])
        return parent_params

    def _override_defaults(self, policyparams):
        sampler_class = policyparams.get('sampler', GaussianCEMSampler)
        for name, value in sampler_class.get_default_hparams().items():
            if name in self._hp:
                print('Warning default value for {} already set!'.format(name))
                self._hp.set_hparam(name, value)
            else:
                self._hp.add_hparam(name, value)

        super(CEMBaseController, self)._override_defaults(policyparams)
        self._hp.sampler = sampler_class

    def reset(self):
        self._best_indices = None
        self._best_actions = None
        self._t_since_replan = None
        self._sampler = self._hp.sampler(self._hp, self._adim, self._sdim)
        self.plan_stat = {}  #planning statistics

    def perform_CEM(self, state):
        self._logger.log('starting cem at t{}...'.format(self._t))
        self._logger.log('------------------------------------------------')

        K = self._hp.minimum_selection
        if self._hp.selection_frac:
            K = max(int(self._hp.selection_frac * self._hp.num_samples),
                    self._hp.minimum_selection)
        actions = self._sampler.sample_initial_actions(self._t,
                                                       self._hp.num_samples,
                                                       state[-1])
        for itr in range(self._n_iter):
            if self._hp.append_action:
                act_append = np.tile(
                    np.array(self._hp.append_action)[None, None],
                    [self._hp.num_samples, actions.shape[1], 1])
                actions = np.concatenate((actions, act_append), axis=-1)

            self._logger.log('------------')
            self._logger.log('iteration: ', itr)

            scores = self.evaluate_rollouts(actions, itr)
            assert scores.shape == (
                actions.shape[0], ), "score shape should be (n_actions,)"

            self._best_indices = scores.argsort()[:K]
            self._best_actions = actions[self._best_indices]

            self.plan_stat['scores_itr{}'.format(itr)] = scores
            if itr < self._n_iter - 1:
                re_sample_act = self._best_actions.copy()
                if self._hp.append_action:
                    re_sample_act = re_sample_act[:, :, :-len(self._hp.
                                                              append_action)]

                actions = self._sampler.sample_next_actions(
                    self._hp.num_samples, re_sample_act,
                    scores[self._best_indices].copy())

        self._t_since_replan = 0

    def evaluate_rollouts(self, actions, cem_itr):
        raise NotImplementedError

    def _verbose_condition(self, cem_itr):
        if self._hp.verbose:
            if self._hp.verbose_every_iter or cem_itr == self._n_iter - 1:
                return True
        return False

    def act(self, t=None, i_tr=None, state=None):
        """
        Return a random action for a state.
        Args:
            t: the current controller's Time step
        """
        self._state = state
        self.i_tr = i_tr
        self._t = t

        if t < self._hp.start_planning:
            if self._hp.zeros_for_start_frames:
                assert self._hp.hard_coded_start_action is None
                action = np.zeros(self.agentparams['adim'])
            elif self._hp.hard_coded_start_action:
                action = np.array(self._hp.hard_coded_start_action)
            else:
                initial_sampler = self._hp.sampler(self._hp, self._adim,
                                                   self._sdim)
                action = initial_sampler.sample_initial_actions(
                    t, 1, state[-1])[0, 0] * self._hp.context_action_weight
                if self._hp.append_action:
                    action = np.concatenate((action, self._hp.append_action),
                                            axis=0)

        else:
            if self._hp.replan_interval:
                if self._t_since_replan is None or self._t_since_replan + 1 >= self._hp.replan_interval:
                    self.perform_CEM(state)
                else:
                    self._t_since_replan += 1
            else:
                self.perform_CEM(state)
            action = self._best_actions[0, self._t_since_replan]

        assert action.shape == (
            self.agentparams['adim'], ), "action shape does not match adim!"

        self._logger.log('time {}, action - {}'.format(t, action))

        if self._best_actions is not None:
            action_plan_slice = self._best_actions[:,
                                                   min(
                                                       self._t_since_replan +
                                                       1, self._hp.T - 1):]
            self._sampler.log_best_action(action, action_plan_slice)
        else:
            self._sampler.log_best_action(action, None)

        return {'actions': action, 'plan_stat': self.plan_stat}