Beispiel #1
0
def main():
    # setup logger
    if args.resume_dir == "":
        date = str(datetime.datetime.now())
        date = date[:date.rfind(":")].replace("-", "") \
            .replace(":", "") \
            .replace(" ", "_")
        log_dir = os.path.join(args.log_root, "log_" + date)
    else:
        log_dir = args.resume_dir
    hparams_file = os.path.join(log_dir, "hparams.json")
    checkpoints_dir = os.path.join(log_dir, "checkpoints")
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    if not os.path.exists(checkpoints_dir):
        os.makedirs(checkpoints_dir)
    if args.resume_dir == "":
        # write hparams
        with open(hparams_file, "w") as f:
            json.dump(args.__dict__, f, indent=2)
    log_file = os.path.join(log_dir, "log_train.txt")
    logger = Logger(log_file)
    # logger.info(args)
    logger.info("The args corresponding to training process are: ")
    for (key, value) in vars(args).items():
        logger.info("{key:20}: {value:}".format(key=key, value=value))

    actor_critic = ActorCritic(args, log_dir, checkpoints_dir)
    actor_critic.train()
Beispiel #2
0
def main():
    # setup logger
    log_dir = args.resume_dir
    hparams_file = os.path.join(log_dir, "hparams.json")
    checkpoints_dir = os.path.join(log_dir, "checkpoints")
    log_file = os.path.join(log_dir, "log_train.txt")
    logger = Logger(log_file)

    actor_critic = ActorCritic(args, log_dir, checkpoints_dir)
    actor_critic.evaluation()
Beispiel #3
0
    def __init__(
        self,
        lr,
        gamma,
        k_epochs,
        eps_clip,
        n_j,
        n_m,
        num_layers,
        neighbor_pooling_type,
        input_dim,
        hidden_dim,
        num_mlp_layers_feature_extract,
        num_mlp_layers_actor,
        hidden_dim_actor,
        num_mlp_layers_critic,
        hidden_dim_critic,
    ):
        self.lr = lr
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.k_epochs = k_epochs

        self.policy = ActorCritic(
            n_j=n_j,
            n_m=n_m,
            num_layers=num_layers,
            learn_eps=False,
            neighbor_pooling_type=neighbor_pooling_type,
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            num_mlp_layers_feature_extract=num_mlp_layers_feature_extract,
            num_mlp_layers_actor=num_mlp_layers_actor,
            hidden_dim_actor=hidden_dim_actor,
            num_mlp_layers_critic=num_mlp_layers_critic,
            hidden_dim_critic=hidden_dim_critic,
            device=device)
        self.policy_old = deepcopy(self.policy)
        '''self.policy.load_state_dict(
            torch.load(path='./{}.pth'.format(str(n_j) + '_' + str(n_m) + '_' + str(1) + '_' + str(99))))'''

        self.policy_old.load_state_dict(self.policy.state_dict())
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)
        self.scheduler = torch.optim.lr_scheduler.StepLR(
            self.optimizer,
            step_size=configs.decay_step_size,
            gamma=configs.decay_ratio)

        self.V_loss_2 = nn.MSELoss()
#print("Cuda: " + str(torch.cuda.is_available()))

if __name__ == '__main__':
    os.environ['OMP_NUM_THREADS'] = '1'
    torch.cuda.empty_cache()

    args = parser.parse_args()
    SAVEPATH = os.getcwd(
    ) + '/save/scme_' + args.reward_type + '/mario_a3c_params.pkl'
    if not os.path.exists(os.getcwd() + '/save/scme_' + args.reward_type):
        os.makedirs(os.getcwd() + '/save/scme_' + args.reward_type)

    env = create_mario_env(args.env_name, args.reward_type)

    shared_model = ActorCritic(env.observation_space.shape[0], len(ACTIONS))
    shared_model.share_memory()

    shared_scme = SCME(env.observation_space.shape[0], len(ACTIONS))
    shared_scme.share_memory()

    if os.path.isfile(SAVEPATH):
        print('Loading A3C parametets & SCME parameters...')
        shared_model.load_state_dict(torch.load(SAVEPATH))
        shared_scme.load_state_dict(torch.load(SAVEPATH[:-4] + '_scme.pkl'))

    torch.manual_seed(args.seed)

    #optimizer = torch.optim.Adam(list(shared_model.parameters()) + list(shared_scme.parameters()), lr=args.lr)
    optimizer = SharedAdam(list(shared_model.parameters()) +
                           list(shared_scme.parameters()),
Beispiel #5
0
def train(rank,
          args,
          shared_model,
          counter,
          lock,
          optimizer=None,
          select_sample=True):
    torch.manual_seed(args.seed + rank)

    print("Process No : {} | Sampling : {}".format(rank, select_sample))

    FloatTensor = torch.cuda.FloatTensor if args.use_cuda else torch.FloatTensor
    DoubleTensor = torch.cuda.DoubleTensor if args.use_cuda else torch.DoubleTensor
    ByteTensor = torch.cuda.ByteTensor if args.use_cuda else torch.ByteTensor

    env = create_mario_env(args.env_name)
    env.seed(args.seed + rank)

    model = ActorCritic(env.observation_space.shape[0], len(ACTIONS))
    if args.use_cuda:
        model.cuda()
    if optimizer is None:
        optimizer = optim.Adam(shared_model.parameters(), lr=args.lr)

    model.train()

    state = env.reset()
    state = torch.from_numpy(state)
    done = True

    episode_length = 0
    for num_iter in count():

        if rank == 0:
            env.render()

            if num_iter % args.save_interval == 0 and num_iter > 0:
                print("Saving model at :" + args.save_path)
                torch.save(shared_model.state_dict(), args.save_path)

        if num_iter % (
                args.save_interval * 2.5
        ) == 0 and num_iter > 0 and rank == 1:  # Second saver in-case first processes crashes
            print("Saving model for process 1 at :" + args.save_path)
            torch.save(shared_model.state_dict(), args.save_path)

        # Sync with the shared model
        model.load_state_dict(shared_model.state_dict())
        if done:
            cx = Variable(torch.zeros(1, 512)).type(FloatTensor)
            hx = Variable(torch.zeros(1, 512)).type(FloatTensor)
        else:
            cx = Variable(cx.data).type(FloatTensor)
            hx = Variable(hx.data).type(FloatTensor)

        values = []
        log_probs = []
        rewards = []
        entropies = []
        reason = ''

        for step in range(args.num_steps):
            episode_length += 1
            state_inp = Variable(state.unsqueeze(0)).type(FloatTensor)
            value, logit, (hx, cx) = model((state_inp, (hx, cx)))
            prob = F.softmax(logit, dim=-1)
            log_prob = F.log_softmax(logit, dim=-1)
            entropy = -(log_prob * prob).sum(-1, keepdim=True)
            entropies.append(entropy)

            if select_sample:
                action = prob.multinomial().data
            else:
                action = prob.max(-1, keepdim=True)[1].data

            log_prob = log_prob.gather(-1, Variable(action))

            action_out = ACTIONS[action][0, 0]

            # print("Process: {} Action: {}".format(rank,  str(action_out)))

            state, reward, done, _ = env.step(action_out)

            done = done or episode_length >= args.max_episode_length
            reward = max(min(reward, 50), -50)

            with lock:
                counter.value += 1

            if done:
                episode_length = 0
                env.change_level(0)
                state = env.reset()
                print("Process {} has completed.".format(rank))

            env.locked_levels = [False] + [True] * 31
            state = torch.from_numpy(state)
            values.append(value)
            log_probs.append(log_prob)
            rewards.append(reward)

            if done:
                break

        R = torch.zeros(1, 1)
        if not done:
            state_inp = Variable(state.unsqueeze(0)).type(FloatTensor)
            value, _, _ = model((state_inp, (hx, cx)))
            R = value.data

        values.append(Variable(R).type(FloatTensor))
        policy_loss = 0
        value_loss = 0
        R = Variable(R).type(FloatTensor)
        gae = torch.zeros(1, 1).type(FloatTensor)
        for i in reversed(range(len(rewards))):
            R = args.gamma * R + rewards[i]
            advantage = R - values[i]
            value_loss = value_loss + 0.5 * advantage.pow(2)

            # Generalized Advantage Estimataion
            delta_t = rewards[i] + args.gamma * \
                values[i + 1].data - values[i].data
            gae = gae * args.gamma * args.tau + delta_t

            policy_loss = policy_loss - \
                log_probs[i] * Variable(gae).type(FloatTensor) - args.entropy_coef * entropies[i]

        total_loss = policy_loss + args.value_loss_coef * value_loss

        print("Process {} loss :".format(rank), total_loss.data)
        # print("Process: {} Episode: {}".format(rank,  str(episode_length)))
        optimizer.zero_grad()

        (total_loss).backward()
        torch.nn.utils.clip_grad_norm(model.parameters(), args.max_grad_norm)

        ensure_shared_grads(model, shared_model)
        optimizer.step()
    print("Process {} closed.".format(rank))
Beispiel #6
0
def test(rank, args, shared_model, counter):
    torch.manual_seed(args.seed + rank)

    FloatTensor = torch.cuda.FloatTensor if args.use_cuda else torch.FloatTensor
    DoubleTensor = torch.cuda.DoubleTensor if args.use_cuda else torch.DoubleTensor
    ByteTensor = torch.cuda.ByteTensor if args.use_cuda else torch.ByteTensor

    env = create_mario_env(args.env_name)
    """ 
        need to implement Monitor wrapper with env.change_level
    """
    # expt_dir = 'video'
    # env = wrappers.Monitor(env, expt_dir, force=True, video_callable=lambda count: count % 10 == 0)

    env.seed(args.seed + rank)

    model = ActorCritic(env.observation_space.shape[0], len(ACTIONS))
    if args.use_cuda:
        model.cuda()
    model.eval()

    state = env.reset()
    state = torch.from_numpy(state)
    reward_sum = 0
    done = True
    savefile = os.getcwd() + '/save/mario_curves.csv'

    title = ['Time', 'No. Steps', 'Total Reward', 'Episode Length']
    with open(savefile, 'a', newline='') as sfile:
        writer = csv.writer(sfile)
        writer.writerow(title)

    start_time = time.time()

    # a quick hack to prevent the agent from stucking
    actions = deque(maxlen=4000)
    episode_length = 0
    while True:
        episode_length += 1
        ep_start_time = time.time()
        # Sync with the shared model
        if done:
            model.load_state_dict(shared_model.state_dict())
            cx = Variable(torch.zeros(1, 512), volatile=True).type(FloatTensor)
            hx = Variable(torch.zeros(1, 512), volatile=True).type(FloatTensor)

        else:
            cx = Variable(cx.data, volatile=True).type(FloatTensor)
            hx = Variable(hx.data, volatile=True).type(FloatTensor)

        state_inp = Variable(state.unsqueeze(0),
                             volatile=True).type(FloatTensor)
        value, logit, (hx, cx) = model((state_inp, (hx, cx)))
        prob = F.softmax(logit, dim=-1)
        action = prob.max(-1, keepdim=True)[1].data

        action_out = ACTIONS[action][0, 0]
        # print("Process: Test Action: {}".format(str(action_out)))

        state, reward, done, _ = env.step(action_out)
        env.render()
        done = done or episode_length >= args.max_episode_length
        reward_sum += reward

        # a quick hack to prevent the agent from stucking
        actions.append(action[0, 0])
        if actions.count(actions[0]) == actions.maxlen:
            done = True

        if done:
            print(
                "Time {}, num steps {}, FPS {:.0f}, episode reward {}, episode length {}"
                .format(
                    time.strftime("%Hh %Mm %Ss",
                                  time.gmtime(time.time() - start_time)),
                    counter.value, counter.value / (time.time() - start_time),
                    reward_sum, episode_length))

            data = [
                time.time() - ep_start_time, counter.value, reward_sum,
                episode_length
            ]

            with open(savefile, 'a', newline='') as sfile:
                writer = csv.writer(sfile)
                writer.writerows([data])

            reward_sum = 0
            episode_length = 0
            actions.clear()
            time.sleep(60)
            env.locked_levels = [False] + [True] * 31
            env.change_level(0)
            state = env.reset()

        state = torch.from_numpy(state)
Beispiel #7
0
def train(rank, args, shared_model, shared_scme, counter, lock, optimizer=None, select_sample=True):
    torch.manual_seed(args.seed + rank)

    print("Process No : {} | Sampling : {}".format(rank, select_sample))

    FloatTensor = torch.FloatTensor# torch.cuda.FloatTensor if args.use_cuda else torch.FloatTensor
    DoubleTensor = torch.DoubleTensor# torch.cuda.DoubleTensor if args.use_cuda else torch.DoubleTensor
    ByteTensor = torch.ByteTensor# torch.cuda.ByteTensor if args.use_cuda else torch.ByteTensor

    savefile = os.getcwd() + '/save/scmemi_'+ args.reward_type +'/train_reward.csv'
    saveweights = os.getcwd() + '/save/scmemi_'+ args.reward_type +'/mario_a3c_params.pkl'

    env = create_mario_env(args.env_name, args.reward_type)
    #env.seed(args.seed + rank)

    model = ActorCritic(env.observation_space.shape[0], len(ACTIONS))
    if optimizer is None:
        optimizer = optim.Adam(list(shared_model.parameters()) + list(shared_scme.parameters()), lr=args.lr)
        
    scme_model = SCME(env.observation_space.shape[0], len(ACTIONS))
    
    model.train()
    scme_model.train()

    state = env.reset()
    cum_rew = 0 
    state = torch.from_numpy(state)
    done = True
    
    episode_length = 0
    for num_iter in count():
        #env.render()
        if rank == 0:
            
            if num_iter % args.save_interval == 0 and num_iter > 0:
                print ("Saving model at :" + saveweights)            
                torch.save(shared_model.state_dict(), saveweights)
                torch.save(shared_scme.state_dict(), saveweights[:-4] + '_scme.pkl')

        if num_iter % (args.save_interval * 2.5) == 0 and num_iter > 0 and rank == 1:    # Second saver in-case first processes crashes 
            print ("Saving model for process 1 at :" + saveweights)            
            torch.save(shared_model.state_dict(), saveweights)
            torch.save(shared_scme.state_dict(), saveweights[:-4] + '_scme.pkl')
            
        # Sync with the shared model
        model.load_state_dict(shared_model.state_dict())
        scme_model.load_state_dict(shared_scme.state_dict())
        
        if done:
            cx = Variable(torch.zeros(1, 512)).type(FloatTensor)
            hx = Variable(torch.zeros(1, 512)).type(FloatTensor)
        else:
            cx = Variable(cx.data).type(FloatTensor)
            hx = Variable(hx.data).type(FloatTensor)

        values = []
        log_probs = []
        rewards = []
        entropies = []
        vae_losses = []
        cur_losses = []
        mi_losses = []
        #reason =''
        
        for step in range(args.num_steps):
            episode_length += 1            
            state_inp = Variable(state.unsqueeze(0)).type(FloatTensor)
            value, logit, (hx, cx) = model((state_inp, (hx, cx)))
            prob = F.softmax(logit, dim=-1)
            log_prob = F.log_softmax(logit, dim=-1)
            entropy = -(log_prob * prob).sum(-1, keepdim=True)
            entropies.append(entropy)
            
            
            if select_sample:
                action = prob.multinomial(1).data
            else:
                action = prob.max(-1, keepdim=True)[1].data
                
            log_prob = log_prob.gather(-1, Variable(action))
            
            action_out = int(action[0, 0].data.numpy())
            state, reward, done, info = env.step(action_out)
            cum_rew = cum_rew + reward
            
            action_one_hot = (torch.eye(len(ACTIONS))[action_out]).view(1,-1)
            
            next_state_inp = Variable(torch.from_numpy(state).unsqueeze(0)).type(FloatTensor)

            pred_z, mi, mi1, actual_z, xt1_hat, xt1, xt1_mu, xt1_logvar = scme_model((state_inp, next_state_inp, action_one_hot))
            vae_loss = loss_function(xt1_hat, xt1, xt1_mu, xt1_logvar)
            cur_loss = ((pred_z - actual_z).pow(2)).sum(-1, keepdim=True)/2/50
            mi_loss = mutual(mi, mi1).sum(-1, keepdim=True)/10
            done = done or episode_length >= args.max_episode_length
            
            cur_reward = (args.alpha*cur_loss).data.numpy()[0,0]
            mi_reward = (args.beta*mi_loss).data.numpy()
            reward = cur_reward + reward + mi_reward
            reward = max(min(reward, 50), -5)
            
            
            with lock:
                counter.value += 1

            if done:
                episode_length = 0
#                 env.change_level(0)
                state = env.reset()
                with open(savefile[:-4]+'_{}.csv'.format(rank), 'a', newline='') as sfile:
                    writer = csv.writer(sfile)
                    writer.writerows([[cum_rew, info['x_pos']/x_norm]])
                cum_rew = 0 
 #               print ("Process {} has completed.".format(rank))

#            env.locked_levels = [False] + [True] * 31
            state = torch.from_numpy(state)
            values.append(value)
            log_probs.append(log_prob)
            rewards.append(reward)
            vae_losses.append(vae_loss)
            cur_losses.append(cur_loss)
            mi_losses.append(mi_loss)
            
            
            if done:
                break
        R = torch.zeros(1, 1)
        if not done:
            state_inp = Variable(state.unsqueeze(0)).type(FloatTensor)
            value, _, _ = model((state_inp, (hx, cx)))
            R = value.data

        values.append(Variable(R).type(FloatTensor))
        policy_loss = 0
        value_loss = 0
        scme_loss = 0
        R = Variable(R).type(FloatTensor)
        gae = torch.zeros(1, 1).type(FloatTensor)
        for i in reversed(range(len(rewards))):
            R = args.gamma * R + rewards[i]
            advantage = R - values[i]
            value_loss = value_loss + 0.5 * advantage.pow(2)

            # Generalized Advantage Estimataion
            delta_t = rewards[i] + args.gamma * values[i + 1].data - values[i].data
            gae = gae * args.gamma * args.tau + delta_t

            policy_loss = policy_loss - log_probs[i] * Variable(gae).type(FloatTensor) - args.entropy_coef * entropies[i]
            
            scme_loss = 0.01*vae_losses[i] + cur_losses[i] - mi_losses[i]
        total_loss = args.lambd*(policy_loss + args.value_loss_coef * value_loss)
        
#        print ("Process {} loss :".format(rank), total_loss.data)
        optimizer.zero_grad()
#         cur_optimizer.zero_grad()
        
        (total_loss + scme_loss).backward()
#         (curiosity_loss).backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
        torch.nn.utils.clip_grad_norm_(scme_model.parameters(), args.max_grad_norm)

        ensure_shared_grads(model, shared_model)
        ensure_shared_grads(scme_model, shared_scme)
        
        optimizer.step()
Beispiel #8
0
def test(rank, args, shared_model, counter):
    torch.manual_seed(args.seed + rank)

    FloatTensor = torch.FloatTensor# torch.cuda.FloatTensor if args.use_cuda else torch.FloatTensor
    DoubleTensor = torch.DoubleTensor# torch.cuda.DoubleTensor if args.use_cuda else torch.DoubleTensor
    ByteTensor = torch.ByteTensor# torch.cuda.ByteTensor if args.use_cuda else torch.ByteTensor

    env = create_mario_env(args.env_name, args.reward_type)
    """ 
        need to implement Monitor wrapper with env.change_level
    """
    # expt_dir = 'video'
    # env = wrappers.Monitor(env, expt_dir, force=True, video_callable=lambda count: count % 10 == 0)
    
    #env.seed(args.seed + rank)

    model = ActorCritic(env.observation_space.shape[0], len(ACTIONS))
    model.eval()

    state = env.reset()
    state = torch.from_numpy(state)
    reward_sum = 0
    done = True
    savefile = os.getcwd() + '/save/scmemi_'+ args.reward_type +'/mario_curves.csv'
    
    title = ['Time','No. Steps', 'Total Reward', 'final_position', 'Episode Length']
    with open(savefile, 'a', newline='') as sfile:
        writer = csv.writer(sfile)
        writer.writerow(title)    

    start_time = time.time()

    # a quick hack to prevent the agent from stucking
    actions = deque(maxlen=400)
    positions = deque(maxlen=400)
    episode_length = 0
    while True:
        episode_length += 1
        ep_start_time = time.time()
        # Sync with the shared model
        if done:
            model.load_state_dict(shared_model.state_dict())
            cx = Variable(torch.zeros(1, 512),  requires_grad=True ).type(FloatTensor)
            with torch.no_grad():
                cx=cx
            hx = Variable(torch.zeros(1, 512),  requires_grad=True).type(FloatTensor)
            with torch.no_grad():
                hx=hx

        else:
            with torch.no_grad():
                cx = Variable(cx.data).type(FloatTensor)
                hx = Variable(hx.data).type(FloatTensor)
        

        with torch.no_grad(): state_inp = Variable(state.unsqueeze(0)).type(FloatTensor)
        value, logit, (hx, cx) = model((state_inp, (hx, cx)))
        prob = F.softmax(logit, dim=-1)
        action = prob.max(-1, keepdim=True)[1].data
        action_out = int(action[0, 0].data.numpy())
        state, reward, done, info = env.step(action_out)
        #env.render()
        done = done or episode_length >= args.max_episode_length
        reward_sum += reward

        # a quick hack to prevent the agent from stucking
        actions.append(action[0, 0])
        if actions.count(actions[0]) == actions.maxlen:
            done = True
            print('action')
        if args.pos_stuck :
            positions.append(info['x_pos'])
            pos_ar = np.array(positions)
            if (len(positions) >= 200) and (pos_ar < pos_ar[-1] + 20).all() and (pos_ar > pos_ar[-1] - 20).all():
                done = True

        if done:
            print("Time {}, num steps {}, FPS {:.0f}, episode reward {:.3f}, distance covered {:.3f}, episode length {}".format(
                time.strftime("%Hh %Mm %Ss",
                              time.gmtime(time.time() - start_time)), 
                counter.value, counter.value / (time.time() - start_time),
                reward_sum, info['x_pos']/x_norm, episode_length))
            
            data = [time.time() - ep_start_time,
                    counter.value, reward_sum, info['x_pos']/x_norm, episode_length]
            
            with open(savefile, 'a', newline='') as sfile:
                writer = csv.writer(sfile)
                writer.writerows([data])
            
            reward_sum = 0
            episode_length = 0
            actions.clear()
            positions.clear()
            time.sleep(60)
#             env.locked_levels = [False] + [True] * 31
#             env.change_level(0)
            state = env.reset()

        state = torch.from_numpy(state)
Beispiel #9
0
					help='model save interval (default: {})'.format(SAVEPATH))
parser.add_argument('--non-sample', type=int,default=1,
					help='number of non sampling processes (default: 1)')

mp = _mp.get_context('spawn')

print("Cuda: " + str(torch.cuda.is_available()))

if __name__ == '__main__':

	os.environ['OMP_NUM_THREADS'] = '1'

	args = parser.parse_args()
	env = create_mario_env(args.env_name)

	shared_model = ActorCritic( env.observation_space.shape[0], len(COMPLEX_MOVEMENT))
	if args.use_cuda:
		shared_model.cuda()

	shared_model.share_memory()

	if os.path.isfile(args.save_path):
		print('Loading A3C parametets ...')
		shared_model.load_state_dict(torch.load(args.save_path, map_location='cpu'))

	torch.manual_seed(args.seed)

	optimizer = SharedAdam(shared_model.parameters(), lr=args.lr)
	optimizer.share_memory()

	print (color.BLUE + "No of available cores : {}".format(mp.cpu_count()) + color.END)
Beispiel #10
0
parser.add_argument('--non-sample',
                    type=int,
                    default=2,
                    help='number of non sampling processes (default: 2)')

mp = _mp.get_context('spawn')

print("Cuda: " + str(torch.cuda.is_available()))

if __name__ == '__main__':
    os.environ['OMP_NUM_THREADS'] = '1'

    args = parser.parse_args()
    env = create_mario_env(args.env_name)

    shared_model = ActorCritic(env.observation_space.shape[0], len(ACTIONS))
    if args.use_cuda:
        shared_model.cuda()

    shared_model.share_memory()

    if os.path.isfile(args.save_path):
        print('Loading A3C parametets ...')
        shared_model.load_state_dict(torch.load(args.save_path))

    torch.manual_seed(args.seed)

    optimizer = SharedAdam(shared_model.parameters(), lr=args.lr)
    optimizer.share_memory()

    print(color.BLUE + "No of available cores : {}".format(mp.cpu_count()) +
Beispiel #11
0
class PPO:
    def __init__(
        self,
        lr,
        gamma,
        k_epochs,
        eps_clip,
        n_j,
        n_m,
        num_layers,
        neighbor_pooling_type,
        input_dim,
        hidden_dim,
        num_mlp_layers_feature_extract,
        num_mlp_layers_actor,
        hidden_dim_actor,
        num_mlp_layers_critic,
        hidden_dim_critic,
    ):
        self.lr = lr
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.k_epochs = k_epochs

        self.policy = ActorCritic(
            n_j=n_j,
            n_m=n_m,
            num_layers=num_layers,
            learn_eps=False,
            neighbor_pooling_type=neighbor_pooling_type,
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            num_mlp_layers_feature_extract=num_mlp_layers_feature_extract,
            num_mlp_layers_actor=num_mlp_layers_actor,
            hidden_dim_actor=hidden_dim_actor,
            num_mlp_layers_critic=num_mlp_layers_critic,
            hidden_dim_critic=hidden_dim_critic,
            device=device)
        self.policy_old = deepcopy(self.policy)
        '''self.policy.load_state_dict(
            torch.load(path='./{}.pth'.format(str(n_j) + '_' + str(n_m) + '_' + str(1) + '_' + str(99))))'''

        self.policy_old.load_state_dict(self.policy.state_dict())
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)
        self.scheduler = torch.optim.lr_scheduler.StepLR(
            self.optimizer,
            step_size=configs.decay_step_size,
            gamma=configs.decay_ratio)

        self.V_loss_2 = nn.MSELoss()

    def update(self, memories, n_tasks, g_pool):

        vloss_coef = configs.vloss_coef
        ploss_coef = configs.ploss_coef
        entloss_coef = configs.entloss_coef

        rewards_all_env = []
        adj_mb_t_all_env = []
        fea_mb_t_all_env = []
        candidate_mb_t_all_env = []
        mask_mb_t_all_env = []
        a_mb_t_all_env = []
        old_logprobs_mb_t_all_env = []
        # store data for all env
        for i in range(len(memories)):
            rewards = []
            discounted_reward = 0
            for reward, is_terminal in zip(reversed(memories[i].r_mb),
                                           reversed(memories[i].done_mb)):
                if is_terminal:
                    discounted_reward = 0
                discounted_reward = reward + (self.gamma * discounted_reward)
                rewards.insert(0, discounted_reward)
            rewards = torch.tensor(rewards, dtype=torch.float).to(device)
            rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)
            rewards_all_env.append(rewards)
            # process each env data
            adj_mb_t_all_env.append(
                aggr_obs(torch.stack(memories[i].adj_mb).to(device), n_tasks))
            fea_mb_t = torch.stack(memories[i].fea_mb).to(device)
            fea_mb_t = fea_mb_t.reshape(-1, fea_mb_t.size(-1))
            fea_mb_t_all_env.append(fea_mb_t)
            candidate_mb_t_all_env.append(
                torch.stack(memories[i].candidate_mb).to(device).squeeze())
            mask_mb_t_all_env.append(
                torch.stack(memories[i].mask_mb).to(device).squeeze())
            a_mb_t_all_env.append(
                torch.stack(memories[i].a_mb).to(device).squeeze())
            old_logprobs_mb_t_all_env.append(
                torch.stack(
                    memories[i].logprobs).to(device).squeeze().detach())

        # get batch argument for net forwarding: mb_g_pool is same for all env
        mb_g_pool = g_pool_cal(
            g_pool,
            torch.stack(memories[0].adj_mb).to(device).shape, n_tasks, device)

        # Optimize policy for K epochs:
        for _ in range(self.k_epochs):
            loss_sum = 0
            vloss_sum = 0
            for i in range(len(memories)):
                pis, vals = self.policy(x=fea_mb_t_all_env[i],
                                        graph_pool=mb_g_pool,
                                        adj=adj_mb_t_all_env[i],
                                        candidate=candidate_mb_t_all_env[i],
                                        mask=mask_mb_t_all_env[i],
                                        padded_nei=None)
                logprobs, ent_loss = eval_actions(pis.squeeze(),
                                                  a_mb_t_all_env[i])
                ratios = torch.exp(logprobs -
                                   old_logprobs_mb_t_all_env[i].detach())
                advantages = rewards_all_env[i] - vals.detach()
                surr1 = ratios * advantages
                surr2 = torch.clamp(ratios, 1 - self.eps_clip,
                                    1 + self.eps_clip) * advantages
                v_loss = self.V_loss_2(vals.squeeze(), rewards_all_env[i])
                p_loss = -torch.min(surr1, surr2)
                ent_loss = -ent_loss.clone()
                loss = vloss_coef * v_loss + ploss_coef * p_loss + entloss_coef * ent_loss
                loss_sum += loss
                vloss_sum += v_loss
            self.optimizer.zero_grad()
            loss_sum.mean().backward()
            self.optimizer.step()

        # Copy new weights into old policy:
        self.policy_old.load_state_dict(self.policy.state_dict())
        if configs.decayflag:
            self.scheduler.step()
        return loss_sum.mean().item(), vloss_sum.mean().item()