示例#1
0
    def build_graph(self):
        """Builds the neural network graph."""

        # define graph
        self.g = tf.Graph()
        with self.g.as_default():

            # create and store a new session for the graph
            self.sess = tf.Session()

            # define placeholders
            self.x = tf.placeholder(shape=[None, self.dim_input],
                                    dtype=tf.float32)
            self.y = tf.placeholder(shape=[None, self.num_classes],
                                    dtype=tf.float32)

            # define simple model
            with tf.variable_scope('last_layer'):
                self.z = tf.layers.dense(inputs=self.x, units=self.num_classes)

            self.loss = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits_v2(labels=self.y,
                                                           logits=self.z))

            self.output_probs = tf.nn.softmax(self.z)

            # Variables of the last layer
            self.ll_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            self.ll_vars_concat = tf.concat(
                [self.ll_vars[0],
                 tf.expand_dims(self.ll_vars[1], axis=0)], 0)

            # Summary
            _variable_summaries(self.ll_vars_concat)

            # saving the weights of last layer when running bootstrap algorithm
            self.saver = tf.train.Saver(var_list=self.ll_vars)

            self.gd_opt = tf.train.GradientDescentOptimizer(self.step_size)

            # SGD optimizer for the last layer
            grads_vars_sgd = self.gd_opt.compute_gradients(self.loss)
            self.train_op = self.gd_opt.apply_gradients(grads_vars_sgd)

            for g, v in grads_vars_sgd:
                if g is not None:
                    s = list(v.name)
                    s[v.name.rindex(':')] = '_'
                    tf.summary.histogram(''.join(s) + '/grad_hist_boot_sgd', g)

            # Merge all the summaries and write them out
            self.all_summaries = tf.summary.merge_all()
            location = os.path.join(self.working_dir, 'logs')
            self.writer = tf.summary.FileWriter(location, graph=self.g)

            saver_network = tf.train.Saver(var_list=self.ll_vars)
            print('Loading the network...')
            # Restores from checkpoint
            saver_network.restore(self.sess, self.model_dir)
            print('Graph successfully loaded.')
示例#2
0
    def build_graph(self):
        """Builds the neural network graph."""

        # define graph
        self.g = tf.Graph()
        with self.g.as_default():

            # create and store a new session for the graph
            self.sess = tf.Session()

            # define placeholders
            self.x = tf.placeholder(shape=[None, self.dim_input],
                                    dtype=tf.float32)
            self.y = tf.placeholder(shape=[None, self.num_classes],
                                    dtype=tf.float32)

            # linear layer(WX + b)
            with tf.variable_scope('last_layer/dense') as scope:
                weights = tf.get_variable('kernel',
                                          [self.dim_input, self.num_classes],
                                          dtype=tf.float32)
                biases = tf.get_variable('bias', [self.num_classes],
                                         dtype=tf.float32)
                wb = tf.concat([weights, tf.expand_dims(biases, axis=0)], 0)
                wb_renorm = tf.matmul(self.sigma_half_inv, wb)
                weights_renorm = wb_renorm[:self.dim_input, :]
                biases_renorm = wb_renorm[-1, :]
                self.z = tf.add(tf.matmul(self.x, weights_renorm),
                                biases_renorm,
                                name=scope.name)

            # Gaussian prior
            # prior = tf.nn.l2_loss(weights) + tf.nn.l2_loss(biases)

            # Non normalized loss, because of the preconditioning
            self.loss = self.n * tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits_v2(labels=self.y,
                                                           logits=self.z))

            # Bayesian loss
            self.bayesian_loss = self.loss  # + prior

            self.output_probs = tf.nn.softmax(self.z)

            # Variables of the last layer
            self.ll_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            self.ll_vars_concat = tf.concat(
                [self.ll_vars[0],
                 tf.expand_dims(self.ll_vars[1], axis=0)], 0)

            # Summary
            _variable_summaries(self.ll_vars_concat)

            # saving the weights of last layer when running SGLD/SGD/MCMC algorithm
            self.saver = tf.train.Saver(var_list=self.ll_vars,
                                        max_to_keep=self.num_samples)

            self.gd_opt = tf.train.GradientDescentOptimizer(self.step_size)
            # SGLD optimizer for the last layer
            if self.sampler in ['sgld', 'lmc']:
                grads_vars = self.gd_opt.compute_gradients(self.bayesian_loss)
                grads_vars_sgld = []

                for g, v in grads_vars:
                    if g is not None:
                        s = list(v.name)
                        s[v.name.rindex(':')] = '_'
                        # Adding Gaussian noise to the gradient
                        gaussian_noise = (np.sqrt(2. / self.step_size) *
                                          tf.random_normal(tf.shape(g)))
                        g_sgld = g + gaussian_noise
                        tf.summary.histogram(''.join(s) + '/grad_hist_mcmc', g)
                        tf.summary.histogram(
                            ''.join(s) + '/gaussian_noise_hist_mcmc',
                            gaussian_noise)
                        tf.summary.histogram(
                            ''.join(s) + '/grad_total_hist_mcmc', g_sgld)
                        grads_vars_sgld.append((g_sgld, v))

                self.train_op = self.gd_opt.apply_gradients(grads_vars_sgld)

            # SGD optimizer for the last layer
            if self.sampler == 'sgd':
                grads_vars_sgd = self.gd_opt.compute_gradients(self.loss)
                self.train_op = self.gd_opt.apply_gradients(grads_vars_sgd)

                for g, v in grads_vars_sgd:
                    if g is not None:
                        s = list(v.name)
                        s[v.name.rindex(':')] = '_'
                        tf.summary.histogram(''.join(s) + '/grad_hist_sgd', g)

            # Merge all the summaries and write them out
            self.all_summaries = tf.summary.merge_all()
            location = os.path.join(self.working_dir, 'logs')
            self.writer = tf.summary.FileWriter(location, graph=self.g)

            saver_network = tf.train.Saver(var_list=self.ll_vars)
            print('loading the network ...')
            # Restores from checkpoint
            saver_network.restore(self.sess, self.model_dir)
            print('Graph successfully loaded.')
示例#3
0
def train_eval(

        ##############################################
        # types of params:
        # 0: specific to algorithm (gin file 0)
        # 1: specific to environment (gin file 1)
        # 2: specific to experiment (gin file 2 + command line)

        # Note: there are other important params
        # in eg ModelDistributionNetwork that the gin files specify
        # like sparse vs dense rewards, latent dimensions, etc.
        ##############################################

    # basic params for running/logging experiment
    root_dir,  # 2
        experiment_name,  # 2
        num_iterations=int(1e7),  # 2
        seed=1,  # 2
        gpu_allow_growth=False,  # 2
        gpu_memory_limit=None,  # 2
        verbose=True,  # 2
        policy_checkpoint_freq_in_iter=100,  # policies needed for future eval                             # 2
        train_checkpoint_freq_in_iter=0,  #default don't save                                              # 2
        rb_checkpoint_freq_in_iter=0,  #default don't save                                                 # 2
        logging_freq_in_iter=10,  # printing to terminal                                                   # 2
        summary_freq_in_iter=10,  # saving to tb                                                           # 2
        num_images_per_summary=2,  # 2
        summaries_flush_secs=10,  # 2
        max_episode_len_override=None,  # 2
        num_trials_to_render=1,  # 2

        # environment, action mode, etc.
    env_name='HalfCheetah-v2',  # 1
        action_repeat=1,  # 1
        action_mode='joint_position',  # joint_position or joint_delta_position                           # 1
        double_camera=False,  # camera input                                                               # 1
        universe='gym',  # default
        task_reward_dim=1,  # default

        # dims for all networks
    actor_fc_layers=(256, 256),  # 1
        critic_obs_fc_layers=None,  # 1
        critic_action_fc_layers=None,  # 1
        critic_joint_fc_layers=(256, 256),  # 1
        num_repeat_when_concatenate=None,  # 1

        # networks
    critic_input='state',  # 0
        actor_input='state',  # 0

        # specifying tasks and eval
    episodes_per_trial=1,  # 2
        num_train_tasks=10,  # 2
        num_eval_tasks=10,  # 2
        num_eval_trials=10,  # 2
        eval_interval=10,  # 2
        eval_on_holdout_tasks=True,  # 2

        # data collection/buffer
    init_collect_trials_per_task=None,  # 2
        collect_trials_per_task=None,  # 2
        num_tasks_to_collect_per_iter=5,  # 2
        replay_buffer_capacity=int(1e5),  # 2

        # training
    init_model_train_ratio=0.8,  # 2
        model_train_ratio=1,  # 2
        model_train_freq=1,  # 2
        ac_train_ratio=1,  # 2
        ac_train_freq=1,  # 2
        num_tasks_per_train=5,  # 2
        train_trials_per_task=5,  # 2
        model_bs_in_steps=256,  # 2
        ac_bs_in_steps=128,  # 2

        # default AC learning rates, gamma, etc.
    target_update_tau=0.005,
        target_update_period=1,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        model_learning_rate=1e-4,
        td_errors_loss_fn=functools.partial(
            tf.compat.v1.losses.mean_squared_error, weights=0.5),
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        log_image_strips=False,
        stop_model_training=1E10,
        eval_only=False,  # evaluate checkpoints ONLY
        log_image_observations=False,
        load_offline_data=False,  # whether to use offline data
        offline_data_dir=None,  # replay buffer's dir
        offline_episode_len=None,  # episode len of episodes stored in rb
        offline_ratio=0,  # ratio of data that is from offline buffer
):

    g = tf.Graph()

    # register all gym envs
    max_steps_dict = {
        "HalfCheetahVel-v0": 50,
        "SawyerReach-v0": 40,
        "SawyerReachMT-v0": 40,
        "SawyerPeg-v0": 40,
        "SawyerPegMT-v0": 40,
        "SawyerPegMT4box-v0": 40,
        "SawyerShelfMT-v0": 40,
        "SawyerKitchenMT-v0": 40,
        "SawyerShelfMT-v2": 40,
        "SawyerButtons-v0": 40,
    }
    if max_episode_len_override:
        max_steps_dict[env_name] = max_episode_len_override
    register_all_gym_envs(max_steps_dict)

    # set max_episode_len based on our env
    max_episode_len = max_steps_dict[env_name]

    ######################################################
    # Calculate additional params
    ######################################################

    # convert to number of steps
    env_steps_per_trial = episodes_per_trial * max_episode_len
    real_env_steps_per_trial = episodes_per_trial * (max_episode_len + 1)
    env_steps_per_iter = num_tasks_to_collect_per_iter * collect_trials_per_task * env_steps_per_trial
    per_task_collect_steps = collect_trials_per_task * env_steps_per_trial

    # initial collect + train
    init_collect_env_steps = num_train_tasks * init_collect_trials_per_task * env_steps_per_trial
    init_model_train_steps = int(init_collect_env_steps *
                                 init_model_train_ratio)

    # collect + train
    collect_env_steps_per_iter = num_tasks_to_collect_per_iter * per_task_collect_steps
    model_train_steps_per_iter = int(env_steps_per_iter * model_train_ratio)
    ac_train_steps_per_iter = int(env_steps_per_iter * ac_train_ratio)

    # other
    global_steps_per_iter = collect_env_steps_per_iter + model_train_steps_per_iter + ac_train_steps_per_iter
    sample_episodes_per_task = train_trials_per_task * episodes_per_trial  # number of episodes to sample from each replay
    model_bs_in_trials = model_bs_in_steps // real_env_steps_per_trial

    # assertions that make sure parameters make sense
    assert model_bs_in_trials > 0, "model batch size need to be at least as big as one full real trial"
    assert num_tasks_to_collect_per_iter <= num_train_tasks, "when sampling replace=False"
    assert num_tasks_per_train * train_trials_per_task >= model_bs_in_trials, "not enough data for one batch model train"
    assert num_tasks_per_train * train_trials_per_task * env_steps_per_trial >= ac_bs_in_steps, "not enough data for one batch ac train"

    ######################################################
    # Print a summary of params
    ######################################################
    MELD_summary_string = f"""\n\n\n
==============================================================
==============================================================
  \n
  MELD algorithm summary:

  * each trial consists of {episodes_per_trial} episodes
  * episode length: {max_episode_len}, trial length: {env_steps_per_trial}
  * {num_train_tasks} train tasks, {num_eval_tasks} eval tasks, hold-out: {eval_on_holdout_tasks}
  * environment: {env_name}
  
  For each of {num_train_tasks} tasks:
    Do {init_collect_trials_per_task} trials of initial collect
  (total {init_collect_env_steps} env steps)
  
  Do {init_model_train_steps} steps of initial model training
    
  For i in range(inf):
    For each of {num_tasks_to_collect_per_iter} randomly selected tasks:
      Do {collect_trials_per_task} trials of collect
    (which is {collect_trials_per_task*env_steps_per_trial} env steps per task)
    (for a total of {num_tasks_to_collect_per_iter*collect_trials_per_task*env_steps_per_trial} env steps in the iteration)
    
    if i % model_train_freq(={model_train_freq}):
      Do {model_train_steps_per_iter} steps of model training
        - select {sample_episodes_per_task} episodes from each of {num_tasks_per_train} random train_tasks, combine into {num_tasks_per_train*train_trials_per_task} total trials.
        - pick randomly {model_bs_in_trials} trials, train model on whole trials.
    
    if i % ac_train_freq(={ac_train_freq}):
      Do {ac_train_steps_per_iter} steps of ac training
        - select {sample_episodes_per_task} episodes from each of {num_tasks_per_train} random train_tasks, combine into {num_tasks_per_train*train_trials_per_task} total trials.
        - pick randomly {ac_bs_in_steps} transitions, not including between trial transitions, 
          to train ac.
  
  
  * Other important params:
  Evaluate policy every {eval_interval} iters, equivalent to {global_steps_per_iter*eval_interval/1000:.1f}k global steps
  Average evaluation across {num_eval_trials} trials
  Save summary to tensorboard every {summary_freq_in_iter} iters, equivalent to {global_steps_per_iter*summary_freq_in_iter/1000:.1f}k global steps
  Checkpoint:
   - training checkpoint every {train_checkpoint_freq_in_iter} iters, equivalent to {global_steps_per_iter*train_checkpoint_freq_in_iter//1000}k global steps, keep 1 checkpoint
   - policy checkpoint every {policy_checkpoint_freq_in_iter} iters, equivalent to {global_steps_per_iter*policy_checkpoint_freq_in_iter//1000}k global steps, keep all checkpoints
   - replay buffer checkpoint every {rb_checkpoint_freq_in_iter} iters, equivalent to {global_steps_per_iter*rb_checkpoint_freq_in_iter//1000}k global steps, keep 1 checkpoint
    
  \n
=============================================================
=============================================================
  """

    print(MELD_summary_string)
    time.sleep(1)

    ######################################################
    # Seed + name + GPU configs + directories for saving
    ######################################################
    np.random.seed(int(seed))
    experiment_name += "_seed" + str(seed)

    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpu_allow_growth:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    if gpu_memory_limit:
        for gpu in gpus:
            tf.config.experimental.set_virtual_device_configuration(
                gpu, [
                    tf.config.experimental.VirtualDeviceConfiguration(
                        memory_limit=gpu_memory_limit)
                ])

    train_eval_dir = get_train_eval_dir(root_dir, universe, env_name,
                                        experiment_name)
    train_dir = os.path.join(train_eval_dir, 'train')
    eval_dir = os.path.join(train_eval_dir, 'eval')
    eval_dir_2 = os.path.join(train_eval_dir, 'eval2')

    ######################################################
    # Train and Eval Summary Writers
    ######################################################
    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_summary_flush_op = eval_summary_writer.flush()

    eval_logger = Logger(eval_dir_2)

    ######################################################
    # Train and Eval metrics
    ######################################################
    eval_buffer_size = num_eval_trials * episodes_per_trial * max_episode_len  # across all eval trials in each evaluation
    eval_metrics = []
    for position in range(
            episodes_per_trial
    ):  # have metrics for each episode position, to track whether it is learning
        eval_metrics_pos = [
            py_metrics.AverageReturnMetric(name='c_AverageReturnEval_' +
                                           str(position),
                                           buffer_size=eval_buffer_size),
            py_metrics.AverageEpisodeLengthMetric(
                name='f_AverageEpisodeLengthEval_' + str(position),
                buffer_size=eval_buffer_size),
            custom_metrics.AverageScoreMetric(
                name="d_AverageScoreMetricEval_" + str(position),
                buffer_size=eval_buffer_size),
        ]
        eval_metrics.extend(eval_metrics_pos)

    train_buffer_size = num_train_tasks * episodes_per_trial
    train_metrics = [
        tf_metrics.NumberOfEpisodes(name='NumberOfEpisodes'),
        tf_metrics.EnvironmentSteps(name='EnvironmentSteps'),
        tf_py_metric.TFPyMetric(
            py_metrics.AverageReturnMetric(name="a_AverageReturnTrain",
                                           buffer_size=train_buffer_size)),
        tf_py_metric.TFPyMetric(
            py_metrics.AverageEpisodeLengthMetric(
                name="e_AverageEpisodeLengthTrain",
                buffer_size=train_buffer_size)),
        tf_py_metric.TFPyMetric(
            custom_metrics.AverageScoreMetric(name="b_AverageScoreTrain",
                                              buffer_size=train_buffer_size)),
    ]

    global_step = tf.compat.v1.train.get_or_create_global_step(
    )  # will be use to record number of model grad steps + ac grad steps + env_step

    log_cond = get_log_condition_tensor(
        global_step, init_collect_trials_per_task, env_steps_per_trial,
        num_train_tasks, init_model_train_steps, collect_trials_per_task,
        num_tasks_to_collect_per_iter, model_train_steps_per_iter,
        ac_train_steps_per_iter, summary_freq_in_iter, eval_interval)

    with tf.compat.v2.summary.record_if(log_cond):

        ######################################################
        # Create env
        ######################################################
        py_env, eval_py_env, train_tasks, eval_tasks = load_environments(
            universe,
            action_mode,
            env_name=env_name,
            observations_whitelist=['state', 'pixels', "env_info"],
            action_repeat=action_repeat,
            num_train_tasks=num_train_tasks,
            num_eval_tasks=num_eval_tasks,
            eval_on_holdout_tasks=eval_on_holdout_tasks,
            return_multiple_tasks=True,
        )
        override_reward_func = None
        if load_offline_data:
            py_env.set_task_dict(train_tasks)
            override_reward_func = py_env.override_reward_func

        tf_env = tf_py_environment.TFPyEnvironment(py_env, isolation=True)

        # Get data specs from env
        time_step_spec = tf_env.time_step_spec()
        observation_spec = time_step_spec.observation
        action_spec = tf_env.action_spec()
        original_control_timestep = get_control_timestep(eval_py_env)

        # fps
        control_timestep = original_control_timestep * float(action_repeat)
        render_fps = int(np.round(1.0 / original_control_timestep))

        ######################################################
        # Latent variable model
        ######################################################
        if verbose:
            print("-- start constructing model networks --")

        model_net = ModelDistributionNetwork(
            double_camera=double_camera,
            observation_spec=observation_spec,
            num_repeat_when_concatenate=num_repeat_when_concatenate,
            task_reward_dim=task_reward_dim,
            episodes_per_trial=episodes_per_trial,
            max_episode_len=max_episode_len
        )  # rest of arguments provided via gin

        if verbose:
            print("-- finish constructing AC networks --")

        ######################################################
        # Compressor Network for Actor/Critic
        # The model's compressor is also used by the AC
        # compressor function: images --> features
        ######################################################

        compressor_net = model_net.compressor

        ######################################################
        # Specs for Actor and Critic
        ######################################################
        if actor_input == 'state':
            actor_state_size = observation_spec['state'].shape[0]
        elif actor_input == 'latentSample':
            actor_state_size = model_net.state_size
        elif actor_input == "latentDistribution":
            actor_state_size = 2 * model_net.state_size  # mean and (diagonal) variance of gaussian, of two latents
        else:
            raise NotImplementedError
        actor_input_spec = tensor_spec.TensorSpec((actor_state_size, ),
                                                  dtype=tf.float32)

        if critic_input == 'state':
            critic_state_size = observation_spec['state'].shape[0]
        elif critic_input == 'latentSample':
            critic_state_size = model_net.state_size
        elif critic_input == "latentDistribution":
            critic_state_size = 2 * model_net.state_size  # mean and (diagonal) variance of gaussian, of two latents
        else:
            raise NotImplementedError
        critic_input_spec = tensor_spec.TensorSpec((critic_state_size, ),
                                                   dtype=tf.float32)

        ######################################################
        # Actor and Critic Networks
        ######################################################
        if verbose:
            print("-- start constructing Actor and Critic networks --")

        actor_net = actor_distribution_network.ActorDistributionNetwork(
            actor_input_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
        )

        critic_net = critic_network.CriticNetwork(
            (critic_input_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers)

        if verbose:
            print("-- finish constructing AC networks --")
            print("-- start constructing agent --")

        ######################################################
        # Create the agent
        ######################################################

        which_posterior_overwrite = None
        which_reward_overwrite = None

        meld_agent = MeldAgent(
            # specs
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            # step counter
            train_step_counter=
            global_step,  # will count number of model training steps
            # networks
            actor_network=actor_net,
            critic_network=critic_net,
            model_network=model_net,
            compressor_network=compressor_net,
            # optimizers
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            model_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=model_learning_rate),
            # target update
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            # inputs
            critic_input=critic_input,
            actor_input=actor_input,
            # bs stuff
            model_batch_size=model_bs_in_steps,
            ac_batch_size=ac_bs_in_steps,
            # other
            num_tasks_per_train=num_tasks_per_train,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            control_timestep=control_timestep,
            num_images_per_summary=num_images_per_summary,
            task_reward_dim=task_reward_dim,
            episodes_per_trial=episodes_per_trial,
            # offline data
            override_reward_func=override_reward_func,
            offline_ratio=offline_ratio,
        )

        if verbose:
            print("-- finish constructing agent --")

        ######################################################
        # Replay buffers + observers to add data to them
        ######################################################
        replay_buffers = []
        replay_observers = []
        for _ in range(num_train_tasks):
            replay_buffer_episodic = episodic_replay_buffer.EpisodicReplayBuffer(
                meld_agent.collect_policy.
                trajectory_spec,  # spec of each point stored in here (i.e. Trajectory)
                capacity=replay_buffer_capacity,
                completed_only=
                True,  # in as_dataset, if num_steps is None, this means return full episodes
                # device='GPU:0', # gpu not supported for some reason
                begin_episode_fn=lambda traj: traj.is_first()[
                    0],  # first step of seq we add should be is_first
                end_episode_fn=lambda traj: traj.is_last()[
                    0],  # last step of seq we add should be is_last
                dataset_drop_remainder=
                True,  #`as_dataset` makes the final batch be dropped if it does not contain exactly `sample_batch_size` items
            )
            replay_buffer = StatefulEpisodicReplayBuffer(
                replay_buffer_episodic)  # adding num_episodes here is bad
            replay_buffers.append(replay_buffer)
            replay_observers.append([replay_buffer.add_sequence])

        if load_offline_data:
            # for each task, has a separate replay buffer for relabeled data
            replay_buffers_withRelabel = []
            replay_observers_withRelabel = []
            for _ in range(num_train_tasks):
                replay_buffer_episodic_withRelabel = episodic_replay_buffer.EpisodicReplayBuffer(
                    meld_agent.collect_policy.
                    trajectory_spec,  # spec of each point stored in here (i.e. Trajectory)
                    capacity=replay_buffer_capacity,
                    completed_only=
                    True,  # in as_dataset, if num_steps is None, this means return full episodes
                    # device='GPU:0', # gpu not supported for some reason
                    begin_episode_fn=lambda traj: traj.is_first()[
                        0],  # first step of seq we add should be is_first
                    end_episode_fn=lambda traj: traj.is_last()[
                        0],  # last step of seq we add should be is_last
                    dataset_drop_remainder=True,
                    # `as_dataset` makes the final batch be dropped if it does not contain exactly `sample_batch_size` items
                )
                replay_buffer_withRelabel = StatefulEpisodicReplayBuffer(
                    replay_buffer_episodic_withRelabel
                )  # adding num_episodes here is bad
                replay_buffers_withRelabel.append(replay_buffer_withRelabel)
                replay_observers_withRelabel.append(
                    [replay_buffer_withRelabel.add_sequence])

        if verbose:
            print("-- finish constructing replay buffers --")
            print("-- start constructing policies and collect ops --")

        ######################################################
        # Policies
        #####################################################

        # init collect policy (random)
        init_collect_policy = random_tf_policy.RandomTFPolicy(
            time_step_spec, action_spec)

        # eval
        eval_py_policy = py_tf_policy.PyTFPolicy(meld_agent.policy)

        ################################################################################
        # Collect ops : use policies to get data + have the observer put data into corresponding RB
        ################################################################################

        #init collection (with random policy)
        init_collect_ops = []
        for task_idx in range(num_train_tasks):
            # put init data into the rb + track with the train metric
            observers = replay_observers[task_idx] + train_metrics

            # initial collect op
            init_collect_op = DynamicTrialDriver(
                tf_env,
                init_collect_policy,
                num_trials_to_collect=init_collect_trials_per_task,
                observers=observers,
                episodes_per_trial=
                episodes_per_trial,  # policy state will not be reset within these episodes
                max_episode_len=max_episode_len,
            ).run()  # collect one trial
            init_collect_ops.append(init_collect_op)

        # data collection for training (with collect policy)
        collect_ops = []
        for task_idx in range(num_train_tasks):
            collect_op = DynamicTrialDriver(
                tf_env,
                meld_agent.collect_policy,
                num_trials_to_collect=collect_trials_per_task,
                observers=replay_observers[task_idx] +
                train_metrics,  # put data into 1st RB + track with 1st pol metrics
                episodes_per_trial=
                episodes_per_trial,  # policy state will not be reset within these episodes
                max_episode_len=max_episode_len,
            ).run()  # collect one trial
            collect_ops.append(collect_op)

        if verbose:
            print("-- finish constructing policies and collect ops --")
            print("-- start constructing replay buffer->training pipeline --")

        ######################################################
        # replay buffer --> dataset --> iterate to get trajecs for training
        ######################################################

        # get some data from all task replay buffers (even though won't actually train on all of them)
        dataset_iterators = []
        all_tasks_trajectories_fromdense = []
        for task_idx in range(num_train_tasks):
            dataset = replay_buffers[task_idx].as_dataset(
                sample_batch_size=
                sample_episodes_per_task,  # number of episodes to sample
                num_steps=max_episode_len + 1
            ).prefetch(
                3
            )  # +1 to include the last state: a trajectory with n transition has n+1 states
            # iterator to go through the data
            dataset_iterator = tf.compat.v1.data.make_initializable_iterator(
                dataset)
            dataset_iterators.append(dataset_iterator)
            # get sample_episodes_per_task sequences, each of length num_steps
            trajectories_task_i, _ = dataset_iterator.get_next()
            all_tasks_trajectories_fromdense.append(trajectories_task_i)

        if load_offline_data:
            # have separate dataset for relabel data
            dataset_iterators_withRelabel = []
            all_tasks_trajectories_fromdense_withRelabel = []
            for task_idx in range(num_train_tasks):
                dataset = replay_buffers_withRelabel[task_idx].as_dataset(
                    sample_batch_size=
                    sample_episodes_per_task,  # number of episodes to sample
                    num_steps=offline_episode_len + 1
                ).prefetch(
                    3
                )  # +1 to include the last state: a trajectory with n transition has n+1 states
                # iterator to go through the data
                dataset_iterator = tf.compat.v1.data.make_initializable_iterator(
                    dataset)
                dataset_iterators_withRelabel.append(dataset_iterator)
                # get sample_episodes_per_task sequences, each of length num_steps
                trajectories_task_i, _ = dataset_iterator.get_next()
                all_tasks_trajectories_fromdense_withRelabel.append(
                    trajectories_task_i)

        if verbose:
            print("-- finish constructing replay buffer->training pipeline --")
            print("-- start constructing model and AC training ops --")

        ######################################
        # Decoding latent samples into rewards
        ######################################

        latent_samples_1_ph = tf.compat.v1.placeholder(
            dtype=tf.float32,
            shape=(None, None, meld_agent._model_network.latent1_size))
        latent_samples_2_ph = tf.compat.v1.placeholder(
            dtype=tf.float32,
            shape=(None, None, meld_agent._model_network.latent2_size))
        decode_rews_op = meld_agent._model_network.decode_latents_into_reward(
            latent_samples_1_ph, latent_samples_2_ph)

        ######################################
        # Model/Actor/Critic train + summary ops
        ######################################

        # train AC on data from replay buffer
        if load_offline_data:
            ac_train_op = meld_agent.train_ac_meld(
                all_tasks_trajectories_fromdense,
                all_tasks_trajectories_fromdense_withRelabel)
        else:
            ac_train_op = meld_agent.train_ac_meld(
                all_tasks_trajectories_fromdense)

        summary_ops = []
        for train_metric in train_metrics:
            summary_ops.append(
                train_metric.tf_summaries(train_step=global_step,
                                          step_metrics=train_metrics[:2]))

        if verbose:
            print("-- finish constructing AC training ops --")

        ############################
        # Model train + summary ops
        ############################

        # train model on data from replay buffer
        if load_offline_data:
            model_train_op, check_step_types = meld_agent.train_model_meld(
                all_tasks_trajectories_fromdense,
                all_tasks_trajectories_fromdense_withRelabel)
        else:
            model_train_op, check_step_types = meld_agent.train_model_meld(
                all_tasks_trajectories_fromdense)

        model_summary_ops, model_summary_ops_2 = [], []
        for summary_op in tf.compat.v1.summary.all_v2_summary_ops():
            if summary_op not in summary_ops:
                model_summary_ops.append(summary_op)

        if verbose:
            print("-- finish constructing model training ops --")
            print("-- start constructing checkpointers --")

        ########################
        # Eval metrics
        ########################

        with eval_summary_writer.as_default(), \
             tf.compat.v2.summary.record_if(True):
            for eval_metric in eval_metrics:
                eval_metric.tf_summaries(train_step=global_step,
                                         step_metrics=train_metrics[:2])

        ########################
        # Create savers
        ########################
        train_config_saver = gin.tf.GinConfigSaverHook(train_dir,
                                                       summarize_config=False)
        eval_config_saver = gin.tf.GinConfigSaverHook(eval_dir,
                                                      summarize_config=False)

        ########################
        # Create checkpointers
        ########################

        train_checkpointer = common.Checkpointer(
            ckpt_dir=train_dir,
            agent=meld_agent,
            global_step=global_step,
            metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'),
            max_to_keep=1)
        policy_checkpointer = common.Checkpointer(
            ckpt_dir=os.path.join(train_dir, 'policy'),
            policy=meld_agent.policy,
            global_step=global_step,
            max_to_keep=99999999999
        )  # keep many policy checkpoints, in case of future eval
        rb_checkpointers = []
        for buffer_idx in range(len(replay_buffers)):
            rb_checkpointer = common.Checkpointer(
                ckpt_dir=os.path.join(train_dir, 'replay_buffers/',
                                      "task" + str(buffer_idx)),
                max_to_keep=1,
                replay_buffer=replay_buffers[buffer_idx])
            rb_checkpointers.append(rb_checkpointer)

        if load_offline_data:  # for LOADING data not for checkpointing. No new data going in anyways
            rb_checkpointers_withRelabel = []
            for buffer_idx in range(len(replay_buffers_withRelabel)):
                ckpt_dir = os.path.join(offline_data_dir,
                                        "task" + str(buffer_idx))
                rb_checkpointer = common.Checkpointer(
                    ckpt_dir=ckpt_dir,
                    max_to_keep=99999999999,
                    replay_buffer=replay_buffers_withRelabel[buffer_idx])
                rb_checkpointers_withRelabel.append(rb_checkpointer)
            # Notice: these replay buffers need to follow the same sequence of tasks as the current one

        if verbose:
            print("-- finish constructing checkpointers --")
            print("-- start main training loop --")

        with tf.compat.v1.Session() as sess:

            ########################
            # Initialize
            ########################

            if eval_only:
                sess.run(eval_summary_writer.init())
                load_eval_log(
                    train_eval_dir=train_eval_dir,
                    meld_agent=meld_agent,
                    global_step=global_step,
                    sess=sess,
                    eval_metrics=eval_metrics,
                    eval_py_env=eval_py_env,
                    eval_py_policy=eval_py_policy,
                    num_eval_trials=num_eval_trials,
                    max_episode_len=max_episode_len,
                    episodes_per_trial=episodes_per_trial,
                    log_image_strips=log_image_strips,
                    num_trials_to_render=num_trials_to_render,
                    train_tasks=
                    train_tasks,  # in case want to eval on a train task
                    eval_tasks=eval_tasks,
                    model_net=model_net,
                    render_fps=render_fps,
                    decode_rews_op=decode_rews_op,
                    latent_samples_1_ph=latent_samples_1_ph,
                    latent_samples_2_ph=latent_samples_2_ph,
                )
                return

            # Initialize checkpointing
            train_checkpointer.initialize_or_restore(sess)
            for rb_checkpointer in rb_checkpointers:
                rb_checkpointer.initialize_or_restore(sess)

            if load_offline_data:
                for rb_checkpointer in rb_checkpointers_withRelabel:
                    rb_checkpointer.initialize_or_restore(sess)

            # Initialize dataset iterators
            for dataset_iterator in dataset_iterators:
                sess.run(dataset_iterator.initializer)

            if load_offline_data:
                for dataset_iterator in dataset_iterators_withRelabel:
                    sess.run(dataset_iterator.initializer)

            # Initialize variables
            common.initialize_uninitialized_variables(sess)

            # Initialize summary writers
            sess.run(train_summary_writer.init())
            sess.run(eval_summary_writer.init())

            # Initialize savers
            train_config_saver.after_create_session(sess)
            eval_config_saver.after_create_session(sess)
            # Get value of step counter
            global_step_val = sess.run(global_step)

            if verbose:
                print("====== finished initialization ======")

            ################################################################
            # If this is start of new exp (i.e., 1st step) and not continuing old exp
            # eval rand policy + do initial data collection
            ################################################################
            fresh_start = (global_step_val == 0)

            if fresh_start:

                ########################
                # Evaluate initial policy
                ########################

                if eval_interval:
                    logging.info(
                        '\n\nDoing evaluation of initial policy on %d trials with randomly sampled tasks',
                        num_eval_trials)
                    perform_eval_and_summaries_meld(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_eval_trials,
                        max_episode_len,
                        episodes_per_trial,
                        log_image_strips=log_image_strips,
                        num_trials_to_render=num_eval_tasks,
                        eval_tasks=eval_tasks,
                        latent1_size=model_net.latent1_size,
                        latent2_size=model_net.latent2_size,
                        logger=eval_logger,
                        global_step_val=global_step_val,
                        render_fps=render_fps,
                        decode_rews_op=decode_rews_op,
                        latent_samples_1_ph=latent_samples_1_ph,
                        latent_samples_2_ph=latent_samples_2_ph,
                        log_image_observations=log_image_observations,
                    )
                    sess.run(eval_summary_flush_op)
                    logging.info(
                        'Done with evaluation of initial (random) policy.\n\n')

                ########################
                # Initial data collection
                ########################

                logging.info(
                    '\n\nGlobal step %d: Beginning init collect op with random policy. Collecting %dx {%d, %d} trials for each task',
                    global_step_val, init_collect_trials_per_task,
                    max_episode_len, episodes_per_trial)

                init_increment_global_step_op = global_step.assign_add(
                    env_steps_per_trial * init_collect_trials_per_task)

                for task_idx in range(num_train_tasks):
                    logging.info('on task %d / %d', task_idx + 1,
                                 num_train_tasks)
                    py_env.set_task_for_env(train_tasks[task_idx])
                    sess.run([
                        init_collect_ops[task_idx],
                        init_increment_global_step_op
                    ])  # incremented gs in granularity of task

                rb_checkpointer.save(global_step=global_step_val)
                logging.info('Finished init collect.\n\n')

            else:
                logging.info(
                    '\n\nGlobal step %d from loaded experiment: Skipping init collect op.\n\n',
                    global_step_val)

            #########################
            # Create calls
            #########################

            # [1] calls for running the policies to collect training data
            collect_calls = []
            increment_global_step_op = global_step.assign_add(
                env_steps_per_trial * collect_trials_per_task)
            for task_idx in range(num_train_tasks):
                collect_calls.append(
                    sess.make_callable(
                        [collect_ops[task_idx], increment_global_step_op]))

            # [2] call for doing a training step (A + C)
            ac_train_step_call = sess.make_callable([ac_train_op, summary_ops])

            # [3] call for doing a training step (model)
            model_train_step_call = sess.make_callable(
                [model_train_op, check_step_types, model_summary_ops])

            # [4] call for evaluating what global_step number we're on
            global_step_call = sess.make_callable(global_step)

            # reset keeping track of steps/time
            timed_at_step = global_step_call()
            time_acc = 0
            steps_per_second_ph = tf.compat.v1.placeholder(
                tf.float32, shape=(), name='steps_per_sec_ph')
            with train_summary_writer.as_default(
            ), tf.compat.v2.summary.record_if(True):
                steps_per_second_summary = tf.compat.v2.summary.scalar(
                    name='global_steps_per_sec',
                    data=steps_per_second_ph,
                    step=global_step)

            #################################
            # init model training
            #################################
            if fresh_start:
                logging.info(
                    '\n\nPerforming %d steps of init model training, each step on %d random tasks',
                    init_model_train_steps, num_tasks_per_train)
                for i in range(init_model_train_steps):

                    temp_start = time.time()
                    if i % 100 == 0:
                        print(".... init model training ", i, "/",
                              init_model_train_steps)

                    # init model training
                    total_loss_value_model, check_step_types, _ = model_train_step_call(
                    )

                    if PRINT_TIMING:
                        print("single model train step: ",
                              time.time() - temp_start)

            if verbose:
                print("\n\n\n-- start training loop --\n")

            #################################
            # Training Loop
            #################################
            start_time = time.time()
            for iteration in range(num_iterations):

                if iteration > 0:
                    g.finalize()

                # print("\n\n\niter", iteration, sess.run(curr_iter))
                print("global step", global_step_call())

                logging.info("Iteration: %d, Global step: %d\n", iteration,
                             global_step_val)

                ####################
                # collect data
                ####################
                logging.info(
                    '\nStarting batch data collection. Collecting %d {%d, %d} trials for each of %d tasks',
                    collect_trials_per_task, max_episode_len,
                    episodes_per_trial, num_tasks_to_collect_per_iter)

                # randomly select tasks to collect this iteration
                list_of_collect_task_idxs = np.random.choice(
                    len(train_tasks),
                    num_tasks_to_collect_per_iter,
                    replace=False)
                for count, task_idx in enumerate(list_of_collect_task_idxs):
                    logging.info('on randomly selected task %d / %d',
                                 count + 1, num_tasks_to_collect_per_iter)

                    # set task for the env
                    py_env.set_task_for_env(train_tasks[task_idx])

                    # collect data with collect policy
                    _, policy_state_val = collect_calls[task_idx]()

                logging.info('Finish data collection. Global step: %d\n',
                             global_step_call())

                ####################
                # train model
                ####################
                if (iteration
                        == 0) or ((iteration % model_train_freq == 0) and
                                  (global_step_val < stop_model_training)):
                    logging.info(
                        '\n\nPerforming %d steps of model training, each on %d random tasks',
                        model_train_steps_per_iter, num_tasks_per_train)
                    for model_iter in range(model_train_steps_per_iter):
                        temp_start_2 = time.time()

                        # train model
                        total_loss_value_model, _, _ = model_train_step_call()

                        # print("is logging step", model_iter, sess.run(is_logging_step))
                        if PRINT_TIMING:
                            print("2: single model train step: ",
                                  time.time() - temp_start_2)
                    logging.info('Finish model training. Global step: %d\n',
                                 global_step_call())
                else:
                    print("SKIPPING MODEL TRAINING")

                ####################
                # train actor critic
                ####################
                if iteration % ac_train_freq == 0:
                    logging.info(
                        '\n\nPerforming %d steps of AC training, each on %d random tasks \n\n',
                        ac_train_steps_per_iter, num_tasks_per_train)
                    for ac_iter in range(ac_train_steps_per_iter):
                        temp_start_2_ac = time.time()

                        # train ac
                        total_loss_value_ac, _ = ac_train_step_call()
                        if PRINT_TIMING:
                            print("2: single AC train step: ",
                                  time.time() - temp_start_2_ac)
                logging.info('Finish AC training. Global step: %d\n',
                             global_step_call())

                # add up time
                time_acc += time.time() - start_time

                ####################
                # logging/summaries
                ####################

                ### Eval
                if eval_interval and (iteration % eval_interval == 0):
                    logging.info(
                        '\n\nDoing evaluation of trained policy on %d trials with randomly sampled tasks',
                        num_eval_trials)

                    perform_eval_and_summaries_meld(
                        eval_metrics,
                        eval_py_env,
                        eval_py_policy,
                        num_eval_trials,
                        max_episode_len,
                        episodes_per_trial,
                        log_image_strips=log_image_strips,
                        num_trials_to_render=
                        num_trials_to_render,  # hardcoded: or gif will get too long
                        eval_tasks=eval_tasks,
                        latent1_size=model_net.latent1_size,
                        latent2_size=model_net.latent2_size,
                        logger=eval_logger,
                        global_step_val=global_step_call(),
                        render_fps=render_fps,
                        decode_rews_op=decode_rews_op,
                        latent_samples_1_ph=latent_samples_1_ph,
                        latent_samples_2_ph=latent_samples_2_ph,
                        log_image_observations=log_image_observations,
                    )

                ### steps_per_second_summary
                global_step_val = global_step_call()
                if logging_freq_in_iter and (iteration % logging_freq_in_iter
                                             == 0):
                    # log step number + speed (steps/sec)
                    logging.info(
                        'step = %d, loss = %f', global_step_val,
                        total_loss_value_ac.loss + total_loss_value_model.loss)
                    steps_per_sec = (global_step_val -
                                     timed_at_step) / time_acc
                    logging.info('%.3f env_steps/sec', steps_per_sec)
                    sess.run(steps_per_second_summary,
                             feed_dict={steps_per_second_ph: steps_per_sec})

                    # reset keeping track of steps/time
                    timed_at_step = global_step_val
                    time_acc = 0

                ### train_checkpoint
                if train_checkpoint_freq_in_iter and (
                        iteration % train_checkpoint_freq_in_iter == 0):
                    train_checkpointer.save(global_step=global_step_val)

                ### policy_checkpointer
                if policy_checkpoint_freq_in_iter and (
                        iteration % policy_checkpoint_freq_in_iter == 0):
                    policy_checkpointer.save(global_step=global_step_val)

                ### rb_checkpointer
                if rb_checkpoint_freq_in_iter and (
                        iteration % rb_checkpoint_freq_in_iter == 0):
                    for rb_checkpointer in rb_checkpointers:
                        rb_checkpointer.save(global_step=global_step_val)
示例#4
0
  def __init__(self,
               num_actions=None,
               observation_size=None,
               num_players=None,
               num_atoms=51,
               vmax=25.,
               gamma=0.99,
               update_horizon=1,
               min_replay_history=500,
               update_period=4,
               target_update_period=500,
               epsilon_train=0.0,
               epsilon_eval=0.0,
               epsilon_decay_period=1000,
               learning_rate=0.000025,
               optimizer_epsilon=0.00003125,
               tf_device='/cpu:*'):
    """Initializes the agent and constructs its graph.

    Args:
      num_actions: int, number of actions the agent can take at any state.
      observation_size: int, size of observation vector.
      num_players: int, number of players playing this game.
      num_atoms: Int, the number of buckets for the value function distribution.
      vmax: float, maximum return predicted by a value distribution.
      gamma: float, discount factor as commonly used in the RL literature.
      update_horizon: int, horizon at which updates are performed, the 'n' in
        n-step update.
      min_replay_history: int, number of stored transitions before training.
      update_period: int, period between DQN updates.
      target_update_period: int, update period for the target network.
      epsilon_train: float, final epsilon for training.
      epsilon_eval: float, epsilon during evaluation.
      epsilon_decay_period: int, number of steps for epsilon to decay.
      learning_rate: float, learning rate for the optimizer.
      optimizer_epsilon: float, epsilon for Adam optimizer.
      tf_device: str, Tensorflow device on which to run computations.
    """
    self.graph = tf.Graph()

    with self.graph.as_default():
      # We need this because some tools convert round floats into ints.
      vmax = float(vmax)
      self.num_atoms = num_atoms
      # Using -vmax as the minimum return is is wasteful, because all rewards are
      # positive -- but does not unduly affect performance.
      self.support = tf.linspace(-vmax, vmax, num_atoms)
      self.learning_rate = learning_rate
      self.optimizer_epsilon = optimizer_epsilon

      graph_template = functools.partial(rainbow_template, num_atoms=num_atoms)
      super(RainbowAgent, self).__init__(
          num_actions=num_actions,
          observation_size=observation_size,
          num_players=num_players,
          gamma=gamma,
          update_horizon=update_horizon,
          min_replay_history=min_replay_history,
          update_period=update_period,
          target_update_period=target_update_period,
          epsilon_train=epsilon_train,
          epsilon_eval=epsilon_eval,
          epsilon_decay_period=epsilon_decay_period,
          graph_template=graph_template,
          tf_device=tf_device)
      tf.logging.info('\t learning_rate: %f', learning_rate)
      tf.logging.info('\t optimizer_epsilon: %f', optimizer_epsilon)
示例#5
0
    def build_graph(self):
        """Builds the neural network graph."""

        # define graph
        self.g = tf.Graph()
        with self.g.as_default():

            # create and store a new session for the graph
            self.sess = tf.Session()

            # define placeholders
            self.x = tf.placeholder(shape=[None, self.dim_input],
                                    dtype=tf.float32)
            self.y = tf.placeholder(shape=[None, self.num_classes],
                                    dtype=tf.float32)

            # define simple model
            with tf.variable_scope('last_layer'):
                self.z = tf.layers.dense(inputs=self.x, units=self.num_classes)

            self.loss = tf.reduce_mean(
                tf.nn.softmax_cross_entropy_with_logits_v2(labels=self.y,
                                                           logits=self.z))

            self.output_probs = tf.nn.softmax(self.z)

            # Variables of the last layer
            self.ll_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
            self.ll_vars_concat = tf.concat(
                [self.ll_vars[0],
                 tf.expand_dims(self.ll_vars[1], axis=0)], 0)

            # Summary
            _variable_summaries(self.ll_vars_concat)

            # add regularization that acts as a unit Gaussian prior on the last layer
            regularizer = tf.contrib.layers.l2_regularizer(1.0)

            # regularization
            prior = tf.contrib.layers.apply_regularization(
                regularizer, self.ll_vars)
            self.bayesian_loss = self.n * self.loss + prior

            # saving the weights of last layer when running SGLD/SGD/MCMC algorithm
            self.saver = tf.train.Saver(var_list=self.ll_vars,
                                        max_to_keep=self.num_samples)

            # SGLD optimizer for the last layer
            if self.sampler in ['sgld', 'lmc']:
                step = self.step_size / self.n
                gd_opt = tf.train.GradientDescentOptimizer(step)
                grads_vars = gd_opt.compute_gradients(self.bayesian_loss)
                grads_vars_sgld = []

                for g, v in grads_vars:
                    if g is not None:
                        s = list(v.name)
                        s[v.name.rindex(':')] = '_'
                        # Adding Gaussian noise to the gradient
                        gaussian_noise = (np.sqrt(2. / step) *
                                          tf.random_normal(tf.shape(g)))
                        g_sgld = g + gaussian_noise
                        tf.summary.histogram(''.join(s) + '/grad_hist_mcmc',
                                             g / self.n)
                        tf.summary.histogram(
                            ''.join(s) + '/gaussian_noise_hist_mcmc',
                            gaussian_noise / self.n)
                        tf.summary.histogram(
                            ''.join(s) + '/grad_total_hist_mcmc',
                            g_sgld / self.n)
                        grads_vars_sgld.append((g_sgld, v))

                self.train_op = gd_opt.apply_gradients(grads_vars_sgld)

            # SGD optimizer for the last layer
            if self.sampler == 'sgd':
                gd_opt = tf.train.GradientDescentOptimizer(self.step_size)
                grads_vars_sgd = gd_opt.compute_gradients(self.loss)
                self.train_op = gd_opt.apply_gradients(grads_vars_sgd)

                for g, v in grads_vars_sgd:
                    if g is not None:
                        s = list(v.name)
                        s[v.name.rindex(':')] = '_'
                        tf.summary.histogram(''.join(s) + '/grad_hist_sgd', g)

            # Merge all the summaries and write them out
            self.all_summaries = tf.summary.merge_all()
            location = os.path.join(self.working_dir, 'logs')
            self.writer = tf.summary.FileWriter(location, graph=self.g)

            saver_network = tf.train.Saver(var_list=self.ll_vars)
            print('loading the network ...')
            # Restores from checkpoint
            # self.sess.run(tf.global_variables_initializer())
            saver_network.restore(self.sess, self.model_dir)
            print('Graph successfully loaded.')
示例#6
0
def run(tpu_job_name,
        tpu,
        gcp_project,
        tpu_zone,
        model_dir,
        model_type="bitransformer",
        vocabulary=gin.REQUIRED,
        train_dataset_fn=None,
        eval_dataset_fn=None,
        dataset_split="train",
        autostack=True,
        checkpoint_path="",
        mode="train",
        iterations_per_loop=100,
        save_checkpoints_steps=1000,
        keep_checkpoint_max=10,
        batch_size=("tokens_per_replica", 2048),
        train_steps=auto_train_steps,
        sequence_length=gin.REQUIRED,
        mesh_shape=gin.REQUIRED,
        layout_rules=gin.REQUIRED,
        num_eval_examples=None,
        get_components_fn=None,
        compute_metrics_from_file_fn=None,
        learning_rate_schedule=None,
        optimizer=None):
    """Run training/eval/inference.

  Args:
    tpu_job_name: string, name of TPU worker binary
    tpu: string, the Cloud TPU to use for training
    gcp_project: string, project name for the Cloud TPU-enabled project
    tpu_zone: string, GCE zone where the Cloud TPU is located in
    model_dir: string, estimator model_dir
    model_type: a string - either "bitransformer", "bi_student_teacher", lm" or
      "aligned"
    vocabulary: a vocabulary.Vocabulary or
      (inputs_vocabulary, targets_vocabulary) tuple.
    train_dataset_fn: A function returning a tf.data.Dataset. Must be provided
      for mode=train
    eval_dataset_fn: A function returning a tf.data.Dataset. Must be provided
      for model=eval
    dataset_split: a string
    autostack: boolean, internally combine variables
    checkpoint_path: a string - which checkpoint to load for inference
    mode: string, train/evaluate/infer
    iterations_per_loop: integer, steps per train loop
    save_checkpoints_steps: integer, steps per checkpoint
    keep_checkpoint_max: an integer, keep up to this many checkpoints
    batch_size: An integer or a (method, value) pair to pass to
      compute_batch_size(). Note that this is
      the global batch size and not the per-shard batch size.
    train_steps: An integer or a function with the same signature as
      auto_train_steps().  Total number of training steps.
    sequence_length: an integer
    mesh_shape: an input to mtf.convert_to_shape()
    layout_rules: an input to mtf.convert_to_layout_rules()
    num_eval_examples: maximum number of examples per task to use for continuous
      eval.
    get_components_fn: an optional function that returns a list of tuples of
      (metric_names, component) for each component.
      Required if mode is "continuous_eval."
    compute_metrics_from_file_fn: an optional function that takes in: component,
      metric names (list of strs), targets (list of strs), predictions (list of
      strs), dataset_split (str), and tb_summary_dir (str), runs metrics on
      targets and predictions, and returns a dictionary of metrics and their
      computed values. Required if mode is "continuous_eval."
    learning_rate_schedule: an optional function taking the scalar name
      argument `step` and the numeric argument `total_train_steps` and return
      the scalar learning rate
    optimizer: a class extending optimize.Optimizer, required for training
  """
    if not isinstance(batch_size, int):
        batch_size = compute_batch_size(sequence_length, mesh_shape,
                                        layout_rules, batch_size)

    if not isinstance(train_steps, int):
        train_steps = train_steps(batch_size, sequence_length)

    if callable(learning_rate_schedule):
        learning_rate_schedule = functools.partial(
            learning_rate_schedule, total_train_steps=train_steps)

    tf.logging.info("model_type=%s" % model_type, )
    tf.logging.info("mode=%s" % mode, )
    tf.logging.info("sequence_length=%s" % sequence_length, )
    tf.logging.info("batch_size=%s" % batch_size, )
    tf.logging.info("train_steps=%s" % train_steps, )
    tf.logging.info("mesh_shape=%s" % mesh_shape, )
    tf.logging.info("layout_rules=%s" % layout_rules, )

    if mode == "train" and dataset_split != "train":
        raise ValueError("mode==\"train\" requires dataset_split==\"train\"")

    mesh_shape = mtf.convert_to_shape(mesh_shape)
    layout_rules = mtf.convert_to_layout_rules(layout_rules)

    cluster = tf.contrib.cluster_resolver.TPUClusterResolver(
        tpu if (tpu) else "", zone=tpu_zone, project=gcp_project)

    tf.logging.info(
        "Building TPUConfig with tpu_job_name={}".format(tpu_job_name))
    my_tpu_config = tpu_config.TPUConfig(
        tpu_job_name=tpu_job_name,
        iterations_per_loop=iterations_per_loop,
        num_cores_per_replica=1,
        per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST,
    )

    run_config = tpu_config.RunConfig(
        cluster=cluster,
        model_dir=model_dir,
        tpu_config=my_tpu_config,
        # We use a saver hook, so disable checkpoints here to prevent double
        # saving.
        save_checkpoints_steps=None,
        save_checkpoints_secs=None)

    transformer_model = build_model(
        model_type=model_type,
        input_vocab_size=inputs_vocabulary(vocabulary).vocab_size,
        output_vocab_size=targets_vocabulary(vocabulary).vocab_size,
        layout_rules=layout_rules,
        mesh_shape=mesh_shape)

    model_fn = tpu_estimator_model_fn(
        model_type=model_type,
        transformer_model=transformer_model,
        model_dir=model_dir,
        use_tpu=tpu,
        mesh_shape=mesh_shape,
        layout_rules=layout_rules,
        batch_size=batch_size,
        sequence_length=sequence_length,
        autostack=autostack,
        learning_rate_schedule=learning_rate_schedule,
        keep_checkpoint_max=keep_checkpoint_max,
        save_checkpoints_steps=save_checkpoints_steps,
        optimizer=optimizer)

    estimator = tpu_estimator.TPUEstimator(model_fn=model_fn,
                                           config=run_config,
                                           train_batch_size=batch_size,
                                           eval_batch_size=batch_size,
                                           predict_batch_size=batch_size,
                                           use_tpu=tpu,
                                           export_to_tpu=False,
                                           params={})

    if mode == "train":
        if train_dataset_fn is None:
            raise ValueError(
                "Must provide train_dataset_fn through gin for train.")

        def input_fn(params):
            del params
            dataset = train_dataset_fn(batch_size=batch_size,
                                       sequence_length=sequence_length,
                                       vocabulary=vocabulary,
                                       dataset_split=dataset_split)
            return dataset

        estimator.train(input_fn=input_fn, max_steps=train_steps)
    elif mode == "continuous_eval":
        if eval_dataset_fn is None:
            raise ValueError(
                "Must provide eval_dataset_fn through gin for eval.")
        if get_components_fn is None:
            raise ValueError(
                "Must provide get_components_fn through gin for eval.")
        if compute_metrics_from_file_fn is None:
            raise ValueError(
                "Must provide compute_metrics_from_file_fn through gin for eval."
            )

        metrics_inputs = get_components_fn()
        for ckpt in tf.contrib.training.checkpoints_iterator(
                estimator.model_dir):
            for metric_names, component in metrics_inputs:
                if not metric_names:
                    tf.logging.info("Skipping %s", component.__dict__)
                    continue
                tf.logging.info("Evaluating %s on metrics %s",
                                component.tfds_name, component.metric_names)
                tf.logging.info("on split %s", dataset_split)

                # Regenerate the estimator
                model_fn = tpu_estimator_model_fn(
                    model_type=model_type,
                    transformer_model=transformer_model,
                    model_dir=model_dir,
                    use_tpu=tpu,
                    mesh_shape=mesh_shape,
                    layout_rules=layout_rules,
                    batch_size=batch_size,
                    sequence_length=sequence_length,
                    autostack=autostack,
                    keep_checkpoint_max=keep_checkpoint_max,
                    save_checkpoints_steps=save_checkpoints_steps)
                estimator = tpu_estimator.TPUEstimator(
                    model_fn=model_fn,
                    config=run_config,
                    train_batch_size=batch_size,
                    eval_batch_size=batch_size,
                    predict_batch_size=batch_size,
                    use_tpu=tpu,
                    export_to_tpu=False,
                    params={})

                # Extra eval_dataset_fn call to get the dataset_size and an extra
                # dataset object to write out targets. We need to use a separate graph
                # because estimator finalizes the default graph after iterating over the
                # dataset.
                dataset_graph = tf.Graph()
                with dataset_graph.as_default():
                    dataset, dataset_size, padded_dataset_size = eval_dataset_fn(
                        component,  # pylint: disable=cell-var-from-loop
                        batch_size=batch_size,
                        sequence_length=sequence_length,
                        vocabulary=vocabulary,
                        dataset_split=dataset_split,
                        pack=False,
                        max_dataset_size=num_eval_examples)

                def input_fn(params):
                    del params
                    dataset, _, _ = eval_dataset_fn(
                        component,  # pylint: disable=cell-var-from-loop
                        batch_size=batch_size,
                        sequence_length=sequence_length,
                        vocabulary=vocabulary,
                        dataset_split=dataset_split,
                        pack=False,
                        max_dataset_size=num_eval_examples)
                    return dataset

                dataset_name = component.tfds_name.replace("/", "-").replace(
                    ":", "-")
                output_filename = os.path.join(
                    model_dir, "{}-{}-decoded".format(dataset_name,
                                                      dataset_split))
                pred_output_filename = output_filename + "-preds-test"
                target_output_filename = output_filename + "-targets-test"
                decodes = decode(estimator,
                                 input_fn,
                                 dataset_size,
                                 padded_dataset_size,
                                 batch_size,
                                 vocabulary,
                                 checkpoint_path=checkpoint_path)
                with dataset_graph.as_default():
                    log_pred_target(
                        decodes,
                        dataset,
                        dataset_size,
                        vocabulary,
                        pred_output_filename=pred_output_filename,
                        target_output_filename=target_output_filename)
                tf.logging.info("Evaluating metrics: {}".format(metric_names))
                tb_summary_dir = os.path.join(
                    model_dir,
                    "{}_eval".format("eval" if dataset_split ==
                                     "validation" else dataset_split))
                summary_writer = tf.summary.FileWriter(tb_summary_dir)
                _ = compute_metrics_from_file_fn(component,
                                                 pred_output_filename,
                                                 target_output_filename,
                                                 dataset_split,
                                                 tb_summary_dir,
                                                 ckpt,
                                                 summary_writer=summary_writer)

    elif mode == "infer":
        decode_from_file(estimator,
                         vocabulary=vocabulary,
                         model_type=model_type,
                         batch_size=batch_size,
                         sequence_length=sequence_length,
                         checkpoint_path=checkpoint_path)
    else:
        raise ValueError(
            "unknown mode %s - must be train/continuous_eval/infer" % mode)
示例#7
0
文件: utils.py 项目: masak1112/mesh
def run(tpu_job_name,
        tpu,
        gcp_project,
        tpu_zone,
        model_dir,
        model_type="bitransformer",
        vocabulary=gin.REQUIRED,
        train_dataset_fn=None,
        eval_dataset_fn=None,
        dataset_split="train",
        autostack=True,
        checkpoint_step=None,
        mode="train",
        iterations_per_loop=100,
        save_checkpoints_steps=1000,
        keep_checkpoint_max=10,
        eval_summary_dir=None,
        batch_size=("tokens_per_replica", 2048),
        train_steps=auto_train_steps,
        sequence_length=gin.REQUIRED,
        mesh_shape=gin.REQUIRED,
        layout_rules=gin.REQUIRED,
        learning_rate_schedule=None,
        optimizer=None,
        predict_fn=None):
  """Run training/eval/inference.

  Args:
    tpu_job_name: string, name of TPU worker binary
    tpu: string, the Cloud TPU to use for training
    gcp_project: string, project name for the Cloud TPU-enabled project
    tpu_zone: string, GCE zone where the Cloud TPU is located in
    model_dir: string, estimator model_dir
    model_type: a string - either "bitransformer", "bi_student_teacher", lm" or
      "aligned"
    vocabulary: a vocabulary.Vocabulary or (inputs_vocabulary,
      targets_vocabulary) tuple.
    train_dataset_fn: A function returning a tf.data.Dataset. Must be provided
      for mode="train". Should accept the following arguments:
        - batch_size: int, number of entries in each batch.
        - sequence_length: int, length of each packed or padded sequence.
        - vocabulary: Vocabulary instance to use for encoding.
        - dataset_split: str, which dataset split to load.
    eval_dataset_fn: A function returning a list of dataset.EvalDataset tuples.
      Must be provided for mode="eval". Should accept the following arguments:
        - batch_size: int, number of entries in each batch.
        - sequence_length: int, length of each packed or padded sequence.
        - vocabulary: Vocabulary instance to use for encoding.
        - dataset_split: str, which dataset split to load.
      dataset.EvalDataset tuples are namedtuples with the following fields:
        - name: string, the task name
        - dataset_fn: function which returns a tf.data.Dataset of tokenized and
          padded examples. Must not require any arguments and must include the
          feature keys 'inputs' and 'targets_plaintext'.
        - postprocess_fn: function which converts model outputs to evalable str
        - list_of_metric_fns: list of metric functions with the call signature
          `metric_fn(targets, predictions)` which return either a scalar value
          or a dict mapping submetric names to scalar values. TensorBoard
          summaries and other tags will be written out using
          `metric_fn.__name__`.
        - dataset_size: number of entries in the dataset.
        - padded_dataset_size: number of entries in the dataset after padding.
    dataset_split: a string
    autostack: boolean, internally combine variables
    checkpoint_step: int, list of ints, or None. Only used when mode="eval" or
      mode="infer". If an int or list of ints, evaluation or inference will be
      run on the checkpoint files  in `model_dir` whose global steps are closest
      to the global steps provided. If None and mode="eval", run eval
      continuously waiting for new checkpoints via
      `tf.contrib.training.checkpoints_iterator`.
    mode: string, train/eval/infer
    iterations_per_loop: integer, steps per train loop
    save_checkpoints_steps: integer, steps per checkpoint
    keep_checkpoint_max: an integer, keep up to this many checkpoints
    eval_summary_dir: str, path to write TensorBoard events file summaries for
      eval. If None, use model_dir/eval_{split}.
    batch_size: An integer or a (method, value) pair to pass to
      compute_batch_size(). Note that this is the global batch size and not the
      per-shard batch size.
    train_steps: An integer or a function with the same signature as
      auto_train_steps().  Total number of training steps.
    sequence_length: an integer
    mesh_shape: an input to mtf.convert_to_shape()
    layout_rules: an input to mtf.convert_to_layout_rules()
    learning_rate_schedule: an optional function taking the scalar name argument
      `step` and the numeric argument `total_train_steps` and return the scalar
      learning rate
    optimizer: a class extending optimize.Optimizer, required for training
    predict_fn: an optional function that can be used to override the default
      transformer prediction behavior. Must return a tensor of shape [batch_dim,
      length_dim] that will be the prediction for each example. Must accept the
      following arguments:
        - model: a Unitransformer or Bitransformer
        - features: a dict representing an example. Every value will be an
          mtf.Tensor with shape [batch_dim, length_dim].
        - variable_dtype: an mtf.VariableDType
  """
  if not isinstance(batch_size, int):
    batch_size = compute_batch_size(
        sequence_length, mesh_shape, layout_rules, batch_size)

  if not isinstance(train_steps, int):
    train_steps = train_steps(batch_size, sequence_length)

  if callable(learning_rate_schedule):
    learning_rate_schedule = functools.partial(
        learning_rate_schedule, total_train_steps=train_steps)

  tf.logging.info("model_type=%s" % model_type,)
  tf.logging.info("mode=%s" % mode,)
  tf.logging.info("sequence_length=%s" % sequence_length,)
  tf.logging.info("batch_size=%s" % batch_size,)
  tf.logging.info("train_steps=%s" % train_steps,)
  tf.logging.info("mesh_shape=%s" % mesh_shape,)
  tf.logging.info("layout_rules=%s" % layout_rules,)

  if mode == "train" and dataset_split != "train":
    raise ValueError("mode==\"train\" requires dataset_split==\"train\"")

  mesh_shape = mtf.convert_to_shape(mesh_shape)
  layout_rules = mtf.convert_to_layout_rules(layout_rules)

  cluster = tf.contrib.cluster_resolver.TPUClusterResolver(
      tpu if (tpu) else "", zone=tpu_zone, project=gcp_project)

  tf.logging.info(
      "Building TPUConfig with tpu_job_name={}".format(tpu_job_name)
  )
  my_tpu_config = tpu_config.TPUConfig(
      tpu_job_name=tpu_job_name,
      iterations_per_loop=iterations_per_loop,
      num_cores_per_replica=1,
      per_host_input_for_training=tpu_config.InputPipelineConfig.BROADCAST,
  )

  run_config = tpu_config.RunConfig(
      cluster=cluster,
      model_dir=model_dir,
      tpu_config=my_tpu_config,
      # We use a saver hook, so disable checkpoints here to prevent double
      # saving.
      save_checkpoints_steps=None,
      save_checkpoints_secs=None)

  transformer_model = build_model(
      model_type=model_type,
      input_vocab_size=inputs_vocabulary(vocabulary).vocab_size,
      output_vocab_size=targets_vocabulary(vocabulary).vocab_size,
      layout_rules=layout_rules,
      mesh_shape=mesh_shape)

  model_fn = tpu_estimator_model_fn(
      model_type=model_type,
      transformer_model=transformer_model,
      model_dir=model_dir,
      use_tpu=tpu,
      mesh_shape=mesh_shape,
      layout_rules=layout_rules,
      batch_size=batch_size,
      sequence_length=sequence_length,
      autostack=autostack,
      learning_rate_schedule=learning_rate_schedule,
      keep_checkpoint_max=keep_checkpoint_max,
      save_checkpoints_steps=save_checkpoints_steps,
      optimizer=optimizer,
      predict_fn=predict_fn)

  estimator = tpu_estimator.TPUEstimator(
      model_fn=model_fn,
      config=run_config,
      train_batch_size=batch_size,
      eval_batch_size=batch_size,
      predict_batch_size=batch_size,
      use_tpu=tpu,
      export_to_tpu=False,
      params={})

  if mode == "train":
    if train_dataset_fn is None:
      raise ValueError("Must provide train_dataset_fn through gin for train.")
    def input_fn(params):
      del params
      dataset = train_dataset_fn(batch_size=batch_size,
                                 sequence_length=sequence_length,
                                 vocabulary=vocabulary,
                                 dataset_split=dataset_split)
      return dataset

    estimator.train(input_fn=input_fn, max_steps=train_steps)

  elif mode == "eval":
    if eval_dataset_fn is None:
      raise ValueError("Must provide eval_dataset_fn through gin for eval.")

    eval_datasets = eval_dataset_fn(
        batch_size=batch_size,
        sequence_length=sequence_length,
        vocabulary=vocabulary,
        dataset_split=dataset_split,
    )

    # Pre-load in all of the targets once before entering continuous eval loop
    cached_targets = {}
    # Need to create a separate graph for loading in plaintext targets
    # or else TF will complain that we modified the graph
    with tf.Graph().as_default():
      for eval_dataset in eval_datasets:
        eval_dataset = transformer_dataset.EvalDataset(*eval_dataset)
        # Only cache targets for those tasks with eval functions provides
        if eval_dataset.metric_fns:
          ds = eval_dataset.dataset_fn()
          # De-batch the dataset
          ds = ds.flat_map(tf.data.Dataset.from_tensor_slices)
          ds = tfds.as_numpy(ds)
          targets = [
              eval_dataset.postprocess_fn(d["targets_plaintext"]) for d in ds
          ]
          targets = targets[:eval_dataset.dataset_size]
          cached_targets[eval_dataset.name] = targets

    for checkpoint_path in get_checkpoint_iterator(checkpoint_step, model_dir):
      for eval_dataset in eval_datasets:
        eval_dataset = transformer_dataset.EvalDataset(*eval_dataset)
        if not eval_dataset.metric_fns:
          tf.logging.info(
              "Skipping %s because metric_fns is empty", eval_dataset.name
          )
          continue
        metric_names = [metric.__name__ for metric in eval_dataset.metric_fns]
        tf.logging.info(
            "Evaluating %s on metrics %s", eval_dataset.name, metric_names
        )
        tf.logging.info("on split %s", dataset_split)

        def input_fn(params):
          del params
          ds = eval_dataset.dataset_fn()
          # Only pass those variables which will be used for decoding
          ds = ds.map(
              lambda x: {k: v for k, v in x.items() if k in _INPUT_FEATURES}
          )
          return ds

        decodes = decode(
            estimator,
            input_fn,
            eval_dataset.dataset_size,
            eval_dataset.padded_dataset_size,
            batch_size,
            vocabulary,
            checkpoint_path=checkpoint_path,
        )
        predictions = [eval_dataset.postprocess_fn(d) for d in decodes]
        # TODO(craffel): Log predictions and targets.

        eval_summary_dir = eval_summary_dir or os.path.join(
            model_dir, "{}_eval".format(dataset_split)
        )
        summary_writer = tf.summary.FileWriter(eval_summary_dir)
        global_step = int(get_step_from_checkpoint_path(checkpoint_path))
        for metric_fn in eval_dataset.metric_fns:
          summary = tf.Summary()
          tag = "eval/{}/{}/{}".format(
              eval_dataset.name, dataset_split, metric_fn.__name__
          )
          targets = cached_targets[eval_dataset.name]
          metric_result = metric_fn(targets, predictions)
          if isinstance(metric_result, dict):
            tags = ["{}.{}".format(tag, key) for key in metric_result]
            metric_values = metric_result.values()
          else:
            tags, metric_values = [tag], [metric_result]
          for tag, metric_value in zip(tags, metric_values):
            tf.logging.info(
                "%s at step %d: %.3f", tag, global_step, metric_value
            )
            summary.value.add(tag=tag, simple_value=metric_value)
            summary_writer.add_summary(summary, global_step)
        summary_writer.flush()

  elif mode == "infer":
    for checkpoint_path in get_checkpoint_iterator(checkpoint_step, model_dir):
      decode_from_file(
          estimator,
          vocabulary=vocabulary,
          model_type=model_type,
          batch_size=batch_size,
          sequence_length=sequence_length,
          checkpoint_path=checkpoint_path)
  else:
    raise ValueError(
        "unknown mode %s - must be train/eval/infer" % mode)