def __init__(self, model: DeterministicModel, dim_state: int, dim_action: int, normalizers, datasets, *, loss, batch_size): super().__init__() self.model = model self.iterators = { 'train': datasets['train'].iterator(batch_size, n_epochs=-1), 'dev': datasets['dev'].sample_iterator(1024), } self._n_updates = 0 criterion_map = { 'L1': nn.L1Loss(), 'L2': nn.L2Loss(), 'MSE': nn.MSELoss(), # 'G': DescLoss(vfn, normalizers, dim_state), } self.normalizers = normalizers self.criterion = criterion_map[loss] with self.scope: self.op_states = self.model.op_states self.op_actions = self.model.op_actions self.op_next_states = self.model.op_next_states self.op_next_states_ = tf.placeholder(tf.float32, shape=[None, dim_state]) self.build() self.train_loss_meter = AverageMeter()
def compute_vf(self, states, actions, returns): vf_loss = nn.MSELoss()(self.vf(states, actions), returns).reduce_mean() optimizer = tf.train.AdamOptimizer(self.vf_lr) train_vf = optimizer.minimize(vf_loss) return vf_loss, train_vf
def main(): FLAGS.set_seed() FLAGS.freeze() env = make_env(FLAGS.env.id) dim_state = int(np.prod(env.observation_space.shape)) dim_action = int(np.prod(env.action_space.shape)) env.verify() normalizers = Normalizers(dim_action=dim_action, dim_state=dim_state) dtype = gen_dtype(env, 'state action next_state reward done timeout') train_set = Dataset(dtype, FLAGS.rollout.max_buf_size) dev_set = Dataset(dtype, FLAGS.rollout.max_buf_size) policy = GaussianMLPPolicy(dim_state, dim_action, normalizer=normalizers.state, **FLAGS.policy.as_dict()) # batched noises noise = OUNoise(env.action_space, theta=FLAGS.OUNoise.theta, sigma=FLAGS.OUNoise.sigma, shape=(1, dim_action)) vfn = MLPVFunction(dim_state, [64, 64], normalizers.state) model = DynamicsModel(dim_state, dim_action, normalizers, FLAGS.model.hidden_sizes) virt_env = VirtualEnv(model, make_env(FLAGS.env.id), FLAGS.plan.n_envs, opt_model=FLAGS.slbo.opt_model) virt_runner = Runner( virt_env, **{ **FLAGS.runner.as_dict(), 'max_steps': FLAGS.plan.max_steps }) criterion_map = { 'L1': nn.L1Loss(), 'L2': nn.L2Loss(), 'MSE': nn.MSELoss(), } criterion = criterion_map[FLAGS.model.loss] loss_mod = MultiStepLoss(model, normalizers, dim_state, dim_action, criterion, FLAGS.model.multi_step) loss_mod.build_backward(FLAGS.model.lr, FLAGS.model.weight_decay) algo = TRPO(vfn=vfn, policy=policy, dim_state=dim_state, dim_action=dim_action, **FLAGS.TRPO.as_dict()) tf.get_default_session().run(tf.global_variables_initializer()) runners = { 'test': make_real_runner(4), 'collect': make_real_runner(1), 'dev': make_real_runner(1), 'train': make_real_runner(FLAGS.plan.n_envs) if FLAGS.algorithm == 'MF' else virt_runner, } settings = [(runners['test'], policy, 'Real Env'), (runners['train'], policy, 'Virt Env')] saver = nn.ModuleDict({'policy': policy, 'model': model, 'vfn': vfn}) print(saver) if FLAGS.ckpt.model_load: saver.load_state_dict(np.load(FLAGS.ckpt.model_load)[()]) logger.warning('Load model from %s', FLAGS.ckpt.model_load) if FLAGS.ckpt.buf_load: n_samples = 0 for i in range(FLAGS.ckpt.buf_load_index): data = pickle.load( open(f'{FLAGS.ckpt.buf_load}/stage-{i}.inc-buf.pkl', 'rb')) add_multi_step(data, train_set) n_samples += len(data) logger.warning('Loading %d samples from %s', n_samples, FLAGS.ckpt.buf_load) max_ent_coef = FLAGS.TRPO.ent_coef for T in range(FLAGS.slbo.n_stages): logger.info('------ Starting Stage %d --------', T) evaluate(settings, 'episode') if not FLAGS.use_prev: train_set.clear() dev_set.clear() # collect data recent_train_set, ep_infos = runners['collect'].run( noise.make(policy), FLAGS.rollout.n_train_samples) add_multi_step(recent_train_set, train_set) add_multi_step( runners['dev'].run(noise.make(policy), FLAGS.rollout.n_dev_samples)[0], dev_set, ) returns = np.array([ep_info['return'] for ep_info in ep_infos]) if len(returns) > 0: logger.info("episode: %s", np.mean(returns)) if T == 0: # check samples = train_set.sample_multi_step(100, 1, FLAGS.model.multi_step) for i in range(FLAGS.model.multi_step - 1): masks = 1 - (samples.done[i] | samples.timeout[i])[..., np.newaxis] assert np.allclose(samples.state[i + 1] * masks, samples.next_state[i] * masks) # recent_states = obsvs # ref_actions = policy.eval('actions_mean actions_std', states=recent_states) if FLAGS.rollout.normalizer == 'policy' or FLAGS.rollout.normalizer == 'uniform' and T == 0: normalizers.state.update(recent_train_set.state) normalizers.action.update(recent_train_set.action) normalizers.diff.update(recent_train_set.next_state - recent_train_set.state) if T == 50: max_ent_coef = 0. for i in range(FLAGS.slbo.n_iters): if i % FLAGS.slbo.n_evaluate_iters == 0 and i != 0: # cur_actions = policy.eval('actions_mean actions_std', states=recent_states) # kl_old_new = gaussian_kl(*ref_actions, *cur_actions).sum(axis=1).mean() # logger.info('KL(old || cur) = %.6f', kl_old_new) evaluate(settings, 'iteration') losses = deque(maxlen=FLAGS.slbo.n_model_iters) grad_norm_meter = AverageMeter() n_model_iters = FLAGS.slbo.n_model_iters for _ in range(n_model_iters): samples = train_set.sample_multi_step( FLAGS.model.train_batch_size, 1, FLAGS.model.multi_step) _, train_loss, grad_norm = loss_mod.get_loss( samples.state, samples.next_state, samples.action, ~samples.done & ~samples.timeout, fetch='train loss grad_norm') losses.append(train_loss.mean()) grad_norm_meter.update(grad_norm) # ideally, we should define an Optimizer class, which takes parameters as inputs. # The `update` method of `Optimizer` will invalidate all parameters during updates. for param in model.parameters(): param.invalidate() if i % FLAGS.model.validation_freq == 0: samples = train_set.sample_multi_step( FLAGS.model.train_batch_size, 1, FLAGS.model.multi_step) loss = loss_mod.get_loss(samples.state, samples.next_state, samples.action, ~samples.done & ~samples.timeout) loss = loss.mean() if np.isnan(loss) or np.isnan(np.mean(losses)): logger.info('nan! %s %s', np.isnan(loss), np.isnan(np.mean(losses))) logger.info( '# Iter %3d: Loss = [train = %.3f, dev = %.3f], after %d steps, grad_norm = %.6f', i, np.mean(losses), loss, n_model_iters, grad_norm_meter.get()) for n_updates in range(FLAGS.slbo.n_policy_iters): if FLAGS.algorithm != 'MF' and FLAGS.slbo.start == 'buffer': runners['train'].set_state( train_set.sample(FLAGS.plan.n_envs).state) else: runners['train'].reset() data, ep_infos = runners['train'].run( policy, FLAGS.plan.n_trpo_samples) advantages, values = runners['train'].compute_advantage( vfn, data) dist_mean, dist_std, vf_loss = algo.train( max_ent_coef, data, advantages, values) returns = [info['return'] for info in ep_infos] logger.info( '[TRPO] # %d: n_episodes = %d, returns: {mean = %.0f, std = %.0f}, ' 'dist std = %.10f, dist mean = %.10f, vf_loss = %.3f', n_updates, len(returns), np.mean(returns), np.std(returns) / np.sqrt(len(returns)), dist_std, dist_mean, vf_loss) if T % FLAGS.ckpt.n_save_stages == 0: np.save(f'{FLAGS.log_dir}/stage-{T}', saver.state_dict()) np.save(f'{FLAGS.log_dir}/final', saver.state_dict()) if FLAGS.ckpt.n_save_stages == 1: pickle.dump(recent_train_set, open(f'{FLAGS.log_dir}/stage-{T}.inc-buf.pkl', 'wb'))