def train(opt):
    torch.manual_seed(123)
    if os.path.isdir(opt.log_path):
        shutil.rmtree(opt.log_path)
    os.makedirs(opt.log_path)
    if not os.path.isdir(opt.saved_path):
        os.makedirs(opt.saved_path)
    mp = _mp.get_context("spawn")
    env, num_states, num_actions = create_train_env(opt.world, opt.stage,
                                                    opt.action_type)
    global_model = ActorCritic(num_states, num_actions)
    if opt.use_gpu:
        global_model.cuda()
    global_model.share_memory()
    if opt.load_from_previous_stage:
        if opt.stage == 1:
            previous_world = opt.world - 1
            previous_stage = 4
        else:
            previous_world = opt.world
            previous_stage = opt.stage - 1
        file_ = "{}/a3c_super_mario_bros_{}_{}".format(opt.saved_path,
                                                       previous_world,
                                                       previous_stage)
        if os.path.isfile(file_):
            global_model.load_state_dict(torch.load(file_))

    optimizer = GlobalAdam(global_model.parameters(), lr=opt.lr)
    local_train(0, opt, global_model, optimizer, True)
Beispiel #2
0
def train(opt):
    torch.manual_seed(123)
    if os.path.isdir(opt.log_path):
        shutil.rmtree(opt.log_path)
    os.makedirs(opt.log_path)
    if not os.path.isdir(opt.saved_path):
        os.makedirs(opt.saved_path)
    mp = _mp.get_context("spawn")
    global_model = ActorCritic(num_inputs=3, num_actions=90)
    global_icm = IntrinsicCuriosityModule(num_inputs=3, num_actions=90)
    if opt.use_gpu:
        global_model.cuda()
        global_icm.cuda()
    global_model.share_memory()
    global_icm.share_memory()

    optimizer = GlobalAdam(list(global_model.parameters()) + list(global_icm.parameters()), lr=opt.lr)
    processes = []
    for index in range(opt.num_processes):
        if index == 0:
            process = mp.Process(target=local_train, args=(index, opt, global_model, global_icm, optimizer, True))
        else:
            process = mp.Process(target=local_train, args=(index, opt, global_model, global_icm, optimizer))
        process.start()
        processes.append(process)
    for process in processes:
        process.join()
Beispiel #3
0
def train(opt):
    torch.manual_seed(123)
    opt.log_path = opt.log_path + "/" + opt.exp
    opt.saved_path = opt.saved_path + "/" + opt.exp
    opt.output_path = opt.output_path + "/" + opt.exp
    if os.path.isdir(opt.log_path):
        shutil.rmtree(opt.log_path)
    os.makedirs(opt.log_path)
    mp = _mp.get_context("spawn")
    global_model = ActorCritic(num_inputs=3, num_actions=opt.num_actions)
    global_icm = IntrinsicCuriosityModule(num_inputs=3, num_actions=opt.num_actions)

    if opt.resume_path:
        print("Load model from checkpoint: {}".format(opt.resume_path))
        global_model.load_state_dict(torch.load("{}/a3c".format(opt.resume_path)))
        global_icm.load_state_dict(torch.load("{}/icm".format(opt.resume_path)))

    if opt.use_gpu:
        global_model.cuda()
        global_icm.cuda()
    global_model.share_memory()
    global_icm.share_memory()

    optimizer = GlobalAdam(list(global_model.parameters()) + list(global_icm.parameters()), lr=opt.lr)
    processes = []
    for index in range(opt.num_processes):
        if index == 0:
            process = mp.Process(target=local_train, args=(index, opt, global_model, global_icm, optimizer, True))
        else:
            process = mp.Process(target=local_train, args=(index, opt, global_model, global_icm, optimizer))
        process.start()
        processes.append(process)
    for process in processes:
        process.join()
Beispiel #4
0
def train(opt):
    torch.manual_seed(SEED)

    if os.path.isdir(opt.log_path):
        shutil.rmtree(opt.log_path)
    os.makedirs(opt.log_path)

    if not os.path.isdir(opt.saved_path):
        os.makedirs(opt.saved_path)

    mp = _mp.get_context("spawn")
    env, num_states, num_actions = create_train_env(opt.world, opt.stage,
                                                    opt.action_type)
    global_model = ActorCritic(num_states, num_actions)
    global_model.share_memory()

    if opt.load_from_previous_stage:
        if opt.stage == 1:
            previous_world = opt.world - 1
            previous_stage = 4
        else:
            previous_world = opt.world
            previous_stage = opt.stage - 1

        file_ = f"{opt.saved_path}/a3c_super_mario_bros_{previous_world}_{previous_stage}"
        if os.path.isfile(file_):
            global_model.load_state_dict(torch.load(file_))

    optimizer = GlobalAdam(global_model.parameters(), lr=opt.lr)
    processes = []

    for index in range(opt.num_processes):
        if index == 0:
            process = mp.Process(target=local_train,
                                 args=(index, opt, global_model, optimizer,
                                       True))
        else:
            process = mp.Process(target=local_train,
                                 args=(index, opt, global_model, optimizer))

        process.start()
        processes.append(process)

    process = mp.Process(target=local_test,
                         args=(opt.num_processes, opt, global_model))
    process.start()
    processes.append(process)

    for process in processes:
        process.join()
def local_train(index, opt, global_model, optimizer, save=False):
    torch.manual_seed(123 + index)
    info = {}
    info["flag_get"] = False
    if save:
        start_time = timeit.default_timer()
    writer = SummaryWriter(opt.log_path)
    env, num_states, num_actions = create_train_env(opt.world, opt.stage,
                                                    opt.action_type)
    local_model = ActorCritic(num_states, num_actions)
    if opt.use_gpu:
        local_model.cuda()
    local_model.train()
    state = torch.from_numpy(env.reset())
    if opt.use_gpu:
        state = state.cuda()
    done = True
    curr_step = 0
    curr_episode = 0
    # while True:
    while True:
        if save:
            # if curr_episode % opt.save_interval == 0 and curr_episode > 0:
            #     torch.save(global_model.state_dict(),
            #                "{}/a3c_super_mario_bros_{}_{}".format(opt.saved_path, opt.world, opt.stage))
            print("Process {}. Episode {}".format(index, curr_episode))
        curr_episode += 1
        local_model.load_state_dict(global_model.state_dict())
        if done:
            h_0 = torch.zeros((1, 512), dtype=torch.float)
            c_0 = torch.zeros((1, 512), dtype=torch.float)
        else:
            h_0 = h_0.detach()
            c_0 = c_0.detach()
        if opt.use_gpu:
            h_0 = h_0.cuda()
            c_0 = c_0.cuda()

        log_policies = []
        values = []
        rewards = []
        entropies = []

        for _ in range(opt.num_local_steps):
            curr_step += 1
            logits, value, h_0, c_0 = local_model(state, h_0, c_0)
            policy = F.softmax(logits, dim=1)
            log_policy = F.log_softmax(logits, dim=1)
            entropy = -(policy * log_policy).sum(1, keepdim=True)

            m = Categorical(policy)
            action = m.sample().item()

            state, reward, done, info = env.step(action)
            state = torch.from_numpy(state)
            if opt.use_gpu:
                state = state.cuda()
            if curr_step > opt.num_global_steps:
                done = True

            if done:
                curr_step = 0
                state = torch.from_numpy(env.reset())
                if opt.use_gpu:
                    state = state.cuda()

            values.append(value)
            log_policies.append(log_policy[0, action])
            rewards.append(reward)
            entropies.append(entropy)

            if done:
                break

        R = torch.zeros((1, 1), dtype=torch.float)
        if opt.use_gpu:
            R = R.cuda()
        if not done:
            _, R, _, _ = local_model(state, h_0, c_0)

        gae = torch.zeros((1, 1), dtype=torch.float)
        if opt.use_gpu:
            gae = gae.cuda()
        actor_loss = 0
        critic_loss = 0
        entropy_loss = 0
        next_value = R

        for value, log_policy, reward, entropy in list(
                zip(values, log_policies, rewards, entropies))[::-1]:
            gae = gae * opt.gamma * opt.tau
            gae = gae + reward + opt.gamma * next_value.detach(
            ) - value.detach()
            next_value = value
            actor_loss = actor_loss + log_policy * gae
            R = R * opt.gamma + reward
            critic_loss = critic_loss + (R - value)**2 / 2
            entropy_loss = entropy_loss + entropy

        total_loss = -actor_loss + critic_loss - opt.beta * entropy_loss
        writer.add_scalar("Train_{}/Loss".format(index), total_loss,
                          curr_episode)
        optimizer.zero_grad()
        total_loss.backward()

        for local_param, global_param in zip(local_model.parameters(),
                                             global_model.parameters()):
            if global_param.grad is not None:
                break
            global_param._grad = local_param.grad

        optimizer.step()

        if curr_episode == int(opt.num_global_steps / opt.num_local_steps):
            print("Training process {} terminated".format(index))
            if save:
                end_time = timeit.default_timer()
                print('The code runs for %.2f s ' % (end_time - start_time))
            return

        if curr_episode % opt.save_interval == 0:
            # if info["flag_get"]:
            if local_test(opt.num_processes, opt, global_model, start_time,
                          curr_episode):
                break
Beispiel #6
0
def local_train(index, opt, global_model, global_icm, optimizer, save=False):
    torch.manual_seed(123 + index)
    if save:
        start_time = timeit.default_timer()
    writer = SummaryWriter(opt.log_path)
    env, num_states, num_actions = create_train_env(index + 1)
    local_model = ActorCritic(num_states, num_actions)
    local_icm = IntrinsicCuriosityModule(num_states, num_actions)
    if opt.use_gpu:
        local_model.cuda()
        local_icm.cuda()
    local_model.train()
    local_icm.train()
    inv_criterion = nn.CrossEntropyLoss()
    fwd_criterion = nn.MSELoss()
    state = torch.from_numpy(env.reset(False, False, True))
    if opt.use_gpu:
        state = state.cuda()
    round_done, stage_done, game_done = False, False, True
    curr_step = 0
    curr_episode = 0
    while True:
        if save:
            if curr_episode % opt.save_interval == 0 and curr_episode > 0:
                torch.save(global_model.state_dict(),
                           "{}/a3c_street_fighter".format(opt.saved_path))
                torch.save(global_icm.state_dict(),
                           "{}/icm_street_fighter".format(opt.saved_path))
        curr_episode += 1
        local_model.load_state_dict(global_model.state_dict())
        if round_done or stage_done or game_done:
            h_0 = torch.zeros((1, 1024), dtype=torch.float)
            c_0 = torch.zeros((1, 1024), dtype=torch.float)
        else:
            h_0 = h_0.detach()
            c_0 = c_0.detach()
        if opt.use_gpu:
            h_0 = h_0.cuda()
            c_0 = c_0.cuda()

        log_policies = []
        values = []
        rewards = []
        entropies = []
        inv_losses = []
        fwd_losses = []

        for _ in range(opt.num_local_steps):
            curr_step += 1
            logits, value, h_0, c_0 = local_model(state, h_0, c_0)
            policy = F.softmax(logits, dim=1)
            log_policy = F.log_softmax(logits, dim=1)
            entropy = -(policy * log_policy).sum(1, keepdim=True)

            m = Categorical(policy)
            action = m.sample().item()

            next_state, reward, round_done, stage_done, game_done = env.step(
                action)
            next_state = torch.from_numpy(next_state)
            if opt.use_gpu:
                next_state = next_state.cuda()
            action_oh = torch.zeros((1, num_actions))  # one-hot action
            action_oh[0, action] = 1
            if opt.use_gpu:
                action_oh = action_oh.cuda()
            pred_logits, pred_phi, phi = local_icm(state, next_state,
                                                   action_oh)
            if opt.use_gpu:
                inv_loss = inv_criterion(pred_logits,
                                         torch.tensor([action]).cuda())
            else:
                inv_loss = inv_criterion(pred_logits, torch.tensor([action]))
            fwd_loss = fwd_criterion(pred_phi, phi) / 2
            intrinsic_reward = opt.eta * fwd_loss.detach()
            reward += intrinsic_reward

            if curr_step > opt.num_global_steps:
                round_done, stage_done, game_done = False, False, True

            if round_done or stage_done or game_done:
                curr_step = 0
                next_state = torch.from_numpy(
                    env.reset(round_done, stage_done, game_done))
                if opt.use_gpu:
                    next_state = next_state.cuda()

            values.append(value)
            log_policies.append(log_policy[0, action])
            rewards.append(reward)
            entropies.append(entropy)
            inv_losses.append(inv_loss)
            fwd_losses.append(fwd_loss)
            state = next_state
            if round_done or stage_done or game_done:
                break

        R = torch.zeros((1, 1), dtype=torch.float)
        if opt.use_gpu:
            R = R.cuda()
        if not (round_done or stage_done or game_done):
            _, R, _, _ = local_model(state, h_0, c_0)

        gae = torch.zeros((1, 1), dtype=torch.float)
        if opt.use_gpu:
            gae = gae.cuda()
        actor_loss = 0
        critic_loss = 0
        entropy_loss = 0
        curiosity_loss = 0
        next_value = R

        for value, log_policy, reward, entropy, inv, fwd in list(
                zip(values, log_policies, rewards, entropies, inv_losses,
                    fwd_losses))[::-1]:
            gae = gae * opt.gamma * opt.tau
            gae = gae + reward + opt.gamma * next_value.detach(
            ) - value.detach()
            next_value = value
            actor_loss = actor_loss + log_policy * gae
            R = R * opt.gamma + reward
            critic_loss = critic_loss + (R - value)**2 / 2
            entropy_loss = entropy_loss + entropy
            curiosity_loss = curiosity_loss + (1 -
                                               opt.beta) * inv + opt.beta * fwd

        total_loss = opt.lambda_ * (-actor_loss + critic_loss -
                                    opt.sigma * entropy_loss) + curiosity_loss
        writer.add_scalar("Train_{}/Loss".format(index), total_loss,
                          curr_episode)
        if save:
            print("Process {}. Episode {}. Loss: {}".format(
                index, curr_episode, total_loss))
        optimizer.zero_grad()
        total_loss.backward()

        for local_param, global_param in zip(local_model.parameters(),
                                             global_model.parameters()):
            if global_param.grad is not None:
                break
            global_param._grad = local_param.grad
        for local_param, global_param in zip(local_icm.parameters(),
                                             global_icm.parameters()):
            if global_param.grad is not None:
                break
            global_param._grad = local_param.grad

        optimizer.step()

        if curr_episode == int(opt.num_global_steps / opt.num_local_steps):
            print("Training process {} terminated".format(index))
            if save:
                end_time = timeit.default_timer()
                print('The code runs for %.2f s ' % (end_time - start_time))
            return
def train_a3c(index,
              args,
              A3C_optimizer,
              A3C_shared_model,
              CAE_shared_model,
              CAE_optimizer,
              save=True,
              new_stage=False):

    # load the weights of pretrained model. In pytorch if the model is saved with GPU
    #the data is mapped differently
    if (args.use_cuda):
        CAE_shared_model.load_state_dict(torch.load(
            args.pretrained_model_weights_path),
                                         strict=False)
    else:
        CAE_shared_model.load_state_dict(torch.load(
            args.pretrained_model_weights_path),
                                         strict=False,
                                         map_location=device)
    CAE_shared_model.eval()
    torch.manual_seed(123 + index)

    if save:
        start_time = timeit.default_timer()
    #tensorboard
    writer = SummaryWriter(args.sum_path)
    #initialize the environment
    # in return, environment, states, action (12)
    env, num_states, num_actions = build_environment(args.world, args.stage)
    print('Num of states: {}'.format(num_states))
    #CAE model worker
    CAE_local_model = Convolutional_AutoEncoder()
    #a3c model Worker
    a3c_local_model = ActorCritic(num_states, num_actions)
    #use cuda
    if args.use_cuda:
        a3c_local_model.cuda()
        CAE_local_model.cuda()
    #train the a3c but not the cae as we need itto have frozen weights
    a3c_local_model.train()
    state = torch.from_numpy(env.reset())

    if args.use_cuda:
        state = state.cuda()
    done = True
    curr_step = 0
    curr_episode = 0
    #Runs for each episode until the max episode limit is reached
    while True:
        #save model every 500 episodes
        if save:
            if curr_episode % args.save_interval == 0 and curr_episode > 0:

                torch.save(
                    CAE_shared_model.state_dict(),
                    "{}/CAE_super_mario_bros_{}_{}_enc2".format(
                        args.trained_models_path, args.world, args.stage))

                torch.save(
                    A3C_shared_model.state_dict(),
                    "{}/a3c_super_mario_bros_{}_{}_enc2".format(
                        args.trained_models_path, args.world, args.stage))
            print("Process {}. Episode {}".format(index, curr_episode))
        curr_episode += 1
        a3c_local_model.load_state_dict(A3C_shared_model.state_dict())

        # strict is false because we are only loading the weights for the encoder layers
        # gpu check
        if (args.use_cuda):

            CAE_local_model.load_state_dict(
                torch.load("{}/CAE_super_mario_bros_1_1_enc2".format(
                    args.trained_models_path)))
        else:

            CAE_local_model.load_state_dict(
                torch.load("{}/CAE_super_mario_bros_1_1_enc2".format(
                    args.trained_models_path),
                           map_location='cpu'))

        CAE_local_model.eval()
        #empty tensors that will the first time per episode to the LSTM
        if done:
            hx = torch.zeros((1, 512), dtype=torch.float)
            cx = torch.zeros((1, 512), dtype=torch.float)
        else:
            hx = hx.detach()
            cx = cx.detach()
        if args.use_cuda:
            hx = hx.cuda()
            cx = cx.cuda()

        log_policies = []
        values = []

        rewards = []
        entropies = []
        # start training for each step in an episode
        for _ in range(args.num_steps):
            curr_step += 1
            #Freezing CAE local models weights so it does not get updates
            for param in CAE_local_model.parameters():
                param.requires_grad = False
            #the state is ent as input to CAE Local model which returns an output
            #from the last layer of the encoder
            output_cae = CAE_local_model(state)
            #then the output of cae is sent to the a3c part of the model,
            #which returns the values, policy, and memories
            logits, value, hx, cx = a3c_local_model
            #best policy and log policy is determined through softmax
            #action is chosen through the probability
            #distributions of the policy and entropy
            policy = F.softmax(logits, dim=1)
            log_policy = F.log_softmax(logits, dim=1)
            entropy = -(policy * log_policy).sum(1, keepdim=True)
            m = Categorical(policy)
            action = m.sample().item()
            #using the chosen action to take a step in the environment
            #at return reward for the action and next state
            state, reward, done, _ = env.step(action)
            state = torch.from_numpy(state)
            #check if max episodes reached
            if args.use_cuda:
                state = state.cuda()
            if (curr_step > args.max_steps) or (curr_episode > int(
                    args.max_episodes)):
                done = True
            #check if level is done
            if done:
                curr_step = 0
                state = torch.from_numpy(env.reset())
                if args.use_cuda:
                    state = state.cuda()
            #append all relavant data to use it in the next iteration
            values.append(value)
            log_policies.append(log_policy[0, action])
            rewards.append(reward)
            entropies.append(entropy)

            if done:
                break

        R = torch.zeros((1, 1), dtype=torch.float)
        if args.use_cuda:
            R = R.cuda()
        if not done:
            output_ = CAE_local_model(state)
            _, R, _, _ = a3c_local_model(output_, hx, cx)

        gae = torch.zeros((1, 1), dtype=torch.float)
        if args.use_cuda:
            gae = gae.cuda()
        actor_loss = 0
        critic_loss = 0
        entropy_loss = 0
        next_value = R
        # calculate loss
        for value, log_policy, reward, entropy in list(
                zip(values, log_policies, rewards, entropies))[::-1]:
            gae = gae * args.gamma * args.tau
            gae = gae + reward + args.gamma * next_value.detach(
            ) - value.detach()
            next_value = value
            actor_loss = actor_loss + log_policy * gae
            R = R * args.gamma + reward
            critic_loss = critic_loss + (R - value)**2 / 2
            entropy_loss = entropy_loss + entropy

        total_loss = -actor_loss + critic_loss - args.beta * entropy_loss
        writer.add_scalar("Train_{}/Loss".format(index), total_loss,
                          curr_episode)
        A3C_optimizer.zero_grad()
        total_loss.backward()
        #update model
        for local_param, global_param in zip(a3c_local_model.parameters(),
                                             A3C_shared_model.parameters()):
            if global_param.grad is not None:
                break
            global_param._grad = local_param.grad

        A3C_optimizer.step()

        if curr_episode == int(args.max_episodes):
            print("Training process {} terminated".format(index))
            if save:
                end_time = timeit.default_timer()
                print('The code runs for %.2f s ' % (end_time - start_time))
            return
def shared_learn(args):

    os.environ['OMP_NUM_THREADS'] = '1'
    torch.manual_seed(123)
    # create path for logs
    if os.path.isdir(args.sum_path):
        shutil.rmtree(args.sum_path)
    os.makedirs(args.sum_path)
    if not os.path.isdir(args.trained_models_path):
        os.makedirs(args.trained_models_path)
    mp = _mp.get_context('spawn')

    # create initial mario environment
    env, num_states, num_actions = build_environment(args.world, args.stage)

    print('Num of states: {}'.format(num_states))  #4
    print('environment: {}'.format(env))
    print('Num of actions: {}'.format(num_actions))  #12

    # check if cuda is available else cpu
    device = torch.device('cuda' if (
        args.use_cuda and torch.cuda.is_available()) else 'cpu')

    CAE_shared_model = Convolutional_AutoEncoder()  #.to(device)
    A3C_shared_model = ActorCritic(num_states, num_actions)  #.to(device)
    # if a new stage, then it picks up previous saved model
    if args.new_stage:
        A3C_shared_model.load_state_dict(
            torch.load('{}/a3c_super_mario_bros_{}_{}_enc2'.format(
                args.world, args.stage, args.trained_models_path)))
        A3C_shared_model.eval()
    # GPU check
    if (args.use_cuda and torch.cuda.is_available()):
        A3C_shared_model.cuda()
        CAE_shared_model.cuda()
    # shares memory with worker instances
    CAE_shared_model.share_memory()

    A3C_shared_model.share_memory()

    print('A3C')
    print(A3C_shared_model)
    # intialize optimizer
    optimizer_cae = CAE_shared_model.createLossAndOptimizer(
        CAE_shared_model, 0.001)
    optimizer_a3c = SharedAdam(A3C_shared_model.parameters(), lr=args.lr)
    #optimizer.share_memory()

    # processes
    workers = []

    # start train process (run for the set number of workers)
    for rank in range(args.num_processes):
        if rank == 0:
            worker = mp.Process(target=train_a3c,
                                args=(rank, args, optimizer_a3c,
                                      A3C_shared_model, CAE_shared_model,
                                      optimizer_cae, True))
        else:
            worker = mp.Process(target=train_a3c,
                                args=(rank, args, optimizer_a3c,
                                      A3C_shared_model, CAE_shared_model,
                                      optimizer_cae, True))
        worker.start()
        worker.append(worker)

    # test worker
    worker = mp.Process(target=test_a3c,
                        args=(rank, args, A3C_shared_model, CAE_shared_model))
    worker.start()
    workers.append(worker)

    # join all processes
    for worker in workers:
        worker.join()
Beispiel #9
0
def local_train(index, opt, global_model, global_icm, optimizer, save=False):
    torch.manual_seed(123 + index)
    if save:
        start_time = timeit.default_timer()
    writer = SummaryWriter(opt.log_path)
    env, num_states, num_actions = create_train_env(
        index + 1, opt, "{}/test.mp4".format(opt.output_path))
    local_model = ActorCritic(num_states, num_actions)
    local_icm = IntrinsicCuriosityModule(num_states, num_actions)
    if opt.use_gpu:
        local_model.cuda()
        local_icm.cuda()
    local_model.train()
    local_icm.train()
    inv_criterion = nn.CrossEntropyLoss()
    fwd_criterion = nn.MSELoss()
    state = torch.from_numpy(env.reset(False, False, True))
    if opt.use_gpu:
        state = state.cuda()
    round_done, stage_done, game_done = False, False, True
    curr_step = 0
    total_step = 0
    curr_episode = 0
    return_eps = 0
    next_save = False
    while True:
        if save and next_save:
            next_save = False
            saved_path = opt.saved_path + "/" + str(total_step // 1000) + "K"
            if not os.path.isdir(saved_path):
                os.makedirs(saved_path)
            torch.save(global_model.state_dict(), "{}/a3c".format(saved_path))
            torch.save(global_icm.state_dict(), "{}/icm".format(saved_path))
        #curr_episode += 1
        local_model.load_state_dict(global_model.state_dict())
        local_icm.load_state_dict(global_icm.state_dict())
        if round_done or stage_done or game_done:
            h_0 = torch.zeros((1, 256), dtype=torch.float)
            c_0 = torch.zeros((1, 256), dtype=torch.float)
        else:
            h_0 = h_0.detach()
            c_0 = c_0.detach()
        if opt.use_gpu:
            h_0 = h_0.cuda()
            c_0 = c_0.cuda()

        log_policies = []
        values = []
        rewards = []
        entropies = []
        inv_losses = []
        fwd_losses = []
        action_cnt = [0] * num_actions
        highest_position = env.game.newGame.Players[0].getPosition()
        first_policy = None

        for i in range(opt.num_local_steps):
            total_step += 1
            if total_step % opt.save_interval == 0 and total_step > 0:
                next_save = True
            curr_step += 1
            logits, value, h_0, c_0 = local_model(state, h_0, c_0)
            policy = F.softmax(logits, dim=1)
            log_policy = F.log_softmax(logits, dim=1)
            entropy = -(policy * log_policy).sum(1, keepdim=True)

            m = Categorical(policy)
            action = m.sample().item()
            action_cnt[action] += 1
            if i == 0:
                first_policy = policy

            next_state, reward, round_done, stage_done, game_done = env.step(
                action)
            return_eps += reward
            highest_position = max(highest_position,
                                   env.game.newGame.Players[0].getPosition(),
                                   key=lambda p: -p[1])
            next_state = torch.from_numpy(next_state)
            if opt.use_gpu:
                next_state = next_state.cuda()
            action_oh = torch.zeros((1, num_actions))  # one-hot action
            action_oh[0, action] = 1
            if opt.use_gpu:
                action_oh = action_oh.cuda()
            pred_logits, pred_phi, phi = local_icm(state, next_state,
                                                   action_oh)
            if opt.use_gpu:
                inv_loss = inv_criterion(pred_logits,
                                         torch.tensor([action]).cuda())
            else:
                inv_loss = inv_criterion(pred_logits, torch.tensor([action]))
            fwd_loss = fwd_criterion(pred_phi, phi) / 2
            intrinsic_reward = opt.eta * fwd_loss.detach()
            reward += intrinsic_reward

            if curr_step >= opt.max_steps:
                round_done, stage_done, game_done = False, False, True

            if round_done or stage_done or game_done:
                curr_step = 0
                curr_episode += 1
                next_state = torch.from_numpy(
                    env.reset(round_done, stage_done, game_done))
                if opt.use_gpu:
                    next_state = next_state.cuda()
                if save:
                    writer.add_scalar("Train_{}/Return".format(index),
                                      return_eps, curr_episode)
                return_eps = 0

            values.append(value)
            log_policies.append(log_policy[0, action])
            rewards.append(reward)
            entropies.append(entropy)
            inv_losses.append(inv_loss)
            fwd_losses.append(fwd_loss)
            state = next_state
            if round_done or stage_done or game_done:
                break

        R = torch.zeros((1, 1), dtype=torch.float)
        if opt.use_gpu:
            R = R.cuda()
        if not (round_done or stage_done or game_done):
            _, R, _, _ = local_model(state, h_0, c_0)

        gae = torch.zeros((1, 1), dtype=torch.float)
        if opt.use_gpu:
            gae = gae.cuda()
        actor_loss = 0
        critic_loss = 0
        entropy_loss = 0
        curiosity_loss = 0
        next_value = R

        for value, log_policy, reward, entropy, inv, fwd in list(
                zip(values, log_policies, rewards, entropies, inv_losses,
                    fwd_losses))[::-1]:
            gae = gae * opt.gamma * opt.tau
            gae = gae + reward + opt.gamma * next_value.detach(
            ) - value.detach()
            next_value = value
            actor_loss = actor_loss + log_policy * gae
            R = R * opt.gamma + reward
            critic_loss = critic_loss + (R - value)**2 / 2
            entropy_loss = entropy_loss + entropy
            curiosity_loss = curiosity_loss + (1 -
                                               opt.beta) * inv + opt.beta * fwd

        total_loss = opt.lambda_ * (-actor_loss + critic_loss -
                                    opt.sigma * entropy_loss) + curiosity_loss
        if save:
            writer.add_scalar("Train_{}/Loss".format(index), total_loss,
                              total_step)
            c_loss = curiosity_loss.item()
            t_loss = total_loss.item()
            print("Process {}. Episode {}. A3C Loss: {}. ICM Loss: {}.".format(
                index, curr_episode, t_loss - c_loss, c_loss))
            print(
                "# Actions Tried: {}. Highest Position: {}. First Policy: {}".
                format(action_cnt, highest_position,
                       first_policy.cpu().detach().numpy()))
        optimizer.zero_grad()
        total_loss.backward()

        for local_param, global_param in zip(local_model.parameters(),
                                             global_model.parameters()):
            if global_param.grad is not None:
                break
            global_param._grad = local_param.grad
        for local_param, global_param in zip(local_icm.parameters(),
                                             global_icm.parameters()):
            if global_param.grad is not None:
                break
            global_param._grad = local_param.grad

        optimizer.step()

        if curr_episode == int(opt.num_global_steps / opt.num_local_steps):
            print("Training process {} terminated".format(index))
            if save:
                end_time = timeit.default_timer()
                print('The code runs for %.2f s ' % (end_time - start_time))
            return