예제 #1
0
    def build_network(self):
        # Create value (critic) network + target network
        if train_params.USE_BATCH_NORM:
            pass # for now
            # self.critic_net = Critic_BN(self.state_ph, self.action_ph, train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, train_params.NUM_ATOMS, train_params.V_MIN, train_params.V_MAX, is_training=True, scope='learner_critic_main')
            # self.critic_target_net = Critic_BN(self.state_ph, self.action_ph, train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, train_params.NUM_ATOMS, train_params.V_MIN, train_params.V_MAX, is_training=True, scope='learner_critic_target')
        else:
            self.critic_net = Critic(train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, train_params.NUM_ATOMS, train_params.V_MIN, train_params.V_MAX, name='critic')
            self.critic_target_net = Critic(train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, train_params.NUM_ATOMS, train_params.V_MIN, train_params.V_MAX, name='critic_target')

        # Create policy (actor) network + target network
        if train_params.USE_BATCH_NORM:
            pass # for now
            # self.actor_net = Actor_BN(self.state_ph, train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.ACTION_BOUND_LOW, train_params.ACTION_BOUND_HIGH, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, is_training=True, scope='learner_actor_main')
            # self.actor_target_net = Actor_BN(self.state_ph, train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.ACTION_BOUND_LOW, train_params.ACTION_BOUND_HIGH, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, is_training=True, scope='learner_actor_target')
        else:
            self.actor_net = Actor(train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.ACTION_BOUND_LOW, train_params.ACTION_BOUND_HIGH, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, name='actor')
            self.actor_target_net = Actor(train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.ACTION_BOUND_LOW, train_params.ACTION_BOUND_HIGH, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, name='actor_target')
예제 #2
0
    def build_network(self):
        
        # Define input placeholders    
        self.state_ph = tf.placeholder(tf.float32, ((train_params.BATCH_SIZE,) + train_params.STATE_DIMS))
        self.action_ph = tf.placeholder(tf.float32, ((train_params.BATCH_SIZE,) + train_params.ACTION_DIMS))
        self.noise_ph = tf.placeholder(tf.float32, (train_params.BATCH_SIZE,train_params.NUM_ATOMS,train_params.NOISE_DIMS))
        self.action_grads_ph = tf.placeholder(tf.float32, ((train_params.BATCH_SIZE,) + train_params.ACTION_DIMS)) # Gradient of critic's value output wrt action input - for actor training
        self.weights_ph = tf.placeholder(tf.float32, (train_params.BATCH_SIZE)) # Batch of IS weights to weigh gradient updates based on sample priorities
        self.real_samples_ph = tf.placeholder(tf.float32, (train_params.BATCH_SIZE, train_params.NUM_ATOMS)) # samples of target network with Bellman update applied

        # Create value (critic) network + target network
        if train_params.USE_BATCH_NORM:
            self.critic_net = Critic_BN(self.state_ph, self.action_ph, self.noise_ph,train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.NOISE_DIMS, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, train_params.NUM_ATOMS, train_params.V_MIN, train_params.V_MAX, is_training=True, scope='learner_critic_main')
            self.critic_target_net = Critic_BN(self.state_ph, self.action_ph, self.noise_ph,train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.NOISE_DIMS, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, train_params.NUM_ATOMS, train_params.V_MIN, train_params.V_MAX, is_training=True, scope='learner_critic_target')
        else:
            self.critic_net = Critic(self.state_ph, self.action_ph, self.noise_ph, train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.NOISE_DIMS, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, train_params.NUM_ATOMS, train_params.V_MIN, train_params.V_MAX, scope='learner_critic_main')
            self.critic_target_net = Critic( self.state_ph, self.action_ph, self.noise_ph, train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.NOISE_DIMS, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, train_params.NUM_ATOMS, train_params.V_MIN, train_params.V_MAX, scope='learner_critic_target')
        
        # Create policy (actor) network + target network
        if train_params.USE_BATCH_NORM:
            self.actor_net = Actor_BN(self.state_ph, train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.ACTION_BOUND_LOW, train_params.ACTION_BOUND_HIGH, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, is_training=True, scope='learner_actor_main')
            self.actor_target_net = Actor_BN(self.state_ph, train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.ACTION_BOUND_LOW, train_params.ACTION_BOUND_HIGH, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, is_training=True, scope='learner_actor_target')
        else:
            self.actor_net = Actor(self.state_ph, train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.ACTION_BOUND_LOW, train_params.ACTION_BOUND_HIGH, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, scope='learner_actor_main')
            self.actor_target_net = Actor(self.state_ph, train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.ACTION_BOUND_LOW, train_params.ACTION_BOUND_HIGH, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, scope='learner_actor_target')
     
        # Create training step ops
        self.critic_train_step = self.critic_net.train_step(self.real_samples_ph, self.weights_ph, train_params.CRITIC_LEARNING_RATE, train_params.CRITIC_L2_LAMBDA, train_params.NUM_ATOMS)
        self.actor_train_step = self.actor_net.train_step(self.action_grads_ph, train_params.ACTOR_LEARNING_RATE, train_params.BATCH_SIZE)
        
        # Create saver for saving model ckpts (we only save learner network vars)
        model_name = train_params.ENV + '.ckpt'
        self.checkpoint_path = os.path.join(train_params.CKPT_DIR, model_name)        
        if not os.path.exists(train_params.CKPT_DIR):
            os.makedirs(train_params.CKPT_DIR)
        saver_vars = [v for v in tf.global_variables() if 'learner' in v.name]
        self.saver = tf.train.Saver(var_list = saver_vars, max_to_keep=1001) 
예제 #3
0
파일: agent.py 프로젝트: yangminsi/D4PG
 def build_network(self, training):
     # Input placeholder    
     self.state_ph = tf.placeholder(tf.float32, ((None,) + train_params.STATE_DIMS)) 
     
     if training:
         # each agent has their own var_scope
         var_scope = ('actor_agent_%02d'%self.n_agent)
     else:
         # when testing, var_scope comes from main learner policy (actor) network
         var_scope = ('learner_actor_main')
       
     # Create policy (actor) network
     if train_params.USE_BATCH_NORM:
         self.actor_net = Actor_BN(self.state_ph, train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.ACTION_BOUND_LOW, train_params.ACTION_BOUND_HIGH, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, is_training=False, scope=var_scope)
         self.agent_policy_params = self.actor_net.network_params + self.actor_net.bn_params
     else:
         self.actor_net = Actor(self.state_ph, train_params.STATE_DIMS, train_params.ACTION_DIMS, train_params.ACTION_BOUND_LOW, train_params.ACTION_BOUND_HIGH, train_params.DENSE1_SIZE, train_params.DENSE2_SIZE, train_params.FINAL_LAYER_INIT, scope=var_scope)
         self.agent_policy_params = self.actor_net.network_params
def play(args):
    # Create environment
    env = gym.make(args.env)
    state_dims = env.observation_space.shape
    action_dims = env.action_space.shape
    action_bound_low = env.action_space.low
    action_bound_high = env.action_space.high

    # Define input placeholders
    state_ph = tf.placeholder(tf.float32, ((None,) + state_dims))

    # Create policy (actor) network
    if args.use_batch_norm:
        actor = Actor_BN(state_ph, state_dims, action_dims, action_bound_low, action_bound_high, args, is_training=False, scope='actor_main')
    else:
        actor = Actor(state_ph, state_dims, action_dims, action_bound_low, action_bound_high, args, scope='actor_main')

    # Create session
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    # Load ckpt file
    loader = tf.train.Saver()
    if args.ckpt_file is not None:
        ckpt = args.ckpt_dir + '/' + args.ckpt_file
    else:
        ckpt = tf.train.latest_checkpoint(args.ckpt_dir)

    loader.restore(sess, ckpt)
    print('%s restored.\n\n' % ckpt)

    # Create record directory
    if args.record_dir is not None:
        if not os.path.exists(args.record_dir):
            os.makedirs(args.record_dir)

    for ep in range(args.num_eps):
        state = env.reset()
        for step in range(args.max_ep_length):
            frame = env.render(mode='rgb_array')
            if args.record_dir is not None:
                filepath = args.record_dir + '/Ep%03d_Step%04d.jpg' % (ep, step)
                cv2.imwrite(filepath, frame)
            action = sess.run(actor.output, {state_ph:np.expand_dims(state, 0)})[0]     # Add batch dimension to single state input, and remove batch dimension from single action output
            state, _, terminal, _ = env.step(action)

            if terminal:
                break

    env.close()

    # Convert saved frames to gif
    if args.record_dir is not None:
        images = []
        for file in sorted(os.listdir(args.record_dir)):
            # Load image
            filename = args.record_dir + '/' + file
            im = cv2.imread(filename)
            images.append(im)
            # Delete static image once loaded
            os.remove(filename)

        # Save as gif
        imageio.mimsave(args.record_dir + '/%s.gif' % args.env, images, duration=0.01)
예제 #5
0
def train(args):
    # Create environment
    env = gym.make(args.env)
    state_dims = env.observation_space.shape
    action_dims = env.action_space.shape
    action_bound_low = env.action_space.low
    action_bound_high = env.action_space.high

    # Set random seeds for reproducability
    env.seed(args.random_seed)
    np.random.seed(args.random_seed)
    tf.set_random_seed(args.random_seed)

    # Initialise replay memory
    replay_mem = ReplayMemory(args, state_dims, action_dims)

    # Initialise Ornstein-Uhlenbeck Noise generator
    exploration_noise = OrnsteinUhlenbeckActionNoise(mu=np.zeros(action_dims))
    noise_scaling = args.noise_scale * (action_bound_high - action_bound_low)

    # Define input placeholders
    state_ph = tf.placeholder(tf.float32, ((None, ) + state_dims))
    action_ph = tf.placeholder(tf.float32, ((None, ) + action_dims))
    target_ph = tf.placeholder(
        tf.float32, (None, 1))  # Target Q-value - for critic training
    action_grads_ph = tf.placeholder(
        tf.float32, ((None, ) + action_dims)
    )  # Gradient of critic's value output wrt action input - for actor training
    is_training_ph = tf.placeholder_with_default(True, shape=None)

    # Create value (critic) network + target network
    if args.use_batch_norm:
        critic = Critic_BN(state_ph,
                           action_ph,
                           state_dims,
                           action_dims,
                           args,
                           is_training=is_training_ph,
                           scope='critic_main')
        critic_target = Critic_BN(state_ph,
                                  action_ph,
                                  state_dims,
                                  action_dims,
                                  args,
                                  is_training=is_training_ph,
                                  scope='critic_target')
    else:
        critic = Critic(state_ph,
                        action_ph,
                        state_dims,
                        action_dims,
                        args,
                        scope='critic_main')
        critic_target = Critic(state_ph,
                               action_ph,
                               state_dims,
                               action_dims,
                               args,
                               scope='critic_target')

    # Create policy (actor) network + target network
    if args.use_batch_norm:
        actor = Actor_BN(state_ph,
                         state_dims,
                         action_dims,
                         action_bound_low,
                         action_bound_high,
                         args,
                         is_training=is_training_ph,
                         scope='actor_main')
        actor_target = Actor_BN(state_ph,
                                state_dims,
                                action_dims,
                                action_bound_low,
                                action_bound_high,
                                args,
                                is_training=is_training_ph,
                                scope='actor_target')
    else:
        actor = Actor(state_ph,
                      state_dims,
                      action_dims,
                      action_bound_low,
                      action_bound_high,
                      args,
                      scope='actor_main')
        actor_target = Actor(state_ph,
                             state_dims,
                             action_dims,
                             action_bound_low,
                             action_bound_high,
                             args,
                             scope='actor_target')

    # Create training step ops
    critic_train_step = critic.train_step(target_ph)
    actor_train_step = actor.train_step(action_grads_ph)

    # Create ops to update target networks
    update_critic_target = update_target_network(critic.network_params,
                                                 critic_target.network_params,
                                                 args.tau)
    update_actor_target = update_target_network(actor.network_params,
                                                actor_target.network_params,
                                                args.tau)

    # Create session
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    # Define saver for saving model ckpts
    model_name = args.env + '.ckpt'
    checkpoint_path = os.path.join(args.ckpt_dir, model_name)
    if not os.path.exists(args.ckpt_dir):
        os.makedirs(args.ckpt_dir)
    saver = tf.train.Saver(max_to_keep=201)

    # Load ckpt file if given
    if args.ckpt_file is not None:
        loader = tf.train.Saver()  #Restore all variables from ckpt
        ckpt = args.ckpt_dir + '/' + args.ckpt_file
        ckpt_split = ckpt.split('-')
        step_str = ckpt_split[-1]
        start_ep = int(step_str)
        loader.restore(sess, ckpt)
    else:
        start_ep = 0
        sess.run(tf.global_variables_initializer())
        # Perform hard copy (tau=1.0) of initial params to target networks
        sess.run(
            update_target_network(critic.network_params,
                                  critic_target.network_params))
        sess.run(
            update_target_network(actor.network_params,
                                  actor_target.network_params))

    # Create summary writer to write summaries to disk
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)
    summary_writer = tf.summary.FileWriter(args.log_dir, sess.graph)

    # Create summary op to save episode reward to Tensorboard log
    ep_reward_var = tf.Variable(0.0, trainable=False)
    tf.summary.scalar("Episode Reward", ep_reward_var)
    summary_op = tf.summary.merge_all()

    ## Training

    # Initially populate replay memory by taking random actions
    sys.stdout.write('\nPopulating replay memory with random actions...\n')
    sys.stdout.flush()
    env.reset()

    for random_step in range(1, args.initial_replay_mem_size + 1):
        if args.render:
            env.render()
        action = env.action_space.sample()
        state, reward, terminal, _ = env.step(action)
        replay_mem.add(action, reward, state, terminal)

        if terminal:
            env.reset()

        sys.stdout.write('\x1b[2K\rStep {:d}/{:d}'.format(
            random_step, args.initial_replay_mem_size))
        sys.stdout.flush()

    sys.stdout.write('\n\nTraining...\n')
    sys.stdout.flush()

    for train_ep in range(start_ep + 1, args.num_eps_train + 1):
        # Reset environment and noise process
        state = env.reset()
        exploration_noise.reset()

        train_step = 0
        episode_reward = 0
        duration_values = []
        ep_done = False

        sys.stdout.write('\n')
        sys.stdout.flush()

        while not ep_done:
            train_step += 1
            start_time = time.time()
            ## Take action and store experience
            if args.render:
                env.render()
            if args.use_batch_norm:
                action = sess.run(
                    actor.output, {
                        state_ph: np.expand_dims(state, 0),
                        is_training_ph: False
                    }
                )[0]  # Add batch dimension to single state input, and remove batch dimension from single action output
            else:
                action = sess.run(actor.output,
                                  {state_ph: np.expand_dims(state, 0)})[0]
            action += exploration_noise() * noise_scaling
            action = min(action, action_bound_high)
            action = max(action, action_bound_low)
            state, reward, terminal, _ = env.step(action)
            replay_mem.add(action, reward, state, terminal)

            episode_reward += reward

            ## Train networks
            # Get minibatch
            states_batch, actions_batch, rewards_batch, next_states_batch, terminals_batch = replay_mem.getMinibatch(
            )

            # Critic training step
            # Predict actions for next states by passing next states through policy target network
            future_action = sess.run(actor_target.output,
                                     {state_ph: next_states_batch})
            # Predict target Q values by passing next states and actions through value target network
            future_Q = sess.run(
                critic_target.output, {
                    state_ph: next_states_batch,
                    action_ph: future_action
                }
            )[:,
              0]  # future_Q is of shape [batch_size, 1], need to remove second dimension for ops with terminals_batch and rewards_batch which are of shape [batch_size]
            # Q values of the terminal states is 0 by definition
            future_Q[terminals_batch] = 0
            targets = rewards_batch + (future_Q * args.discount_rate)
            # Train critic
            sess.run(
                critic_train_step, {
                    state_ph: states_batch,
                    action_ph: actions_batch,
                    target_ph: np.expand_dims(targets, 1)
                })

            # Actor training step
            # Get policy network's action outputs for selected states
            actor_actions = sess.run(actor.output, {state_ph: states_batch})
            # Compute gradients of critic's value output wrt actions
            action_grads = sess.run(critic.action_grads, {
                state_ph: states_batch,
                action_ph: actor_actions
            })
            # Train actor
            sess.run(actor_train_step, {
                state_ph: states_batch,
                action_grads_ph: action_grads[0]
            })

            # Update target networks
            sess.run(update_critic_target)
            sess.run(update_actor_target)

            # Display progress
            duration = time.time() - start_time
            duration_values.append(duration)
            ave_duration = sum(duration_values) / float(len(duration_values))

            #pdb.set_trace()
            sys.stdout.write(
                '\x1b[2K\rEpisode {:d}/{:d} \t Steps = {:d} \t Reward = {:.3f} \t ({:.3f} s/step)'
                .format(train_ep, args.num_eps_train, train_step,
                        episode_reward, ave_duration))
            sys.stdout.flush()

            if terminal or train_step == args.max_ep_length:
                # Log total episode reward and begin next episode
                summary_str = sess.run(summary_op,
                                       {ep_reward_var: episode_reward})
                summary_writer.add_summary(summary_str, train_ep)
                ep_done = True

        if train_ep % args.save_ckpt_step == 0:
            saver.save(sess, checkpoint_path, global_step=train_ep)
            sys.stdout.write('\n Checkpoint saved.')
            sys.stdout.flush()

    env.close()
예제 #6
0
def play():

    if play_params.ENV == 'Pendulum-v0':
        play_env = PendulumWrapper()
    elif play_params.ENV == 'LunarLanderContinuous-v2':
        play_env = LunarLanderContinuousWrapper()
    elif play_params.ENV == 'BipedalWalker-v2':
        play_env = BipedalWalkerWrapper()
    elif play_params.ENV == 'BipedalWalkerHardcore-v2':
        play_env = BipedalWalkerWrapper(hardcore=True)
    else:
        raise Exception(
            'Chosen environment does not have an environment wrapper defined. Please choose an environment with an environment wrapper defined, or create a wrapper for this environment in utils.env_wrapper.py'
        )

    actor_net = Actor(play_params.STATE_DIMS,
                      play_params.ACTION_DIMS,
                      play_params.ACTION_BOUND_LOW,
                      play_params.ACTION_BOUND_HIGH,
                      train_params.DENSE1_SIZE,
                      train_params.DENSE2_SIZE,
                      train_params.FINAL_LAYER_INIT,
                      name='actor_play')
    critic_net = Critic(play_params.STATE_DIMS,
                        play_params.ACTION_DIMS,
                        train_params.DENSE1_SIZE,
                        train_params.DENSE2_SIZE,
                        train_params.FINAL_LAYER_INIT,
                        train_params.NUM_ATOMS,
                        train_params.V_MIN,
                        train_params.V_MAX,
                        name='critic_play')

    actor_net.load_weights(play_params.ACTOR_MODEL_DIR)
    critic_net.load_weights(play_params.CRITIC_MODEL_DIR)

    if not os.path.exists(play_params.RECORD_DIR):
        os.makedirs(play_params.RECORD_DIR)

    for ep in tqdm(range(1, play_params.NUM_EPS_PLAY + 1), desc='playing'):
        state = play_env.reset()
        state = play_env.normalise_state(state)
        step = 0
        ep_done = False

        while not ep_done:
            frame = play_env.render()
            if play_params.RECORD_DIR is not None:
                filepath = play_params.RECORD_DIR + '/Ep%03d_Step%04d.jpg' % (
                    ep, step)
                cv2.imwrite(filepath, frame)
            action = actor_net(np.expand_dims(state.astype(np.float32), 0))[0]
            state, _, terminal = play_env.step(action)
            state = play_env.normalise_state(state)

            step += 1

            # Episode can finish either by reaching terminal state or max episode steps
            if terminal or step == play_params.MAX_EP_LENGTH:
                ep_done = True

    # Convert saved frames to gif
    exit()
    if play_params.RECORD_DIR is not None:
        images = []
        for file in tqdm(sorted(os.listdir(play_params.RECORD_DIR)),
                         desc='converting to gif'):
            # Load image
            filename = play_params.RECORD_DIR + '/' + file
            im = cv2.imread(filename)
            images.append(im)
            # Delete static image once loaded
            os.remove(filename)

        # Save as gif
        print("Saving to ", play_params.RECORD_DIR)
        imageio.mimsave(play_params.RECORD_DIR + '/%s.gif' % play_params.ENV,
                        images[:-1],
                        duration=0.01)

    play_env.close()
예제 #7
0
파일: test.py 프로젝트: LT310/DDPG-2
def test(args):
    # Create environment
    env = gym.make(args.env)
    state_dims = env.observation_space.shape
    action_dims = env.action_space.shape
    action_bound_low = env.action_space.low
    action_bound_high = env.action_space.high

    # Set random seeds for reproducability
    env.seed(args.random_seed)
    np.random.seed(args.random_seed)
    tf.set_random_seed(args.random_seed)

    # Define input placeholder
    state_ph = tf.placeholder(tf.float32, ((None, ) + state_dims))

    # Create policy (actor) network
    if args.use_batch_norm:
        actor = Actor_BN(state_ph,
                         state_dims,
                         action_dims,
                         action_bound_low,
                         action_bound_high,
                         args,
                         is_training=False,
                         scope='actor_main')
    else:
        actor = Actor(state_ph,
                      state_dims,
                      action_dims,
                      action_bound_low,
                      action_bound_high,
                      args,
                      scope='actor_main')

    # Create session
    config = tf.ConfigProto(allow_soft_placement=True)
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)

    # Load ckpt file
    loader = tf.train.Saver()
    if args.ckpt_file is not None:
        ckpt = args.ckpt_dir + '/' + args.ckpt_file
    else:
        ckpt = tf.train.latest_checkpoint(args.ckpt_dir)

    loader.restore(sess, ckpt)
    sys.stdout.write('%s restored.\n\n' % ckpt)
    sys.stdout.flush()

    ckpt_split = ckpt.split('-')
    train_ep = ckpt_split[-1]

    # Create summary writer to write summaries to disk
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)
    summary_writer = tf.summary.FileWriter(args.log_dir, sess.graph)

    # Create summary op to save episode reward to Tensorboard log
    reward_var = tf.Variable(0.0, trainable=False)
    tf.summary.scalar("Average Test Reward", reward_var)
    summary_op = tf.summary.merge_all()

    # Start testing

    rewards = []

    for test_ep in range(args.num_eps_test):
        state = env.reset()
        ep_reward = 0
        step = 0
        ep_done = False

        while not ep_done:
            if args.render:
                env.render()
            action = sess.run(
                actor.output, {state_ph: np.expand_dims(state, 0)}
            )[0]  # Add batch dimension to single state input, and remove batch dimension from single action output
            state, reward, terminal, _ = env.step(action)

            ep_reward += reward
            step += 1

            # Episode can finish either by reaching terminal state or max episode steps
            if terminal or step == args.max_ep_length:
                sys.stdout.write('\x1b[2K\rTest episode {:d}/{:d}'.format(
                    test_ep, args.num_eps_test))
                sys.stdout.flush()
                rewards.append(ep_reward)
                ep_done = True

    mean_reward = np.mean(rewards)
    error_reward = ss.sem(rewards)

    sys.stdout.write(
        '\x1b[2K\rTesting complete \t Average reward = {:.2f} +/- {:.2f} /ep \n\n'
        .format(mean_reward, error_reward))
    sys.stdout.flush()

    # Log average episode reward for Tensorboard visualisation
    summary_str = sess.run(summary_op, {reward_var: mean_reward})
    summary_writer.add_summary(summary_str, train_ep)

    # Write results to file
    if args.results_dir is not None:
        if not os.path.exists(args.results_dir):
            os.makedirs(args.results_dir)
        output_file = open(args.results_dir + '/' + args.env + '.txt', 'a')
        output_file.write(
            'Training Episode {}: \t Average reward = {:.2f} +/- {:.2f} /ep \n\n'
            .format(train_ep, mean_reward, error_reward))
        output_file.flush()
        sys.stdout.write('Results saved to file \n\n')
        sys.stdout.flush()

    env.close()