Exemple #1
0
def test_func(
    rank,
    E,
    T,
    args,
    test_q,
    device,
    tensorboard_dir,
):
    torch.manual_seed(args.seed + rank)
    np.random.seed(args.seed + rank)
    print("set up Test process env")
    opp = args.opp_list[rank]
    # non_station evaluation
    # if args.exp_name == "test":
    #     env = gym.make("CartPole-v0")
    # elif p2 == "Non-station":
    #     env = make_ftg_ram_nonstation(args.env, p2_list=args.list, total_episode=args.test_episode,
    #                                   stable=args.stable)
    # else:
    #     env = make_ftg_ram(args.env, p2=p2)
    # obs_dim = env.observation_space.shape[0]
    # act_dim = env.action_space.n
    env = SoccerPLUS()
    obs_dim = env.n_features
    act_dim = env.n_actions

    ac_kwargs = dict(hidden_sizes=[args.hid] * args.l)
    local_ac = MLPActorCritic(obs_dim, act_dim, **ac_kwargs)
    env.close()
    del env
    temp_dir = os.path.join(tensorboard_dir, "test_{}".format(opp))
    if not os.path.exists(temp_dir):
        os.makedirs(temp_dir)
    writer = SummaryWriter(log_dir=temp_dir)
    # Main loop: collect experience in env and update/log each epoch
    while True:
        received_obj = test_q.get()
        (test_model, t) = received_obj
        print("TEST Process {} loaded new mode at {} step".format(rank, t))
        model_dict = deepcopy(test_model)
        local_ac.load_state_dict(model_dict)
        del received_obj
        # if args.exp_name == "test":
        #     env = gym.make("CartPole-v0")
        # elif p2 == "Non-station":
        #     env = make_ftg_ram_nonstation(args.env, p2_list=args.list, total_episode=args.test_episode,stable=args.stable)
        # else:
        #     env = make_ftg_ram(args.env, p2=p2)
        env = SoccerPLUS()
        print("TESTING process {} start to test, opp: {}".format(rank, opp))
        m_score, win_rate, steps = test_proc(local_ac, env, opp, args, device)
        test_summary(opp, steps, m_score, win_rate, writer, args, t)
        print("TESTING process {} finished, opp: {}".format(rank, opp))
        env.close()
        del env
        if t >= args.episode:
            break
    print("Process {}\tTester Ended".format(rank))
Exemple #2
0
def main():
    # env = gym.make('CartPole-v0')
    # obs_dim = env.observation_space.shape[0]
    # act_dim = env.action_space.n

    env = SoccerPLUS(visual=False)
    obs_dim = env.n_features
    act_dim = env.n_actions
    learning_rate = 0.0001
    gamma = 0.98
    hidden = 256
    n_rollout = 10
    policy_type = 1
    opp_policy = Policy(game=env, player_num=False)
    model = ActorCritic(obs_dim, act_dim, hidden, learning_rate, gamma)

    # Training Loop
    print_interval = 100
    score = 0.0
    n_epi = 0
    while True:
        n_epi += 1
        done = False
        s = env.reset()
        while not done:
            for t in range(n_rollout):
                prob = model.pi(torch.from_numpy(s).float())
                m = Categorical(prob)
                a = m.sample().item()
                # s_prime, r, done, info = env.step(a)
                s_prime, r, done, info = env.step(
                    a, opp_policy.get_actions(policy_type))
                env.render()
                model.put_data((s, a, r, s_prime, done))

                s = s_prime
                score += r

                if done:
                    break

            model.train_net()

        if n_epi % print_interval == 0 and n_epi != 0:
            print("# of episode :{}, avg score : {:.1f}".format(
                n_epi, score / print_interval))
            score = 0.0
    env.close()
Exemple #3
0
def sac(env_fn,
        actor_critic=MLPActorCritic,
        ac_kwargs=dict(),
        seed=0,
        steps_per_epoch=4000,
        epochs=100,
        replay_size=int(1e6),
        gamma=0.99,
        polyak=0.995,
        lr=1e-3,
        alpha=0.2,
        batch_size=100,
        start_steps=10000,
        update_after=1000,
        update_every=50,
        num_test_episodes=10,
        max_ep_len=1000,
        policy_type=1,
        logger_kwargs=dict(),
        save_freq=1000,
        save_dir=None):

    torch.manual_seed(seed)
    np.random.seed(seed)

    env = env_fn()
    opp_policy = Policy(game=env, player_num=False)
    test_env = SoccerPLUS(visual=False)
    test_opp_policy = Policy(game=test_env, player_num=False)
    obs_dim = env.n_features
    act_dim = env.n_actions  #env.n_actions

    # Action limit for clamping: critically, assumes all dimensions share the same bound!
    # act_limit = env.action_space.high[0]

    # Create actor-critic module and target networks
    ac = actor_critic(obs_dim, act_dim, **ac_kwargs)
    ac_targ = deepcopy(ac)
    if torch.cuda.is_available():
        ac.cuda()
        ac_targ.cuda()

    # Freeze target networks with respect to optimizers (only update via polyak averaging)
    for p in ac_targ.parameters():
        p.requires_grad = False

    # List of parameters for both Q-networks (save this for convenience)

    # Experience buffer
    replay_buffer = ReplayBuffer(obs_dim=obs_dim, size=replay_size)

    # Count variables (protip: try to get a feel for how different size networks behave!)
    var_counts = tuple(count_vars(module) for module in [ac.pi, ac.q1, ac.q2])

    # Set up optimizers for policy and q-function
    pi_optimizer = Adam(ac.pi.parameters(), lr=lr)
    q1_optimizer = Adam(ac.q1.parameters(), lr=lr)
    q2_optimizer = Adam(ac.q2.parameters(), lr=lr)

    # Set up model saving

    # product action
    def get_actions_info(a_prob):
        a_dis = Categorical(a_prob)
        max_a = torch.argmax(a_prob)
        sample_a = a_dis.sample().cpu()
        z = a_prob == 0.0
        z = z.float() * 1e-20
        log_a_prob = torch.log(a_prob + z)
        return a_prob, log_a_prob, sample_a, max_a

    # Set up function for computing SAC Q-losses
    def compute_loss_q(data):
        o, a, r, o2, d = data['obs'], data['act'], data['rew'], data[
            'obs2'], data['done']

        # Bellman backup for Q functions
        with torch.no_grad():
            # Target actions come from *current* policy
            a_prob, log_a_prob, sample_a, max_a = get_actions_info(ac.pi(o2))

            # Target Q-values
            q1_pi_targ = ac_targ.q1(o2)
            q2_pi_targ = ac_targ.q2(o2)
            q_pi_targ = torch.min(q1_pi_targ, q2_pi_targ)
            backup = r + gamma * (1 - d) * torch.sum(
                a_prob * (q_pi_targ - alpha * log_a_prob), dim=1)

        # MSE loss against Bellman backup
        q1 = ac.q1(o).gather(1, a.unsqueeze(-1).long())
        q2 = ac.q2(o).gather(1, a.unsqueeze(-1).long())
        loss_q1 = F.mse_loss(q1, backup.unsqueeze(-1))
        loss_q2 = F.mse_loss(q2, backup.unsqueeze(-1))
        loss_q = loss_q1 + loss_q2

        return loss_q

    # Set up function for computing SAC pi loss
    def compute_loss_pi(data):
        o = data['obs']
        a_prob, log_a_prob, sample_a, max_a = get_actions_info(ac.pi(o))
        q1_pi = ac.q1(o)
        q2_pi = ac.q2(o)
        q_pi = torch.min(q1_pi, q2_pi)

        # Entropy-regularized policy loss
        loss_pi = torch.sum(a_prob * (alpha * log_a_prob - q_pi),
                            dim=1,
                            keepdim=True).mean()
        entropy = torch.sum(log_a_prob * a_prob, dim=1).detach()

        # Useful info for logging
        pi_info = dict(LogPi=entropy.cpu().numpy())
        return loss_pi, entropy

    def update(data):
        # First run one gradient descent step for Q1 and Q2
        q1_optimizer.zero_grad()
        q2_optimizer.zero_grad()
        loss_q = compute_loss_q(data)
        loss_q.backward()
        nn.utils.clip_grad_norm_(ac.parameters(), max_norm=10, norm_type=2)
        q1_optimizer.step()
        q2_optimizer.step()

        # Next run one gradient descent step for pi.
        pi_optimizer.zero_grad()
        loss_pi, entropy = compute_loss_pi(data)
        loss_pi.backward()
        nn.utils.clip_grad_norm_(ac.parameters(), max_norm=10, norm_type=2)
        pi_optimizer.step()

        # Unfreeze Q-networks so you can optimize it at next DDPG step.
        # for p in q_params:
        # p.requires_grad = True

        # Record things

        if t >= update_after:
            # lr = max(args.lr * 2 ** (-(t-update_after) * 0.0001), 1e-10)
            _adjust_learning_rate(q1_optimizer, max(lr, 1e-10))
            _adjust_learning_rate(q2_optimizer, max(lr, 1e-10))
            _adjust_learning_rate(pi_optimizer, max(lr, 1e-10))

        # Finally, update target networks by polyak averaging.
        with torch.no_grad():
            for p, p_targ in zip(ac.parameters(), ac_targ.parameters()):
                p_targ.data.copy_((1 - polyak) * p.data + polyak * p_targ.data)

        writer.add_scalar("training/pi_loss", loss_pi.detach().item(), t)
        writer.add_scalar("training/q_loss", loss_q.detach().item(), t)
        writer.add_scalar("training/entropy",
                          entropy.detach().mean().item(), t)
        writer.add_scalar("training/lr", lr, t)

    def get_action(o, greedy=False):
        if len(o.shape) == 1:
            o = np.expand_dims(o, axis=0)
        a_prob = ac.act(
            torch.as_tensor(
                o,
                dtype=torch.float32,
                device=torch.device("cuda")
                if torch.cuda.is_available() else torch.device("cpu")), greedy)
        a_prob, log_a_prob, sample_a, max_a = get_actions_info(a_prob)
        action = sample_a if not greedy else max_a
        return action.item()

    def test_agent(epoch, t_opp, writer):
        if num_test_episodes == 0:
            return
        with torch.no_grad():
            win = 0
            total_ret = 0
            total_len = 0
            for j in range(num_test_episodes):
                o, d, ep_ret, ep_len = test_env.reset(), False, 0, 0
                while not (d or (ep_len == max_ep_len)):
                    # Take deterministic actions at test time
                    o2, r, d, _ = test_env.step(
                        get_action(o, True),
                        test_opp_policy.get_actions(t_opp))
                    r *= 10
                    # test_env.render()
                    o = o2
                    ep_ret += r
                    ep_len += 1
                total_ret += ep_ret
                total_len += ep_len
                if (ep_ret == 50):
                    win += 1
            mean_score = total_ret / num_test_episodes
            win_rate = win / num_test_episodes
            mean_len = total_len / num_test_episodes
            logger.info(
                "opponent:\t{}\ntest epoch:\t{}\nmean score:\t{:.1f}\nwin_rate:\t{}\nmean len:\t{}"
                .format(t_opp, epoch, mean_score, win_rate, mean_len))
            writer.add_scalar("test/mean_score", mean_score, epoch)
            writer.add_scalar("test/win_rate", win_rate, epoch)
            writer.add_scalar("test/mean_len", mean_len, epoch)

    def test_opp(epoch, p, writer):
        if num_test_episodes == 0:
            return
        with torch.no_grad():
            win = 0
            total_ret = 0
            for j in range(num_test_episodes):
                t_opp = get_opp_policy(p)
                o, d, ep_ret, ep_len = test_env.reset(), False, 0, 0
                while not (d or (ep_len == max_ep_len)):
                    # Take deterministic actions at test time
                    o2, r, d, _ = test_env.step(
                        get_action(o, True),
                        test_opp_policy.get_actions(t_opp))
                    r *= 10
                    # test_env.render()
                    o = o2
                    ep_ret += r
                    ep_len += 1
                total_ret += ep_ret
                if (ep_ret == 50):
                    win += 1
            mean_score = total_ret / num_test_episodes
            win_rate = win / num_test_episodes
            logger.info(
                "p:\t{}\ntest epoch:\t{}\nmean score:\t{:.1f}\nwin_rate:\t{}\n"
                .format(p, epoch, mean_score, win_rate))
            writer.add_scalar("test/mean_score", mean_score, epoch)
            writer.add_scalar("test/win_rate", win_rate, epoch)

    def get_opp_policy(p):
        p_sample = np.random.rand()
        if p_sample < p:
            return 4
        else:
            return 5

    # Prepare for interaction with environment
    total_steps = steps_per_epoch * epochs
    start_time = time.time()
    scores = []
    o, ep_ret, ep_len = env.reset(), 0, 0
    discard = False
    episode = 0
    opp = get_opp_policy(args.p)

    # Main loop: collect experience in env and update/log each epoch
    for t in range(total_steps):

        # Until start_steps have elapsed, randomly sample actions
        # from a uniform distribution for better exploration. Afterwards,
        # use the learned policy.
        with torch.no_grad():
            if t >= start_steps:
                a = get_action(o)
            else:
                a = np.random.randint(act_dim)

        # Step the env
        o2, r, d, info = env.step(a, opp_policy.get_actions(opp))
        # r = int(r * 0.2)
        if info.get('no_data_receive', False):
            discard = True
        env.render()
        ep_ret += r
        ep_len += 1

        # Ignore the "done" signal if it comes from hitting the time
        # horizon (that is, when it's an artificial terminal signal
        # that isn't based on the agent's state)
        d = False if ep_len == max_ep_len or discard else d

        # Store experience to replay buffer
        replay_buffer.store(o, a, r, o2, d)
        writer.add_scalar("learner/buffer_size", replay_buffer.size, t)

        # Super critical, easy to overlook step: make sure to update
        # most recent observation!
        o = o2

        # End of trajectory handling
        if d or (ep_len == max_ep_len) or discard:
            scores.append(ep_ret)
            # print("total_step {},round len:{}, round score: {}, 100 mean score: {}, 10 mean Score: {}".format(t,ep_len, ep_ret, np.mean(scores[-100:]),np.mean(scores[-10:])))
            logger.info(
                "total_step {}, episode: {}, opp:{}, round len:{}, round score: {}, 100 mean score: {}, 10 mean Score: {}"
                .format(t, episode, opp, ep_len, ep_ret,
                        np.mean(scores[-100:]), np.mean(scores[-10:])))
            writer.add_scalar("metrics/round_score", ep_ret, t)
            writer.add_scalar("metrics/round_step", ep_len, t)
            writer.add_scalar("metrics/alpha", alpha, t)
            o, ep_ret, ep_len = env.reset(), 0, 0
            opp = get_opp_policy(args.p)
            episode += 1
            discard = False

        # Update handling
        if t >= update_after and t % update_every == 0:
            for j in range(update_every):
                batch = replay_buffer.sample_batch(
                    batch_size,
                    device=torch.device("cuda")
                    if torch.cuda.is_available() else torch.device("cpu"))
                update(data=batch)

        # End of epoch handling
        # if (t + 1) % steps_per_epoch == 0:
        #     epoch = (t + 1) // steps_per_epoch

        #     # Save model
        #     if t % save_freq == 0 and t > 0:
        #         torch.save(ac.state_dict(), os.path.join(save_dir, "model"))
        #         print("Saving model at episode:{}".format(t))
        if t >= update_after and t % save_freq == 0:

            # Test the performance of the deterministic version of the agent.
            test_agent(t, 4, writer_1)
            test_agent(t, 5, writer_3)
            test_opp(t, args.p, writer_opp)
Exemple #4
0
    if not os.path.isfile(filename):
        f = open(filename, mode='w')
        f.close()
    logger = get_logger(filename)

    argument_file = save_dir + '.args'
    argsDict = args.__dict__
    with open(argument_file, 'w') as f:
        f.writelines('------------------ start ------------------' + '\n')
        for eachArg, value in argsDict.items():
            f.writelines(eachArg + ' : ' + str(value) + '\n')
        f.writelines('------------------- end -------------------' + '\n')

    torch.set_num_threads(torch.get_num_threads())

    sac(lambda: SoccerPLUS(visual=False),
        actor_critic=MLPActorCritic,
        ac_kwargs=dict(hidden_sizes=[args.hid] * args.l),
        gamma=args.gamma,
        seed=args.seed,
        epochs=args.epochs,
        policy_type=args.policy_type,
        replay_size=args.replay_size,
        lr=args.lr,
        alpha=args.alpha,
        batch_size=args.batch_size,
        start_steps=10000,
        steps_per_epoch=1000,
        polyak=0.995,
        update_after=10000,
        update_every=1,
Exemple #5
0
        action = self.score_to_index(valid_actions, actions_score)
        return action

    def get_actions(self, policy_num):
        if isinstance(policy_num, str):
            policy_num = int(policy_num)
        self.update_status()
        if policy_num == -1:
            return np.random.randint(len(self.action_map))
        valid_new_locations, all_actions = self.validActionAll()
        return self.policy_type[policy_num](all_actions)


if __name__ == "__main__":
    policy_types = list(range(5))
    env = SoccerPLUS(visual=True)
    env.reset()
    my_policy = Policy(game=env, player_num=True)
    opp_policy = Policy(game=env, player_num=False)
    for rounds in range(1000):
        env.reset()
        while True:
            s_, reward, done, _ = env.step(my_policy.get_actions(5),
                                           opp_policy.get_actions(5))
            env.render()
            time.sleep(0.5)
            if done:
                print(reward)
                break
    # total_performance = dict()
    # win_rate = dict()
Exemple #6
0
def sac(
    rank,
    E,
    T,
    args,
    model_q,
    buffer_q,
    device=None,
    tensorboard_dir=None,
):
    torch.manual_seed(args.seed + rank)
    np.random.seed(args.seed + rank)
    # writer = GlobalSummaryWriter.getSummaryWriter()
    tensorboard_dir = os.path.join(tensorboard_dir, str(rank))
    if not os.path.exists(tensorboard_dir):
        os.makedirs(tensorboard_dir)
    writer = SummaryWriter(log_dir=tensorboard_dir)
    # if args.exp_name == "test":
    #     env = gym.make("CartPole-v0")
    # elif args.non_station:
    #     env = make_ftg_ram_nonstation(args.env, p2_list=args.opp_list, total_episode=args.station_rounds,stable=args.stable)
    # else:
    #     env = make_ftg_ram(args.env, p2=args.p2)
    # obs_dim = env.observation_space.shape[0]
    # act_dim = env.action_space.n
    env = SoccerPLUS()
    opp_policy = Policy(game=env, player_num=False)
    opps = []
    for opp in args.opp_list:
        opps += [opp] * args.opp_freq
    opp = opps[0]
    obs_dim = env.n_features
    act_dim = env.n_actions
    ac_kwargs = dict(hidden_sizes=[args.hid] * args.l)
    local_ac = MLPActorCritic(obs_dim, act_dim, **ac_kwargs)
    print("set up child process env")

    # Prepare for interaction with environment
    scores, wins = [], []
    # meta data is purely for experiment analysis
    trajectory, meta = [], []
    o, ep_ret, ep_len = env.reset(), 0, 0
    discard = False
    local_t, local_e = 0, 0
    if not model_q.empty():
        print("Process {}\t Initially LOADING...".format(rank))
        received_obj = model_q.get()
        model_dict = deepcopy(received_obj)
        local_ac.load_state_dict(model_dict)
        print("Process {}\t Initially Loading FINISHED!!!".format(rank))
        del received_obj
    # Main loop: collect experience in env and update/log each epoch
    while T.value() < args.episode:
        with torch.no_grad():
            if E.value() <= args.update_after:
                a = np.random.randint(act_dim)
            else:
                a = local_ac.get_action(o, device=device)

        # print(o)
        # Step the env
        o2, r, d, info = env.step(a, opp_policy.get_actions(opp))
        env.render()
        if info.get('no_data_receive', False):
            discard = True
        ep_ret += r
        ep_len += 1

        d = False if (ep_len == args.max_ep_len) or discard else d
        # send the transition to main process
        # if hasattr(env, 'p2'):
        #     opp = env.p2
        # else:
        #     opp = None
        transition = (o, a, r, o2, d)
        trajectory.append(transition)
        meta.append([opp, rank, local_e, ep_len, r, a])
        o = o2
        local_t += 1
        # End of trajectory handling
        if d or (ep_len == args.max_ep_len) or discard:
            e = E.value()
            t = T.value()
            send_data = (trajectory, meta)
            buffer_q.put(send_data, )
            local_e += 1
            # logger.store(EpRet=ep_ret, EpLen=ep_len)
            if info.get('win', False):
                wins.append(1)
            else:
                wins.append(0)
            scores.append(ep_ret)
            m_score = np.mean(scores[-100:])
            win_rate = np.mean(wins[-100:])
            # print(
            # "Process\t{}\topponent:{},\t# of local episode :{},\tglobal episode {},\tglobal step {}\tround score: {},\tmean score : {:.1f},\twin rate:{},\tsteps: {}".format(
            # rank, opp, local_e, e, t, ep_ret, m_score, win_rate, ep_len))
            writer.add_scalar("actor/round_score", ep_ret, local_e)
            writer.add_scalar("actor/mean_score", m_score.item(), local_e)
            writer.add_scalar("actor/win_rate", win_rate.item(), local_e)
            writer.add_scalar("actor/round_step", ep_len, local_e)
            writer.add_scalar("actor/learner_actor_speed", e, local_e)
            o, ep_ret, ep_len = env.reset(), 0, 0
            opp = opps[local_e % len(opps)]
            discard = False
            trajectory, meta = list(), list()
            if not model_q.empty():
                # print("Process {}\tLOADING model at Global\t{},local\t{} EPISODE...".format(rank, e, local_e))
                received_obj = model_q.get()
                model_dict = deepcopy(received_obj)
                local_ac.load_state_dict(model_dict)
                # print("Process {}\tLOADED new mode at Global\t{},local\t{}!!!".format(rank, e, local_e))
                del received_obj
    print("Process {}\tActor Ended".format(rank))
Exemple #7
0
    tensorboard_dir_1 = os.path.join(test_save_dir, "test_" + str(args.opp1) + "_" + str(args.p1) + '_' + str(args.p2) + '_' + str(args.seed))
    tensorboard_dir_3 = os.path.join(test_save_dir, "test_" + str(args.opp2) + "_" + str(args.p1) + '_' + str(args.p2) + '_' + str(args.seed))
    if not os.path.exists(tensorboard_dir_1):
        os.makedirs(tensorboard_dir_1)
    if not os.path.exists(tensorboard_dir_3):
        os.makedirs(tensorboard_dir_3)
    writer_1 = SummaryWriter(log_dir=tensorboard_dir_1)
    writer_3 = SummaryWriter(log_dir=tensorboard_dir_3)

    filename = save_dir + '_exp.log'
    if not os.path.isfile(filename):
        f = open(filename,mode = 'w')
        f.close()
    logger = get_logger(filename)

    argument_file = save_dir + '.args'
    argsDict = args.__dict__
    with open(argument_file, 'w') as f:
        f.writelines('------------------ start ------------------' + '\n')
        for eachArg, value in argsDict.items():
            f.writelines(eachArg + ' : ' + str(value) + '\n')
        f.writelines('------------------- end -------------------' + '\n')

    torch.set_num_threads(torch.get_num_threads())

    sac(lambda: SoccerPLUS(visual=False), actor_critic=MLPActorCritic,
        ac_kwargs=dict(hidden_sizes=[args.hid] * args.l),
        gamma=args.gamma, seed=args.seed, epochs=args.epochs, policy_type=args.policy_type,  replay_size=args.replay_size,
        lr=args.lr, alpha=args.alpha, batch_size=args.batch_size, start_steps=10000, steps_per_epoch=1000, polyak=0.995,
        update_after=10000, update_every=1, num_test_episodes=args.test_episodes, max_ep_len=1000, save_freq=500,
        logger_kwargs=dict(), save_dir=save_dir)
Exemple #8
0
    if not os.path.exists(main_dir):
        os.makedirs(main_dir)
    writer = SummaryWriter(log_dir=main_dir)
    with open(os.path.join(experiment_dir, "arguments"), 'w') as f:
        json.dump(args.__dict__, f, indent=2)
    device = torch.device("cuda") if args.cuda else torch.device("cpu")
    # env and model setup
    ac_kwargs = dict(hidden_sizes=[args.hid] * args.l)

    # if args.exp_name == "test":
    #     env = gym.make("CartPole-v0")
    # elif args.non_station:
    #     env = make_ftg_ram_nonstation(args.env, p2_list=args.opp_list, total_episode=args.opp_freq,stable=args.stable)
    # else:
    #     env = make_ftg_ram(args.env, p2=args.p2)
    env = SoccerPLUS()
    obs_dim = env.n_features
    act_dim = env.n_actions
    # create model
    global_ac = MLPActorCritic(obs_dim, act_dim, **ac_kwargs)
    if args.cpc:
        global_cpc = CPC(timestep=args.timestep,
                         obs_dim=obs_dim,
                         hidden_sizes=[args.hid] * args.l,
                         z_dim=args.z_dim,
                         c_dim=args.c_dim,
                         device=device)
    else:
        global_cpc = None
    # create shared model for actor
    global_ac_targ = deepcopy(global_ac)