예제 #1
0
#Code Reused from https://github.com/stanfordnlp/mac-network/
from __future__ import division

import tensorflow as tf

from experiments.config_args import parse_args
from sarnet_td3.common.config_ops import config
from sarnet_td3.common.mi_gru_cell import MiGRUCell
from sarnet_td3.common.mi_lstm_cell import MiLSTMCell

arglist = parse_args()

eps = 1e-20
inf = 1e30

####################################### variables ########################################
'''
Initializes a weight matrix variable given a shape and a name. 
Uses random_normal initialization if 1d, otherwise uses xavier. 
'''


def getWeight(shape, name=""):
    with tf.compat.v1.variable_scope("weights"):
        initializer = tf.contrib.layers.xavier_initializer()
        # if len(shape) == 1: # good?
        #     initializer = tf.random_normal_initializer()
        W = tf.compat.v1.get_variable("weight" + name,
                                      shape=shape,
                                      initializer=initializer)
    return W
예제 #2
0
파일: train.py 프로젝트: caslab-vt/SARNet
def train():
    # Setup random seeds and args parameters
    args = parse_args()
    if args.benchmark or args.display:
        args.random_seed = int(time.time())
        args.memory_dropout = 1.0
        args.read_dropout = 1.0
        args.write_dropout = 1.0
        args.output_dropout = 1.0
    np.random.seed(args.random_seed)
    tf.compat.v1.set_random_seed(args.random_seed)
    """" 
    --------------------------------------------------------------------------- 
    Set experiment directory structure and files to read/write data to
    ---------------------------------------------------------------------------
    """

    # exp_name, exp_itr, tboard_dir, data_file = nutil.create_dir(args)
    is_bench_dis = args.benchmark or args.display
    is_train = not is_bench_dis
    """" 
    --------------------------------------------------------------------------- 
    Create the number of environments, num_env == 1 for benchmark and display
    ---------------------------------------------------------------------------
    """
    cpu_proc_envs, num_env, num_agents, num_adversaries, obs_shape_n, action_space = create_env(
        args)
    args.num_gpu_threads = int(num_agents + 1)
    """" 
    --------------------------------------------------------------------------- 
    Load/Create Model
    ---------------------------------------------------------------------------
    """

    trainers, sess = load_model(num_agents, obs_shape_n, action_space, args,
                                num_env, is_train)
    # Initialize a replay buffer
    buffer_op = BufferOp(args, num_agents)
    # Get GPU Trainer Threads
    gpu_threads_train = get_gputhreads(trainers, args, buffer_op, num_env,
                                       num_agents, num_adversaries)
    # Initialize action/train calls
    if args.policy_grad == "reinforce":
        train_act_op = ActionOPVPG(trainers, args, num_env, num_agents,
                                   cpu_proc_envs, gpu_threads_train, is_train)
    elif args.policy_grad == "maddpg":
        train_act_op = ActionOPTD3(trainers, args, num_env, num_agents,
                                   cpu_proc_envs, gpu_threads_train, is_train)
    else:
        raise NotImplementedError
    U.initialize()

    # Load previous results, if necessary
    if args.load_dir == "":
        dirname = os.path.dirname(__file__)
        args.load_dir = os.path.join(
            dirname, 'exp_data/' + train_act_op.exp_name + '/' +
            train_act_op.exp_itr + args.save_dir + args.policy_file)
    if args.display or args.restore or args.benchmark:
        print('Loading previous state...')
        U.load_state(args.load_dir)
    """" 
    --------------------------------------------------------------------------- 
    Initialize environment and reward data structures
    ---------------------------------------------------------------------------
    """
    # Initialize training parameters
    saver = tf.compat.v1.train.Saver()

    print('Starting iterations...')
    main_run_time = time.time()
    # print([x.name for x in tf.global_variables()])

    # CPU: Reset all environments and initialize all hidden states
    train_act_op.reset_states()
    start_time = time.time()
    while True:
        """ 
        Perform following steps: 
        1. Queue and receive action from GPU session
        2. Queue critic hidden states to GPU session
        3. Queue and receive environment steps
        4. Receive critic hidden state from GPU
        5. Queue (and move on) buffer additions
        6. Prepare updated hidden states for next step (non multi thread)
        
        """
        # GPU: Queue and wait for all actions
        # Stores actions in self.action_n_t
        train_act_op.queue_recv_actor()
        # GPU: Queue for all critic states
        if args.policy_grad == "maddpg":
            train_act_op.queue_critic()
        # Stores new observation, reward, done and benchmark
        train_act_op.get_env_act()
        if args.display:
            train_act_op.display_env()
        # Get all critic states and store in self.q1/2_h_n_t1 for next step for ddpg updates
        if args.policy_grad == "maddpg":
            train_act_op.recv_critic()
        # Queue values to be saved into the buffer, also computes done status before feeding data
        update_status = train_act_op.save_buffer()
        # Queue rewards to be saved
        train_act_op.save_rew_info()
        # Prepare inputs for next step
        train_act_op.update_states()

        # Update the actors and critics by sampling from the buffer, also write to tensor board
        if not (args.benchmark or args.display):
            if update_status:
                train_act_op.get_loss()

        if train_act_op.terminal:
            # eps_completed = train_act_op.train_step * num_env / args.max_episode_len
            if not (args.benchmark or args.display):
                train_act_op.save_model_rew_disk(saver,
                                                 time.time() - start_time)
            else:
                done_bench = train_act_op.save_benchmark()
                if done_bench:
                    print("Finished Benchmarking")
                    cpu_proc_envs.cancel()
                    close_gputhreads(gpu_threads_train)
                    tf.compat.v1.InteractiveSession.close(sess)
                    time.sleep(60)
                    break
            train_act_op.reset_rew_info()
            start_time = time.time()

        # saves final episode reward for plotting training curve later
        if args.policy_grad == "maddpg":
            if not args.benchmark:
                if train_act_op.train_step * num_env > args.num_total_frames:
                    eps_completed = train_act_op.train_step * num_env / args.max_episode_len
                    wutil.write_runtime(train_act_op.data_file, eps_completed,
                                        main_run_time)
                    cpu_proc_envs.cancel()
                    close_gputhreads(gpu_threads_train)
                    tf.compat.v1.InteractiveSession.close(sess)
                    time.sleep(60)
                    break

        elif args.policy_grad == "reinforce":
            if not args.benchmark:
                if train_act_op.train_step * num_env / args.max_episode_len > args.num_episodes:
                    eps_completed = train_act_op.train_step * num_env / args.max_episode_len
                    wutil.write_runtime(train_act_op.data_file, eps_completed,
                                        main_run_time)
                    cpu_proc_envs.cancel()
                    close_gputhreads(gpu_threads_train)
                    tf.compat.v1.InteractiveSession.close(sess)
                    time.sleep(60)